Pytorch框架實(shí)現(xiàn)mnist手寫庫識(shí)別(與tensorflow對(duì)比)
前言最近在學(xué)習(xí)過程中需要用到pytorch框架,簡(jiǎn)單學(xué)習(xí)了一下,寫了一個(gè)簡(jiǎn)單的案例,記錄一下pytorch中搭建一個(gè)識(shí)別網(wǎng)絡(luò)基礎(chǔ)的東西。對(duì)應(yīng)一位博主寫的tensorflow的識(shí)別mnist數(shù)據(jù)集,將其改為pytorch框架,也可以詳細(xì)看到兩個(gè)框架大體的區(qū)別。
Tensorflow版本轉(zhuǎn)載來源(CSDN博主「兔八哥1024」):http://www.dbjr.com.cn/article/191157.htm
Pytorch實(shí)戰(zhàn)mnist手寫數(shù)字識(shí)別
#需要導(dǎo)入的包
import torch
import torch.nn as nn#用于構(gòu)建網(wǎng)絡(luò)層
import torch.optim as optim#導(dǎo)入優(yōu)化器
from torch.utils.data import DataLoader#加載數(shù)據(jù)集的迭代器
from torchvision import datasets, transforms#用于加載mnsit數(shù)據(jù)集
#下載數(shù)據(jù)集
train_set = datasets.MNIST('./data', train=True, download=True,transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1037,), (0.3081,))
]))
test_set = datasets.MNIST('./data', train=False, download=True,transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1037,), (0.3081,))
]))
#構(gòu)建網(wǎng)絡(luò)(網(wǎng)絡(luò)結(jié)構(gòu)對(duì)應(yīng)tensorflow的那一篇文章)
class Net(nn.Module):
def __init__(self, num_classes=10):
super(Net, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(1, 32, kernel_size=5, stride=1, padding=2),
nn.MaxPool2d(kernel_size=2,stride=2),
nn.Conv2d(32, 64, kernel_size=5, stride=1, padding=2),
nn.MaxPool2d(kernel_size=2,stride=2),
)
self.classifier = nn.Sequential(
nn.Linear(3136, 7*7*64),
nn.Linear(3136, num_classes),
)
def forward(self,x):
x = self.features(x)
x = torch.flatten(x, 1)
x = self.classifier(x)
return x
net=Net()
net.cuda()#用GPU運(yùn)行
#計(jì)算誤差,使用adam優(yōu)化器優(yōu)化誤差
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), 1e-2)
train_data = DataLoader(train_set, batch_size=128, shuffle=True)
test_data = DataLoader(test_set, batch_size=128, shuffle=False)
#訓(xùn)練過程
for epoch in range(1):
net.train() ##在進(jìn)行訓(xùn)練時(shí)加上train(),測(cè)試時(shí)加上eval()
batch = 0
for batch_images, batch_labels in train_data:
average_loss = 0
train_acc = 0
##在pytorch0.4之后將Variable 與tensor進(jìn)行合并,所以這里不需要進(jìn)行Variable封裝
if torch.cuda.is_available():
batch_images, batch_labels = batch_images.cuda(),batch_labels.cuda()
#前向傳播
out = net(batch_images)
loss = criterion(out,batch_labels)
average_loss = loss
prediction = torch.max(out,1)[1]
# print(prediction)
train_correct = (prediction == batch_labels).sum()
##這里得到的train_correct是一個(gè)longtensor型,需要轉(zhuǎn)換為float
train_acc = (train_correct.float()) / 128
optimizer.zero_grad() #清空梯度信息,否則在每次進(jìn)行反向傳播時(shí)都會(huì)累加
loss.backward() #loss反向傳播
optimizer.step() ##梯度更新
batch+=1
print("Epoch: %d/%d || batch:%d/%d average_loss: %.3f || train_acc: %.2f"
%(epoch, 20, batch, float(int(50000/128)), average_loss, train_acc))
# 在測(cè)試集上檢驗(yàn)效果
net.eval() # 將模型改為預(yù)測(cè)模式
for idx,(im1, label1) in enumerate(test_data):
if torch.cuda.is_available():
im, label = im1.cuda(),label1.cuda()
out = net(im)
loss = criterion(out, label)
eval_loss = loss
pred = torch.max(out,1)[1]
num_correct = (pred == label).sum()
acc = (num_correct.float())/ 128
eval_acc = acc
print('EVA_Batch:{}, Eval Loss: {:.6f}, Eval Acc: {:.6f}'
.format(idx,eval_loss , eval_acc))
運(yùn)行結(jié)果:

到此這篇關(guān)于Pytorch框架實(shí)現(xiàn)mnist手寫庫識(shí)別(與tensorflow對(duì)比)的文章就介紹到這了,更多相關(guān)Pytorch框架實(shí)現(xiàn)mnist手寫庫識(shí)別(與tensorflow對(duì)比)內(nèi)容請(qǐng)搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
相關(guān)文章
Python基礎(chǔ)之賦值,淺拷貝,深拷貝的區(qū)別
這篇文章主要介紹了Python基礎(chǔ)之賦值,淺拷貝,深拷貝的區(qū)別,文中有非常詳細(xì)的代碼示例,對(duì)正在學(xué)習(xí)python基礎(chǔ)的小伙伴們也有非常好的幫助,需要的朋友可以參考下2021-04-04
詳解Selenium-webdriver繞開反爬蟲機(jī)制的4種方法
這篇文章主要介紹了詳解Selenium-webdriver繞開反爬蟲機(jī)制的4種方法,文中通過示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2020-10-10
python采集天氣數(shù)據(jù)并做數(shù)據(jù)可視化
本文主要介紹了python采集天氣數(shù)據(jù)并做數(shù)據(jù)可視化,文中通過示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2022-07-07
詳解Django項(xiàng)目中模板標(biāo)簽及模板的繼承與引用(網(wǎng)站中快速布置廣告)
這篇文章主要介紹了詳解Django項(xiàng)目中模板標(biāo)簽及模板的繼承與引用【網(wǎng)站中快速布置廣告】,小編覺得挺不錯(cuò)的,現(xiàn)在分享給大家,也給大家做個(gè)參考。一起跟隨小編過來看看吧2019-03-03
Django中使用極驗(yàn)Geetest滑動(dòng)驗(yàn)證碼過程解析
這篇文章主要介紹了Django中使用極驗(yàn)Geetest滑動(dòng)驗(yàn)證碼過程解析,文中通過示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友可以參考下2019-07-07

