欧美bbbwbbbw肥妇,免费乱码人妻系列日韩,一级黄片

pytorch如何使用訓(xùn)練好的模型預(yù)測新數(shù)據(jù)

 更新時(shí)間:2023年06月15日 09:03:45   作者:Xiuxiu_Law  
這篇文章主要介紹了pytorch如何使用訓(xùn)練好的模型預(yù)測新數(shù)據(jù)問題,具有很好的參考價(jià)值,希望對大家有所幫助。如有錯(cuò)誤或未考慮完全的地方,望不吝賜教

pytorch使用訓(xùn)練好的模型預(yù)測新數(shù)據(jù)

神經(jīng)網(wǎng)絡(luò)在進(jìn)行完訓(xùn)練和測試后,如果達(dá)到了較高的正確率的話,我們可以嘗試將模型用于預(yù)測新數(shù)據(jù)。

總共需要兩大部分:神經(jīng)網(wǎng)絡(luò)、預(yù)測函數(shù)(新圖片的加載,傳入模型、得出結(jié)果)。

完整代碼

import torch, glob, cv2
from torchvision import transforms
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
class Net(nn.Module):  # 神經(jīng)網(wǎng)絡(luò)部分用你自己的
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(1, 32, 3, 2, 1)  # nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding)
        self.conv2 = nn.Conv2d(32, 64, 3, 2, 1)
        self.conv3 = nn.Conv2d(64, 128, 3, 1)
        self.dropout1 = nn.Dropout2d(0.25)
        self.dropout2 = nn.Dropout2d(0.5)
        self.fc1 = nn.Linear(6272, 128)  # 6272=128*7*7
        self.fc2 = nn.Linear(128, 8)
    def forward(self, x):
        x = self.conv1(x)
        x = F.relu(x)
        x = self.conv2(x)
        x = F.relu(x)
        x = self.conv3(x)
        x = F.relu(x)
        x = F.max_pool2d(x, 2)
        x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.dropout2(x)
        x = self.fc2(x)
        self.output = F.log_softmax(x, dim=1)
        out1 = x
        return self.output,out1
def predict():
    model = Net()
    model.load_state_dict(torch.load('test.pt'))
    torch.no_grad()
    imgfile = glob.glob(r"")  # 輸入要預(yù)測的圖片所在路徑
    print(len(imgfile), imgfile)
    for i in imgfile:
        imgfile1 = i.replace("\\", "/")
        img = cv2.imdecode(np.fromfile(imgfile1, dtype=np.uint8), cv2.IMREAD_GRAYSCALE)
        img = cv2.resize(img, (64, 64))  # 是否需要resize取決于新圖片格式與訓(xùn)練時(shí)的是否一致
        tran = transforms.ToTensor()
        img = img.reshape((*img.shape, -1))
        img = tran(img)
        img = img.unsqueeze(0)
        outputs, out1 = model(img)  # outputs,out1修改為你的網(wǎng)絡(luò)的輸出
        predicted, index  = torch.max(out1, 1)
        degre = int(index[0])
        list = [0, 45, -45, -90, 90, 135, -135, 180]
        print(predicted, list[degre])
if __name__ == '__main__':
    predict()

神經(jīng)網(wǎng)絡(luò)部分復(fù)制你在訓(xùn)練時(shí)定義的神經(jīng)網(wǎng)絡(luò)即可,如果模型保存為字典,則需要

model.load_state_dict(torch.load('test.pt'))

新圖片的格式需要與訓(xùn)練測試時(shí)的圖片格式保持一致,所以需要resize,如果新圖片為相同格式略過。

最后的list是你樣本類別的list,每一類的索引需要與label保持一致,例如:

list = ['褲子', '套衫', '連衣裙', '外套', '涼鞋', '襯衫', '運(yùn)動(dòng)鞋', '短靴']

結(jié)果分析

tensor([7.0595], grad_fn=<MaxBackward0>) 45
tensor([11.9538], grad_fn=<MaxBackward0>) -45
tensor([5.8450], grad_fn=<MaxBackward0>) 135

前面的張量tensor代表了各個(gè)類別的“概率”中最大的那一個(gè),然后根據(jù)最大“概率”所在的位置(index)來找到list所對應(yīng)的類別,然后輸出。

pytorch框架--簡單模型預(yù)測

模型預(yù)測示例

使用訓(xùn)練好的模型進(jìn)行預(yù)測

import torchvision
from model import Tudui
import torch
from PIL import Image
# 讀取圖像
img = Image.open("./data/train/Dog/9.jpg")
# 數(shù)據(jù)預(yù)處理
# 縮放
transform = torchvision.transforms.Compose([torchvision.transforms.Resize((32, 32)),
                                            torchvision.transforms.ToTensor()])
image = transform(img)
print(image.shape)
# 根據(jù)保存方式加載
model = torch.load("tudui_99.pth", map_location=torch.device('cpu'))
# 注意維度轉(zhuǎn)換,單張圖片
image1 = torch.reshape(image, (1, 3, 32, 32))
# 測試開關(guān)
model.eval()
# 節(jié)約性能
with torch.no_grad():
    output = model(image1)
print(output)
# print(output.argmax(1))
# 定義類別對應(yīng)字典
dist = {0: "飛機(jī)", 1: "汽車", 2: "鳥", 3: "貓", 4: "鹿", 5: "狗", 6: "青蛙", 7: "馬", 8: "船", 9: "卡車"}
# 轉(zhuǎn)numpy格式,列表內(nèi)取第一個(gè)
a = dist[output.argmax(1).numpy()[0]]
img.show()
print(a)

總結(jié)

以上為個(gè)人經(jīng)驗(yàn),希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。

相關(guān)文章

  • Python pip install如何修改默認(rèn)下載路徑

    Python pip install如何修改默認(rèn)下載路徑

    這篇文章主要介紹了Python pip install如何修改默認(rèn)下載路徑,文中通過示例代碼介紹的非常詳細(xì),對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友可以參考下
    2020-04-04
  • 樹莓派升級python的具體步驟

    樹莓派升級python的具體步驟

    在本篇文章里小編給大家整理的是關(guān)于樹莓派升級python的具體步驟,需要的朋友們可以參考下。
    2020-07-07
  • django 使用全局搜索功能的實(shí)例詳解

    django 使用全局搜索功能的實(shí)例詳解

    今天小編就為大家分享一篇django 使用全局搜索功能的實(shí)例詳解,具有很好的參考價(jià)值,希望對大家有所幫助。一起跟隨小編過來看看吧
    2019-07-07
  • Python?Melt函數(shù)將寬格式的數(shù)據(jù)表轉(zhuǎn)換為長格式

    Python?Melt函數(shù)將寬格式的數(shù)據(jù)表轉(zhuǎn)換為長格式

    在數(shù)據(jù)處理和清洗中,melt函數(shù)是Pandas庫中一個(gè)強(qiáng)大而靈活的工具,它的主要功能是將寬格式的數(shù)據(jù)表轉(zhuǎn)換為長格式,從而更方便進(jìn)行分析和可視化,本文將深入探討melt函數(shù)的用法、參數(shù)解析以及實(shí)際應(yīng)用場景
    2023-12-12
  • Python heapq使用詳解及實(shí)例代碼

    Python heapq使用詳解及實(shí)例代碼

    這篇文章主要介紹了Python heapq使用詳解及實(shí)例代碼的相關(guān)資料,需要的朋友可以參考下
    2017-01-01
  • linux平臺(tái)使用Python制作BT種子并獲取BT種子信息的方法

    linux平臺(tái)使用Python制作BT種子并獲取BT種子信息的方法

    這篇文章主要介紹了linux平臺(tái)使用Python制作BT種子并獲取BT種子信息的方法,結(jié)合實(shí)例形式詳細(xì)分析了Python BT模塊的安裝及針對BT種子文件的相關(guān)操作技巧,需要的朋友可以參考下
    2017-01-01
  • 利用python繪制正態(tài)分布曲線

    利用python繪制正態(tài)分布曲線

    這篇文章主要介紹了如何利用python繪制正態(tài)分布曲線,幫助大家更好的利用python進(jìn)行數(shù)據(jù)分析,感興趣的朋友可以了解下
    2021-01-01
  • Python做屏幕錄制工具的實(shí)現(xiàn)示例

    Python做屏幕錄制工具的實(shí)現(xiàn)示例

    本文主要介紹了Python做屏幕錄制工具的實(shí)現(xiàn)示例,文中通過示例代碼介紹的非常詳細(xì),對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧
    2022-06-06
  • python和shell變量互相傳遞的幾種方法

    python和shell變量互相傳遞的幾種方法

    這篇文章主要介紹了python和shell變量互相傳遞方法,使用了環(huán)境變量、管道等方法
    2013-11-11
  • Python中生成ndarray實(shí)例講解

    Python中生成ndarray實(shí)例講解

    在本篇文章里小編給大家整理的是一篇關(guān)于Python中生成ndarray實(shí)例講解內(nèi)容,有興趣的朋友們可以學(xué)習(xí)參考下。
    2021-02-02

最新評論