pytorch模型保存與加載中的一些問題實戰(zhàn)記錄
前言
最近使用pytorch訓(xùn)練模型,保存模型后再次加載使用出現(xiàn)了一些問題。記錄一下解決方案!
一、torch中模型保存和加載的方式
1、模型參數(shù)和模型結(jié)構(gòu)保存和加載
torch.save(model,path) torch.load(path)
2、只保存模型的參數(shù)和加載——這種方式比較安全,但是比較稍微麻煩一點點
torch.save(model.state_dict(),path) model_state_dic = torch.load(path) model.load_state_dic(model_state_dic)
二、torch中模型保存和加載出現(xiàn)的問題
1、單卡模型下保存模型結(jié)構(gòu)和參數(shù)后加載出現(xiàn)的問題
模型保存的時候會把模型結(jié)構(gòu)定義文件路徑記錄下來,加載的時候就會根據(jù)路徑解析它然后裝載參數(shù);當(dāng)把模型定義文件路徑修改以后,使用torch.load(path)就會報錯。
把model文件夾修改為models后,再加載就會報錯。
import torch from model.TextRNN import TextRNN load_model = torch.load('experiment_model_save/textRNN.bin') print('load_model',load_model)
這種保存完整模型結(jié)構(gòu)和參數(shù)的方式,一定不要改動模型定義文件路徑。
2、多卡機器單卡訓(xùn)練模型保存后在單卡機器上加載會報錯
在多卡機器上有多張顯卡0號開始,現(xiàn)在模型在n>=1上的顯卡訓(xùn)練保存后,拷貝在單卡機器上加載
import torch from model.TextRNN import TextRNN load_model = torch.load('experiment_model_save/textRNN_cuda_1.bin') print('load_model',load_model)
會出現(xiàn)cuda device不匹配的問題——你保存的模代碼段 小部件型是使用的cuda1,那么采用torch.load()打開的時候,會默認(rèn)的去尋找cuda1,然后把模型加載到該設(shè)備上。這個時候可以直接使用map_location來解決,把模型加載到CPU上即可。
load_model = torch.load('experiment_model_save/textRNN_cuda_1.bin',map_location=torch.device('cpu'))
3、多卡訓(xùn)練模型保存模型結(jié)構(gòu)和參數(shù)后加載出現(xiàn)的問題
當(dāng)用多GPU同時訓(xùn)練模型之后,不管是采用模型結(jié)構(gòu)和參數(shù)一起保存還是單獨保存模型參數(shù),然后在單卡下加載都會出現(xiàn)問題
a、模型結(jié)構(gòu)和參數(shù)一起保然后在加載
torch.distributed.init_process_group(backend='nccl')
模型訓(xùn)練的時候采用上述多進程的方式,所以你在加載的時候也要聲明,不然就會報錯。
b、單獨保存模型參數(shù)
model = Transformer(num_encoder_layers=6,num_decoder_layers=6) state_dict = torch.load('train_model/clip/experiment.pt') model.load_state_dict(state_dict)
同樣會出現(xiàn)問題,不過這里出現(xiàn)的問題是參數(shù)字典的key和模型定義的key不一樣
原因是多GPU訓(xùn)練下,使用分布式訓(xùn)練的時候會給模型進行一個包裝,代碼如下:
model = torch.load('train_model/clip/Vtransformers_bert_6_layers_encoder_clip.bin') print(model) model.cuda(args.local_rank) 。。。。。。 model = nn.parallel.DistributedDataParallel(model,device_ids=[args.local_rank],find_unused_parameters=True) print('model',model)
包裝前的模型結(jié)構(gòu):
包裝后的模型
在外層多了DistributedDataParallel以及module,所以才會導(dǎo)致在單卡環(huán)境下加載模型權(quán)重的時候出現(xiàn)權(quán)重的keys不一致。
三、正確的保存模型和加載的方法
if gpu_count > 1: torch.save(model.module.state_dict(),save_path) else: torch.save(model.state_dict(),save_path) model = Transformer(num_encoder_layers=6,num_decoder_layers=6) state_dict = torch.load(save_path) model.load_state_dict(state_dict)
這樣就是比較好的范式,加載不會出錯。
總結(jié)
到此這篇關(guān)于pytorch模型保存與加載中的一些問題的文章就介紹到這了,更多相關(guān)pytorch模型保存與加載內(nèi)容請搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
相關(guān)文章
Python實現(xiàn)利用最大公約數(shù)求三個正整數(shù)的最小公倍數(shù)示例
這篇文章主要介紹了Python實現(xiàn)利用最大公約數(shù)求三個正整數(shù)的最小公倍數(shù),涉及Python數(shù)學(xué)運算相關(guān)操作技巧,需要的朋友可以參考下2017-09-09matplotlib 對坐標(biāo)的控制,加圖例注釋的操作
這篇文章主要介紹了matplotlib 對坐標(biāo)的控制,加圖例注釋的操作,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2020-04-04Python退出While循環(huán)的3種方法舉例詳解
在每次循環(huán)結(jié)束后,我們需要檢查循環(huán)條件是否滿足。如果條件滿足,則繼續(xù)執(zhí)行循環(huán)體內(nèi)的代碼,否則退出循環(huán),這篇文章主要給大家介紹了關(guān)于Python退出While循環(huán)的3種方法,需要的朋友可以參考下2023-10-10