欧美bbbwbbbw肥妇,免费乱码人妻系列日韩,一级黄片

pytorch 利用lstm做mnist手寫數(shù)字識別分類的實例

 更新時間:2020年01月10日 10:43:23   作者:xckkcxxck  
今天小編就為大家分享一篇pytorch 利用lstm做mnist手寫數(shù)字識別分類的實例,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧

代碼如下,U我認為對于新手來說最重要的是學會rnn讀取數(shù)據(jù)的格式。

# -*- coding: utf-8 -*-
"""
Created on Tue Oct 9 08:53:25 2018
@author: www
"""
 
import sys
sys.path.append('..')
 
import torch
import datetime
from torch.autograd import Variable
from torch import nn
from torch.utils.data import DataLoader
 
from torchvision import transforms as tfs
from torchvision.datasets import MNIST
 
#定義數(shù)據(jù)
data_tf = tfs.Compose([
   tfs.ToTensor(),
   tfs.Normalize([0.5], [0.5])
])
train_set = MNIST('E:/data', train=True, transform=data_tf, download=True)
test_set = MNIST('E:/data', train=False, transform=data_tf, download=True)
 
train_data = DataLoader(train_set, 64, True, num_workers=4)
test_data = DataLoader(test_set, 128, False, num_workers=4)
 
#定義模型
class rnn_classify(nn.Module):
   def __init__(self, in_feature=28, hidden_feature=100, num_class=10, num_layers=2):
     super(rnn_classify, self).__init__()
     self.rnn = nn.LSTM(in_feature, hidden_feature, num_layers)#使用兩層lstm
     self.classifier = nn.Linear(hidden_feature, num_class)#將最后一個的rnn使用全連接的到最后的輸出結(jié)果
     
   def forward(self, x):
     #x的大小為(batch,1,28,28),所以我們需要將其轉(zhuǎn)化為rnn的輸入格式(28,batch,28)
     x = x.squeeze() #去掉(batch,1,28,28)中的1,變成(batch, 28,28)
     x = x.permute(2, 0, 1)#將最后一維放到第一維,變成(batch,28,28)
     out, _ = self.rnn(x) #使用默認的隱藏狀態(tài),得到的out是(28, batch, hidden_feature)
     out = out[-1,:,:]#取序列中的最后一個,大小是(batch, hidden_feature)
     out = self.classifier(out) #得到分類結(jié)果
     return out
     
net = rnn_classify()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adadelta(net.parameters(), 1e-1)
 
#定義訓練過程
def get_acc(output, label):
  total = output.shape[0]
  _, pred_label = output.max(1)
  num_correct = (pred_label == label).sum().item()
  return num_correct / total
  
  
def train(net, train_data, valid_data, num_epochs, optimizer, criterion):
  if torch.cuda.is_available():
    net = net.cuda()
  prev_time = datetime.datetime.now()
  for epoch in range(num_epochs):
    train_loss = 0
    train_acc = 0
    net = net.train()
    for im, label in train_data:
      if torch.cuda.is_available():
        im = Variable(im.cuda()) # (bs, 3, h, w)
        label = Variable(label.cuda()) # (bs, h, w)
      else:
        im = Variable(im)
        label = Variable(label)
      # forward
      output = net(im)
      loss = criterion(output, label)
      # backward
      optimizer.zero_grad()
      loss.backward()
      optimizer.step()
 
      train_loss += loss.item()
      train_acc += get_acc(output, label)
 
    cur_time = datetime.datetime.now()
    h, remainder = divmod((cur_time - prev_time).seconds, 3600)
    m, s = divmod(remainder, 60)
    time_str = "Time %02d:%02d:%02d" % (h, m, s)
    if valid_data is not None:
      valid_loss = 0
      valid_acc = 0
      net = net.eval()
      for im, label in valid_data:
        if torch.cuda.is_available():
          im = Variable(im.cuda())
          label = Variable(label.cuda())
        else:
          im = Variable(im)
          label = Variable(label)
        output = net(im)
        loss = criterion(output, label)
        valid_loss += loss.item()
        valid_acc += get_acc(output, label)
      epoch_str = (
        "Epoch %d. Train Loss: %f, Train Acc: %f, Valid Loss: %f, Valid Acc: %f, "
        % (epoch, train_loss / len(train_data),
          train_acc / len(train_data), valid_loss / len(valid_data),
          valid_acc / len(valid_data)))
    else:
      epoch_str = ("Epoch %d. Train Loss: %f, Train Acc: %f, " %
             (epoch, train_loss / len(train_data),
             train_acc / len(train_data)))
    prev_time = cur_time
    print(epoch_str + time_str)
    
train(net, train_data, test_data, 10, optimizer, criterion)    

以上這篇pytorch 利用lstm做mnist手寫數(shù)字識別分類的實例就是小編分享給大家的全部內(nèi)容了,希望能給大家一個參考,也希望大家多多支持腳本之家。

相關(guān)文章

  • el-table 多表格彈窗嵌套數(shù)據(jù)顯示異常錯亂問題解決方案

    el-table 多表格彈窗嵌套數(shù)據(jù)顯示異常錯亂問題解決方案

    使用vue+element開發(fā)報表功能時,需要列表上某列的超鏈接按鈕彈窗展示,在彈窗的el-table列表某列中再次使用超鏈接按鈕點開彈窗,以此類推多表格彈窗嵌套,本文以彈窗兩次為例,需要的朋友可以參考下
    2023-11-11
  • 利用Python將原始邊列表轉(zhuǎn)換為鄰接矩陣的過程

    利用Python將原始邊列表轉(zhuǎn)換為鄰接矩陣的過程

    有時候,我們會從外部數(shù)據(jù)源中得到原始的邊列表,而需要將其轉(zhuǎn)換為鄰接矩陣以便進行后續(xù)的分析和處理,本文將介紹如何使用Python來實現(xiàn)這一轉(zhuǎn)換過程,需要的朋友可以參考下
    2024-04-04
  • JPype實現(xiàn)在python中調(diào)用JAVA的實例

    JPype實現(xiàn)在python中調(diào)用JAVA的實例

    本篇文章主要介紹了JPype實現(xiàn)在python中調(diào)用JAVA的實例,小編覺得挺不錯的,現(xiàn)在分享給大家,也給大家做個參考。一起跟隨小編過來看看吧
    2017-07-07
  • pandas使用get_dummies進行one-hot編碼的方法

    pandas使用get_dummies進行one-hot編碼的方法

    今天小編就為大家分享一篇pandas使用get_dummies進行one-hot編碼的方法,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧
    2018-07-07
  • python從入門到精通(DAY 1)

    python從入門到精通(DAY 1)

    本文是此次python從入門到精通系列文章的第一篇,給大家匯總一下常用的Python的基礎(chǔ)知識,非常的簡單,但是很全面,有需要的小伙伴可以參考下
    2015-12-12
  • Win10環(huán)境中如何實現(xiàn)python2和python3并存

    Win10環(huán)境中如何實現(xiàn)python2和python3并存

    這篇文章主要介紹了Win10環(huán)境中如何實現(xiàn)python2和python3并存,文中通過示例代碼介紹的非常詳細,對大家的學習或者工作具有一定的參考學習價值,需要的朋友可以參考下
    2020-07-07
  • pytorch permute維度轉(zhuǎn)換方法

    pytorch permute維度轉(zhuǎn)換方法

    今天小編就為大家分享一篇pytorch permute維度轉(zhuǎn)換方法,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧
    2018-12-12
  • 簡單了解Python字典copy與賦值的區(qū)別

    簡單了解Python字典copy與賦值的區(qū)別

    這篇文章主要介紹了簡單了解Python字典copy與賦值區(qū)別,文中通過示例代碼介紹的非常詳細,對大家的學習或者工作具有一定的參考學習價值,需要的朋友可以參考下
    2020-09-09
  • python字符串判斷密碼強弱

    python字符串判斷密碼強弱

    這篇文章主要為大家詳細介紹了python字符串判斷密碼強弱,文中示例代碼介紹的非常詳細,具有一定的參考價值,感興趣的小伙伴們可以參考一下
    2020-03-03
  • Python CSV模塊使用實例

    Python CSV模塊使用實例

    這篇文章主要介紹了Python CSV模塊使用實例,本文將舉幾個例子來介紹一下Python的CSV模塊的使用方法,包括reader、writer、DictReader、DictWriter.register_dialect等,需要的朋友可以參考下
    2015-04-04

最新評論