欧美bbbwbbbw肥妇,免费乱码人妻系列日韩,一级黄片

在pytorch中對非葉節(jié)點的變量計算梯度實例

 更新時間:2020年01月10日 15:39:20   作者:FesianXu  
今天小編就為大家分享一篇在pytorch中對非葉節(jié)點的變量計算梯度實例,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧

在pytorch中一般只對葉節(jié)點進(jìn)行梯度計算,也就是下圖中的d,e節(jié)點,而對非葉節(jié)點,也即是c,b節(jié)點則沒有顯式地去保留其中間計算過程中的梯度(因為一般來說只有葉節(jié)點才需要去更新),這樣可以節(jié)省很大部分的顯存,但是在調(diào)試過程中,有時候我們需要對中間變量梯度進(jìn)行監(jiān)控,以確保網(wǎng)絡(luò)的有效性,這個時候我們需要打印出非葉節(jié)點的梯度,為了實現(xiàn)這個目的,我們可以通過兩種手段進(jìn)行。

注冊hook函數(shù)

Tensor.register_hook[2] 可以注冊一個反向梯度傳導(dǎo)時的hook函數(shù),這個hook函數(shù)將會在每次計算 關(guān)于該張量 的時候 被調(diào)用,經(jīng)常用于調(diào)試的時候打印出非葉節(jié)點梯度。當(dāng)然,通過這個手段,你也可以自定義某一層的梯度更新方法。[3] 具體到這里的打印非葉節(jié)點的梯度,代碼如:

def hook_y(grad):
 print(grad)

x = Variable(torch.ones(2, 2), requires_grad=True)
y = x + 2
z = y * y * 3

y.register_hook(hook_y) 

out = z.mean()
out.backward()

輸出如:

tensor([[4.5000, 4.5000],
  [4.5000, 4.5000]])

retain_grad()

Tensor.retain_grad()顯式地保存非葉節(jié)點的梯度,當(dāng)然代價就是會增加顯存的消耗,而用hook函數(shù)的方法則是在反向計算時直接打印,因此不會增加顯存消耗,但是使用起來retain_grad()要比hook函數(shù)方便一些。代碼如:

x = Variable(torch.ones(2, 2), requires_grad=True)
y = x + 2
y.retain_grad()
z = y * y * 3
out = z.mean()
out.backward()
print(y.grad)

輸出如:

tensor([[4.5000, 4.5000],
  [4.5000, 4.5000]])

以上這篇在pytorch中對非葉節(jié)點的變量計算梯度實例就是小編分享給大家的全部內(nèi)容了,希望能給大家一個參考,也希望大家多多支持腳本之家。

相關(guān)文章

最新評論