欧美bbbwbbbw肥妇,免费乱码人妻系列日韩,一级黄片

PyTorch之關(guān)于hook機(jī)制

 更新時(shí)間:2023年08月02日 15:35:56   作者:harry_tea  
這篇文章主要介紹了PyTorch之關(guān)于hook機(jī)制的理解,具有很好的參考價(jià)值,希望對大家有所幫助,如有錯(cuò)誤或未考慮完全的地方,望不吝賜教

PyTorch: hook機(jī)制

在訓(xùn)練神經(jīng)網(wǎng)絡(luò)的時(shí)候我們有時(shí)需要輸出網(wǎng)絡(luò)中間層,一般來說我們有兩種處理方法:

一種是在model的forward中保存中間層的變量,然后再return的時(shí)候?qū)⑵浜徒Y(jié)果一起返回;

另一種是使用pytorch自帶的register_forward_hook,即hook機(jī)制

register_forward_hook

register_forward_hook(hook)
  • 返回module中的一個(gè)前向的hook,這個(gè)hook每次在執(zhí)行forward的時(shí)候都會(huì)被調(diào)用
  • hook: hook(module, input, output)

可能不是很好理解,我們直接用一個(gè)例子來說明,如下所示,首先我們將hook包裝在類SaveValues中,我們現(xiàn)在想要獲取模型Net中的l1的輸入和輸出,因此將model.l1存入到類中:value = SaveValues(model.l1),在類中定義一個(gè)hook_fn_act函數(shù),此函數(shù)的作用是隨著我們的register_forward_hook函數(shù)獲取Net的某一層的名字,輸入以及輸出,在這里對應(yīng)的就是model.l1, 他的輸入和輸出,最終我們將他獲取的網(wǎng)絡(luò)層的名字、輸入以及輸出保存到類SaveValues中方便我們輸出

注意:hook_fn_act函數(shù)必須有三個(gè)參數(shù),分別對應(yīng)module,input以及output

import torch
import torch.nn as nn
class SaveValues():
    def __init__(self, layer):
        self.model  = None
        self.input  = None
        self.output = None
        self.grad_input  = None
        self.grad_output = None
        self.forward_hook  = layer.register_forward_hook(self.hook_fn_act)
        self.backward_hook = layer.register_full_backward_hook(self.hook_fn_grad)
    def hook_fn_act(self, module, input, output):
        self.model  = module
        self.input  = input[0]
        self.output = output
    def hook_fn_grad(self, module, grad_input, grad_output):
        self.grad_input  = grad_input[0]
        self.grad_output = grad_output[0]
    def remove(self):
        self.forward_hook.remove()
        self.backward_hook.remove()
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.l1 = nn.Linear(2, 5)
        self.l2 = nn.Linear(5, 10)
    def forward(self, x):
        x = self.l1(x)
        x = self.l2(x)
        return x
l1loss = nn.L1Loss()
model  = Net()
value  = SaveValues(model.l2)
gt = torch.ones((10,), dtype=torch.float32, requires_grad=False)
x  = torch.ones((2,), dtype=torch.float32, requires_grad=False)
y = model(x)
loss  = l1loss(y, gt)
loss.backward()
x += 1.2
value.remove()

運(yùn)行上述程序,當(dāng)我們運(yùn)行到y = model(x)這一行時(shí),我們看一下value中的值(圖左),當(dāng)我們運(yùn)行完y = model(x)時(shí),我們看一下value中的值(圖右),這是因?yàn)樵趫?zhí)行net中的forward函數(shù)時(shí),我們的hook機(jī)制會(huì)從中提取出網(wǎng)絡(luò)的輸入和輸出,不執(zhí)行forward就不會(huì)提取

注意:

當(dāng)我們不想在提取網(wǎng)絡(luò)中間層時(shí),我們調(diào)用value.remove()即可,即刪除了網(wǎng)絡(luò)中的hook。

但是在訓(xùn)練網(wǎng)絡(luò)時(shí)我們可能需要輸出每個(gè)epoch的中間層信息,那么在for循環(huán)中就不需要?jiǎng)h除hook啦

register_full_backward_hook

好像這個(gè)反向hook很少用到?

register_forward_hook(hook)
  • 返回module中的一個(gè)反向的hook,這個(gè)hook每次在執(zhí)行forward的時(shí)候都會(huì)被調(diào)用
  • hook: hook(module, grad_input, grad_output)

繼續(xù)上述的代碼,這次我們運(yùn)行到loss.backward()之前與之后查看value中存儲(chǔ)的grad的變化,如下所示,可以發(fā)現(xiàn)在沒有反向傳播之前grad為None,當(dāng)我們執(zhí)行反向傳播之后grad就有值了

注意:

這里將layer換成了l2,因?yàn)榈谝粚觢1經(jīng)過backward之后依然是左圖不變,可能是第一層沒有梯度?

value  = SaveValues(model.l2)  # modify here: model.l1--->model.l2

remove

關(guān)于remove其實(shí)如果顯存足夠可以不用remove,雖然每個(gè)epoch的時(shí)候hook的值都會(huì)變化,但是只占用一個(gè)hook的內(nèi)存,除非開銷很大可以考慮remove

visual

當(dāng)我們的SaveValues類提取出特征圖之后,就可以對value.output進(jìn)行可視化啦

當(dāng)然如果有需要也可以用input、output或者grad進(jìn)行相應(yīng)的操作

使用Pytorch的hook機(jī)制提取特征時(shí)踩的一個(gè)坑

因?yàn)轫?xiàng)目需求,需要用DenseNet模型提取圖片特征,在使用Pytorch的hook機(jī)制提取特征,調(diào)試的時(shí)候發(fā)現(xiàn)提取出來的特征數(shù)值上全部大于等于0。

很明顯提取出來的特征是經(jīng)過ReLU的?,F(xiàn)在來看一下筆者是怎么定義hook的:

fmap_block = []
# 注冊hook
def forward_hook(module, input, output):
    fmap_block.append(output)
get_feature_model = densenet121(num_classes=2, pretrained=False)
model_dict = torch.load(model_weight_path)
get_feature_model = nn.DataParallel(get_feature_model.cuda())
get_feature_model.module.features.register_forward_hook(forward_hook)

模型定義的時(shí)候因項(xiàng)目需求,筆者并沒有使用預(yù)訓(xùn)練模型。而是自己訓(xùn)練了一個(gè)DenseNet121模型,并且使用了DataParallel進(jìn)行包裝。這里有兩點(diǎn)需要注意:

1.大部分的官方模型都會(huì)分成兩部分,分別是特征層和分類層。

    def forward(self, x):
        features = self.features(x)
        out = F.relu(features, inplace=True)
        out = F.avg_pool2d(out, kernel_size=7, stride=1).view(features.size(0), -1)
        out = self.classifier(out)
        return out

這是DenseNet模型前向傳播代碼,很明顯就是筆者上訴說的那樣。所以在使用Pytorch的hook進(jìn)行提取特征的時(shí)候可以很方便的定義成這個(gè)樣子:

DenseNet類實(shí)例.features.register_forward_hook(forward_hook)

2.眼尖的讀者可以發(fā)現(xiàn)筆者的代碼里并不是這樣定義的,多了一個(gè).module(這也算是一個(gè)小小的坑)。這是因?yàn)楣P者使用了DataParallel進(jìn)行包裝模型,使之可以使用多GPU訓(xùn)練,下面來看一下DataParallel的源碼:

    def __init__(self, module, device_ids=None, output_device=None, dim=0):
        super(DataParallel, self).__init__()
        if not torch.cuda.is_available():
            self.module = module
            self.device_ids = []

可以看到初始化DataParallel類的時(shí)候,將model作為一個(gè)參數(shù)傳給了module,所以得多加一個(gè).module才能定位到我們需要的feature。

看到這里,估計(jì)很多人已經(jīng)發(fā)現(xiàn)問題在哪里了,沒錯(cuò),問題出現(xiàn)了前向傳播部分,更準(zhǔn)確的來說是relu函數(shù)。

out = F.relu(features, inplace=True)

inplace表示原地修改張量,所以經(jīng)過relu層時(shí)提前放在列表中的特征張量就會(huì)被修改。兩種解決方法:

將inplace置為False,這樣就不會(huì)原地修改張量了。修改hook函數(shù)

def forward_hook(module, input, output):
    fmap_block.append(output.detach().cpu())

總結(jié)

以上為個(gè)人經(jīng)驗(yàn),希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。

相關(guān)文章

  • python 實(shí)現(xiàn)查找文件并輸出滿足某一條件的數(shù)據(jù)項(xiàng)方法

    python 實(shí)現(xiàn)查找文件并輸出滿足某一條件的數(shù)據(jù)項(xiàng)方法

    今天小編就為大家分享一篇python 實(shí)現(xiàn)查找文件并輸出滿足某一條件的數(shù)據(jù)項(xiàng)方法,具有很好的參考價(jià)值,希望對大家有所幫助。一起跟隨小編過來看看吧
    2019-06-06
  • Python WSGI 規(guī)范簡介

    Python WSGI 規(guī)范簡介

    這篇文章主要介紹了Python WSGI 規(guī)范的相關(guān)資料,幫助大家更好的理解和學(xué)習(xí)使用python,感興趣的朋友可以了解下
    2021-04-04
  • Python項(xiàng)目 基于Scapy實(shí)現(xiàn)SYN泛洪攻擊的方法

    Python項(xiàng)目 基于Scapy實(shí)現(xiàn)SYN泛洪攻擊的方法

    今天小編就為大家分享一篇Python項(xiàng)目 基于Scapy實(shí)現(xiàn)SYN泛洪攻擊的方法,具有很好的參考價(jià)值,希望對大家有所幫助。一起跟隨小編過來看看吧
    2019-07-07
  • python+django+selenium搭建簡易自動(dòng)化測試

    python+django+selenium搭建簡易自動(dòng)化測試

    這篇文章主要介紹了python+django+selenium搭建簡易自動(dòng)化測試,文中通過示例代碼介紹的非常詳細(xì),對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧
    2020-08-08
  • Python構(gòu)造函數(shù)屬性示例魔法解析

    Python構(gòu)造函數(shù)屬性示例魔法解析

    Python構(gòu)造函數(shù)和屬性魔法是面向?qū)ο缶幊讨械年P(guān)鍵概念,它們允許在類定義中執(zhí)行特定操作,以控制對象的初始化和屬性訪問,本文將深入學(xué)習(xí)Python中的構(gòu)造函數(shù)和屬性魔法,包括構(gòu)造函數(shù)__init__、屬性的@property和@attribute.setter等,以及它們的實(shí)際應(yīng)用
    2023-12-12
  • python linecache 處理固定格式文本數(shù)據(jù)的方法

    python linecache 處理固定格式文本數(shù)據(jù)的方法

    今天小編就為大家分享一篇python linecache 處理固定格式文本數(shù)據(jù)的方法,具有很好的參考價(jià)值,希望對大家有所幫助。一起跟隨小編過來看看吧
    2019-01-01
  • python添加菜單圖文講解

    python添加菜單圖文講解

    在本篇文章中小編給大家整理的是關(guān)于python添加菜單圖文講解以及步驟分析,需要的朋友們學(xué)習(xí)下吧。
    2019-06-06
  • 最新評論