pytorch實(shí)現(xiàn)圖像識(shí)別(實(shí)戰(zhàn))
1. 代碼講解
1.1 導(dǎo)庫(kù)
import os.path from os import listdir import numpy as np import pandas as pd from PIL import Image import torch import torch.nn as nn import torch.nn.functional as F import torch.optim as optim from torch.nn import AdaptiveAvgPool2d from torch.utils.data.sampler import SubsetRandomSampler from torch.utils.data import Dataset import torchvision.transforms as transforms from sklearn.model_selection import train_test_split
1.2 標(biāo)準(zhǔn)化、transform、設(shè)置GPU
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') normalize = transforms.Normalize( ? ?mean=[0.485, 0.456, 0.406], ? ?std=[0.229, 0.224, 0.225] ) transform = transforms.Compose([transforms.ToTensor(), normalize]) ?# 轉(zhuǎn)換
1.3 預(yù)處理數(shù)據(jù)
class DogDataset(Dataset): # 定義變量 ? ? def __init__(self, img_paths, img_labels, size_of_images): ? ? ? ? ? self.img_paths = img_paths ? ? ? ? self.img_labels = img_labels ? ? ? ? self.size_of_images = size_of_images # 多少長(zhǎng)圖片 ? ? def __len__(self): ? ? ? ? return len(self.img_paths) # 打開(kāi)每組圖片并處理每張圖片 ? ? def __getitem__(self, index): ? ? ? ? PIL_IMAGE = Image.open(self.img_paths[index]).resize(self.size_of_images) ? ? ? ? TENSOR_IMAGE = transform(PIL_IMAGE) ? ? ? ? label = self.img_labels[index] ? ? ? ? return TENSOR_IMAGE, label print(len(listdir(r'C:\Users\AIAXIT\Desktop\DeepLearningProject\Deep_Learning_Data\dog-breed-identification\train'))) print(len(pd.read_csv(r'C:\Users\AIAXIT\Desktop\DeepLearningProject\Deep_Learning_Data\dog-breed-identification\labels.csv'))) print(len(listdir(r'C:\Users\AIAXIT\Desktop\DeepLearningProject\Deep_Learning_Data\dog-breed-identification\test'))) train_paths = [] test_paths = [] labels = [] # 訓(xùn)練集圖片路徑 train_paths_lir = r'C:\Users\AIAXIT\Desktop\DeepLearningProject\Deep_Learning_Data\dog-breed-identification\train' for path in listdir(train_paths_lir): ? ? train_paths.append(os.path.join(train_paths_lir, path)) ? # 測(cè)試集圖片路徑 labels_data = pd.read_csv(r'C:\Users\AIAXIT\Desktop\DeepLearningProject\Deep_Learning_Data\dog-breed-identification\labels.csv') labels_data = pd.DataFrame(labels_data) ? # 把字符標(biāo)簽離散化,因?yàn)閿?shù)據(jù)有120種狗,不離散化后面把數(shù)據(jù)給模型時(shí)會(huì)報(bào)錯(cuò):字符標(biāo)簽過(guò)多。把字符標(biāo)簽從0-119編號(hào) size_mapping = {} value = 0 size_mapping = dict(labels_data['breed'].value_counts()) for kay in size_mapping: ? ? size_mapping[kay] = value ? ? value += 1 # print(size_mapping) labels = labels_data['breed'].map(size_mapping) labels = list(labels) # print(labels) print(len(labels)) # 劃分訓(xùn)練集和測(cè)試集 X_train, X_test, y_train, y_test = train_test_split(train_paths, labels, test_size=0.2) train_set = DogDataset(X_train, y_train, (32, 32)) test_set = DogDataset(X_test, y_test, (32, 32)) train_loader = torch.utils.data.DataLoader(train_set, batch_size=64) test_loader = torch.utils.data.DataLoader(test_set, batch_size=64)
1.4 建立模型
class LeNet(nn.Module): ? ? def __init__(self): ? ? ? ? super(LeNet, self).__init__() ? ? ? ? self.features = nn.Sequential( ? ? ? ? ? ? nn.Conv2d(in_channels=3, out_channels=6, kernel_size=5), ? ? ? ? ? ? ? nn.ReLU(), ? ? ? ? ? ? nn.AvgPool2d(kernel_size=2, stride=2), ? ? ? ? ? ? nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5), ? ? ? ? ? ? nn.ReLU(), ? ? ? ? ? ? nn.AvgPool2d(kernel_size=2, stride=2) ? ? ? ? ) ? ? ? ? self.classifier = nn.Sequential( ? ? ? ? ? ? nn.Linear(16 * 5 * 5, 120), ? ? ? ? ? ? nn.ReLU(), ? ? ? ? ? ? nn.Linear(120, 84), ? ? ? ? ? ? nn.ReLU(), ? ? ? ? ? ? nn.Linear(84, 120) ? ? ? ? ) ? ? def forward(self, x): ? ? ? ? batch_size = x.shape[0] ? ? ? ? x = self.features(x) ? ? ? ? x = x.view(batch_size, -1) ? ? ? ? x = self.classifier(x) ? ? ? ? return x model = LeNet().to(device) criterion = nn.CrossEntropyLoss().to(device) optimizer = optim.Adam(model.parameters()) TRAIN_LOSS = [] ?# 損失 TRAIN_ACCURACY = [] ?# 準(zhǔn)確率
1.5 訓(xùn)練模型
def train(epoch): ? ? model.train() ? ? epoch_loss = 0.0 # 損失 ? ? correct = 0 ?# 精確率 ? ? for batch_index, (Data, Label) in enumerate(train_loader): ? ? # 扔到GPU中 ? ? ? ? Data = Data.to(device) ? ? ? ? Label = Label.to(device) ? ? ? ? output_train = model(Data) ? ? # 計(jì)算損失 ? ? ? ? loss_train = criterion(output_train, Label) ? ? ? ? epoch_loss = epoch_loss + loss_train.item() ? ? # 計(jì)算精確率 ? ? ? ? pred = torch.max(output_train, 1)[1] ? ? ? ? train_correct = (pred == Label).sum() ? ? ? ? correct = correct + train_correct.item() ? ? # 梯度歸零、反向傳播、更新參數(shù) ? ? ? ? optimizer.zero_grad() ? ? ? ? loss_train.backward() ? ? ? ? optimizer.step() ? ? print('Epoch: ', epoch, 'Train_loss: ', epoch_loss / len(train_set), 'Train correct: ', correct / len(train_set))
1.6 測(cè)試模型
和訓(xùn)練集差不多。
def test(): ? ? model.eval() ? ? correct = 0.0 ? ? test_loss = 0.0 ? ? with torch.no_grad(): ? ? ? ? for Data, Label in test_loader: ? ? ? ? ? ? Data = Data.to(device) ? ? ? ? ? ? Label = Label.to(device) ? ? ? ? ? ? test_output = model(Data) ? ? ? ? ? ? loss = criterion(test_output, Label) ? ? ? ? ? ? pred = torch.max(test_output, 1)[1] ? ? ? ? ? ? test_correct = (pred == Label).sum() ? ? ? ? ? ? correct = correct + test_correct.item() ? ? ? ? ? ? test_loss = test_loss + loss.item() ? ? print('Test_loss: ', test_loss / len(test_set), 'Test correct: ', correct / len(test_set))
1.7結(jié)果
epoch = 10 for n_epoch in range(epoch): ? ? train(n_epoch) test()
到此這篇關(guān)于pytorch實(shí)現(xiàn)圖像識(shí)別(實(shí)戰(zhàn))的文章就介紹到這了,更多相關(guān)pytorch實(shí)現(xiàn)圖像識(shí)別內(nèi)容請(qǐng)搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
相關(guān)文章
Python+OpenCV數(shù)字圖像處理之ROI區(qū)域的提取
ROI區(qū)域又叫感興趣區(qū)域。在機(jī)器視覺(jué)、圖像處理中,從被處理的圖像以方框、圓、橢圓、不規(guī)則多邊形等方式勾勒出需要處理的區(qū)域,稱(chēng)為感興趣區(qū)域,ROI。本文主要為大家介紹如何通過(guò)Python+OpenCV提取ROI區(qū)域,需要的朋友可以了解一下2021-12-12python+tkinter編寫(xiě)電腦桌面放大鏡程序?qū)嵗a
這篇文章主要介紹了Python+tkinter編寫(xiě)電腦桌面放大鏡程序?qū)嵗a,具有一定借鑒價(jià)值,需要的朋友可以參考下2018-01-01詳細(xì)聊聊為什么Python中0.2+0.1不等于0.3
最近在學(xué)習(xí)過(guò)程中發(fā)現(xiàn)在計(jì)算機(jī)JS時(shí)發(fā)現(xiàn)了一個(gè)非常有意思事,0.1+0.2的結(jié)果不是0.3,而是0.30000000000000004,下面這篇文章主要給大家介紹了關(guān)于為什么Python中0.2+0.1不等于0.3的相關(guān)資料,需要的朋友可以參考下2022-12-12python使用matplotlib:subplot繪制多個(gè)子圖的示例
這篇文章主要介紹了python使用matplotlib:subplot繪制多個(gè)子圖的示例,幫助大家更好的利用python繪制圖像,感興趣的朋友可以了解下2020-09-09Python通過(guò)zookeeper實(shí)現(xiàn)分布式服務(wù)代碼解析
這篇文章主要介紹了Python通過(guò)zookeeper實(shí)現(xiàn)分布式服務(wù)代碼解析,文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友可以參考下2020-07-07tensorflow構(gòu)建BP神經(jīng)網(wǎng)絡(luò)的方法
這篇文章主要為大家詳細(xì)介紹了tensorflow構(gòu)建BP神經(jīng)網(wǎng)絡(luò)的方法,具有一定的參考價(jià)值,感興趣的小伙伴們可以參考一下2018-03-03python 統(tǒng)計(jì)數(shù)組中元素出現(xiàn)次數(shù)并進(jìn)行排序的實(shí)例
今天小編就為大家分享一篇python 統(tǒng)計(jì)數(shù)組中元素出現(xiàn)次數(shù)并進(jìn)行排序的實(shí)例,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧2018-07-07django admin實(shí)現(xiàn)動(dòng)態(tài)多選框表單的示例代碼
借助django-admin,可以快速得到CRUD界面,但若需要?jiǎng)?chuàng)建多選標(biāo)簽字段時(shí),需要對(duì)表單進(jìn)行調(diào)整,本文通過(guò)示例代碼給大家介紹django admin多選框表單的實(shí)現(xiàn)方法,感興趣的朋友跟隨小編一起看看吧2021-05-05