欧美bbbwbbbw肥妇,免费乱码人妻系列日韩,一级黄片

Pytorch中實(shí)現(xiàn)只導(dǎo)入部分模型參數(shù)的方式

 更新時(shí)間:2020年01月02日 16:57:09   作者:咆哮的阿杰  
今天小編就為大家分享一篇Pytorch中實(shí)現(xiàn)只導(dǎo)入部分模型參數(shù)的方式,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過來看看吧

我們?cè)谧鲞w移學(xué)習(xí),或者在分割,檢測(cè)等任務(wù)想使用預(yù)訓(xùn)練好的模型,同時(shí)又有自己修改之后的結(jié)構(gòu),使得模型文件保存的參數(shù),有一部分是不需要的(don't expected)。我們搭建的網(wǎng)絡(luò)對(duì)保存文件來說,有一部分參數(shù)也是沒有的(missed)。如果依舊使用torch.load(model.state_dict())的辦法,就會(huì)出現(xiàn) xxx expected,xxx missed類似的錯(cuò)誤。那么在這種情況下,該如何導(dǎo)入模型呢?

好在Pytorch中的模型參數(shù)使用字典保存的,鍵是參數(shù)的名稱,值是參數(shù)的具體數(shù)值。我們使用model.state_dict()獲得這個(gè)字典,之后就能利用參數(shù)名稱來實(shí)現(xiàn)導(dǎo)入。

請(qǐng)看下面的一個(gè)例子。

我們先搭建一個(gè)小小的網(wǎng)絡(luò)。

import torch as t
from torch.nn import Module
from torch import nn
from torch.nn import functional as F
class Net(Module):
  def __init__(self):
    super(Net,self).__init__()
    self.conv1 = nn.Conv2d(3,32,3,1)
    self.conv2 = nn.Conv2d(32,3,3,1)
    self.w = nn.Parameter(t.randn(3,10))
    for p in self.children():
      nn.init.xavier_normal_(p.weight.data)
      nn.init.constant_(p.bias.data, 0)
  def forward(self, x):
    out = self.conv1(x)
    out = self.conv2(x)
 
    out = F.avg_pool2d(out,(out.shape[2],out.shape[3]))
    out = F.linear(out,weight=self.w)
    return out

然后我們保存這個(gè)網(wǎng)絡(luò)的初始值。

model = Net()
t.save(model.state_dict(),'xxx.pth')

現(xiàn)在我們將Net修改一下,多加幾個(gè)卷積層,但并不加入到forward中,僅僅出于少些幾行的目的。

import torch as t
from torch.nn import Module
from torch import nn
from torch.nn import functional as F
 
 
class Net(Module):
  def __init__(self):
    super(Net, self).__init__()
    self.conv1 = nn.Conv2d(3, 32, 3, 1)
    self.conv2 = nn.Conv2d(32, 3, 3, 1)
    self.conv3 = nn.Conv2d(3,64,3,1)
    self.conv4 = nn.Conv2d(64,32,3,1)
    for p in self.children():
      nn.init.xavier_normal_(p.weight.data)
      nn.init.constant_(p.bias.data, 0)
 
    self.w = nn.Parameter(t.randn(3, 10))
  def forward(self, x):
    out = self.conv1(x)
    out = self.conv2(x)
 
    out = F.avg_pool2d(out, (out.shape[2], out.shape[3]))
    out = F.linear(out, weight=self.w)
    return out

我們現(xiàn)在試著導(dǎo)入之前保存的模型參數(shù)。

path = 'xxx.pth'
model = Net()
model.load_state_dict(t.load(path))
 
'''
RuntimeError: Error(s) in loading state_dict for Net:
 Missing key(s) in state_dict: "conv3.weight", "conv3.bias", "conv4.weight", "conv4.bias". 
'''

出現(xiàn)了沒有在模型文件中找到error中的關(guān)鍵字的錯(cuò)誤。

現(xiàn)在我們這樣導(dǎo)入模型

path = 'xxx.pth'
model = Net()
save_model = t.load(path)
model_dict = model.state_dict()
state_dict = {k:v for k,v in save_model.items() if k in model_dict.keys()}
print(state_dict.keys()) # dict_keys(['w', 'conv1.weight', 'conv1.bias', 'conv2.weight', 'conv2.bias'])
model_dict.update(state_dict)
model.load_state_dict(model_dict)

看看上面的代碼,很容易弄明白。其中model_dict.update的作用是更新代碼中搭建的模型參數(shù)字典。為啥更新我其實(shí)并不清楚,但這一步驟是必須的,否則還會(huì)報(bào)錯(cuò)。

為了弄清楚為什么要更新model_dict,我們不妨分別輸出state_dict和model_dict的關(guān)鍵值看一看。

for k in state_dict.keys():
  print(k)
 
'''
w
conv1.weight
conv1.bias
conv2.weight
conv2.bias
'''
for k in model_dict.keys():
  print(k)
 
'''
w
conv1.weight
conv1.bias
conv2.weight
conv2.bias
conv3.weight
conv3.bias
conv4.weight
conv4.bias
'''

這個(gè)結(jié)果也是預(yù)料之中的,所以我猜測(cè),update之后,model_dict和state_dict中具有相同鍵的值已經(jīng)同步了。updata的目的就是使model_dict帶有state_dict中都具有的那一部分參數(shù)的值,對(duì)于model_dict中有的,但是save_dict中沒有的參數(shù),值不改變,參數(shù)仍然使用初始值。

以上這篇Pytorch中實(shí)現(xiàn)只導(dǎo)入部分模型參數(shù)的方式就是小編分享給大家的全部?jī)?nèi)容了,希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。

相關(guān)文章

  • matplotlib繪制多子圖共享鼠標(biāo)光標(biāo)的方法示例

    matplotlib繪制多子圖共享鼠標(biāo)光標(biāo)的方法示例

    這篇文章主要介紹了matplotlib繪制多子圖共享鼠標(biāo)光標(biāo)的方法示例,文中通過示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧
    2021-01-01
  • python檢測(cè)服務(wù)器端口代碼實(shí)例

    python檢測(cè)服務(wù)器端口代碼實(shí)例

    這篇文章主要介紹了python檢測(cè)服務(wù)器端口代碼實(shí)例,文中通過示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友可以參考下
    2019-08-08
  • 利用python/R語言繪制圣誕樹實(shí)例代碼

    利用python/R語言繪制圣誕樹實(shí)例代碼

    圣誕節(jié)快到了,分別用R和Python繪制了圣誕樹祝你們圣誕節(jié)快樂,所以下面這篇文章主要給大家介紹了關(guān)于如何利用python/R繪制圣誕樹的相關(guān)資料,需要的朋友可以參考下
    2021-12-12
  • python之如何將標(biāo)簽轉(zhuǎn)化為one-hot(獨(dú)熱編碼)

    python之如何將標(biāo)簽轉(zhuǎn)化為one-hot(獨(dú)熱編碼)

    這篇文章主要介紹了python之如何將標(biāo)簽轉(zhuǎn)化為one-hot(獨(dú)熱編碼)問題,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。如有錯(cuò)誤或未考慮完全的地方,望不吝賜教
    2023-06-06
  • Python 的迭代器與zip詳解

    Python 的迭代器與zip詳解

    本篇文章主要介紹Python 的迭代器與zip,可迭代對(duì)象的相關(guān)概念,有需要的小伙伴可以參考下,希望能夠給你帶來幫助
    2021-11-11
  • Python實(shí)現(xiàn)CET查分的方法

    Python實(shí)現(xiàn)CET查分的方法

    這篇文章主要介紹了Python實(shí)現(xiàn)CET查分的方法,實(shí)例分析了Python操作鏈接查詢的技巧,需要的朋友可以參考下
    2015-03-03
  • Python設(shè)計(jì)模式中的備忘錄模式

    Python設(shè)計(jì)模式中的備忘錄模式

    這篇文章主要為大家詳細(xì)介紹了Python設(shè)計(jì)模式中的備忘錄模式,文中示例代碼介紹的非常詳細(xì),具有一定的參考價(jià)值,感興趣的小伙伴們可以參考一下,希望能夠給你帶來幫助
    2022-02-02
  • 讓你相見恨晚的十個(gè)Python騷操作

    讓你相見恨晚的十個(gè)Python騷操作

    這篇文章主要給大家介紹了十個(gè)讓你相見恨晚的Python騷操作,文中通過示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧
    2020-11-11
  • Python實(shí)現(xiàn)簡(jiǎn)單狀態(tài)框架的方法

    Python實(shí)現(xiàn)簡(jiǎn)單狀態(tài)框架的方法

    這篇文章主要介紹了Python實(shí)現(xiàn)簡(jiǎn)單狀態(tài)框架的方法,涉及Python狀態(tài)框架的實(shí)現(xiàn)技巧,具有一定參考借鑒價(jià)值,需要的朋友可以參考下
    2015-03-03
  • Python實(shí)時(shí)監(jiān)控網(wǎng)站瀏覽記錄實(shí)現(xiàn)過程詳解

    Python實(shí)時(shí)監(jiān)控網(wǎng)站瀏覽記錄實(shí)現(xiàn)過程詳解

    這篇文章主要介紹了Python實(shí)時(shí)監(jiān)控網(wǎng)站瀏覽記錄實(shí)現(xiàn)過程詳解,文中通過示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友可以參考下
    2020-07-07

最新評(píng)論