pytorch之pytorch?hook和關(guān)于pytorch?backward過程問題
pytorch 的 hook 機(jī)制
在看 pytorch
官方文檔的時(shí)候,發(fā)現(xiàn)在 nn.Module
部分和 Variable
部分均有 hook
的身影。
感到很神奇,因?yàn)樵谑褂?tensorflow
的時(shí)候沒有碰到過這個(gè)詞。所以打算一探究竟。
Variable 的 hook
register_hook(hook)
注冊(cè)一個(gè) backward
鉤子。
每次 gradients
被計(jì)算的時(shí)候,這個(gè) hook
都被調(diào)用。 hook
應(yīng)該擁有以下簽名:
hook(grad) -> Variable or None
hook
不應(yīng)該修改它的輸入,但是它可以返回一個(gè)替代當(dāng)前梯度的新梯度。
這個(gè)函數(shù)返回一個(gè) 句柄( handle
)。它有一個(gè)方法 handle.remove()
,可以用這個(gè)方法將 hook
從 module
移除。
例子:
import torch v = torch.tensor([0, 0, 0], requires_grad=True, dtype=torch.float32) h = v.register_hook(lambda grad: grad * 2) # double the gradient v.backward(torch.tensor([1, 1, 1], dtype=torch.float32)) # 先計(jì)算原始梯度,再進(jìn)hook,獲得一個(gè)新梯度。 print(v.grad.data) h.remove() # removes the hook
tensor([2., 2., 2.])
nn.Module的hook
register_forward_hook(hook)
在 module
上注冊(cè)一個(gè) forward hook
。
這里要注意的是,hook 只能注冊(cè)到 Module 上,即,僅僅是簡單的 op
包裝的 Module,而不是我們繼承 Module時(shí)寫的那個(gè)類,我們繼承 Module寫的類叫做 Container。
每次調(diào)用 forward()
計(jì)算輸出的時(shí)候,這個(gè) hook
就會(huì)被調(diào)用。
它應(yīng)該擁有以下簽名:
hook(module, input, output) -> None
hook
不應(yīng)該修改 input
和 output
的值。 這個(gè)函數(shù)返回一個(gè) 句柄( handle
)。它有一個(gè)方法 handle.remove()
,可以用這個(gè)方法將 hook
從 module
移除。
看這個(gè)解釋可能有點(diǎn)蒙逼,但是如果要看一下 nn.Module
的源碼怎么使用 hook
的話,那就烏云盡散了。
先看 register_forward_hook
def register_forward_hook(self, hook): handle = hooks.RemovableHandle(self._forward_hooks) self._forward_hooks[handle.id] = hook return handle
這個(gè)方法的作用是在此 module
上注冊(cè)一個(gè) hook
,函數(shù)中第一句就沒必要在意了,主要看第二句,是把注冊(cè)的 hook
保存在 _forward_hooks
字典里。
再看 nn.Module
的 __call__
方法(被閹割了,只留下需要關(guān)注的部分):
def __call__(self, *input, **kwargs): result = self.forward(*input, **kwargs) for hook in self._forward_hooks.values(): #將注冊(cè)的hook拿出來用 hook_result = hook(self, input, result) ... return result
可以看到,當(dāng)我們執(zhí)行 model(x)
的時(shí)候,底層干了以下幾件事:
- 調(diào)用
forward
方法計(jì)算結(jié)果 - 判斷有沒有注冊(cè)
forward_hook
,有的話,就將forward
的輸入及結(jié)果作為hook
的實(shí)參。然后讓hook
自己干一些不可告人的事情。
看到這,我們就明白 hook
簽名的意思了,還有為什么 hook
不能修改 input
的 output
的原因。
小例子:
import torch from torch import nn import torch.functional as F from torch.autograd import Variable def for_hook(module, input, output): print(module) for val in input: print("input val:",val) for out_val in output: print("output val:", out_val) class Model(nn.Module): def __init__(self): super(Model, self).__init__() def forward(self, x): return x+1 model = Model() x = Variable(torch.FloatTensor([1]), requires_grad=True) handle = model.register_forward_hook(for_hook) print(model(x)) handle.remove()
register_backward_hook
在 module
上注冊(cè)一個(gè) bachward hook
。此方法目前只能用在 Module
上,不能用在 Container
上,當(dāng) Module
的forward函數(shù)中只有一個(gè) Function
的時(shí)候,稱為 Module
,如果 Module
包含其它 Module
,稱之為 Container
每次計(jì)算 module
的 inputs
的梯度的時(shí)候,這個(gè) hook
會(huì)被調(diào)用。 hook
應(yīng)該擁有下面的 signature
。
hook(module, grad_input, grad_output) -> Tensor or None
如果 module
有多個(gè)輸入輸出的話,那么 grad_input
grad_output
將會(huì)是個(gè) tuple
。 hook
不應(yīng)該修改它的 arguments
,但是它可以選擇性的返回關(guān)于輸入的梯度,這個(gè)返回的梯度在后續(xù)的計(jì)算中會(huì)替代 grad_input
。
這個(gè)函數(shù)返回一個(gè) 句柄( handle
)。它有一個(gè)方法 handle.remove()
,可以用這個(gè)方法將 hook
從 module
移除。
從上邊描述來看, backward hook
似乎可以幫助我們處理一下計(jì)算完的梯度??聪旅?nn.Module
中 register_backward_hook
方法的實(shí)現(xiàn),和 register_forward_hook
方法的實(shí)現(xiàn)幾乎一樣,都是用字典把注冊(cè)的 hook
保存起來。
def register_backward_hook(self, hook): handle = hooks.RemovableHandle(self._backward_hooks) self._backward_hooks[handle.id] = hook return handle
先看個(gè)例子來看一下 hook
的參數(shù)代表了什么:
import torch from torch.autograd import Variable from torch.nn import Parameter import torch.nn as nn import math def bh(m,gi,go): print("Grad Input") print(gi) print("Grad Output") print(go) return gi[0]*0,gi[1]*0 class Linear(nn.Module): def __init__(self, in_features, out_features, bias=True): super(Linear, self).__init__() self.in_features = in_features self.out_features = out_features self.weight = Parameter(torch.Tensor(out_features, in_features)) if bias: self.bias = Parameter(torch.Tensor(out_features)) else: self.register_parameter('bias', None) self.reset_parameters() def reset_parameters(self): stdv = 1. / math.sqrt(self.weight.size(1)) self.weight.data.uniform_(-stdv, stdv) if self.bias is not None: self.bias.data.uniform_(-stdv, stdv) def forward(self, input): if self.bias is None: return self._backend.Linear()(input, self.weight) else: return self._backend.Linear()(input, self.weight, self.bias) x=Variable(torch.FloatTensor([[1, 2, 3]]),requires_grad=True) mod=Linear(3, 1, bias=False) mod.register_backward_hook(bh) # 在這里給module注冊(cè)了backward hook out=mod(x) out.register_hook(lambda grad: 0.1*grad) #在這里給variable注冊(cè)了 hook out.backward() print(['*']*20) print("x.grad", x.grad) print(mod.weight.grad)
Grad Input (Variable containing: 1.00000e-02 * 5.1902 -2.3778 -4.4071 [torch.FloatTensor of size 1x3] , Variable containing: 0.1000 0.2000 0.3000 [torch.FloatTensor of size 1x3] ) Grad Output (Variable containing: 0.1000 [torch.FloatTensor of size 1x1] ,) ['*', '*', '*', '*', '*', '*', '*', '*', '*', '*', '*', '*', '*', '*', '*', '*', '*', '*', '*', '*'] x.grad Variable containing: 0 -0 -0 [torch.FloatTensor of size 1x3] Variable containing: 0 0 0 [torch.FloatTensor of size 1x3]
可以看出, grad_in
保存的是,此模塊 Function
方法的輸入的值的梯度。 grad_out
保存的是,此模塊 forward
方法返回值的梯度。我們不能在 grad_in
上直接修改,但是我們可以返回一個(gè)新的 new_grad_in
作為 Function
方法 inputs
的梯度。
上述代碼對(duì) variable
和 module
同時(shí)注冊(cè)了 backward hook
,這里要注意的是,無論是 module hook
還是 variable hook
,最終還是注冊(cè)到 Function
上的。這點(diǎn)通過查看 Varible
的 register_hook
源碼和 Module
的 __call__
源碼得知。
Module的register_backward_hook的行為在未來的幾個(gè)版本可能會(huì)改變
BP過程中 Function
中的動(dòng)作可能是這樣的
class Function: def __init__(self): ... def forward(self, inputs): ... return outputs def backward(self, grad_outs): ... return grad_ins def _backward(self, grad_outs): hooked_grad_outs = grad_outs for hook in hook_in_outputs: hooked_grad_outs = hook(hooked_grad_outs) grad_ins = self.backward(hooked_grad_outs) hooked_grad_ins = grad_ins for hook in hooks_in_module: hooked_grad_ins = hook(hooked_grad_ins) return hooked_grad_ins
關(guān)于 pytorch run_backward()
的可能實(shí)現(xiàn)猜測為。
def run_backward(variable, gradient): creator = variable.creator if creator is None: variable.grad = variable.hook(gradient) return grad_ins = creator._backward(gradient) vars = creator.saved_variables for var, grad in zip(vars, grad_ins): run_backward(var, var.grad)
中間Variable的梯度在BP的過程中是保存到GradBuffer中的(C++源碼中可以看到), BP完會(huì)釋放. 如果retain_grads=True的話,就不會(huì)被釋放
總結(jié)
以上為個(gè)人經(jīng)驗(yàn),希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
相關(guān)文章
Python3.6連接Oracle數(shù)據(jù)庫的方法詳解
這篇文章主要介紹了Python3.6連接Oracle數(shù)據(jù)庫的方法,較為詳細(xì)的分析了cx_Oracle模塊安裝及Python3.6使用cx_Oracle模塊操作Oracle數(shù)據(jù)庫的具體操作步驟與相關(guān)注意事項(xiàng),需要的朋友可以參考下2018-05-05Python數(shù)據(jù)類型之Number數(shù)字操作實(shí)例詳解
這篇文章主要介紹了Python數(shù)據(jù)類型之Number數(shù)字操作,結(jié)合實(shí)例形式詳細(xì)分析了Python數(shù)字類型的概念、功能、分類及常用數(shù)學(xué)函數(shù)相關(guān)使用技巧,需要的朋友可以參考下2019-05-05利用Python實(shí)現(xiàn)手機(jī)短信監(jiān)控通知的方法
今天小編就為大家分享一篇利用Python實(shí)現(xiàn)手機(jī)短信監(jiān)控通知的方法,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過來看看吧2019-07-07點(diǎn)云地面點(diǎn)濾波(Cloth Simulation Filter, CSF)
這篇文章主要介紹了點(diǎn)云地面點(diǎn)濾波(Cloth Simulation Filter, CSF)“布料”濾波算法介紹,本文從基本思想到實(shí)現(xiàn)思路一步步給大家講解的非常詳細(xì),需要的朋友可以參考下2021-08-08詳解基于django實(shí)現(xiàn)的webssh簡單例子
這篇文章主要介紹了基于 django 實(shí)現(xiàn)的 webssh 簡單例子,小編覺得挺不錯(cuò)的,現(xiàn)在分享給大家,也給大家做個(gè)參考。一起跟隨小編過來看看吧2018-07-07運(yùn)用TensorFlow進(jìn)行簡單實(shí)現(xiàn)線性回歸、梯度下降示例
這篇文章主要介紹了運(yùn)用TensorFlow進(jìn)行簡單實(shí)現(xiàn)線性回歸、梯度下降示例,具有一定的參考價(jià)值,感興趣的小伙伴們可以參考一下2018-03-03