pytorch保存和加載模型的方法及如何load部分參數(shù)
本文總結(jié)了pytorch中保存和加載模型的方法,以及在保存的模型文件與新定義的模型的參數(shù)不一一對(duì)應(yīng)時(shí),我們?cè)撊绾渭虞d模型參數(shù)。
1. 模型保存和加載的基本方式
在PyTorch中,模型可以通過(guò)兩種方式保存和加載:保存整個(gè)模型(包括模型架構(gòu)和參數(shù))或僅保存模型的參數(shù)(state_dict)。
保存整個(gè)模型: 保存模型的架構(gòu)和所有的權(quán)重參數(shù)。這樣做的好處是可以直接加載使用,無(wú)需再定義模型架構(gòu),但是無(wú)法再對(duì)模型做出調(diào)整,不夠靈活。
python import torch import torchvision.models as models # 實(shí)例化一個(gè)預(yù)訓(xùn)練的resnet模型 model = models.resnet18(pretrained=True) # 保存整個(gè)模型 torch.save(model, 'model.pth')
# 加載整個(gè)模型 model = torch.load('model.pth')
僅保存模型參數(shù)
通常推薦此方式,因?yàn)樗?strong>僅保存權(quán)重參數(shù),體積更小,更靈活,需要時(shí)可用新定義的模型結(jié)構(gòu)加載參數(shù)。
保存的參數(shù)通過(guò)model.state_dict()獲取,得到一個(gè)有序字典類型:collections.OrderedDict,其中key是參數(shù)名稱,value是保存了參數(shù)數(shù)值的tensor類型。
OrderedDict是 Python 標(biāo)準(zhǔn)庫(kù) collections 模塊中的一種字典(dict)類的子類。和普通的字典相比,OrderedDict 繳存了元素插入的順序,所以當(dāng)對(duì)其進(jìn)行迭代時(shí),鍵值對(duì)會(huì)按照添加的先后次序返回,而不是基于鍵的散列值。
保存模型參數(shù)示例:
# 保存模型的state_dict torch.save(model.state_dict(), 'model_state_dict.pth')
加載模型參數(shù)示例:
# 首先需要重新定義模型的結(jié)構(gòu),這里假設(shè)我們已經(jīng)有了一模一樣的模型定義 model = models.resnet18(pretrained=False) # 取消預(yù)訓(xùn)練權(quán)重 # 加載模型參數(shù) model.load_state_dict(torch.load('model_state_dict.pth'))
2. 保存的模型文件和當(dāng)前定義的模型參數(shù)不完全一致時(shí)
有時(shí)候我們會(huì)對(duì)一個(gè)pretrained model的若干層進(jìn)行一些修改,涉及到層的添加和減少,同時(shí)未改變的那些層想要load pretrained model的參數(shù)。
假設(shè)新定義的模型是new_net, pretrained模型是old_net, 以下兩種方式適用于以下所有場(chǎng)景:
1. old_net的參數(shù)是new_net的子集
2. new_net的參數(shù)是old_net的子集
3. new_net和old_net的參數(shù)有交集
strict=False
一個(gè)直接的方式是在load_state_dict時(shí)strict=False,這樣在load參數(shù)時(shí)pytorch會(huì)匹配兩個(gè)模型中參數(shù)名字相同的參數(shù)進(jìn)行導(dǎo)入。
net_2.load_state_dict(torch.load("net_1.pth"), strict=False)
一種更靈活的方式,可自行添加更多的規(guī)則
def load(save_path, model): pretraind_dict = torch.load(save_path) model_dict = model.state_dict() # 只將pretraind_dict中那些在model_dict中的參數(shù),提取出來(lái) state_dict = {k:v for k,v in pretraind_dict.items() if k in model_dict.keys()} # 將提取出來(lái)的參數(shù)更新到model_dict中,而model_dict有的,而state_dict沒(méi)有的參數(shù),不會(huì)被更新 model_dict.update(state_dict) model.load_state_dict(model_dict)
可利用上面的代碼自行設(shè)計(jì)一些規(guī)則,比如如果不要laod某個(gè)參數(shù),就可以在上面的代碼中修改:
state_dict = {k:v for k,v in pretraind_dict.items() if k in model_dict.keys() and k != 'conv1.weight'}
3. 驗(yàn)證代碼
import torch from torch import nn as nn class model_2_convs(nn.Module): def __init__(self) -> None: super().__init__() self.conv1 = nn.Conv2d(3, 64, 3) self.relu = nn.ReLU() self.conv2 = nn.Conv2d(64, 32, 3) self.mlp = nn.Linear(32, 10) def forward(self, x): x = self.conv1(x) x = self.relu(x) x = self.conv2(x) x = self.relu(x) return x class model_3_convs(nn.Module): def __init__(self) -> None: super().__init__() self.conv1 = nn.Conv2d(3, 64, 3) self.relu = nn.ReLU() self.conv2 = nn.Conv2d(64, 32, 3) self.conv3 = nn.Conv2d(32, 64, 3) def forward(self, x): x = self.conv1(x) x = self.relu(x) x = self.conv2(x) x = self.relu(x) return x def load(save_path, model): pretraind_dict = torch.load(save_path) model_dict = model.state_dict() # 只將pretraind_dict中那些在model_dict中的參數(shù),提取出來(lái) state_dict = {k:v for k,v in pretraind_dict.items() if k in model_dict.keys()} # print(state_dict.keys()) # 將提取出來(lái)的參數(shù)更新到model_dict中,而model_dict有的,而state_dict沒(méi)有的參數(shù),不會(huì)被更新 model_dict.update(state_dict) model.load_state_dict(model_dict) def load_weight_from_3_conv_to_2_conv(use_strict=False): net_1 = model_3_convs() net_2 = model_2_convs() torch.save(net_1.state_dict(), "net_1.pth") if use_strict: net_2.load_state_dict(torch.load("net_1.pth"), strict=False) else: load("net_1.pth", net_2) for key, para in net_2.state_dict().items(): print(key) if key in net_1.state_dict().keys(): print(torch.equal(para, net_1.state_dict()[key])) def load_weight_from_2_conv_to_3_conv(use_strict=False): net_1 = model_3_convs() net_2 = model_2_convs() torch.save(net_2.state_dict(), "net_2.pth") if use_strict: net_1.load_state_dict(torch.load("net_2.pth"), strict=False) else: load("net_2.pth", net_1) for key, para in net_1.state_dict().items(): print(key) if key in net_2.state_dict().keys(): print(torch.equal(para, net_2.state_dict()[key])) if __name__ == "__main__": load_weight_from_3_conv_to_2_conv(use_strict=True) load_weight_from_3_conv_to_2_conv(use_strict=False) load_weight_from_2_conv_to_3_conv(use_strict=True) load_weight_from_2_conv_to_3_conv(use_strict=False)
到此這篇關(guān)于pytorch保存和加載模型以及如何load部分參數(shù)的文章就介紹到這了,更多相關(guān)pytorch保存和加載模型內(nèi)容請(qǐng)搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
相關(guān)文章
解決c++調(diào)用python中文亂碼問(wèn)題
這篇文章主要介紹了c++調(diào)用python中文亂碼問(wèn)題,本文通過(guò)實(shí)例代碼給大家介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或工作具有一定的參考借鑒價(jià)值,需要的朋友可以參考下2020-07-07win10子系統(tǒng)python開發(fā)環(huán)境準(zhǔn)備及kenlm和nltk的使用教程
這篇文章主要介紹了win10子系統(tǒng)python開發(fā)環(huán)境準(zhǔn)備及kenlm和nltk的使用教程,非常不錯(cuò),具有一定的參考借鑒價(jià)值,需要的朋友可以參考下2019-10-10用python wxpy管理微信公眾號(hào)并利用微信獲取自己的開源數(shù)據(jù)
這篇文章主要介紹了用python wxpy管理微信公眾號(hào)并利用微信獲取自己的開源數(shù)據(jù),本文通過(guò)實(shí)例代碼給大家介紹的非常詳細(xì),具有一定的參考借鑒價(jià)值,需要的朋友可以參考下2019-07-07Python3多目標(biāo)賦值及共享引用注意事項(xiàng)
這篇文章主要介紹了Python3多目標(biāo)賦值及共享引用注意事項(xiàng),本文通過(guò)實(shí)例代碼給大家介紹的非常詳細(xì),具有一定的參考借鑒價(jià)值,需要的朋友可以參考下2019-05-05