欧美bbbwbbbw肥妇,免费乱码人妻系列日韩,一级黄片

Pytorch 中retain_graph的用法詳解

 更新時間:2020年01月07日 14:41:39   作者:DaneAI  
今天小編就為大家分享一篇Pytorch 中retain_graph的用法詳解,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧

用法分析

在查看SRGAN源碼時有如下?lián)p失函數(shù),其中設(shè)置了retain_graph=True,其作用是什么?

		############################
    # (1) Update D network: maximize D(x)-1-D(G(z))
    ###########################
    real_img = Variable(target)
    if torch.cuda.is_available():
      real_img = real_img.cuda()
    z = Variable(data)
    if torch.cuda.is_available():
      z = z.cuda()
    fake_img = netG(z)

    netD.zero_grad()
    real_out = netD(real_img).mean()
    fake_out = netD(fake_img).mean()
    d_loss = 1 - real_out + fake_out
    d_loss.backward(retain_graph=True) #####
    optimizerD.step()

    ############################
    # (2) Update G network: minimize 1-D(G(z)) + Perception Loss + Image Loss + TV Loss
    ###########################
    netG.zero_grad()
    g_loss = generator_criterion(fake_out, fake_img, real_img)
    g_loss.backward()
    optimizerG.step()
    fake_img = netG(z)
    fake_out = netD(fake_img).mean()

    g_loss = generator_criterion(fake_out, fake_img, real_img)
    running_results['g_loss'] += g_loss.data[0] * batch_size
    d_loss = 1 - real_out + fake_out
    running_results['d_loss'] += d_loss.data[0] * batch_size
    running_results['d_score'] += real_out.data[0] * batch_size
    running_results['g_score'] += fake_out.data[0] * batch_size

在更新D網(wǎng)絡(luò)時的loss反向傳播過程中使用了retain_graph=True,目的為是為保留該過程中計算的梯度,后續(xù)G網(wǎng)絡(luò)更新時使用;

其實retain_graph這個參數(shù)在平常中我們是用不到的,但是在特殊的情況下我們會用到它,

如下代碼:

import torch
y=x**2
z=y*4
output1=z.mean()
output2=z.sum()
output1.backward()
output2.backward()

輸出如下錯誤信息:

---------------------------------------------------------------------------
RuntimeError               Traceback (most recent call last)
<ipython-input-19-8ad6b0658906> in <module>()
----> 1 output1.backward()
   2 output2.backward()

D:\ProgramData\Anaconda3\lib\site-packages\torch\tensor.py in backward(self, gradient, retain_graph, create_graph)
   91         products. Defaults to ``False``.
   92     """
---> 93     torch.autograd.backward(self, gradient, retain_graph, create_graph)
   94 
   95   def register_hook(self, hook):

D:\ProgramData\Anaconda3\lib\site-packages\torch\autograd\__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables)
   88   Variable._execution_engine.run_backward(
   89     tensors, grad_tensors, retain_graph, create_graph,
---> 90     allow_unreachable=True) # allow_unreachable flag
   91 
   92 

RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.

修改成如下正確:

import torch
y=x**2
z=y*4
output1=z.mean()
output2=z.sum()
output1.backward(retain_graph=True)
output2.backward()
# 假如你有兩個Loss,先執(zhí)行第一個的backward,再執(zhí)行第二個backward
loss1.backward(retain_graph=True)
loss2.backward() # 執(zhí)行完這個后,所有中間變量都會被釋放,以便下一次的循環(huán)
optimizer.step() # 更新參數(shù)

Variable 類源代碼

class Variable(_C._VariableBase):
 
  """
  Attributes:
    data: 任意類型的封裝好的張量。
    grad: 保存與data類型和位置相匹配的梯度,此屬性難以分配并且不能重新分配。
    requires_grad: 標(biāo)記變量是否已經(jīng)由一個需要調(diào)用到此變量的子圖創(chuàng)建的bool值。只能在葉子變量上進(jìn)行修改。
    volatile: 標(biāo)記變量是否能在推理模式下應(yīng)用(如不保存歷史記錄)的bool值。只能在葉變量上更改。
    is_leaf: 標(biāo)記變量是否是圖葉子(如由用戶創(chuàng)建的變量)的bool值.
    grad_fn: Gradient function graph trace.
 
  Parameters:
    data (any tensor class): 要包裝的張量.
    requires_grad (bool): bool型的標(biāo)記值. **Keyword only.**
    volatile (bool): bool型的標(biāo)記值. **Keyword only.**
  """
 
  def backward(self, gradient=None, retain_graph=None, create_graph=None, retain_variables=None):
    """計算關(guān)于當(dāng)前圖葉子變量的梯度,圖使用鏈?zhǔn)椒▌t導(dǎo)致分化
    如果Variable是一個標(biāo)量(例如它包含一個單元素數(shù)據(jù)),你無需對backward()指定任何參數(shù)
    如果變量不是標(biāo)量(包含多個元素數(shù)據(jù)的矢量)且需要梯度,函數(shù)需要額外的梯度;
    需要指定一個和tensor的形狀匹配的grad_output參數(shù)(y在指定方向投影對x的導(dǎo)數(shù));
    可以是一個類型和位置相匹配且包含與自身相關(guān)的不同函數(shù)梯度的張量。
    函數(shù)在葉子上累積梯度,調(diào)用前需要對該葉子進(jìn)行清零。
 
    Arguments:
      grad_variables (Tensor, Variable or None):
              變量的梯度,如果是一個張量,除非“create_graph”是True,否則會自動轉(zhuǎn)換成volatile型的變量。
              可以為標(biāo)量變量或不需要grad的值指定None值。如果None值可接受,則此參數(shù)可選。
      retain_graph (bool, optional): 如果為False,用來計算梯度的圖將被釋放。
                      在幾乎所有情況下,將此選項設(shè)置為True不是必需的,通??梢砸愿行У姆绞浇鉀Q。
                      默認(rèn)值為create_graph的值。
      create_graph (bool, optional): 為True時,會構(gòu)造一個導(dǎo)數(shù)的圖,用來計算出更高階導(dǎo)數(shù)結(jié)果。
                      默認(rèn)為False,除非``gradient``是一個volatile變量。
    """
    torch.autograd.backward(self, gradient, retain_graph, create_graph, retain_variables)
 
 
  def register_hook(self, hook):
    """Registers a backward hook.
 
    每當(dāng)與variable相關(guān)的梯度被計算時調(diào)用hook,hook的申明:hook(grad)->Variable or None
    不能對hook的參數(shù)進(jìn)行修改,但可以選擇性地返回一個新的梯度以用在`grad`的相應(yīng)位置。
 
    函數(shù)返回一個handle,其``handle.remove()``方法用于將hook從模塊中移除。
 
    Example:
      >>> v = Variable(torch.Tensor([0, 0, 0]), requires_grad=True)
      >>> h = v.register_hook(lambda grad: grad * 2) # double the gradient
      >>> v.backward(torch.Tensor([1, 1, 1]))
      >>> v.grad.data
       2
       2
       2
      [torch.FloatTensor of size 3]
      >>> h.remove() # removes the hook
    """
    if self.volatile:
      raise RuntimeError("cannot register a hook on a volatile variable")
    if not self.requires_grad:
      raise RuntimeError("cannot register a hook on a variable that "
                "doesn't require gradient")
    if self._backward_hooks is None:
      self._backward_hooks = OrderedDict()
      if self.grad_fn is not None:
        self.grad_fn._register_hook_dict(self)
    handle = hooks.RemovableHandle(self._backward_hooks)
    self._backward_hooks[handle.id] = hook
    return handle
 
  def reinforce(self, reward):
    """Registers a reward obtained as a result of a stochastic process.
    區(qū)分隨機節(jié)點需要為他們提供reward值。如果圖表中包含任何的隨機操作,都應(yīng)該在其輸出上調(diào)用此函數(shù),否則會出現(xiàn)錯誤。
    Parameters:
      reward(Tensor): 帶有每個元素獎賞的張量,必須與Variable數(shù)據(jù)的設(shè)備位置和形狀相匹配。
    """
    if not isinstance(self.grad_fn, StochasticFunction):
      raise RuntimeError("reinforce() can be only called on outputs "
                "of stochastic functions")
    self.grad_fn._reinforce(reward)
 
  def detach(self):
    """返回一個從當(dāng)前圖分離出來的心變量。
    結(jié)果不需要梯度,如果輸入是volatile,則輸出也是volatile。
 
    .. 注意::
     返回變量使用與原始變量相同的數(shù)據(jù)張量,并且可以看到其中任何一個的就地修改,并且可能會觸發(fā)正確性檢查中的錯誤。
    """
    result = NoGrad()(self) # this is needed, because it merges version counters
    result._grad_fn = None
    return result
 
  def detach_(self):
    """從創(chuàng)建它的圖中分離出變量并作為該圖的一個葉子"""
    self._grad_fn = None
    self.requires_grad = False
 
  def retain_grad(self):
    """Enables .grad attribute for non-leaf Variables."""
    if self.grad_fn is None: # no-op for leaves
      return
    if not self.requires_grad:
      raise RuntimeError("can't retain_grad on Variable that has requires_grad=False")
    if hasattr(self, 'retains_grad'):
      return
    weak_self = weakref.ref(self)
 
    def retain_grad_hook(grad):
      var = weak_self()
      if var is None:
        return
      if var._grad is None:
        var._grad = grad.clone()
      else:
        var._grad = var._grad + grad
 
    self.register_hook(retain_grad_hook)
    self.retains_grad = True

以上這篇Pytorch 中retain_graph的用法詳解就是小編分享給大家的全部內(nèi)容了,希望能給大家一個參考,也希望大家多多支持腳本之家。

相關(guān)文章

  • Pyecharts?繪制3種常用的圖形

    Pyecharts?繪制3種常用的圖形

    這篇文章主要介紹了Pyecharts?繪制3種常用的圖形,上下組合圖、左右組合圖、一軸多圖,下文繪制過程幾介紹,需要的小伙伴可以參考一下
    2022-02-02
  • python實現(xiàn)隱馬爾科夫模型HMM

    python實現(xiàn)隱馬爾科夫模型HMM

    這篇文章主要為大家詳細(xì)介紹了python實現(xiàn)隱馬爾科夫模型HMM,具有一定的參考價值,感興趣的小伙伴們可以參考一下
    2018-03-03
  • Python數(shù)據(jù)分析之獲取雙色球歷史信息的方法示例

    Python數(shù)據(jù)分析之獲取雙色球歷史信息的方法示例

    這篇文章主要介紹了Python數(shù)據(jù)分析之獲取雙色球歷史信息的方法,涉及Python網(wǎng)頁抓取、正則匹配、文件讀寫及數(shù)值運算等相關(guān)操作技巧,需要的朋友可以參考下
    2018-02-02
  • 好的Python培訓(xùn)機構(gòu)應(yīng)該具備哪些條件

    好的Python培訓(xùn)機構(gòu)應(yīng)該具備哪些條件

    python是現(xiàn)在開發(fā)的熱潮,大家應(yīng)該如何學(xué)習(xí)呢?許多人選擇自學(xué),還有人會選擇去培訓(xùn)結(jié)構(gòu)學(xué)習(xí),那么好的培訓(xùn)機構(gòu)的標(biāo)準(zhǔn)是什么樣的呢?下面跟隨腳本之家小編一起通過本文學(xué)習(xí)吧
    2018-05-05
  • Pygame Rect區(qū)域位置的使用(圖文)

    Pygame Rect區(qū)域位置的使用(圖文)

    在 Pygame 中我們使用 Rect() 方法來創(chuàng)建一個指定位置,大小的矩形區(qū)域。本文主要就來介紹一下如何使用,具有一定的參考價值,感興趣的可以了解一下
    2021-11-11
  • Python中尋找數(shù)據(jù)異常值的3種方法

    Python中尋找數(shù)據(jù)異常值的3種方法

    這篇文章主要介紹了Python中尋找數(shù)據(jù)異常值的3種方法,文章圍繞主題展開詳細(xì)的內(nèi)容介紹,具有一定的參考價值,需要的小伙伴可以參考一下
    2022-08-08
  • PyTorch使用GPU加速計算的實現(xiàn)

    PyTorch使用GPU加速計算的實現(xiàn)

    PyTorch利用NVIDIA CUDA庫提供的底層接口來實現(xiàn)GPU加速計算,本文就來介紹一下PyTorch使用GPU加速計算的實現(xiàn),具有一定的參考價值,感興趣的可以了解一下
    2024-02-02
  • python實現(xiàn)雨滴下落到地面效果

    python實現(xiàn)雨滴下落到地面效果

    這篇文章主要為大家詳細(xì)介紹了python實現(xiàn)雨滴下落到地面效果,具有一定的參考價值,感興趣的小伙伴們可以參考一下
    2018-06-06
  • python爬蟲爬取某網(wǎng)站視頻的示例代碼

    python爬蟲爬取某網(wǎng)站視頻的示例代碼

    這篇文章主要介紹了python爬蟲爬取某網(wǎng)站視頻的示例代碼,代碼簡單易懂,對大家的學(xué)習(xí)或工作具有一定的參考借鑒價值,需要的朋友可以參考下
    2021-02-02
  • Python XML轉(zhuǎn)Json之XML2Dict的使用方法

    Python XML轉(zhuǎn)Json之XML2Dict的使用方法

    今天小編就為大家分享一篇Python XML轉(zhuǎn)Json之XML2Dict的使用方法,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧
    2019-01-01

最新評論