PyTorch之關(guān)于hook機(jī)制
PyTorch: hook機(jī)制
在訓(xùn)練神經(jīng)網(wǎng)絡(luò)的時(shí)候我們有時(shí)需要輸出網(wǎng)絡(luò)中間層,一般來說我們有兩種處理方法:
一種是在model的forward中保存中間層的變量,然后再return的時(shí)候?qū)⑵浜徒Y(jié)果一起返回;
另一種是使用pytorch自帶的register_forward_hook,即hook機(jī)制
register_forward_hook
register_forward_hook(hook)
- 返回module中的一個(gè)前向的hook,這個(gè)hook每次在執(zhí)行forward的時(shí)候都會(huì)被調(diào)用
- hook:
hook(module, input, output)
可能不是很好理解,我們直接用一個(gè)例子來說明,如下所示,首先我們將hook包裝在類SaveValues中,我們現(xiàn)在想要獲取模型Net中的l1的輸入和輸出,因此將model.l1存入到類中:value = SaveValues(model.l1)
,在類中定義一個(gè)hook_fn_act函數(shù),此函數(shù)的作用是隨著我們的register_forward_hook
函數(shù)獲取Net的某一層的名字,輸入以及輸出,在這里對應(yīng)的就是model.l1, 他的輸入和輸出,最終我們將他獲取的網(wǎng)絡(luò)層的名字、輸入以及輸出保存到類SaveValues中方便我們輸出
注意:hook_fn_act函數(shù)必須有三個(gè)參數(shù),分別對應(yīng)module,input以及output
import torch import torch.nn as nn class SaveValues(): def __init__(self, layer): self.model = None self.input = None self.output = None self.grad_input = None self.grad_output = None self.forward_hook = layer.register_forward_hook(self.hook_fn_act) self.backward_hook = layer.register_full_backward_hook(self.hook_fn_grad) def hook_fn_act(self, module, input, output): self.model = module self.input = input[0] self.output = output def hook_fn_grad(self, module, grad_input, grad_output): self.grad_input = grad_input[0] self.grad_output = grad_output[0] def remove(self): self.forward_hook.remove() self.backward_hook.remove() class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.l1 = nn.Linear(2, 5) self.l2 = nn.Linear(5, 10) def forward(self, x): x = self.l1(x) x = self.l2(x) return x l1loss = nn.L1Loss() model = Net() value = SaveValues(model.l2) gt = torch.ones((10,), dtype=torch.float32, requires_grad=False) x = torch.ones((2,), dtype=torch.float32, requires_grad=False) y = model(x) loss = l1loss(y, gt) loss.backward() x += 1.2 value.remove()
運(yùn)行上述程序,當(dāng)我們運(yùn)行到y = model(x)
這一行時(shí),我們看一下value中的值(圖左),當(dāng)我們運(yùn)行完y = model(x)
時(shí),我們看一下value中的值(圖右),這是因?yàn)樵趫?zhí)行net中的forward函數(shù)時(shí),我們的hook機(jī)制會(huì)從中提取出網(wǎng)絡(luò)的輸入和輸出,不執(zhí)行forward就不會(huì)提取
注意:
當(dāng)我們不想在提取網(wǎng)絡(luò)中間層時(shí),我們調(diào)用value.remove()即可,即刪除了網(wǎng)絡(luò)中的hook。
但是在訓(xùn)練網(wǎng)絡(luò)時(shí)我們可能需要輸出每個(gè)epoch的中間層信息,那么在for循環(huán)中就不需要?jiǎng)h除hook啦
register_full_backward_hook
好像這個(gè)反向hook很少用到?
register_forward_hook(hook)
- 返回module中的一個(gè)反向的hook,這個(gè)hook每次在執(zhí)行forward的時(shí)候都會(huì)被調(diào)用
- hook:
hook(module, grad_input, grad_output)
繼續(xù)上述的代碼,這次我們運(yùn)行到loss.backward()
之前與之后查看value中存儲(chǔ)的grad的變化,如下所示,可以發(fā)現(xiàn)在沒有反向傳播之前grad為None,當(dāng)我們執(zhí)行反向傳播之后grad就有值了
注意:
這里將layer換成了l2,因?yàn)榈谝粚觢1經(jīng)過backward之后依然是左圖不變,可能是第一層沒有梯度?
value = SaveValues(model.l2) # modify here: model.l1--->model.l2
remove
關(guān)于remove其實(shí)如果顯存足夠可以不用remove,雖然每個(gè)epoch的時(shí)候hook的值都會(huì)變化,但是只占用一個(gè)hook的內(nèi)存,除非開銷很大可以考慮remove
visual
當(dāng)我們的SaveValues類提取出特征圖之后,就可以對value.output進(jìn)行可視化啦
當(dāng)然如果有需要也可以用input、output或者grad進(jìn)行相應(yīng)的操作
使用Pytorch的hook機(jī)制提取特征時(shí)踩的一個(gè)坑
因?yàn)轫?xiàng)目需求,需要用DenseNet模型提取圖片特征,在使用Pytorch的hook機(jī)制提取特征,調(diào)試的時(shí)候發(fā)現(xiàn)提取出來的特征數(shù)值上全部大于等于0。
很明顯提取出來的特征是經(jīng)過ReLU的?,F(xiàn)在來看一下筆者是怎么定義hook的:
fmap_block = [] # 注冊hook def forward_hook(module, input, output): fmap_block.append(output) get_feature_model = densenet121(num_classes=2, pretrained=False) model_dict = torch.load(model_weight_path) get_feature_model = nn.DataParallel(get_feature_model.cuda()) get_feature_model.module.features.register_forward_hook(forward_hook)
模型定義的時(shí)候因項(xiàng)目需求,筆者并沒有使用預(yù)訓(xùn)練模型。而是自己訓(xùn)練了一個(gè)DenseNet121模型,并且使用了DataParallel進(jìn)行包裝。這里有兩點(diǎn)需要注意:
1.大部分的官方模型都會(huì)分成兩部分,分別是特征層和分類層。
def forward(self, x): features = self.features(x) out = F.relu(features, inplace=True) out = F.avg_pool2d(out, kernel_size=7, stride=1).view(features.size(0), -1) out = self.classifier(out) return out
這是DenseNet模型前向傳播代碼,很明顯就是筆者上訴說的那樣。所以在使用Pytorch的hook進(jìn)行提取特征的時(shí)候可以很方便的定義成這個(gè)樣子:
DenseNet類實(shí)例.features.register_forward_hook(forward_hook)
2.眼尖的讀者可以發(fā)現(xiàn)筆者的代碼里并不是這樣定義的,多了一個(gè).module(這也算是一個(gè)小小的坑)。這是因?yàn)楣P者使用了DataParallel進(jìn)行包裝模型,使之可以使用多GPU訓(xùn)練,下面來看一下DataParallel的源碼:
def __init__(self, module, device_ids=None, output_device=None, dim=0): super(DataParallel, self).__init__() if not torch.cuda.is_available(): self.module = module self.device_ids = []
可以看到初始化DataParallel類的時(shí)候,將model作為一個(gè)參數(shù)傳給了module,所以得多加一個(gè).module才能定位到我們需要的feature。
看到這里,估計(jì)很多人已經(jīng)發(fā)現(xiàn)問題在哪里了,沒錯(cuò),問題出現(xiàn)了前向傳播部分,更準(zhǔn)確的來說是relu函數(shù)。
out = F.relu(features, inplace=True)
inplace表示原地修改張量,所以經(jīng)過relu層時(shí)提前放在列表中的特征張量就會(huì)被修改。兩種解決方法:
將inplace置為False,這樣就不會(huì)原地修改張量了。修改hook函數(shù)
def forward_hook(module, input, output): fmap_block.append(output.detach().cpu())
總結(jié)
以上為個(gè)人經(jīng)驗(yàn),希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
相關(guān)文章
Python實(shí)現(xiàn)一個(gè)發(fā)送程序和接收程序
這篇文章主要介紹了Python實(shí)現(xiàn)一個(gè)發(fā)送程序和接收程序,文章圍繞主題展開詳細(xì)的內(nèi)容介紹,具有一定的參考價(jià)值,需要的小伙伴可以參考一下2022-09-09關(guān)于keras中keras.layers.merge的用法說明
這篇文章主要介紹了關(guān)于keras中keras.layers.merge的用法說明,具有很好的參考價(jià)值,希望對大家有所幫助。一起跟隨小編過來看看吧2020-05-05

python 實(shí)現(xiàn)查找文件并輸出滿足某一條件的數(shù)據(jù)項(xiàng)方法

Python項(xiàng)目 基于Scapy實(shí)現(xiàn)SYN泛洪攻擊的方法

python+django+selenium搭建簡易自動(dòng)化測試

python linecache 處理固定格式文本數(shù)據(jù)的方法