PyTorch之關(guān)于hook機(jī)制
PyTorch: hook機(jī)制
在訓(xùn)練神經(jīng)網(wǎng)絡(luò)的時(shí)候我們有時(shí)需要輸出網(wǎng)絡(luò)中間層,一般來(lái)說(shuō)我們有兩種處理方法:
一種是在model的forward中保存中間層的變量,然后再return的時(shí)候?qū)⑵浜徒Y(jié)果一起返回;
另一種是使用pytorch自帶的register_forward_hook,即hook機(jī)制
register_forward_hook
register_forward_hook(hook)
- 返回module中的一個(gè)前向的hook,這個(gè)hook每次在執(zhí)行forward的時(shí)候都會(huì)被調(diào)用
- hook:
hook(module, input, output)
可能不是很好理解,我們直接用一個(gè)例子來(lái)說(shuō)明,如下所示,首先我們將hook包裝在類SaveValues中,我們現(xiàn)在想要獲取模型Net中的l1的輸入和輸出,因此將model.l1存入到類中:value = SaveValues(model.l1),在類中定義一個(gè)hook_fn_act函數(shù),此函數(shù)的作用是隨著我們的register_forward_hook函數(shù)獲取Net的某一層的名字,輸入以及輸出,在這里對(duì)應(yīng)的就是model.l1, 他的輸入和輸出,最終我們將他獲取的網(wǎng)絡(luò)層的名字、輸入以及輸出保存到類SaveValues中方便我們輸出
注意:hook_fn_act函數(shù)必須有三個(gè)參數(shù),分別對(duì)應(yīng)module,input以及output
import torch
import torch.nn as nn
class SaveValues():
def __init__(self, layer):
self.model = None
self.input = None
self.output = None
self.grad_input = None
self.grad_output = None
self.forward_hook = layer.register_forward_hook(self.hook_fn_act)
self.backward_hook = layer.register_full_backward_hook(self.hook_fn_grad)
def hook_fn_act(self, module, input, output):
self.model = module
self.input = input[0]
self.output = output
def hook_fn_grad(self, module, grad_input, grad_output):
self.grad_input = grad_input[0]
self.grad_output = grad_output[0]
def remove(self):
self.forward_hook.remove()
self.backward_hook.remove()
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.l1 = nn.Linear(2, 5)
self.l2 = nn.Linear(5, 10)
def forward(self, x):
x = self.l1(x)
x = self.l2(x)
return x
l1loss = nn.L1Loss()
model = Net()
value = SaveValues(model.l2)
gt = torch.ones((10,), dtype=torch.float32, requires_grad=False)
x = torch.ones((2,), dtype=torch.float32, requires_grad=False)
y = model(x)
loss = l1loss(y, gt)
loss.backward()
x += 1.2
value.remove()運(yùn)行上述程序,當(dāng)我們運(yùn)行到y = model(x)這一行時(shí),我們看一下value中的值(圖左),當(dāng)我們運(yùn)行完y = model(x)時(shí),我們看一下value中的值(圖右),這是因?yàn)樵趫?zhí)行net中的forward函數(shù)時(shí),我們的hook機(jī)制會(huì)從中提取出網(wǎng)絡(luò)的輸入和輸出,不執(zhí)行forward就不會(huì)提取
注意:
當(dāng)我們不想在提取網(wǎng)絡(luò)中間層時(shí),我們調(diào)用value.remove()即可,即刪除了網(wǎng)絡(luò)中的hook。
但是在訓(xùn)練網(wǎng)絡(luò)時(shí)我們可能需要輸出每個(gè)epoch的中間層信息,那么在for循環(huán)中就不需要?jiǎng)h除hook啦


register_full_backward_hook
好像這個(gè)反向hook很少用到?
register_forward_hook(hook)
- 返回module中的一個(gè)反向的hook,這個(gè)hook每次在執(zhí)行forward的時(shí)候都會(huì)被調(diào)用
- hook:
hook(module, grad_input, grad_output)
繼續(xù)上述的代碼,這次我們運(yùn)行到loss.backward()之前與之后查看value中存儲(chǔ)的grad的變化,如下所示,可以發(fā)現(xiàn)在沒(méi)有反向傳播之前grad為None,當(dāng)我們執(zhí)行反向傳播之后grad就有值了
注意:
這里將layer換成了l2,因?yàn)榈谝粚觢1經(jīng)過(guò)backward之后依然是左圖不變,可能是第一層沒(méi)有梯度?
value = SaveValues(model.l2) # modify here: model.l1--->model.l2


remove
關(guān)于remove其實(shí)如果顯存足夠可以不用remove,雖然每個(gè)epoch的時(shí)候hook的值都會(huì)變化,但是只占用一個(gè)hook的內(nèi)存,除非開(kāi)銷很大可以考慮remove
visual
當(dāng)我們的SaveValues類提取出特征圖之后,就可以對(duì)value.output進(jìn)行可視化啦
當(dāng)然如果有需要也可以用input、output或者grad進(jìn)行相應(yīng)的操作
使用Pytorch的hook機(jī)制提取特征時(shí)踩的一個(gè)坑
因?yàn)轫?xiàng)目需求,需要用DenseNet模型提取圖片特征,在使用Pytorch的hook機(jī)制提取特征,調(diào)試的時(shí)候發(fā)現(xiàn)提取出來(lái)的特征數(shù)值上全部大于等于0。

很明顯提取出來(lái)的特征是經(jīng)過(guò)ReLU的?,F(xiàn)在來(lái)看一下筆者是怎么定義hook的:
fmap_block = []
# 注冊(cè)hook
def forward_hook(module, input, output):
fmap_block.append(output)
get_feature_model = densenet121(num_classes=2, pretrained=False)
model_dict = torch.load(model_weight_path)
get_feature_model = nn.DataParallel(get_feature_model.cuda())
get_feature_model.module.features.register_forward_hook(forward_hook)模型定義的時(shí)候因項(xiàng)目需求,筆者并沒(méi)有使用預(yù)訓(xùn)練模型。而是自己訓(xùn)練了一個(gè)DenseNet121模型,并且使用了DataParallel進(jìn)行包裝。這里有兩點(diǎn)需要注意:
1.大部分的官方模型都會(huì)分成兩部分,分別是特征層和分類層。
def forward(self, x):
features = self.features(x)
out = F.relu(features, inplace=True)
out = F.avg_pool2d(out, kernel_size=7, stride=1).view(features.size(0), -1)
out = self.classifier(out)
return out這是DenseNet模型前向傳播代碼,很明顯就是筆者上訴說(shuō)的那樣。所以在使用Pytorch的hook進(jìn)行提取特征的時(shí)候可以很方便的定義成這個(gè)樣子:
DenseNet類實(shí)例.features.register_forward_hook(forward_hook)
2.眼尖的讀者可以發(fā)現(xiàn)筆者的代碼里并不是這樣定義的,多了一個(gè).module(這也算是一個(gè)小小的坑)。這是因?yàn)楣P者使用了DataParallel進(jìn)行包裝模型,使之可以使用多GPU訓(xùn)練,下面來(lái)看一下DataParallel的源碼:
def __init__(self, module, device_ids=None, output_device=None, dim=0):
super(DataParallel, self).__init__()
if not torch.cuda.is_available():
self.module = module
self.device_ids = []可以看到初始化DataParallel類的時(shí)候,將model作為一個(gè)參數(shù)傳給了module,所以得多加一個(gè).module才能定位到我們需要的feature。
看到這里,估計(jì)很多人已經(jīng)發(fā)現(xiàn)問(wèn)題在哪里了,沒(méi)錯(cuò),問(wèn)題出現(xiàn)了前向傳播部分,更準(zhǔn)確的來(lái)說(shuō)是relu函數(shù)。
out = F.relu(features, inplace=True)
inplace表示原地修改張量,所以經(jīng)過(guò)relu層時(shí)提前放在列表中的特征張量就會(huì)被修改。兩種解決方法:
將inplace置為False,這樣就不會(huì)原地修改張量了。修改hook函數(shù)
def forward_hook(module, input, output):
fmap_block.append(output.detach().cpu())總結(jié)
以上為個(gè)人經(jīng)驗(yàn),希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
相關(guān)文章
Python實(shí)現(xiàn)一個(gè)發(fā)送程序和接收程序
這篇文章主要介紹了Python實(shí)現(xiàn)一個(gè)發(fā)送程序和接收程序,文章圍繞主題展開(kāi)詳細(xì)的內(nèi)容介紹,具有一定的參考價(jià)值,需要的小伙伴可以參考一下2022-09-09
關(guān)于keras中keras.layers.merge的用法說(shuō)明
這篇文章主要介紹了關(guān)于keras中keras.layers.merge的用法說(shuō)明,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧2020-05-05
python 實(shí)現(xiàn)查找文件并輸出滿足某一條件的數(shù)據(jù)項(xiàng)方法
Python項(xiàng)目 基于Scapy實(shí)現(xiàn)SYN泛洪攻擊的方法
python+django+selenium搭建簡(jiǎn)易自動(dòng)化測(cè)試
python linecache 處理固定格式文本數(shù)據(jù)的方法

