欧美bbbwbbbw肥妇,免费乱码人妻系列日韩,一级黄片

Python 實現(xiàn)LeNet網(wǎng)絡(luò)模型的訓(xùn)練及預(yù)測

 更新時間:2021年11月23日 17:02:28   作者:Serins  
本文將為大家詳細講解如何使用CIFR10數(shù)據(jù)集訓(xùn)練模型以及用訓(xùn)練好的模型做預(yù)測。代碼具有一定價值,感興趣的小伙伴可以學(xué)習(xí)一下

1.LeNet模型訓(xùn)練腳本

整體的訓(xùn)練代碼如下,下面我會為大家詳細講解這些代碼的意思

import torch
import torchvision
from torchvision.transforms import transforms
import torch.nn as nn
from torch.utils.data import DataLoader
from pytorch.lenet.model import LeNet
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt

transform = transforms.Compose(
    # 將數(shù)據(jù)集轉(zhuǎn)換成tensor形式
    [transforms.ToTensor(),
     # 進行標準化,0.5是均值,也是方差,對應(yīng)三個維度都是0.5
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)

# 下載完整的數(shù)據(jù)集時,download=True,第一個為保存的路徑,下載完后download要改為False
# 為訓(xùn)練集時,train=True,為測試集時,train=False
train_set = torchvision.datasets.CIFAR10('./data', train=True,
                                         download=False, transform=transform)

# 加載訓(xùn)練集,設(shè)置批次大小,是否打亂,number_works是線程數(shù),window不設(shè)置為0會報錯,linux可以設(shè)置非零
train_loader = DataLoader(train_set, batch_size=36,
                          shuffle=True, num_workers=0)

test_set = torchvision.datasets.CIFAR10('./data', train=False,
                                        download=False, transform=transform)
# 設(shè)置的批次大小一次性將所有測試集圖片傳進去
test_loader = DataLoader(test_set, batch_size=10000,
                         shuffle=False, num_workers=0)

# 迭代測試集的圖片數(shù)據(jù)和標簽值
test_img, test_label = next(iter(test_loader))

# CIFAR10的十個類別名稱
classes = ('plane', 'car', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck')

# # ----------------------------顯示圖片-----------------------------------
# def imshow(img, label):
#     fig = plt.figure()
#     for i in range(len(img)):
#         ax = fig.add_subplot(1, len(img), i+1)
#         nping = img[i].numpy().transpose([1, 2, 0])
#         npimg = (nping * 2 + 0.5)
#         plt.imshow(npimg)
#         title = '{}'.format(classes[label[i]])
#         ax.set_title(title)
#         plt.axis('off')
#     plt.show()
# 
# 
# batch_image = test_img[: 5]
# label_img = test_label[: 5]
# imshow(batch_image, label_img)
# # ----------------------------------------------------------------------

net = LeNet()
# 定義損失函數(shù),nn.CrossEntropyLoss()自帶softmax函數(shù),所以模型的最后一層不需要softmax進行激活
loss_function = nn.CrossEntropyLoss()
# 定義優(yōu)化器,優(yōu)化網(wǎng)絡(luò)模型所有參數(shù)
optimizer = optim.Adam(net.parameters(), lr=0.001)

# 迭代五次
for epoch in range(5):
    # 初始損失設(shè)置為0
    running_loss = 0
    # 循環(huán)訓(xùn)練集,從1開始
    for step, data in enumerate(train_loader, start=1):
        inputs, labels = data
        # 優(yōu)化器的梯度清零,每次循環(huán)都需要清零,否則梯度會無限疊加,相當(dāng)于增加批次大小
        optimizer.zero_grad()
        # 將圖片數(shù)據(jù)輸入模型中
        outputs = net(inputs)
        # 傳入預(yù)測值和真實值,計算當(dāng)前損失值
        loss = loss_function(outputs, labels)
        # 損失反向傳播
        loss.backward()
        # 進行梯度更新
        optimizer.step()
        # 計算該輪的總損失,因為loss是tensor類型,所以需要用item()取具體值
        running_loss += loss.item()
        # 每500次進行日志的打印,對測試集進行預(yù)測
        if step % 500 == 0:
            # torch.no_grad()就是上下文管理,測試時不需要梯度更新,不跟蹤梯度
            with torch.no_grad():
                # 傳入所有測試集圖片進行預(yù)測
                outputs = net(test_img)
                # torch.max()中dim=1是因為結(jié)果為(batch, 10)的形式,我們只需要取第二個維度的最大值
                # max這個函數(shù)返回[最大值, 最大值索引],我們只需要取索引就行了,所以用[1]
                predict_y = torch.max(outputs, dim=1)[1]
                # (predict_y == test_label)相同返回True,不相等返回False,sum()對正確率進行疊加
                # 因為計算的變量都是tensor,所以需要用item()拿到取值
                accuracy = (predict_y == test_label).sum().item() / test_label.size(0)
                # running_loss/500是計算每一個step的loss,即每一步的損失
                print('[%d, %5d] train_loss: %.3f   test_accuracy: %.3f' %
                      (epoch+1, step, running_loss/500, accuracy))
                running_loss = 0.0

print('Finished Training!')

save_path = 'lenet.pth'
# 保存模型,字典形式
torch.save(net.state_dict(), save_path)

(1).下載CIFAR10數(shù)據(jù)集

首先要訓(xùn)練一個網(wǎng)絡(luò)模型,我們需要足夠多的圖片做數(shù)據(jù)集,這里我們用的是torchvision.dataset為我們提供的CIFAR10數(shù)據(jù)集(更多的數(shù)據(jù)集可以去pytorch官網(wǎng)查看pytorch官網(wǎng)提供的數(shù)據(jù)集)

train_set = torchvision.datasets.CIFAR10('./data', train=True,
                                         download=False, transform=transform)
test_set = torchvision.datasets.CIFAR10('./data', train=False,
                                        download=False, transform=transform)

這部分代碼是下載CIFAR10,第一個參數(shù)是下載數(shù)據(jù)集后存放的路徑,train=True和False對應(yīng)下載的訓(xùn)練集和測試集,transform是對應(yīng)的圖像增強方式

(2).圖像增強

transform = transforms.Compose(
    # 將數(shù)據(jù)集轉(zhuǎn)換成tensor形式
    [transforms.ToTensor(),
     # 進行標準化,0.5是均值,也是方差,對應(yīng)三個維度都是0.5
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)

這就是簡單的圖像圖像增強,transforms.ToTensor()將數(shù)據(jù)集的所有圖像轉(zhuǎn)換成tensor, transforms.Normalize()是標準化處理,包含兩個元組對應(yīng)均值和標準差,每個元組包含三個元素對應(yīng)圖片的三個維度[channels, height, width],為什么是這樣排序,別問,問就是pytorch要求的,順序不能變,之后會看到transforms.Normalize([0.485, 0.406, 0.456], [0.229, 0.224, 0.225])這兩組數(shù)據(jù),這是官方給出的均值和標準差,之后標準化的時候會經(jīng)常用到

(3).加載數(shù)據(jù)集

# 加載訓(xùn)練集,設(shè)置批次大小,是否打亂,number_works是線程數(shù),window不設(shè)置為0會報錯,linux可以設(shè)置非零
train_loader = DataLoader(dataset=train_set, batch_size=36,
                          shuffle=True, num_workers=0)
test_loader = DataLoader(dataset=test_set, batch_size=36,
                         shuffle=False, num_workers=0)

這里只簡單的設(shè)置的四個參數(shù)也是比較重要的,第一個就是需要加載的訓(xùn)練集和測試集,shuffle=True表示將數(shù)據(jù)集打亂,batch_size表示一次性向設(shè)備放入36張圖片,打包成一個batch,這時圖片的shape就會從[3, 32, 32]----》[36, 3, 32, 32],傳入網(wǎng)絡(luò)模型的shape也必須是[None, channels, height, width],None代表一個batch多少張圖片,否則就會報錯,number_works是代表線程數(shù),window系統(tǒng)必須設(shè)置為0,否則會報錯,linux系統(tǒng)可以設(shè)置非0數(shù)

(4).顯示部分圖像

def imshow(img, label):
    fig = plt.figure()
    for i in range(len(img)):
        ax = fig.add_subplot(1, len(img), i+1)
        nping = img[i].numpy().transpose([1, 2, 0])
        npimg = (nping * 2 + 0.5)
        plt.imshow(npimg)
        title = '{}'.format(classes[label[i]])
        ax.set_title(title)
        plt.axis('off')
    plt.show()


batch_image = test_img[: 5]
label_img = test_label[: 5]
imshow(batch_image, label_img)

這部分代碼是顯示測試集當(dāng)中前五張圖片,運行后會顯示5張拼接的圖片

由于這個數(shù)據(jù)集的圖片都比較小都是32x32的尺寸,有些可能也看的不太清楚,圖中顯示的是真實標簽,注:顯示圖片的代碼可能會這個報警(Clipping input data to the valid range for imshow with RGB data ([0…1] for floats or [0…255] for integers).),警告解決的方法:將圖片數(shù)組轉(zhuǎn)成uint8類型即可,即 plt.imshow(npimg.astype(‘uint8'),但是那樣顯示出來的圖片會變,所以暫時可以先不用管。

(5).初始化模型

數(shù)據(jù)圖片處理完了,下面就是我們的正式訓(xùn)練過程

net = LeNet()
# 定義損失函數(shù),nn.CrossEntropyLoss()自帶softmax函數(shù),所以模型的最后一層不需要softmax進行激活
loss_function = nn.CrossEntropyLoss()
# 定義優(yōu)化器,優(yōu)化模型所有參數(shù)
optimizer = optim.Adam(net.parameters(), lr=0.001)

首先初始化LeNet網(wǎng)絡(luò),定義交叉熵損失函數(shù),以及Adam優(yōu)化器,關(guān)于注釋寫的,我們可以ctrl+鼠標左鍵查看CrossEntropyLoss(),翻到CrossEntropyLoss類,可以看到注釋寫的這個標準包含LogSoftmax函數(shù),所以搭建LetNet模型的最后一層沒有使用softmax激活函數(shù)

(6).訓(xùn)練模型及保存模型參數(shù)

for epoch in range(5):
    # 初始損失設(shè)置為0
    running_loss = 0
    # 循環(huán)訓(xùn)練集,從1開始
    for step, data in enumerate(train_loader, start=1):
        inputs, labels = data
        # 優(yōu)化器的梯度清零,每次循環(huán)都需要清零,否則梯度會無限疊加,相當(dāng)于增加批次大小
        optimizer.zero_grad()
        # 將圖片數(shù)據(jù)輸入模型中得到輸出
        outputs = net(inputs)
        # 傳入預(yù)測值和真實值,計算當(dāng)前損失值
        loss = loss_function(outputs, labels)
        # 損失反向傳播
        loss.backward()
        # 進行梯度更新(更新W,b)
        optimizer.step()
        # 計算該輪的總損失,因為loss是tensor類型,所以需要用item()取到值
        running_loss += loss.item()
        # 每500次進行日志的打印,對測試集進行測試
        if step % 500 == 0:
            # torch.no_grad()就是上下文管理,測試時不需要梯度更新,不跟蹤梯度
            with torch.no_grad():
                # 傳入所有測試集圖片進行預(yù)測
                outputs = net(test_img)
                # torch.max()中dim=1是因為結(jié)果為(batch, 10)的形式,我們只需要取第二個維度的最大值,第二個維度是包含十個類別每個類別的概率的向量
                # max這個函數(shù)返回[最大值, 最大值索引],我們只需要取索引就行了,所以用[1]
                predict_y = torch.max(outputs, dim=1)[1]
                # (predict_y == test_label)相同返回True,不相等返回False,sum()對正確結(jié)果進行疊加,最后除測試集標簽的總個數(shù)
                # 因為計算的變量都是tensor,所以需要用item()拿到取值
                accuracy = (predict_y == test_label).sum().item() / test_label.size(0)
                # running_loss/500是計算每一個step的loss,即每一步的損失
                print('[%d, %5d] train_loss: %.3f   test_accuracy: %.3f' %
                      (epoch+1, step, running_loss/500, accuracy))
                running_loss = 0.0
                
print('Finished Training!')

save_path = 'lenet.pth'
# 保存模型,字典形式
torch.save(net.state_dict(), save_path)

這段代碼注釋寫的很清楚,大家仔細看就能看懂,流程不復(fù)雜,多看幾遍就能理解,最后再對訓(xùn)練好的模型進行保存就好了(* ̄︶ ̄)

2.預(yù)測腳本

上面已經(jīng)訓(xùn)練好了模型,得到了lenet.pth參數(shù)文件,預(yù)測就很簡單了,可以去網(wǎng)上隨便找一張數(shù)據(jù)集包含的類別圖片,將模型參數(shù)文件載入模型,通過對圖像進行一點處理,喂入模型即可,下面奉上代碼:

import torch
import numpy as np
import torchvision.transforms as transforms
from PIL import Image
from pytorch.lenet.model import LeNet

classes = ('plane', 'car', 'bird', 'cat', 'deer',
           'dog', 'frog', 'horse', 'ship', 'truck')

transforms = transforms.Compose(
    # 對數(shù)據(jù)圖片調(diào)整大小
    [transforms.Resize([32, 32]),
     transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))]
)

net = LeNet()
# 加載預(yù)訓(xùn)練模型
net.load_state_dict(torch.load('lenet.pth'))
# 網(wǎng)上隨便找的貓的圖片
img_path = '../../Photo/cat2.jpg'
img = Image.open(img_path)
# 圖片的處理
img = transforms(img)
# 增加一個維度,(channels, height, width)------->(batch, channels, height, width),pytorch要求必須輸入這樣的shape
img = torch.unsqueeze(img, dim=0)

with torch.no_grad():
    output = net(img)
    # dim=1,只取[batch, 10]中10個類別的那個維度,取預(yù)測結(jié)果的最大值索引,并轉(zhuǎn)換為numpy類型
    prediction1 = torch.max(output, dim=1)[1].data.numpy()
    # 用softmax()預(yù)測出一個概率矩陣
    prediction2 = torch.softmax(output, dim=1)
    # 得到概率最大的值得索引
    prediction2 = np.argmax(prediction2)
# 兩種方式都可以得到最后的結(jié)果
print(classes[int(prediction1)])
print(classes[int(prediction2)])

反正我最后預(yù)測出來結(jié)果把貓識別成了狗,還有90.01%的概率,就離譜哈哈哈,但也說明了LeNet這個網(wǎng)絡(luò)模型確實很淺,特征提取的不夠深,才會出現(xiàn)這種。

到此這篇關(guān)于Python 實現(xiàn)LeNet網(wǎng)絡(luò)模型的訓(xùn)練及預(yù)測的文章就介紹到這了,更多相關(guān)LeNet網(wǎng)絡(luò)模型訓(xùn)練及預(yù)測內(nèi)容請搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!

相關(guān)文章

最新評論