欧美bbbwbbbw肥妇,免费乱码人妻系列日韩,一级黄片

pytorch之pytorch?hook和關(guān)于pytorch?backward過程問題

 更新時(shí)間:2023年09月08日 09:40:24   作者:u012436149  
這篇文章主要介紹了pytorch之pytorch?hook和關(guān)于pytorch?backward過程問題,具有很好的參考價(jià)值,希望對(duì)大家有所幫助,如有錯(cuò)誤或未考慮完全的地方,望不吝賜教

pytorch 的 hook 機(jī)制

在看 pytorch 官方文檔的時(shí)候,發(fā)現(xiàn)在 nn.Module 部分和 Variable 部分均有 hook 的身影。

感到很神奇,因?yàn)樵谑褂?tensorflow 的時(shí)候沒有碰到過這個(gè)詞。所以打算一探究竟。

Variable 的 hook

register_hook(hook)

注冊(cè)一個(gè) backward 鉤子。

每次 gradients 被計(jì)算的時(shí)候,這個(gè) hook 都被調(diào)用。 hook 應(yīng)該擁有以下簽名:

hook(grad) -> Variable or None

hook 不應(yīng)該修改它的輸入,但是它可以返回一個(gè)替代當(dāng)前梯度的新梯度。

這個(gè)函數(shù)返回一個(gè) 句柄( handle )。它有一個(gè)方法 handle.remove() ,可以用這個(gè)方法將 hook module 移除。

例子:

import torch
v = torch.tensor([0, 0, 0], requires_grad=True, dtype=torch.float32)
h = v.register_hook(lambda grad: grad * 2)  # double the gradient
v.backward(torch.tensor([1, 1, 1], dtype=torch.float32))
# 先計(jì)算原始梯度,再進(jìn)hook,獲得一個(gè)新梯度。
print(v.grad.data)
h.remove()  # removes the hook
tensor([2., 2., 2.])

nn.Module的hook

register_forward_hook(hook)

module 上注冊(cè)一個(gè) forward hook 。

這里要注意的是,hook 只能注冊(cè)到 Module 上,即,僅僅是簡單的 op 包裝的 Module,而不是我們繼承 Module時(shí)寫的那個(gè)類,我們繼承 Module寫的類叫做 Container。

每次調(diào)用 forward() 計(jì)算輸出的時(shí)候,這個(gè) hook 就會(huì)被調(diào)用。

它應(yīng)該擁有以下簽名:

hook(module, input, output) -> None

hook 不應(yīng)該修改 input output 的值。 這個(gè)函數(shù)返回一個(gè) 句柄( handle )。它有一個(gè)方法 handle.remove() ,可以用這個(gè)方法將 hook module 移除。

看這個(gè)解釋可能有點(diǎn)蒙逼,但是如果要看一下 nn.Module 的源碼怎么使用 hook 的話,那就烏云盡散了。

先看 register_forward_hook

def register_forward_hook(self, hook):
       handle = hooks.RemovableHandle(self._forward_hooks)
       self._forward_hooks[handle.id] = hook
       return handle

這個(gè)方法的作用是在此 module 上注冊(cè)一個(gè) hook ,函數(shù)中第一句就沒必要在意了,主要看第二句,是把注冊(cè)的 hook 保存在 _forward_hooks 字典里。

再看 nn.Module __call__ 方法(被閹割了,只留下需要關(guān)注的部分):

def __call__(self, *input, **kwargs):
   result = self.forward(*input, **kwargs)
   for hook in self._forward_hooks.values():
       #將注冊(cè)的hook拿出來用
       hook_result = hook(self, input, result)
   ...
   return result

可以看到,當(dāng)我們執(zhí)行 model(x) 的時(shí)候,底層干了以下幾件事:

  • 調(diào)用 forward 方法計(jì)算結(jié)果
  • 判斷有沒有注冊(cè) forward_hook ,有的話,就將 forward 的輸入及結(jié)果作為 hook 的實(shí)參。然后讓 hook 自己干一些不可告人的事情。

看到這,我們就明白 hook 簽名的意思了,還有為什么 hook 不能修改 input output 的原因。

小例子:

import torch
from torch import nn
import torch.functional as F
from torch.autograd import Variable
def for_hook(module, input, output):
    print(module)
    for val in input:
        print("input val:",val)
    for out_val in output:
        print("output val:", out_val)
class Model(nn.Module):
    def __init__(self):
        super(Model, self).__init__()
    def forward(self, x):
        return x+1
model = Model()
x = Variable(torch.FloatTensor([1]), requires_grad=True)
handle = model.register_forward_hook(for_hook)
print(model(x))
handle.remove()

register_backward_hook

module 上注冊(cè)一個(gè) bachward hook 。此方法目前只能用在 Module 上,不能用在 Container 上,當(dāng) Module 的forward函數(shù)中只有一個(gè) Function 的時(shí)候,稱為 Module ,如果 Module 包含其它 Module ,稱之為 Container

每次計(jì)算 module inputs 的梯度的時(shí)候,這個(gè) hook 會(huì)被調(diào)用。 hook 應(yīng)該擁有下面的 signature 。

hook(module, grad_input, grad_output) -> Tensor or None

如果 module 有多個(gè)輸入輸出的話,那么 grad_input grad_output 將會(huì)是個(gè) tuple 。 hook 不應(yīng)該修改它的 arguments ,但是它可以選擇性的返回關(guān)于輸入的梯度,這個(gè)返回的梯度在后續(xù)的計(jì)算中會(huì)替代 grad_input

這個(gè)函數(shù)返回一個(gè) 句柄( handle )。它有一個(gè)方法 handle.remove() ,可以用這個(gè)方法將 hook module 移除。

從上邊描述來看, backward hook 似乎可以幫助我們處理一下計(jì)算完的梯度??聪旅?nn.Module register_backward_hook 方法的實(shí)現(xiàn),和 register_forward_hook 方法的實(shí)現(xiàn)幾乎一樣,都是用字典把注冊(cè)的 hook 保存起來。

def register_backward_hook(self, hook):
    handle = hooks.RemovableHandle(self._backward_hooks)
    self._backward_hooks[handle.id] = hook
    return handle

先看個(gè)例子來看一下 hook 的參數(shù)代表了什么:

import torch
from torch.autograd import Variable
from torch.nn import Parameter
import torch.nn as nn
import math
def bh(m,gi,go):
    print("Grad Input")
    print(gi)
    print("Grad Output")
    print(go)
    return gi[0]*0,gi[1]*0
class Linear(nn.Module):
    def __init__(self, in_features, out_features, bias=True):
        super(Linear, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = Parameter(torch.Tensor(out_features, in_features))
        if bias:
            self.bias = Parameter(torch.Tensor(out_features))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()
    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight.size(1))
        self.weight.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.uniform_(-stdv, stdv)
    def forward(self, input):
        if self.bias is None:
            return self._backend.Linear()(input, self.weight)
        else:
            return self._backend.Linear()(input, self.weight, self.bias)
x=Variable(torch.FloatTensor([[1, 2, 3]]),requires_grad=True)
mod=Linear(3, 1, bias=False)
mod.register_backward_hook(bh) # 在這里給module注冊(cè)了backward hook
out=mod(x)
out.register_hook(lambda grad: 0.1*grad) #在這里給variable注冊(cè)了 hook
out.backward()
print(['*']*20)
print("x.grad", x.grad)
print(mod.weight.grad)
Grad Input
(Variable containing:
1.00000e-02 *
  5.1902 -2.3778 -4.4071
[torch.FloatTensor of size 1x3]
, Variable containing:
 0.1000  0.2000  0.3000
[torch.FloatTensor of size 1x3]
)
Grad Output
(Variable containing:
 0.1000
[torch.FloatTensor of size 1x1]
,)
['*', '*', '*', '*', '*', '*', '*', '*', '*', '*', '*', '*', '*', '*', '*', '*', '*', '*', '*', '*']
x.grad Variable containing:
 0 -0 -0
[torch.FloatTensor of size 1x3]
Variable containing:
 0  0  0
[torch.FloatTensor of size 1x3]

可以看出, grad_in 保存的是,此模塊 Function 方法的輸入的值的梯度。 grad_out 保存的是,此模塊 forward 方法返回值的梯度。我們不能在 grad_in 上直接修改,但是我們可以返回一個(gè)新的 new_grad_in 作為 Function 方法 inputs 的梯度。

上述代碼對(duì) variable module 同時(shí)注冊(cè)了 backward hook ,這里要注意的是,無論是 module hook 還是 variable hook ,最終還是注冊(cè)到 Function 上的。這點(diǎn)通過查看 Varible register_hook 源碼和 Module __call__ 源碼得知。

Module的register_backward_hook的行為在未來的幾個(gè)版本可能會(huì)改變

BP過程中 Function 中的動(dòng)作可能是這樣的

class Function:
	def __init__(self):
		...
	def forward(self, inputs):
		...
		return outputs
	def backward(self, grad_outs):
		...
		return grad_ins
	def _backward(self, grad_outs):
		hooked_grad_outs = grad_outs
		for hook in hook_in_outputs:
			hooked_grad_outs = hook(hooked_grad_outs)
		grad_ins = self.backward(hooked_grad_outs)
		hooked_grad_ins = grad_ins
		for hook in hooks_in_module:
			hooked_grad_ins = hook(hooked_grad_ins)
		return hooked_grad_ins

關(guān)于 pytorch run_backward() 的可能實(shí)現(xiàn)猜測為。

def run_backward(variable, gradient):
	creator = variable.creator
	if creator is None:
		variable.grad = variable.hook(gradient)
		return 
	grad_ins = creator._backward(gradient)
	vars = creator.saved_variables
	for var, grad in zip(vars, grad_ins):
		run_backward(var, var.grad)

中間Variable的梯度在BP的過程中是保存到GradBuffer中的(C++源碼中可以看到), BP完會(huì)釋放. 如果retain_grads=True的話,就不會(huì)被釋放

總結(jié)

以上為個(gè)人經(jīng)驗(yàn),希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。

相關(guān)文章

最新評(píng)論