PyTorch之關于hook機制
PyTorch: hook機制
在訓練神經(jīng)網(wǎng)絡的時候我們有時需要輸出網(wǎng)絡中間層,一般來說我們有兩種處理方法:
一種是在model的forward中保存中間層的變量,然后再return的時候將其和結果一起返回;
另一種是使用pytorch自帶的register_forward_hook,即hook機制
register_forward_hook
register_forward_hook(hook)
- 返回module中的一個前向的hook,這個hook每次在執(zhí)行forward的時候都會被調(diào)用
- hook:
hook(module, input, output)
可能不是很好理解,我們直接用一個例子來說明,如下所示,首先我們將hook包裝在類SaveValues中,我們現(xiàn)在想要獲取模型Net中的l1的輸入和輸出,因此將model.l1存入到類中:value = SaveValues(model.l1)
,在類中定義一個hook_fn_act函數(shù),此函數(shù)的作用是隨著我們的register_forward_hook
函數(shù)獲取Net的某一層的名字,輸入以及輸出,在這里對應的就是model.l1, 他的輸入和輸出,最終我們將他獲取的網(wǎng)絡層的名字、輸入以及輸出保存到類SaveValues中方便我們輸出
注意:hook_fn_act函數(shù)必須有三個參數(shù),分別對應module,input以及output
import torch import torch.nn as nn class SaveValues(): def __init__(self, layer): self.model = None self.input = None self.output = None self.grad_input = None self.grad_output = None self.forward_hook = layer.register_forward_hook(self.hook_fn_act) self.backward_hook = layer.register_full_backward_hook(self.hook_fn_grad) def hook_fn_act(self, module, input, output): self.model = module self.input = input[0] self.output = output def hook_fn_grad(self, module, grad_input, grad_output): self.grad_input = grad_input[0] self.grad_output = grad_output[0] def remove(self): self.forward_hook.remove() self.backward_hook.remove() class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.l1 = nn.Linear(2, 5) self.l2 = nn.Linear(5, 10) def forward(self, x): x = self.l1(x) x = self.l2(x) return x l1loss = nn.L1Loss() model = Net() value = SaveValues(model.l2) gt = torch.ones((10,), dtype=torch.float32, requires_grad=False) x = torch.ones((2,), dtype=torch.float32, requires_grad=False) y = model(x) loss = l1loss(y, gt) loss.backward() x += 1.2 value.remove()
運行上述程序,當我們運行到y = model(x)
這一行時,我們看一下value中的值(圖左),當我們運行完y = model(x)
時,我們看一下value中的值(圖右),這是因為在執(zhí)行net中的forward函數(shù)時,我們的hook機制會從中提取出網(wǎng)絡的輸入和輸出,不執(zhí)行forward就不會提取
注意:
當我們不想在提取網(wǎng)絡中間層時,我們調(diào)用value.remove()即可,即刪除了網(wǎng)絡中的hook。
但是在訓練網(wǎng)絡時我們可能需要輸出每個epoch的中間層信息,那么在for循環(huán)中就不需要刪除hook啦
register_full_backward_hook
好像這個反向hook很少用到?
register_forward_hook(hook)
- 返回module中的一個反向的hook,這個hook每次在執(zhí)行forward的時候都會被調(diào)用
- hook:
hook(module, grad_input, grad_output)
繼續(xù)上述的代碼,這次我們運行到loss.backward()
之前與之后查看value中存儲的grad的變化,如下所示,可以發(fā)現(xiàn)在沒有反向傳播之前grad為None,當我們執(zhí)行反向傳播之后grad就有值了
注意:
這里將layer換成了l2,因為第一層l1經(jīng)過backward之后依然是左圖不變,可能是第一層沒有梯度?
value = SaveValues(model.l2) # modify here: model.l1--->model.l2
remove
關于remove其實如果顯存足夠可以不用remove,雖然每個epoch的時候hook的值都會變化,但是只占用一個hook的內(nèi)存,除非開銷很大可以考慮remove
visual
當我們的SaveValues類提取出特征圖之后,就可以對value.output進行可視化啦
當然如果有需要也可以用input、output或者grad進行相應的操作
使用Pytorch的hook機制提取特征時踩的一個坑
因為項目需求,需要用DenseNet模型提取圖片特征,在使用Pytorch的hook機制提取特征,調(diào)試的時候發(fā)現(xiàn)提取出來的特征數(shù)值上全部大于等于0。
很明顯提取出來的特征是經(jīng)過ReLU的?,F(xiàn)在來看一下筆者是怎么定義hook的:
fmap_block = [] # 注冊hook def forward_hook(module, input, output): fmap_block.append(output) get_feature_model = densenet121(num_classes=2, pretrained=False) model_dict = torch.load(model_weight_path) get_feature_model = nn.DataParallel(get_feature_model.cuda()) get_feature_model.module.features.register_forward_hook(forward_hook)
模型定義的時候因項目需求,筆者并沒有使用預訓練模型。而是自己訓練了一個DenseNet121模型,并且使用了DataParallel進行包裝。這里有兩點需要注意:
1.大部分的官方模型都會分成兩部分,分別是特征層和分類層。
def forward(self, x): features = self.features(x) out = F.relu(features, inplace=True) out = F.avg_pool2d(out, kernel_size=7, stride=1).view(features.size(0), -1) out = self.classifier(out) return out
這是DenseNet模型前向傳播代碼,很明顯就是筆者上訴說的那樣。所以在使用Pytorch的hook進行提取特征的時候可以很方便的定義成這個樣子:
DenseNet類實例.features.register_forward_hook(forward_hook)
2.眼尖的讀者可以發(fā)現(xiàn)筆者的代碼里并不是這樣定義的,多了一個.module(這也算是一個小小的坑)。這是因為筆者使用了DataParallel進行包裝模型,使之可以使用多GPU訓練,下面來看一下DataParallel的源碼:
def __init__(self, module, device_ids=None, output_device=None, dim=0): super(DataParallel, self).__init__() if not torch.cuda.is_available(): self.module = module self.device_ids = []
可以看到初始化DataParallel類的時候,將model作為一個參數(shù)傳給了module,所以得多加一個.module才能定位到我們需要的feature。
看到這里,估計很多人已經(jīng)發(fā)現(xiàn)問題在哪里了,沒錯,問題出現(xiàn)了前向傳播部分,更準確的來說是relu函數(shù)。
out = F.relu(features, inplace=True)
inplace表示原地修改張量,所以經(jīng)過relu層時提前放在列表中的特征張量就會被修改。兩種解決方法:
將inplace置為False,這樣就不會原地修改張量了。修改hook函數(shù)
def forward_hook(module, input, output): fmap_block.append(output.detach().cpu())
總結
以上為個人經(jīng)驗,希望能給大家一個參考,也希望大家多多支持腳本之家。
相關文章
關于keras中keras.layers.merge的用法說明
這篇文章主要介紹了關于keras中keras.layers.merge的用法說明,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2020-05-05python 實現(xiàn)查找文件并輸出滿足某一條件的數(shù)據(jù)項方法
今天小編就為大家分享一篇python 實現(xiàn)查找文件并輸出滿足某一條件的數(shù)據(jù)項方法,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2019-06-06Python項目 基于Scapy實現(xiàn)SYN泛洪攻擊的方法
今天小編就為大家分享一篇Python項目 基于Scapy實現(xiàn)SYN泛洪攻擊的方法,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2019-07-07python+django+selenium搭建簡易自動化測試
這篇文章主要介紹了python+django+selenium搭建簡易自動化測試,文中通過示例代碼介紹的非常詳細,對大家的學習或者工作具有一定的參考學習價值,需要的朋友們下面隨著小編來一起學習學習吧2020-08-08python linecache 處理固定格式文本數(shù)據(jù)的方法
今天小編就為大家分享一篇python linecache 處理固定格式文本數(shù)據(jù)的方法,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2019-01-01