pytorch使用resnet快速加載官方提供的預訓練模型
使用resnet快速加載官方提供的預訓練模型
在做神經(jīng)網(wǎng)絡的搭建過程,經(jīng)常使用pytorch中的resnet作為backbone,特別是resnet50,
比如下面的這個網(wǎng)絡設定:
import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torchvision import models
class base_resnet(nn.Module):
def __init__(self):
super(base_resnet, self).__init__()
self.model = models.resnet50(pretrained=True)
#self.model.load_state_dict(torch.load('./model/resnet50-19c8e357.pth'))
self.model.avgpool = nn.AdaptiveAvgPool2d((1, 1))
def forward(self, x):
x = self.model.conv1(x)
x = self.model.bn1(x)
x = self.model.relu(x)
x = self.model.maxpool(x)
x = self.model.layer1(x)
x = self.model.layer2(x)
x = self.model.layer3(x)
x = self.model.layer4(x)
x = self.model.avgpool(x)
# x = x.view(x.size(0), x.size(1))
return x該網(wǎng)絡相當于繼承了resnet50的所有參數(shù)結(jié)構(gòu),只不過是在forward中,改變了數(shù)據(jù)的傳輸過程,沒有經(jīng)過最后的特征展開以及線性分類。
在下面的這行代碼中,是相當于調(diào)用了pytoch中定義的resnet50網(wǎng)絡,并且會自動下載并且加載訓練好的網(wǎng)絡參數(shù),如果調(diào)為 pretrained=False,則不會加載訓練好的參數(shù),而是隨機進行參數(shù)的賦值。
但是我在服務器上跑這一類代碼的時候發(fā)現(xiàn),每當我重新跑一次程序,如果設置為True都會重新下載resnet50訓練好的參數(shù),但是由于有時候網(wǎng)絡特別不好,導致我下載個基礎的resnet50就要耗費我好長時間,那么我就想能不能將這個resnet50的參數(shù)提前下載好,使用的時候直接加載呢。
當然是能了。
self.model = models.resnet50(pretrained=True)
我們可以根據(jù)我們使用的結(jié)構(gòu),到對應的地址下載對應的模型到本地,常用的resnet的地址如下:
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth', 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth', 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth', 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',
將其下載下來,然后將模型放入到和net.py同目錄的model文件夾下面,然后使用下面的代碼就可以避免每次都重新下載模型的問題了。
self.model = models.resnet50(pretrained=False)
self.model.load_state_dict(torch.load('./model/resnet50-19c8e357.pth'))pytorch代碼規(guī)范之加載預訓練模型
加載預訓練模型,并去除需要再次訓練的層
model=resnet()#自己構(gòu)建的模型,以resnet為例, 需要重新訓練的層的名字要和之前的不同。
model_dict = model.state_dict()
pretrained_dict = torch.load('xxx.pkl')
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)固定部分參數(shù)
#k是可訓練參數(shù)的名字,v是包含可訓練參數(shù)的一個實體 #可以先print(k),找到自己想進行調(diào)整的層,并將該層的名字加入到if語句中: for k,v in model.named_parameters(): if k!='xxx.weight' and k!='xxx.bias' : v.requires_grad=False#固定參數(shù)
訓練部分參數(shù)
#將要訓練的參數(shù)放入優(yōu)化器 optimizer2=torch.optim.Adam(params=[model.xxx.weight,model.xxx.bias],lr=learning_rate,betas=(0.9,0.999),weight_decay=1e-5)
檢查是否固定
for k,v in model.named_parameters(): if k!='xxx.weight' and k!='xxx.bias' : print(v.requires_grad)#理想狀態(tài)下,所有值都是False
總結(jié)
以上為個人經(jīng)驗,希望能給大家一個參考,也希望大家多多支持腳本之家。
相關文章
python安裝numpy&安裝matplotlib& scipy的教程
下面小編就為大家?guī)硪黄猵ython安裝numpy&安裝matplotlib& scipy的教程。小編覺得挺不錯的,現(xiàn)在就分享給大家,也給大家做個參考。一起跟隨小編過來看看吧2017-11-11

