Pytorch模型參數(shù)的保存和加載
一、前言
在模型訓(xùn)練完成后,我們需要保存模型參數(shù)值用于后續(xù)的測(cè)試過程。由于保存整個(gè)模型將耗費(fèi)大量的存儲(chǔ),故推薦的做法是只保存參數(shù),使用時(shí)只需在建好模型的基礎(chǔ)上加載。
通常來說,保存的對(duì)象包括網(wǎng)絡(luò)參數(shù)值、優(yōu)化器參數(shù)值、epoch值等。本文將簡單介紹保存和加載模型參數(shù)的方法,同時(shí)也給出保存整個(gè)模型的方法供大家參考。
二、參數(shù)保存
在這里我們使用 torch.save() 函數(shù)保存模型參數(shù):
import torch path = './model.pth' torch.save(model.state_dict(), path)
model——指定義的模型實(shí)例變量,如model=net( )
state_dict()——state_dict( )是一個(gè)可以輕松地保存、更新、修改和恢復(fù)的python字典對(duì)象, 對(duì)于model來說,表示模型的每一層的權(quán)重及偏置等參數(shù)信息;對(duì)于 optimizer 來說,其包含了優(yōu)化器的狀態(tài)以及被使用的超參數(shù)(如lr, momentum,weight_decay等)
path——path是保存參數(shù)的路徑,一般設(shè)置為 path='./model.pth' , path='./model.pkl'等形式。
此外,如果想保存某一次訓(xùn)練采用的optimizer、epochs等信息,可將這些信息組合起來構(gòu)成一個(gè)字典保存起來:
import torch path = './model.pth' state = {'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'epoch': epoch} torch.save(state, path)
三、參數(shù)的加載
使用 load_state_dict()函數(shù)加載參數(shù)到模型中, 當(dāng)僅保存了模型參數(shù),而沒有optimizer、epochs等信息時(shí):
model.load_state_dict(torch.load(path))
model——事先定義好的跟原模型一致的模型
path——之前保存的模型參數(shù)文件
如若保存了optimizer、epochs等信息,我們這樣載入信息:
# 使用torch.load()函數(shù)將文件中字典信息載入 state_dict 變量中 state_dict = torch.load(path) # 分布加載參數(shù)到模型和優(yōu)化器 model.load_state_dict(state_dict['model']) optimizer.load_state_dict(state_dict['optimizer']) epoch = state_dict(['epoch'])
我們還可以在每n個(gè)epoch后保存一次參數(shù),以觀察不同迭代次數(shù)模型的表現(xiàn)。此時(shí)我們可設(shè)置不同的path,如 path='./model' + str(epoch) +'.pth',這樣,不同epoch的參數(shù)就能保存在不同的文件中。
四、保存和加載整個(gè)模型
使用上文提到的方法即可:
torch.save(model, path) model = torch.load(path)
五、總結(jié)
pytorch中state_dict()和load_state_dict()函數(shù)配合使用可以實(shí)現(xiàn)狀態(tài)的獲取與重載,load()和save()函數(shù)配合使用可以實(shí)現(xiàn)參數(shù)的存儲(chǔ)與讀取。掌握對(duì)應(yīng)的函數(shù)使用方法就可以游刃有余地進(jìn)行運(yùn)用。
到此這篇關(guān)于Pytorch模型參數(shù)的保存和加載的文章就介紹到這了,更多相關(guān)Pytorch模型參數(shù)保存內(nèi)容請(qǐng)搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
相關(guān)文章
Python 繪制北上廣深的地鐵路線動(dòng)態(tài)圖
這篇文章主要介紹了用python制作北上廣深——地鐵線路動(dòng)態(tài)圖,文中的示例代碼講解詳細(xì),對(duì)我們的工作或?qū)W習(xí)都有一定的價(jià)值,感興趣的同學(xué)可以學(xué)習(xí)一下2021-12-12一文搞懂Python中的進(jìn)程,線程和協(xié)程
并發(fā)編程是實(shí)現(xiàn)多任務(wù)協(xié)同處理,改善系統(tǒng)性能的方式。Python中實(shí)現(xiàn)并發(fā)編程主要依靠進(jìn)程、線程和協(xié)程,本文將通過示例詳解三者的區(qū)別,感興趣的可以了解一下2022-05-05python中import,from……import的使用詳解
這篇文章主要介紹了python中import,from……import的使用方式,具有很好的參考價(jià)值,希望對(duì)大家有所幫助,如有錯(cuò)誤或未考慮完全的地方,望不吝賜教2024-02-02Python3實(shí)現(xiàn)監(jiān)控新型冠狀病毒肺炎疫情的示例代碼
這篇文章主要介紹了Python3實(shí)現(xiàn)監(jiān)控新型冠狀病毒肺炎疫情的示例代碼,代碼簡單易懂,非常不錯(cuò),具有一定的參考借鑒價(jià)值,需要的朋友可以參考下2020-02-02關(guān)于Python中幾個(gè)有趣的函數(shù)和推導(dǎo)式解析
這篇文章主要介紹了關(guān)于Python中幾個(gè)有趣的函數(shù)和推導(dǎo)式解析,推導(dǎo)式comprehensions,又稱解析式,是Python的一種獨(dú)有特性,推導(dǎo)式是可以從一個(gè)數(shù)據(jù)序列構(gòu)建另一個(gè)新的數(shù)據(jù)序列的結(jié)構(gòu)體,需要的朋友可以參考下2023-08-08Python3如何根據(jù)函數(shù)名動(dòng)態(tài)調(diào)用函數(shù)
這篇文章主要介紹了Python3如何根據(jù)函數(shù)名動(dòng)態(tài)調(diào)用函數(shù)問題,具有很好的參考價(jià)值,希望對(duì)大家有所幫助,如有錯(cuò)誤或未考慮完全的地方,望不吝賜教2023-11-11python3中os.path模塊下常用的用法總結(jié)【推薦】
這篇文章主要介紹了python3中os.path模塊下常用的用法總結(jié) ,需要的朋友可以參考下2018-09-09