pytorch如何使用訓(xùn)練好的模型預(yù)測新數(shù)據(jù)
pytorch使用訓(xùn)練好的模型預(yù)測新數(shù)據(jù)
神經(jīng)網(wǎng)絡(luò)在進(jìn)行完訓(xùn)練和測試后,如果達(dá)到了較高的正確率的話,我們可以嘗試將模型用于預(yù)測新數(shù)據(jù)。
總共需要兩大部分:神經(jīng)網(wǎng)絡(luò)、預(yù)測函數(shù)(新圖片的加載,傳入模型、得出結(jié)果)。
完整代碼
import torch, glob, cv2
from torchvision import transforms
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
class Net(nn.Module): # 神經(jīng)網(wǎng)絡(luò)部分用你自己的
def __init__(self):
super(Net, self).__init__()
self.conv1 = nn.Conv2d(1, 32, 3, 2, 1) # nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
self.conv2 = nn.Conv2d(32, 64, 3, 2, 1)
self.conv3 = nn.Conv2d(64, 128, 3, 1)
self.dropout1 = nn.Dropout2d(0.25)
self.dropout2 = nn.Dropout2d(0.5)
self.fc1 = nn.Linear(6272, 128) # 6272=128*7*7
self.fc2 = nn.Linear(128, 8)
def forward(self, x):
x = self.conv1(x)
x = F.relu(x)
x = self.conv2(x)
x = F.relu(x)
x = self.conv3(x)
x = F.relu(x)
x = F.max_pool2d(x, 2)
x = self.dropout1(x)
x = torch.flatten(x, 1)
x = self.fc1(x)
x = F.relu(x)
x = self.dropout2(x)
x = self.fc2(x)
self.output = F.log_softmax(x, dim=1)
out1 = x
return self.output,out1
def predict():
model = Net()
model.load_state_dict(torch.load('test.pt'))
torch.no_grad()
imgfile = glob.glob(r"") # 輸入要預(yù)測的圖片所在路徑
print(len(imgfile), imgfile)
for i in imgfile:
imgfile1 = i.replace("\\", "/")
img = cv2.imdecode(np.fromfile(imgfile1, dtype=np.uint8), cv2.IMREAD_GRAYSCALE)
img = cv2.resize(img, (64, 64)) # 是否需要resize取決于新圖片格式與訓(xùn)練時的是否一致
tran = transforms.ToTensor()
img = img.reshape((*img.shape, -1))
img = tran(img)
img = img.unsqueeze(0)
outputs, out1 = model(img) # outputs,out1修改為你的網(wǎng)絡(luò)的輸出
predicted, index = torch.max(out1, 1)
degre = int(index[0])
list = [0, 45, -45, -90, 90, 135, -135, 180]
print(predicted, list[degre])
if __name__ == '__main__':
predict()神經(jīng)網(wǎng)絡(luò)部分復(fù)制你在訓(xùn)練時定義的神經(jīng)網(wǎng)絡(luò)即可,如果模型保存為字典,則需要
model.load_state_dict(torch.load('test.pt'))新圖片的格式需要與訓(xùn)練測試時的圖片格式保持一致,所以需要resize,如果新圖片為相同格式略過。
最后的list是你樣本類別的list,每一類的索引需要與label保持一致,例如:
list = ['褲子', '套衫', '連衣裙', '外套', '涼鞋', '襯衫', '運(yùn)動鞋', '短靴']
結(jié)果分析
tensor([7.0595], grad_fn=<MaxBackward0>) 45
tensor([11.9538], grad_fn=<MaxBackward0>) -45
tensor([5.8450], grad_fn=<MaxBackward0>) 135
前面的張量tensor代表了各個類別的“概率”中最大的那一個,然后根據(jù)最大“概率”所在的位置(index)來找到list所對應(yīng)的類別,然后輸出。
pytorch框架--簡單模型預(yù)測
模型預(yù)測示例
使用訓(xùn)練好的模型進(jìn)行預(yù)測
import torchvision
from model import Tudui
import torch
from PIL import Image
# 讀取圖像
img = Image.open("./data/train/Dog/9.jpg")
# 數(shù)據(jù)預(yù)處理
# 縮放
transform = torchvision.transforms.Compose([torchvision.transforms.Resize((32, 32)),
torchvision.transforms.ToTensor()])
image = transform(img)
print(image.shape)
# 根據(jù)保存方式加載
model = torch.load("tudui_99.pth", map_location=torch.device('cpu'))
# 注意維度轉(zhuǎn)換,單張圖片
image1 = torch.reshape(image, (1, 3, 32, 32))
# 測試開關(guān)
model.eval()
# 節(jié)約性能
with torch.no_grad():
output = model(image1)
print(output)
# print(output.argmax(1))
# 定義類別對應(yīng)字典
dist = {0: "飛機(jī)", 1: "汽車", 2: "鳥", 3: "貓", 4: "鹿", 5: "狗", 6: "青蛙", 7: "馬", 8: "船", 9: "卡車"}
# 轉(zhuǎn)numpy格式,列表內(nèi)取第一個
a = dist[output.argmax(1).numpy()[0]]
img.show()
print(a)總結(jié)
以上為個人經(jīng)驗(yàn),希望能給大家一個參考,也希望大家多多支持腳本之家。
相關(guān)文章
Python pip install如何修改默認(rèn)下載路徑
這篇文章主要介紹了Python pip install如何修改默認(rèn)下載路徑,文中通過示例代碼介紹的非常詳細(xì),對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價值,需要的朋友可以參考下2020-04-04
Python?Melt函數(shù)將寬格式的數(shù)據(jù)表轉(zhuǎn)換為長格式
在數(shù)據(jù)處理和清洗中,melt函數(shù)是Pandas庫中一個強(qiáng)大而靈活的工具,它的主要功能是將寬格式的數(shù)據(jù)表轉(zhuǎn)換為長格式,從而更方便進(jìn)行分析和可視化,本文將深入探討melt函數(shù)的用法、參數(shù)解析以及實(shí)際應(yīng)用場景2023-12-12
linux平臺使用Python制作BT種子并獲取BT種子信息的方法
這篇文章主要介紹了linux平臺使用Python制作BT種子并獲取BT種子信息的方法,結(jié)合實(shí)例形式詳細(xì)分析了Python BT模塊的安裝及針對BT種子文件的相關(guān)操作技巧,需要的朋友可以參考下2017-01-01

