欧美bbbwbbbw肥妇,免费乱码人妻系列日韩,一级黄片

pytorch使用resnet快速加載官方提供的預訓練模型

 更新時間:2023年09月09日 09:34:03   作者:Tchunren  
這篇文章主要介紹了pytorch使用resnet快速加載官方提供的預訓練模型方式,具有很好的參考價值,希望對大家有所幫助,如有錯誤或未考慮完全的地方,望不吝賜教

使用resnet快速加載官方提供的預訓練模型

在做神經(jīng)網(wǎng)絡的搭建過程,經(jīng)常使用pytorch中的resnet作為backbone,特別是resnet50,

比如下面的這個網(wǎng)絡設定:

import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torchvision import models
class base_resnet(nn.Module):
    def __init__(self):
        super(base_resnet, self).__init__()
        self.model = models.resnet50(pretrained=True)
        #self.model.load_state_dict(torch.load('./model/resnet50-19c8e357.pth'))
        self.model.avgpool = nn.AdaptiveAvgPool2d((1, 1))
    def forward(self, x):
        x = self.model.conv1(x)
        x = self.model.bn1(x)
        x = self.model.relu(x)
        x = self.model.maxpool(x)
        x = self.model.layer1(x)
        x = self.model.layer2(x)
        x = self.model.layer3(x)
        x = self.model.layer4(x)
        x = self.model.avgpool(x)
        # x = x.view(x.size(0), x.size(1))
        return x

該網(wǎng)絡相當于繼承了resnet50的所有參數(shù)結(jié)構(gòu),只不過是在forward中,改變了數(shù)據(jù)的傳輸過程,沒有經(jīng)過最后的特征展開以及線性分類。

在下面的這行代碼中,是相當于調(diào)用了pytoch中定義的resnet50網(wǎng)絡,并且會自動下載并且加載訓練好的網(wǎng)絡參數(shù),如果調(diào)為 pretrained=False,則不會加載訓練好的參數(shù),而是隨機進行參數(shù)的賦值。

但是我在服務器上跑這一類代碼的時候發(fā)現(xiàn),每當我重新跑一次程序,如果設置為True都會重新下載resnet50訓練好的參數(shù),但是由于有時候網(wǎng)絡特別不好,導致我下載個基礎的resnet50就要耗費我好長時間,那么我就想能不能將這個resnet50的參數(shù)提前下載好,使用的時候直接加載呢。

當然是能了。

self.model = models.resnet50(pretrained=True)

我們可以根據(jù)我們使用的結(jié)構(gòu),到對應的地址下載對應的模型到本地,常用的resnet的地址如下:

 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',

將其下載下來,然后將模型放入到和net.py同目錄的model文件夾下面,然后使用下面的代碼就可以避免每次都重新下載模型的問題了。

self.model = models.resnet50(pretrained=False)
self.model.load_state_dict(torch.load('./model/resnet50-19c8e357.pth'))

pytorch代碼規(guī)范之加載預訓練模型

加載預訓練模型,并去除需要再次訓練的層

model=resnet()#自己構(gòu)建的模型,以resnet為例, 需要重新訓練的層的名字要和之前的不同。
model_dict = model.state_dict()
pretrained_dict = torch.load('xxx.pkl')
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)

固定部分參數(shù)

#k是可訓練參數(shù)的名字,v是包含可訓練參數(shù)的一個實體
#可以先print(k),找到自己想進行調(diào)整的層,并將該層的名字加入到if語句中:
for k,v in model.named_parameters():
if k!='xxx.weight' and k!='xxx.bias' :
v.requires_grad=False#固定參數(shù)

訓練部分參數(shù)

#將要訓練的參數(shù)放入優(yōu)化器
optimizer2=torch.optim.Adam(params=[model.xxx.weight,model.xxx.bias],lr=learning_rate,betas=(0.9,0.999),weight_decay=1e-5)

檢查是否固定

for k,v in model.named_parameters():
if k!='xxx.weight' and k!='xxx.bias' :
print(v.requires_grad)#理想狀態(tài)下,所有值都是False

總結(jié)

以上為個人經(jīng)驗,希望能給大家一個參考,也希望大家多多支持腳本之家。

相關文章

  • Pycharm激活方法及詳細教程(詳細且實用)

    Pycharm激活方法及詳細教程(詳細且實用)

    這篇文章主要介紹了Pycharm激活方法及詳細教程,本文通過圖文并茂的形式給大家介紹的非常詳細,對大家的學習或工作具有一定的參考借鑒價值,需要的朋友參考下吧
    2020-05-05
  • python通過ElementTree操作XML

    python通過ElementTree操作XML

    這篇文章介紹了python通過ElementTree操作XML的方法,文中通過示例代碼介紹的非常詳細。對大家的學習或工作具有一定的參考借鑒價值,需要的朋友可以參考下
    2022-07-07
  • Python下opencv使用hough變換檢測直線與圓

    Python下opencv使用hough變換檢測直線與圓

    在數(shù)字圖像中,往往存在著一些特殊形狀的幾何圖形,像檢測馬路邊一條直線,檢測人眼的圓形等等,有時我們需要把這些特定圖形檢測出來,本文就詳細的介紹了一下方法
    2021-06-06
  • E: 無法定位軟件包 python3-pip問題及解決

    E: 無法定位軟件包 python3-pip問題及解決

    這篇文章主要介紹了E: 無法定位軟件包 python3-pip問題及解決方案,具有很好的參考價值,希望對大家有所幫助。如有錯誤或未考慮完全的地方,望不吝賜教
    2023-02-02
  • Python實現(xiàn)二叉堆

    Python實現(xiàn)二叉堆

    二叉堆是一種特殊的堆,二叉堆是完全二元樹(二叉樹)或者是近似完全二元樹(二叉樹)。二叉堆有兩種:最大堆和最小堆。最大堆:父結(jié)點的鍵值總是大于或等于任何一個子節(jié)點的鍵值;最小堆:父結(jié)點的鍵值總是小于或等于任何一個子節(jié)點的鍵值。
    2016-02-02
  • python教程對函數(shù)中的參數(shù)進行排序

    python教程對函數(shù)中的參數(shù)進行排序

    這篇文章主要介紹了python教程對函數(shù)中的參數(shù)進行排序的方法講解,有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進步,早日升職加薪
    2021-09-09
  • Datawhale練習之二手車價格預測

    Datawhale練習之二手車價格預測

    此篇文章是關于Datawhale練習,代碼完整,但由于該數(shù)據(jù)集中數(shù)據(jù)特征較少(39維),以下可作為少量特征情況下的分析。當特征數(shù)目過大(成千上萬)時,需要繼續(xù)學習。需要的朋友可以參考下
    2021-04-04
  • 詳解如何用TensorFlow訓練和識別/分類自定義圖片

    詳解如何用TensorFlow訓練和識別/分類自定義圖片

    這篇文章主要介紹了詳解如何用TensorFlow訓練和識別/分類自定義圖片,文中通過示例代碼介紹的非常詳細,對大家的學習或者工作具有一定的參考學習價值,需要的朋友們下面隨著小編來一起學習學習吧
    2019-08-08
  • python安裝numpy&安裝matplotlib& scipy的教程

    python安裝numpy&安裝matplotlib& scipy的教程

    下面小編就為大家?guī)硪黄猵ython安裝numpy&安裝matplotlib& scipy的教程。小編覺得挺不錯的,現(xiàn)在就分享給大家,也給大家做個參考。一起跟隨小編過來看看吧
    2017-11-11
  • 一篇文章帶你深入學習Python函數(shù)

    一篇文章帶你深入學習Python函數(shù)

    這篇文章主要帶大家深入學習Python函數(shù),具有一定的參考價值,感興趣的小伙伴們可以參考一下,希望能夠給你帶來幫助
    2022-01-01

最新評論