pytorch加載訓(xùn)練好的模型用來(lái)測(cè)試或者處理方式
1.直接加載預(yù)訓(xùn)練模型
如果我們使用的模型和原模型完全一樣,
那么我們可以直接加載別人訓(xùn)練好的模型:
import torchvision.models as models resnet50 = models.resnet50(pretrained=True)
如果只需要網(wǎng)絡(luò)結(jié)構(gòu),不需要用預(yù)訓(xùn)練模型的參數(shù)來(lái)初始化,
那么就是:
model =torchvision.models.resnet50(pretrained=False)
2.修改某一層
PyTorch中的torchvision里已經(jīng)有很多常用的模型了,
可以直接調(diào)用:
- AlexNet
- VGG
- ResNet
- SqueezeNet
- DenseNet
import torchvision.models as models resnet18 = models.resnet18() alexnet = models.alexnet() squeezenet = models.squeezenet1_0() densenet = models.densenet_161()
但是對(duì)于我們的任務(wù)而言有些層并不是直接能用,需要我們微微改一下,
比如,resnet最后的全連接層是分1000類,而我們只有21類;
又比如,resnet第一層卷積接收的通道是3, 我們可能輸入圖片的通道是4,
那么可以通過(guò)以下方法修改:
resnet.conv1 = nn.Conv2d(4, 64, kernel_size=7, stride=2, padding=3, bias=False) resnet.fc = nn.Linear(2048, 21)
3.加載部分預(yù)訓(xùn)練模型
其實(shí)大多數(shù)時(shí)候我們需要根據(jù)我們的任務(wù)調(diào)節(jié)我們的模型,所以很難保證模型和公開(kāi)的模型完全一樣,但是預(yù)訓(xùn)練模型的參數(shù)確實(shí)有助于提高訓(xùn)練的準(zhǔn)確率,為了結(jié)合二者的優(yōu)點(diǎn),就需要我們加載部分預(yù)訓(xùn)練模型。
#加載model,model是自己定義好的模型 resnet50 = models.resnet50(pretrained=True) model =Net(...) #讀取參數(shù) pretrained_dict =resnet50.state_dict() model_dict = model.state_dict() #將pretrained_dict里不屬于model_dict的鍵剔除掉 pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} # 更新現(xiàn)有的model_dict model_dict.update(pretrained_dict) # 加載我們真正需要的state_dict model.load_state_dict(model_dict) # 加載我們真正需要的state_dict model.load_state_dict(model_dict)
4. 保存和加載自己的模型
pytorch保存模型的方式有兩種:
- 第一種:將整個(gè)網(wǎng)絡(luò)都都保存下來(lái)
- 第二種:僅保存和加載模型參數(shù)(推薦使用這樣的方法)
4.1 保存和加載整個(gè)模型
# 保存 torch.save(model_object, Path) # 加載 model = torch.load(Path)
4.2 僅保存和加載模型參數(shù)(推薦使用)
# ----------------保存模型參數(shù)-------------------------- torch.save(model.state_dict(), PATH) #example torch.save(resnet50.state_dict(),'ckp/model.pth') # ----------------加載模型參數(shù)-------------------------- model = ModelClass(*args, **kwargs) # 這是你后來(lái)設(shè)置的模型 model.load_state_dict(torch.load(PATH)) # 加載參數(shù) #example resnet=resnet50(pretrained=True) resnet.load_state_dict(torch.load('ckp/model.pth'))
4.3 每個(gè)epoch保存一個(gè)模型參數(shù)
for epoch in range(start_epoch, nEpochs + 1): train(training_data_loader, optimizer, model, criterion, epoch) save_checkpoint(model, epoch) def save_checkpoint(model, epoch): model_out_path = "checkpoint/" + "model_epoch_{}.pth".format(epoch) state = {"epoch": epoch ,"model": model} if not os.path.exists("checkpoint/"): os.makedirs("checkpoint/") torch.save(state, model_out_path) print("Checkpoint saved to {}".format(model_out_path))
上面的代碼中start_epoch是開(kāi)始保存模型的epoch,nEpochs是總共訓(xùn)練的次數(shù)。
train()里面的參數(shù),是訓(xùn)練的過(guò)程:一些訓(xùn)練數(shù)據(jù),優(yōu)化器,模型和訓(xùn)練標(biāo)準(zhǔn)。
總結(jié)
以上為個(gè)人經(jīng)驗(yàn),希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
相關(guān)文章
Django cookie和session的應(yīng)用場(chǎng)景及如何使用
今天我們來(lái)重點(diǎn)看下Django中session和cookie的用法吧。我們會(huì)介紹cookie和session的工作原理,還會(huì)分享實(shí)際應(yīng)用的案例。2021-04-04Python進(jìn)階教程之創(chuàng)建本地PyPI倉(cāng)庫(kù)
pypi是一個(gè)python包的倉(cāng)庫(kù),里面有很多別人寫好的python庫(kù),你可以通過(guò)easy_install或者pip進(jìn)行安裝,下面這篇文章主要給大家介紹了關(guān)于Python進(jìn)階教程之創(chuàng)建本地PyPI倉(cāng)庫(kù)的相關(guān)資料,需要的朋友可以參考下2021-10-10Python標(biāo)準(zhǔn)庫(kù)re的使用舉例(正則化匹配)
正則表達(dá)式re是內(nèi)置函數(shù),通過(guò)一定的匹配規(guī)則獲取指定的數(shù)據(jù),下面這篇文章主要給大家介紹了關(guān)于Python標(biāo)準(zhǔn)庫(kù)re的使用舉例,文中通過(guò)實(shí)例代碼介紹的非常詳細(xì),需要的朋友可以參考下2022-10-10Python實(shí)現(xiàn)解析與生成JSON數(shù)據(jù)
JSON文件是一種輕量級(jí)的數(shù)據(jù)交換格式,它采用了一種類似于JavaScript語(yǔ)法的結(jié)構(gòu),可以方便地在不同平臺(tái)和編程語(yǔ)言之間進(jìn)行數(shù)據(jù)交換,下面我們就來(lái)學(xué)習(xí)一下Python如何使用內(nèi)置的json模塊來(lái)讀取和寫入JSON文件吧2023-12-12TensorFlow打印tensor值的實(shí)現(xiàn)方法
今天小編就為大家分享一篇TensorFlow打印tensor值的實(shí)現(xiàn)方法,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧2018-07-07Keras目標(biāo)檢測(cè)mtcnn?facenet搭建人臉識(shí)別平臺(tái)
這篇文章主要為大家介紹了Keras目標(biāo)檢測(cè)mtcnn?facenet搭建人臉識(shí)別平臺(tái),有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進(jìn)步,早日升職加薪2022-05-05如何用PyMongo在Python中操作MongoDB的超完整指南
本文詳細(xì)介紹了如何使用Python的PyMongo庫(kù)操作MongoDB,涵蓋了數(shù)據(jù)庫(kù)連接、文檔創(chuàng)建、數(shù)據(jù)操作和高級(jí)功能的使用,通過(guò)這些知識(shí)點(diǎn),開(kāi)發(fā)者可以高效地管理和操作MongoDB數(shù)據(jù)庫(kù),需要的朋友可以參考下2024-11-11python讀出當(dāng)前時(shí)間精度到秒的代碼
在本文里小編給各位分享了一篇關(guān)于python怎么讀出當(dāng)前時(shí)間精度到秒的內(nèi)容,對(duì)此有需要的朋友們可以學(xué)習(xí)參考下。2019-07-07Python實(shí)現(xiàn)對(duì)特定列表進(jìn)行從小到大排序操作示例
這篇文章主要介紹了Python實(shí)現(xiàn)對(duì)特定列表進(jìn)行從小到大排序操作,涉及Python文件讀取、計(jì)算、正則匹配、排序等相關(guān)操作技巧,需要的朋友可以參考下2019-02-02