PyTorch實(shí)現(xiàn)線性回歸詳細(xì)過程
一、實(shí)現(xiàn)步驟
1、準(zhǔn)備數(shù)據(jù)
x_data = torch.tensor([[1.0],[2.0],[3.0]]) y_data = torch.tensor([[2.0],[4.0],[6.0]])
2、設(shè)計(jì)模型
class LinearModel(torch.nn.Module): ? ? def __init__(self): ? ? ? ? super(LinearModel,self).__init__() ? ? ? ? self.linear = torch.nn.Linear(1,1) ? ? ? ?? ? ? def forward(self, x): ? ? ? ? y_pred = self.linear(x) ? ? ? ? return y_pred ? ? ? ?? model = LinearModel() ?
3、構(gòu)造損失函數(shù)和優(yōu)化器
criterion = torch.nn.MSELoss(reduction='sum') optimizer = torch.optim.SGD(model.parameters(),lr=0.01)
4、訓(xùn)練過程
epoch_list = [] loss_list = [] w_list = [] b_list = [] for epoch in range(1000): ? ? y_pred = model(x_data)?? ??? ??? ??? ??? ? ?# 計(jì)算預(yù)測(cè)值 ? ? loss = criterion(y_pred, y_data)?? ?# 計(jì)算損失 ? ? print(epoch,loss) ? ?? ? ? epoch_list.append(epoch) ? ? loss_list.append(loss.data.item()) ? ? w_list.append(model.linear.weight.item()) ? ? b_list.append(model.linear.bias.item()) ? ?? ? ? optimizer.zero_grad() ? # 梯度歸零 ? ? loss.backward() ? ? ? ? # 反向傳播 ? ? optimizer.step() ? ? ? ?# 更新
5、結(jié)果展示
展示最終的權(quán)重和偏置:
# 輸出權(quán)重和偏置 print('w = ',model.linear.weight.item()) print('b = ',model.linear.bias.item())
結(jié)果為:
w = 1.9998501539230347
b = 0.0003405189490877092
模型測(cè)試:
# 測(cè)試模型 x_test = torch.tensor([[4.0]]) y_test = model(x_test) print('y_pred = ',y_test.data) y_pred = ?tensor([[7.9997]])
分別繪制損失值隨迭代次數(shù)變化的二維曲線圖和其隨權(quán)重與偏置變化的三維散點(diǎn)圖:
# 二維曲線圖 plt.plot(epoch_list,loss_list,'b') plt.xlabel('epoch') plt.ylabel('loss') plt.show() # 三維散點(diǎn)圖 fig = plt.figure() ax = fig.add_subplot(111, projection='3d') ax.scatter(w_list,b_list,loss_list,c='r') #設(shè)置坐標(biāo)軸 ax.set_xlabel('weight') ax.set_ylabel('bias') ax.set_zlabel('loss') plt.show()
結(jié)果如下圖所示:
到此這篇關(guān)于PyTorch實(shí)現(xiàn)線性回歸詳細(xì)過程的文章就介紹到這了,更多相關(guān)PyTorch線性回歸內(nèi)容請(qǐng)搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
二、參考文獻(xiàn)
相關(guān)文章
淺析python表達(dá)式4+0.5值的數(shù)據(jù)類型
在本篇文章里小編給大家整理的是一篇關(guān)于python表達(dá)式4+0.5值的數(shù)據(jù)類型的知識(shí)點(diǎn)內(nèi)容,需要的的朋友們學(xué)習(xí)下。2020-02-02Python3使用Selenium獲取session和token方法詳解
這篇文章主要介紹了Python3使用Selenium獲取session和token方法詳解,需要的朋友可以參考下2021-02-02在 Python 中如何使用 Re 模塊的正則表達(dá)式通配符
這篇文章主要介紹了在 Python 中如何使用 Re 模塊的正則表達(dá)式通配符,本文詳細(xì)解釋了如何在 Python 中使用帶有通配符的 re.sub() 來匹配字符串與正則表達(dá)式,需要的朋友可以參考下2023-06-06Python算法之棧(stack)的實(shí)現(xiàn)
這篇文章主要介紹了Python算法之棧(stack)的實(shí)現(xiàn),非常實(shí)用,需要的朋友可以參考下2014-08-08Python+pyecharts繪制雙動(dòng)態(tài)曲線教程詳解
pyecharts 是一個(gè)用于生成 Echarts 圖表的類庫。Echarts 是百度開源的一個(gè)數(shù)據(jù)可視化 JS 庫。用 Echarts 生成的圖可視化效果非常棒。本文將用pyecharts繪制雙動(dòng)態(tài)曲線,需要的可以參考一下2022-06-06推薦一款高效的python數(shù)據(jù)框處理工具Sidetable
這篇文章主要為大家介紹推薦一款高效的python數(shù)據(jù)框處理工具Sidetable,文章詳細(xì)的講解了Sidetable的安裝及用法,有需要的朋友可以借鑒參考下,希望能夠有所幫助2021-11-11Python解決MySQL數(shù)據(jù)處理從SQL批量刪除報(bào)錯(cuò)
這篇文章主要為大家介紹了Python解決MySQL數(shù)據(jù)處理從SQL批量刪除報(bào)錯(cuò),有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進(jìn)步,早日升職加薪2023-12-12基于python的opencv圖像處理實(shí)現(xiàn)對(duì)斑馬線的檢測(cè)示例
這篇文章主要介紹了基于python的opencv圖像處理實(shí)現(xiàn)對(duì)斑馬線的檢測(cè)示例,文中通過示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2020-11-11