Pytorch搭建簡單的卷積神經(jīng)網(wǎng)絡(luò)(CNN)實(shí)現(xiàn)MNIST數(shù)據(jù)集分類任務(wù)
關(guān)于一些代碼里的解釋,可以看我上一篇發(fā)布的文章,里面有很詳細(xì)的介紹?。?!
可以依次把下面的代碼段合在一起運(yùn)行,也可以通過jupyter notebook分次運(yùn)行
第一步:基本庫的導(dǎo)入
import numpy as np import torch import torch.nn as nn import torchvision import torchvision.transforms as transforms from torch.utils.data import DataLoader import matplotlib.pyplot as plt import time np.random.seed(1234)
第二步:引用MNIST數(shù)據(jù)集,這里采用的是torchvision自帶的MNIST數(shù)據(jù)集
#這里用的是torchvision已經(jīng)封裝好的MINST數(shù)據(jù)集 trainset=torchvision.datasets.MNIST( root='MNIST', #root是下載MNIST數(shù)據(jù)集保存的路徑,可以自行修改 train=True, transform=torchvision.transforms.ToTensor(), download=True ) testset=torchvision.datasets.MNIST( root='MNIST', train=False, transform=torchvision.transforms.ToTensor(), download=True ) trainloader = DataLoader(dataset=trainset, batch_size=100, shuffle=True) #DataLoader是一個很好地能夠幫助整理數(shù)據(jù)集的類,可以用來分批次,打亂以及多線程等操作 testloader = DataLoader(dataset=testset, batch_size=100, shuffle=True)
下載之后利用DataLoader實(shí)例化為適合遍歷的訓(xùn)練集和測試集,我們把其中的某一批數(shù)據(jù)進(jìn)行可視化,下面是可視化的代碼,其實(shí)就是利用subplot畫了子圖。
#可視化某一批數(shù)據(jù) train_img,train_label=next(iter(trainloader)) #iter迭代器,可以用來便利trainloader里面每一個數(shù)據(jù),這里只迭代一次來進(jìn)行可視化 fig, axes = plt.subplots(10, 10, figsize=(10, 10)) axes_list = [] #輸入到網(wǎng)絡(luò)的圖像 for i in range(axes.shape[0]): for j in range(axes.shape[1]): axes[i, j].imshow(train_img[i*10+j,0,:,:],cmap="gray") #這里畫出來的就是我們想輸入到網(wǎng)絡(luò)里訓(xùn)練的圖像,與之對應(yīng)的標(biāo)簽用來進(jìn)行最后分類結(jié)果損失函數(shù)的計(jì)算 axes[i, j].axis("off") #對應(yīng)的標(biāo)簽 print(train_label)
第三步:用pytorch搭建簡單的卷積神經(jīng)網(wǎng)絡(luò)(CNN)
這里把卷積模塊單獨(dú)拿出來作為一個類,看上去會舒服一點(diǎn)。
#卷積模塊,由卷積核和激活函數(shù)組成 class conv_block(nn.Module): def __init__(self,ks,ch_in,ch_out): super(conv_block,self).__init__() self.conv = nn.Sequential( nn.Conv2d(ch_in, ch_out, kernel_size=ks,stride=1,padding=1,bias=True), #二維卷積核,用于提取局部的圖像信息 nn.ReLU(inplace=True), #這里用ReLU作為激活函數(shù) nn.Conv2d(ch_out, ch_out, kernel_size=ks,stride=1,padding=1,bias=True), nn.ReLU(inplace=True), ) def forward(self,x): return self.conv(x)
下面是CNN主體部分,由上面的卷積模塊和全連接分類器組合而成。這里只用了簡單的幾個卷積塊進(jìn)行堆疊,沒有采用池化以及dropout的操作。主要目的是給大家簡單搭建一下以便學(xué)習(xí)。
#常規(guī)CNN模塊(由幾個卷積模塊堆疊而成) class CNN(nn.Module): def __init__(self,kernel_size,in_ch,out_ch): super(CNN, self).__init__() feature_list = [16,32,64,128,256] #代表每一層網(wǎng)絡(luò)的特征數(shù),擴(kuò)大特征空間有助于挖掘更多的局部信息 self.conv1 = conv_block(kernel_size,in_ch,feature_list[0]) self.conv2 = conv_block(kernel_size,feature_list[0],feature_list[1]) self.conv3 = conv_block(kernel_size,feature_list[1],feature_list[2]) self.conv4 = conv_block(kernel_size,feature_list[2],feature_list[3]) self.conv5 = conv_block(kernel_size,feature_list[3],feature_list[4]) self.fc = nn.Sequential( #全連接層主要用來進(jìn)行分類,整合采集的局部信息以及全局信息 nn.Linear(feature_list[4] * 28 * 28, 1024), #此處28為MINST一張圖片的維度 nn.ReLU(), nn.Linear(1024, 512), nn.ReLU(), nn.Linear(512, 10) ) def forward(self,x): device = x.device x1 = self.conv1(x ) x2 = self.conv2(x1) x3 = self.conv3(x2) x4 = self.conv4(x3) x5 = self.conv5(x4) x5 = x5.view(x5.size()[0], -1) #全連接層相當(dāng)于做了矩陣乘法,所以這里需要將維度降維來實(shí)現(xiàn)矩陣的運(yùn)算 out = self.fc(x5) return out
第四步:訓(xùn)練以及模型保存
先是一些網(wǎng)絡(luò)參數(shù)的定義,包括優(yōu)化器,迭代輪數(shù),學(xué)習(xí)率,運(yùn)行硬件等等的確定。
#網(wǎng)絡(luò)參數(shù)定義 device = torch.device("cuda:4") #此處根據(jù)電腦配置進(jìn)行選擇,如果沒有cuda就用cpu #device = torch.device("cpu") net = CNN(3,1,1).to(device = device,dtype = torch.float32) epochs = 50 #訓(xùn)練輪次 optimizer = torch.optim.Adam(net.parameters(), lr=1e-4, weight_decay=1e-8) #使用Adam優(yōu)化器 criterion = nn.CrossEntropyLoss() #分類任務(wù)常用的交叉熵?fù)p失函數(shù) train_loss = []
然后是每一輪訓(xùn)練的主體:
# Begin training MinTrainLoss = 999 for epoch in range(1,epochs+1): total_train_loss = [] net.train() start = time.time() for input_img,label in trainloader: input_img = input_img.to(device = device,dtype=torch.float32) #我們同樣地,需要將我們?nèi)〕鰜淼挠?xùn)練集數(shù)據(jù)進(jìn)行torch能夠運(yùn)算的格式轉(zhuǎn)換 label = label.to(device = device,dtype=torch.float32) #輸入和輸出的格式都保持一致才能進(jìn)行運(yùn)算 optimizer.zero_grad() #每一次算loss前需要將之前的梯度清零,這樣才不會影響后面的更新 pred_img = net(input_img) loss = criterion(pred_img,label.long()) loss.backward() optimizer.step() total_train_loss.append(loss.item()) train_loss.append(np.mean(total_train_loss)) #將一個minibatch里面的損失取平均作為這一輪的loss end = time.time() #打印當(dāng)前的loss print("epochs[%3d/%3d] current loss: %.5f, time: %.3f"%(epoch,epochs,train_loss[-1],(end-start))) #打印每一輪訓(xùn)練的結(jié)果 if train_loss[-1]<MinTrainLoss: torch.save(net.state_dict(), "./model_min_train.pth") #保存loss最小的模型 MinTrainLoss = train_loss[-1]
以下是迭代過程:
第五步:導(dǎo)入網(wǎng)絡(luò)模型,輸入某一批測試數(shù)據(jù),查看結(jié)果
我們先來看某一批測試數(shù)據(jù)
#測試機(jī)某一批數(shù)據(jù) test_img,test_label=next(iter(testloader)) fig, axes = plt.subplots(10, 10, figsize=(10, 10)) axes_list = [] #輸入到網(wǎng)絡(luò)的圖像 for i in range(axes.shape[0]): for j in range(axes.shape[1]): axes[i, j].imshow(test_img[i*10+j,0,:,:],cmap="gray") axes[i, j].axis("off")
然后將其輸入到訓(xùn)練好的模型進(jìn)行預(yù)測
#預(yù)測我拿出來的那一批數(shù)據(jù)進(jìn)行展示 cnn = CNN(3,1,1).to(device = device,dtype = torch.float32) cnn.load_state_dict(torch.load("./model_min_train.pth", map_location=device)) #導(dǎo)入我們之前已經(jīng)訓(xùn)練好的模型 cnn.eval() #評估模式 test_img = test_img.to(device = device,dtype = torch.float32) test_label = test_label.to(device = device,dtype = torch.float32) pred_test = cnn(test_img) #記住,輸出的結(jié)果是一個長度為10的tensor test_pred = np.argmax(pred_test.cpu().data.numpy(), axis=1) #所以我們需要對其進(jìn)行最大值對應(yīng)索引的處理,從而得到我們想要的預(yù)測結(jié)果 #預(yù)測結(jié)果以及標(biāo)簽 print("預(yù)測結(jié)果") print(test_pred) print("標(biāo)簽") print(test_label.cpu().data.numpy())
從預(yù)測的結(jié)果我們可以看到,整體上這么一個簡單的CNN搭配全連接分類器對MNIST這一批數(shù)據(jù)分類的效果還不錯。當(dāng)然,我這里只用了交叉熵?fù)p失函數(shù),并且沒有計(jì)算準(zhǔn)確率,僅供大家對于CNN學(xué)習(xí)和參考。
到此這篇關(guān)于Pytorch搭建簡單的卷積神經(jīng)網(wǎng)絡(luò)(CNN)實(shí)現(xiàn)MNIST數(shù)據(jù)集分類任務(wù)的文章就介紹到這了,更多相關(guān)Pytorch卷積神經(jīng)網(wǎng)絡(luò)內(nèi)容請搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
相關(guān)文章
python 控制Asterisk AMI接口外呼電話的例子
今天小編就為大家分享一篇python 控制Asterisk AMI接口外呼電話的例子,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2019-08-08Tensorflow訓(xùn)練模型越來越慢的2種解決方案
今天小編就為大家分享一篇Tensorflow訓(xùn)練模型越來越慢的2種解決方案,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2020-02-02python實(shí)現(xiàn)bilibili動畫下載視頻批量改名功能
這篇文章主要介紹了python實(shí)現(xiàn)bilibili動畫下載視頻批量改名,本文通過實(shí)例代碼給大家介紹的非常詳細(xì),對大家的學(xué)習(xí)或工作具有一定的參考借鑒價值,需要的朋友可以參考下2021-11-11