欧美bbbwbbbw肥妇,免费乱码人妻系列日韩,一级黄片

pytorch使用resnet快速加載官方提供的預(yù)訓(xùn)練模型

 更新時(shí)間:2023年09月09日 09:34:03   作者:Tchunren  
這篇文章主要介紹了pytorch使用resnet快速加載官方提供的預(yù)訓(xùn)練模型方式,具有很好的參考價(jià)值,希望對(duì)大家有所幫助,如有錯(cuò)誤或未考慮完全的地方,望不吝賜教

使用resnet快速加載官方提供的預(yù)訓(xùn)練模型

在做神經(jīng)網(wǎng)絡(luò)的搭建過(guò)程,經(jīng)常使用pytorch中的resnet作為backbone,特別是resnet50,

比如下面的這個(gè)網(wǎng)絡(luò)設(shè)定:

import torch
import torch.nn as nn
from torchvision import datasets, transforms
from torchvision import models
class base_resnet(nn.Module):
    def __init__(self):
        super(base_resnet, self).__init__()
        self.model = models.resnet50(pretrained=True)
        #self.model.load_state_dict(torch.load('./model/resnet50-19c8e357.pth'))
        self.model.avgpool = nn.AdaptiveAvgPool2d((1, 1))
    def forward(self, x):
        x = self.model.conv1(x)
        x = self.model.bn1(x)
        x = self.model.relu(x)
        x = self.model.maxpool(x)
        x = self.model.layer1(x)
        x = self.model.layer2(x)
        x = self.model.layer3(x)
        x = self.model.layer4(x)
        x = self.model.avgpool(x)
        # x = x.view(x.size(0), x.size(1))
        return x

該網(wǎng)絡(luò)相當(dāng)于繼承了resnet50的所有參數(shù)結(jié)構(gòu),只不過(guò)是在forward中,改變了數(shù)據(jù)的傳輸過(guò)程,沒(méi)有經(jīng)過(guò)最后的特征展開(kāi)以及線(xiàn)性分類(lèi)。

在下面的這行代碼中,是相當(dāng)于調(diào)用了pytoch中定義的resnet50網(wǎng)絡(luò),并且會(huì)自動(dòng)下載并且加載訓(xùn)練好的網(wǎng)絡(luò)參數(shù),如果調(diào)為 pretrained=False,則不會(huì)加載訓(xùn)練好的參數(shù),而是隨機(jī)進(jìn)行參數(shù)的賦值。

但是我在服務(wù)器上跑這一類(lèi)代碼的時(shí)候發(fā)現(xiàn),每當(dāng)我重新跑一次程序,如果設(shè)置為T(mén)rue都會(huì)重新下載resnet50訓(xùn)練好的參數(shù),但是由于有時(shí)候網(wǎng)絡(luò)特別不好,導(dǎo)致我下載個(gè)基礎(chǔ)的resnet50就要耗費(fèi)我好長(zhǎng)時(shí)間,那么我就想能不能將這個(gè)resnet50的參數(shù)提前下載好,使用的時(shí)候直接加載呢。

當(dāng)然是能了。

self.model = models.resnet50(pretrained=True)

我們可以根據(jù)我們使用的結(jié)構(gòu),到對(duì)應(yīng)的地址下載對(duì)應(yīng)的模型到本地,常用的resnet的地址如下:

 'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth',
 'resnet34': 'https://download.pytorch.org/models/resnet34-333f7ec4.pth',
 'resnet50': 'https://download.pytorch.org/models/resnet50-19c8e357.pth',
 'resnet101': 'https://download.pytorch.org/models/resnet101-5d3b4d8f.pth',
 'resnet152': 'https://download.pytorch.org/models/resnet152-b121ed2d.pth',

將其下載下來(lái),然后將模型放入到和net.py同目錄的model文件夾下面,然后使用下面的代碼就可以避免每次都重新下載模型的問(wèn)題了。

self.model = models.resnet50(pretrained=False)
self.model.load_state_dict(torch.load('./model/resnet50-19c8e357.pth'))

pytorch代碼規(guī)范之加載預(yù)訓(xùn)練模型

加載預(yù)訓(xùn)練模型,并去除需要再次訓(xùn)練的層

model=resnet()#自己構(gòu)建的模型,以resnet為例, 需要重新訓(xùn)練的層的名字要和之前的不同。
model_dict = model.state_dict()
pretrained_dict = torch.load('xxx.pkl')
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
model_dict.update(pretrained_dict)
model.load_state_dict(model_dict)

固定部分參數(shù)

#k是可訓(xùn)練參數(shù)的名字,v是包含可訓(xùn)練參數(shù)的一個(gè)實(shí)體
#可以先print(k),找到自己想進(jìn)行調(diào)整的層,并將該層的名字加入到if語(yǔ)句中:
for k,v in model.named_parameters():
if k!='xxx.weight' and k!='xxx.bias' :
v.requires_grad=False#固定參數(shù)

訓(xùn)練部分參數(shù)

#將要訓(xùn)練的參數(shù)放入優(yōu)化器
optimizer2=torch.optim.Adam(params=[model.xxx.weight,model.xxx.bias],lr=learning_rate,betas=(0.9,0.999),weight_decay=1e-5)

檢查是否固定

for k,v in model.named_parameters():
if k!='xxx.weight' and k!='xxx.bias' :
print(v.requires_grad)#理想狀態(tài)下,所有值都是False

總結(jié)

以上為個(gè)人經(jīng)驗(yàn),希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。

相關(guān)文章

  • Pycharm激活方法及詳細(xì)教程(詳細(xì)且實(shí)用)

    Pycharm激活方法及詳細(xì)教程(詳細(xì)且實(shí)用)

    這篇文章主要介紹了Pycharm激活方法及詳細(xì)教程,本文通過(guò)圖文并茂的形式給大家介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或工作具有一定的參考借鑒價(jià)值,需要的朋友參考下吧
    2020-05-05
  • python通過(guò)ElementTree操作XML

    python通過(guò)ElementTree操作XML

    這篇文章介紹了python通過(guò)ElementTree操作XML的方法,文中通過(guò)示例代碼介紹的非常詳細(xì)。對(duì)大家的學(xué)習(xí)或工作具有一定的參考借鑒價(jià)值,需要的朋友可以參考下
    2022-07-07
  • Python下opencv使用hough變換檢測(cè)直線(xiàn)與圓

    Python下opencv使用hough變換檢測(cè)直線(xiàn)與圓

    在數(shù)字圖像中,往往存在著一些特殊形狀的幾何圖形,像檢測(cè)馬路邊一條直線(xiàn),檢測(cè)人眼的圓形等等,有時(shí)我們需要把這些特定圖形檢測(cè)出來(lái),本文就詳細(xì)的介紹了一下方法
    2021-06-06
  • E: 無(wú)法定位軟件包 python3-pip問(wèn)題及解決

    E: 無(wú)法定位軟件包 python3-pip問(wèn)題及解決

    這篇文章主要介紹了E: 無(wú)法定位軟件包 python3-pip問(wèn)題及解決方案,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。如有錯(cuò)誤或未考慮完全的地方,望不吝賜教
    2023-02-02
  • Python實(shí)現(xiàn)二叉堆

    Python實(shí)現(xiàn)二叉堆

    二叉堆是一種特殊的堆,二叉堆是完全二元樹(shù)(二叉樹(shù))或者是近似完全二元樹(shù)(二叉樹(shù))。二叉堆有兩種:最大堆和最小堆。最大堆:父結(jié)點(diǎn)的鍵值總是大于或等于任何一個(gè)子節(jié)點(diǎn)的鍵值;最小堆:父結(jié)點(diǎn)的鍵值總是小于或等于任何一個(gè)子節(jié)點(diǎn)的鍵值。
    2016-02-02
  • python教程對(duì)函數(shù)中的參數(shù)進(jìn)行排序

    python教程對(duì)函數(shù)中的參數(shù)進(jìn)行排序

    這篇文章主要介紹了python教程對(duì)函數(shù)中的參數(shù)進(jìn)行排序的方法講解,有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進(jìn)步,早日升職加薪
    2021-09-09
  • Datawhale練習(xí)之二手車(chē)價(jià)格預(yù)測(cè)

    Datawhale練習(xí)之二手車(chē)價(jià)格預(yù)測(cè)

    此篇文章是關(guān)于Datawhale練習(xí),代碼完整,但由于該數(shù)據(jù)集中數(shù)據(jù)特征較少(39維),以下可作為少量特征情況下的分析。當(dāng)特征數(shù)目過(guò)大(成千上萬(wàn))時(shí),需要繼續(xù)學(xué)習(xí)。需要的朋友可以參考下
    2021-04-04
  • 詳解如何用TensorFlow訓(xùn)練和識(shí)別/分類(lèi)自定義圖片

    詳解如何用TensorFlow訓(xùn)練和識(shí)別/分類(lèi)自定義圖片

    這篇文章主要介紹了詳解如何用TensorFlow訓(xùn)練和識(shí)別/分類(lèi)自定義圖片,文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來(lái)一起學(xué)習(xí)學(xué)習(xí)吧
    2019-08-08
  • python安裝numpy&安裝matplotlib& scipy的教程

    python安裝numpy&安裝matplotlib& scipy的教程

    下面小編就為大家?guī)?lái)一篇python安裝numpy&安裝matplotlib& scipy的教程。小編覺(jué)得挺不錯(cuò)的,現(xiàn)在就分享給大家,也給大家做個(gè)參考。一起跟隨小編過(guò)來(lái)看看吧
    2017-11-11
  • 一篇文章帶你深入學(xué)習(xí)Python函數(shù)

    一篇文章帶你深入學(xué)習(xí)Python函數(shù)

    這篇文章主要帶大家深入學(xué)習(xí)Python函數(shù),具有一定的參考價(jià)值,感興趣的小伙伴們可以參考一下,希望能夠給你帶來(lái)幫助
    2022-01-01

最新評(píng)論