pytorch中可視化之hook鉤子
一、hook
在PyTorch中,提供了一個(gè)專用的接口使得網(wǎng)絡(luò)在前向傳播過(guò)程中能夠獲取到特征圖,這個(gè)接口的名稱非常形象,叫做hook。
可以想象這樣的場(chǎng)景,數(shù)據(jù)通過(guò)網(wǎng)絡(luò)向前傳播,網(wǎng)絡(luò)某一層我們預(yù)先設(shè)置了一個(gè)鉤子,數(shù)據(jù)傳播過(guò)后鉤子上會(huì)留下數(shù)據(jù)在這一層的樣子,讀取鉤子的信息就是這一層的特征圖。
具體實(shí)現(xiàn)如下:
1.1 什么是hook,什么情況下使用?
首先,明確一下,為什么需要用hook,假設(shè)有這么一個(gè)函數(shù)
需要通過(guò)梯度下降法求最小值,其實(shí)現(xiàn)方法如下:
import torch x = torch.tensor(3.0, requires_grad=True) y = (x-2) z = ((y-x) ** 2) z.backward() print("x.grad:",x.requires_grad,x.grad) print("y.grad:",y.requires_grad,y.grad) print("z.grad:",z.requires_grad,z.grad)
結(jié)果如下:
x.grad: True tensor(0.)
y.grad: True None
z.grad: True None
注意:在使用訓(xùn)練PyTorch訓(xùn)練模型時(shí),只有葉節(jié)點(diǎn)(即直接指定數(shù)值的變量,而不是由其他變量計(jì)算得到的,比如網(wǎng)絡(luò)輸入)的梯度會(huì)保留,其余中間節(jié)點(diǎn)梯度在反向傳播完成后就會(huì)自動(dòng)釋放以節(jié)省顯存。 因此y.requires_grad的返回值為True,y.grad卻為None。
可以看到上面的requires_grad方法都顯示True,但是grad沒(méi)有返回值。當(dāng)然pytorch也提供某種方法保留非葉子節(jié)點(diǎn)的梯度信息。
使用 retain_grad() 方法可以保留非葉子節(jié)點(diǎn)的梯度,使用 retain_grad 保留的grad會(huì)占用顯存,具體操作如下:
x = torch.tensor(3.0, requires_grad=True) y = (x-2) z = ((y-x) ** 2) y.retain_grad() z.retain_grad() z.backward() print("x.grad:",x.requires_grad,x.grad) print("y.grad:",y.requires_grad,y.grad) print("z.grad:",z.requires_grad,z.grad)
out:
x.grad: True tensor(0.) y.grad: True tensor(-4.) z.grad: True tensor(1.)
** 重申一次** 使用retain_grad方法會(huì)占用顯存,如果不想要占用顯存,就使用到了hook方法。
對(duì)于中間節(jié)點(diǎn)的變量a,可以使用a.register_hook(hook_fn)對(duì)其grad進(jìn)行操作。 而hook_fn是一個(gè)自定義的函數(shù),其聲明為hook_fn(grad) -> Tensor or None
1.2 hook在變量中的使用
1.2.1 hook的打印功能
# 自定義hook方法,其傳入?yún)?shù)為grad,打印出使用鉤子的節(jié)點(diǎn)梯度 def hook_fn(grad): print(grad) x = torch.tensor(3.0, requires_grad=True) y = (x-2) z = ((y-x) ** 2) y.register_hook(hook_fn) z.register_hook(hook_fn) print("backward前") z.backward() print("backward后\n") print("x.grad:",x.requires_grad,x.grad) print("y.grad:",y.requires_grad,y.grad) print("z.grad:",z.requires_grad,z.grad)
out:
backward前 tensor(1.) tensor(-4.) backward后 x.grad: True tensor(0.) y.grad: True None z.grad: True None
可以看到綁定hook后,backward打印的時(shí)候打印了y和z的梯度,調(diào)用grad的時(shí)候沒(méi)有保留grad值,已經(jīng)釋放掉內(nèi)存。注意,打印出來(lái)的結(jié)果是反向傳播,所以先打印z的梯度,再打印y的梯度。
1.2.2 使用hook改變grad的功能
對(duì)標(biāo)記的節(jié)點(diǎn),梯度加2
def hook_fn(grad): grad += 2 print(grad) return grad x = torch.tensor(3.0, requires_grad=True) y = (x-2) z = ((y-x) ** 2) y.register_hook(hook_fn) z.register_hook(hook_fn) print("backward前") z.backward() print("backward后\n") print("x.grad:",x.requires_grad,x.grad) print("y.grad:",x.requires_grad,y.grad) print("z.grad:",x.requires_grad,z.grad)
out:
backward前 tensor(3.) tensor(-10.) backward后 x.grad: True tensor(2.) y.grad: True None z.grad: True None
可以看到梯度教上面的已經(jīng)發(fā)生的改變。
1.3 hook在模型中的使用:
PyTorch中使用register_forward_hook和register_backward_hook獲取Module輸入和輸出的feature_map和grad。使用結(jié)構(gòu)如下: hook_fn(module, input, output) -> Tensor or None
模型中使用hook一點(diǎn)要帶有這三個(gè)參數(shù)module, grad_input, grad_output
1.3.1 register_forward_hook的使用
import torch.nn as nn def hook_forward_fn(model,put,out): print("model:",model) print("input:",put) print("output:",out) # 定義一個(gè)model class Net(nn.Module): def __init__(self): super(Net,self).__init__() self.conv = nn.Conv2d(3, 1, 1) self.bn = nn.BatchNorm2d(1) #self.conv.register_forward_hook(hook_forward_fn) #self.bn.register_forward_hook(hook_forward_fn) def forward(self, x): x = self.conv(x) x = self.bn(x) return torch.relu(x) net = Net() # 對(duì)模型中的具體某一層使用hook net.conv.register_forward_hook(hook_forward_fn) net.bn.register_forward_hook(hook_forward_fn) x = torch.rand(1, 3, 2, 2, requires_grad=True) y = net(x).mean()
注意:該方法不需要使用。backword就能輸出結(jié)果,是記錄前向傳播的鉤子。
結(jié)果如下:
model: Conv2d(3, 1, kernel_size=(1, 1), stride=(1, 1)) input: (tensor([[[[0.4570, 0.6791], [0.0197, 0.5040]], [[0.8883, 0.1808], [0.6289, 0.9386]], [[0.8772, 0.5290], [0.0014, 0.3728]]]], requires_grad=True),) output: tensor([[[[-0.4909, -0.1122], [-0.6301, -0.5649]]]], grad_fn=<ConvolutionBackward0>) model: BatchNorm2d(1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) input: (tensor([[[[-0.4909, -0.1122], [-0.6301, -0.5649]]]], grad_fn=<ConvolutionBackward0>),) output: tensor([[[[-0.2060, 1.6790], [-0.8987, -0.5743]]]], grad_fn=<NativeBatchNormBackward0>)
1.3.2 register_backward_hook的使用
使用上面相同的Net模型
def hook_backward_fn(module, grad_input, grad_output): print(f"module: {module}") print(f"grad_output: {grad_output}") print(f"grad_input: {grad_input}") print("*"*20) net = Net() net.conv.register_backward_hook(hook_backward_fn) net.bn.register_backward_hook(hook_backward_fn) x = x = torch.rand(1, 3, 2, 2, requires_grad=True) y = net(x).mean() y.backward()
out:
module: BatchNorm2d(1, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) grad_output: (tensor([[[[0.2500, 0.2500], [0.0000, 0.0000]]]]),) grad_input: (tensor([[[[ 0.6586, -0.3360], [-0.3009, -0.0218]]]]), tensor([0.4575]), tensor([0.5000])) ******************** module: Conv2d(3, 1, kernel_size=(1, 1), stride=(1, 1)) grad_output: (tensor([[[[ 0.6586, -0.3360], [-0.3009, -0.0218]]]]),) grad_input: (tensor([[[[-0.2974, 0.1517], [ 0.1359, 0.0098]], [[ 0.0270, -0.0138], [-0.0123, -0.0009]], [[ 0.2918, -0.1489], [-0.1333, -0.0096]]]]), tensor([[[[0.4331]], [[0.1386]], [[0.4292]]]]), tensor([-1.4156e-07])) ********************
其結(jié)果是逆向輸出各節(jié)點(diǎn)層的梯度信息。
1.3.3 hook中使用展示卷積層
隨便畫一張圖,圖片張這個(gè)樣子:
使用讀取圖片發(fā)現(xiàn)是個(gè)4通道的圖像,我們轉(zhuǎn)成單通道并可視化:
import matplotlib.pyplot as plt import matplotlib.image as mping img=mping.imread("./test1.png") print(img.shape) img = torch.tensor(img[:,:,0]).view(1,1,228,226) plt.imshow(img[0][0])
接下來(lái)創(chuàng)建一個(gè)只有卷積層的模型
class Net(nn.Module): def __init__(self): super(Net,self).__init__() self.conv = nn.Sequential(nn.Conv2d(1,1,7), nn.ReLU() ) def forward(self, x): x=self.conv(x) return x
使用我們的鉤子hook對(duì)卷積層的輸出進(jìn)行可視化
def hook_forward_fn(model,put,out): print("inputshape:",put[0].shape) # 打印出輸入圖片的維度 print("outputshape:",out[0][0].shape) # 經(jīng)過(guò)卷積之后的維度 # 可視化,因?yàn)榫矸e之后帶有g(shù)rad梯度信息,所以需要使用detach().numpy()方法,否則會(huì)報(bào)錯(cuò) plt.imshow(out[0][0].detach().numpy())
具體完整實(shí)現(xiàn)以及可視化代碼如下:
import matplotlib.pyplot as plt import matplotlib.image as mping import numpy as np img=mping.imread("./test1.png") img = torch.tensor(img[:,:,0]).view(1,1,228,226) def hook_forward_fn(model,put,out): print("inputshape:",put[0].shape) print("outputshape:",out[0][0].shape) plt.imshow(out[0][0].detach().numpy()) class Net(nn.Module): def __init__(self): super(Net,self).__init__() self.conv = nn.Sequential(nn.Conv2d(1,1,7), nn.ReLU() ) def forward(self, x): x=self.conv(x) return x model = Net() model.conv.register_forward_hook(hook_forward_fn) y=model(img)
到此這篇關(guān)于pytorch中可視化之hook鉤子的文章就介紹到這了,更多相關(guān)pytorch hook鉤子內(nèi)容請(qǐng)搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
相關(guān)文章
Python學(xué)習(xí)之字典的創(chuàng)建和使用
這篇文章主要為大家介紹了Python中的字典的創(chuàng)建與使用,包括使用字典(添加、刪除、修改等操作),感興趣的小伙伴可以跟隨小編一起學(xué)習(xí)一下2022-06-06django將圖片上傳數(shù)據(jù)庫(kù)后在前端顯式的方法
今天小編就為大家分享一篇django將圖片上傳數(shù)據(jù)庫(kù)后在前端顯式的方法,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧2018-05-05Python制作簡(jiǎn)易注冊(cè)登錄系統(tǒng)
這篇文章主要為大家詳細(xì)介紹了Python簡(jiǎn)易注冊(cè)登錄系統(tǒng)的制作方法,具有一定的參考價(jià)值,感興趣的小伙伴們可以參考一下2016-12-12Python計(jì)算兩個(gè)矩形重合面積代碼實(shí)例
這篇文章主要介紹了Python 實(shí)現(xiàn)兩個(gè)矩形重合面積代碼實(shí)例,文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友可以參考下2019-09-09通過(guò)實(shí)例淺析Python對(duì)比C語(yǔ)言的編程思想差異
這篇文章主要介紹了通過(guò)實(shí)例淺析Python對(duì)比C語(yǔ)言的編程思想差異,作為面向?qū)ο蠛兔嫦蜻^(guò)程的編程語(yǔ)言代表,二者的對(duì)比可謂經(jīng)典,需要的朋友可以參考下2015-08-08利用Python正則表達(dá)式過(guò)濾敏感詞的方法
今天小編就為大家分享一篇利用Python正則表達(dá)式過(guò)濾敏感詞的方法,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧2019-01-01python datatable庫(kù)大型數(shù)據(jù)集和多核數(shù)據(jù)處理使用探索
這篇文章主要介紹了python datatable庫(kù)大型數(shù)據(jù)集和多核數(shù)據(jù)處理使用探索,有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進(jìn)步,早日升職加薪2024-01-01Python語(yǔ)法糖for?else循環(huán)語(yǔ)句里的break使用詳解
這篇文章主要介紹了Python語(yǔ)法糖之for?else循環(huán)語(yǔ)句里的break使用詳解,有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進(jìn)步,早日升職加薪2023-05-05