pytorch中retain_graph==True的作用說(shuō)明
pytorch retain_graph==True的作用說(shuō)明
總的來(lái)說(shuō)進(jìn)行一次backward之后,各個(gè)節(jié)點(diǎn)的值會(huì)清除,這樣進(jìn)行第二次backward會(huì)報(bào)錯(cuò),如果加上retain_graph==True后,可以再來(lái)一次backward。
retain_graph參數(shù)的作用
官方定義:
retain_graph (bool, optional) – If False, the graph used to compute the grad will be freed. Note that in nearly all cases setting this option to True is not needed and often can be worked around in a much more efficient way. Defaults to the value of create_graph.
大意是如果設(shè)置為False,計(jì)算圖中的中間變量在計(jì)算完后就會(huì)被釋放。
但是在平時(shí)的使用中這個(gè)參數(shù)默認(rèn)都為False從而提高效率,和creat_graph的值一樣。
具體看一個(gè)例子理解
假設(shè)一個(gè)我們有一個(gè)輸入x,y = x **2, z = y*4,然后我們有兩個(gè)輸出,一個(gè)output_1 = z.mean(),另一個(gè)output_2 = z.sum()。
然后我們對(duì)兩個(gè)output執(zhí)行backward。
import torch x = torch.randn((1,4),dtype=torch.float32,requires_grad=True) y = x ** 2 z = y * 4 print(x) print(y) print(z) loss1 = z.mean() loss2 = z.sum() print(loss1,loss2) loss1.backward() ? ?# 這個(gè)代碼執(zhí)行正常,但是執(zhí)行完中間變量都free了,所以下一個(gè)出現(xiàn)了問(wèn)題 print(loss1,loss2) loss2.backward() ? ?# 這時(shí)會(huì)引發(fā)錯(cuò)誤
程序正常執(zhí)行到第12行,所有的變量正常保存。
但是在第13行報(bào)錯(cuò):
RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.
分析:計(jì)算節(jié)點(diǎn)數(shù)值保存了,但是計(jì)算圖x-y-z結(jié)構(gòu)被釋放了,而計(jì)算loss2的backward仍然試圖利用x-y-z的結(jié)構(gòu),因此會(huì)報(bào)錯(cuò)。
因此需要retain_graph參數(shù)為T(mén)rue去保留中間參數(shù)從而兩個(gè)loss的backward()不會(huì)相互影響。
正確的代碼應(yīng)當(dāng)把第11行以及之后改成
- 1 # 假如你需要執(zhí)行兩次backward,先執(zhí)行第一個(gè)的backward,再執(zhí)行第二個(gè)backward
- 2 loss1.backward(retain_graph=True)# 這里參數(shù)表明保留backward后的中間參數(shù)。
- 3 loss2.backward() # 執(zhí)行完這個(gè)后,所有中間變量都會(huì)被釋放,以便下一次的循環(huán)
- 4 #如果是在訓(xùn)練網(wǎng)絡(luò)optimizer.step() # 更新參數(shù)
create_graph參數(shù)比較簡(jiǎn)單,參考官方定義:
create_graph (bool, optional) – If True, graph of the derivative will be constructed, allowing to compute higher order derivative products. Defaults to False.
Pytorch retain_graph=True錯(cuò)誤信息
(Pytorch:RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time)
具有多個(gè)loss值
retain_graph設(shè)置True,一般多用于兩次backward
# 假如有兩個(gè)Loss,先執(zhí)行第一個(gè)的backward,再執(zhí)行第二個(gè)backward loss1.backward(retain_graph=True) # 這樣計(jì)算圖就不會(huì)立即釋放 loss2.backward() # 執(zhí)行完這個(gè)后,所有中間變量都會(huì)被釋放,以便下一次的循環(huán) optimizer.step() # 更新參數(shù)
retain_graph設(shè)置True后一定要知道釋放,否則顯卡會(huì)占用越來(lái)越多,代碼速度也會(huì)跑的越來(lái)越慢。
有的時(shí)候我明明僅有一個(gè)模型的也會(huì)出現(xiàn)這種錯(cuò)誤
第一種是輸入的原因。
// Example x = torch.randn((100,1), requires_grad = True) y = 1 + 2 * x + 0.3 * torch.randn(100,1) x_train, y_train = x[:70], y[:70] x_val, y_val = x[70:], y[70:] for epoch in range(n_epochs): ?? ?... ?? ?prediction = model(x_train) ?? ?loss.backward() ?? ?...
在多次循環(huán)的過(guò)程中,input的梯度沒(méi)有清除,而且我們也不需要計(jì)算輸入的梯度,因此將x的require_grad設(shè)置為False就可以解決問(wèn)題。
第二種是我在訓(xùn)練LSTM時(shí)候發(fā)現(xiàn)的。
class LSTMpred(nn.Module): ? ? def __init__(self, input_size, hidden_dim): ? ? ?? ?self.hidden = self.init_hidden() ? ? ? ?... ? ? def init_hidden(self):?? ?#這里我們是需要個(gè)隱層參數(shù)的 ? ? ? ? return (torch.zeros(1, 1, self.hidden_dim, requires_grad=True), ? ? ? ? ? ? ? ? torch.zeros(1, 1, self.hidden_dim, requires_grad=True)) ? ? def forward(self, seq): ? ? ? ? ...
這里面的self.hidden我們?cè)诿恳淮斡?xùn)練的時(shí)候都要重新初始化隱層參數(shù):
for epoch in range(Epoch): ?? ?... ?? ?model.hidden = model.init_hidden() ?? ?modout = model(seq) ? ? ...
3. 我的看法
其實(shí),想想這幾種情況都是一回事,都是網(wǎng)絡(luò)在反向傳播中不允許多個(gè)backward(),也就是梯度下降反饋的時(shí)候,有多個(gè)循環(huán)過(guò)程中共用了同一個(gè)需要計(jì)算梯度的變量,在前一個(gè)循環(huán)清除梯度后,后面一個(gè)循環(huán)過(guò)程就會(huì)在這個(gè)變量上栽跟頭(個(gè)人想法)。
總結(jié)
以上為個(gè)人經(jīng)驗(yàn),希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
相關(guān)文章
python定時(shí)采集攝像頭圖像上傳ftp服務(wù)器功能實(shí)現(xiàn)
本文程序?qū)崿F(xiàn)python定時(shí)采集攝像頭圖像上傳ftp服務(wù)器功能,大家參考使用吧2013-12-12Django shell調(diào)試models輸出的SQL語(yǔ)句方法
今天小編就為大家分享一篇Django shell調(diào)試models輸出的SQL語(yǔ)句方法,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧2019-08-08Python-jenkins 獲取job構(gòu)建信息方式
這篇文章主要介紹了Python-jenkins 獲取job構(gòu)建信息方式,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧2020-05-05詳解Python 2.6 升級(jí)至 Python 2.7 的實(shí)踐心得
本篇文章主要介紹了詳解Python 2.6 升級(jí)至 Python 2.7 的實(shí)踐心得,具有一定的參考價(jià)值,感興趣的小伙伴們可以參考一下2017-04-04利用python實(shí)現(xiàn)xml與數(shù)據(jù)庫(kù)讀取轉(zhuǎn)換的方法
這篇文章主要給大家介紹了關(guān)于利用python實(shí)現(xiàn)xml與數(shù)據(jù)庫(kù)讀取轉(zhuǎn)換的方法,文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面來(lái)一起看看吧。2017-06-06python 讀txt文件,按‘,’分割每行數(shù)據(jù)操作
這篇文章主要介紹了python 讀txt文件,按‘,’分割每行數(shù)據(jù)操作,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧2020-07-07