基于Pytorch的神經(jīng)網(wǎng)絡(luò)之Regression的實(shí)現(xiàn)
1.引言
我們之前已經(jīng)介紹了神經(jīng)網(wǎng)絡(luò)的基本知識(shí),神經(jīng)網(wǎng)絡(luò)的主要作用就是預(yù)測(cè)與分類,現(xiàn)在讓我們來(lái)搭建第一個(gè)用于擬合回歸的神經(jīng)網(wǎng)絡(luò)吧。
2.神經(jīng)網(wǎng)絡(luò)搭建
2.1 準(zhǔn)備工作
要搭建擬合神經(jīng)網(wǎng)絡(luò)并繪圖我們需要使用python的幾個(gè)庫(kù)。
import torch import torch.nn.functional as F import matplotlib.pyplot as plt x = torch.unsqueeze(torch.linspace(-5, 5, 100), dim=1) y = x.pow(3) + 0.2 * torch.rand(x.size())
既然是擬合,我們當(dāng)然需要一些數(shù)據(jù)啦,我選取了在區(qū)間 內(nèi)的100個(gè)等間距點(diǎn),并將它們排列成三次函數(shù)的圖像。
2.2 搭建網(wǎng)絡(luò)
我們定義一個(gè)類,繼承了封裝在torch中的一個(gè)模塊,我們先分別確定輸入層、隱藏層、輸出層的神經(jīng)元數(shù)目,繼承父類后再使用torch中的.nn.Linear()函數(shù)進(jìn)行輸入層到隱藏層的線性變換,隱藏層也進(jìn)行線性變換后傳入輸出層predict,接下來(lái)定義前向傳播的函數(shù)forward(),使用relu()作為激活函數(shù),最后輸出predict()結(jié)果即可。
class Net(torch.nn.Module): def __init__(self, n_feature, n_hidden, n_output): super(Net, self).__init__() self.hidden = torch.nn.Linear(n_feature, n_hidden) self.predict = torch.nn.Linear(n_hidden, n_output) def forward(self, x): x = F.relu(self.hidden(x)) return self.predict(x) net = Net(1, 20, 1) print(net) optimizer = torch.optim.Adam(net.parameters(), lr=0.2) loss_func = torch.nn.MSELoss()
網(wǎng)絡(luò)的框架搭建完了,然后我們傳入三層對(duì)應(yīng)的神經(jīng)元數(shù)目再定義優(yōu)化器,這里我選取了Adam而隨機(jī)梯度下降(SGD),因?yàn)樗荢GD的優(yōu)化版本,效果在大部分情況下比SGD好,我們要傳入這個(gè)神經(jīng)網(wǎng)絡(luò)的參數(shù)(parameters),并定義學(xué)習(xí)率(learning rate),學(xué)習(xí)率通常選取小于1的數(shù),需要憑借經(jīng)驗(yàn)并不斷調(diào)試。最后我們選取均方差法(MSE)來(lái)計(jì)算損失(loss)。
2.3 訓(xùn)練網(wǎng)絡(luò)
接下來(lái)我們要對(duì)我們搭建好的神經(jīng)網(wǎng)絡(luò)進(jìn)行訓(xùn)練,我訓(xùn)練了2000輪(epoch),先更新結(jié)果prediction再計(jì)算損失,接著清零梯度,然后根據(jù)loss反向傳播(backward),最后進(jìn)行優(yōu)化,找出最優(yōu)的擬合曲線。
for t in range(2000): prediction = net(x) loss = loss_func(prediction, y) optimizer.zero_grad() loss.backward() optimizer.step()
3.效果
使用如下繪圖的代碼展示效果。
for t in range(2000): prediction = net(x) loss = loss_func(prediction, y) optimizer.zero_grad() loss.backward() optimizer.step() if t % 5 == 0: plt.cla() plt.scatter(x.data.numpy(), y.data.numpy(), s=10) plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=2) plt.text(2, -100, 'Loss=%.4f' % loss.data.numpy(), fontdict={'size': 10, 'color': 'red'}) plt.pause(0.1) plt.ioff() plt.show()
最后的結(jié)果:
4. 完整代碼
import torch import torch.nn.functional as F import matplotlib.pyplot as plt x = torch.unsqueeze(torch.linspace(-5, 5, 100), dim=1) y = x.pow(3) + 0.2 * torch.rand(x.size()) class Net(torch.nn.Module): def __init__(self, n_feature, n_hidden, n_output): super(Net, self).__init__() self.hidden = torch.nn.Linear(n_feature, n_hidden) self.predict = torch.nn.Linear(n_hidden, n_output) def forward(self, x): x = F.relu(self.hidden(x)) return self.predict(x) net = Net(1, 20, 1) print(net) optimizer = torch.optim.Adam(net.parameters(), lr=0.2) loss_func = torch.nn.MSELoss() plt.ion() for t in range(2000): prediction = net(x) loss = loss_func(prediction, y) optimizer.zero_grad() loss.backward() optimizer.step() if t % 5 == 0: plt.cla() plt.scatter(x.data.numpy(), y.data.numpy(), s=10) plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=2) plt.text(2, -100, 'Loss=%.4f' % loss.data.numpy(), fontdict={'size': 10, 'color': 'red'}) plt.pause(0.1) plt.ioff() plt.show()
到此這篇關(guān)于基于Pytorch的神經(jīng)網(wǎng)絡(luò)之Regression的實(shí)現(xiàn)的文章就介紹到這了,更多相關(guān) Pytorch Regression內(nèi)容請(qǐng)搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
相關(guān)文章
python使用pycharm環(huán)境調(diào)用opencv庫(kù)
這篇文章主要介紹了python使用pycharm環(huán)境調(diào)用opencv庫(kù),小編覺(jué)得挺不錯(cuò)的,現(xiàn)在分享給大家,也給大家做個(gè)參考。一起跟隨小編過(guò)來(lái)看看吧2018-02-02Python報(bào)表自動(dòng)化之從數(shù)據(jù)到可視化一站式指南
在現(xiàn)代數(shù)據(jù)驅(qū)動(dòng)的世界中,生成清晰、有用的報(bào)表對(duì)于業(yè)務(wù)決策至關(guān)重要,Python作為一門(mén)強(qiáng)大的編程語(yǔ)言,提供了豐富的庫(kù)和工具,使得報(bào)表自動(dòng)化變得輕而易舉,本文將詳細(xì)介紹如何利用Python從數(shù)據(jù)處理到可視化,實(shí)現(xiàn)報(bào)表自動(dòng)化的全過(guò)程2024-01-01LyScript實(shí)現(xiàn)對(duì)內(nèi)存堆棧掃描的方法詳解
LyScript插件中提供了三種基本的堆棧操作方法,其中push_stack用于入棧,pop_stack用于出棧,peek_stac可用于檢查指定堆棧位置處的內(nèi)存參數(shù)。所以本文將利用這一特性實(shí)現(xiàn)對(duì)內(nèi)存堆棧掃描,感興趣的可以了解一下2022-08-08python中用shutil.move移動(dòng)文件或目錄的方法實(shí)例
在python操作中大家對(duì)os,shutil,sys,等通用庫(kù)一定不陌生,下面這篇文章主要給大家介紹了關(guān)于python中用shutil.move移動(dòng)文件或目錄的相關(guān)資料,需要的朋友可以參考下2022-12-12Pycharm+Python工程,引用子模塊的實(shí)現(xiàn)
這篇文章主要介紹了Pycharm+Python工程,引用子模塊的實(shí)現(xiàn),具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧2020-03-03