PyTorch中model.zero_grad()和optimizer.zero_grad()用法
廢話不多說,直接上代碼吧~
model.zero_grad()
optimizer.zero_grad()
首先,這兩種方式都是把模型中參數(shù)的梯度設(shè)為0
當(dāng)optimizer = optim.Optimizer(net.parameters())時(shí),二者等效,其中Optimizer可以是Adam、SGD等優(yōu)化器
def zero_grad(self): """Sets gradients of all model parameters to zero.""" for p in self.parameters(): if p.grad is not None: p.grad.data.zero_()
補(bǔ)充知識(shí):Pytorch中的optimizer.zero_grad和loss和net.backward和optimizer.step的理解
引言
一般訓(xùn)練神經(jīng)網(wǎng)絡(luò),總是逃不開optimizer.zero_grad之后是loss(后面有的時(shí)候還會(huì)寫forward,看你網(wǎng)絡(luò)怎么寫了)之后是是net.backward之后是optimizer.step的這個(gè)過程。
real_a, real_b = batch[0].to(device), batch[1].to(device) fake_b = net_g(real_a) optimizer_d.zero_grad() # 判別器對(duì)虛假數(shù)據(jù)進(jìn)行訓(xùn)練 fake_ab = torch.cat((real_a, fake_b), 1) pred_fake = net_d.forward(fake_ab.detach()) loss_d_fake = criterionGAN(pred_fake, False) # 判別器對(duì)真實(shí)數(shù)據(jù)進(jìn)行訓(xùn)練 real_ab = torch.cat((real_a, real_b), 1) pred_real = net_d.forward(real_ab) loss_d_real = criterionGAN(pred_real, True) # 判別器損失 loss_d = (loss_d_fake + loss_d_real) * 0.5 loss_d.backward() optimizer_d.step()
上面這是一段cGAN的判別器訓(xùn)練過程。標(biāo)題中所涉及到的這些方法,其實(shí)整個(gè)神經(jīng)網(wǎng)絡(luò)的參數(shù)更新過程(特別是反向傳播),具體是怎么操作的,我們一起來探討一下。
參數(shù)更新和反向傳播
上圖為一個(gè)簡單的梯度下降示意圖。比如以SGD為例,是算一個(gè)batch計(jì)算一次梯度,然后進(jìn)行一次梯度更新。這里梯度值就是對(duì)應(yīng)偏導(dǎo)數(shù)的計(jì)算結(jié)果。顯然,我們進(jìn)行下一次batch梯度計(jì)算的時(shí)候,前一個(gè)batch的梯度計(jì)算結(jié)果,沒有保留的必要了。所以在下一次梯度更新的時(shí)候,先使用optimizer.zero_grad把梯度信息設(shè)置為0。
我們使用loss來定義損失函數(shù),是要確定優(yōu)化的目標(biāo)是什么,然后以目標(biāo)為頭,才可以進(jìn)行鏈?zhǔn)椒▌t和反向傳播。
調(diào)用loss.backward方法時(shí)候,Pytorch的autograd就會(huì)自動(dòng)沿著計(jì)算圖反向傳播,計(jì)算每一個(gè)葉子節(jié)點(diǎn)的梯度(如果某一個(gè)變量是由用戶創(chuàng)建的,則它為葉子節(jié)點(diǎn))。使用該方法,可以計(jì)算鏈?zhǔn)椒▌t求導(dǎo)之后計(jì)算的結(jié)果值。
optimizer.step用來更新參數(shù),就是圖片中下半部分的w和b的參數(shù)更新操作。
以上這篇PyTorch中model.zero_grad()和optimizer.zero_grad()用法就是小編分享給大家的全部內(nèi)容了,希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
相關(guān)文章
PyTorch使用GPU加速計(jì)算的實(shí)現(xiàn)
PyTorch利用NVIDIA CUDA庫提供的底層接口來實(shí)現(xiàn)GPU加速計(jì)算,本文就來介紹一下PyTorch使用GPU加速計(jì)算的實(shí)現(xiàn),具有一定的參考價(jià)值,感興趣的可以了解一下2024-02-02Python用for循環(huán)實(shí)現(xiàn)九九乘法表
本文通過實(shí)例代碼給大家介紹了Python用for循環(huán)實(shí)現(xiàn)九九乘法表的方法,代碼簡單易懂,非常不錯(cuò),具有一定的參考借鑒價(jià)值,需要的朋友參考下吧2018-05-05關(guān)于Django使用 django-celery-beat動(dòng)態(tài)添加定時(shí)任務(wù)的方法
本文給大家介紹Django使用 django-celery-beat動(dòng)態(tài)添加定時(shí)任務(wù)的方法,安裝對(duì)應(yīng)的是celery版本,文中給大家介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或工作具有一定的參考借鑒價(jià)值,需要的朋友參考下吧2021-10-10Python3.7 dataclass使用指南小結(jié)
本文將帶你走進(jìn)python3.7的新特性dataclass,通過本文你將學(xué)會(huì)dataclass的使用并避免踏入某些陷阱。小編覺得挺不錯(cuò)的,現(xiàn)在分享給大家,也給大家做個(gè)參考。一起跟隨小編過來看看吧2019-02-02淺談Python 釘釘報(bào)警必備知識(shí)系統(tǒng)講解
這篇文章主要介紹了淺談Python 釘釘報(bào)警必備知識(shí)系統(tǒng)講解,文中通過示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2020-08-08Pytorch反向傳播中的細(xì)節(jié)-計(jì)算梯度時(shí)的默認(rèn)累加操作
這篇文章主要介紹了Pytorch反向傳播中的細(xì)節(jié)-計(jì)算梯度時(shí)的默認(rèn)累加操作,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。如有錯(cuò)誤或未考慮完全的地方,望不吝賜教2021-06-06Python實(shí)現(xiàn)根據(jù)Excel生成Model和數(shù)據(jù)導(dǎo)入腳本
最近遇到一個(gè)需求,有幾十個(gè)Excel,每個(gè)的字段都不一樣,然后都差不多是第一行是表頭,后面幾千上萬的數(shù)據(jù),需要把這些Excel中的數(shù)據(jù)全都加入某個(gè)已經(jīng)上線的Django項(xiàng)目。所以我造了個(gè)自動(dòng)生成?Model和導(dǎo)入腳本的輪子,希望對(duì)大家有所幫助2022-11-11使用grappelli為django admin后臺(tái)添加模板
本文介紹了一款非常流行的Django模板系統(tǒng)--grappelli,以及如何給Django的admin后臺(tái)添加模板,非常的實(shí)用,這里推薦給大家。2014-11-11