pytorch使用resnet快速加載官方提供的預(yù)訓(xùn)練模型
使用resnet快速加載官方提供的預(yù)訓(xùn)練模型
在做神經(jīng)網(wǎng)絡(luò)的搭建過(guò)程,經(jīng)常使用pytorch中的resnet作為backbone,特別是resnet50,
比如下面的這個(gè)網(wǎng)絡(luò)設(shè)定:
import torch import torch.nn as nn from torchvision import datasets, transforms from torchvision import models class base_resnet(nn.Module): def __init__(self): super(base_resnet, self).__init__() self.model = models.resnet50(pretrained=True) #self.model.load_state_dict(torch.load('./model/resnet50-19c8e357.pth')) self.model.avgpool = nn.AdaptiveAvgPool2d((1, 1)) def forward(self, x): x = self.model.conv1(x) x = self.model.bn1(x) x = self.model.relu(x) x = self.model.maxpool(x) x = self.model.layer1(x) x = self.model.layer2(x) x = self.model.layer3(x) x = self.model.layer4(x) x = self.model.avgpool(x) # x = x.view(x.size(0), x.size(1)) return x
該網(wǎng)絡(luò)相當(dāng)于繼承了resnet50的所有參數(shù)結(jié)構(gòu),只不過(guò)是在forward中,改變了數(shù)據(jù)的傳輸過(guò)程,沒(méi)有經(jīng)過(guò)最后的特征展開(kāi)以及線(xiàn)性分類(lèi)。
在下面的這行代碼中,是相當(dāng)于調(diào)用了pytoch中定義的resnet50網(wǎng)絡(luò),并且會(huì)自動(dòng)下載并且加載訓(xùn)練好的網(wǎng)絡(luò)參數(shù),如果調(diào)為 pretrained=False,則不會(huì)加載訓(xùn)練好的參數(shù),而是隨機(jī)進(jìn)行參數(shù)的賦值。
但是我在服務(wù)器上跑這一類(lèi)代碼的時(shí)候發(fā)現(xiàn),每當(dāng)我重新跑一次程序,如果設(shè)置為T(mén)rue都會(huì)重新下載resnet50訓(xùn)練好的參數(shù),但是由于有時(shí)候網(wǎng)絡(luò)特別不好,導(dǎo)致我下載個(gè)基礎(chǔ)的resnet50就要耗費(fèi)我好長(zhǎng)時(shí)間,那么我就想能不能將這個(gè)resnet50的參數(shù)提前下載好,使用的時(shí)候直接加載呢。
當(dāng)然是能了。
self.model = models.resnet50(pretrained=True)
我們可以根據(jù)我們使用的結(jié)構(gòu),到對(duì)應(yīng)的地址下載對(duì)應(yīng)的模型到本地,常用的resnet的地址如下:
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
將其下載下來(lái),然后將模型放入到和net.py同目錄的model文件夾下面,然后使用下面的代碼就可以避免每次都重新下載模型的問(wèn)題了。
self.model = models.resnet50(pretrained=False) self.model.load_state_dict(torch.load('./model/resnet50-19c8e357.pth'))
pytorch代碼規(guī)范之加載預(yù)訓(xùn)練模型
加載預(yù)訓(xùn)練模型,并去除需要再次訓(xùn)練的層
model=resnet()#自己構(gòu)建的模型,以resnet為例, 需要重新訓(xùn)練的層的名字要和之前的不同。 model_dict = model.state_dict() pretrained_dict = torch.load('xxx.pkl') pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} model_dict.update(pretrained_dict) model.load_state_dict(model_dict)
固定部分參數(shù)
#k是可訓(xùn)練參數(shù)的名字,v是包含可訓(xùn)練參數(shù)的一個(gè)實(shí)體 #可以先print(k),找到自己想進(jìn)行調(diào)整的層,并將該層的名字加入到if語(yǔ)句中: for k,v in model.named_parameters(): if k!='xxx.weight' and k!='xxx.bias' : v.requires_grad=False#固定參數(shù)
訓(xùn)練部分參數(shù)
#將要訓(xùn)練的參數(shù)放入優(yōu)化器 optimizer2=torch.optim.Adam(params=[model.xxx.weight,model.xxx.bias],lr=learning_rate,betas=(0.9,0.999),weight_decay=1e-5)
檢查是否固定
for k,v in model.named_parameters(): if k!='xxx.weight' and k!='xxx.bias' : print(v.requires_grad)#理想狀態(tài)下,所有值都是False
總結(jié)
以上為個(gè)人經(jīng)驗(yàn),希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
相關(guān)文章
Pycharm激活方法及詳細(xì)教程(詳細(xì)且實(shí)用)
這篇文章主要介紹了Pycharm激活方法及詳細(xì)教程,本文通過(guò)圖文并茂的形式給大家介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或工作具有一定的參考借鑒價(jià)值,需要的朋友參考下吧2020-05-05Python下opencv使用hough變換檢測(cè)直線(xiàn)與圓
在數(shù)字圖像中,往往存在著一些特殊形狀的幾何圖形,像檢測(cè)馬路邊一條直線(xiàn),檢測(cè)人眼的圓形等等,有時(shí)我們需要把這些特定圖形檢測(cè)出來(lái),本文就詳細(xì)的介紹了一下方法2021-06-06E: 無(wú)法定位軟件包 python3-pip問(wèn)題及解決
這篇文章主要介紹了E: 無(wú)法定位軟件包 python3-pip問(wèn)題及解決方案,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。如有錯(cuò)誤或未考慮完全的地方,望不吝賜教2023-02-02python教程對(duì)函數(shù)中的參數(shù)進(jìn)行排序
這篇文章主要介紹了python教程對(duì)函數(shù)中的參數(shù)進(jìn)行排序的方法講解,有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進(jìn)步,早日升職加薪2021-09-09Datawhale練習(xí)之二手車(chē)價(jià)格預(yù)測(cè)
此篇文章是關(guān)于Datawhale練習(xí),代碼完整,但由于該數(shù)據(jù)集中數(shù)據(jù)特征較少(39維),以下可作為少量特征情況下的分析。當(dāng)特征數(shù)目過(guò)大(成千上萬(wàn))時(shí),需要繼續(xù)學(xué)習(xí)。需要的朋友可以參考下2021-04-04詳解如何用TensorFlow訓(xùn)練和識(shí)別/分類(lèi)自定義圖片
這篇文章主要介紹了詳解如何用TensorFlow訓(xùn)練和識(shí)別/分類(lèi)自定義圖片,文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來(lái)一起學(xué)習(xí)學(xué)習(xí)吧2019-08-08python安裝numpy&安裝matplotlib& scipy的教程
下面小編就為大家?guī)?lái)一篇python安裝numpy&安裝matplotlib& scipy的教程。小編覺(jué)得挺不錯(cuò)的,現(xiàn)在就分享給大家,也給大家做個(gè)參考。一起跟隨小編過(guò)來(lái)看看吧2017-11-11一篇文章帶你深入學(xué)習(xí)Python函數(shù)
這篇文章主要帶大家深入學(xué)習(xí)Python函數(shù),具有一定的參考價(jià)值,感興趣的小伙伴們可以參考一下,希望能夠給你帶來(lái)幫助2022-01-01