欧美bbbwbbbw肥妇,免费乱码人妻系列日韩,一级黄片

pytorch保存和加載模型的方法及如何load部分參數(shù)

 更新時(shí)間:2024年03月11日 14:55:32   作者:BigerBang  
本文總結(jié)了pytorch中保存和加載模型的方法,以及在保存的模型文件與新定義的模型的參數(shù)不一一對(duì)應(yīng)時(shí),我們?cè)撊绾渭虞d模型參數(shù),對(duì)pytorch保存和加載模型相關(guān)知識(shí)感興趣的朋友一起看看吧

本文總結(jié)了pytorch中保存和加載模型的方法,以及在保存的模型文件與新定義的模型的參數(shù)不一一對(duì)應(yīng)時(shí),我們?cè)撊绾渭虞d模型參數(shù)。

1. 模型保存和加載的基本方式

在PyTorch中,模型可以通過(guò)兩種方式保存和加載:保存整個(gè)模型(包括模型架構(gòu)和參數(shù))或僅保存模型的參數(shù)(state_dict)。

保存整個(gè)模型: 保存模型的架構(gòu)和所有的權(quán)重參數(shù)。這樣做的好處是可以直接加載使用,無(wú)需再定義模型架構(gòu),但是無(wú)法再對(duì)模型做出調(diào)整,不夠靈活。

python
import torch
import torchvision.models as models
# 實(shí)例化一個(gè)預(yù)訓(xùn)練的resnet模型
model = models.resnet18(pretrained=True)
# 保存整個(gè)模型
torch.save(model, 'model.pth')
# 加載整個(gè)模型
model = torch.load('model.pth')

僅保存模型參數(shù)
通常推薦此方式,因?yàn)樗?strong>僅保存權(quán)重參數(shù),體積更小,更靈活,需要時(shí)可用新定義的模型結(jié)構(gòu)加載參數(shù)。
保存的參數(shù)通過(guò)model.state_dict()獲取,得到一個(gè)有序字典類型:collections.OrderedDict,其中key是參數(shù)名稱,value是保存了參數(shù)數(shù)值的tensor類型。

OrderedDict是 Python 標(biāo)準(zhǔn)庫(kù) collections 模塊中的一種字典(dict)類的子類。和普通的字典相比,OrderedDict 繳存了元素插入的順序,所以當(dāng)對(duì)其進(jìn)行迭代時(shí),鍵值對(duì)會(huì)按照添加的先后次序返回,而不是基于鍵的散列值。

保存模型參數(shù)示例:

# 保存模型的state_dict
torch.save(model.state_dict(), 'model_state_dict.pth')

加載模型參數(shù)示例:

# 首先需要重新定義模型的結(jié)構(gòu),這里假設(shè)我們已經(jīng)有了一模一樣的模型定義
model = models.resnet18(pretrained=False) # 取消預(yù)訓(xùn)練權(quán)重
# 加載模型參數(shù)
model.load_state_dict(torch.load('model_state_dict.pth'))

2. 保存的模型文件和當(dāng)前定義的模型參數(shù)不完全一致時(shí)

有時(shí)候我們會(huì)對(duì)一個(gè)pretrained model的若干層進(jìn)行一些修改,涉及到層的添加和減少,同時(shí)未改變的那些層想要load pretrained model的參數(shù)。
假設(shè)新定義的模型是new_net, pretrained模型是old_net, 以下兩種方式適用于以下所有場(chǎng)景:
1. old_net的參數(shù)是new_net的子集
2. new_net的參數(shù)是old_net的子集
3. new_net和old_net的參數(shù)有交集

strict=False
一個(gè)直接的方式是在load_state_dict時(shí)strict=False,這樣在load參數(shù)時(shí)pytorch會(huì)匹配兩個(gè)模型中參數(shù)名字相同的參數(shù)進(jìn)行導(dǎo)入。

net_2.load_state_dict(torch.load("net_1.pth"), strict=False)

一種更靈活的方式,可自行添加更多的規(guī)則

def load(save_path, model):
    pretraind_dict = torch.load(save_path)
    model_dict =  model.state_dict()
    # 只將pretraind_dict中那些在model_dict中的參數(shù),提取出來(lái)
    state_dict = {k:v for k,v in pretraind_dict.items() if k in model_dict.keys()}
    # 將提取出來(lái)的參數(shù)更新到model_dict中,而model_dict有的,而state_dict沒(méi)有的參數(shù),不會(huì)被更新
    model_dict.update(state_dict)
    model.load_state_dict(model_dict)

可利用上面的代碼自行設(shè)計(jì)一些規(guī)則,比如如果不要laod某個(gè)參數(shù),就可以在上面的代碼中修改:

state_dict = {k:v for k,v in pretraind_dict.items() if k in model_dict.keys() and k != 'conv1.weight'}

3. 驗(yàn)證代碼

import torch
from torch import nn as nn
class model_2_convs(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(3, 64, 3)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(64, 32, 3)
        self.mlp = nn.Linear(32, 10)
    def forward(self, x):
        x = self.conv1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.relu(x)
        return x
class model_3_convs(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.conv1 = nn.Conv2d(3, 64, 3)
        self.relu = nn.ReLU()
        self.conv2 = nn.Conv2d(64, 32, 3)
        self.conv3 = nn.Conv2d(32, 64, 3)
    def forward(self, x):
        x = self.conv1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.relu(x)
        return x
def load(save_path, model):
    pretraind_dict = torch.load(save_path)
    model_dict =  model.state_dict()
    # 只將pretraind_dict中那些在model_dict中的參數(shù),提取出來(lái)
    state_dict = {k:v for k,v in pretraind_dict.items() if k in model_dict.keys()}
    # print(state_dict.keys())
    # 將提取出來(lái)的參數(shù)更新到model_dict中,而model_dict有的,而state_dict沒(méi)有的參數(shù),不會(huì)被更新
    model_dict.update(state_dict)
    model.load_state_dict(model_dict)
def load_weight_from_3_conv_to_2_conv(use_strict=False):
    net_1 = model_3_convs()
    net_2 = model_2_convs()
    torch.save(net_1.state_dict(), "net_1.pth")
    if use_strict:
        net_2.load_state_dict(torch.load("net_1.pth"), strict=False)
    else:
        load("net_1.pth", net_2)
    for key, para in net_2.state_dict().items():
        print(key)
        if key in net_1.state_dict().keys():
            print(torch.equal(para, net_1.state_dict()[key]))
def load_weight_from_2_conv_to_3_conv(use_strict=False):
    net_1 = model_3_convs()
    net_2 = model_2_convs()
    torch.save(net_2.state_dict(), "net_2.pth")
    if use_strict:
        net_1.load_state_dict(torch.load("net_2.pth"), strict=False)
    else:
        load("net_2.pth", net_1)
    for key, para in net_1.state_dict().items():
        print(key)
        if key in net_2.state_dict().keys():
            print(torch.equal(para, net_2.state_dict()[key]))
if __name__ == "__main__":
    load_weight_from_3_conv_to_2_conv(use_strict=True)
    load_weight_from_3_conv_to_2_conv(use_strict=False)
    load_weight_from_2_conv_to_3_conv(use_strict=True)
    load_weight_from_2_conv_to_3_conv(use_strict=False)

到此這篇關(guān)于pytorch保存和加載模型以及如何load部分參數(shù)的文章就介紹到這了,更多相關(guān)pytorch保存和加載模型內(nèi)容請(qǐng)搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!

相關(guān)文章

  • 解決c++調(diào)用python中文亂碼問(wèn)題

    解決c++調(diào)用python中文亂碼問(wèn)題

    這篇文章主要介紹了c++調(diào)用python中文亂碼問(wèn)題,本文通過(guò)實(shí)例代碼給大家介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或工作具有一定的參考借鑒價(jià)值,需要的朋友可以參考下
    2020-07-07
  • win10子系統(tǒng)python開發(fā)環(huán)境準(zhǔn)備及kenlm和nltk的使用教程

    win10子系統(tǒng)python開發(fā)環(huán)境準(zhǔn)備及kenlm和nltk的使用教程

    這篇文章主要介紹了win10子系統(tǒng)python開發(fā)環(huán)境準(zhǔn)備及kenlm和nltk的使用教程,非常不錯(cuò),具有一定的參考借鑒價(jià)值,需要的朋友可以參考下
    2019-10-10
  • Python繪制正二十面體圖形示例

    Python繪制正二十面體圖形示例

    正二十面體由20個(gè)小的正三角形面組成,每個(gè)頂點(diǎn)周圍有?5?個(gè)頂點(diǎn),下面這篇文章主要給大家介紹了關(guān)于Python繪制正二十面體圖形的相關(guān)資料,需要的朋友可以參考下
    2022-12-12
  • OpenCV圖像變換之傅里葉變換的一些應(yīng)用

    OpenCV圖像變換之傅里葉變換的一些應(yīng)用

    這篇文章主要給大家介紹了關(guān)于OpenCV圖像變換之傅里葉變換的相關(guān)資料,傅里葉變換可以將一幅圖片分解為正弦和余弦兩個(gè)分量,換而言之,他可以將一幅圖像從其空間域(spatial domain)轉(zhuǎn)換為頻域(frequency domain),需要的朋友可以參考下
    2021-07-07
  • 如何讓python的運(yùn)行速度得到提升

    如何讓python的運(yùn)行速度得到提升

    在本篇文章里小編給大家分享了關(guān)于如何讓python的運(yùn)行速度得到提升的方法和技巧,需要的朋友們可以學(xué)習(xí)下。
    2020-07-07
  • Python微信庫(kù):itchat的用法詳解

    Python微信庫(kù):itchat的用法詳解

    本篇文章主要介紹了Python微信庫(kù):itchat的用法詳解,小編覺(jué)得挺不錯(cuò)的,現(xiàn)在分享給大家,也給大家做個(gè)參考。一起跟隨小編過(guò)來(lái)看看吧
    2017-08-08
  • 用python wxpy管理微信公眾號(hào)并利用微信獲取自己的開源數(shù)據(jù)

    用python wxpy管理微信公眾號(hào)并利用微信獲取自己的開源數(shù)據(jù)

    這篇文章主要介紹了用python wxpy管理微信公眾號(hào)并利用微信獲取自己的開源數(shù)據(jù),本文通過(guò)實(shí)例代碼給大家介紹的非常詳細(xì),具有一定的參考借鑒價(jià)值,需要的朋友可以參考下
    2019-07-07
  • Python3多目標(biāo)賦值及共享引用注意事項(xiàng)

    Python3多目標(biāo)賦值及共享引用注意事項(xiàng)

    這篇文章主要介紹了Python3多目標(biāo)賦值及共享引用注意事項(xiàng),本文通過(guò)實(shí)例代碼給大家介紹的非常詳細(xì),具有一定的參考借鑒價(jià)值,需要的朋友可以參考下
    2019-05-05
  • Python使用psutil獲取系統(tǒng)信息

    Python使用psutil獲取系統(tǒng)信息

    這篇文章介紹了Python使用psutil獲取系統(tǒng)信息的方法,文中通過(guò)示例代碼介紹的非常詳細(xì)。對(duì)大家的學(xué)習(xí)或工作具有一定的參考借鑒價(jià)值,需要的朋友可以參考下
    2022-05-05
  • Python模塊zipfile原理及使用方法詳解

    Python模塊zipfile原理及使用方法詳解

    這篇文章主要介紹了Python模塊zipfile原理及使用方法詳解,文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友可以參考下
    2020-08-08

最新評(píng)論