PyTorch學(xué)習(xí)筆記之回歸實(shí)戰(zhàn)
本文主要是用PyTorch來實(shí)現(xiàn)一個簡單的回歸任務(wù)。
編輯器:spyder
1.引入相應(yīng)的包及生成偽數(shù)據(jù)
import torch import torch.nn.functional as F # 主要實(shí)現(xiàn)激活函數(shù) import matplotlib.pyplot as plt # 繪圖的工具 from torch.autograd import Variable # 生成偽數(shù)據(jù) x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim = 1) y = x.pow(2) + 0.2 * torch.rand(x.size()) # 變?yōu)閂ariable x, y = Variable(x), Variable(y)
其中torch.linspace
是為了生成連續(xù)間斷的數(shù)據(jù),第一個參數(shù)表示起點(diǎn),第二個參數(shù)表示終點(diǎn),第三個參數(shù)表示將這個區(qū)間分成平均幾份,即生成幾個數(shù)據(jù)。因?yàn)閠orch只能處理二維的數(shù)據(jù),所以我們用torch.unsqueeze
給偽數(shù)據(jù)添加一個維度,dim表示添加在第幾維。torch.rand
返回的是[0,1)之間的均勻分布。
2.繪制數(shù)據(jù)圖像
在上述代碼后面加下面的代碼,然后運(yùn)行可得偽數(shù)據(jù)的圖形化表示:
# 繪制數(shù)據(jù)圖像 plt.scatter(x.data.numpy(), y.data.numpy()) plt.show()
3.建立神經(jīng)網(wǎng)絡(luò)
class Net(torch.nn.Module): def __init__(self, n_feature, n_hidden, n_output): super(Net, self).__init__() self.hidden = torch.nn.Linear(n_feature, n_hidden) # hidden layer self.predict = torch.nn.Linear(n_hidden, n_output) # output layer def forward(self, x): x = F.relu(self.hidden(x)) # activation function for hidden layer x = self.predict(x) # linear output return x net = Net(n_feature=1, n_hidden=10, n_output=1) # define the network print(net) # net architecture
一般神經(jīng)網(wǎng)絡(luò)的類都繼承自torch.nn.Module
,__init__()和forward()
兩個函數(shù)是自定義類的主要函數(shù)。在__init__()
中都要添加一句super(Net, self).__init__(),
這是固定的標(biāo)準(zhǔn)寫法,用于繼承父類的初始化函數(shù)。__init__()
中只是對神經(jīng)網(wǎng)絡(luò)的模塊進(jìn)行了聲明,真正的搭建是在forwad()
中實(shí)現(xiàn)。自定義類中的成員都通過self指針來進(jìn)行訪問,所以參數(shù)列表中都包含了self。
如果想查看網(wǎng)絡(luò)結(jié)構(gòu),可以用print()
函數(shù)直接打印網(wǎng)絡(luò)。本文的網(wǎng)絡(luò)結(jié)構(gòu)輸出如下:
Net ( (hidden): Linear (1 -> 10) (predict): Linear (10 -> 1) )
4.訓(xùn)練網(wǎng)絡(luò)
# 訓(xùn)練100次 for t in range(100): prediction = net(x) # input x and predict based on x loss = loss_func(prediction, y) # 一定要是輸出在前,標(biāo)簽在后 (1. nn output, 2. target) optimizer.zero_grad() # clear gradients for next train loss.backward() # backpropagation, compute gradients optimizer.step() # apply gradients
訓(xùn)練網(wǎng)絡(luò)之前我們需要先定義優(yōu)化器和損失函數(shù)。torch.optim
包中包括了各種優(yōu)化器,這里我們選用最常見的SGD作為優(yōu)化器。因?yàn)槲覀円獙W(wǎng)絡(luò)的參數(shù)進(jìn)行優(yōu)化,所以我們要把網(wǎng)絡(luò)的參數(shù)net.parameters()
傳入優(yōu)化器中,并設(shè)置學(xué)習(xí)率(一般小于1)。
由于這里是回歸任務(wù),我們選擇torch.nn.MSELoss()
作為損失函數(shù)。
由于優(yōu)化器是基于梯度來優(yōu)化參數(shù)的,并且梯度會保存在其中。所以在每次優(yōu)化前要通過optimizer.zero_grad()
把梯度置零,然后再后向傳播及更新。
5.可視化訓(xùn)練過程
plt.ion() # something about plotting for t in range(100): ... if t % 5 == 0: # plot and show learning process plt.cla() plt.scatter(x.data.numpy(), y.data.numpy()) plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5) plt.text(0.5, 0, 'Loss=%.4f' % loss.data[0], fontdict={'size': 20, 'color': 'red'}) plt.pause(0.1) plt.ioff() plt.show()
6.運(yùn)行結(jié)果
以上就是本文的全部內(nèi)容,希望對大家的學(xué)習(xí)有所幫助,也希望大家多多支持腳本之家。
相關(guān)文章
python根據(jù)開頭和結(jié)尾字符串獲取中間字符串的方法
這篇文章主要介紹了python根據(jù)開頭和結(jié)尾字符串獲取中間字符串的方法,涉及Python操作字符串截取的相關(guān)技巧,具有一定參考借鑒價值,需要的朋友可以參考下2015-03-03ubuntu在線服務(wù)器python?Package安裝到離線服務(wù)器的過程
這篇文章主要介紹了ubuntu在線服務(wù)器python?Package安裝到離線服務(wù)器,本文給大家介紹的非常詳細(xì),對大家的學(xué)習(xí)或工作具有一定的參考借鑒價值,需要的朋友可以參考下2023-04-04python啟動辦公軟件進(jìn)程(word、excel、ppt、以及wps的et、wps、wpp)
見如下源代碼,也可從附件中下載。2009-04-04Python Pandas數(shù)據(jù)結(jié)構(gòu)簡單介紹
這篇文章主要介紹了Python Pandas數(shù)據(jù)結(jié)構(gòu)簡單介紹的相關(guān)資料,文中通過示例代碼介紹的非常詳細(xì),對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價值,需要的朋友可以參考下2019-07-07基于Python實(shí)現(xiàn)開發(fā)釘釘通知機(jī)器人
在項(xiàng)目協(xié)同工作或自動化流程完成時,我們需要用一定的手段通知自己或他人。Telegram 非常好用,幾個步驟就能創(chuàng)建一個機(jī)器人,可惜在國內(nèi)無法使用。所以本文就來開發(fā)一個釘釘通知機(jī)器人吧2023-02-02