欧美bbbwbbbw肥妇,免费乱码人妻系列日韩,一级黄片

Pytorch寫數字識別LeNet模型

 更新時間:2022年01月26日 17:54:21   作者:Jokic_Rn?  
這篇文章主要介紹了Pytorch寫數字識別LeNet模型,LeNet-5是一個較簡單的卷積神經網絡,??LeNet-5?這個網絡雖然很小,但是它包含了深度學習的基本模塊:卷積層,池化層,全連接層。是其他深度學習模型的基礎,?這里我們對LeNet-5進行深入分析,需要的朋友可以參考下

LeNet網絡

LeNet網絡過卷積層時候保持分辨率不變,過池化層時候分辨率變小。實現(xiàn)如下

from PIL import Image
import cv2
import matplotlib.pyplot as plt
import torchvision
from torchvision import transforms
import torch
from torch.utils.data import DataLoader
import torch.nn as nn
import numpy as np
import tqdm as tqdm

class LeNet(nn.Module):
? ? def __init__(self) -> None:
? ? ? ? super().__init__()
? ? ? ? self.sequential = nn.Sequential(nn.Conv2d(1,6,kernel_size=5,padding=2),nn.Sigmoid(),
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? nn.AvgPool2d(kernel_size=2,stride=2),
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? nn.Conv2d(6,16,kernel_size=5),nn.Sigmoid(),
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? nn.AvgPool2d(kernel_size=2,stride=2),
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? nn.Flatten(),
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? nn.Linear(16*25,120),nn.Sigmoid(),
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? nn.Linear(120,84),nn.Sigmoid(),
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? nn.Linear(84,10))
? ? ? ??
? ??
? ? def forward(self,x):
? ? ? ? return self.sequential(x)

class MLP(nn.Module):
? ? def __init__(self) -> None:
? ? ? ? super().__init__()
? ? ? ? self.sequential = nn.Sequential(nn.Flatten(),
? ? ? ? ? ? ? ? ? ? ? ? ? nn.Linear(28*28,120),nn.Sigmoid(),
? ? ? ? ? ? ? ? ? ? ? ? ? nn.Linear(120,84),nn.Sigmoid(),
? ? ? ? ? ? ? ? ? ? ? ? ? nn.Linear(84,10))
? ? ? ??
? ??
? ? def forward(self,x):
? ? ? ? return self.sequential(x)

epochs = 15
batch = 32
lr=0.9
loss = nn.CrossEntropyLoss()
model = LeNet()
optimizer = torch.optim.SGD(model.parameters(),lr)
device = torch.device('cuda')
root = r"./"
trans_compose ?= transforms.Compose([transforms.ToTensor(),
? ? ? ? ? ? ? ? ? ? ])
train_data = torchvision.datasets.MNIST(root,train=True,transform=trans_compose,download=True)
test_data = torchvision.datasets.MNIST(root,train=False,transform=trans_compose,download=True)
train_loader = DataLoader(train_data,batch_size=batch,shuffle=True)
test_loader = DataLoader(test_data,batch_size=batch,shuffle=False)

model.to(device)
loss.to(device)
# model.apply(init_weights)
for epoch in range(epochs):
? train_loss = 0
? test_loss = 0
? correct_train = 0
? correct_test = 0
? for index,(x,y) in enumerate(train_loader):
? ? x = x.to(device)
? ? y = y.to(device)
? ? predict = model(x)
? ? L = loss(predict,y)
? ? optimizer.zero_grad()
? ? L.backward()
? ? optimizer.step()
? ? train_loss = train_loss + L
? ? correct_train += (predict.argmax(dim=1)==y).sum()
? acc_train = correct_train/(batch*len(train_loader))
? with torch.no_grad():
? ? for index,(x,y) in enumerate(test_loader):
? ? ? [x,y] = [x.to(device),y.to(device)]
? ? ? predict = model(x)
? ? ? L1 = loss(predict,y)
? ? ? test_loss = test_loss + L1
? ? ? correct_test += (predict.argmax(dim=1)==y).sum()
? ? acc_test = correct_test/(batch*len(test_loader))
? print(f'epoch:{epoch},train_loss:{train_loss/batch},test_loss:{test_loss/batch},acc_train:{acc_train},acc_test:{acc_test}')

訓練結果

epoch:12,train_loss:2.235553741455078,test_loss:0.3947642743587494,acc_train:0.9879833459854126,acc_test:0.9851238131523132
epoch:13,train_loss:2.028963804244995,test_loss:0.3220392167568207,acc_train:0.9891499876976013,acc_test:0.9875199794769287
epoch:14,train_loss:1.8020273447036743,test_loss:0.34837451577186584,acc_train:0.9901833534240723,acc_test:0.98702073097229

泛化能力測試

找了一張圖片,將其分割成只含一個數字的圖片進行測試

images_np = cv2.imread("/content/R-C.png",cv2.IMREAD_GRAYSCALE)
h,w = images_np.shape
images_np = np.array(255*torch.ones(h,w))-images_np#圖片反色
images = Image.fromarray(images_np)
plt.figure(1)
plt.imshow(images)
test_images = []
for i in range(10):
? for j in range(16):
? ? test_images.append(images_np[h//10*i:h//10+h//10*i,w//16*j:w//16*j+w//16])
sample = test_images[77]
sample_tensor = torch.tensor(sample).unsqueeze(0).unsqueeze(0).type(torch.FloatTensor).to(device)
sample_tensor = torch.nn.functional.interpolate(sample_tensor,(28,28))
predict = model(sample_tensor)
output = predict.argmax()
print(output)
plt.figure(2)
plt.imshow(np.array(sample_tensor.squeeze().to('cpu')))

此時預測結果為4,預測正確。從這段代碼中可以看到有一個反色的步驟,若不反色,結果會受到影響,如下圖所示,預測為0,錯誤。
模型用于輸入的圖片是單通道的黑白圖片,這里由于可視化出現(xiàn)了黃色,但實際上是黑白色,反色操作說明了數據的預處理十分的重要,很多數據如果是不清理過是無法直接用于推理的。

將所有用來泛化性測試的圖片進行準確率測試:

correct = 0
i = 0
cnt = 1
for sample in test_images:
? sample_tensor = torch.tensor(sample).unsqueeze(0).unsqueeze(0).type(torch.FloatTensor).to(device)
? sample_tensor = torch.nn.functional.interpolate(sample_tensor,(28,28))
? predict = model(sample_tensor)
? output = predict.argmax()
? if(output==i):
? ? correct+=1
? if(cnt%16==0):
? ? i+=1
? cnt+=1
acc_g = correct/len(test_images)
print(f'acc_g:{acc_g}')

如果不反色,acc_g=0.15

acc_g:0.50625

到此這篇關于Pytorch寫數字識別LeNet模型的文章就介紹到這了,更多相關Pytorch寫數字識別LeNet模型內容請搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關文章希望大家以后多多支持腳本之家!

相關文章

  • Python實現(xiàn)將一個大文件按段落分隔為多個小文件的簡單操作方法

    Python實現(xiàn)將一個大文件按段落分隔為多個小文件的簡單操作方法

    這篇文章主要介紹了Python實現(xiàn)將一個大文件按段落分隔為多個小文件的簡單操作方法,涉及Python針對文件的讀取、遍歷、轉換、寫入等相關操作技巧,需要的朋友可以參考下
    2017-04-04
  • Ubuntu18.04安裝 PyCharm并使用 Anaconda 管理的Python環(huán)境

    Ubuntu18.04安裝 PyCharm并使用 Anaconda 管理的Python環(huán)境

    這篇文章主要介紹了Ubuntu18.04安裝 PyCharm并使用 Anaconda 管理的Python環(huán)境的教程,本文給大家介紹的非常詳細,對大家的學習或工作具有一定的參考借鑒價值,需要的朋友可以參考下
    2020-04-04
  • Keras神經網絡efficientnet模型搭建yolov3目標檢測平臺

    Keras神經網絡efficientnet模型搭建yolov3目標檢測平臺

    這篇文章主要為大家介紹了Keras利用efficientnet系列模型搭建yolov3目標檢測平臺的過程詳解,有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進步,早日升職加薪
    2022-05-05
  • python中驗證碼連通域分割的方法詳解

    python中驗證碼連通域分割的方法詳解

    這篇文章主要給大家介紹了關于python中驗證碼連通域分割的相關資料,文中通過示例代碼介紹的非常詳細,對大家學習或者使用python具有一定的參考學習價值,需要的朋友們下面隨著小編來一起學習學習吧
    2018-06-06
  • 你真的了解Python的random模塊嗎?

    你真的了解Python的random模塊嗎?

    這篇文章主要介紹了Python的random模塊的相關內容,具有一定借鑒價值,需要的朋友可以參考下。
    2017-12-12
  • python字典的值可以修改嗎

    python字典的值可以修改嗎

    在本篇文章里小編給大家分享的是一篇關于python字典的值修改的方法步驟,需要的朋友們可以學習下。
    2020-06-06
  • Pandas刪除數據的幾種情況(小結)

    Pandas刪除數據的幾種情況(小結)

    這篇文章主要介紹了Pandas刪除數據的幾種情況(小結),詳細的介紹了4種方式,具有一定的參考價值,感興趣的小伙伴們可以參考一下
    2019-06-06
  • Python實現(xiàn)12306火車票搶票系統(tǒng)

    Python實現(xiàn)12306火車票搶票系統(tǒng)

    這篇文章主要介紹了Python實現(xiàn)12306火車票搶票系統(tǒng),本文通過實例代碼給大家介紹的非常詳細,具有一定的參考借鑒價值 ,需要的朋友可以參考下
    2019-07-07
  • 關于Qt6中QtMultimedia多媒體模塊的重大改變分析

    關于Qt6中QtMultimedia多媒體模塊的重大改變分析

    如果您一直在 Qt 5 中使用 Qt Multimedia,則需要對您的實現(xiàn)進行更改。這篇博文將嘗試引導您完成最大的變化,同時查看 API 和內部結構
    2021-09-09
  • opencv實踐項目之圖像拼接詳細步驟

    opencv實踐項目之圖像拼接詳細步驟

    OpenCV的應用領域非常廣泛,包括圖像拼接、圖像降噪、產品質檢、人機交互、人臉識別、動作識別、動作跟蹤、無人駕駛等,下面這篇文章主要給大家介紹了關于opencv實踐項目之圖像拼接的相關資料,需要的朋友可以參考下
    2023-05-05

最新評論