Pytorch反向求導(dǎo)更新網(wǎng)絡(luò)參數(shù)的方法
方法一:手動計算變量的梯度,然后更新梯度
import torch from torch.autograd import Variable # 定義參數(shù) w1 = Variable(torch.FloatTensor([1,2,3]),requires_grad = True) # 定義輸出 d = torch.mean(w1) # 反向求導(dǎo) d.backward() # 定義學(xué)習(xí)率等參數(shù) lr = 0.001 # 手動更新參數(shù) w1.data.zero_() # BP求導(dǎo)更新參數(shù)之前,需先對導(dǎo)數(shù)置0 w1.data.sub_(lr*w1.grad.data)
一個網(wǎng)絡(luò)中通常有很多變量,如果按照上述的方法手動求導(dǎo),然后更新參數(shù),是很麻煩的,這個時候可以調(diào)用torch.optim
方法二:使用torch.optim
import torch from torch.autograd import Variable import torch.nn as nn import torch.optim as optim # 這里假設(shè)我們定義了一個網(wǎng)絡(luò),為net steps = 10000 # 定義一個optim對象 optimizer = optim.SGD(net.parameters(), lr = 0.01) # 在for循環(huán)中更新參數(shù) for i in range(steps): optimizer.zero_grad() # 對網(wǎng)絡(luò)中參數(shù)當前的導(dǎo)數(shù)置0 output = net(input) # 網(wǎng)絡(luò)前向計算 loss = criterion(output, target) # 計算損失 loss.backward() # 得到模型中參數(shù)對當前輸入的梯度 optimizer.step() # 更新參數(shù)
注意:torch.optim只用于參數(shù)更新和對參數(shù)的梯度置0,不能計算參數(shù)的梯度,在使用torch.optim進行參數(shù)更新之前,需要寫前向與反向傳播求導(dǎo)的代碼
以上這篇Pytorch反向求導(dǎo)更新網(wǎng)絡(luò)參數(shù)的方法就是小編分享給大家的全部內(nèi)容了,希望能給大家一個參考,也希望大家多多支持腳本之家。
- 使用pytorch進行張量計算、自動求導(dǎo)和神經(jīng)網(wǎng)絡(luò)構(gòu)建功能
- pytorch如何定義新的自動求導(dǎo)函數(shù)
- 在?pytorch?中實現(xiàn)計算圖和自動求導(dǎo)
- Pytorch自動求導(dǎo)函數(shù)詳解流程以及與TensorFlow搭建網(wǎng)絡(luò)的對比
- 淺談Pytorch中的自動求導(dǎo)函數(shù)backward()所需參數(shù)的含義
- pytorch中的自定義反向傳播,求導(dǎo)實例
- 關(guān)于PyTorch 自動求導(dǎo)機制詳解
- 關(guān)于pytorch求導(dǎo)總結(jié)(torch.autograd)
相關(guān)文章
Python通用驗證碼識別OCR庫ddddocr的安裝使用教程
dddd_ocr是一個用于識別驗證碼的開源庫,又名帶帶弟弟ocr,下面這篇文章主要給大家介紹了關(guān)于Python通用驗證碼識別OCR庫ddddocr的安裝使用教程,文中通過示例代碼介紹的非常詳細,需要的朋友可以參考下2022-07-07Python數(shù)據(jù)結(jié)構(gòu)之樹的全面解讀
數(shù)據(jù)結(jié)構(gòu)中有很多樹的結(jié)構(gòu),其中包括二叉樹、二叉搜索樹、2-3樹、紅黑樹等等。本文中對數(shù)據(jù)結(jié)構(gòu)中常見的樹邏輯結(jié)構(gòu)和存儲結(jié)構(gòu)進行了匯總,不求嚴格精準,但求簡單易懂2021-11-11