Pytorch之保存讀取模型實(shí)例
pytorch保存數(shù)據(jù)
pytorch保存數(shù)據(jù)的格式為.t7文件或者.pth文件,t7文件是沿用torch7中讀取模型權(quán)重的方式。而pth文件是python中存儲(chǔ)文件的常用格式。而在keras中則是使用.h5文件。
# 保存模型示例代碼 print('===> Saving models...') state = { 'state': model.state_dict(), 'epoch': epoch # 將epoch一并保存 } if not os.path.isdir('checkpoint'): os.mkdir('checkpoint') torch.save(state, './checkpoint/autoencoder.t7')
保存用到torch.save函數(shù),注意該函數(shù)第一個(gè)參數(shù)可以是單個(gè)值也可以是字典,字典可以存更多你要保存的參數(shù)(不僅僅是權(quán)重?cái)?shù)據(jù))。
pytorch讀取數(shù)據(jù)
pytorch讀取數(shù)據(jù)使用的方法和我們平時(shí)使用預(yù)訓(xùn)練參數(shù)所用的方法是一樣的,都是使用load_state_dict這個(gè)函數(shù)。
下方的代碼和上方的保存代碼可以搭配使用。
print('===> Try resume from checkpoint') if os.path.isdir('checkpoint'): try: checkpoint = torch.load('./checkpoint/autoencoder.t7') model.load_state_dict(checkpoint['state']) # 從字典中依次讀取 start_epoch = checkpoint['epoch'] print('===> Load last checkpoint data') except FileNotFoundError: print('Can\'t found autoencoder.t7') else: start_epoch = 0 print('===> Start from scratch')
以上是pytorch讀取的方法匯總,但是要注意,在使用官方的預(yù)處理模型進(jìn)行讀取時(shí),一般使用的格式是pth,使用官方的模型讀取命令會(huì)檢查你模型的格式是否正確,如果不是使用官方提供模型通過下面的函數(shù)強(qiáng)行讀取模型(將其他模型例如caffe模型轉(zhuǎn)過來的模型放到指定目錄下)會(huì)發(fā)生錯(cuò)誤。
def vgg19(pretrained=False, **kwargs): """VGG 19-layer model (configuration "E") Args: pretrained (bool): If True, returns a model pre-trained on ImageNet """ model = VGG(make_layers(cfg['E']), **kwargs) if pretrained: model.load_state_dict(model_zoo.load_url(model_urls['vgg19'])) return model
假如我們有從caffe模型轉(zhuǎn)過來的pytorch模型([0-255,BGR]),我們可以使用:
model_dir = '自己的模型地址' model = VGG() model.load_state_dict(torch.load(model_dir + 'vgg_conv.pth'))
也就是pytorch的讀取函數(shù)進(jìn)行讀取即可。
以上這篇Pytorch之保存讀取模型實(shí)例就是小編分享給大家的全部?jī)?nèi)容了,希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
相關(guān)文章
使用Python獲取網(wǎng)段IP個(gè)數(shù)以及地址清單的方法
今天小編就為大家分享一篇使用Python獲取網(wǎng)段IP個(gè)數(shù)以及地址清單的方法,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過來看看吧2018-11-11Python實(shí)現(xiàn)給qq郵箱發(fā)送郵件的方法
這篇文章主要介紹了Python實(shí)現(xiàn)給qq郵箱發(fā)送郵件的方法,涉及Python郵件發(fā)送的相關(guān)技巧,需要的朋友可以參考下2015-05-05python通過urllib2爬網(wǎng)頁上種子下載示例
這篇文章主要介紹了通過urllib2、re模塊抓種子下載的示例,需要的朋友可以參考下2014-02-02Python安裝Scrapy庫(kù)的常見報(bào)錯(cuò)解決
本文主要介紹了Python安裝Scrapy庫(kù)的常見報(bào)錯(cuò)解決,文中通過圖文示例介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2023-11-11OpenCV物體跟蹤樹莓派視覺小車實(shí)現(xiàn)過程學(xué)習(xí)
這篇文章主要介紹了OpenCV物體跟蹤樹莓派視覺小車的實(shí)現(xiàn)過程學(xué)習(xí),有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進(jìn)步2021-10-10python中的import、from import及import as的區(qū)別解析
在Python中,如果import的語句比較長(zhǎng),導(dǎo)致后續(xù)引用不方便,可以使用as語法,這篇文章主要介紹了python中的import、from import以及import as的區(qū)別,需要的朋友可以參考下2022-10-10Python實(shí)現(xiàn)方便使用的級(jí)聯(lián)進(jìn)度信息實(shí)例
這篇文章主要介紹了Python實(shí)現(xiàn)方便使用的級(jí)聯(lián)進(jìn)度信息,實(shí)例分析了Python顯示級(jí)聯(lián)進(jìn)度信息的相關(guān)技巧,非常具有實(shí)用價(jià)值,需要的朋友可以參考下2015-05-05