pytorch常用函數(shù)定義及resnet模型修改實(shí)例
模型定義常用函數(shù)
利用nn.Parameter()設(shè)計(jì)新的層
import torch from torch import nn class MyLinear(nn.Module): def __init__(self, in_features, out_features): super().__init__() self.weight = nn.Parameter(torch.randn(in_features, out_features)) self.bias = nn.Parameter(torch.randn(out_features)) def forward(self, input): return (input @ self.weight) + self.bias
nn.Sequential
一個(gè)有序的容器,神經(jīng)網(wǎng)絡(luò)模塊將按照在傳入構(gòu)造器的順序依次被添加到計(jì)算圖中執(zhí)行,同時(shí)以神經(jīng)網(wǎng)絡(luò)模塊為元素的有序字典也可以作為傳入?yún)?shù)。Sequential適用于快速驗(yàn)證結(jié)果,簡(jiǎn)單易讀,但使用Sequential也會(huì)使得模型定義喪失靈活性,比如需要在模型中間加入一個(gè)外部輸入時(shí)就不適合用Sequential的方式實(shí)現(xiàn)。
net = nn.Sequential( ('fc1',MyLinear(4, 3)), ('act',nn.ReLU()), ('fc2',MyLinear(3, 1)) )
nn.ModuleList()
ModuleList 接收一個(gè)子模塊(或?qū)?,需屬于nn.Module類)的列表作為輸入,然后也可以類似List那樣進(jìn)行append和extend操作。同時(shí),子模塊或?qū)拥臋?quán)重也會(huì)自動(dòng)添加到網(wǎng)絡(luò)中來。
net = nn.ModuleList([nn.Linear(784, 256), nn.ReLU()]) net.append(nn.Linear(256, 10)) # # 類似List的append操作 print(net[-1]) # 類似List的索引訪問 print(net)
Linear(in_features=256, out_features=10, bias=True) ModuleList( (0): Linear(in_features=784, out_features=256, bias=True) (1): ReLU() (2): Linear(in_features=256, out_features=10, bias=True) )
要特別注意的是,nn.ModuleList 并沒有定義一個(gè)網(wǎng)絡(luò),它只是將不同的模塊儲(chǔ)存在一起。
ModuleList中元素的先后順序并不代表其在網(wǎng)絡(luò)中的真實(shí)位置順序,需要經(jīng)過forward函數(shù)指定各個(gè)層的先后順序后才算完成了模型的定義。
具體實(shí)現(xiàn)時(shí)用for循環(huán)即可完成:
class model(nn.Module): def __init__(self, ...): super().__init__() self.modulelist = ... ... def forward(self, x): for layer in self.modulelist: x = layer(x) return x
nn.ModuleDict()
ModuleDict和ModuleList的作用類似,只是ModuleDict能夠更方便地為神經(jīng)網(wǎng)絡(luò)的層添加名稱。
net = nn.ModuleDict({ 'linear': nn.Linear(784, 256), 'act': nn.ReLU(), }) net['output'] = nn.Linear(256, 10) # 添加 print(net['linear']) # 訪問 print(net.output) print(net)
Linear(in_features=784, out_features=256, bias=True) Linear(in_features=256, out_features=10, bias=True) ModuleDict( (act): ReLU() (linear): Linear(in_features=784, out_features=256, bias=True) (output): Linear(in_features=256, out_features=10, bias=True) )
ModuleList和ModuleDict在某個(gè)完全相同的層需要重復(fù)出現(xiàn)多次時(shí),非常方便實(shí)現(xiàn),可以”一行頂多行“;當(dāng)我們需要之前層的信息的時(shí)候,比如 ResNets 中的殘差計(jì)算,當(dāng)前層的結(jié)果需要和之前層中的結(jié)果進(jìn)行融合,一般使用 ModuleList/ModuleDict 比較方便。
nn.Flatten
展平輸入的張量: 28x28 -> 784
input = torch.randn(32, 1, 5, 5) m = nn.Sequential( nn.Conv2d(1, 32, 5, 1, 1), nn.Flatten() ) output = m(input) output.size()
模型修改案例
有了上面的一些常用方法,我們可以修改現(xiàn)有的一些開源模型,這里通過介紹修改模型層、添加額外輸入的案例來幫助我們更好地理解。
修改模型層
以pytorch官方視覺庫(kù)torchvision預(yù)定義好的模型ResNet50為例,探索如何修改模型的某一層或者某幾層。
我們先看看模型的定義:
import torchvision.models as models net = models.resnet50() print(net)
ResNet( (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True) (relu): ReLU(inplace=True) (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False) (layer1): Sequential( (0): Bottleneck( (conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False) .............. (avgpool): AdaptiveAvgPool2d(output_size=(1, 1)) (fc): Linear(in_features=2048, out_features=1000, bias=True) )
為了適配ImageNet,fc層輸出是1000,若需要用這個(gè)resnet模型去做一個(gè)10分類的問題,就應(yīng)該修改模型的fc層,將其輸出節(jié)點(diǎn)數(shù)替換為10。另外,我們覺得一層全連接層可能太少了,想再加一層。
可以做如下修改:
from collections import OrderedDict classifier = nn.Sequential(OrderedDict([('fc1', nn.Linear(2048, 128)), ('relu1', nn.ReLU()), ('dropout1',nn.Dropout(0.5)), ('fc2', nn.Linear(128, 10)), ('output', nn.Softmax(dim=1)) ])) net.fc = classifier # 將模型(net)最后名稱為“fc”的層替換成了我們自己定義的名稱為“classifier”的結(jié)構(gòu)
添加外部輸入
有時(shí)候在模型訓(xùn)練中,除了已有模型的輸入之外,還需要輸入額外的信息。比如在CNN網(wǎng)絡(luò)中,我們除了輸入圖像,還需要同時(shí)輸入圖像對(duì)應(yīng)的其他信息,這時(shí)候就需要在已有的CNN網(wǎng)絡(luò)中添加額外的輸入變量。
基本思路是:將原模型添加輸入位置前的部分作為一個(gè)整體,同時(shí)在forward中定義好原模型不變的部分、添加的輸入和后續(xù)層之間的連接關(guān)系,從而完成模型的修改。
我們以torchvision的resnet50模型為基礎(chǔ),任務(wù)還是10分類任務(wù)。不同點(diǎn)在于,我們希望利用已有的模型結(jié)構(gòu),在倒數(shù)第二層增加一個(gè)額外的輸入變量add_variable來輔助預(yù)測(cè)。
具體實(shí)現(xiàn)如下:
class Model(nn.Module): def __init__(self, net): super(Model, self).__init__() self.net = net self.relu = nn.ReLU() self.dropout = nn.Dropout(0.5) self.fc_add = nn.Linear(1001, 10, bias=True) self.output = nn.Softmax(dim=1) def forward(self, x, add_variable): x = self.net(x) # add_variable (batch_size, )->(batch_size, 1) x = torch.cat((self.dropout(self.relu(x)), add_variable.unsqueeze(1)),1) x = self.fc_add(x) x = self.output(x) return x
修改好的模型結(jié)構(gòu)進(jìn)行實(shí)例化,就可以使用
import torchvision.models as models net = models.resnet50() model = Model(net).cuda() # 使用時(shí)輸入兩個(gè)inputs outputs = model(inputs, add_var)
參考資料:
Pytorch模型定義與深度學(xué)習(xí)自查手冊(cè)
以上就是pytorch常用函數(shù)定義及resnet模型修改實(shí)例的詳細(xì)內(nèi)容,更多關(guān)于pytorch函數(shù)resnet模型修改的資料請(qǐng)關(guān)注腳本之家其它相關(guān)文章!
相關(guān)文章
python 定義n個(gè)變量方法 (變量聲明自動(dòng)化)
今天小編就為大家分享一篇python 定義n個(gè)變量方法 (變量聲明自動(dòng)化),具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過來看看吧2018-11-11flask框架使用orm連接數(shù)據(jù)庫(kù)的方法示例
這篇文章主要介紹了flask框架使用orm連接數(shù)據(jù)庫(kù)的方法,結(jié)合實(shí)例形式分析了flask框架使用flask_sqlalchemy包進(jìn)行mysql數(shù)據(jù)庫(kù)連接操作的具體步驟與相關(guān)實(shí)現(xiàn)技巧,需要的朋友可以參考下2018-07-07python用tkinter實(shí)現(xiàn)一個(gè)gui的翻譯工具
這篇文章主要介紹了python用tkinter實(shí)現(xiàn)一個(gè)gui的翻譯工具,幫助大家更好的理解和使用python,感興趣的朋友可以了解下 +2020-10-10解決python中os.listdir()函數(shù)讀取文件夾下文件的亂序和排序問題
今天小編就為大家分享一篇解決python中os.listdir()函數(shù)讀取文件夾下文件的亂序和排序問題,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過來看看吧2018-10-10關(guān)于python爬蟲應(yīng)用urllib庫(kù)作用分析
這篇文章主要介紹了關(guān)于python爬蟲應(yīng)用urllib庫(kù)作用分析,想要進(jìn)行python爬蟲首先我們需要先將網(wǎng)頁(yè)上面的信息給獲取下來,這就是utllib庫(kù)的作用,有需要的朋友可以借鑒參考下2021-09-09Ranorex通過Python將報(bào)告發(fā)送到郵箱的方法
這篇文章主要介紹了Ranorex通過Python將報(bào)告發(fā)送到郵箱的方法,文中通過示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2020-01-01Python利用GDAL模塊實(shí)現(xiàn)讀取柵格數(shù)據(jù)并對(duì)指定數(shù)據(jù)加以篩選掩膜
這篇文章主要為大家詳細(xì)介紹了如何基于Python語(yǔ)言中g(shù)dal模塊,對(duì)遙感影像數(shù)據(jù)進(jìn)行柵格讀取與計(jì)算,同時(shí)基于QA波段對(duì)像元加以篩選、掩膜的操作,需要的可以參考一下2023-02-02Python檢查和同步本地時(shí)間(北京時(shí)間)的實(shí)現(xiàn)方法
這篇文章主要介紹了Python檢查和同步本地時(shí)間(北京時(shí)間)的實(shí)現(xiàn)方法,小編覺得挺不錯(cuò)的,現(xiàn)在分享給大家,也給大家做個(gè)參考。一起跟隨小編過來看看吧2018-12-12