pytorch實現(xiàn)模型剪枝的操作方法
一,剪枝分類
所謂模型剪枝,其實是一種從神經(jīng)網(wǎng)絡中移除"不必要"權(quán)重或偏差(weigths/bias)的模型壓縮技術(shù)。關于什么參數(shù)才是“不必要的”,這是一個目前依然在研究的領域。
1.1,非結(jié)構(gòu)化剪枝
非結(jié)構(gòu)化剪枝(Unstructured Puning)是指修剪參數(shù)的單個元素,比如全連接層中的單個權(quán)重、卷積層中的單個卷積核參數(shù)元素或者自定義層中的浮點數(shù)(scaling floats)。其重點在于,剪枝權(quán)重對象是隨機的,沒有特定結(jié)構(gòu),因此被稱為非結(jié)構(gòu)化剪枝。
1.2,結(jié)構(gòu)化剪枝
與非結(jié)構(gòu)化剪枝相反,結(jié)構(gòu)化剪枝會剪枝整個參數(shù)結(jié)構(gòu)。比如,丟棄整行或整列的權(quán)重,或者在卷積層中丟棄整個過濾器(Filter
)。
1.3,本地與全局修剪
剪枝可以在每層(局部)或多層/所有層(全局)上進行。
二,PyTorch 的剪枝
目前 PyTorch 框架支持的權(quán)重剪枝方法有:
- Random: 簡單地修剪隨機參數(shù)。
- Magnitude: 修剪權(quán)重最小的參數(shù)(例如它們的 L2 范數(shù))
以上兩種方法實現(xiàn)簡單、計算容易,且可以在沒有任何數(shù)據(jù)的情況下應用。
2.1,pytorch 剪枝工作原理
剪枝功能在 torch.nn.utils.prune
類中實現(xiàn),代碼在文件 torch/nn/utils/prune.py 中,主要剪枝類如下圖所示。
剪枝原理是基于張量(Tensor)的掩碼(Mask)實現(xiàn)。掩碼是一個與張量形狀相同的布爾類型的張量,掩碼的值為 True 表示相應位置的權(quán)重需要保留,掩碼的值為 False 表示相應位置的權(quán)重可以被刪除。
Pytorch 將原始參數(shù) <param>
復制到名為 <param>_original
的參數(shù)中,并創(chuàng)建一個緩沖區(qū)來存儲剪枝掩碼 <param>_mask
。同時,其也會創(chuàng)建一個模塊級的 forward_pre_hook 回調(diào)函數(shù)(在模型前向傳播之前會被調(diào)用的回調(diào)函數(shù)),將剪枝掩碼應用于原始權(quán)重。
pytorch 剪枝的 api
和教程比較混亂,我個人將做了如下表格,希望能將 api 和剪枝方法及分類總結(jié)好。
pytorch 中進行模型剪枝的工作流程如下:
- 選擇剪枝方法(或者子類化 BasePruningMethod 實現(xiàn)自己的剪枝方法)。
- 指定剪枝模塊和參數(shù)名稱。
- 設置剪枝方法的參數(shù),比如剪枝比例等。
2.2,局部剪枝
Pytorch 框架中的局部剪枝有非結(jié)構(gòu)化和結(jié)構(gòu)化剪枝兩種類型,值得注意的是結(jié)構(gòu)化剪枝只支持局部不支持全局。
2.2.1,局部非結(jié)構(gòu)化剪枝
1,局部非結(jié)構(gòu)化剪枝(Locall Unstructured Pruning)對應函數(shù)原型如下:
def random_unstructured(module, name, amount)
1,函數(shù)功能:
用于對權(quán)重參數(shù)張量進行非結(jié)構(gòu)化剪枝。該方法會在張量中隨機選擇一些權(quán)重或連接進行剪枝,剪枝率由用戶指定。
2,函數(shù)參數(shù)定義:
module
(nn.Module): 需要剪枝的網(wǎng)絡層/模塊,例如 nn.Conv2d() 和 nn.Linear()。name
(str): 要剪枝的參數(shù)名稱,比如 "weight" 或 "bias"。amount
(int or float): 指定要剪枝的數(shù)量,如果是 0~1 之間的小數(shù),則表示剪枝比例;如果是證書,則直接剪去參數(shù)的絕對數(shù)量。比如amount=0.2
,表示將隨機選擇 20% 的元素進行剪枝。
3,下面是 random_unstructured
函數(shù)的使用示例。
import torch import torch.nn.utils.prune as prune conv = torch.nn.Conv2d(1, 1, 4) prune.random_unstructured(conv, name="weight", amount=0.5) conv.weight """ tensor([[[[-0.1703, 0.0000, -0.0000, 0.0690], [ 0.1411, 0.0000, -0.0000, -0.1031], [-0.0527, 0.0000, 0.0640, 0.1666], [ 0.0000, -0.0000, -0.0000, 0.2281]]]], grad_fn=<MulBackward0>) """
可以看書輸出的 conv 層中權(quán)重值有一半比例為 0
。
2.2.2,局部結(jié)構(gòu)化剪枝
局部結(jié)構(gòu)化剪枝(Locall Structured Pruning)有兩種函數(shù),對應函數(shù)原型如下:
def random_structured(module, name, amount, dim) def ln_structured(module, name, amount, n, dim, importance_scores=None)
1,函數(shù)功能
與非結(jié)構(gòu)化移除的是連接權(quán)重不同,結(jié)構(gòu)化剪枝移除的是整個通道權(quán)重。
2,參數(shù)定義
與局部非結(jié)構(gòu)化函數(shù)非常相似,唯一的區(qū)別是您必須定義 dim 參數(shù)(ln_structured 函數(shù)多了 n
參數(shù))。
n
表示剪枝的范數(shù),dim
表示剪枝的維度。
對于 torch.nn.Linear:
dim = 0
: 移除一個神經(jīng)元。dim = 1
:移除與一個輸入的所有連接。
對于 torch.nn.Conv2d:
dim = 0
(Channels) : 通道 channels 剪枝/過濾器 filters 剪枝dim = 1
(Neurons): 二維卷積核 kernel 剪枝,即與輸入通道相連接的 kernel
2.2.3,局部結(jié)構(gòu)化剪枝示例代碼
在寫示例代碼之前,我們先需要理解 Conv2d
函數(shù)參數(shù)、卷積核 shape、軸以及張量的關系。
首先,Conv2d 函數(shù)原型如下;
class torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True)
而 pytorch 中常規(guī)卷積的卷積核權(quán)重 shape
都為(C_out, C_in, kernel_height, kernel_width
),所以在代碼中卷積層權(quán)重 shape
為 [3, 2, 3, 3]
,dim = 0 對應的是 shape [3, 2, 3, 3] 中的 3
。這里我們 dim 設定了哪個軸,那自然剪枝之后權(quán)重張量對應的軸機會發(fā)生變換。
理解了前面的關鍵概念,下面就可以實際使用了,dim=0
的示例如下所示。
conv = torch.nn.Conv2d(2, 3, 3) norm1 = torch.norm(conv.weight, p=1, dim=[1,2,3]) print(norm1) """ tensor([1.9384, 2.3780, 1.8638], grad_fn=<NormBackward1>) """ prune.ln_structured(conv, name="weight", amount=1, n=2, dim=0) print(conv.weight) """ tensor([[[[-0.0005, 0.1039, 0.0306], [ 0.1233, 0.1517, 0.0628], [ 0.1075, -0.0606, 0.1140]], [[ 0.2263, -0.0199, 0.1275], [-0.0455, -0.0639, -0.2153], [ 0.1587, -0.1928, 0.1338]]], [[[-0.2023, 0.0012, 0.1617], [-0.1089, 0.2102, -0.2222], [ 0.0645, -0.2333, -0.1211]], [[ 0.2138, -0.0325, 0.0246], [-0.0507, 0.1812, -0.2268], [-0.1902, 0.0798, 0.0531]]], [[[ 0.0000, -0.0000, -0.0000], [ 0.0000, -0.0000, -0.0000], [ 0.0000, -0.0000, 0.0000]], [[ 0.0000, 0.0000, 0.0000], [-0.0000, 0.0000, 0.0000], [-0.0000, -0.0000, -0.0000]]]], grad_fn=<MulBackward0>) """
從運行結(jié)果可以明顯看出,卷積層參數(shù)的最后一個通道參數(shù)張量被移除了(為 0
張量),其解釋參見下圖。
dim = 1
的情況:
conv = torch.nn.Conv2d(2, 3, 3) norm1 = torch.norm(conv.weight, p=1, dim=[0, 2,3]) print(norm1) """ tensor([3.1487, 3.9088], grad_fn=<NormBackward1>) """ prune.ln_structured(conv, name="weight", amount=1, n=2, dim=1) print(conv.weight) """ tensor([[[[ 0.0000, -0.0000, -0.0000], [-0.0000, 0.0000, 0.0000], [-0.0000, 0.0000, -0.0000]], [[-0.2140, 0.1038, 0.1660], [ 0.1265, -0.1650, -0.2183], [-0.0680, 0.2280, 0.2128]]], [[[-0.0000, 0.0000, 0.0000], [ 0.0000, 0.0000, -0.0000], [-0.0000, -0.0000, -0.0000]], [[-0.2087, 0.1275, 0.0228], [-0.1888, -0.1345, 0.1826], [-0.2312, -0.1456, -0.1085]]], [[[-0.0000, 0.0000, 0.0000], [ 0.0000, -0.0000, 0.0000], [ 0.0000, -0.0000, 0.0000]], [[-0.0891, 0.0946, -0.1724], [-0.2068, 0.0823, 0.0272], [-0.2256, -0.1260, -0.0323]]]], grad_fn=<MulBackward0>) """
很明顯,對于 dim=1
的維度,其第一個張量的 L2 范數(shù)更小,所以shape 為 [2, 3, 3] 的張量中,第一個 [3, 3] 張量參數(shù)會被移除(即張量為 0 矩陣) 。
2.3,全局非結(jié)構(gòu)化剪枝
前文的 local 剪枝的對象是特定網(wǎng)絡層,而 global 剪枝是將模型看作一個整體去移除指定比例(數(shù)量)的參數(shù),同時 global 剪枝結(jié)果會導致模型中每層的稀疏比例是不一樣的。
全局非結(jié)構(gòu)化剪枝函數(shù)原型如下:
# v1.4.0 版本 def global_unstructured(parameters, pruning_method, **kwargs) # v2.0.0-rc2版本 def global_unstructured(parameters, pruning_method, importance_scores=None, **kwargs):
1,函數(shù)功能:
隨機選擇全局所有參數(shù)(包括權(quán)重和偏置)的一部分進行剪枝,而不管它們屬于哪個層。
2,參數(shù)定義:
parameters
((Iterable of (module, name) tuples)): 修剪模型的參數(shù)列表,列表中的元素是 (module, name)。pruning_method
(function): 目前好像官方只支持 pruning_method=prune.L1Unstuctured,另外也可以是自己實現(xiàn)的非結(jié)構(gòu)化剪枝方法函數(shù)。importance_scores
: 表示每個參數(shù)的重要性得分,如果為 None,則使用默認得分。**kwargs
: 表示傳遞給特定剪枝方法的額外參數(shù)。比如amount
指定要剪枝的數(shù)量。
3,global_unstructured
函數(shù)的示例代碼如下所示。
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") class LeNet(nn.Module): def __init__(self): super(LeNet, self).__init__() # 1 input image channel, 6 output channels, 3x3 square conv kernel self.conv1 = nn.Conv2d(1, 6, 3) self.conv2 = nn.Conv2d(6, 16, 3) self.fc1 = nn.Linear(16 * 5 * 5, 120) # 5x5 image dimension self.fc2 = nn.Linear(120, 84) self.fc3 = nn.Linear(84, 10) def forward(self, x): x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2)) x = F.max_pool2d(F.relu(self.conv2(x)), 2) x = x.view(-1, int(x.nelement() / x.shape[0])) x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = self.fc3(x) return x model = LeNet().to(device=device) model = LeNet() parameters_to_prune = ( (model.conv1, 'weight'), (model.conv2, 'weight'), (model.fc1, 'weight'), (model.fc2, 'weight'), (model.fc3, 'weight'), ) prune.global_unstructured( parameters_to_prune, pruning_method=prune.L1Unstructured, amount=0.2, ) # 計算卷積層和整個模型的稀疏度 # 其實調(diào)用的是 Tensor.numel 內(nèi)內(nèi)函數(shù),返回輸入張量中元素的總數(shù) print( "Sparsity in conv1.weight: {:.2f}%".format( 100. * float(torch.sum(model.conv1.weight == 0)) / float(model.conv1.weight.nelement()) ) ) print( "Global sparsity: {:.2f}%".format( 100. * float( torch.sum(model.conv1.weight == 0) + torch.sum(model.conv2.weight == 0) + torch.sum(model.fc1.weight == 0) + torch.sum(model.fc2.weight == 0) + torch.sum(model.fc3.weight == 0) ) / float( model.conv1.weight.nelement() + model.conv2.weight.nelement() + model.fc1.weight.nelement() + model.fc2.weight.nelement() + model.fc3.weight.nelement() ) ) ) # 程序運行結(jié)果 """ Sparsity in conv1.weight: 3.70% Global sparsity: 20.00% """
運行結(jié)果表明,雖然模型整體(全局)的稀疏度是 20%
,但每個網(wǎng)絡層的稀疏度不一定是 20%。
三,總結(jié)
另外,pytorch 框架還提供了一些幫助函數(shù):
- torch.nn.utils.prune.is_pruned(module): 判斷模塊 是否被剪枝。
- torch.nn.utils.prune.remove(module, name): 用于將指定模塊中指定參數(shù)上的剪枝操作移除,從而恢復該參數(shù)的原始形狀和數(shù)值。
雖然 PyTorch 提供了內(nèi)置剪枝 API
,也支持了一些非結(jié)構(gòu)化和結(jié)構(gòu)化剪枝方法,但是 API
比較混亂,對應文檔描述也不清晰,所以后面我還會結(jié)合微軟的開源 nni
工具來實現(xiàn)模型剪枝功能。
參考資料
到此這篇關于pytorch實現(xiàn)模型剪枝的文章就介紹到這了,更多相關pytorch模型剪枝內(nèi)容請搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關文章希望大家以后多多支持腳本之家!
相關文章
selenium+opencv實現(xiàn)滑塊驗證碼的登陸
很多網(wǎng)站登錄登陸時都要用到滑塊驗證碼,本文主要介紹了selenium+opencv實現(xiàn)滑塊驗證碼的登陸,文中通過示例代碼介紹的非常詳細,對大家的學習或者工作具有一定的參考學習價值,需要的朋友們下面隨著小編來一起學習學習吧2023-04-04PageFactory設計模式基于python實現(xiàn)
這篇文章主要介紹了PageFactory設計模式基于python實現(xiàn),文中通過示例代碼介紹的非常詳細,對大家的學習或者工作具有一定的參考學習價值,需要的朋友可以參考下2020-04-04Python集成開發(fā)環(huán)境Pycharm的使用及技巧
本文詳細講解了Python集成開發(fā)環(huán)境Pycharm的使用及技巧,對大家的學習或者工作具有一定的參考學習價值,需要的朋友們下面隨著小編來一起學習學習吧2022-06-06