Pytorch 中retain_graph的用法詳解
用法分析
在查看SRGAN源碼時有如下?lián)p失函數(shù),其中設(shè)置了retain_graph=True,其作用是什么?
############################ # (1) Update D network: maximize D(x)-1-D(G(z)) ########################### real_img = Variable(target) if torch.cuda.is_available(): real_img = real_img.cuda() z = Variable(data) if torch.cuda.is_available(): z = z.cuda() fake_img = netG(z) netD.zero_grad() real_out = netD(real_img).mean() fake_out = netD(fake_img).mean() d_loss = 1 - real_out + fake_out d_loss.backward(retain_graph=True) ##### optimizerD.step() ############################ # (2) Update G network: minimize 1-D(G(z)) + Perception Loss + Image Loss + TV Loss ########################### netG.zero_grad() g_loss = generator_criterion(fake_out, fake_img, real_img) g_loss.backward() optimizerG.step() fake_img = netG(z) fake_out = netD(fake_img).mean() g_loss = generator_criterion(fake_out, fake_img, real_img) running_results['g_loss'] += g_loss.data[0] * batch_size d_loss = 1 - real_out + fake_out running_results['d_loss'] += d_loss.data[0] * batch_size running_results['d_score'] += real_out.data[0] * batch_size running_results['g_score'] += fake_out.data[0] * batch_size
在更新D網(wǎng)絡(luò)時的loss反向傳播過程中使用了retain_graph=True,目的為是為保留該過程中計算的梯度,后續(xù)G網(wǎng)絡(luò)更新時使用;
其實retain_graph這個參數(shù)在平常中我們是用不到的,但是在特殊的情況下我們會用到它,
如下代碼:
import torch y=x**2 z=y*4 output1=z.mean() output2=z.sum() output1.backward() output2.backward()
輸出如下錯誤信息:
--------------------------------------------------------------------------- RuntimeError Traceback (most recent call last) <ipython-input-19-8ad6b0658906> in <module>() ----> 1 output1.backward() 2 output2.backward() D:\ProgramData\Anaconda3\lib\site-packages\torch\tensor.py in backward(self, gradient, retain_graph, create_graph) 91 products. Defaults to ``False``. 92 """ ---> 93 torch.autograd.backward(self, gradient, retain_graph, create_graph) 94 95 def register_hook(self, hook): D:\ProgramData\Anaconda3\lib\site-packages\torch\autograd\__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables) 88 Variable._execution_engine.run_backward( 89 tensors, grad_tensors, retain_graph, create_graph, ---> 90 allow_unreachable=True) # allow_unreachable flag 91 92 RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.
修改成如下正確:
import torch y=x**2 z=y*4 output1=z.mean() output2=z.sum() output1.backward(retain_graph=True) output2.backward()
# 假如你有兩個Loss,先執(zhí)行第一個的backward,再執(zhí)行第二個backward loss1.backward(retain_graph=True) loss2.backward() # 執(zhí)行完這個后,所有中間變量都會被釋放,以便下一次的循環(huán) optimizer.step() # 更新參數(shù)
Variable 類源代碼
class Variable(_C._VariableBase): """ Attributes: data: 任意類型的封裝好的張量。 grad: 保存與data類型和位置相匹配的梯度,此屬性難以分配并且不能重新分配。 requires_grad: 標(biāo)記變量是否已經(jīng)由一個需要調(diào)用到此變量的子圖創(chuàng)建的bool值。只能在葉子變量上進(jìn)行修改。 volatile: 標(biāo)記變量是否能在推理模式下應(yīng)用(如不保存歷史記錄)的bool值。只能在葉變量上更改。 is_leaf: 標(biāo)記變量是否是圖葉子(如由用戶創(chuàng)建的變量)的bool值. grad_fn: Gradient function graph trace. Parameters: data (any tensor class): 要包裝的張量. requires_grad (bool): bool型的標(biāo)記值. **Keyword only.** volatile (bool): bool型的標(biāo)記值. **Keyword only.** """ def backward(self, gradient=None, retain_graph=None, create_graph=None, retain_variables=None): """計算關(guān)于當(dāng)前圖葉子變量的梯度,圖使用鏈?zhǔn)椒▌t導(dǎo)致分化 如果Variable是一個標(biāo)量(例如它包含一個單元素數(shù)據(jù)),你無需對backward()指定任何參數(shù) 如果變量不是標(biāo)量(包含多個元素數(shù)據(jù)的矢量)且需要梯度,函數(shù)需要額外的梯度; 需要指定一個和tensor的形狀匹配的grad_output參數(shù)(y在指定方向投影對x的導(dǎo)數(shù)); 可以是一個類型和位置相匹配且包含與自身相關(guān)的不同函數(shù)梯度的張量。 函數(shù)在葉子上累積梯度,調(diào)用前需要對該葉子進(jìn)行清零。 Arguments: grad_variables (Tensor, Variable or None): 變量的梯度,如果是一個張量,除非“create_graph”是True,否則會自動轉(zhuǎn)換成volatile型的變量。 可以為標(biāo)量變量或不需要grad的值指定None值。如果None值可接受,則此參數(shù)可選。 retain_graph (bool, optional): 如果為False,用來計算梯度的圖將被釋放。 在幾乎所有情況下,將此選項設(shè)置為True不是必需的,通??梢砸愿行У姆绞浇鉀Q。 默認(rèn)值為create_graph的值。 create_graph (bool, optional): 為True時,會構(gòu)造一個導(dǎo)數(shù)的圖,用來計算出更高階導(dǎo)數(shù)結(jié)果。 默認(rèn)為False,除非``gradient``是一個volatile變量。 """ torch.autograd.backward(self, gradient, retain_graph, create_graph, retain_variables) def register_hook(self, hook): """Registers a backward hook. 每當(dāng)與variable相關(guān)的梯度被計算時調(diào)用hook,hook的申明:hook(grad)->Variable or None 不能對hook的參數(shù)進(jìn)行修改,但可以選擇性地返回一個新的梯度以用在`grad`的相應(yīng)位置。 函數(shù)返回一個handle,其``handle.remove()``方法用于將hook從模塊中移除。 Example: >>> v = Variable(torch.Tensor([0, 0, 0]), requires_grad=True) >>> h = v.register_hook(lambda grad: grad * 2) # double the gradient >>> v.backward(torch.Tensor([1, 1, 1])) >>> v.grad.data 2 2 2 [torch.FloatTensor of size 3] >>> h.remove() # removes the hook """ if self.volatile: raise RuntimeError("cannot register a hook on a volatile variable") if not self.requires_grad: raise RuntimeError("cannot register a hook on a variable that " "doesn't require gradient") if self._backward_hooks is None: self._backward_hooks = OrderedDict() if self.grad_fn is not None: self.grad_fn._register_hook_dict(self) handle = hooks.RemovableHandle(self._backward_hooks) self._backward_hooks[handle.id] = hook return handle def reinforce(self, reward): """Registers a reward obtained as a result of a stochastic process. 區(qū)分隨機節(jié)點需要為他們提供reward值。如果圖表中包含任何的隨機操作,都應(yīng)該在其輸出上調(diào)用此函數(shù),否則會出現(xiàn)錯誤。 Parameters: reward(Tensor): 帶有每個元素獎賞的張量,必須與Variable數(shù)據(jù)的設(shè)備位置和形狀相匹配。 """ if not isinstance(self.grad_fn, StochasticFunction): raise RuntimeError("reinforce() can be only called on outputs " "of stochastic functions") self.grad_fn._reinforce(reward) def detach(self): """返回一個從當(dāng)前圖分離出來的心變量。 結(jié)果不需要梯度,如果輸入是volatile,則輸出也是volatile。 .. 注意:: 返回變量使用與原始變量相同的數(shù)據(jù)張量,并且可以看到其中任何一個的就地修改,并且可能會觸發(fā)正確性檢查中的錯誤。 """ result = NoGrad()(self) # this is needed, because it merges version counters result._grad_fn = None return result def detach_(self): """從創(chuàng)建它的圖中分離出變量并作為該圖的一個葉子""" self._grad_fn = None self.requires_grad = False def retain_grad(self): """Enables .grad attribute for non-leaf Variables.""" if self.grad_fn is None: # no-op for leaves return if not self.requires_grad: raise RuntimeError("can't retain_grad on Variable that has requires_grad=False") if hasattr(self, 'retains_grad'): return weak_self = weakref.ref(self) def retain_grad_hook(grad): var = weak_self() if var is None: return if var._grad is None: var._grad = grad.clone() else: var._grad = var._grad + grad self.register_hook(retain_grad_hook) self.retains_grad = True
以上這篇Pytorch 中retain_graph的用法詳解就是小編分享給大家的全部內(nèi)容了,希望能給大家一個參考,也希望大家多多支持腳本之家。
相關(guān)文章
Python數(shù)據(jù)分析之獲取雙色球歷史信息的方法示例
這篇文章主要介紹了Python數(shù)據(jù)分析之獲取雙色球歷史信息的方法,涉及Python網(wǎng)頁抓取、正則匹配、文件讀寫及數(shù)值運算等相關(guān)操作技巧,需要的朋友可以參考下2018-02-02好的Python培訓(xùn)機構(gòu)應(yīng)該具備哪些條件
python是現(xiàn)在開發(fā)的熱潮,大家應(yīng)該如何學(xué)習(xí)呢?許多人選擇自學(xué),還有人會選擇去培訓(xùn)結(jié)構(gòu)學(xué)習(xí),那么好的培訓(xùn)機構(gòu)的標(biāo)準(zhǔn)是什么樣的呢?下面跟隨腳本之家小編一起通過本文學(xué)習(xí)吧2018-05-05Python XML轉(zhuǎn)Json之XML2Dict的使用方法
今天小編就為大家分享一篇Python XML轉(zhuǎn)Json之XML2Dict的使用方法,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2019-01-01