欧美bbbwbbbw肥妇,免费乱码人妻系列日韩,一级黄片

關(guān)于pytorch訓(xùn)練分類器

 更新時間:2023年09月14日 10:51:58   作者:bujbujbiu  
這篇文章主要介紹了關(guān)于pytorch訓(xùn)練分類器問題,具有很好的參考價值,希望對大家有所幫助,如有錯誤或未考慮完全的地方,望不吝賜教

Training a Classifier

前面學(xué)習(xí)到如何定義神經(jīng)網(wǎng)絡(luò),計算損失并且對網(wǎng)絡(luò)權(quán)重進行更新

What about data?

通常,當(dāng)你必須處理圖像,文本,音頻或視頻時,你可以使用能將數(shù)據(jù)加載到numpy數(shù)組的標(biāo)準(zhǔn)python包,然后將該數(shù)組轉(zhuǎn)化成 torch.*Tensor

  • 圖像:Pillow, OpenCV
  • 音頻:scipy,librosa
  • 文本:基于python或cython的原始加載,或者NLTK和SpaCy

專門針對視覺,創(chuàng)建了名為 torchvision 的包,包含常見數(shù)據(jù)集(ImageNet, CIFAR10, MNIST)的加載器,以及用于圖像的數(shù)據(jù)轉(zhuǎn)換器( torchvision.datasets torch.utils.data.DataLoader

提供極大便利,避免編寫樣板代碼

使用CIFAR10數(shù)據(jù)集,有分類:‘airplane’, ‘automobile’, ‘bird’, ‘cat’, ‘deer’,‘dog’, ‘frog’, ‘horse’, ‘ship’, ‘truck’。CIFAR-10中的圖像尺寸為 3x32x32 ,即尺寸為 32x32 像素的3通道彩色圖像

圖像的4D張量為(B,C,H,W)

  • B:batch size
  • C:channel
  • H:height
  • W:width

Training an image classifier

  • 1.使用torchvision加載并標(biāo)準(zhǔn)化CIFAR10訓(xùn)練和測試數(shù)據(jù)集
  • 2.定義卷積神經(jīng)網(wǎng)絡(luò)
  • 3.定義損失函數(shù)
  • 4.基于訓(xùn)練數(shù)據(jù)訓(xùn)練網(wǎng)絡(luò)
  • 5.基于測試數(shù)據(jù)測試網(wǎng)絡(luò)

1.加載并標(biāo)準(zhǔn)化CIFAR10

torchvision 庫包括數(shù)據(jù)集,模型以及針對計算機視覺的圖像轉(zhuǎn)換器,是pytorch的一個圖形。

torchvision 包括以下:

  • torchvision.datasets : 一些加載數(shù)據(jù)的函數(shù)及常用的數(shù)據(jù)集接口
  • torchvision.models :包含常用的模型結(jié)構(gòu)(含預(yù)訓(xùn)練模型),例如AlexNet、VGG、ResNet等
  • torchvision.transforms :常用的圖片變換,例如裁剪、旋轉(zhuǎn)等
  • torchvision.utils :其他的一些有用的方法
import torch
import torchvision
import torchvision.transforms as transforms

torchvision 數(shù)據(jù)集輸出是[0,1]范圍的PILImage圖像,需要轉(zhuǎn)換為標(biāo)準(zhǔn)化范圍的[-1,1]張量

torchvision.transforms.Compose 合并多個圖像變換的操作,常見transforms操作有:

  • ToTensor:把灰度范圍從0-255變換到0-1之間
  • Normalize:用均值和標(biāo)準(zhǔn)差歸一化張量圖像
  • CenterCrop:在圖片的中間區(qū)域進行裁剪

Python圖像庫PIL(Python Image Library)是python的第三方圖像處理庫

PyTorch中數(shù)據(jù)讀取的一個重要接口是 torch.utils.data.DataLoader ,該接口定義在dataloader.py腳本中,只要是用PyTorch來訓(xùn)練模型基本都會用到該接口,該接口主要用來將自定義的數(shù)據(jù)讀取接口的輸出或者PyTorch已有的數(shù)據(jù)讀取接口的輸入按照batch size封裝成Tensor,后續(xù)只需要再包裝成Variable即可作為模型的輸入,因此該接口有點承上啟下的作用,比較重要。

import ssl
ssl._create_default_https_context = ssl._create_unverified_context
# 取消證書驗證
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
# transforms.Normalize(mean,std),圖像尺寸為3*32*32,保持一致
batch_size = 4
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=batch_size,shuffle=True, num_workers=2)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=batch_size,shuffle=False, num_workers=2)
# DataLoader數(shù)據(jù)迭代器,用來封裝數(shù)據(jù),num_workers讀取數(shù)據(jù)的線程數(shù),shuffle設(shè)置為True表示在每個epoch重新洗牌數(shù)據(jù)
classes = ('plane', 'car', 'bird', 'cat','deer', 'dog', 'frog', 'horse', 'ship', 'truck')
Files already downloaded and verified
Files already downloaded and verified

Let us show some of the training images, for fun.

迭代是Python最強大的功能之一,是訪問集合元素的一種方式。字符串,列表,元組都可以用于創(chuàng)建迭代器。迭代器對象從集合的第一個元素開始訪問,直到所有的元素被訪問完結(jié)束包括兩種方法:

  • iter() 創(chuàng)建一個迭代器
  • next() 返回迭代器的下一個項目。
list1=[1,2,3,4]
it=iter(list1)
for x in it:
    print(x,end=' ')

1 2 3 4 

展示一些訓(xùn)練圖像

import matplotlib.pyplot as plt
import numpy as np
# functions to show an image
def imshow(img):
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()  # PIL image轉(zhuǎn)換成numpy array
    plt.imshow(np.transpose(npimg, (1, 2, 0))) # np.transpose反轉(zhuǎn)或置換數(shù)組的軸
    plt.show()
# get some random training images
# trainloader相當(dāng)于一個包含images和labels的列表,前面shuffle設(shè)置為True,因此每次運行都會結(jié)果不同
dataiter = iter(trainloader)
images, labels = dataiter.next()
# show images
imshow(torchvision.utils.make_grid(images))
# torchvision.utils.make_grid將若干張圖像拼成一張網(wǎng)格
# print labels
print(' '.join(f'{classes[labels[j]]:5s}' for j in range(batch_size)))

horse bird deer truck

2.定義卷積神經(jīng)網(wǎng)絡(luò)

從之前部分復(fù)制神經(jīng)網(wǎng)絡(luò)代碼,將圖像改為3通道

nn.Conv2d :在由多個輸入平面組成的輸入信號上應(yīng)用二維卷積

nn.Conv2d(in_channels,out_channels,kernel_size)

nn.MaxPool2d :在由幾個輸入平面組成的輸入信號上應(yīng)用一個2D max池

nn.MaxPool2d(kernel_size,stride)

nn.Linear :對輸入的數(shù)據(jù)應(yīng)用線性轉(zhuǎn)換 y = x A T + b y=xA^T+b y=xAT+b

nn.Linear(in_features,out_features)

在這里插入圖片描述

import torch.nn as nn
import torch.nn.functional as F
class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)# 卷積計算
        # 3channel的32*32原始圖像經(jīng)過6個5*5的filters卷積計算后變成6channel的28*28圖像
        self.pool = nn.MaxPool2d(2, 2)# 池化
        # 6channel的28*28圖像以2*2進行pooling操作變?yōu)?4*14,stride=kernel_size表示沒有重復(fù)部分,28/2=14
        self.conv2 = nn.Conv2d(6, 16, 5)# 卷積計算
        # 6channel的14*14圖像經(jīng)過16個5*5的filters卷積計算后變成16channel的10*10圖像
        #self.pool = nn.MaxPool2d(2, 2)
        # 16channel的10*10圖像以2*2進行pooling變?yōu)?*5,10/2=5
        self.fc1 = nn.Linear(16 * 5 * 5, 120)# 線性變換
        # 16channel的5*5平鋪即16 * 5 * 5,作為FC首層的輸入F5
        self.fc2 = nn.Linear(120, 84)
        # FC第二層F6
        self.fc3 = nn.Linear(84, 10)
        # FC第三層高斯層output
    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))# 卷積->激活->池化
        x = self.pool(F.relu(self.conv2(x)))# 卷積->激活->池化
        x = torch.flatten(x, 1) # 除了batch維度均平鋪
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)# 最后一層為高斯連接
        return x
net = Net()

3.定義損失函數(shù)和優(yōu)化器

使用分類交叉熵?fù)p失和動量SGD

torch.nn.CrossEntropyLoss :計算輸入與目標(biāo)值間的交叉熵?fù)p失,適合帶有C個類別的分類問題,輸入是每個類原始無標(biāo)準(zhǔn)化的分?jǐn)?shù)

import torch.optim as optim
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

criterion type:<class ‘torch.nn.modules.loss.MSELoss’>

loss type:<class ‘torch.Tensor’>

4.訓(xùn)練網(wǎng)絡(luò)

遍歷數(shù)據(jù)迭代器,將輸入饋送到網(wǎng)絡(luò)并進行優(yōu)化

enumerate() 函數(shù)用于將一個可遍歷的數(shù)據(jù)對象(如列表、元組或字符串)組合為一個索引序列,同時列出數(shù)據(jù)和數(shù)據(jù)下標(biāo),一般用在for循環(huán)當(dāng)中,語法 enumerate(sequence, [start=0])

  • sequence: 一個序列、迭代器或其他支持迭代對象
  • start:下標(biāo)起始為止
# 普通for循環(huán)
i=0
sequence=['one','two','three']
for e in sequence:
    print(i,sequence[i])
    i+=1
# 使用enumerate的for循環(huán)
for i,e in enumerate(sequence,0):
    print(i,e)

0 one
1 two
2 three
0 one
1 two
2 three

for epoch in range(2):
    run_loss = 0.0 # 計算平均誤差
    # 獲取inputs,data是一個列表[inputs,labels]
    for i,data in enumerate(trainloader,0):
        inputs,labels = data
        # 梯度清0
        optimizer.zero_grad()
        # forward+loss+backward+optimize
        outputs = net(inputs)
        loss = criterion(outputs,labels)
        loss.backward()
        optimizer.step()
        # 如果使用run_loss+=run_loss,會導(dǎo)致內(nèi)存爆炸,此處loss是變量
        run_loss += loss.item()
        if i%2000 == 1999: # 輸出每2000個mini-batches
            print(f'[{epoch+1},{i+1:5d}],loss:{run_loss/2000:.3f}')
            run_loss = 0.0

[1, 2000],loss:2.268
[1, 4000],loss:2.029
[1, 6000],loss:1.834
[1, 8000],loss:1.666
[1,10000],loss:1.598
[1,12000],loss:1.517
[2, 2000],loss:1.459
[2, 4000],loss:1.418
[2, 6000],loss:1.373
[2, 8000],loss:1.355
[2,10000],loss:1.349
[2,12000],loss:1.306

保存訓(xùn)練過的模型

PATH = './cifar_net.pth'
torch.save(net.state_dict(), PATH)

5.基于測試數(shù)據(jù)測試網(wǎng)絡(luò)

基于訓(xùn)練數(shù)據(jù)對網(wǎng)絡(luò)進行2次訓(xùn)練,為了檢測網(wǎng)絡(luò)性能,通過預(yù)測將神經(jīng)網(wǎng)絡(luò)輸出的類別標(biāo)簽并且與實際對比,如果預(yù)測正確,將該樣本添加到正確預(yù)測表中

首先,顯示幾張測試集中的圖像

dataitertest =iter(testloader)
images,labels = dataitertest.next()
print(labels)
# 此處的labels是數(shù)字代表的類別
imshow(torchvision.utils.make_grid(images))
print('groundtruth:',' '.join(f'{classes[labels[j]]:5s}' for j in range(4)))

tensor([3, 8, 8, 0])

groundtruth: cat   ship  ship  plane

接下來重新加載保存的模型(實際不需要,此處展示如何保存)

net = Net()
net.load_state_dict(torch.load(PATH))
<All keys matched successfully>

現(xiàn)在看看神經(jīng)網(wǎng)絡(luò)對以上樣例的預(yù)測

outputs = net(images)
print(outputs)
tensor([[-0.5511, -1.2592,  1.0451,  1.7341,  0.2255,  1.0719,  0.3474, -0.0722,
         -0.7703, -1.7738],
        [ 2.9296,  4.5538, -0.4796, -1.7549, -2.4294, -2.7830, -3.4919, -3.0665,
          4.3148,  2.5193],
        [ 2.0322,  2.4424,  0.4408, -1.1508, -1.1923, -1.9300, -2.9568, -1.5784,
          2.8175,  2.0967],
        [ 3.1805,  2.2340,  0.1468, -1.6451, -0.8934, -2.9459, -3.4108, -2.2368,
          4.2390,  2.2832]], grad_fn=<AddmmBackward0>)

輸出是4張圖像10個類別的能量,某個類的能量越高,代表網(wǎng)絡(luò)傾向于認(rèn)為該圖像屬于該類別,因此讓我們獲取最高能量的指數(shù)

torch.max(input, dim, keepdim=False, out=None) :返回輸入tensor中所有元素的最大值

torch.max(tensor,0) :返回每一列(1行)中最大值的那個元素,且返回索引(返回最大元素在這一列的行索引)

_,predicted = torch.max(outputs,1)
print("predicted:",''.join('%5s'%classes[predicted[j]] for j in range(4)))
predicted:   cat  car ship ship

正確率75%

接下來看網(wǎng)絡(luò)在整個數(shù)據(jù)集上的表現(xiàn)

totalnum = 0
correctnum = 0
# 沒有訓(xùn)練,因此不需要計算輸出的梯度
with torch.no_grad():
    for data in testloader:
        images,labels = data
        # 前向傳播
        outputs = net(images)
        _,predicted = torch.max(outputs,1)
        # totalnum所有測試圖像數(shù)量,correctnum預(yù)測準(zhǔn)確圖像數(shù)量
        totalnum += labels.size(0)
        correctnum += (predicted==labels).sum().item()
print("Accuracy of the network on the 10000 test images:%d %%"%(100*correctnum/totalnum))
Accuracy of the network on the 10000 test images:55 %

隨機選擇一個類,準(zhǔn)確率為10%,因此神經(jīng)網(wǎng)絡(luò)訓(xùn)練比隨機更好。接下來分析網(wǎng)絡(luò)在哪些類表現(xiàn)好,哪些類表現(xiàn)不好

zip([iterable,...]) 函數(shù)用于將可迭代的對象作為參數(shù),將對象中對應(yīng)的元素打包成一個個元組,然后返回由這些元組組成的對象,這樣做的好處是節(jié)約了不少的內(nèi)存。使用 list() 轉(zhuǎn)換來輸出列表

ex=[1,2,3]
ex1=[4,5,6]
m=zip(ex,ex1)
print(list(m))
# 出現(xiàn)list is not callable,表明有變量名被命名成了list,注意命名規(guī)范!

[(1, 4), (2, 5), (3, 6)]

# 字典存儲每個類別預(yù)測正確的數(shù)量和總數(shù)量
correct_pred = {classname:0 for classname in classes}
total_pred = {classname:0 for classname in classes}
# 預(yù)測并計數(shù)
with torch.no_grad():
    for data in testloader:
        images,labels = data
        outputs = net(images)
        _,predictions = torch.max(outputs,1)
        for label,prediction in zip(labels,predictions):
            if label == prediction:
                correct_pred[classes[label]] += 1
            total_pred[classes[label]] += 1
for classname,correct_count in correct_pred.items():
    accuracy = 100*float(correct_count)/total_pred[classname]
    print('accuracy of %5s:%2d %%'%(classname,accuracy))

accuracy of plane:54 %
accuracy of   car:74 %
accuracy of  bird:49 %
accuracy of   cat:31 %
accuracy of  deer:53 %
accuracy of   dog:47 %
accuracy of  frog:60 %
accuracy of horse:58 %
accuracy of  ship:69 %
accuracy of truck:54 %

Training on GPU

GPU圖像處理器:專門做圖像和圖形相關(guān)運算工作的微處理器。就像張量可以轉(zhuǎn)移到GPU一樣,神經(jīng)網(wǎng)絡(luò)也可以,此處沒有CUDA設(shè)備無法實現(xiàn)

總結(jié)

以上為個人經(jīng)驗,希望能給大家一個參考,也希望大家多多支持腳本之家。

相關(guān)文章

  • Python讀csv文件去掉一列后再寫入新的文件實例

    Python讀csv文件去掉一列后再寫入新的文件實例

    下面小編就為大家分享一篇Python讀csv文件去掉一列后再寫入新的文件實例,具有很的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧
    2017-12-12
  • Django migrations 默認(rèn)目錄修改的方法教程

    Django migrations 默認(rèn)目錄修改的方法教程

    這篇文章主要介紹了Django migrations 默認(rèn)目錄修改的方法教程,小編覺得挺不錯的,現(xiàn)在分享給大家,也給大家做個參考。一起跟隨小編過來看看吧
    2018-09-09
  • Python開發(fā)畢設(shè)案例之桌面學(xué)生信息管理程序

    Python開發(fā)畢設(shè)案例之桌面學(xué)生信息管理程序

    畢業(yè)設(shè)計必備案例:Python開發(fā)桌面程序
    2021-11-11
  • 詳解flask入門模板引擎

    詳解flask入門模板引擎

    這篇文章主要介紹了詳解flask入門模板引擎,小編覺得挺不錯的,現(xiàn)在分享給大家,也給大家做個參考。一起跟隨小編過來看看吧
    2018-07-07
  • python中創(chuàng)建一個包并引用使用的操作方法

    python中創(chuàng)建一個包并引用使用的操作方法

    python包在開發(fā)中十分常見,一般通過導(dǎo)入包含特定功能的python模塊包進行使用。當(dāng)然,也可以自己創(chuàng)建打包模塊,然后發(fā)布,安裝使用,這篇文章主要介紹了python中如何創(chuàng)建一個包并引用使用,需要的朋友可以參考下
    2022-08-08
  • Python可視化程序調(diào)用流程解析

    Python可視化程序調(diào)用流程解析

    這篇文章主要為大家介紹了可視化Python程序調(diào)用流程解析,有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進步,早日升職加薪
    2022-08-08
  • Python?頁面解析Beautiful?Soup庫的使用方法

    Python?頁面解析Beautiful?Soup庫的使用方法

    Beautiful?Soup?簡稱?BS4(其中?4?表示版本號)是一個?Python?中常用的頁面解析庫,它可以從?HTML?或?XML?文檔中快速地提取指定的數(shù)據(jù),這篇文章主要介紹了springboot?集成?docsify?實現(xiàn)隨身文檔?,需要的朋友可以參考下
    2022-09-09
  • 詳解在Python和IPython中使用Docker

    詳解在Python和IPython中使用Docker

    這篇文章主要介紹了詳解在Python和IPython中使用Docker,Docker是一個吸引人的新系統(tǒng),可以用來建立有趣的新技術(shù)應(yīng)用,特別是云服務(wù)相關(guān)的,需要的朋友可以參考下
    2015-04-04
  • python 時間處理之月份加減問題

    python 時間處理之月份加減問題

    這篇文章主要介紹了python 時間處理之月份加減問題,具有很好的參考價值,希望對大家有所幫助。如有錯誤或未考慮完全的地方,望不吝賜教
    2022-11-11
  • Pandas中運行速度優(yōu)化的常用方法介紹

    Pandas中運行速度優(yōu)化的常用方法介紹

    這篇文章主要為大家詳細(xì)介紹了幾種pandas中常用到的方法,對于這些方法使用存在哪些需要注意的問題,以及如何對它們進行速度提升,需要的小伙伴可以參考下
    2025-03-03

最新評論