欧美bbbwbbbw肥妇,免费乱码人妻系列日韩,一级黄片

pytorch中的hook機(jī)制register_forward_hook

 更新時(shí)間:2022年03月09日 11:21:01   作者:機(jī)器學(xué)習(xí)入坑者  
這篇文章主要介紹了pytorch中的hook機(jī)制register_forward_hook,手動(dòng)在forward之前注冊(cè)hook,hook在forward執(zhí)行以后被自動(dòng)執(zhí)行,下面詳細(xì)的內(nèi)容介紹,需要的小伙伴可以參考一下

1、hook背景

Hook被成為鉤子機(jī)制,這不是pytorch的首創(chuàng),在Windows的編程中已經(jīng)被普遍采用,包括進(jìn)程內(nèi)鉤子和全局鉤子。按照自己的理解,hook的作用是通過系統(tǒng)來維護(hù)一個(gè)鏈表,使得用戶攔截(獲取)通信消息,用于處理事件。

pytorch中包含forwardbackward兩個(gè)鉤子注冊(cè)函數(shù),用于獲取forward和backward中輸入和輸出,按照自己不全面的理解,應(yīng)該目的是“不改變網(wǎng)絡(luò)的定義代碼,也不需要在forward函數(shù)中return某個(gè)感興趣層的輸出,這樣代碼太冗雜了”。

2、源碼閱讀

register_forward_hook()函數(shù)必須在forward()函數(shù)調(diào)用之前被使用,因?yàn)檫@個(gè)函數(shù)源碼注釋顯示這個(gè)函數(shù)“ it will not have effect on forward since this is called after :func:`forward` is called”,也就是這個(gè)函數(shù)在forward()之后就沒有作用了?。。。?/p>

作用:獲取forward過程中每層的輸入和輸出,用于對(duì)比hook是不是正確記錄。

def register_forward_hook(self, hook):
? ? ? ? r"""Registers a forward hook on the module.
? ? ? ? The hook will be called every time after :func:`forward` has computed an output.
? ? ? ? It should have the following signature::
? ? ? ? ? ? hook(module, input, output) -> None or modified output
? ? ? ? The hook can modify the output. It can modify the input inplace but
? ? ? ? it will not have effect on forward since this is called after
? ? ? ? :func:`forward` is called.

? ? ? ? Returns:
? ? ? ? ? ? :class:`torch.utils.hooks.RemovableHandle`:
? ? ? ? ? ? ? ? a handle that can be used to remove the added hook by calling
? ? ? ? ? ? ? ? ``handle.remove()``
? ? ? ? """
? ? ? ? handle = hooks.RemovableHandle(self._forward_hooks)
? ? ? ? self._forward_hooks[handle.id] = hook
? ? ? ? return handle

3、定義一個(gè)用于測試hooker的類

如果隨機(jī)的初始化每個(gè)層,那么就無法測試出自己獲取的輸入輸出是不是forward中的輸入輸出了,所以需要將每一層的權(quán)重和偏置設(shè)置為可識(shí)別的值(比如全部初始化為1)。網(wǎng)絡(luò)包含兩層(Linear有需要求導(dǎo)的參數(shù)被稱為一個(gè)層,而ReLU沒有需要求導(dǎo)的參數(shù)不被稱作一層),__init__()中調(diào)用initialize函數(shù)對(duì)所有層進(jìn)行初始化。

注意:在forward()函數(shù)返回各個(gè)層的輸出,但是ReLU6沒有返回,因?yàn)楹罄m(xù)測試的時(shí)候不對(duì)這一層進(jìn)行注冊(cè)hook。

class TestForHook(nn.Module):
? ? def __init__(self):
? ? ? ? super().__init__()

? ? ? ? self.linear_1 = nn.Linear(in_features=2, out_features=2)
? ? ? ? self.linear_2 = nn.Linear(in_features=2, out_features=1)
? ? ? ? self.relu = nn.ReLU()
? ? ? ? self.relu6 = nn.ReLU6()
? ? ? ? self.initialize()

? ? def forward(self, x):
? ? ? ? linear_1 = self.linear_1(x)
? ? ? ? linear_2 = self.linear_2(linear_1)
? ? ? ? relu = self.relu(linear_2)
? ? ? ? relu_6 = self.relu6(relu)
? ? ? ? layers_in = (x, linear_1, linear_2)
? ? ? ? layers_out = (linear_1, linear_2, relu)
? ? ? ? return relu_6, layers_in, layers_out
? ? def initialize(self):
? ? ? ? """ 定義特殊的初始化,用于驗(yàn)證是不是獲取了權(quán)重"""
? ? ? ? self.linear_1.weight = torch.nn.Parameter(torch.FloatTensor([[1, 1], [1, 1]]))
? ? ? ? self.linear_1.bias = torch.nn.Parameter(torch.FloatTensor([1, 1]))
? ? ? ? self.linear_2.weight = torch.nn.Parameter(torch.FloatTensor([[1, 1]]))
? ? ? ? self.linear_2.bias = torch.nn.Parameter(torch.FloatTensor([1]))
? ? ? ? return True

4、定義hook函數(shù)

hook()函數(shù)是register_forward_hook()函數(shù)必須提供的參數(shù),好處是“用戶可以自行決定攔截了中間信息之后要做什么!”,比如自己想單純的記錄網(wǎng)絡(luò)的輸入輸出(也可以進(jìn)行修改等更加復(fù)雜的操作)。

首先定義幾個(gè)容器用于記錄:

定義用于獲取網(wǎng)絡(luò)各層輸入輸出tensor的容器:

# 并定義module_name用于記錄相應(yīng)的module名字
module_name = []
features_in_hook = []
features_out_hook = []
hook函數(shù)需要三個(gè)參數(shù),這三個(gè)參數(shù)是系統(tǒng)傳給hook函數(shù)的,自己不能修改這三個(gè)參數(shù):

hook函數(shù)負(fù)責(zé)將獲取的輸入輸出添加到feature列表中;并提供相應(yīng)的module名字

def hook(module, fea_in, fea_out):
????print("hooker working")
????module_name.append(module.__class__)
????features_in_hook.append(fea_in)
????features_out_hook.append(fea_out)
????return None

5、對(duì)需要的層注冊(cè)hook

注冊(cè)鉤子必須在forward()函數(shù)被執(zhí)行之前,也就是定義網(wǎng)絡(luò)進(jìn)行計(jì)算之前就要注冊(cè),下面的代碼對(duì)網(wǎng)絡(luò)除去ReLU6以外的層都進(jìn)行了注冊(cè)(也可以選定某些層進(jìn)行注冊(cè)):

注冊(cè)鉤子可以對(duì)某些層單獨(dú)進(jìn)行:

net = TestForHook()
net_chilren = net.children()
for child in net_chilren:
? ? if not isinstance(child, nn.ReLU6):
? ? ? ? child.register_forward_hook(hook=hook)

6、測試forward()返回的特征和hook記錄的是否一致

6.1 測試forward()提供的輸入輸出特征

由于前面的forward()函數(shù)返回了需要記錄的特征,這里可以直接測試:

out, features_in_forward, features_out_forward = net(x)
print("*"*5+"forward return features"+"*"*5)
print(features_in_forward)
print(features_out_forward)
print("*"*5+"forward return features"+"*"*5)

得到下面的輸出是理所當(dāng)然的:

*****forward return features*****
(tensor([[0.1000, 0.1000],
        [0.1000, 0.1000]]), tensor([[1.2000, 1.2000],
        [1.2000, 1.2000]], grad_fn=<AddmmBackward>), tensor([[3.4000],
        [3.4000]], grad_fn=<AddmmBackward>))
(tensor([[1.2000, 1.2000],
        [1.2000, 1.2000]], grad_fn=<AddmmBackward>), tensor([[3.4000],
        [3.4000]], grad_fn=<AddmmBackward>), tensor([[3.4000],
        [3.4000]], grad_fn=<ThresholdBackward0>))
*****forward return features*****

6.2 hook記錄的輸入特征和輸出特征

hook通過list結(jié)構(gòu)進(jìn)行記錄,所以可以直接print

測試features_in是不是存儲(chǔ)了輸入:

print("*"*5+"hook record features"+"*"*5)
print(features_in_hook)
print(features_out_hook)
print(module_name)
print("*"*5+"hook record features"+"*"*5)

得到和forward一樣的結(jié)果:

*****hook record features*****
[(tensor([[0.1000, 0.1000],
        [0.1000, 0.1000]]),), (tensor([[1.2000, 1.2000],
        [1.2000, 1.2000]], grad_fn=<AddmmBackward>),), (tensor([[3.4000],
        [3.4000]], grad_fn=<AddmmBackward>),)]
[tensor([[1.2000, 1.2000],
        [1.2000, 1.2000]], grad_fn=<AddmmBackward>), tensor([[3.4000],
        [3.4000]], grad_fn=<AddmmBackward>), tensor([[3.4000],
        [3.4000]], grad_fn=<ThresholdBackward0>)]
[<class 'torch.nn.modules.linear.Linear'>, 
<class 'torch.nn.modules.linear.Linear'>,
 <class 'torch.nn.modules.activation.ReLU'>]
*****hook record features*****

6.3 把hook記錄的和forward做減法

如果害怕會(huì)有小數(shù)點(diǎn)后面的數(shù)值不一致,或者數(shù)據(jù)類型的不匹配,可以對(duì)hook記錄的特征和forward記錄的特征做減法:

測試forward返回的feautes_in是不是和hook記錄的一致:

print("sub result'")
for forward_return, hook_record in zip(features_in_forward, features_in_hook):
? ? print(forward_return-hook_record[0])

得到的全部都是0,說明hook沒問題:

sub result
tensor([[0., 0.],
? ? ? ? [0., 0.]])
tensor([[0., 0.],
? ? ? ? [0., 0.]], grad_fn=<SubBackward0>)
tensor([[0.],
? ? ? ? [0.]], grad_fn=<SubBackward0>)

7、完整代碼

import torch
import torch.nn as nn


class TestForHook(nn.Module):
? ? def __init__(self):
? ? ? ? super().__init__()

? ? ? ? self.linear_1 = nn.Linear(in_features=2, out_features=2)
? ? ? ? self.linear_2 = nn.Linear(in_features=2, out_features=1)
? ? ? ? self.relu = nn.ReLU()
? ? ? ? self.relu6 = nn.ReLU6()
? ? ? ? self.initialize()

? ? def forward(self, x):
? ? ? ? linear_1 = self.linear_1(x)
? ? ? ? linear_2 = self.linear_2(linear_1)
? ? ? ? relu = self.relu(linear_2)
? ? ? ? relu_6 = self.relu6(relu)
? ? ? ? layers_in = (x, linear_1, linear_2)
? ? ? ? layers_out = (linear_1, linear_2, relu)
? ? ? ? return relu_6, layers_in, layers_out

? ? def initialize(self):
? ? ? ? """ 定義特殊的初始化,用于驗(yàn)證是不是獲取了權(quán)重"""
? ? ? ? self.linear_1.weight = torch.nn.Parameter(torch.FloatTensor([[1, 1], [1, 1]]))
? ? ? ? self.linear_1.bias = torch.nn.Parameter(torch.FloatTensor([1, 1]))
? ? ? ? self.linear_2.weight = torch.nn.Parameter(torch.FloatTensor([[1, 1]]))
? ? ? ? self.linear_2.bias = torch.nn.Parameter(torch.FloatTensor([1]))
? ? ? ? return True

定義用于獲取網(wǎng)絡(luò)各層輸入輸出tensor的容器,并定義module_name用于記錄相應(yīng)的module名字

module_name = []
features_in_hook = []
features_out_hook = []

hook函數(shù)負(fù)責(zé)將獲取的輸入輸出添加到feature列表中,并提供相應(yīng)的module名字

def hook(module, fea_in, fea_out):
? ? print("hooker working")
? ? module_name.append(module.__class__)
? ? features_in_hook.append(fea_in)
? ? features_out_hook.append(fea_out)
? ? return None

定義全部是1的輸入:

x = torch.FloatTensor([[0.1, 0.1], [0.1, 0.1]])

注冊(cè)鉤子可以對(duì)某些層單獨(dú)進(jìn)行:

net = TestForHook()
net_chilren = net.children()
for child in net_chilren:
? ? if not isinstance(child, nn.ReLU6):
? ? ? ? child.register_forward_hook(hook=hook)

測試網(wǎng)絡(luò)輸出:

out, features_in_forward, features_out_forward = net(x)
print("*"*5+"forward return features"+"*"*5)
print(features_in_forward)
print(features_out_forward)
print("*"*5+"forward return features"+"*"*5)

測試features_in是不是存儲(chǔ)了輸入:

print("*"*5+"hook record features"+"*"*5)
print(features_in_hook)
print(features_out_hook)
print(module_name)
print("*"*5+"hook record features"+"*"*5)

測試forward返回的feautes_in是不是和hook記錄的一致:

print("sub result")
for forward_return, hook_record in zip(features_in_forward, features_in_hook):
    print(forward_return-hook_record[0])

 到此這篇關(guān)于pytorch中的hook機(jī)制register_forward_hook的文章就介紹到這了,更多相關(guān)pytorch中的hook機(jī)制內(nèi)容請(qǐng)搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!

相關(guān)文章

  • Python+Tkinter實(shí)現(xiàn)軟件自動(dòng)更新與提醒

    Python+Tkinter實(shí)現(xiàn)軟件自動(dòng)更新與提醒

    這篇文章主要為大家詳細(xì)介紹了Python如何利用Tkinter編寫一個(gè)軟件自動(dòng)更新與提醒小程序,文中的示例代碼簡潔易懂,感興趣的小伙伴可以動(dòng)手嘗試一下
    2023-07-07
  • python實(shí)現(xiàn)給數(shù)組按片賦值的方法

    python實(shí)現(xiàn)給數(shù)組按片賦值的方法

    這篇文章主要介紹了python實(shí)現(xiàn)給數(shù)組按片賦值的方法,實(shí)例分析了Python在指定位置進(jìn)行賦值的相關(guān)技巧,需要的朋友可以參考下
    2015-07-07
  • pyinstaller使用大全

    pyinstaller使用大全

    這篇文章主要介紹了pyinstaller使用大全,pyinstaller可以方便地將腳本編譯成exe,本文結(jié)合實(shí)例代碼給大家詳細(xì)講解,需要的朋友可以參考下
    2023-02-02
  • Python基于argparse與ConfigParser庫進(jìn)行入?yún)⒔馕雠cini parser

    Python基于argparse與ConfigParser庫進(jìn)行入?yún)⒔馕雠cini parser

    這篇文章主要介紹了Python基于argparse與ConfigParser庫進(jìn)行入?yún)⒔馕雠cini parser,幫助大家更好的理解和使用python,感興趣的朋友可以了解下
    2021-02-02
  • Python異常處理try語句應(yīng)用技巧實(shí)例探究

    Python異常處理try語句應(yīng)用技巧實(shí)例探究

    異常處理在Python中是至關(guān)重要的,try-except是用于捕獲和處理異常的核心機(jī)制之一,本文就帶大家深入了解如何使用try-except,處理各種異常情況
    2024-01-01
  • Mac更新python3.12?解決pip3安裝報(bào)錯(cuò)問題小結(jié)

    Mac更新python3.12?解決pip3安裝報(bào)錯(cuò)問題小結(jié)

    Mac使用homebrew更新了python3.12,刪除了以前的版本和pip3安裝軟件時(shí)候報(bào)錯(cuò),下面小編給大家分享Mac更新python3.12?解決pip3安裝報(bào)錯(cuò)問題,感興趣的朋友跟隨小編一起看看吧
    2024-05-05
  • Python設(shè)計(jì)模式之觀察者模式簡單示例

    Python設(shè)計(jì)模式之觀察者模式簡單示例

    這篇文章主要介紹了Python設(shè)計(jì)模式之觀察者模式,簡單描述了觀察者模式的概念、原理,并結(jié)合實(shí)例形式分析了Python觀察者模式的相關(guān)定義與使用技巧,需要的朋友可以參考下
    2018-01-01
  • Python中的pandas模塊詳解

    Python中的pandas模塊詳解

    在Python中使用pandas模塊,需要先安裝pandas庫,pandas模塊是Python編程語言中用于數(shù)據(jù)處理和分析的強(qiáng)大模塊,它提供了許多用于數(shù)據(jù)操作和清洗的函數(shù),使得數(shù)據(jù)處理和分析變得更為簡單和直觀,本文給大家介紹Python pandas模塊,感興趣的朋友跟隨小編一起看看吧
    2023-10-10
  • Python字符編碼轉(zhuǎn)碼之GBK,UTF8互轉(zhuǎn)

    Python字符編碼轉(zhuǎn)碼之GBK,UTF8互轉(zhuǎn)

    說到python的編碼,一句話總結(jié),說多了都是淚啊,這個(gè)在以后的python的開發(fā)中絕對(duì)是一件令人頭疼的事情。所以有必要輸入理解
    2020-02-02
  • Python連續(xù)賦值需要注意的一些問題

    Python連續(xù)賦值需要注意的一些問題

    這篇文章主要介紹了Python連續(xù)賦值需要注意的一些問題,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。如有錯(cuò)誤或未考慮完全的地方,望不吝賜教
    2021-06-06

最新評(píng)論