Pytorch加載部分預(yù)訓(xùn)練模型的參數(shù)實(shí)例
前言
自從從深度學(xué)習(xí)框架caffe轉(zhuǎn)到Pytorch之后,感覺Pytorch的優(yōu)點(diǎn)妙不可言,各種設(shè)計(jì)簡潔,方便研究網(wǎng)絡(luò)結(jié)構(gòu)修改,容易上手,比TensorFlow的臃腫好多了。對于深度學(xué)習(xí)的初學(xué)者,Pytorch值得推薦。今天主要主要談?wù)凱ytorch是如何加載預(yù)訓(xùn)練模型的參數(shù)以及代碼的實(shí)現(xiàn)過程。
直接加載預(yù)選臉模型
如果我們使用的模型和預(yù)訓(xùn)練模型完全一樣,那么我們就可以直接加載別人的模型,還有一種情況,我們在訓(xùn)練自己模型的過程中,突然中斷了,但只要我們保存了之前的模型的參數(shù)也可以使用下面的代碼直接加載我們保存的模型繼續(xù)訓(xùn)練,不用從頭開始。
model=DPN(*args, **kwargs) model.load_state_dict(torch.load("DPN.pth"))
這樣的加載方式是基于Pytorch使用的模型存儲方法:
torch.save(DPN.state_dict(), "DPN.pth")
加載部分預(yù)訓(xùn)練模型參數(shù)
其實(shí)大多數(shù)時(shí)候我們根據(jù)自己的任物所提出的模型是在一些公開模型的基礎(chǔ)上改變而來,其中公開模型的參數(shù)我們沒有必要在從頭開始訓(xùn)練,只要加載其訓(xùn)練好的模型參數(shù)即可,這樣有助于提高訓(xùn)練的準(zhǔn)確率和我們模型的泛化能力。
model = DPN(num_init_features=64, k_R=96, G=32, k_sec=(3,4,20,3), inc_sec=(16,32,24,128), num_classes=1,decoder=args.decoder) http = {'url': 'http://data.lip6.fr/cadene/pretrainedmodels/dpn92_extra-b040e4a9b.pth'} pretrained_dict=model_zoo.load_url(http['url']) model_dict = model.state_dict() pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}#filter out unnecessary keys model_dict.update(pretrained_dict) model.load_state_dict(model_dict) model = torch.nn.DataParallel(model).cuda()
因?yàn)樾枰獎(jiǎng)h除預(yù)訓(xùn)練模型中不匹配的的鍵,也就是層的名字。
以上這篇Pytorch加載部分預(yù)訓(xùn)練模型的參數(shù)實(shí)例就是小編分享給大家的全部內(nèi)容了,希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
相關(guān)文章
Python Dataframe 指定多列去重、求差集的方法
今天小編就為大家分享一篇Python Dataframe 指定多列去重、求差集的方法,具有很好的參考價(jià)值,希望對大家有所幫助。一起跟隨小編過來看看吧2018-07-07python逆向微信指數(shù)爬取實(shí)現(xiàn)步驟
這篇文章主要為大家介紹了python逆向微信指數(shù)爬取的實(shí)現(xiàn)步驟,有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進(jìn)步早日升職加薪2022-02-02python如何實(shí)現(xiàn)內(nèi)容寫在圖片上
這篇文章主要為大家詳細(xì)介紹了python如何實(shí)現(xiàn)內(nèi)容寫在圖片上,具有一定的參考價(jià)值,感興趣的小伙伴們可以參考一下2018-03-03Python實(shí)現(xiàn)Mysql全量數(shù)據(jù)同步的腳本分享
這篇文章主要為大家詳細(xì)介紹了基于Python如何實(shí)現(xiàn)Mysql全量數(shù)據(jù)同步的功能,文中的示例代碼講解詳細(xì),感興趣的小伙伴可以跟隨小編一起了解一下2023-06-06Python 網(wǎng)絡(luò)編程之TCP客戶端/服務(wù)端功能示例【基于socket套接字】
這篇文章主要介紹了Python 網(wǎng)絡(luò)編程之TCP客戶端/服務(wù)端功能,結(jié)合實(shí)例形式分析了Python使用socket套接字實(shí)現(xiàn)TCP協(xié)議下的客戶端與服務(wù)器端數(shù)據(jù)傳輸操作技巧,需要的朋友可以參考下2019-10-10python opencv 讀取本地視頻文件 修改ffmpeg的方法
今天小編就為大家分享一篇python opencv 讀取本地視頻文件 修改ffmpeg的方法,具有很好的參考價(jià)值,希望對大家有所幫助。一起跟隨小編過來看看吧2019-01-01