pytorch使用resnet快速加載官方提供的預(yù)訓(xùn)練模型
使用resnet快速加載官方提供的預(yù)訓(xùn)練模型
在做神經(jīng)網(wǎng)絡(luò)的搭建過程,經(jīng)常使用pytorch中的resnet作為backbone,特別是resnet50,
比如下面的這個(gè)網(wǎng)絡(luò)設(shè)定:
import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torchvision import models
class base_resnet(nn.Module):
def __init__(self):
super(base_resnet, self).__init__()
self.model = models.resnet50(pretrained=True)
#self.model.load_state_dict(torch.load('./model/resnet50-19c8e357.pth'))
self.model.avgpool = nn.AdaptiveAvgPool2d((1, 1))
def forward(self, x):
x = self.model.conv1(x)
x = self.model.bn1(x)
x = self.model.relu(x)
x = self.model.maxpool(x)
x = self.model.layer1(x)
x = self.model.layer2(x)
x = self.model.layer3(x)
x = self.model.layer4(x)
x = self.model.avgpool(x)
# x = x.view(x.size(0), x.size(1))
return x該網(wǎng)絡(luò)相當(dāng)于繼承了resnet50的所有參數(shù)結(jié)構(gòu),只不過是在forward中,改變了數(shù)據(jù)的傳輸過程,沒有經(jīng)過最后的特征展開以及線性分類。
在下面的這行代碼中,是相當(dāng)于調(diào)用了pytoch中定義的resnet50網(wǎng)絡(luò),并且會(huì)自動(dòng)下載并且加載訓(xùn)練好的網(wǎng)絡(luò)參數(shù),如果調(diào)為 pretrained=False,則不會(huì)加載訓(xùn)練好的參數(shù),而是隨機(jī)進(jìn)行參數(shù)的賦值。
但是我在服務(wù)器上跑這一類代碼的時(shí)候發(fā)現(xiàn),每當(dāng)我重新跑一次程序,如果設(shè)置為True都會(huì)重新下載resnet50訓(xùn)練好的參數(shù),但是由于有時(shí)候網(wǎng)絡(luò)特別不好,導(dǎo)致我下載個(gè)基礎(chǔ)的resnet50就要耗費(fèi)我好長(zhǎng)時(shí)間,那么我就想能不能將這個(gè)resnet50的參數(shù)提前下載好,使用的時(shí)候直接加載呢。
當(dāng)然是能了。
self.model = models.resnet50(pretrained=True)
我們可以根據(jù)我們使用的結(jié)構(gòu),到對(duì)應(yīng)的地址下載對(duì)應(yīng)的模型到本地,常用的resnet的地址如下:
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
將其下載下來,然后將模型放入到和net.py同目錄的model文件夾下面,然后使用下面的代碼就可以避免每次都重新下載模型的問題了。
self.model = models.resnet50(pretrained=False)
self.model.load_state_dict(torch.load('./model/resnet50-19c8e357.pth'))pytorch代碼規(guī)范之加載預(yù)訓(xùn)練模型
加載預(yù)訓(xùn)練模型,并去除需要再次訓(xùn)練的層
model=resnet()#自己構(gòu)建的模型,以resnet為例, 需要重新訓(xùn)練的層的名字要和之前的不同。
model_dict = model.state_dict()
pretrained_dict = torch.load('xxx.pkl')
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)固定部分參數(shù)
#k是可訓(xùn)練參數(shù)的名字,v是包含可訓(xùn)練參數(shù)的一個(gè)實(shí)體 #可以先print(k),找到自己想進(jìn)行調(diào)整的層,并將該層的名字加入到if語句中: for k,v in model.named_parameters(): if k!='xxx.weight' and k!='xxx.bias' : v.requires_grad=False#固定參數(shù)
訓(xùn)練部分參數(shù)
#將要訓(xùn)練的參數(shù)放入優(yōu)化器 optimizer2=torch.optim.Adam(params=[model.xxx.weight,model.xxx.bias],lr=learning_rate,betas=(0.9,0.999),weight_decay=1e-5)
檢查是否固定
for k,v in model.named_parameters(): if k!='xxx.weight' and k!='xxx.bias' : print(v.requires_grad)#理想狀態(tài)下,所有值都是False
總結(jié)
以上為個(gè)人經(jīng)驗(yàn),希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
相關(guān)文章
Pycharm激活方法及詳細(xì)教程(詳細(xì)且實(shí)用)
這篇文章主要介紹了Pycharm激活方法及詳細(xì)教程,本文通過圖文并茂的形式給大家介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或工作具有一定的參考借鑒價(jià)值,需要的朋友參考下吧2020-05-05
Python下opencv使用hough變換檢測(cè)直線與圓
在數(shù)字圖像中,往往存在著一些特殊形狀的幾何圖形,像檢測(cè)馬路邊一條直線,檢測(cè)人眼的圓形等等,有時(shí)我們需要把這些特定圖形檢測(cè)出來,本文就詳細(xì)的介紹了一下方法2021-06-06
python教程對(duì)函數(shù)中的參數(shù)進(jìn)行排序
這篇文章主要介紹了python教程對(duì)函數(shù)中的參數(shù)進(jìn)行排序的方法講解,有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進(jìn)步,早日升職加薪2021-09-09
Datawhale練習(xí)之二手車價(jià)格預(yù)測(cè)
此篇文章是關(guān)于Datawhale練習(xí),代碼完整,但由于該數(shù)據(jù)集中數(shù)據(jù)特征較少(39維),以下可作為少量特征情況下的分析。當(dāng)特征數(shù)目過大(成千上萬)時(shí),需要繼續(xù)學(xué)習(xí)。需要的朋友可以參考下2021-04-04
詳解如何用TensorFlow訓(xùn)練和識(shí)別/分類自定義圖片
這篇文章主要介紹了詳解如何用TensorFlow訓(xùn)練和識(shí)別/分類自定義圖片,文中通過示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2019-08-08
python安裝numpy&安裝matplotlib& scipy的教程
下面小編就為大家?guī)硪黄猵ython安裝numpy&安裝matplotlib& scipy的教程。小編覺得挺不錯(cuò)的,現(xiàn)在就分享給大家,也給大家做個(gè)參考。一起跟隨小編過來看看吧2017-11-11
一篇文章帶你深入學(xué)習(xí)Python函數(shù)
這篇文章主要帶大家深入學(xué)習(xí)Python函數(shù),具有一定的參考價(jià)值,感興趣的小伙伴們可以參考一下,希望能夠給你帶來幫助2022-01-01

