對(duì)pytorch中的梯度更新方法詳解
背景
使用pytorch時(shí),有一個(gè)yolov3的bug,我認(rèn)為涉及到學(xué)習(xí)率的調(diào)整。收集到tencent yolov3和mxnet開(kāi)源的yolov3,兩個(gè)優(yōu)化器中的學(xué)習(xí)率設(shè)置不一樣,而且使用GPU數(shù)目和batch的更新也不太一樣。據(jù)此,我簡(jiǎn)單的了解了下pytorch的權(quán)重梯度的更新策略,看看能否一窺究竟。
對(duì)代碼說(shuō)明
共三個(gè)實(shí)驗(yàn),分布寫(xiě)在代碼中的(一)(二)(三)三個(gè)地方。運(yùn)行實(shí)驗(yàn)時(shí)注釋掉其他兩個(gè)
實(shí)驗(yàn)及其結(jié)果
實(shí)驗(yàn)(三):
不使用zero_grad()時(shí),grad累加在一起,官網(wǎng)是使用accumulate 來(lái)表述的,所以不太清楚是取的和還是均值(這兩種最有可能)。
不使用zero_grad()時(shí),是直接疊加add的方式累加的。
tensor([[[ 1., 1.],……torch.Size([2, 2, 2]) 0 2 * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * tensor([[[ 2., 2.],…… torch.Size([2, 2, 2]) 1 2 * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * tensor([[[ 3., 3.],…… torch.Size([2, 2, 2]) 2 2 * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * *
實(shí)驗(yàn)(二):
單卡上不同的batchsize對(duì)梯度是怎么作用的。 mini-batch SGD中的batch是加快訓(xùn)練,同時(shí)保持一定的噪聲。但設(shè)置不同的batchsize的權(quán)重的梯度是怎么計(jì)算的呢。
設(shè)置運(yùn)行實(shí)驗(yàn)(二),可以看到結(jié)果如下:所以單卡batchsize計(jì)算梯度是取均值的
tensor([[[ 3., 3.],…… torch.Size([2, 2, 2])
實(shí)驗(yàn)(一):
多gpu情況下,梯度怎么合并在一起的。
在《training imagenet in 1 hours》中提到grad是allreduce的,是累加的形式。但是當(dāng)設(shè)置g=2,實(shí)驗(yàn)一運(yùn)行時(shí),結(jié)果也是取均值的,類(lèi)同于實(shí)驗(yàn)(二)
tensor([[[ 3., 3.],…… torch.Size([2, 2, 2])
實(shí)驗(yàn)代碼
import torch import torch.nn as nn from torch.autograd import Variable class model(nn.Module): def __init__(self, w): super(model, self).__init__() self.w = w def forward(self, xx): b, c, _, _ = xx.shape # extra = xx.device.index + 1 ## 實(shí)驗(yàn)(一) y = xx.reshape(b, -1).mm(self.w.cuda(xx.device).reshape(-1, 2) * extra) return y.reshape(len(xx), -1) g = 1 x = Variable(torch.ones(2, 1, 2, 2)) # x[1] += 1 ## 實(shí)驗(yàn)(二) w = Variable(torch.ones(2, 2, 2) * 2, requires_grad=True) # optim = torch.optim.SGD({'params': x}, lr = 0.01 momentum = 0.9 M = model(w) M = torch.nn.DataParallel(M, device_ids=range(g)) for i in range(3): b = len(x) z = M(x) zz = z.sum(1) l = (zz - Variable(torch.ones(b).cuda())).mean() # zz.backward(Variable(torch.ones(b).cuda())) l.backward() print(w.grad, w.grad.shape) # w.grad.zero_() ## 實(shí)驗(yàn)(三) print(i, b, '* * ' * 20)
以上這篇對(duì)pytorch中的梯度更新方法詳解就是小編分享給大家的全部?jī)?nèi)容了,希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
相關(guān)文章
Python爬蟲(chóng)實(shí)例扒取2345天氣預(yù)報(bào)
本篇文章給大家詳細(xì)分析了通過(guò)Python爬蟲(chóng)如何采集到2345的天氣預(yù)報(bào)信息,有興趣的朋友參考學(xué)習(xí)下吧。2018-03-03音頻處理 windows10下python三方庫(kù)librosa安裝教程
這篇文章主要介紹了音頻處理 windows10下python三方庫(kù)librosa安裝方法,本文給大家介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或工作具有一定的參考借鑒價(jià)值,需要的朋友可以參考下2020-06-06pycharm連接虛擬機(jī)的實(shí)現(xiàn)步驟
本文主要介紹了pycharm連接虛擬機(jī)的實(shí)現(xiàn)步驟,文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來(lái)一起學(xué)習(xí)學(xué)習(xí)吧2023-12-12Python實(shí)現(xiàn)Harbor私有鏡像倉(cāng)庫(kù)垃圾自動(dòng)化清理詳情
這篇文章主要介紹了Python實(shí)現(xiàn)Harbor私有鏡像倉(cāng)庫(kù)垃圾自動(dòng)化清理詳情,文章圍繞主題分享相關(guān)詳細(xì)代碼,需要的小伙伴可以參考一下2022-05-05詳解Python網(wǎng)絡(luò)爬蟲(chóng)功能的基本寫(xiě)法
這篇文章主要介紹了Python網(wǎng)絡(luò)爬蟲(chóng)功能的基本寫(xiě)法,網(wǎng)絡(luò)爬蟲(chóng),即Web Spider,是一個(gè)很形象的名字。把互聯(lián)網(wǎng)比喻成一個(gè)蜘蛛網(wǎng),那么Spider就是在網(wǎng)上爬來(lái)爬去的蜘蛛,對(duì)網(wǎng)絡(luò)爬蟲(chóng)感興趣的朋友可以參考本文2016-01-01