欧美bbbwbbbw肥妇,免费乱码人妻系列日韩,一级黄片

pytorch中的模型訓(xùn)練(以CIFAR10數(shù)據(jù)集為例)

 更新時(shí)間:2023年06月15日 09:23:50   作者:MarkAssassin  
這篇文章主要介紹了pytorch中的模型訓(xùn)練(以CIFAR10數(shù)據(jù)集為例),具有很好的參考價(jià)值,希望對大家有所幫助。如有錯(cuò)誤或未考慮完全的地方,望不吝賜教

在pytorch模型訓(xùn)練時(shí),基本的訓(xùn)練步驟可以大致地歸納為:

準(zhǔn)備數(shù)據(jù)集--->搭建神經(jīng)網(wǎng)絡(luò)--->創(chuàng)建網(wǎng)絡(luò)模型--->創(chuàng)建損失函數(shù)--->設(shè)置優(yōu)化器--->訓(xùn)練步驟開始--->測試步驟開始

本文以pytorch官網(wǎng)中torchvision中的CIFAR10數(shù)據(jù)集為例進(jìn)行講解。

需要用到的庫為(這里說一個(gè)小技巧,比如可以在沒有import對應(yīng)庫的情況下先輸入"torch",,之后將光標(biāo)移到torch處單擊,這時(shí)左邊就會出現(xiàn)一個(gè)紅色的小燈泡,點(diǎn)開它就可以import對應(yīng)的庫了):

import torch
import torchvision
from torch import nn
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

準(zhǔn)備數(shù)據(jù)集

數(shù)據(jù)集分為“訓(xùn)練數(shù)據(jù)集”+“測試數(shù)據(jù)集”。

CIFAR數(shù)據(jù)集是由50000訓(xùn)練集和10000測試集組成。

這里可以調(diào)用torchvision.datasets對CIFAR10數(shù)據(jù)集進(jìn)行獲?。?/p>

#訓(xùn)練數(shù)據(jù)集
train_data=torchvision.datasets.CIFAR10(root='./dataset',train=True,transform=torchvision.transforms.ToTensor(),download=True)
#測試數(shù)據(jù)集
test_data=torchvision.datasets.CIFAR10(root='./dataset',train=False,transform=torchvision.transforms.ToTensor(),download=True)
#利用dataloader來加載數(shù)據(jù)集
train_data_loader=DataLoader(train_data,batch_size=64)
test_data_loader=DataLoader(test_data,batch_size=64)

root是數(shù)據(jù)集保存的路徑,這里筆者使用的是相對路徑;對于訓(xùn)練集train=True,而測試集train=False;transform是將數(shù)據(jù)集的類型轉(zhuǎn)換為tensor類型;download一般設(shè)置為True。

之后利用DataLoader對數(shù)據(jù)集進(jìn)行加載即可,其中batch_size表示單次傳遞給程序用以訓(xùn)練的數(shù)據(jù)(樣本)個(gè)數(shù)。

(這里可以再獲取下測試集數(shù)據(jù)的長度,這樣后面在測試步驟時(shí)就可以通過計(jì)算得到整體測試集上的準(zhǔn)確率

搭建神經(jīng)網(wǎng)絡(luò)

pytorch官網(wǎng)提供了一個(gè)神經(jīng)網(wǎng)絡(luò)的簡單實(shí)例:

import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(1, 20, 5)
        self.conv2 = nn.Conv2d(20, 20, 5)
    def forward(self, x):
        x = F.relu(self.conv1(x))
        return F.relu(self.conv2(x))

這個(gè)實(shí)例展示了神經(jīng)網(wǎng)絡(luò)的基本架構(gòu)。

從百度上我們可以搜索到CIFAR10數(shù)據(jù)集的網(wǎng)絡(luò)基本架構(gòu):

 從原理圖中可以看到,該網(wǎng)絡(luò)從輸入(inputs)到輸出(outputs)先后經(jīng)過了

卷積(Convolution)--->最大池化(Max-pooling)--->卷積--->最大池化--->卷積--->最大池化--->展平(Flatten)--->2次線性層

由此可以開始搭建神經(jīng)網(wǎng)絡(luò),筆者將網(wǎng)絡(luò)命名為MXC:

class Flatten(nn.Module):
    def forward(self, input):
        return input.view(input.size(0), -1)
#搭建神經(jīng)網(wǎng)絡(luò)
class MXC(nn.Module):
    def __init__(self):
        super(MXC, self).__init__()
        self.model=nn.Sequential(
            nn.Conv2d(3,32,5,1,2),
            nn.MaxPool2d(2),
            nn.Conv2d(32,32,5,1,2),
            nn.MaxPool2d(2),
            nn.Conv2d(32,64,5,1,2),
            nn.MaxPool2d(2),
            Flatten(),
            nn.Linear(64*4*4,64),
            nn.Linear(64,10)
        )
    def forward(self,x):
        x=self.model(x)
        return x

注:這里Flatten類自己寫的原因是筆者使用的torch版本中沒有展平類,因此需要自己構(gòu)建,可以參考解決 ImportError: cannot import name ‘Flatten‘ from ‘torch.nn‘

創(chuàng)建網(wǎng)絡(luò)模型

mxc=MXC()

創(chuàng)建損失函數(shù)

loss_fn=nn.CrossEntropyLoss()

設(shè)置優(yōu)化器

learning_rate=1e-2
optimizer=torch.optim.SGD(params=mxc.parameters(),lr=learning_rate)

這里params是網(wǎng)絡(luò)模型;lr是學(xué)習(xí)速率,一般設(shè)置小一些(0.01)。

訓(xùn)練步驟+測試步驟開始

先設(shè)置一些參數(shù)

#設(shè)置訓(xùn)練網(wǎng)絡(luò)的一些參數(shù)
#記錄訓(xùn)練的次數(shù)
total_train_step=0
#記錄測試的次數(shù)
total_test_step=0
#記錄測試的準(zhǔn)確率
total_accuracy=0
#訓(xùn)練的輪數(shù)
epoch=10

將訓(xùn)練步驟和測試步驟放入一個(gè)大循環(huán)中,進(jìn)入循環(huán)開始訓(xùn)練:

for i in range(epoch):
    print("-----第{}輪訓(xùn)練開始------".format(i+1))
    #訓(xùn)練步驟開始
    for data in train_data_loader:
        imgs,targets=data
        output=mxc(imgs)
        loss=loss_fn(output,targets)
        #優(yōu)化器優(yōu)化模型
        optimizer.zero_grad()#梯度清零
        loss.backward()#反向傳播
        optimizer.step()#參數(shù)優(yōu)化
        total_train_step=total_train_step+1
        if total_train_step%100==0:
            print("訓(xùn)練次數(shù):{} , Loss:{}".format(total_train_step, loss))  # 更正規(guī)的可以寫成loss.item()
    #測試步驟開始
    total_test_loss=0
    with torch.no_grad():
        for data in test_data_loader:
            imgs,targets=data
            output=mxc(imgs)
            loss=loss_fn(output,targets)
            total_test_loss=total_test_loss+loss
            accuracy=(output.argmax(1)==targets).sum()
            total_accuracy=total_accuracy+accuracy
    print("整體測試集上的Loss:{}".format(total_test_loss))
    print("整體測試集上的準(zhǔn)確率:{}".format(total_accuracy/test_data_size))
    total_test_step=total_test_step+1#測試的次數(shù),其實(shí)就是第幾輪
    #保存模型
    torch.save(mxc,"mxc_cpu{}.pth".format(i+1))#mxc_1是cpu版的
    print("模型已保存")

這里每一個(gè)data中包含有圖片+標(biāo)簽,需要將圖片(imgs)放入之前搭建好的神經(jīng)網(wǎng)絡(luò)模型mxc中去。

筆者這里設(shè)置的是每訓(xùn)練100次打印1次,訓(xùn)練完一輪后會將模型進(jìn)行保存,注意文件類型是pth格式。

如果想讓訓(xùn)練后的結(jié)果可視化,有兩種方法:

1.在循環(huán)前調(diào)用SummaryWriter添加tensorboard:

#添加tensorboard
writer=SummaryWriter('./logs_train')

并在循環(huán)的適當(dāng)位置中插入writer.add_scalar

        if total_train_step%100==0:
            print("訓(xùn)練次數(shù):{} , Loss:{}".format(total_train_step, loss))  # 更正規(guī)的可以寫成loss.item()
            writer.add_scalar("train_loss",loss.item(),total_train_step)
print("整體測試集上的Loss:{}".format(total_test_loss))
print("整體測試集上的準(zhǔn)確率:{}".format(total_accuracy/test_data_size))
total_test_step=total_test_step+1#測試的次數(shù),其實(shí)就是第幾輪
writer.add_scalar("test_loss",total_test_loss,total_test_step)
writer.add_scalar("test_accuracy",total_accuracy/test_data_size,total_test_step)

在運(yùn)行結(jié)束后,打開Terminal輸入(注意,前面需要顯示pytorch,因?yàn)橹挥性趐ytorch環(huán)境下才可以,如果沒有顯示還要切換到pytorch才行,可以輸入activate pytorch):

tensorboard --logdir=logs_train --port=6007

這里的“logs_train”是在SummaryWriter中設(shè)置的保存路徑。運(yùn)行之后,就可以在tensorboard中查看隨著訓(xùn)練次數(shù)的增加測試集上的Loss和準(zhǔn)確率的趨勢圖像。

2.調(diào)用matlab庫自己進(jìn)行繪制

總結(jié)

本文簡要介紹了pytorch模型訓(xùn)練的一個(gè)基本流程,并以CIFAR10數(shù)據(jù)集進(jìn)行了演示。

但這種方法是在CPU(device=“cpu”)上進(jìn)行訓(xùn)練的,訓(xùn)練速度比較慢,如果數(shù)據(jù)集十分龐大不建議使用這種方法,應(yīng)該在GPU(device=“cuda”)上進(jìn)行訓(xùn)練。

以上為個(gè)人經(jīng)驗(yàn),希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。

相關(guān)文章

  • pandas groupby()的使用小結(jié)

    pandas groupby()的使用小結(jié)

    在數(shù)據(jù)分析中,經(jīng)常會用到分組,可用函數(shù)pandas中的groupby(),本文就來介紹一下pandas groupby()的使用小結(jié),具有一定的參考價(jià)值,感興趣的可以了解一下
    2023-11-11
  • Python之父談Python的未來形式

    Python之父談Python的未來形式

    這篇文章主要介紹了Python之父談Python的未來,需要的朋友可以參考下
    2016-07-07
  • pycharm中代碼回滾到指定版本的兩種實(shí)現(xiàn)方法(附帶截圖展示)

    pycharm中代碼回滾到指定版本的兩種實(shí)現(xiàn)方法(附帶截圖展示)

    在編寫代碼的時(shí)候,經(jīng)常會出現(xiàn)寫的代碼存在一些問題,但是比較難以發(fā)現(xiàn)具體存在的問題在哪里,需要將帶代碼恢復(fù)到指定的版本,下面這篇文章主要給大家介紹了關(guān)于pycharm中代碼回滾到指定版本的兩種實(shí)現(xiàn)方法,需要的朋友可以參考下
    2022-06-06
  • python數(shù)組過濾實(shí)現(xiàn)方法

    python數(shù)組過濾實(shí)現(xiàn)方法

    這篇文章主要介紹了python數(shù)組過濾實(shí)現(xiàn)方法,涉及Python針對數(shù)組的相關(guān)操作技巧,具有一定參考借鑒價(jià)值,需要的朋友可以參考下
    2015-07-07
  • Python字符串模糊匹配工具TheFuzz的用法詳解

    Python字符串模糊匹配工具TheFuzz的用法詳解

    在處理文本數(shù)據(jù)時(shí),常常需要進(jìn)行模糊字符串匹配來找到相似的字符串,Python的TheFuzz庫提供了強(qiáng)大的方法用于解決這類問題,本文將深入介紹TheFuzz庫,探討其基本概念、常用方法和示例代碼,需要的朋友可以參考下
    2023-12-12
  • Python中條件判斷語句的簡單使用方法

    Python中條件判斷語句的簡單使用方法

    這篇文章主要介紹了Python中條件判斷語句的簡單使用方法,是Python入門學(xué)習(xí)中的基礎(chǔ)知識,需要的朋友可以參考下
    2015-08-08
  • 三步解決python PermissionError: [WinError 5]拒絕訪問的情況

    三步解決python PermissionError: [WinError 5]拒絕訪問的情況

    這篇文章主要介紹了三步解決python PermissionError: [WinError 5]拒絕訪問的情況,具有很好的參考價(jià)值,希望對大家有所幫助。一起跟隨小編過來看看吧
    2020-04-04
  • Python實(shí)現(xiàn)簡單的

    Python實(shí)現(xiàn)簡單的"導(dǎo)彈" 自動追蹤原理解析

    這篇文章主要介紹了Python實(shí)現(xiàn)簡單的"導(dǎo)彈" 自動追蹤原理解析,本文給大家介紹的非常詳細(xì),對大家的學(xué)習(xí)或工作具有一定的參考借鑒價(jià)值,需要的朋友可以參考下
    2021-03-03
  • Pytorch中torch.repeat_interleave()函數(shù)使用及說明

    Pytorch中torch.repeat_interleave()函數(shù)使用及說明

    這篇文章主要介紹了Pytorch中torch.repeat_interleave()函數(shù)使用及說明,具有很好的參考價(jià)值,希望對大家有所幫助。如有錯(cuò)誤或未考慮完全的地方,望不吝賜教
    2023-01-01
  • 使用apiDoc實(shí)現(xiàn)python接口文檔編寫

    使用apiDoc實(shí)現(xiàn)python接口文檔編寫

    今天小編就為大家分享一篇使用apiDoc實(shí)現(xiàn)python接口文檔編寫,具有很好的參考價(jià)值,希望對大家有所幫助。一起跟隨小編過來看看吧
    2019-11-11

最新評論