PyTorch加載模型model.load_state_dict()問題及解決
PyTorch加載模型model.load_state_dict()問題
希望將訓(xùn)練好的模型加載到新的網(wǎng)絡(luò)上。
如上面題目所描述的,PyTorch在加載之前保存的模型參數(shù)的時(shí)候,遇到了問題。
Unexpected key(s) in state_dict: "module.features. ...".,Expected ".features....". 直接原因是key值名字不對(duì)應(yīng)。
表明了加載過程中,期望獲得的key值為feature...,而不是module.features....。
這是由模型保存過程中導(dǎo)致的,模型應(yīng)該是在DataParallel模式下面,也就是采用了多GPU訓(xùn)練模型,然后直接保存的。
You probably saved the model using nn.DataParallel, which stores the model in module, and now you are trying to load it without . You can either add a nn.DataParallel temporarily in your network for loading purposes, or you can load the weights file, create a new ordered dict without the module prefix, and load it back.
解決上面的問題有三個(gè)辦法:
1. 對(duì)load的模型創(chuàng)建新的字典
去掉不需要的key值"module".
# original saved file with DataParallel state_dict = torch.load('checkpoint.pt') # 模型可以保存為pth文件,也可以為pt文件。 # create new OrderedDict that does not contain `module.` from collections import OrderedDict new_state_dict = OrderedDict() for k, v in state_dict.items(): name = k[7:] # remove `module.`,表面從第7個(gè)key值字符取到最后一個(gè)字符,正好去掉了module. new_state_dict[name] = v #新字典的key值對(duì)應(yīng)的value為一一對(duì)應(yīng)的值。 # load params model.load_state_dict(new_state_dict) # 從新加載這個(gè)模型。
2. 直接用空白''代替'module.'
model.load_state_dict({k.replace('module.',''):v for k,v in torch.load('checkpoint.pt').items()}) # 相當(dāng)于用''代替'module.'。 #直接使得需要的鍵名等于期望的鍵名。
3. 最簡(jiǎn)單的方法
加載模型之后,接著將模型DataParallel,此時(shí)就可以load_state_dict。
如果有多個(gè)GPU,將模型并行化,用DataParallel來操作。
這個(gè)過程會(huì)將key值加一個(gè)"module. ***"。
model = VGGNet() params=model.state_dict() #獲得模型的原始狀態(tài)以及參數(shù)。 for k,v in params.items(): print(k) #只打印key值,不打印具體參數(shù)。
4. 總結(jié)
從出錯(cuò)顯示的問題就可以看出,key值不匹配,因此可以選擇多種方法,將模型參數(shù)加載進(jìn)去。
這個(gè)方法通常會(huì)在load_state_dict過程中遇到。將訓(xùn)練好的一個(gè)網(wǎng)絡(luò)參數(shù),移植到另外一個(gè)網(wǎng)絡(luò)上面,繼續(xù)訓(xùn)練。
或者將訓(xùn)練好的網(wǎng)絡(luò)checkpoint加載進(jìn)模型,再次進(jìn)行訓(xùn)練??梢源蛴〕鰉odel state_dict來看出兩者的差別。
model = VGGNet() params=model.state_dict() #獲得模型的原始狀態(tài)以及參數(shù)。 for k,v in params.items(): print(k) #只打印key值,不打印具體參數(shù)。
features.0.0.weight
features.0.1.weight
features.1.conv.3.weight
features.1.conv.4.num_batches_tracked
model = VGGNet() checkpoint = torch.load('checkpoint.pt', map_location='cpu') # Load weights to resume from checkpoint。 # print('**************************************') # 這個(gè)方法能夠直接打印出你保存的checkpoint的鍵和值。 for k,v in checkpoint.items(): print(k) print("*****************************************")
輸出結(jié)果為:
module.features.0.0.weight",
"module.features.0.1.weight",
"module.features.0.1.bias
可以看出不匹配,模型的參數(shù)中,key值不同,多了module。
PS: 追加
在移植參數(shù)的過程中,對(duì)于出現(xiàn) .total_ops和.total_params結(jié)尾的參數(shù),可參考以下代碼:
from collections import OrderedDict checkpoint = torch.load( pretrained_model_file_path, map_location=(None if use_cuda and not remap_to_cpu else "cpu")) new_state_dict = OrderedDict() for k, v in checkpoint.items(): if not k.endswith('total_ops') and not k.endswith('total_params'): name = k[7:] new_state_dict[name] = v
最后
以上為個(gè)人經(jīng)驗(yàn),希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
相關(guān)文章
教你用Type Hint提高Python程序開發(fā)效率
本文通過介紹和實(shí)例教大家如何利用Type Hint來提升Python程序開發(fā)效率,對(duì)大家使用python開發(fā)很有幫助,有需要的參考學(xué)習(xí)。2016-08-08PyQt5基本控件使用詳解:單選按鈕、復(fù)選框、下拉框
這篇文章主要介紹了PyQt5基本控件使用:單選按鈕、復(fù)選框、下拉框,本文中的內(nèi)容和實(shí)例也基本回答了開篇提到的問題。需要的朋友可以參考下2019-08-08Python 找出出現(xiàn)次數(shù)超過數(shù)組長(zhǎng)度一半的元素實(shí)例
這篇文章主要介紹了Python 找出出現(xiàn)次數(shù)超過數(shù)組長(zhǎng)度一半的元素實(shí)例,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過來看看吧2020-05-05python時(shí)間日期函數(shù)與利用pandas進(jìn)行時(shí)間序列處理詳解
python標(biāo)準(zhǔn)庫包含于日期(date)和時(shí)間(time)數(shù)據(jù)的數(shù)據(jù)類型,datetime、time以及calendar模塊會(huì)被經(jīng)常用到,而pandas則可以對(duì)時(shí)間進(jìn)行序列化排序2018-03-03Python3多線程基礎(chǔ)知識(shí)點(diǎn)
在本篇內(nèi)容里小編給大家分享了關(guān)于Python3多線程基礎(chǔ)知識(shí)點(diǎn)內(nèi)容,需要的朋友們跟著學(xué)習(xí)參考下。2019-02-02Python使用struct處理二進(jìn)制的實(shí)例詳解
這篇文章主要介紹了Python使用struct處理二進(jìn)制的實(shí)例詳解的相關(guān)資料,希望通過本文大家能掌握這部分內(nèi)容,需要的朋友可以參考下2017-09-09從訓(xùn)練好的tensorflow模型中打印訓(xùn)練變量實(shí)例
今天小編就為大家分享一篇從訓(xùn)練好的tensorflow模型中打印訓(xùn)練變量實(shí)例,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過來看看吧2020-01-01Python時(shí)間轉(zhuǎn)化方法超全總結(jié)
在生活和工作中,我們每個(gè)人每天都在和時(shí)間打交道。本文就為大家總結(jié)了Python實(shí)現(xiàn)時(shí)間轉(zhuǎn)化的多種方法,快來跟隨小編一起學(xué)習(xí)一下吧2022-03-03