欧美bbbwbbbw肥妇,免费乱码人妻系列日韩,一级黄片

Pytorch搭建簡單的卷積神經(jīng)網(wǎng)絡(luò)(CNN)實(shí)現(xiàn)MNIST數(shù)據(jù)集分類任務(wù)

 更新時間:2023年03月23日 10:04:59   作者:無知的吱屋  
這篇文章主要介紹了Pytorch搭建簡單的卷積神經(jīng)網(wǎng)絡(luò)(CNN)實(shí)現(xiàn)MNIST數(shù)據(jù)集分類任務(wù),本文給大家介紹的非常詳細(xì),對大家的學(xué)習(xí)或工作具有一定的參考借鑒價值,需要的朋友可以參考下

關(guān)于一些代碼里的解釋,可以看我上一篇發(fā)布的文章,里面有很詳細(xì)的介紹?。?!

可以依次把下面的代碼段合在一起運(yùn)行,也可以通過jupyter notebook分次運(yùn)行

第一步:基本庫的導(dǎo)入

import numpy as np
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import time
np.random.seed(1234)

第二步:引用MNIST數(shù)據(jù)集,這里采用的是torchvision自帶的MNIST數(shù)據(jù)集

#這里用的是torchvision已經(jīng)封裝好的MINST數(shù)據(jù)集
trainset=torchvision.datasets.MNIST(
    root='MNIST',  #root是下載MNIST數(shù)據(jù)集保存的路徑,可以自行修改
    train=True,
    transform=torchvision.transforms.ToTensor(),
    download=True
)
 
testset=torchvision.datasets.MNIST(
    root='MNIST',
    train=False,
    transform=torchvision.transforms.ToTensor(),
    download=True
)
 
trainloader = DataLoader(dataset=trainset, batch_size=100, shuffle=True)   #DataLoader是一個很好地能夠幫助整理數(shù)據(jù)集的類,可以用來分批次,打亂以及多線程等操作
testloader = DataLoader(dataset=testset, batch_size=100, shuffle=True)

下載之后利用DataLoader實(shí)例化為適合遍歷的訓(xùn)練集和測試集,我們把其中的某一批數(shù)據(jù)進(jìn)行可視化,下面是可視化的代碼,其實(shí)就是利用subplot畫了子圖。

#可視化某一批數(shù)據(jù)
train_img,train_label=next(iter(trainloader))   #iter迭代器,可以用來便利trainloader里面每一個數(shù)據(jù),這里只迭代一次來進(jìn)行可視化
fig, axes = plt.subplots(10, 10, figsize=(10, 10))
axes_list = []
#輸入到網(wǎng)絡(luò)的圖像
for i in range(axes.shape[0]):
    for j in range(axes.shape[1]):
        axes[i, j].imshow(train_img[i*10+j,0,:,:],cmap="gray")    #這里畫出來的就是我們想輸入到網(wǎng)絡(luò)里訓(xùn)練的圖像,與之對應(yīng)的標(biāo)簽用來進(jìn)行最后分類結(jié)果損失函數(shù)的計(jì)算
        axes[i, j].axis("off")
#對應(yīng)的標(biāo)簽
print(train_label)

 第三步:用pytorch搭建簡單的卷積神經(jīng)網(wǎng)絡(luò)(CNN)

 這里把卷積模塊單獨(dú)拿出來作為一個類,看上去會舒服一點(diǎn)。

#卷積模塊,由卷積核和激活函數(shù)組成
class conv_block(nn.Module):
    def __init__(self,ks,ch_in,ch_out):
        super(conv_block,self).__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(ch_in, ch_out, kernel_size=ks,stride=1,padding=1,bias=True),  #二維卷積核,用于提取局部的圖像信息
            nn.ReLU(inplace=True), #這里用ReLU作為激活函數(shù)
            nn.Conv2d(ch_out, ch_out, kernel_size=ks,stride=1,padding=1,bias=True),
            nn.ReLU(inplace=True),
        )
    def forward(self,x):
        return self.conv(x)

下面是CNN主體部分,由上面的卷積模塊和全連接分類器組合而成。這里只用了簡單的幾個卷積塊進(jìn)行堆疊,沒有采用池化以及dropout的操作。主要目的是給大家簡單搭建一下以便學(xué)習(xí)。

#常規(guī)CNN模塊(由幾個卷積模塊堆疊而成)
class CNN(nn.Module):
    def __init__(self,kernel_size,in_ch,out_ch):
        super(CNN, self).__init__()
        feature_list = [16,32,64,128,256]   #代表每一層網(wǎng)絡(luò)的特征數(shù),擴(kuò)大特征空間有助于挖掘更多的局部信息
        self.conv1 = conv_block(kernel_size,in_ch,feature_list[0])
        self.conv2 = conv_block(kernel_size,feature_list[0],feature_list[1])
        self.conv3 = conv_block(kernel_size,feature_list[1],feature_list[2])
        self.conv4 = conv_block(kernel_size,feature_list[2],feature_list[3])
        self.conv5 = conv_block(kernel_size,feature_list[3],feature_list[4])
        self.fc =  nn.Sequential(           #全連接層主要用來進(jìn)行分類,整合采集的局部信息以及全局信息
            nn.Linear(feature_list[4] * 28 * 28, 1024),  #此處28為MINST一張圖片的維度
            nn.ReLU(),
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Linear(512, 10)
        )
 
    def forward(self,x):
        device = x.device
        x1 = self.conv1(x )
        x2 = self.conv2(x1)
        x3 = self.conv3(x2)
        x4 = self.conv4(x3)
        x5 = self.conv5(x4)
        x5 = x5.view(x5.size()[0], -1)  #全連接層相當(dāng)于做了矩陣乘法,所以這里需要將維度降維來實(shí)現(xiàn)矩陣的運(yùn)算
        out = self.fc(x5)
        return out

第四步:訓(xùn)練以及模型保存

先是一些網(wǎng)絡(luò)參數(shù)的定義,包括優(yōu)化器,迭代輪數(shù),學(xué)習(xí)率,運(yùn)行硬件等等的確定。

#網(wǎng)絡(luò)參數(shù)定義
device = torch.device("cuda:4")  #此處根據(jù)電腦配置進(jìn)行選擇,如果沒有cuda就用cpu
#device = torch.device("cpu")
net = CNN(3,1,1).to(device = device,dtype = torch.float32)
epochs = 50  #訓(xùn)練輪次
optimizer = torch.optim.Adam(net.parameters(), lr=1e-4, weight_decay=1e-8)  #使用Adam優(yōu)化器
criterion = nn.CrossEntropyLoss()  #分類任務(wù)常用的交叉熵?fù)p失函數(shù)
train_loss = []

然后是每一輪訓(xùn)練的主體:

# Begin training
MinTrainLoss = 999
for epoch in range(1,epochs+1):
    total_train_loss = []      
    net.train()
    start = time.time()
    for input_img,label in trainloader:
        input_img = input_img.to(device = device,dtype=torch.float32)  #我們同樣地,需要將我們?nèi)〕鰜淼挠?xùn)練集數(shù)據(jù)進(jìn)行torch能夠運(yùn)算的格式轉(zhuǎn)換
        label = label.to(device = device,dtype=torch.float32)          #輸入和輸出的格式都保持一致才能進(jìn)行運(yùn)算
        optimizer.zero_grad()  #每一次算loss前需要將之前的梯度清零,這樣才不會影響后面的更新
        pred_img = net(input_img) 
        loss = criterion(pred_img,label.long())
        loss.backward()
        optimizer.step()
        total_train_loss.append(loss.item())
    train_loss.append(np.mean(total_train_loss))    #將一個minibatch里面的損失取平均作為這一輪的loss
    end = time.time()
    #打印當(dāng)前的loss
    print("epochs[%3d/%3d] current loss: %.5f, time: %.3f"%(epoch,epochs,train_loss[-1],(end-start)))   #打印每一輪訓(xùn)練的結(jié)果
    
    if train_loss[-1]<MinTrainLoss:
        torch.save(net.state_dict(), "./model_min_train.pth")  #保存loss最小的模型
        MinTrainLoss = train_loss[-1]

以下是迭代過程:

 第五步:導(dǎo)入網(wǎng)絡(luò)模型,輸入某一批測試數(shù)據(jù),查看結(jié)果

我們先來看某一批測試數(shù)據(jù)

#測試機(jī)某一批數(shù)據(jù)
test_img,test_label=next(iter(testloader))
fig, axes = plt.subplots(10, 10, figsize=(10, 10))
axes_list = []
#輸入到網(wǎng)絡(luò)的圖像
for i in range(axes.shape[0]):
    for j in range(axes.shape[1]):
        axes[i, j].imshow(test_img[i*10+j,0,:,:],cmap="gray")
        axes[i, j].axis("off")

然后將其輸入到訓(xùn)練好的模型進(jìn)行預(yù)測

#預(yù)測我拿出來的那一批數(shù)據(jù)進(jìn)行展示
cnn = CNN(3,1,1).to(device = device,dtype = torch.float32)
cnn.load_state_dict(torch.load("./model_min_train.pth", map_location=device)) #導(dǎo)入我們之前已經(jīng)訓(xùn)練好的模型
cnn.eval()   #評估模式
 
test_img = test_img.to(device = device,dtype = torch.float32)
test_label = test_label.to(device = device,dtype = torch.float32)
 
pred_test = cnn(test_img)  #記住,輸出的結(jié)果是一個長度為10的tensor
test_pred = np.argmax(pred_test.cpu().data.numpy(), axis=1)  #所以我們需要對其進(jìn)行最大值對應(yīng)索引的處理,從而得到我們想要的預(yù)測結(jié)果
 
#預(yù)測結(jié)果以及標(biāo)簽
print("預(yù)測結(jié)果")
print(test_pred)
print("標(biāo)簽")
print(test_label.cpu().data.numpy())

從預(yù)測的結(jié)果我們可以看到,整體上這么一個簡單的CNN搭配全連接分類器對MNIST這一批數(shù)據(jù)分類的效果還不錯。當(dāng)然,我這里只用了交叉熵?fù)p失函數(shù),并且沒有計(jì)算準(zhǔn)確率,僅供大家對于CNN學(xué)習(xí)和參考。

到此這篇關(guān)于Pytorch搭建簡單的卷積神經(jīng)網(wǎng)絡(luò)(CNN)實(shí)現(xiàn)MNIST數(shù)據(jù)集分類任務(wù)的文章就介紹到這了,更多相關(guān)Pytorch卷積神經(jīng)網(wǎng)絡(luò)內(nèi)容請搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!

相關(guān)文章

  • python 控制Asterisk AMI接口外呼電話的例子

    python 控制Asterisk AMI接口外呼電話的例子

    今天小編就為大家分享一篇python 控制Asterisk AMI接口外呼電話的例子,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧
    2019-08-08
  • 使用Python寫個小監(jiān)控

    使用Python寫個小監(jiān)控

    最近使用python寫了個小監(jiān)控,為什么使用python?簡單、方便、好管理,Python如何實(shí)現(xiàn)簡單的小監(jiān)控,感興趣的小伙伴們可以參考一下
    2016-01-01
  • 用python對excel查重

    用python對excel查重

    這篇文章主要介紹了用python對excel查重的方法,幫助大家更好的利用python處理excel表格,感興趣的朋友可以了解下
    2020-12-12
  • Python解惑之整數(shù)比較詳解

    Python解惑之整數(shù)比較詳解

    這篇文章主要給大家介紹了Python中整數(shù)比較的相關(guān)資料,文中通過示例代碼介紹的非常詳細(xì),詳細(xì)會對大家學(xué)習(xí)python的整數(shù)具有一定的參考價值,需要的朋友下面跟著小編一起來學(xué)習(xí)學(xué)習(xí)吧。
    2017-04-04
  • Python進(jìn)行密碼學(xué)反向密碼教程

    Python進(jìn)行密碼學(xué)反向密碼教程

    這篇文章主要為大家介紹了Python進(jìn)行密碼學(xué)反向密碼的教程詳解,有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進(jìn)步,早日升職加薪
    2022-05-05
  • Tensorflow訓(xùn)練模型越來越慢的2種解決方案

    Tensorflow訓(xùn)練模型越來越慢的2種解決方案

    今天小編就為大家分享一篇Tensorflow訓(xùn)練模型越來越慢的2種解決方案,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧
    2020-02-02
  • python將每個單詞按空格分開并保存到文件中

    python將每個單詞按空格分開并保存到文件中

    這篇文章主要介紹了python將每個單詞按空格分開并保存到文件中,需要的朋友可以參考下
    2018-03-03
  • python實(shí)現(xiàn)bilibili動畫下載視頻批量改名功能

    python實(shí)現(xiàn)bilibili動畫下載視頻批量改名功能

    這篇文章主要介紹了python實(shí)現(xiàn)bilibili動畫下載視頻批量改名,本文通過實(shí)例代碼給大家介紹的非常詳細(xì),對大家的學(xué)習(xí)或工作具有一定的參考借鑒價值,需要的朋友可以參考下
    2021-11-11
  • Python字符串格式化方式

    Python字符串格式化方式

    這篇文章主要介紹了Python字符串格式化方式,字符串格式化在我們的開發(fā)過程中被廣泛的應(yīng)用,因此也是我們要重點(diǎn)掌握的內(nèi)容之一,下文相關(guān)介紹,需要的朋友可以參考一下
    2022-04-04
  • python 表格打印代碼實(shí)例解析

    python 表格打印代碼實(shí)例解析

    這篇文章主要介紹了python 表格打印代碼實(shí)例解析,文中通過示例代碼介紹的非常詳細(xì),對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價值,需要的朋友可以參考下
    2019-10-10

最新評論