欧美bbbwbbbw肥妇,免费乱码人妻系列日韩,一级黄片

基于Pytorch的神經網絡之Regression的實現(xiàn)

 更新時間:2022年03月15日 10:15:49   作者:ZDDWLIG  
本文主要介紹了基于Pytorch的神經網絡之Regression的實現(xiàn),文中通過示例代碼介紹的非常詳細,具有一定的參考價值,感興趣的小伙伴們可以參考一下

1.引言

我們之前已經介紹了神經網絡的基本知識,神經網絡的主要作用就是預測與分類,現(xiàn)在讓我們來搭建第一個用于擬合回歸的神經網絡吧。

2.神經網絡搭建

2.1 準備工作

要搭建擬合神經網絡并繪圖我們需要使用python的幾個庫。

import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
 
x = torch.unsqueeze(torch.linspace(-5, 5, 100), dim=1)
y = x.pow(3) + 0.2 * torch.rand(x.size())

 既然是擬合,我們當然需要一些數據啦,我選取了在區(qū)間 [-5,5] 內的100個等間距點,并將它們排列成三次函數的圖像。

2.2 搭建網絡

我們定義一個類,繼承了封裝在torch中的一個模塊,我們先分別確定輸入層、隱藏層、輸出層的神經元數目,繼承父類后再使用torch中的.nn.Linear()函數進行輸入層到隱藏層的線性變換,隱藏層也進行線性變換后傳入輸出層predict,接下來定義前向傳播的函數forward(),使用relu()作為激活函數,最后輸出predict()結果即可。

class Net(torch.nn.Module):
    def __init__(self, n_feature, n_hidden, n_output):
        super(Net, self).__init__()
        self.hidden = torch.nn.Linear(n_feature, n_hidden)
        self.predict = torch.nn.Linear(n_hidden, n_output)
    def forward(self, x):
        x = F.relu(self.hidden(x))
        return self.predict(x)
net = Net(1, 20, 1)
print(net)
optimizer = torch.optim.Adam(net.parameters(), lr=0.2)
loss_func = torch.nn.MSELoss()

網絡的框架搭建完了,然后我們傳入三層對應的神經元數目再定義優(yōu)化器,這里我選取了Adam而隨機梯度下降(SGD),因為它是SGD的優(yōu)化版本,效果在大部分情況下比SGD好,我們要傳入這個神經網絡的參數(parameters),并定義學習率(learning rate),學習率通常選取小于1的數,需要憑借經驗并不斷調試。最后我們選取均方差法(MSE)來計算損失(loss)。

2.3 訓練網絡

接下來我們要對我們搭建好的神經網絡進行訓練,我訓練了2000輪(epoch),先更新結果prediction再計算損失,接著清零梯度,然后根據loss反向傳播(backward),最后進行優(yōu)化,找出最優(yōu)的擬合曲線。

for t in range(2000):
    prediction = net(x)
    loss = loss_func(prediction, y)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

3.效果

使用如下繪圖的代碼展示效果。

for t in range(2000):
    prediction = net(x)
    loss = loss_func(prediction, y)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    if t % 5 == 0:
        plt.cla()
        plt.scatter(x.data.numpy(), y.data.numpy(), s=10)
        plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=2)
        plt.text(2, -100, 'Loss=%.4f' % loss.data.numpy(), fontdict={'size': 10, 'color': 'red'})
        plt.pause(0.1)
plt.ioff()
plt.show()

最后的結果: 

4. 完整代碼

import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
 
x = torch.unsqueeze(torch.linspace(-5, 5, 100), dim=1)
y = x.pow(3) + 0.2 * torch.rand(x.size())
class Net(torch.nn.Module):
    def __init__(self, n_feature, n_hidden, n_output):
        super(Net, self).__init__()
        self.hidden = torch.nn.Linear(n_feature, n_hidden)
        self.predict = torch.nn.Linear(n_hidden, n_output)
    def forward(self, x):
        x = F.relu(self.hidden(x))
        return self.predict(x)
net = Net(1, 20, 1)
print(net)
optimizer = torch.optim.Adam(net.parameters(), lr=0.2)
loss_func = torch.nn.MSELoss()
plt.ion()
for t in range(2000):
    prediction = net(x)
    loss = loss_func(prediction, y)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    if t % 5 == 0:
        plt.cla()
        plt.scatter(x.data.numpy(), y.data.numpy(), s=10)
        plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=2)
        plt.text(2, -100, 'Loss=%.4f' % loss.data.numpy(), fontdict={'size': 10, 'color': 'red'})
        plt.pause(0.1)
plt.ioff()
plt.show()

到此這篇關于基于Pytorch的神經網絡之Regression的實現(xiàn)的文章就介紹到這了,更多相關 Pytorch Regression內容請搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關文章希望大家以后多多支持腳本之家!

相關文章

  • Python3安裝Scrapy的方法步驟

    Python3安裝Scrapy的方法步驟

    本篇文章主要介紹了Python3安裝Scrapy的方法步驟,小編覺得挺不錯的,現(xiàn)在分享給大家,也給大家做個參考。一起跟隨小編過來看看吧
    2017-11-11
  • python使用pycharm環(huán)境調用opencv庫

    python使用pycharm環(huán)境調用opencv庫

    這篇文章主要介紹了python使用pycharm環(huán)境調用opencv庫,小編覺得挺不錯的,現(xiàn)在分享給大家,也給大家做個參考。一起跟隨小編過來看看吧
    2018-02-02
  • Python報表自動化之從數據到可視化一站式指南

    Python報表自動化之從數據到可視化一站式指南

    在現(xiàn)代數據驅動的世界中,生成清晰、有用的報表對于業(yè)務決策至關重要,Python作為一門強大的編程語言,提供了豐富的庫和工具,使得報表自動化變得輕而易舉,本文將詳細介紹如何利用Python從數據處理到可視化,實現(xiàn)報表自動化的全過程
    2024-01-01
  • LyScript實現(xiàn)對內存堆棧掃描的方法詳解

    LyScript實現(xiàn)對內存堆棧掃描的方法詳解

    LyScript插件中提供了三種基本的堆棧操作方法,其中push_stack用于入棧,pop_stack用于出棧,peek_stac可用于檢查指定堆棧位置處的內存參數。所以本文將利用這一特性實現(xiàn)對內存堆棧掃描,感興趣的可以了解一下
    2022-08-08
  • Python裝飾器代碼詳解

    Python裝飾器代碼詳解

    這篇文章主要介紹了python 一篇文章搞懂裝飾器所有用法,文中通過示例代碼介紹的非常詳細,對大家的學習或者工作具有一定的參考學習價值,需要的朋友們下面隨著小編來一起學習學習吧
    2021-10-10
  • python獲取本機mac地址和ip地址的方法

    python獲取本機mac地址和ip地址的方法

    這篇文章主要介紹了python獲取本機mac地址和ip地址的方法,涉及Python獲取系統(tǒng)相關信息的技巧,需要的朋友可以參考下
    2015-04-04
  • python中用shutil.move移動文件或目錄的方法實例

    python中用shutil.move移動文件或目錄的方法實例

    在python操作中大家對os,shutil,sys,等通用庫一定不陌生,下面這篇文章主要給大家介紹了關于python中用shutil.move移動文件或目錄的相關資料,需要的朋友可以參考下
    2022-12-12
  • python中的sort方法使用詳解

    python中的sort方法使用詳解

    這篇文章主要介紹了python中的sort方法,需要的朋友可以參考下
    2014-07-07
  • Python matplotlib繪圖風格詳解

    Python matplotlib繪圖風格詳解

    從matplotlib的角度來說,繪圖風格也算是圖像類型的一部分,所以這篇文章小編想帶大家了解一下Python中matplotlib的繪圖風格,有需要的可以參考下
    2023-09-09
  • Pycharm+Python工程,引用子模塊的實現(xiàn)

    Pycharm+Python工程,引用子模塊的實現(xiàn)

    這篇文章主要介紹了Pycharm+Python工程,引用子模塊的實現(xiàn),具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧
    2020-03-03

最新評論