pytorch中的模型訓(xùn)練(以CIFAR10數(shù)據(jù)集為例)
在pytorch模型訓(xùn)練時(shí),基本的訓(xùn)練步驟可以大致地歸納為:
準(zhǔn)備數(shù)據(jù)集--->搭建神經(jīng)網(wǎng)絡(luò)--->創(chuàng)建網(wǎng)絡(luò)模型--->創(chuàng)建損失函數(shù)--->設(shè)置優(yōu)化器--->訓(xùn)練步驟開(kāi)始--->測(cè)試步驟開(kāi)始
本文以pytorch官網(wǎng)中torchvision中的CIFAR10數(shù)據(jù)集為例進(jìn)行講解。
需要用到的庫(kù)為(這里說(shuō)一個(gè)小技巧,比如可以在沒(méi)有import對(duì)應(yīng)庫(kù)的情況下先輸入"torch",,之后將光標(biāo)移到torch處單擊,這時(shí)左邊就會(huì)出現(xiàn)一個(gè)紅色的小燈泡,點(diǎn)開(kāi)它就可以import對(duì)應(yīng)的庫(kù)了):
import torch import torchvision from torch import nn from torch.utils.data import DataLoader from torch.utils.tensorboard import SummaryWriter
準(zhǔn)備數(shù)據(jù)集
數(shù)據(jù)集分為“訓(xùn)練數(shù)據(jù)集”+“測(cè)試數(shù)據(jù)集”。
CIFAR數(shù)據(jù)集是由50000訓(xùn)練集和10000測(cè)試集組成。
這里可以調(diào)用torchvision.datasets對(duì)CIFAR10數(shù)據(jù)集進(jìn)行獲取:
#訓(xùn)練數(shù)據(jù)集 train_data=torchvision.datasets.CIFAR10(root='./dataset',train=True,transform=torchvision.transforms.ToTensor(),download=True) #測(cè)試數(shù)據(jù)集 test_data=torchvision.datasets.CIFAR10(root='./dataset',train=False,transform=torchvision.transforms.ToTensor(),download=True) #利用dataloader來(lái)加載數(shù)據(jù)集 train_data_loader=DataLoader(train_data,batch_size=64) test_data_loader=DataLoader(test_data,batch_size=64)
root是數(shù)據(jù)集保存的路徑,這里筆者使用的是相對(duì)路徑;對(duì)于訓(xùn)練集train=True,而測(cè)試集train=False;transform是將數(shù)據(jù)集的類型轉(zhuǎn)換為tensor類型;download一般設(shè)置為T(mén)rue。
之后利用DataLoader對(duì)數(shù)據(jù)集進(jìn)行加載即可,其中batch_size表示單次傳遞給程序用以訓(xùn)練的數(shù)據(jù)(樣本)個(gè)數(shù)。
(這里可以再獲取下測(cè)試集數(shù)據(jù)的長(zhǎng)度,這樣后面在測(cè)試步驟時(shí)就可以通過(guò)計(jì)算得到整體測(cè)試集上的準(zhǔn)確率)
搭建神經(jīng)網(wǎng)絡(luò)
pytorch官網(wǎng)提供了一個(gè)神經(jīng)網(wǎng)絡(luò)的簡(jiǎn)單實(shí)例:
import torch.nn as nn import torch.nn.functional as F class Model(nn.Module): def __init__(self): super().__init__() self.conv1 = nn.Conv2d(1, 20, 5) self.conv2 = nn.Conv2d(20, 20, 5) def forward(self, x): x = F.relu(self.conv1(x)) return F.relu(self.conv2(x))
這個(gè)實(shí)例展示了神經(jīng)網(wǎng)絡(luò)的基本架構(gòu)。
從百度上我們可以搜索到CIFAR10數(shù)據(jù)集的網(wǎng)絡(luò)基本架構(gòu):
從原理圖中可以看到,該網(wǎng)絡(luò)從輸入(inputs)到輸出(outputs)先后經(jīng)過(guò)了
卷積(Convolution)--->最大池化(Max-pooling)--->卷積--->最大池化--->卷積--->最大池化--->展平(Flatten)--->2次線性層
由此可以開(kāi)始搭建神經(jīng)網(wǎng)絡(luò),筆者將網(wǎng)絡(luò)命名為MXC:
class Flatten(nn.Module): def forward(self, input): return input.view(input.size(0), -1)
#搭建神經(jīng)網(wǎng)絡(luò) class MXC(nn.Module): def __init__(self): super(MXC, self).__init__() self.model=nn.Sequential( nn.Conv2d(3,32,5,1,2), nn.MaxPool2d(2), nn.Conv2d(32,32,5,1,2), nn.MaxPool2d(2), nn.Conv2d(32,64,5,1,2), nn.MaxPool2d(2), Flatten(), nn.Linear(64*4*4,64), nn.Linear(64,10) ) def forward(self,x): x=self.model(x) return x
注:這里Flatten類自己寫(xiě)的原因是筆者使用的torch版本中沒(méi)有展平類,因此需要自己構(gòu)建,可以參考解決 ImportError: cannot import name ‘Flatten‘ from ‘torch.nn‘
創(chuàng)建網(wǎng)絡(luò)模型
mxc=MXC()
創(chuàng)建損失函數(shù)
loss_fn=nn.CrossEntropyLoss()
設(shè)置優(yōu)化器
learning_rate=1e-2 optimizer=torch.optim.SGD(params=mxc.parameters(),lr=learning_rate)
這里params是網(wǎng)絡(luò)模型;lr是學(xué)習(xí)速率,一般設(shè)置小一些(0.01)。
訓(xùn)練步驟+測(cè)試步驟開(kāi)始
先設(shè)置一些參數(shù)
#設(shè)置訓(xùn)練網(wǎng)絡(luò)的一些參數(shù) #記錄訓(xùn)練的次數(shù) total_train_step=0 #記錄測(cè)試的次數(shù) total_test_step=0 #記錄測(cè)試的準(zhǔn)確率 total_accuracy=0 #訓(xùn)練的輪數(shù) epoch=10
將訓(xùn)練步驟和測(cè)試步驟放入一個(gè)大循環(huán)中,進(jìn)入循環(huán)開(kāi)始訓(xùn)練:
for i in range(epoch): print("-----第{}輪訓(xùn)練開(kāi)始------".format(i+1)) #訓(xùn)練步驟開(kāi)始 for data in train_data_loader: imgs,targets=data output=mxc(imgs) loss=loss_fn(output,targets) #優(yōu)化器優(yōu)化模型 optimizer.zero_grad()#梯度清零 loss.backward()#反向傳播 optimizer.step()#參數(shù)優(yōu)化 total_train_step=total_train_step+1 if total_train_step%100==0: print("訓(xùn)練次數(shù):{} , Loss:{}".format(total_train_step, loss)) # 更正規(guī)的可以寫(xiě)成loss.item() #測(cè)試步驟開(kāi)始 total_test_loss=0 with torch.no_grad(): for data in test_data_loader: imgs,targets=data output=mxc(imgs) loss=loss_fn(output,targets) total_test_loss=total_test_loss+loss accuracy=(output.argmax(1)==targets).sum() total_accuracy=total_accuracy+accuracy print("整體測(cè)試集上的Loss:{}".format(total_test_loss)) print("整體測(cè)試集上的準(zhǔn)確率:{}".format(total_accuracy/test_data_size)) total_test_step=total_test_step+1#測(cè)試的次數(shù),其實(shí)就是第幾輪 #保存模型 torch.save(mxc,"mxc_cpu{}.pth".format(i+1))#mxc_1是cpu版的 print("模型已保存")
這里每一個(gè)data中包含有圖片+標(biāo)簽,需要將圖片(imgs)放入之前搭建好的神經(jīng)網(wǎng)絡(luò)模型mxc中去。
筆者這里設(shè)置的是每訓(xùn)練100次打印1次,訓(xùn)練完一輪后會(huì)將模型進(jìn)行保存,注意文件類型是pth格式。
如果想讓訓(xùn)練后的結(jié)果可視化,有兩種方法:
1.在循環(huán)前調(diào)用SummaryWriter添加tensorboard:
#添加tensorboard writer=SummaryWriter('./logs_train')
并在循環(huán)的適當(dāng)位置中插入writer.add_scalar:
if total_train_step%100==0: print("訓(xùn)練次數(shù):{} , Loss:{}".format(total_train_step, loss)) # 更正規(guī)的可以寫(xiě)成loss.item() writer.add_scalar("train_loss",loss.item(),total_train_step)
print("整體測(cè)試集上的Loss:{}".format(total_test_loss)) print("整體測(cè)試集上的準(zhǔn)確率:{}".format(total_accuracy/test_data_size)) total_test_step=total_test_step+1#測(cè)試的次數(shù),其實(shí)就是第幾輪 writer.add_scalar("test_loss",total_test_loss,total_test_step) writer.add_scalar("test_accuracy",total_accuracy/test_data_size,total_test_step)
在運(yùn)行結(jié)束后,打開(kāi)Terminal輸入(注意,前面需要顯示pytorch,因?yàn)橹挥性趐ytorch環(huán)境下才可以,如果沒(méi)有顯示還要切換到pytorch才行,可以輸入activate pytorch):
tensorboard --logdir=logs_train --port=6007
這里的“logs_train”是在SummaryWriter中設(shè)置的保存路徑。運(yùn)行之后,就可以在tensorboard中查看隨著訓(xùn)練次數(shù)的增加測(cè)試集上的Loss和準(zhǔn)確率的趨勢(shì)圖像。
2.調(diào)用matlab庫(kù)自己進(jìn)行繪制
總結(jié)
本文簡(jiǎn)要介紹了pytorch模型訓(xùn)練的一個(gè)基本流程,并以CIFAR10數(shù)據(jù)集進(jìn)行了演示。
但這種方法是在CPU(device=“cpu”)上進(jìn)行訓(xùn)練的,訓(xùn)練速度比較慢,如果數(shù)據(jù)集十分龐大不建議使用這種方法,應(yīng)該在GPU(device=“cuda”)上進(jìn)行訓(xùn)練。
以上為個(gè)人經(jīng)驗(yàn),希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
- Pytorch搭建簡(jiǎn)單的卷積神經(jīng)網(wǎng)絡(luò)(CNN)實(shí)現(xiàn)MNIST數(shù)據(jù)集分類任務(wù)
- Pytorch卷積神經(jīng)網(wǎng)絡(luò)遷移學(xué)習(xí)的目標(biāo)及好處
- Pytorch深度學(xué)習(xí)經(jīng)典卷積神經(jīng)網(wǎng)絡(luò)resnet模塊訓(xùn)練
- Pytorch卷積神經(jīng)網(wǎng)絡(luò)resent網(wǎng)絡(luò)實(shí)踐
- pytorch加載的cifar10數(shù)據(jù)集過(guò)程詳解
- Pytorch使用卷積神經(jīng)網(wǎng)絡(luò)對(duì)CIFAR10圖片進(jìn)行分類方式
相關(guān)文章
pycharm中代碼回滾到指定版本的兩種實(shí)現(xiàn)方法(附帶截圖展示)
在編寫(xiě)代碼的時(shí)候,經(jīng)常會(huì)出現(xiàn)寫(xiě)的代碼存在一些問(wèn)題,但是比較難以發(fā)現(xiàn)具體存在的問(wèn)題在哪里,需要將帶代碼恢復(fù)到指定的版本,下面這篇文章主要給大家介紹了關(guān)于pycharm中代碼回滾到指定版本的兩種實(shí)現(xiàn)方法,需要的朋友可以參考下2022-06-06python數(shù)組過(guò)濾實(shí)現(xiàn)方法
這篇文章主要介紹了python數(shù)組過(guò)濾實(shí)現(xiàn)方法,涉及Python針對(duì)數(shù)組的相關(guān)操作技巧,具有一定參考借鑒價(jià)值,需要的朋友可以參考下2015-07-07Python中條件判斷語(yǔ)句的簡(jiǎn)單使用方法
這篇文章主要介紹了Python中條件判斷語(yǔ)句的簡(jiǎn)單使用方法,是Python入門(mén)學(xué)習(xí)中的基礎(chǔ)知識(shí),需要的朋友可以參考下2015-08-08三步解決python PermissionError: [WinError 5]拒絕訪問(wèn)的情況
這篇文章主要介紹了三步解決python PermissionError: [WinError 5]拒絕訪問(wèn)的情況,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧2020-04-04Python實(shí)現(xiàn)簡(jiǎn)單的"導(dǎo)彈" 自動(dòng)追蹤原理解析
這篇文章主要介紹了Python實(shí)現(xiàn)簡(jiǎn)單的"導(dǎo)彈" 自動(dòng)追蹤原理解析,本文給大家介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或工作具有一定的參考借鑒價(jià)值,需要的朋友可以參考下2021-03-03Pytorch中torch.repeat_interleave()函數(shù)使用及說(shuō)明
這篇文章主要介紹了Pytorch中torch.repeat_interleave()函數(shù)使用及說(shuō)明,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。如有錯(cuò)誤或未考慮完全的地方,望不吝賜教2023-01-01使用apiDoc實(shí)現(xiàn)python接口文檔編寫(xiě)
今天小編就為大家分享一篇使用apiDoc實(shí)現(xiàn)python接口文檔編寫(xiě),具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧2019-11-11