PyTorch深度學(xué)習(xí)模型的保存和加載流程詳解
一、模型參數(shù)的保存和加載
-
torch.save(module.state_dict(), path)
:使用module.state_dict()
函數(shù)獲取各層已經(jīng)訓(xùn)練好的參數(shù)和緩沖區(qū),然后將參數(shù)和緩沖區(qū)保存到path
所指定的文件存放路徑(常用文件格式為.pt
、.pth
或.pkl
)。 torch.nn.Module.load_state_dict(state_dict)
:從state_dict
中加載參數(shù)和緩沖區(qū)到Module
及其子類中 。torch.nn.Module.state_dict()
函數(shù)返回python
中的一個(gè)OrderedDict
類型字典對(duì)象,該對(duì)象將每一層與它的對(duì)應(yīng)參數(shù)和緩沖區(qū)建立映射關(guān)系,字典的鍵值是參數(shù)或緩沖區(qū)的名稱。只有那些參數(shù)可以訓(xùn)練的層才會(huì)被保存到OrderedDict
中,例如:卷積層、線性層等。Python
中的字典類以“鍵:值
”方式存取數(shù)據(jù),OrderedDict
是它的一個(gè)子類,實(shí)現(xiàn)了對(duì)字典對(duì)象中元素的排序(OrderedDict
根據(jù)放入元素的先后順序進(jìn)行排序)。由于進(jìn)行了排序,所以順序不同的兩個(gè)OrderedDict
字典對(duì)象會(huì)被當(dāng)做是兩個(gè)不同的對(duì)象。- 示例:
import torch import torch.nn as nn class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.conv1 = nn.Conv2d(1, 2, 3) self.pool1 = nn.MaxPool2d(2, 2) def forward(self, x): x = self.conv1(x) x = self.pool1(x) return x # 初始化網(wǎng)絡(luò) net = Net() net.conv1.weight[0].detach().fill_(1) net.conv1.weight[1].detach().fill_(2) net.conv1.bias.data.detach().zero_() # 獲取state_dict state_dict = net.state_dict() # 字典的遍歷默認(rèn)是遍歷key,所以param_tensor實(shí)際上是鍵值 for param_tensor in state_dict: print(param_tensor,':\n',state_dict[param_tensor]) # 保存模型參數(shù) torch.save(state_dict,"net_params.pth") # 通過(guò)加載state_dict獲取模型參數(shù) net.load_state_dict(state_dict)
輸出:
二、完整模型的保存和加載
-
torch.save(module, path)
:將訓(xùn)練完的整個(gè)網(wǎng)絡(luò)模型module
保存到path
所指定的文件存放路徑(常用文件格式為.pt
或.pth
)。 torch.load(path)
:加載保存到path
中的整個(gè)神經(jīng)網(wǎng)絡(luò)模型。- 示例:
import torch import torch.nn as nn class Net(nn.Module): def __init__(self): super(Net, self).__init__() self.conv1 = nn.Conv2d(1, 2, 3) self.pool1 = nn.MaxPool2d(2, 2) def forward(self, x): x = self.conv1(x) x = self.pool1(x) return x # 初始化網(wǎng)絡(luò) net = Net() net.conv1.weight[0].detach().fill_(1) net.conv1.weight[1].detach().fill_(2) net.conv1.bias.data.detach().zero_() # 保存整個(gè)網(wǎng)絡(luò) torch.save(net,"net.pth") # 加載網(wǎng)絡(luò) net = torch.load("net.pth")
到此這篇關(guān)于PyTorch深度學(xué)習(xí)模型的保存和加載流程詳解的文章就介紹到這了,更多相關(guān)PyTorch 模型的保存 內(nèi)容請(qǐng)搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
相關(guān)文章
python爬取代理IP并進(jìn)行有效的IP測(cè)試實(shí)現(xiàn)
這篇文章主要介紹了python爬取代理IP并進(jìn)行有效的IP測(cè)試實(shí)現(xiàn),文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來(lái)一起學(xué)習(xí)學(xué)習(xí)吧2020-10-10解決使用pycharm提交代碼時(shí)沖突之后文件丟失找回的方法
這篇文章主要介紹了解決使用pycharm提交代碼時(shí)沖突之后文件丟失找回的方法 ,需要的朋友可以參考下2018-08-08Python中xml和dict格式轉(zhuǎn)換的示例代碼
最近在做APP的接口,遇到XML格式的請(qǐng)求數(shù)據(jù),費(fèi)了很大勁來(lái)解決,下面小編給大家分享下Python中xml和dict格式轉(zhuǎn)換問(wèn)題,感興趣的朋友跟隨小編一起看看吧2019-11-11Python深度學(xué)習(xí)pytorch神經(jīng)網(wǎng)絡(luò)多輸入多輸出通道
這篇文章主要為大家介紹了Python深度學(xué)習(xí)中pytorch神經(jīng)網(wǎng)絡(luò)多輸入多輸出通道的詳解有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進(jìn)步2021-10-10Python 格式化輸出_String Formatting_控制小數(shù)點(diǎn)位數(shù)的實(shí)例詳解
在本篇文章里小編給大家整理了關(guān)于Python 格式化輸出_String Formatting_控制小數(shù)點(diǎn)位數(shù)的實(shí)例內(nèi)容,需要的朋友們參考下。2020-02-02VPS CENTOS 上配置python,mysql,nginx,uwsgi,django的方法詳解
這篇文章主要介紹了VPS CENTOS 上配置python,mysql,nginx,uwsgi,django的方法,較為詳細(xì)的分析了VPS CENTOS 上配置python,mysql,nginx,uwsgi,django的具體步驟、相關(guān)命令與操作注意事項(xiàng),需要的朋友可以參考下2019-07-07Python實(shí)現(xiàn)一行代碼自動(dòng)繪制藝術(shù)畫(huà)
DiscoArt?是一個(gè)很牛的開(kāi)源模塊,它能根據(jù)你給定的關(guān)鍵詞自動(dòng)繪畫(huà)。本文就將利用這一模塊實(shí)現(xiàn)一行代碼自動(dòng)繪制藝術(shù)畫(huà),需要的可以參考一下2022-12-12Python?PaddleNLP開(kāi)源實(shí)現(xiàn)快遞單信息抽取
這篇文章主要為大家介紹了Python?PaddleNLP開(kāi)源項(xiàng)目實(shí)現(xiàn)對(duì)快遞單信息抽取,有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進(jìn)步,早日升職加薪2022-06-06國(guó)產(chǎn)化設(shè)備鯤鵬CentOS7上源碼安裝Python3.7的過(guò)程詳解
這篇文章主要介紹了國(guó)產(chǎn)化設(shè)備鯤鵬CentOS7上源碼安裝Python3.7,本文給大家介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或工作具有一定的參考借鑒價(jià)值,需要的朋友可以參考下2022-05-05