欧美bbbwbbbw肥妇,免费乱码人妻系列日韩,一级黄片

在pytorch中對非葉節(jié)點(diǎn)的變量計(jì)算梯度實(shí)例

 更新時(shí)間:2020年01月10日 15:39:20   作者:FesianXu  
今天小編就為大家分享一篇在pytorch中對非葉節(jié)點(diǎn)的變量計(jì)算梯度實(shí)例,具有很好的參考價(jià)值,希望對大家有所幫助。一起跟隨小編過來看看吧

在pytorch中一般只對葉節(jié)點(diǎn)進(jìn)行梯度計(jì)算,也就是下圖中的d,e節(jié)點(diǎn),而對非葉節(jié)點(diǎn),也即是c,b節(jié)點(diǎn)則沒有顯式地去保留其中間計(jì)算過程中的梯度(因?yàn)橐话銇碚f只有葉節(jié)點(diǎn)才需要去更新),這樣可以節(jié)省很大部分的顯存,但是在調(diào)試過程中,有時(shí)候我們需要對中間變量梯度進(jìn)行監(jiān)控,以確保網(wǎng)絡(luò)的有效性,這個(gè)時(shí)候我們需要打印出非葉節(jié)點(diǎn)的梯度,為了實(shí)現(xiàn)這個(gè)目的,我們可以通過兩種手段進(jìn)行。

注冊hook函數(shù)

Tensor.register_hook[2] 可以注冊一個(gè)反向梯度傳導(dǎo)時(shí)的hook函數(shù),這個(gè)hook函數(shù)將會在每次計(jì)算 關(guān)于該張量 的時(shí)候 被調(diào)用,經(jīng)常用于調(diào)試的時(shí)候打印出非葉節(jié)點(diǎn)梯度。當(dāng)然,通過這個(gè)手段,你也可以自定義某一層的梯度更新方法。[3] 具體到這里的打印非葉節(jié)點(diǎn)的梯度,代碼如:

def hook_y(grad):
 print(grad)

x = Variable(torch.ones(2, 2), requires_grad=True)
y = x + 2
z = y * y * 3

y.register_hook(hook_y) 

out = z.mean()
out.backward()

輸出如:

tensor([[4.5000, 4.5000],
  [4.5000, 4.5000]])

retain_grad()

Tensor.retain_grad()顯式地保存非葉節(jié)點(diǎn)的梯度,當(dāng)然代價(jià)就是會增加顯存的消耗,而用hook函數(shù)的方法則是在反向計(jì)算時(shí)直接打印,因此不會增加顯存消耗,但是使用起來retain_grad()要比hook函數(shù)方便一些。代碼如:

x = Variable(torch.ones(2, 2), requires_grad=True)
y = x + 2
y.retain_grad()
z = y * y * 3
out = z.mean()
out.backward()
print(y.grad)

輸出如:

tensor([[4.5000, 4.5000],
  [4.5000, 4.5000]])

以上這篇在pytorch中對非葉節(jié)點(diǎn)的變量計(jì)算梯度實(shí)例就是小編分享給大家的全部內(nèi)容了,希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。

相關(guān)文章

最新評論