欧美bbbwbbbw肥妇,免费乱码人妻系列日韩,一级黄片

pytorch 狀態(tài)字典:state_dict使用詳解

 更新時(shí)間:2020年01月17日 17:12:28   作者:wzg2016  
今天小編就為大家分享一篇pytorch 狀態(tài)字典:state_dict使用詳解,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧

pytorch 中的 state_dict 是一個(gè)簡(jiǎn)單的python的字典對(duì)象,將每一層與它的對(duì)應(yīng)參數(shù)建立映射關(guān)系.(如model的每一層的weights及偏置等等)

(注意,只有那些參數(shù)可以訓(xùn)練的layer才會(huì)被保存到模型的state_dict中,如卷積層,線性層等等)

優(yōu)化器對(duì)象Optimizer也有一個(gè)state_dict,它包含了優(yōu)化器的狀態(tài)以及被使用的超參數(shù)(如lr, momentum,weight_decay等)

備注:

1) state_dict是在定義了model或optimizer之后pytorch自動(dòng)生成的,可以直接調(diào)用.常用的保存state_dict的格式是".pt"或'.pth'的文件,即下面命令的 PATH="./***.pt"

torch.save(model.state_dict(), PATH)

2) load_state_dict 也是model或optimizer之后pytorch自動(dòng)具備的函數(shù),可以直接調(diào)用

model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()

注意:model.eval() 的重要性,在2)中最后用到了model.eval(),是因?yàn)?只有在執(zhí)行該命令后,"dropout層"及"batch normalization層"才會(huì)進(jìn)入 evalution 模態(tài). 而在"訓(xùn)練(training)模態(tài)"與"評(píng)估(evalution)模態(tài)"下,這兩層有不同的表現(xiàn)形式.

模態(tài)字典(state_dict)的保存(model是一個(gè)網(wǎng)絡(luò)結(jié)構(gòu)類的對(duì)象)

1.1)僅保存學(xué)習(xí)到的參數(shù),用以下命令

 torch.save(model.state_dict(), PATH)

1.2)加載model.state_dict,用以下命令

 model = TheModelClass(*args, **kwargs)
 model.load_state_dict(torch.load(PATH))
 model.eval()

備注:model.load_state_dict的操作對(duì)象是 一個(gè)具體的對(duì)象,而不能是文件名

2.1)保存整個(gè)model的狀態(tài),用以下命令

torch.save(model,PATH)

2.2)加載整個(gè)model的狀態(tài),用以下命令:

   # Model class must be defined somewhere

 model = torch.load(PATH)

 model.eval()

state_dict 是一個(gè)python的字典格式,以字典的格式存儲(chǔ),然后以字典的格式被加載,而且只加載key匹配的項(xiàng)

如何僅加載某一層的訓(xùn)練的到的參數(shù)(某一層的state)

If you want to load parameters from one layer to another, but some keys do not match, simply change the name of the parameter keys in the state_dict that you are loading to match the keys in the model that you are loading into.

conv1_weight_state = torch.load('./model_state_dict.pt')['conv1.weight']

加載模型參數(shù)后,如何設(shè)置某層某參數(shù)的"是否需要訓(xùn)練"(param.requires_grad)

for param in list(model.pretrained.parameters()):
 param.requires_grad = False

注意: requires_grad的操作對(duì)象是tensor.

疑問(wèn):能否直接對(duì)某個(gè)層直接之用requires_grad呢?例如:model.conv1.requires_grad=False

回答:經(jīng)測(cè)試,不可以.model.conv1 沒(méi)有requires_grad屬性.

全部測(cè)試代碼:

#-*-coding:utf-8-*-
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
 
 
 
# define model
class TheModelClass(nn.Module):
 def __init__(self):
  super(TheModelClass,self).__init__()
  self.conv1 = nn.Conv2d(3,6,5)
  self.pool = nn.MaxPool2d(2,2)
  self.conv2 = nn.Conv2d(6,16,5)
  self.fc1 = nn.Linear(16*5*5,120)
  self.fc2 = nn.Linear(120,84)
  self.fc3 = nn.Linear(84,10)
 
 def forward(self,x):
  x = self.pool(F.relu(self.conv1(x)))
  x = self.pool(F.relu(self.conv2(x)))
  x = x.view(-1,16*5*5)
  x = F.relu(self.fc1(x))
  x = F.relu(self.fc2(x))
  x = self.fc3(x)
  return x
 
# initial model
model = TheModelClass()
 
#initialize the optimizer
optimizer = optim.SGD(model.parameters(),lr=0.001,momentum=0.9)
 
# print the model's state_dict
print("model's state_dict:")
for param_tensor in model.state_dict():
 print(param_tensor,'\t',model.state_dict()[param_tensor].size())
 
print("\noptimizer's state_dict")
for var_name in optimizer.state_dict():
 print(var_name,'\t',optimizer.state_dict()[var_name])
 
print("\nprint particular param")
print('\n',model.conv1.weight.size())
print('\n',model.conv1.weight)
 
print("------------------------------------")
torch.save(model.state_dict(),'./model_state_dict.pt')
# model_2 = TheModelClass()
# model_2.load_state_dict(torch.load('./model_state_dict'))
# model.eval()
# print('\n',model_2.conv1.weight)
# print((model_2.conv1.weight == model.conv1.weight).size())
## 僅僅加載某一層的參數(shù)
conv1_weight_state = torch.load('./model_state_dict.pt')['conv1.weight']
print(conv1_weight_state==model.conv1.weight)
 
model_2 = TheModelClass()
model_2.load_state_dict(torch.load('./model_state_dict.pt'))
model_2.conv1.requires_grad=False
print(model_2.conv1.requires_grad)
print(model_2.conv1.bias.requires_grad)

以上這篇pytorch 狀態(tài)字典:state_dict使用詳解就是小編分享給大家的全部?jī)?nèi)容了,希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。

相關(guān)文章

  • 教你python制作自己的模塊的基本步驟

    教你python制作自己的模塊的基本步驟

    這篇文章主要介紹了python如何制作自己的模塊,本文通過(guò)實(shí)例代碼給大家介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或工作具有一定的參考借鑒價(jià)值,需要的朋友可以參考下
    2023-08-08
  • Python實(shí)現(xiàn)的人工神經(jīng)網(wǎng)絡(luò)算法示例【基于反向傳播算法】

    Python實(shí)現(xiàn)的人工神經(jīng)網(wǎng)絡(luò)算法示例【基于反向傳播算法】

    這篇文章主要介紹了Python實(shí)現(xiàn)的人工神經(jīng)網(wǎng)絡(luò)算法,結(jié)合實(shí)例形式分析了Python基于反向傳播算法實(shí)現(xiàn)的人工神經(jīng)網(wǎng)絡(luò)相關(guān)操作技巧,需要的朋友可以參考下
    2017-11-11
  • Python利用treap實(shí)現(xiàn)雙索引的方法

    Python利用treap實(shí)現(xiàn)雙索引的方法

    所遍歷的元素一定是遞增(小堆)或是遞減(大堆)關(guān)系,但是我們無(wú)法得知左子樹(shù)與右子樹(shù)兩部分節(jié)點(diǎn)的排序關(guān)系。本文就來(lái)講講算法和數(shù)據(jù)結(jié)構(gòu)共同滿足一組特性,感興趣的小伙伴請(qǐng)參考下面文章的內(nèi)容
    2021-09-09
  • python基礎(chǔ)入門之列表(一)

    python基礎(chǔ)入門之列表(一)

    在Python中,列表(list)是常用的數(shù)據(jù)類型。列表由一系列按照特定順序排列的項(xiàng)(item)組成。
    2021-06-06
  • python使用redis模塊來(lái)跟redis實(shí)現(xiàn)交互

    python使用redis模塊來(lái)跟redis實(shí)現(xiàn)交互

    這篇文章主要介紹了python使用redis模塊來(lái)跟redis實(shí)現(xiàn)交互,文章圍繞主題展開(kāi)詳細(xì)的內(nèi)容介紹,具有一定的參考價(jià)值,需要的小伙伴可以參考一下
    2022-06-06
  • python判斷、獲取一張圖片主色調(diào)的2個(gè)實(shí)例

    python判斷、獲取一張圖片主色調(diào)的2個(gè)實(shí)例

    一幅圖片,想通過(guò)程序判斷獲得其主要色調(diào),應(yīng)該怎么樣處理?本文通過(guò)python實(shí)現(xiàn)判斷、獲取一張圖片的主色調(diào)方法,需要的朋友可以參考下
    2014-04-04
  • 詳解Python?Flask?API?示例演示(附cookies和session)

    詳解Python?Flask?API?示例演示(附cookies和session)

    這篇文章主要為大家介紹了Python?Flask?API?示例演示(附cookies和session)詳解,有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進(jìn)步,早日升職加薪
    2023-03-03
  • 利用Python將時(shí)間或時(shí)間間隔轉(zhuǎn)為ISO 8601格式方法示例

    利用Python將時(shí)間或時(shí)間間隔轉(zhuǎn)為ISO 8601格式方法示例

    國(guó)際標(biāo)準(zhǔn)化組織的國(guó)際標(biāo)準(zhǔn)ISO8601是日期和時(shí)間的表示方法,全稱為《數(shù)據(jù)存儲(chǔ)和交換形式·信息交換·日期和時(shí)間的表示方法》,下面這篇文章主要給大家介紹了關(guān)于利用Python將時(shí)間或時(shí)間間隔轉(zhuǎn)為ISO 8601格式的相關(guān)資料,需要的朋友可以參考下。
    2017-09-09
  • python計(jì)算階乘和的方法(1!+2!+3!+...+n!)

    python計(jì)算階乘和的方法(1!+2!+3!+...+n!)

    今天小編就為大家分享一篇python計(jì)算階乘和的方法(1!+2!+3!+...+n!),具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧
    2019-02-02
  • Python使用內(nèi)置函數(shù)setattr設(shè)置對(duì)象的屬性值

    Python使用內(nèi)置函數(shù)setattr設(shè)置對(duì)象的屬性值

    這篇文章主要介紹了Python使用內(nèi)置函數(shù)setattr設(shè)置對(duì)象的屬性值,文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友可以參考下
    2020-10-10

最新評(píng)論