欧美bbbwbbbw肥妇,免费乱码人妻系列日韩,一级黄片

Pytorch框架實(shí)現(xiàn)mnist手寫庫識別(與tensorflow對比)

 更新時(shí)間:2020年07月20日 08:37:52   作者:社會青年技術(shù)官  
這篇文章主要介紹了Pytorch框架實(shí)現(xiàn)mnist手寫庫識別(與tensorflow對比),文中通過示例代碼介紹的非常詳細(xì),對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧

前言最近在學(xué)習(xí)過程中需要用到pytorch框架,簡單學(xué)習(xí)了一下,寫了一個(gè)簡單的案例,記錄一下pytorch中搭建一個(gè)識別網(wǎng)絡(luò)基礎(chǔ)的東西。對應(yīng)一位博主寫的tensorflow的識別mnist數(shù)據(jù)集,將其改為pytorch框架,也可以詳細(xì)看到兩個(gè)框架大體的區(qū)別。

Tensorflow版本轉(zhuǎn)載來源(CSDN博主「兔八哥1024」):http://www.dbjr.com.cn/article/191157.htm

Pytorch實(shí)戰(zhàn)mnist手寫數(shù)字識別

#需要導(dǎo)入的包
import torch
import torch.nn as nn#用于構(gòu)建網(wǎng)絡(luò)層
import torch.optim as optim#導(dǎo)入優(yōu)化器
from torch.utils.data import DataLoader#加載數(shù)據(jù)集的迭代器
from torchvision import datasets, transforms#用于加載mnsit數(shù)據(jù)集

#下載數(shù)據(jù)集

train_set = datasets.MNIST('./data', train=True, download=True,transform = transforms.Compose([
         transforms.ToTensor(),
         transforms.Normalize((0.1037,), (0.3081,))
       ]))
test_set = datasets.MNIST('./data', train=False, download=True,transform = transforms.Compose([
         transforms.ToTensor(),
         transforms.Normalize((0.1037,), (0.3081,))
       ]))

#構(gòu)建網(wǎng)絡(luò)(網(wǎng)絡(luò)結(jié)構(gòu)對應(yīng)tensorflow的那一篇文章)

class Net(nn.Module):

  def __init__(self, num_classes=10):
    super(Net, self).__init__()
    self.features = nn.Sequential(
      nn.Conv2d(1, 32, kernel_size=5, stride=1, padding=2),
      nn.MaxPool2d(kernel_size=2,stride=2),
      nn.Conv2d(32, 64, kernel_size=5, stride=1, padding=2),
      nn.MaxPool2d(kernel_size=2,stride=2),

    )
    self.classifier = nn.Sequential(
      nn.Linear(3136, 7*7*64),
      nn.Linear(3136, num_classes),

    )

  def forward(self,x):
    x = self.features(x)
    x = torch.flatten(x, 1)
    x = self.classifier(x)

    return x
net=Net()
net.cuda()#用GPU運(yùn)行

#計(jì)算誤差,使用adam優(yōu)化器優(yōu)化誤差
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), 1e-2)

train_data = DataLoader(train_set, batch_size=128, shuffle=True)
test_data = DataLoader(test_set, batch_size=128, shuffle=False)


#訓(xùn)練過程
for epoch in range(1):
  net.train() ##在進(jìn)行訓(xùn)練時(shí)加上train(),測試時(shí)加上eval()
  batch = 0

  for batch_images, batch_labels in train_data:

    average_loss = 0
    train_acc = 0

    ##在pytorch0.4之后將Variable 與tensor進(jìn)行合并,所以這里不需要進(jìn)行Variable封裝
    if torch.cuda.is_available():
      batch_images, batch_labels = batch_images.cuda(),batch_labels.cuda()

    #前向傳播
    out = net(batch_images)
    loss = criterion(out,batch_labels)


    average_loss = loss
    prediction = torch.max(out,1)[1]
    # print(prediction)

    train_correct = (prediction == batch_labels).sum()
    ##這里得到的train_correct是一個(gè)longtensor型,需要轉(zhuǎn)換為float

    train_acc = (train_correct.float()) / 128

    optimizer.zero_grad() #清空梯度信息,否則在每次進(jìn)行反向傳播時(shí)都會累加
    loss.backward() #loss反向傳播
    optimizer.step() ##梯度更新

    batch+=1
    print("Epoch: %d/%d || batch:%d/%d average_loss: %.3f || train_acc: %.2f"
       %(epoch, 20, batch, float(int(50000/128)), average_loss, train_acc))

# 在測試集上檢驗(yàn)效果
net.eval() # 將模型改為預(yù)測模式
for idx,(im1, label1) in enumerate(test_data):
  if torch.cuda.is_available():
    im, label = im1.cuda(),label1.cuda()
  out = net(im)
  loss = criterion(out, label)

  eval_loss = loss

  pred = torch.max(out,1)[1]
  num_correct = (pred == label).sum()
  acc = (num_correct.float())/ 128
  eval_acc = acc

  print('EVA_Batch:{}, Eval Loss: {:.6f}, Eval Acc: {:.6f}'
   .format(idx,eval_loss , eval_acc))

運(yùn)行結(jié)果:

到此這篇關(guān)于Pytorch框架實(shí)現(xiàn)mnist手寫庫識別(與tensorflow對比)的文章就介紹到這了,更多相關(guān)Pytorch框架實(shí)現(xiàn)mnist手寫庫識別(與tensorflow對比)內(nèi)容請搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!

相關(guān)文章

  • Python基礎(chǔ)之賦值,淺拷貝,深拷貝的區(qū)別

    Python基礎(chǔ)之賦值,淺拷貝,深拷貝的區(qū)別

    這篇文章主要介紹了Python基礎(chǔ)之賦值,淺拷貝,深拷貝的區(qū)別,文中有非常詳細(xì)的代碼示例,對正在學(xué)習(xí)python基礎(chǔ)的小伙伴們也有非常好的幫助,需要的朋友可以參考下
    2021-04-04
  • 詳解Selenium-webdriver繞開反爬蟲機(jī)制的4種方法

    詳解Selenium-webdriver繞開反爬蟲機(jī)制的4種方法

    這篇文章主要介紹了詳解Selenium-webdriver繞開反爬蟲機(jī)制的4種方法,文中通過示例代碼介紹的非常詳細(xì),對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧
    2020-10-10
  • python采集天氣數(shù)據(jù)并做數(shù)據(jù)可視化

    python采集天氣數(shù)據(jù)并做數(shù)據(jù)可視化

    本文主要介紹了python采集天氣數(shù)據(jù)并做數(shù)據(jù)可視化,文中通過示例代碼介紹的非常詳細(xì),對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧
    2022-07-07
  • 詳解Django項(xiàng)目中模板標(biāo)簽及模板的繼承與引用(網(wǎng)站中快速布置廣告)

    詳解Django項(xiàng)目中模板標(biāo)簽及模板的繼承與引用(網(wǎng)站中快速布置廣告)

    這篇文章主要介紹了詳解Django項(xiàng)目中模板標(biāo)簽及模板的繼承與引用【網(wǎng)站中快速布置廣告】,小編覺得挺不錯(cuò)的,現(xiàn)在分享給大家,也給大家做個(gè)參考。一起跟隨小編過來看看吧
    2019-03-03
  • numpy矩陣數(shù)值太多不能全部顯示的解決

    numpy矩陣數(shù)值太多不能全部顯示的解決

    這篇文章主要介紹了numpy矩陣數(shù)值太多不能全部顯示的解決,具有很好的參考價(jià)值,希望對大家有所幫助。一起跟隨小編過來看看吧
    2020-05-05
  • VSCode下好用的Python插件及配置

    VSCode下好用的Python插件及配置

    這篇文章主要介紹了微軟官方的Python插件,已經(jīng)自帶很多功能,下面是插件功能描述,其中部分內(nèi)容我做了翻譯,需要的朋友可以參考下
    2018-04-04
  • 在Python的Flask框架下收發(fā)電子郵件的教程

    在Python的Flask框架下收發(fā)電子郵件的教程

    這篇文章主要介紹了在Python的Flask框架下收發(fā)電子郵件的教程,主要用到了Flask中的Flask-mail工具,需要的朋友可以參考下
    2015-04-04
  • Python中Tkinter組件Frame的具體使用

    Python中Tkinter組件Frame的具體使用

    本文主要介紹了Python中Tkinter組件Frame的具體使用,文中通過示例代碼介紹的非常詳細(xì),具有一定的參考價(jià)值,感興趣的小伙伴們可以參考一下
    2022-01-01
  • python切割圖片的實(shí)現(xiàn)示例

    python切割圖片的實(shí)現(xiàn)示例

    本文主要介紹了python切割圖片的實(shí)現(xiàn)示例,文中通過示例代碼介紹的非常詳細(xì),對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧
    2022-05-05
  • Django中使用極驗(yàn)Geetest滑動驗(yàn)證碼過程解析

    Django中使用極驗(yàn)Geetest滑動驗(yàn)證碼過程解析

    這篇文章主要介紹了Django中使用極驗(yàn)Geetest滑動驗(yàn)證碼過程解析,文中通過示例代碼介紹的非常詳細(xì),對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友可以參考下
    2019-07-07

最新評論