pytorch加載訓(xùn)練好的模型用來測試或者處理方式
1.直接加載預(yù)訓(xùn)練模型
如果我們使用的模型和原模型完全一樣,
那么我們可以直接加載別人訓(xùn)練好的模型:
import torchvision.models as models resnet50 = models.resnet50(pretrained=True)
如果只需要網(wǎng)絡(luò)結(jié)構(gòu),不需要用預(yù)訓(xùn)練模型的參數(shù)來初始化,
那么就是:
model =torchvision.models.resnet50(pretrained=False)
2.修改某一層
PyTorch中的torchvision里已經(jīng)有很多常用的模型了,
可以直接調(diào)用:
- AlexNet
- VGG
- ResNet
- SqueezeNet
- DenseNet
import torchvision.models as models resnet18 = models.resnet18() alexnet = models.alexnet() squeezenet = models.squeezenet1_0() densenet = models.densenet_161()
但是對于我們的任務(wù)而言有些層并不是直接能用,需要我們微微改一下,
比如,resnet最后的全連接層是分1000類,而我們只有21類;
又比如,resnet第一層卷積接收的通道是3, 我們可能輸入圖片的通道是4,
那么可以通過以下方法修改:
resnet.conv1 = nn.Conv2d(4, 64, kernel_size=7, stride=2, padding=3, bias=False) resnet.fc = nn.Linear(2048, 21)
3.加載部分預(yù)訓(xùn)練模型
其實(shí)大多數(shù)時(shí)候我們需要根據(jù)我們的任務(wù)調(diào)節(jié)我們的模型,所以很難保證模型和公開的模型完全一樣,但是預(yù)訓(xùn)練模型的參數(shù)確實(shí)有助于提高訓(xùn)練的準(zhǔn)確率,為了結(jié)合二者的優(yōu)點(diǎn),就需要我們加載部分預(yù)訓(xùn)練模型。
#加載model,model是自己定義好的模型 resnet50 = models.resnet50(pretrained=True) model =Net(...) #讀取參數(shù) pretrained_dict =resnet50.state_dict() model_dict = model.state_dict() #將pretrained_dict里不屬于model_dict的鍵剔除掉 pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict} # 更新現(xiàn)有的model_dict model_dict.update(pretrained_dict) # 加載我們真正需要的state_dict model.load_state_dict(model_dict) # 加載我們真正需要的state_dict model.load_state_dict(model_dict)
4. 保存和加載自己的模型
pytorch保存模型的方式有兩種:
- 第一種:將整個(gè)網(wǎng)絡(luò)都都保存下來
- 第二種:僅保存和加載模型參數(shù)(推薦使用這樣的方法)
4.1 保存和加載整個(gè)模型
# 保存 torch.save(model_object, Path) # 加載 model = torch.load(Path)
4.2 僅保存和加載模型參數(shù)(推薦使用)
# ----------------保存模型參數(shù)-------------------------- torch.save(model.state_dict(), PATH) #example torch.save(resnet50.state_dict(),'ckp/model.pth') # ----------------加載模型參數(shù)-------------------------- model = ModelClass(*args, **kwargs) # 這是你后來設(shè)置的模型 model.load_state_dict(torch.load(PATH)) # 加載參數(shù) #example resnet=resnet50(pretrained=True) resnet.load_state_dict(torch.load('ckp/model.pth'))
4.3 每個(gè)epoch保存一個(gè)模型參數(shù)
for epoch in range(start_epoch, nEpochs + 1): train(training_data_loader, optimizer, model, criterion, epoch) save_checkpoint(model, epoch) def save_checkpoint(model, epoch): model_out_path = "checkpoint/" + "model_epoch_{}.pth".format(epoch) state = {"epoch": epoch ,"model": model} if not os.path.exists("checkpoint/"): os.makedirs("checkpoint/") torch.save(state, model_out_path) print("Checkpoint saved to {}".format(model_out_path))
上面的代碼中start_epoch是開始保存模型的epoch,nEpochs是總共訓(xùn)練的次數(shù)。
train()里面的參數(shù),是訓(xùn)練的過程:一些訓(xùn)練數(shù)據(jù),優(yōu)化器,模型和訓(xùn)練標(biāo)準(zhǔn)。
總結(jié)
以上為個(gè)人經(jīng)驗(yàn),希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
相關(guān)文章
Django cookie和session的應(yīng)用場景及如何使用
今天我們來重點(diǎn)看下Django中session和cookie的用法吧。我們會介紹cookie和session的工作原理,還會分享實(shí)際應(yīng)用的案例。2021-04-04Python進(jìn)階教程之創(chuàng)建本地PyPI倉庫
pypi是一個(gè)python包的倉庫,里面有很多別人寫好的python庫,你可以通過easy_install或者pip進(jìn)行安裝,下面這篇文章主要給大家介紹了關(guān)于Python進(jìn)階教程之創(chuàng)建本地PyPI倉庫的相關(guān)資料,需要的朋友可以參考下2021-10-10Python標(biāo)準(zhǔn)庫re的使用舉例(正則化匹配)
正則表達(dá)式re是內(nèi)置函數(shù),通過一定的匹配規(guī)則獲取指定的數(shù)據(jù),下面這篇文章主要給大家介紹了關(guān)于Python標(biāo)準(zhǔn)庫re的使用舉例,文中通過實(shí)例代碼介紹的非常詳細(xì),需要的朋友可以參考下2022-10-10Python實(shí)現(xiàn)解析與生成JSON數(shù)據(jù)
JSON文件是一種輕量級的數(shù)據(jù)交換格式,它采用了一種類似于JavaScript語法的結(jié)構(gòu),可以方便地在不同平臺和編程語言之間進(jìn)行數(shù)據(jù)交換,下面我們就來學(xué)習(xí)一下Python如何使用內(nèi)置的json模塊來讀取和寫入JSON文件吧2023-12-12TensorFlow打印tensor值的實(shí)現(xiàn)方法
今天小編就為大家分享一篇TensorFlow打印tensor值的實(shí)現(xiàn)方法,具有很好的參考價(jià)值,希望對大家有所幫助。一起跟隨小編過來看看吧2018-07-07Keras目標(biāo)檢測mtcnn?facenet搭建人臉識別平臺
這篇文章主要為大家介紹了Keras目標(biāo)檢測mtcnn?facenet搭建人臉識別平臺,有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進(jìn)步,早日升職加薪2022-05-05如何用PyMongo在Python中操作MongoDB的超完整指南
本文詳細(xì)介紹了如何使用Python的PyMongo庫操作MongoDB,涵蓋了數(shù)據(jù)庫連接、文檔創(chuàng)建、數(shù)據(jù)操作和高級功能的使用,通過這些知識點(diǎn),開發(fā)者可以高效地管理和操作MongoDB數(shù)據(jù)庫,需要的朋友可以參考下2024-11-11python讀出當(dāng)前時(shí)間精度到秒的代碼
在本文里小編給各位分享了一篇關(guān)于python怎么讀出當(dāng)前時(shí)間精度到秒的內(nèi)容,對此有需要的朋友們可以學(xué)習(xí)參考下。2019-07-07Python實(shí)現(xiàn)對特定列表進(jìn)行從小到大排序操作示例
這篇文章主要介紹了Python實(shí)現(xiàn)對特定列表進(jìn)行從小到大排序操作,涉及Python文件讀取、計(jì)算、正則匹配、排序等相關(guān)操作技巧,需要的朋友可以參考下2019-02-02