Pytorch框架實(shí)現(xiàn)mnist手寫庫識別(與tensorflow對比)
前言最近在學(xué)習(xí)過程中需要用到pytorch框架,簡單學(xué)習(xí)了一下,寫了一個(gè)簡單的案例,記錄一下pytorch中搭建一個(gè)識別網(wǎng)絡(luò)基礎(chǔ)的東西。對應(yīng)一位博主寫的tensorflow的識別mnist數(shù)據(jù)集,將其改為pytorch框架,也可以詳細(xì)看到兩個(gè)框架大體的區(qū)別。
Tensorflow版本轉(zhuǎn)載來源(CSDN博主「兔八哥1024」):http://www.dbjr.com.cn/article/191157.htm
Pytorch實(shí)戰(zhàn)mnist手寫數(shù)字識別
#需要導(dǎo)入的包 import torch import torch.nn as nn#用于構(gòu)建網(wǎng)絡(luò)層 import torch.optim as optim#導(dǎo)入優(yōu)化器 from torch.utils.data import DataLoader#加載數(shù)據(jù)集的迭代器 from torchvision import datasets, transforms#用于加載mnsit數(shù)據(jù)集 #下載數(shù)據(jù)集 train_set = datasets.MNIST('./data', train=True, download=True,transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1037,), (0.3081,)) ])) test_set = datasets.MNIST('./data', train=False, download=True,transform = transforms.Compose([ transforms.ToTensor(), transforms.Normalize((0.1037,), (0.3081,)) ])) #構(gòu)建網(wǎng)絡(luò)(網(wǎng)絡(luò)結(jié)構(gòu)對應(yīng)tensorflow的那一篇文章) class Net(nn.Module): def __init__(self, num_classes=10): super(Net, self).__init__() self.features = nn.Sequential( nn.Conv2d(1, 32, kernel_size=5, stride=1, padding=2), nn.MaxPool2d(kernel_size=2,stride=2), nn.Conv2d(32, 64, kernel_size=5, stride=1, padding=2), nn.MaxPool2d(kernel_size=2,stride=2), ) self.classifier = nn.Sequential( nn.Linear(3136, 7*7*64), nn.Linear(3136, num_classes), ) def forward(self,x): x = self.features(x) x = torch.flatten(x, 1) x = self.classifier(x) return x net=Net() net.cuda()#用GPU運(yùn)行 #計(jì)算誤差,使用adam優(yōu)化器優(yōu)化誤差 criterion = nn.CrossEntropyLoss() optimizer = optim.Adam(net.parameters(), 1e-2) train_data = DataLoader(train_set, batch_size=128, shuffle=True) test_data = DataLoader(test_set, batch_size=128, shuffle=False) #訓(xùn)練過程 for epoch in range(1): net.train() ##在進(jìn)行訓(xùn)練時(shí)加上train(),測試時(shí)加上eval() batch = 0 for batch_images, batch_labels in train_data: average_loss = 0 train_acc = 0 ##在pytorch0.4之后將Variable 與tensor進(jìn)行合并,所以這里不需要進(jìn)行Variable封裝 if torch.cuda.is_available(): batch_images, batch_labels = batch_images.cuda(),batch_labels.cuda() #前向傳播 out = net(batch_images) loss = criterion(out,batch_labels) average_loss = loss prediction = torch.max(out,1)[1] # print(prediction) train_correct = (prediction == batch_labels).sum() ##這里得到的train_correct是一個(gè)longtensor型,需要轉(zhuǎn)換為float train_acc = (train_correct.float()) / 128 optimizer.zero_grad() #清空梯度信息,否則在每次進(jìn)行反向傳播時(shí)都會累加 loss.backward() #loss反向傳播 optimizer.step() ##梯度更新 batch+=1 print("Epoch: %d/%d || batch:%d/%d average_loss: %.3f || train_acc: %.2f" %(epoch, 20, batch, float(int(50000/128)), average_loss, train_acc)) # 在測試集上檢驗(yàn)效果 net.eval() # 將模型改為預(yù)測模式 for idx,(im1, label1) in enumerate(test_data): if torch.cuda.is_available(): im, label = im1.cuda(),label1.cuda() out = net(im) loss = criterion(out, label) eval_loss = loss pred = torch.max(out,1)[1] num_correct = (pred == label).sum() acc = (num_correct.float())/ 128 eval_acc = acc print('EVA_Batch:{}, Eval Loss: {:.6f}, Eval Acc: {:.6f}' .format(idx,eval_loss , eval_acc))
運(yùn)行結(jié)果:
到此這篇關(guān)于Pytorch框架實(shí)現(xiàn)mnist手寫庫識別(與tensorflow對比)的文章就介紹到這了,更多相關(guān)Pytorch框架實(shí)現(xiàn)mnist手寫庫識別(與tensorflow對比)內(nèi)容請搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
相關(guān)文章
Python基礎(chǔ)之賦值,淺拷貝,深拷貝的區(qū)別
這篇文章主要介紹了Python基礎(chǔ)之賦值,淺拷貝,深拷貝的區(qū)別,文中有非常詳細(xì)的代碼示例,對正在學(xué)習(xí)python基礎(chǔ)的小伙伴們也有非常好的幫助,需要的朋友可以參考下2021-04-04詳解Selenium-webdriver繞開反爬蟲機(jī)制的4種方法
這篇文章主要介紹了詳解Selenium-webdriver繞開反爬蟲機(jī)制的4種方法,文中通過示例代碼介紹的非常詳細(xì),對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2020-10-10python采集天氣數(shù)據(jù)并做數(shù)據(jù)可視化
本文主要介紹了python采集天氣數(shù)據(jù)并做數(shù)據(jù)可視化,文中通過示例代碼介紹的非常詳細(xì),對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2022-07-07詳解Django項(xiàng)目中模板標(biāo)簽及模板的繼承與引用(網(wǎng)站中快速布置廣告)
這篇文章主要介紹了詳解Django項(xiàng)目中模板標(biāo)簽及模板的繼承與引用【網(wǎng)站中快速布置廣告】,小編覺得挺不錯(cuò)的,現(xiàn)在分享給大家,也給大家做個(gè)參考。一起跟隨小編過來看看吧2019-03-03Django中使用極驗(yàn)Geetest滑動驗(yàn)證碼過程解析
這篇文章主要介紹了Django中使用極驗(yàn)Geetest滑動驗(yàn)證碼過程解析,文中通過示例代碼介紹的非常詳細(xì),對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友可以參考下2019-07-07