pytorch 狀態(tài)字典:state_dict使用詳解
pytorch 中的 state_dict 是一個簡單的python的字典對象,將每一層與它的對應(yīng)參數(shù)建立映射關(guān)系.(如model的每一層的weights及偏置等等)
(注意,只有那些參數(shù)可以訓練的layer才會被保存到模型的state_dict中,如卷積層,線性層等等)
優(yōu)化器對象Optimizer也有一個state_dict,它包含了優(yōu)化器的狀態(tài)以及被使用的超參數(shù)(如lr, momentum,weight_decay等)
備注:
1) state_dict是在定義了model或optimizer之后pytorch自動生成的,可以直接調(diào)用.常用的保存state_dict的格式是".pt"或'.pth'的文件,即下面命令的 PATH="./***.pt"
torch.save(model.state_dict(), PATH)
2) load_state_dict 也是model或optimizer之后pytorch自動具備的函數(shù),可以直接調(diào)用
model = TheModelClass(*args, **kwargs) model.load_state_dict(torch.load(PATH)) model.eval()
注意:model.eval() 的重要性,在2)中最后用到了model.eval(),是因為,只有在執(zhí)行該命令后,"dropout層"及"batch normalization層"才會進入 evalution 模態(tài). 而在"訓練(training)模態(tài)"與"評估(evalution)模態(tài)"下,這兩層有不同的表現(xiàn)形式.
模態(tài)字典(state_dict)的保存(model是一個網(wǎng)絡(luò)結(jié)構(gòu)類的對象)
1.1)僅保存學習到的參數(shù),用以下命令
torch.save(model.state_dict(), PATH)
1.2)加載model.state_dict,用以下命令
model = TheModelClass(*args, **kwargs) model.load_state_dict(torch.load(PATH)) model.eval()
備注:model.load_state_dict的操作對象是 一個具體的對象,而不能是文件名
2.1)保存整個model的狀態(tài),用以下命令
torch.save(model,PATH)
2.2)加載整個model的狀態(tài),用以下命令:
# Model class must be defined somewhere model = torch.load(PATH) model.eval()
state_dict 是一個python的字典格式,以字典的格式存儲,然后以字典的格式被加載,而且只加載key匹配的項
如何僅加載某一層的訓練的到的參數(shù)(某一層的state)
If you want to load parameters from one layer to another, but some keys do not match, simply change the name of the parameter keys in the state_dict that you are loading to match the keys in the model that you are loading into.
conv1_weight_state = torch.load('./model_state_dict.pt')['conv1.weight']
加載模型參數(shù)后,如何設(shè)置某層某參數(shù)的"是否需要訓練"(param.requires_grad)
for param in list(model.pretrained.parameters()): param.requires_grad = False
注意: requires_grad的操作對象是tensor.
疑問:能否直接對某個層直接之用requires_grad呢?例如:model.conv1.requires_grad=False
回答:經(jīng)測試,不可以.model.conv1 沒有requires_grad屬性.
全部測試代碼:
#-*-coding:utf-8-*- import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim # define model class TheModelClass(nn.Module): def __init__(self): super(TheModelClass,self).__init__() self.conv1 = nn.Conv2d(3,6,5) self.pool = nn.MaxPool2d(2,2) self.conv2 = nn.Conv2d(6,16,5) self.fc1 = nn.Linear(16*5*5,120) self.fc2 = nn.Linear(120,84) self.fc3 = nn.Linear(84,10) def forward(self,x): x = self.pool(F.relu(self.conv1(x))) x = self.pool(F.relu(self.conv2(x))) x = x.view(-1,16*5*5) x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = self.fc3(x) return x # initial model model = TheModelClass() #initialize the optimizer optimizer = optim.SGD(model.parameters(),lr=0.001,momentum=0.9) # print the model's state_dict print("model's state_dict:") for param_tensor in model.state_dict(): print(param_tensor,'\t',model.state_dict()[param_tensor].size()) print("\noptimizer's state_dict") for var_name in optimizer.state_dict(): print(var_name,'\t',optimizer.state_dict()[var_name]) print("\nprint particular param") print('\n',model.conv1.weight.size()) print('\n',model.conv1.weight) print("------------------------------------") torch.save(model.state_dict(),'./model_state_dict.pt') # model_2 = TheModelClass() # model_2.load_state_dict(torch.load('./model_state_dict')) # model.eval() # print('\n',model_2.conv1.weight) # print((model_2.conv1.weight == model.conv1.weight).size()) ## 僅僅加載某一層的參數(shù) conv1_weight_state = torch.load('./model_state_dict.pt')['conv1.weight'] print(conv1_weight_state==model.conv1.weight) model_2 = TheModelClass() model_2.load_state_dict(torch.load('./model_state_dict.pt')) model_2.conv1.requires_grad=False print(model_2.conv1.requires_grad) print(model_2.conv1.bias.requires_grad)
以上這篇pytorch 狀態(tài)字典:state_dict使用詳解就是小編分享給大家的全部內(nèi)容了,希望能給大家一個參考,也希望大家多多支持腳本之家。
相關(guān)文章
Python實現(xiàn)的人工神經(jīng)網(wǎng)絡(luò)算法示例【基于反向傳播算法】
這篇文章主要介紹了Python實現(xiàn)的人工神經(jīng)網(wǎng)絡(luò)算法,結(jié)合實例形式分析了Python基于反向傳播算法實現(xiàn)的人工神經(jīng)網(wǎng)絡(luò)相關(guān)操作技巧,需要的朋友可以參考下2017-11-11python使用redis模塊來跟redis實現(xiàn)交互
這篇文章主要介紹了python使用redis模塊來跟redis實現(xiàn)交互,文章圍繞主題展開詳細的內(nèi)容介紹,具有一定的參考價值,需要的小伙伴可以參考一下2022-06-06詳解Python?Flask?API?示例演示(附cookies和session)
這篇文章主要為大家介紹了Python?Flask?API?示例演示(附cookies和session)詳解,有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進步,早日升職加薪2023-03-03利用Python將時間或時間間隔轉(zhuǎn)為ISO 8601格式方法示例
國際標準化組織的國際標準ISO8601是日期和時間的表示方法,全稱為《數(shù)據(jù)存儲和交換形式·信息交換·日期和時間的表示方法》,下面這篇文章主要給大家介紹了關(guān)于利用Python將時間或時間間隔轉(zhuǎn)為ISO 8601格式的相關(guān)資料,需要的朋友可以參考下。2017-09-09python計算階乘和的方法(1!+2!+3!+...+n!)
今天小編就為大家分享一篇python計算階乘和的方法(1!+2!+3!+...+n!),具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2019-02-02Python使用內(nèi)置函數(shù)setattr設(shè)置對象的屬性值
這篇文章主要介紹了Python使用內(nèi)置函數(shù)setattr設(shè)置對象的屬性值,文中通過示例代碼介紹的非常詳細,對大家的學習或者工作具有一定的參考學習價值,需要的朋友可以參考下2020-10-10