解決pytorch多GPU訓(xùn)練保存的模型,在單GPU環(huán)境下加載出錯(cuò)問題
背景
在公司用多卡訓(xùn)練模型,得到權(quán)值文件后保存,然后回到實(shí)驗(yàn)室,沒有多卡的環(huán)境,用單卡訓(xùn)練,加載模型時(shí)出錯(cuò),因?yàn)閱慰C(jī)器上,沒有使用DataParallel來加載模型,所以會出現(xiàn)加載錯(cuò)誤。
原因
DataParallel包裝的模型在保存時(shí),權(quán)值參數(shù)前面會帶有module字符,然而自己在單卡環(huán)境下,沒有用DataParallel包裝的模型權(quán)值參數(shù)不帶module。本質(zhì)上保存的權(quán)值文件是一個(gè)有序字典。
解決方法
1.在單卡環(huán)境下,用DataParallel包裝模型。
2.自己重寫Load函數(shù),靈活。
from collections import OrderedDict def myOwnLoad(model, check): modelState = model.state_dict() tempState = OrderedDict() for i in range(len(check.keys())-2): print modelState.keys()[i], check.keys()[i] tempState[modelState.keys()[i]] = check[check.keys()[i]] temp = [[0.02]*1024 for i in range(200)] # mean=0, std=0.02 tempState['myFc.weight'] = torch.normal(mean=0, std=torch.FloatTensor(temp)).cuda() tempState['myFc.bias'] = torch.normal(mean=0, std=torch.FloatTensor([0]*200)).cuda() model.load_state_dict(tempState) return model
補(bǔ)充知識:Pytorch:多GPU訓(xùn)練網(wǎng)絡(luò)與單GPU訓(xùn)練網(wǎng)絡(luò)保存模型的區(qū)別
測試環(huán)境:Python3.6 + Pytorch0.4
在pytorch中,使用多GPU訓(xùn)練網(wǎng)絡(luò)需要用到 【nn.DataParallel】:
gpu_ids = [0, 1, 2, 3] device = t.device("cuda:0" if t.cuda.is_available() else "cpu") # 只能單GPU運(yùn)行 net = LeNet() if len(gpu_ids) > 1: net = nn.DataParallel(net, device_ids=gpu_ids) net = net.to(device)
而使用單GPU訓(xùn)練網(wǎng)絡(luò):
device = t.device("cuda:0" if t.cuda.is_available() else "cpu") # 只能單GPU運(yùn)行
net = LeNet().to(device)
由于多GPU訓(xùn)練使用了 nn.DataParallel(net, device_ids=gpu_ids) 對網(wǎng)絡(luò)進(jìn)行封裝,因此在原始網(wǎng)絡(luò)結(jié)構(gòu)中添加了一層module。網(wǎng)絡(luò)結(jié)構(gòu)如下:
DataParallel( (module): LeNet( (conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1)) (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1)) (fc1): Linear(in_features=400, out_features=120, bias=True) (fc2): Linear(in_features=120, out_features=84, bias=True) (fc3): Linear(in_features=84, out_features=10, bias=True) ) )
而不使用多GPU訓(xùn)練的網(wǎng)絡(luò)結(jié)構(gòu)如下:
LeNet( (conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1)) (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1)) (fc1): Linear(in_features=400, out_features=120, bias=True) (fc2): Linear(in_features=120, out_features=84, bias=True) (fc3): Linear(in_features=84, out_features=10, bias=True) )
由于在測試模型時(shí)不需要用到多GPU測試,因此在保存模型時(shí)應(yīng)該把module層去掉。如下:
if len(gpu_ids) > 1: t.save(net.module.state_dict(), "model.pth") else: t.save(net.state_dict(), "model.pth")
以上這篇解決pytorch多GPU訓(xùn)練保存的模型,在單GPU環(huán)境下加載出錯(cuò)問題就是小編分享給大家的全部內(nèi)容了,希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
相關(guān)文章
Python小游戲之300行代碼實(shí)現(xiàn)俄羅斯方塊
這篇文章主要給大家介紹了關(guān)于Python小游戲之300行代碼實(shí)現(xiàn)俄羅斯方塊的相關(guān)資料,文中通過示例代碼介紹的非常詳細(xì),對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面來一起看看吧2019-01-01Python中BaseHTTPRequestHandler實(shí)現(xiàn)簡單的API接口
本文主要介紹了Python中BaseHTTPRequestHandler實(shí)現(xiàn)簡單的API接口,文中通過示例代碼介紹的非常詳細(xì),對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2023-07-07深度學(xué)習(xí)中shape[0]、shape[1]、shape[2]的區(qū)別詳解
本文主要介紹了深度學(xué)習(xí)中shape[0]、shape[1]、shape[2]的區(qū)別詳解,文中通過示例代碼介紹的非常詳細(xì),對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2022-07-07python基于FTP實(shí)現(xiàn)文件傳輸相關(guān)功能代碼實(shí)例
這篇文章主要介紹了python基于FTP實(shí)現(xiàn)文件傳輸相關(guān)功能代碼實(shí)例,文中通過示例代碼介紹的非常詳細(xì),對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友可以參考下2019-09-09解決Atom安裝Hydrogen無法運(yùn)行python3的問題
今天小編就為大家分享一篇解決Atom安裝Hydrogen無法運(yùn)行python3的問題,具有很好的參考價(jià)值,希望對大家有所幫助。一起跟隨小編過來看看吧2019-08-08詳談套接字中SO_REUSEPORT和SO_REUSEADDR的區(qū)別
下面小編就為大家分享一篇詳談套接字中SO_REUSEPORT和SO_REUSEADDR的區(qū)別,具有很好的參考價(jià)值,希望對大家有所幫助。一起跟隨小編過來看看吧2018-04-04Django實(shí)現(xiàn)jquery select2帶搜索的下拉框
最近在開發(fā)一個(gè)web應(yīng)用中需要用到帶搜索功能下拉框,本文實(shí)現(xiàn)Django實(shí)現(xiàn)jquery select2帶搜索的下拉框,感興趣的小伙伴們可以參考一下2021-06-06Python數(shù)據(jù)結(jié)構(gòu)與算法中的棧詳解(1)
這篇文章主要為大家詳細(xì)介紹了Python中的棧,文中示例代碼介紹的非常詳細(xì),具有一定的參考價(jià)值,感興趣的小伙伴們可以參考一下,希望能夠給你帶來幫助2022-03-03