欧美bbbwbbbw肥妇,免费乱码人妻系列日韩,一级黄片

pytorch模型保存與加載中的一些問題實(shí)戰(zhàn)記錄

 更新時(shí)間:2022年10月28日 12:37:20   作者:colourmind  
一般來說,保存模型是把參數(shù)全部用model.cpu().state_dict(),然后加載模型時(shí)一般用model.load_state_dict(torch.load(model_path)),下面這篇文章主要給大家介紹了關(guān)于pytorch模型保存與加載中的一些問題實(shí)戰(zhàn)記錄,需要的朋友可以參考下

前言

最近使用pytorch訓(xùn)練模型,保存模型后再次加載使用出現(xiàn)了一些問題。記錄一下解決方案!

一、torch中模型保存和加載的方式

1、模型參數(shù)和模型結(jié)構(gòu)保存和加載

torch.save(model,path)
torch.load(path)

2、只保存模型的參數(shù)和加載——這種方式比較安全,但是比較稍微麻煩一點(diǎn)點(diǎn)

torch.save(model.state_dict(),path)
model_state_dic = torch.load(path)
model.load_state_dic(model_state_dic)

二、torch中模型保存和加載出現(xiàn)的問題

1、單卡模型下保存模型結(jié)構(gòu)和參數(shù)后加載出現(xiàn)的問題

模型保存的時(shí)候會(huì)把模型結(jié)構(gòu)定義文件路徑記錄下來,加載的時(shí)候就會(huì)根據(jù)路徑解析它然后裝載參數(shù);當(dāng)把模型定義文件路徑修改以后,使用torch.load(path)就會(huì)報(bào)錯(cuò)。

把model文件夾修改為models后,再加載就會(huì)報(bào)錯(cuò)。

import torch
from model.TextRNN import TextRNN
 
load_model = torch.load('experiment_model_save/textRNN.bin')
print('load_model',load_model)

這種保存完整模型結(jié)構(gòu)和參數(shù)的方式,一定不要改動(dòng)模型定義文件路徑。

2、多卡機(jī)器單卡訓(xùn)練模型保存后在單卡機(jī)器上加載會(huì)報(bào)錯(cuò)

在多卡機(jī)器上有多張顯卡0號(hào)開始,現(xiàn)在模型在n>=1上的顯卡訓(xùn)練保存后,拷貝在單卡機(jī)器上加載

import torch
from model.TextRNN import TextRNN
 
load_model = torch.load('experiment_model_save/textRNN_cuda_1.bin')
print('load_model',load_model)

會(huì)出現(xiàn)cuda device不匹配的問題——你保存的模代碼段 小部件型是使用的cuda1,那么采用torch.load()打開的時(shí)候,會(huì)默認(rèn)的去尋找cuda1,然后把模型加載到該設(shè)備上。這個(gè)時(shí)候可以直接使用map_location來解決,把模型加載到CPU上即可。

load_model = torch.load('experiment_model_save/textRNN_cuda_1.bin',map_location=torch.device('cpu'))

3、多卡訓(xùn)練模型保存模型結(jié)構(gòu)和參數(shù)后加載出現(xiàn)的問題

當(dāng)用多GPU同時(shí)訓(xùn)練模型之后,不管是采用模型結(jié)構(gòu)和參數(shù)一起保存還是單獨(dú)保存模型參數(shù),然后在單卡下加載都會(huì)出現(xiàn)問題

a、模型結(jié)構(gòu)和參數(shù)一起保然后在加載

torch.distributed.init_process_group(backend='nccl')

模型訓(xùn)練的時(shí)候采用上述多進(jìn)程的方式,所以你在加載的時(shí)候也要聲明,不然就會(huì)報(bào)錯(cuò)。

b、單獨(dú)保存模型參數(shù)

model = Transformer(num_encoder_layers=6,num_decoder_layers=6)
state_dict = torch.load('train_model/clip/experiment.pt')
model.load_state_dict(state_dict)

同樣會(huì)出現(xiàn)問題,不過這里出現(xiàn)的問題是參數(shù)字典的key和模型定義的key不一樣

原因是多GPU訓(xùn)練下,使用分布式訓(xùn)練的時(shí)候會(huì)給模型進(jìn)行一個(gè)包裝,代碼如下:

model = torch.load('train_model/clip/Vtransformers_bert_6_layers_encoder_clip.bin')
print(model)
model.cuda(args.local_rank)
。。。。。。
model = nn.parallel.DistributedDataParallel(model,device_ids=[args.local_rank],find_unused_parameters=True)
print('model',model)

包裝前的模型結(jié)構(gòu):

包裝后的模型

在外層多了DistributedDataParallel以及module,所以才會(huì)導(dǎo)致在單卡環(huán)境下加載模型權(quán)重的時(shí)候出現(xiàn)權(quán)重的keys不一致。

三、正確的保存模型和加載的方法

    if gpu_count > 1:
        torch.save(model.module.state_dict(),save_path)
    else:
        torch.save(model.state_dict(),save_path)
    model = Transformer(num_encoder_layers=6,num_decoder_layers=6)
    state_dict = torch.load(save_path)
    model.load_state_dict(state_dict)

這樣就是比較好的范式,加載不會(huì)出錯(cuò)。

總結(jié)

到此這篇關(guān)于pytorch模型保存與加載中的一些問題的文章就介紹到這了,更多相關(guān)pytorch模型保存與加載內(nèi)容請(qǐng)搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!

相關(guān)文章

  • Python通過Pygame繪制移動(dòng)的矩形實(shí)例代碼

    Python通過Pygame繪制移動(dòng)的矩形實(shí)例代碼

    這篇文章主要介紹了Python通過Pygame繪制移動(dòng)的矩形實(shí)例代碼,具有一定借鑒價(jià)值,需要的朋友可以參考下
    2018-01-01
  • Python實(shí)現(xiàn)利用最大公約數(shù)求三個(gè)正整數(shù)的最小公倍數(shù)示例

    Python實(shí)現(xiàn)利用最大公約數(shù)求三個(gè)正整數(shù)的最小公倍數(shù)示例

    這篇文章主要介紹了Python實(shí)現(xiàn)利用最大公約數(shù)求三個(gè)正整數(shù)的最小公倍數(shù),涉及Python數(shù)學(xué)運(yùn)算相關(guān)操作技巧,需要的朋友可以參考下
    2017-09-09
  • python實(shí)現(xiàn)報(bào)表自動(dòng)化詳解

    python實(shí)現(xiàn)報(bào)表自動(dòng)化詳解

    這篇文章主要介紹了python實(shí)現(xiàn)報(bào)表自動(dòng)化詳解,涉及python讀,寫excel—xlwt常用功能,xlutils 常用功能,xlwt寫Excel時(shí)公式的應(yīng)用等相關(guān)內(nèi)容,具有一定參考價(jià)值,需要的朋友可以了解下。
    2017-11-11
  • 利用python list完成最簡(jiǎn)單的DB連接池方法

    利用python list完成最簡(jiǎn)單的DB連接池方法

    這篇文章主要介紹了利用python list完成最簡(jiǎn)單的DB連接池方法,文中通過示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧
    2019-08-08
  • python 日志模塊logging的使用場(chǎng)景及示例

    python 日志模塊logging的使用場(chǎng)景及示例

    這篇文章主要介紹了python 日志模塊logging的使用場(chǎng)景及示例,幫助大家更好的理解和使用python,感興趣的朋友可以了解下
    2021-01-01
  • matplotlib 對(duì)坐標(biāo)的控制,加圖例注釋的操作

    matplotlib 對(duì)坐標(biāo)的控制,加圖例注釋的操作

    這篇文章主要介紹了matplotlib 對(duì)坐標(biāo)的控制,加圖例注釋的操作,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過來看看吧
    2020-04-04
  • 如何使用Python逆向抓取APP數(shù)據(jù)

    如何使用Python逆向抓取APP數(shù)據(jù)

    今天給大伙分享一下 Python 爬蟲的教程,這次主要涉及到的是關(guān)于某 APP 的逆向分析并抓取數(shù)據(jù),具有一定的參考價(jià)值,感興趣的小伙伴們可以參考一下
    2021-05-05
  • Python 多進(jìn)程并發(fā)操作中進(jìn)程池Pool的實(shí)例

    Python 多進(jìn)程并發(fā)操作中進(jìn)程池Pool的實(shí)例

    下面小編就為大家?guī)硪黄狿ython 多進(jìn)程并發(fā)操作中進(jìn)程池Pool的實(shí)例。小編覺得挺不錯(cuò)的,現(xiàn)在就分享給大家,也給大家做個(gè)參考。一起跟隨小編過來看看吧
    2017-11-11
  • Django實(shí)現(xiàn)CAS+OAuth2的方法示例

    Django實(shí)現(xiàn)CAS+OAuth2的方法示例

    這篇文章主要介紹了Django實(shí)現(xiàn)CAS+OAuth2的方法示例,文中通過示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧
    2019-10-10
  • Python退出While循環(huán)的3種方法舉例詳解

    Python退出While循環(huán)的3種方法舉例詳解

    在每次循環(huán)結(jié)束后,我們需要檢查循環(huán)條件是否滿足。如果條件滿足,則繼續(xù)執(zhí)行循環(huán)體內(nèi)的代碼,否則退出循環(huán),這篇文章主要給大家介紹了關(guān)于Python退出While循環(huán)的3種方法,需要的朋友可以參考下
    2023-10-10

最新評(píng)論