pytorch模型保存方式
pytorch模型保存
保存模型主要分為兩類:
- 保存整個模型
- 只保存模型參數(shù)
1.保存加載整個模型(不推薦)
保存整個網(wǎng)絡(luò)模型,網(wǎng)絡(luò)結(jié)構(gòu)+權(quán)重參數(shù)
torch.save(model,'net.pth')
加載整個網(wǎng)絡(luò)模型(可能比較耗時)
model=torch.load('net.pth')
2.只保存加載模型參數(shù)(推薦)
保存模型的權(quán)重參數(shù)(速度快,占內(nèi)存少)
torch.save(model.state_dict(),'net_params.pth')
load 模型參數(shù)
因為我們只保存了 模型的參數(shù),所以需要先定義一個網(wǎng)絡(luò)對象,然后再加載模型參數(shù)。
model=myNet()
#將模型參數(shù)加載到新模型中,torch.load返回的是一個OrderedDict,說明.state_dict()只是把所有模型的參數(shù)都已OrderedDict的形式存下來。
state_dict=torch.load('net_params.pth') model.load_state_dict(state_dict)
Note:保存模型進(jìn)行推理測試時,只需保存訓(xùn)練好的模型的權(quán)重參數(shù),即推薦第二種方法。
load_state_dict的參數(shù)strict=False new_model.load_state_dict(state_dict,strict=False)
如果哪一天我們需要重新寫這個網(wǎng)絡(luò)的,比如使用new_model,如果直接load會出現(xiàn)unexpected key.
但是加上strict=False可以很容易地加載預(yù)訓(xùn)練的參數(shù)(注意檢查key是否匹配),直接忽略不匹配的key,對于匹配的key則進(jìn)行正常的賦值。
總結(jié)
以上為個人經(jīng)驗,希望能給大家一個參考,也希望大家多多支持腳本之家。
相關(guān)文章
python?dataframe獲得指定行列實戰(zhàn)代碼
對于一個DataFrame,常常需要篩選出某列為指定值的行,下面這篇文章主要給大家介紹了關(guān)于python?dataframe獲得指定行列的相關(guān)資料,文中通過代碼介紹的非常詳細(xì),需要的朋友可以參考下2023-12-12利用Python實現(xiàn)自動化監(jiān)控文件夾完成服務(wù)部署
本篇文章將為大家詳細(xì)介紹如何利用Python語言實現(xiàn)監(jiān)控文件夾,以此輔助完成服務(wù)的部署動作,文中的示例代碼講解詳細(xì),感興趣的可以嘗試一下2022-07-07