pytorch如何使用訓(xùn)練好的模型預(yù)測新數(shù)據(jù)
pytorch使用訓(xùn)練好的模型預(yù)測新數(shù)據(jù)
神經(jīng)網(wǎng)絡(luò)在進(jìn)行完訓(xùn)練和測試后,如果達(dá)到了較高的正確率的話,我們可以嘗試將模型用于預(yù)測新數(shù)據(jù)。
總共需要兩大部分:神經(jīng)網(wǎng)絡(luò)、預(yù)測函數(shù)(新圖片的加載,傳入模型、得出結(jié)果)。
完整代碼
import torch, glob, cv2 from torchvision import transforms import numpy as np import torch.nn as nn import torch.nn.functional as F class Net(nn.Module): # 神經(jīng)網(wǎng)絡(luò)部分用你自己的 def __init__(self): super(Net, self).__init__() self.conv1 = nn.Conv2d(1, 32, 3, 2, 1) # nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding) self.conv2 = nn.Conv2d(32, 64, 3, 2, 1) self.conv3 = nn.Conv2d(64, 128, 3, 1) self.dropout1 = nn.Dropout2d(0.25) self.dropout2 = nn.Dropout2d(0.5) self.fc1 = nn.Linear(6272, 128) # 6272=128*7*7 self.fc2 = nn.Linear(128, 8) def forward(self, x): x = self.conv1(x) x = F.relu(x) x = self.conv2(x) x = F.relu(x) x = self.conv3(x) x = F.relu(x) x = F.max_pool2d(x, 2) x = self.dropout1(x) x = torch.flatten(x, 1) x = self.fc1(x) x = F.relu(x) x = self.dropout2(x) x = self.fc2(x) self.output = F.log_softmax(x, dim=1) out1 = x return self.output,out1 def predict(): model = Net() model.load_state_dict(torch.load('test.pt')) torch.no_grad() imgfile = glob.glob(r"") # 輸入要預(yù)測的圖片所在路徑 print(len(imgfile), imgfile) for i in imgfile: imgfile1 = i.replace("\\", "/") img = cv2.imdecode(np.fromfile(imgfile1, dtype=np.uint8), cv2.IMREAD_GRAYSCALE) img = cv2.resize(img, (64, 64)) # 是否需要resize取決于新圖片格式與訓(xùn)練時(shí)的是否一致 tran = transforms.ToTensor() img = img.reshape((*img.shape, -1)) img = tran(img) img = img.unsqueeze(0) outputs, out1 = model(img) # outputs,out1修改為你的網(wǎng)絡(luò)的輸出 predicted, index = torch.max(out1, 1) degre = int(index[0]) list = [0, 45, -45, -90, 90, 135, -135, 180] print(predicted, list[degre]) if __name__ == '__main__': predict()
神經(jīng)網(wǎng)絡(luò)部分復(fù)制你在訓(xùn)練時(shí)定義的神經(jīng)網(wǎng)絡(luò)即可,如果模型保存為字典,則需要
model.load_state_dict(torch.load('test.pt'))
新圖片的格式需要與訓(xùn)練測試時(shí)的圖片格式保持一致,所以需要resize,如果新圖片為相同格式略過。
最后的list是你樣本類別的list,每一類的索引需要與label保持一致,例如:
list = ['褲子', '套衫', '連衣裙', '外套', '涼鞋', '襯衫', '運(yùn)動(dòng)鞋', '短靴']
結(jié)果分析
tensor([7.0595], grad_fn=<MaxBackward0>) 45
tensor([11.9538], grad_fn=<MaxBackward0>) -45
tensor([5.8450], grad_fn=<MaxBackward0>) 135
前面的張量tensor代表了各個(gè)類別的“概率”中最大的那一個(gè),然后根據(jù)最大“概率”所在的位置(index)來找到list所對應(yīng)的類別,然后輸出。
pytorch框架--簡單模型預(yù)測
模型預(yù)測示例
使用訓(xùn)練好的模型進(jìn)行預(yù)測
import torchvision from model import Tudui import torch from PIL import Image # 讀取圖像 img = Image.open("./data/train/Dog/9.jpg") # 數(shù)據(jù)預(yù)處理 # 縮放 transform = torchvision.transforms.Compose([torchvision.transforms.Resize((32, 32)), torchvision.transforms.ToTensor()]) image = transform(img) print(image.shape) # 根據(jù)保存方式加載 model = torch.load("tudui_99.pth", map_location=torch.device('cpu')) # 注意維度轉(zhuǎn)換,單張圖片 image1 = torch.reshape(image, (1, 3, 32, 32)) # 測試開關(guān) model.eval() # 節(jié)約性能 with torch.no_grad(): output = model(image1) print(output) # print(output.argmax(1)) # 定義類別對應(yīng)字典 dist = {0: "飛機(jī)", 1: "汽車", 2: "鳥", 3: "貓", 4: "鹿", 5: "狗", 6: "青蛙", 7: "馬", 8: "船", 9: "卡車"} # 轉(zhuǎn)numpy格式,列表內(nèi)取第一個(gè) a = dist[output.argmax(1).numpy()[0]] img.show() print(a)
總結(jié)
以上為個(gè)人經(jīng)驗(yàn),希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
相關(guān)文章
Python pip install如何修改默認(rèn)下載路徑
這篇文章主要介紹了Python pip install如何修改默認(rèn)下載路徑,文中通過示例代碼介紹的非常詳細(xì),對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友可以參考下2020-04-04Python?Melt函數(shù)將寬格式的數(shù)據(jù)表轉(zhuǎn)換為長格式
在數(shù)據(jù)處理和清洗中,melt函數(shù)是Pandas庫中一個(gè)強(qiáng)大而靈活的工具,它的主要功能是將寬格式的數(shù)據(jù)表轉(zhuǎn)換為長格式,從而更方便進(jìn)行分析和可視化,本文將深入探討melt函數(shù)的用法、參數(shù)解析以及實(shí)際應(yīng)用場景2023-12-12linux平臺(tái)使用Python制作BT種子并獲取BT種子信息的方法
這篇文章主要介紹了linux平臺(tái)使用Python制作BT種子并獲取BT種子信息的方法,結(jié)合實(shí)例形式詳細(xì)分析了Python BT模塊的安裝及針對BT種子文件的相關(guān)操作技巧,需要的朋友可以參考下2017-01-01