pytorch模型的保存和加載、checkpoint操作
其實(shí)之前筆者寫(xiě)代碼的時(shí)候用到模型的保存和加載,需要用的時(shí)候就去度娘搜一下大致代碼,現(xiàn)在有時(shí)間就來(lái)整理下整個(gè)pytorch模型的保存和加載,開(kāi)始學(xué)習(xí)把~
pytorch的模型和參數(shù)是分開(kāi)的,可以分別保存或加載模型和參數(shù)。所以pytorch的保存和加載對(duì)應(yīng)存在兩種方式:
1. 直接保存加載模型
(1)保存和加載整個(gè)模型
# 保存模型 torch.save(model, 'model.pth\pkl\pt') #一般形式torch.save(net, PATH) # 加載模型 model = torch.load('model.pth\pkl\pt') #一般形式為model_dict=torch.load(PATH)
(2)僅保存和加載模型參數(shù)(推薦使用,需要提前手動(dòng)構(gòu)建模型)
速度快,占空間少
# 保存模型參數(shù) torch.save(model.state_dict(), 'model.pth\pkl\pt') #一般形式為torch.save(net.state_dict(),PATH) # 加載模型參數(shù) model.load_state_dict(torch.load('model.pth\pkl\pt') #一般形式為model_dict=model.load_state_dict(torch.load(PATH))
state_dict() 是一個(gè)Python字典,將每一層映射成它的參數(shù)張量。注意只有帶有可學(xué)習(xí)參數(shù)的層(卷積層、全連接層等),以及注冊(cè)的緩存(batchnorm的運(yùn)行平均值)在state_dict 中才有記錄。state_dict同樣包含優(yōu)化器對(duì)象,存儲(chǔ)了優(yōu)化器的狀態(tài),所使用到的超參數(shù)。
然而,在實(shí)驗(yàn)中往往需要保存更多的信息,比如優(yōu)化器的參數(shù),那么可以采取下面的方法保存:
torch.save({'epoch': epochID + 1, 'state_dict': model.state_dict(), 'best_loss': lossMIN, 'optimizer': optimizer.state_dict(),'alpha': loss.alpha, 'gamma': loss.gamma}, checkpoint_path + '/m-' + launchTimestamp + '-' + str("%.4f" % lossMIN) + '.pth.tar')
如下一個(gè)完整的使用model.state_dict()和optimizer.state_dict()例子:
# 定義模型 class TheModelClass(nn.Module): #定義一個(gè)神經(jīng)網(wǎng)絡(luò)模型 TheModelClass def __init__(self): super(TheModelClass, self).__init__() self.conv1 = nn.Conv2d(3, 6, 5) self.pool = nn.MaxPool2d(2, 2) self.conv2 = nn.Conv2d(6, 16, 5) self.fc1 = nn.Linear(16 * 5 * 5, 120) self.fc2 = nn.Linear(120, 84) self.fc3 = nn.Linear(84, 10) def forward(self, x): x = self.pool(F.relu(self.conv1(x))) x = self.pool(F.relu(self.conv2(x))) x = x.view(-1, 16 * 5 * 5) x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = self.fc3(x) return x # 初始化模型 model = TheModelClass() # 初始化優(yōu)化器 optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9) # 打印模型的 state_dict print("Model's state_dict:") for param_tensor in model.state_dict(): # param_tensor 為參數(shù)名稱(chēng) print(param_tensor, "\t", model.state_dict()[param_tensor].size()) # 打印優(yōu)化器的 state_dict print("Optimizer's state_dict:") for var_name in optimizer.state_dict(): print(var_name, "\t", optimizer.state_dict()[var_name])
輸出結(jié)果:
Model's state_dict:
conv1.weight torch.Size([6, 3, 5, 5])
conv1.bias torch.Size([6])
conv2.weight torch.Size([16, 6, 5, 5])
conv2.bias torch.Size([16])
fc1.weight torch.Size([120, 400])
fc1.bias torch.Size([120])
fc2.weight torch.Size([84, 120])
fc2.bias torch.Size([84])
fc3.weight torch.Size([10, 84])
fc3.bias torch.Size([10])Optimizer's state_dict:
state {}
param_groups [{'lr': 0.001, 'momentum': 0.9, 'dampening': 0, 'weight_decay': 0, 'nesterov': False, 'params': [4675713712, 4675713784, 4675714000, 4675714072, 4675714216, 4675714288, 4675714432, 4675714504, 4675714648, 4675714720]}]
(3)load提供了很多重載的功能,其可以把在GPU上訓(xùn)練的權(quán)重加載到CPU上跑
torch.load('tensors.pt') # 強(qiáng)制所有GPU張量加載到CPU中 torch.load('tensors.pt', map_location=lambda storage, loc: storage) #或者model.load_state_dict(torch.load('model.pth', map_location='cpu')) # 把所有的張量加載到GPU 1中 torch.load('tensors.pt', map_location=lambda storage, loc: storage.cuda(1)) # 把張量從GPU 1 移動(dòng)到 GPU 0 torch.load('tensors.pt', map_location={'cuda:1':'cuda:0'})
上述代碼只有在模型在一個(gè)GPU上訓(xùn)練時(shí)才起作用。如果我在多個(gè)GPU上訓(xùn)練模型并且保存它,然后嘗試在CPU上加載,會(huì)得到錯(cuò)誤:KeyError: ‘unexpected key “module.conv1.weight” in state_dict' 如何解決?
因?yàn)榇藭r(shí)已經(jīng)使用模型保存了模型nn.DataParallel,該模型將模型存儲(chǔ)在該模型中module,而現(xiàn)在您正試圖加載模型DataParallel。您可以nn.DataParallel在網(wǎng)絡(luò)中暫時(shí)添加一個(gè)加載目的,也可以加載權(quán)重文件,創(chuàng)建一個(gè)沒(méi)有module前綴的新的有序字典,然后加載它??吹冢?)點(diǎn)
(4)通過(guò)DataParalle使用多GPU時(shí)的保存和加載
odel=DataParalle(model) #保存參數(shù) torch.save(model.module.state_dict(), 'model.pth')
由此看出多個(gè)GPU時(shí)多了一個(gè)該模型中module,加載再cpu時(shí),創(chuàng)建一個(gè)沒(méi)有module前綴的新的有序字典,然后加載它。
補(bǔ)充:一般來(lái)說(shuō),PyTorch的模型以.pt或者.pth文件格式保存。
2. 保存加載用于推理的常規(guī)Checkpoint/或繼續(xù)訓(xùn)練**
checkpoint檢查點(diǎn):不僅保存模型的參數(shù),優(yōu)化器參數(shù),還有l(wèi)oss,epoch等(相當(dāng)于一個(gè)保存模型的文件夾)
if (epoch+1) % checkpoint_interval == 0: checkpoint = {"model_state_dict": net.state_dict(), "optimizer_state_dict": optimizer.state_dict(), "epoch": epoch} path_checkpoint = "./checkpoint_{}_epoch.pkl".format(epoch) torch.save(checkpoint, path_checkpoint) #或者 #保存 torch.save({ 'epoch': epoch, 'model_state_dict': model.state_dict(), 'optimizer_state_dict': optimizer.state_dict(), 'loss': loss, ... }, PATH) #加載 model = TheModelClass(*args, **kwargs) optimizer = TheOptimizerClass(*args, **kwargs) checkpoint = torch.load(PATH) model.load_state_dict(checkpoint['model_state_dict']) optimizer.load_state_dict(checkpoint['optimizer_state_dict']) epoch = checkpoint['epoch'] loss = checkpoint['loss'] model.eval() # - 或者 - model.train()
注意:
在保存用于推理或者繼續(xù)訓(xùn)練的常規(guī)檢查點(diǎn)的時(shí)候,除了模型的state_dict之外,還必須保存其他參數(shù)。保存優(yōu)化器的state_dict也非常重要,因?yàn)樗四P驮谟?xùn)練時(shí)候優(yōu)化器的緩存和參數(shù)。除此之外,還可以保存停止訓(xùn)練時(shí)epoch數(shù),最新的模型損失,額外的torch.nn.Embedding層等。
要保存多個(gè)組件,則將它們放到一個(gè)字典中,然后使用torch.save()序列化這個(gè)字典。一般來(lái)說(shuō),使用.tar文件格式來(lái)保存這些檢查點(diǎn)。
加載各個(gè)組件,首先初始化模型和優(yōu)化器,然后使用torch.load()加載保存的字典,然后可以直接查詢字典中的值來(lái)獲取保存的組件。
同樣,評(píng)估模型的時(shí)候一定不要忘了調(diào)用model.eval()。
是不是很簡(jiǎn)單!!以上為個(gè)人經(jīng)驗(yàn),希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
相關(guān)文章
Python實(shí)現(xiàn)猜數(shù)字小游戲
這篇文章介紹了Python實(shí)現(xiàn)猜數(shù)字小游戲,文中通過(guò)示例代碼介紹的非常詳細(xì)。對(duì)大家的學(xué)習(xí)或工作具有一定的參考借鑒價(jià)值,需要的朋友可以收藏下,方便下次瀏覽觀看2021-12-12Python3使用tracemalloc實(shí)現(xiàn)追蹤mmap內(nèi)存變化
這篇文章主要為大家詳細(xì)介紹了在Python3中如何使用tracemalloc實(shí)現(xiàn)追蹤mmap內(nèi)存變化,文中的示例代碼講解詳細(xì),感興趣的可以了解一下2023-03-03利用Python編寫(xiě)一個(gè)注冊(cè)機(jī)用于生成卡密
這篇文章主要為大家詳細(xì)介紹了如何利用Python編寫(xiě)一個(gè)注冊(cè)機(jī)用于生成卡密(兌換碼),并使用這些卡密登錄應(yīng)用程序,感興趣的小伙伴可以了解下2023-11-11pycharm-professional-2020.1下載與激活的教程
這篇文章主要介紹了pycharm-professional-2020.1下載與激活的教程,本文分為安裝和永久激活兩部分內(nèi)容,需要的朋友可以參考下2020-09-09如何利用python制作時(shí)間戳轉(zhuǎn)換工具詳解
這篇文章主要給大家介紹了關(guān)于如何利用python制作時(shí)間戳轉(zhuǎn)換工具的相關(guān)資料,文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來(lái)一起學(xué)習(xí)學(xué)習(xí)吧2018-09-09pytorch 實(shí)現(xiàn)L2和L1正則化regularization的操作
這篇文章主要介紹了pytorch 實(shí)現(xiàn)L2和L1正則化regularization的操作,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧2021-03-03python計(jì)算導(dǎo)數(shù)并繪圖的實(shí)例
今天小編就為大家分享一篇python計(jì)算導(dǎo)數(shù)并繪圖的實(shí)例,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧2020-02-02