欧美bbbwbbbw肥妇,免费乱码人妻系列日韩,一级黄片

pytorch 狀態(tài)字典:state_dict使用詳解

 更新時間:2020年01月17日 17:12:28   作者:wzg2016  
今天小編就為大家分享一篇pytorch 狀態(tài)字典:state_dict使用詳解,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧

pytorch 中的 state_dict 是一個簡單的python的字典對象,將每一層與它的對應(yīng)參數(shù)建立映射關(guān)系.(如model的每一層的weights及偏置等等)

(注意,只有那些參數(shù)可以訓練的layer才會被保存到模型的state_dict中,如卷積層,線性層等等)

優(yōu)化器對象Optimizer也有一個state_dict,它包含了優(yōu)化器的狀態(tài)以及被使用的超參數(shù)(如lr, momentum,weight_decay等)

備注:

1) state_dict是在定義了model或optimizer之后pytorch自動生成的,可以直接調(diào)用.常用的保存state_dict的格式是".pt"或'.pth'的文件,即下面命令的 PATH="./***.pt"

torch.save(model.state_dict(), PATH)

2) load_state_dict 也是model或optimizer之后pytorch自動具備的函數(shù),可以直接調(diào)用

model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load(PATH))
model.eval()

注意:model.eval() 的重要性,在2)中最后用到了model.eval(),是因為,只有在執(zhí)行該命令后,"dropout層"及"batch normalization層"才會進入 evalution 模態(tài). 而在"訓練(training)模態(tài)"與"評估(evalution)模態(tài)"下,這兩層有不同的表現(xiàn)形式.

模態(tài)字典(state_dict)的保存(model是一個網(wǎng)絡(luò)結(jié)構(gòu)類的對象)

1.1)僅保存學習到的參數(shù),用以下命令

 torch.save(model.state_dict(), PATH)

1.2)加載model.state_dict,用以下命令

 model = TheModelClass(*args, **kwargs)
 model.load_state_dict(torch.load(PATH))
 model.eval()

備注:model.load_state_dict的操作對象是 一個具體的對象,而不能是文件名

2.1)保存整個model的狀態(tài),用以下命令

torch.save(model,PATH)

2.2)加載整個model的狀態(tài),用以下命令:

   # Model class must be defined somewhere

 model = torch.load(PATH)

 model.eval()

state_dict 是一個python的字典格式,以字典的格式存儲,然后以字典的格式被加載,而且只加載key匹配的項

如何僅加載某一層的訓練的到的參數(shù)(某一層的state)

If you want to load parameters from one layer to another, but some keys do not match, simply change the name of the parameter keys in the state_dict that you are loading to match the keys in the model that you are loading into.

conv1_weight_state = torch.load('./model_state_dict.pt')['conv1.weight']

加載模型參數(shù)后,如何設(shè)置某層某參數(shù)的"是否需要訓練"(param.requires_grad)

for param in list(model.pretrained.parameters()):
 param.requires_grad = False

注意: requires_grad的操作對象是tensor.

疑問:能否直接對某個層直接之用requires_grad呢?例如:model.conv1.requires_grad=False

回答:經(jīng)測試,不可以.model.conv1 沒有requires_grad屬性.

全部測試代碼:

#-*-coding:utf-8-*-
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
 
 
 
# define model
class TheModelClass(nn.Module):
 def __init__(self):
  super(TheModelClass,self).__init__()
  self.conv1 = nn.Conv2d(3,6,5)
  self.pool = nn.MaxPool2d(2,2)
  self.conv2 = nn.Conv2d(6,16,5)
  self.fc1 = nn.Linear(16*5*5,120)
  self.fc2 = nn.Linear(120,84)
  self.fc3 = nn.Linear(84,10)
 
 def forward(self,x):
  x = self.pool(F.relu(self.conv1(x)))
  x = self.pool(F.relu(self.conv2(x)))
  x = x.view(-1,16*5*5)
  x = F.relu(self.fc1(x))
  x = F.relu(self.fc2(x))
  x = self.fc3(x)
  return x
 
# initial model
model = TheModelClass()
 
#initialize the optimizer
optimizer = optim.SGD(model.parameters(),lr=0.001,momentum=0.9)
 
# print the model's state_dict
print("model's state_dict:")
for param_tensor in model.state_dict():
 print(param_tensor,'\t',model.state_dict()[param_tensor].size())
 
print("\noptimizer's state_dict")
for var_name in optimizer.state_dict():
 print(var_name,'\t',optimizer.state_dict()[var_name])
 
print("\nprint particular param")
print('\n',model.conv1.weight.size())
print('\n',model.conv1.weight)
 
print("------------------------------------")
torch.save(model.state_dict(),'./model_state_dict.pt')
# model_2 = TheModelClass()
# model_2.load_state_dict(torch.load('./model_state_dict'))
# model.eval()
# print('\n',model_2.conv1.weight)
# print((model_2.conv1.weight == model.conv1.weight).size())
## 僅僅加載某一層的參數(shù)
conv1_weight_state = torch.load('./model_state_dict.pt')['conv1.weight']
print(conv1_weight_state==model.conv1.weight)
 
model_2 = TheModelClass()
model_2.load_state_dict(torch.load('./model_state_dict.pt'))
model_2.conv1.requires_grad=False
print(model_2.conv1.requires_grad)
print(model_2.conv1.bias.requires_grad)

以上這篇pytorch 狀態(tài)字典:state_dict使用詳解就是小編分享給大家的全部內(nèi)容了,希望能給大家一個參考,也希望大家多多支持腳本之家。

相關(guān)文章

  • 教你python制作自己的模塊的基本步驟

    教你python制作自己的模塊的基本步驟

    這篇文章主要介紹了python如何制作自己的模塊,本文通過實例代碼給大家介紹的非常詳細,對大家的學習或工作具有一定的參考借鑒價值,需要的朋友可以參考下
    2023-08-08
  • Python實現(xiàn)的人工神經(jīng)網(wǎng)絡(luò)算法示例【基于反向傳播算法】

    Python實現(xiàn)的人工神經(jīng)網(wǎng)絡(luò)算法示例【基于反向傳播算法】

    這篇文章主要介紹了Python實現(xiàn)的人工神經(jīng)網(wǎng)絡(luò)算法,結(jié)合實例形式分析了Python基于反向傳播算法實現(xiàn)的人工神經(jīng)網(wǎng)絡(luò)相關(guān)操作技巧,需要的朋友可以參考下
    2017-11-11
  • Python利用treap實現(xiàn)雙索引的方法

    Python利用treap實現(xiàn)雙索引的方法

    所遍歷的元素一定是遞增(小堆)或是遞減(大堆)關(guān)系,但是我們無法得知左子樹與右子樹兩部分節(jié)點的排序關(guān)系。本文就來講講算法和數(shù)據(jù)結(jié)構(gòu)共同滿足一組特性,感興趣的小伙伴請參考下面文章的內(nèi)容
    2021-09-09
  • python基礎(chǔ)入門之列表(一)

    python基礎(chǔ)入門之列表(一)

    在Python中,列表(list)是常用的數(shù)據(jù)類型。列表由一系列按照特定順序排列的項(item)組成。
    2021-06-06
  • python使用redis模塊來跟redis實現(xiàn)交互

    python使用redis模塊來跟redis實現(xiàn)交互

    這篇文章主要介紹了python使用redis模塊來跟redis實現(xiàn)交互,文章圍繞主題展開詳細的內(nèi)容介紹,具有一定的參考價值,需要的小伙伴可以參考一下
    2022-06-06
  • python判斷、獲取一張圖片主色調(diào)的2個實例

    python判斷、獲取一張圖片主色調(diào)的2個實例

    一幅圖片,想通過程序判斷獲得其主要色調(diào),應(yīng)該怎么樣處理?本文通過python實現(xiàn)判斷、獲取一張圖片的主色調(diào)方法,需要的朋友可以參考下
    2014-04-04
  • 詳解Python?Flask?API?示例演示(附cookies和session)

    詳解Python?Flask?API?示例演示(附cookies和session)

    這篇文章主要為大家介紹了Python?Flask?API?示例演示(附cookies和session)詳解,有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進步,早日升職加薪
    2023-03-03
  • 利用Python將時間或時間間隔轉(zhuǎn)為ISO 8601格式方法示例

    利用Python將時間或時間間隔轉(zhuǎn)為ISO 8601格式方法示例

    國際標準化組織的國際標準ISO8601是日期和時間的表示方法,全稱為《數(shù)據(jù)存儲和交換形式·信息交換·日期和時間的表示方法》,下面這篇文章主要給大家介紹了關(guān)于利用Python將時間或時間間隔轉(zhuǎn)為ISO 8601格式的相關(guān)資料,需要的朋友可以參考下。
    2017-09-09
  • python計算階乘和的方法(1!+2!+3!+...+n!)

    python計算階乘和的方法(1!+2!+3!+...+n!)

    今天小編就為大家分享一篇python計算階乘和的方法(1!+2!+3!+...+n!),具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧
    2019-02-02
  • Python使用內(nèi)置函數(shù)setattr設(shè)置對象的屬性值

    Python使用內(nèi)置函數(shù)setattr設(shè)置對象的屬性值

    這篇文章主要介紹了Python使用內(nèi)置函數(shù)setattr設(shè)置對象的屬性值,文中通過示例代碼介紹的非常詳細,對大家的學習或者工作具有一定的參考學習價值,需要的朋友可以參考下
    2020-10-10

最新評論