欧美bbbwbbbw肥妇,免费乱码人妻系列日韩,一级黄片

pytorch實(shí)現(xiàn)梯度下降和反向傳播圖文詳細(xì)講解

 更新時(shí)間:2023年04月24日 10:18:10   作者:瘋狂的小強(qiáng)呀  
這篇文章主要介紹了pytorch實(shí)現(xiàn)梯度下降和反向傳播,反向傳播的目的是計(jì)算成本函數(shù)C對(duì)網(wǎng)絡(luò)中任意w或b的偏導(dǎo)數(shù)。一旦我們有了這些偏導(dǎo)數(shù),我們將通過(guò)一些常數(shù)α的乘積和該數(shù)量相對(duì)于成本函數(shù)的偏導(dǎo)數(shù)來(lái)更新網(wǎng)絡(luò)中的權(quán)重和偏差

反向傳播

這里說(shuō)一下我的理解,反向傳播是相對(duì)于前向計(jì)算的,以公式J(a,b,c)=3(a+bc)為例,前向計(jì)算相當(dāng)于向右計(jì)算J(a,b,c)的值,反向傳播相當(dāng)于反過(guò)來(lái)通過(guò)y求變量a,b,c的導(dǎo)數(shù),如下圖

手動(dòng)完成線性回歸

import torch
import numpy as np
from matplotlib import pyplot as plt
"""
假設(shè)模型為y=w*x+b
我們給出的訓(xùn)練數(shù)據(jù)是通過(guò)y=3*x+1,得到的,其中w=3,b=1
通過(guò)訓(xùn)練y=w*x+b觀察訓(xùn)練結(jié)果是否接近于w=3,b=1
"""
# 設(shè)置學(xué)習(xí)率
learning_rate=0.01
#準(zhǔn)備數(shù)據(jù)
x=torch.rand(500,1) #隨機(jī)生成500個(gè)x作為訓(xùn)練數(shù)據(jù)
y_true=x*3+1 #根據(jù)模型得到x對(duì)應(yīng)的y的實(shí)際值
#初始化參數(shù)
w=torch.rand([1,1],requires_grad=True) #初始化w
b=torch.rand(1,requires_grad=True,dtype=torch.float32) #初始化b
#通過(guò)循環(huán),反向傳播,更新參數(shù)
for i in range(2000):
    # 通過(guò)模型計(jì)算y_predict
    y_predict=torch.matmul(x,w)+b #根據(jù)模型得到預(yù)測(cè)值
    #計(jì)算loss
    loss=(y_true-y_predict).pow(2).mean()
    #防止梯度累加,每次計(jì)算梯度前都將其置為0
    if w.grad is not None:
        w.grad.data.zero_()
    if b.grad is not None:
        b.grad.data.zero_()
    #通過(guò)反向傳播,記錄梯度
    loss.backward()
    #更新參數(shù)
    w.data=w.data-learning_rate*w.grad
    b.data=b.data-learning_rate*b.grad
    # 這里打印部分值看一看變化
    if i%50==0:
        print("w,b,loss:",w.item(),b.item(),loss.item())
#設(shè)置圖像的大小
plt.figure(figsize=(20,8))
#將真實(shí)值用散點(diǎn)表示出來(lái)
plt.scatter(x.numpy().reshape(-1),y_true.numpy().reshape(-1))
#將預(yù)測(cè)值用直線表示出來(lái)
y_predict=torch.matmul(x,w)+b
plt.plot(x.numpy().reshape(-1),y_predict.detach().numpy().reshape(-1),c="r")
#顯示圖像
plt.show()

pytorch API完成線性回歸

優(yōu)化器類

優(yōu)化器(optimizer),可以理解為torch為我們封裝的用來(lái)進(jìn)行更新參數(shù)的方法,比如常見(jiàn)的隨機(jī)梯度下降(stochastic gradient descent,SGD)

優(yōu)化器類都是由torch.optim提供的,例如

  • torch.optim.SGD(參數(shù),學(xué)習(xí)率)
  • torch.optim.Adam(參數(shù),學(xué)習(xí)率)

注意:

  • 參數(shù)可以使用model.parameters()來(lái)獲取,獲取模型中所有requires_grad=True的參數(shù)
  • 優(yōu)化類的使用方法

①實(shí)例化

②所有參數(shù)的梯度,將其置為0

③反向傳播計(jì)算梯度

④更新參數(shù)值

實(shí)現(xiàn)

import torch
from torch import nn
from torch import optim
from matplotlib import pyplot as plt
import numpy as np
# 1.定義數(shù)據(jù),給出x
x=torch.rand(50,1)
# 假定模型為y=w*x+b,根據(jù)模型給出真實(shí)值y=x*3+0.8
y=x*3+0.8
# print(x)
#2.定義模型
class Lr(torch.nn.Module):
    def __init__(self):
        super(Lr, self).__init__()
        self.linear = torch.nn.Linear(1, 1)
    def forward(self, x):
        out = self.linear(x)
        return out
# 3.實(shí)例化模型、loss、優(yōu)化器
model=Lr()
criterion=nn.MSELoss()
# print(list(model.parameters()))
optimizer=optim.SGD(model.parameters(),lr=1e-3)
# 4.訓(xùn)練模型
for i in range(30000):
    out=model(x) #獲取預(yù)測(cè)值
    loss=criterion(y,out) #計(jì)算損失
    optimizer.zero_grad() #梯度歸零
    loss.backward() #計(jì)算梯度
    optimizer.step() #更新梯度
    if (i+1)%100 ==0:
        print('Epoch[{}/{}],loss:{:.6f}'.format(i,30000,loss.data))
# 5.模型評(píng)估
model.eval() #設(shè)置模型為評(píng)估模式,即預(yù)測(cè)模式
predict=model(x)
predict=predict.data.numpy()
plt.scatter(x.data.numpy(),y.data.numpy(),c="r")
plt.plot(x.data.numpy(),predict)
plt.show()

到此這篇關(guān)于pytorch實(shí)現(xiàn)梯度下降和反向傳播圖文詳細(xì)講解的文章就介紹到這了,更多相關(guān)pytorch梯度下降和反向傳播內(nèi)容請(qǐng)搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!

相關(guān)文章

  • python通過(guò)socket搭建極簡(jiǎn)web服務(wù)器的實(shí)現(xiàn)代碼

    python通過(guò)socket搭建極簡(jiǎn)web服務(wù)器的實(shí)現(xiàn)代碼

    python的web框架眾多,常見(jiàn)的如django、flask、tornado等,其底層是什么還是有些許的疑問(wèn),所以查找相關(guān)資料,實(shí)現(xiàn)瀏覽器訪問(wèn),并返回相關(guān)信息,本文將給大家介紹python通過(guò)socket搭建極簡(jiǎn)web服務(wù)器,需要的朋友可以參考下
    2023-10-10
  • 詳解Python自建logging模塊

    詳解Python自建logging模塊

    本篇文章給大家詳細(xì)分析了Python自建logging模塊的方法和代碼分享,有需要的朋友參考學(xué)習(xí)下吧。
    2018-01-01
  • python?操作?mongodb?數(shù)據(jù)庫(kù)詳情

    python?操作?mongodb?數(shù)據(jù)庫(kù)詳情

    這篇文章主要介紹了python?操作?mongodb?數(shù)據(jù)庫(kù)詳情,通過(guò)鏈接數(shù)據(jù)庫(kù),創(chuàng)建數(shù)據(jù)庫(kù)展開(kāi)內(nèi)容詳細(xì),具有一定的參考價(jià)值,需要的的小伙伴可以參考一下
    2022-04-04
  • Python中的匿名函數(shù)使用簡(jiǎn)介

    Python中的匿名函數(shù)使用簡(jiǎn)介

    這篇文章主要介紹了Python中的匿名函數(shù)的使用,lambda是各個(gè)現(xiàn)代編程語(yǔ)言中的重要功能,需要的朋友可以參考下
    2015-04-04
  • 如何通過(guò)神經(jīng)網(wǎng)絡(luò)實(shí)現(xiàn)線性回歸的擬合

    如何通過(guò)神經(jīng)網(wǎng)絡(luò)實(shí)現(xiàn)線性回歸的擬合

    這篇文章主要介紹了如何通過(guò)神經(jīng)網(wǎng)絡(luò)實(shí)現(xiàn)線性回歸的擬合問(wèn)題,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。如有錯(cuò)誤或未考慮完全的地方,望不吝賜教
    2023-05-05
  • python原類、類的創(chuàng)建過(guò)程與方法詳解

    python原類、類的創(chuàng)建過(guò)程與方法詳解

    在本篇文章里小編給各位分享了關(guān)于python原類、類的創(chuàng)建過(guò)程與方法的相關(guān)知識(shí)點(diǎn)內(nèi)容,有興趣的朋友們跟著學(xué)習(xí)參考下。
    2019-07-07
  • 使用Python實(shí)現(xiàn)SSH隧道界面功能

    使用Python實(shí)現(xiàn)SSH隧道界面功能

    這篇文章主要介紹了使用Python實(shí)現(xiàn)一個(gè)SSH隧道界面功能,界面使用tkinter實(shí)現(xiàn),左邊是輸入隧道的信息,右邊為歷史列表,本文通過(guò)示例代碼給大家介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或工作具有一定的參考借鑒價(jià)值,需要的朋友參考下吧
    2022-02-02
  • python pandas生成時(shí)間列表

    python pandas生成時(shí)間列表

    這篇文章主要介紹了python pandas生成時(shí)間列表,文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友可以參考下
    2019-06-06
  • python如何通過(guò)twisted實(shí)現(xiàn)數(shù)據(jù)庫(kù)異步插入

    python如何通過(guò)twisted實(shí)現(xiàn)數(shù)據(jù)庫(kù)異步插入

    這篇文章主要為大家詳細(xì)介紹了python如何通過(guò)twisted實(shí)現(xiàn)數(shù)據(jù)庫(kù)異步插入,具有一定的參考價(jià)值,感興趣的小伙伴們可以參考一下
    2018-03-03
  • Python強(qiáng)大的自省機(jī)制詳解

    Python強(qiáng)大的自省機(jī)制詳解

    這篇文章主要為大家介紹了Python強(qiáng)大的自省機(jī)制,具有一定的參考價(jià)值,感興趣的小伙伴們可以參考一下,希望能夠給你帶來(lái)幫助
    2021-11-11

最新評(píng)論