python PyTorch預(yù)訓(xùn)練示例
前言
最近使用PyTorch感覺妙不可言,有種當(dāng)初使用Keras的快感,而且速度還不慢。各種設(shè)計(jì)直接簡(jiǎn)潔,方便研究,比tensorflow的臃腫好多了。今天讓我們來談?wù)凱yTorch的預(yù)訓(xùn)練,主要是自己寫代碼的經(jīng)驗(yàn)以及論壇PyTorch Forums上的一些回答的總結(jié)整理。
直接加載預(yù)訓(xùn)練模型
如果我們使用的模型和原模型完全一樣,那么我們可以直接加載別人訓(xùn)練好的模型:
my_resnet = MyResNet(*args, **kwargs) my_resnet.load_state_dict(torch.load("my_resnet.pth"))
當(dāng)然這樣的加載方法是基于PyTorch推薦的存儲(chǔ)模型的方法:
torch.save(my_resnet.state_dict(), "my_resnet.pth")
還有第二種加載方法:
my_resnet = torch.load("my_resnet.pth")
加載部分預(yù)訓(xùn)練模型
其實(shí)大多數(shù)時(shí)候我們需要根據(jù)我們的任務(wù)調(diào)節(jié)我們的模型,所以很難保證模型和公開的模型完全一樣,但是預(yù)訓(xùn)練模型的參數(shù)確實(shí)有助于提高訓(xùn)練的準(zhǔn)確率,為了結(jié)合二者的優(yōu)點(diǎn),就需要我們加載部分預(yù)訓(xùn)練模型。
pretrained_dict = model_zoo.load_url(model_urls['resnet152']) model_dict = model.state_dict() # 將pretrained_dict里不屬于model_dict的鍵剔除掉 pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} # 更新現(xiàn)有的model_dict model_dict.update(pretrained_dict) # 加載我們真正需要的state_dict model.load_state_dict(model_dict)
因?yàn)樾枰蕹P椭胁黄ヅ涞逆I,也就是層的名字,所以我們的新模型改變了的層需要和原模型對(duì)應(yīng)層的名字不一樣,比如:resnet最后一層的名字是fc(PyTorch中),那么我們修改過的resnet的最后一層就不能取這個(gè)名字,可以叫fc_
微改基礎(chǔ)模型預(yù)訓(xùn)練
對(duì)于改動(dòng)比較大的模型,我們可能需要自己實(shí)現(xiàn)一下再加載別人的預(yù)訓(xùn)練參數(shù)。但是,對(duì)于一些基本模型PyTorch中已經(jīng)有了,而且我只想進(jìn)行一些小的改動(dòng)那么怎么辦呢?難道我又去實(shí)現(xiàn)一遍嗎?當(dāng)然不是。
我們首先看看怎么進(jìn)行微改模型。
微改基礎(chǔ)模型
PyTorch中的torchvision里已經(jīng)有很多常用的模型了,可以直接調(diào)用:
- AlexNet
- VGG
- ResNet
- SqueezeNet
- DenseNet
import torchvision.models as models resnet18 = models.resnet18() alexnet = models.alexnet() squeezenet = models.squeezenet1_0() densenet = models.densenet_161()
但是對(duì)于我們的任務(wù)而言有些層并不是直接能用,需要我們微微改一下,比如,resnet最后的全連接層是分1000類,而我們只有21類;又比如,resnet第一層卷積接收的通道是3, 我們可能輸入圖片的通道是4,那么可以通過以下方法修改:
resnet.conv1 = nn.Conv2d(4, 64, kernel_size=7, stride=2, padding=3, bias=False) resnet.fc = nn.Linear(2048, 21)
簡(jiǎn)單預(yù)訓(xùn)練
模型已經(jīng)改完了,接下來我們就進(jìn)行簡(jiǎn)單預(yù)訓(xùn)練吧。
我們先從torchvision中調(diào)用基本模型,加載預(yù)訓(xùn)練模型,然后,重點(diǎn)來了,將其中的層直接替換為我們需要的層即可:
resnet = torchvision.models.resnet152(pretrained=True) # 原本為1000類,改為10類 resnet.fc = torch.nn.Linear(2048, 10)
其中使用了pretrained參數(shù),會(huì)直接加載預(yù)訓(xùn)練模型,內(nèi)部實(shí)現(xiàn)和前文提到的加載預(yù)訓(xùn)練的方法一樣。因?yàn)槭窍燃虞d的預(yù)訓(xùn)練參數(shù),相當(dāng)于模型中已經(jīng)有參數(shù)了,所以替換掉最后一層即可。OK!
以上就是本文的全部?jī)?nèi)容,希望對(duì)大家的學(xué)習(xí)有所幫助,也希望大家多多支持腳本之家。
- Python深度學(xué)習(xí)之Pytorch初步使用
- Python深度學(xué)習(xí)之使用Pytorch搭建ShuffleNetv2
- python 如何查看pytorch版本
- 簡(jiǎn)述python&pytorch 隨機(jī)種子的實(shí)現(xiàn)
- 淺談pytorch、cuda、python的版本對(duì)齊問題
- python、PyTorch圖像讀取與numpy轉(zhuǎn)換實(shí)例
- 基于python及pytorch中乘法的使用詳解
- python PyTorch參數(shù)初始化和Finetune
- Python機(jī)器學(xué)習(xí)之基于Pytorch實(shí)現(xiàn)貓狗分類
相關(guān)文章
python類參數(shù)定義及數(shù)據(jù)擴(kuò)展方式unsqueeze/expand
本文主要介紹了python類參數(shù)定義及數(shù)據(jù)擴(kuò)展方式unsqueeze/expand,文章通過圍繞主題展開詳細(xì)的內(nèi)容介紹,具有一定的參考價(jià)值,需要的小伙伴可以參考一下2022-08-08python實(shí)現(xiàn)播放音頻和錄音功能示例代碼
這篇文章主要給大家介紹了關(guān)于python播放音頻和錄音的相關(guān)資料,文中通過示例代碼介紹的非常詳細(xì),對(duì)大家學(xué)習(xí)或者使用python具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2018-12-12Python中處理字符串之islower()方法的使用簡(jiǎn)介
這篇文章主要介紹了Python中處理字符串之islower()方法的使用,是Python入門的基礎(chǔ)知識(shí),需要的朋友可以參考下2015-05-05Python協(xié)程異步爬取數(shù)據(jù)(asyncio+aiohttp)實(shí)例
這篇文章主要為大家介紹了Python協(xié)程異步爬取數(shù)據(jù)(asyncio+aiohttp)實(shí)現(xiàn)示例,有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進(jìn)步,早日升職加薪2023-08-08Python中使用第三方庫(kù)xlrd來寫入Excel文件示例
這篇文章主要介紹了Python中使用第三方庫(kù)xlrd來寫入Excel文件示例,本文講解了安裝xlwt、API介紹、使用xlwt寫入Excel文件實(shí)例,需要的朋友可以參考下2015-04-04Python基礎(chǔ)之列表常見操作經(jīng)典實(shí)例詳解
這篇文章主要介紹了Python基礎(chǔ)之列表常見操作,結(jié)合實(shí)例形式詳細(xì)分析了Python列表創(chuàng)建方式、內(nèi)置函數(shù)與相關(guān)使用技巧,需要的朋友可以參考下2020-02-02Django使用詳解:ORM 的反向查找(related_name)
今天小編就為大家分享一篇Django使用詳解:ORM 的反向查找(related_name),具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過來看看吧2018-05-05