欧美bbbwbbbw肥妇,免费乱码人妻系列日韩,一级黄片

PyTorch預訓練的實現

 更新時間:2019年09月18日 11:12:51   作者:算法學習者  
這篇文章主要介紹了PyTorch預訓練的實現,文中通過示例代碼介紹的非常詳細,對大家的學習或者工作具有一定的參考學習價值,需要的朋友們下面隨著小編來一起學習學習吧

前言

最近使用PyTorch感覺妙不可言,有種當初使用Keras的快感,而且速度還不慢。各種設計直接簡潔,方便研究,比tensorflow的臃腫好多了。今天讓我們來談談PyTorch的預訓練,主要是自己寫代碼的經驗以及論壇PyTorch Forums上的一些回答的總結整理。

直接加載預訓練模型

如果我們使用的模型和原模型完全一樣,那么我們可以直接加載別人訓練好的模型:

my_resnet = MyResNet(*args, **kwargs)
my_resnet.load_state_dict(torch.load("my_resnet.pth"))

當然這樣的加載方法是基于PyTorch推薦的存儲模型的方法:

torch.save(my_resnet.state_dict(), "my_resnet.pth")

還有第二種加載方法:

my_resnet = torch.load("my_resnet.pth")

加載部分預訓練模型

其實大多數時候我們需要根據我們的任務調節(jié)我們的模型,所以很難保證模型和公開的模型完全一樣,但是預訓練模型的參數確實有助于提高訓練的準確率,為了結合二者的優(yōu)點,就需要我們加載部分預訓練模型。

pretrained_dict = model_zoo.load_url(model_urls['resnet152'])
model_dict = model.state_dict()
# 將pretrained_dict里不屬于model_dict的鍵剔除掉
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
# 更新現有的model_dict
model_dict.update(pretrained_dict)
# 加載我們真正需要的state_dict
model.load_state_dict(model_dict)

因為需要剔除原模型中不匹配的鍵,也就是層的名字,所以我們的新模型改變了的層需要和原模型對應層的名字不一樣,比如:resnet最后一層的名字是fc(PyTorch中),那么我們修改過的resnet的最后一層就不能取這個名字,可以叫fc_

微改基礎模型預訓練

對于改動比較大的模型,我們可能需要自己實現一下再加載別人的預訓練參數。但是,對于一些基本模型PyTorch中已經有了,而且我只想進行一些小的改動那么怎么辦呢?難道我又去實現一遍嗎?當然不是。

我們首先看看怎么進行微改模型。

微改基礎模型

PyTorch中的torchvision里已經有很多常用的模型了,可以直接調用:

  • AlexNet
  • VGG
  • ResNet
  • SqueezeNet
  • DenseNet
import torchvision.models as models
resnet18 = models.resnet18()
alexnet = models.alexnet()
squeezenet = models.squeezenet1_0()
densenet = models.densenet_161()

但是對于我們的任務而言有些層并不是直接能用,需要我們微微改一下,比如,resnet最后的全連接層是分1000類,而我們只有21類;又比如,resnet第一層卷積接收的通道是3, 我們可能輸入圖片的通道是4,那么可以通過以下方法修改:

resnet.conv1 = nn.Conv2d(4, 64, kernel_size=7, stride=2, padding=3, bias=False)
resnet.fc = nn.Linear(2048, 21)

簡單預訓練

模型已經改完了,接下來我們就進行簡單預訓練吧。
我們先從torchvision中調用基本模型,加載預訓練模型,然后,重點來了,將其中的層直接替換為我們需要的層即可:

resnet = torchvision.models.resnet152(pretrained=True)
# 原本為1000類,改為10類
resnet.fc = torch.nn.Linear(2048, 10)

其中使用了pretrained參數,會直接加載預訓練模型,內部實現和前文提到的加載預訓練的方法一樣。因為是先加載的預訓練參數,相當于模型中已經有參數了,所以替換掉最后一層即可。OK!

以上就是本文的全部內容,希望對大家的學習有所幫助,也希望大家多多支持腳本之家。

相關文章

  • 用Python批量把文件復制到另一個文件夾的實現方法

    用Python批量把文件復制到另一個文件夾的實現方法

    這篇文章主要介紹了用Python批量把文件復制到另一個文件夾的實現方法,文中通過示例代碼介紹的非常詳細,對大家的學習或者工作具有一定的參考學習價值,需要的朋友們下面隨著小編來一起學習學習吧
    2019-08-08
  • python+opencv實現移動偵測(幀差法)

    python+opencv實現移動偵測(幀差法)

    這篇文章主要為大家詳細介紹了python+opencv實現移動偵測,文中示例代碼介紹的非常詳細,具有一定的參考價值,感興趣的小伙伴們可以參考一下
    2020-03-03
  • 對Python新手編程過程中如何規(guī)避一些常見問題的建議

    對Python新手編程過程中如何規(guī)避一些常見問題的建議

    這篇文章中作者對Python新手編程過程中如何規(guī)避一些常見問題給出了建議,主要著眼于初學者對于一些常用函數方法在平時的使用習慣中的問題給出建議,需要的朋友可以參考下
    2015-04-04
  • python和pygame實現簡單俄羅斯方塊游戲

    python和pygame實現簡單俄羅斯方塊游戲

    這篇文章主要為大家詳細介紹了python和pygame實現簡單俄羅斯方塊游戲,文中示例代碼介紹的非常詳細,具有一定的參考價值,感興趣的小伙伴們可以參考一下
    2018-06-06
  • python返回昨天日期的方法

    python返回昨天日期的方法

    這篇文章主要介紹了python返回昨天日期的方法,涉及Python日期操作的相關技巧,需要的朋友可以參考下
    2015-05-05
  • Pycharm學習教程(3) 代碼運行調試

    Pycharm學習教程(3) 代碼運行調試

    這篇文章主要為大家詳細介紹了最全的Pycharm學習教程第三篇代碼運行調試,具有一定的參考價值,感興趣的小伙伴們可以參考一下
    2017-05-05
  • python3 mmh3安裝及使用方法

    python3 mmh3安裝及使用方法

    這篇文章主要介紹了python3 mmh3安裝及使用方法,本文給大家介紹的非常詳細,具有一定的參考借鑒價值,需要的朋友可以參考下
    2019-10-10
  • Tensorflow 實現釋放內存

    Tensorflow 實現釋放內存

    今天小編就為大家分享一篇Tensorflow 實現釋放內存,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧
    2020-02-02
  • Python Numpy 數組的初始化和基本操作

    Python Numpy 數組的初始化和基本操作

    Python 是一種高級的,動態(tài)的,多泛型的編程語言。接下來通過本文給大家介紹Python Numpy 數組的初始化和基本操作,感興趣的朋友一起看看吧
    2018-03-03
  • Python獲取時光網電影數據的實例代碼

    Python獲取時光網電影數據的實例代碼

    這篇文章主要介紹了Python獲取時光網電影數據,基本原理是先通過requests庫,通過時光網自帶的電影數據API接口,獲取到指定的電影數據,本文結合示例代碼給大家介紹的非常詳細,需要的朋友可以參考下
    2022-09-09

最新評論