欧美bbbwbbbw肥妇,免费乱码人妻系列日韩,一级黄片

Python實戰(zhàn)小項目之Mnist手寫數(shù)字識別

 更新時間:2021年10月20日 15:01:10   作者:GSAU-深藍工作室  
MNIST 數(shù)據(jù)集已經(jīng)是一個被”嚼爛”了的數(shù)據(jù)集, 很多教程都會對它”下手”, 幾乎成為一個 “典范”. 不過有些人可能對它還不是很了解, 下面通過一個小實例來帶你了解它

程序流程分析圖:

傳播過程:

代碼展示:

創(chuàng)建環(huán)境

使用<pip install+包名>來下載torch,torchvision包

準備數(shù)據(jù)集

設置一次訓練所選取的樣本數(shù)Batch_Sized的值為512,訓練此時Epochs的值為8

BATCH_SIZE = 512
EPOCHS = 8
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

下載數(shù)據(jù)集

Normalize()數(shù)字歸一化,轉(zhuǎn)換使用的值0.1307和0.3081是MNIST數(shù)據(jù)集的全局平均值和標準偏差,這里我們將它們作為給定值。model

train_loader = torch.utils.data.DataLoader(
    datasets.MNIST('data', train=True, download=True,
                   transform=transforms.Compose([.
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=BATCH_SIZE, shuffle=True)

下載測試集

test_loader = torch.utils.data.DataLoader(
    datasets.MNIST('data', train=False,
                   transform=transforms.Compose([
                       transforms.ToTensor(),
                       transforms.Normalize((0.1307,), (0.3081,))
                   ])),
    batch_size=BATCH_SIZE, shuffle=True)

繪制圖像

我們可以使用matplotlib來繪制其中的一些圖像

examples = enumerate(test_loader)
batch_idx, (example_data, example_targets) = next(examples)
print(example_targets)
print(example_data.shape)
print(example_data)
 
import matplotlib.pyplot as plt
fig = plt.figure()
for i in range(6):
  plt.subplot(2,3,i+1)
  plt.tight_layout()
  plt.imshow(example_data[i][0], cmap='gray', interpolation='none')
  plt.title("Ground Truth: {}".format(example_targets[i]))
  plt.xticks([])
  plt.yticks([])
plt.show()

搭建神經(jīng)網(wǎng)絡

這里我們構建全連接神經(jīng)網(wǎng)絡,我們使用三個全連接(或線性)層進行前向傳播。

class linearNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.fc1 = nn.Linear(784, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 10)
    def forward(self, x):
        x = x.view(-1, 784)
        x = self.fc1(x)
        x = F.relu(x)
        x = self.fc2(x)
        x = F.relu(x)
        x = self.fc3(x)
        x = F.log_softmax(x, dim=1)
        return x

訓練模型

首先,我們需要使用optimizer.zero_grad()手動將梯度設置為零,因為PyTorch在默認情況下會累積梯度。然后,我們生成網(wǎng)絡的輸出(前向傳遞),并計算輸出與真值標簽之間的負對數(shù)概率損失?,F(xiàn)在,我們收集一組新的梯度,并使用optimizer.step()將其傳播回每個網(wǎng)絡參數(shù)。

def train(model, device, train_loader, optimizer, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
 
        data, target = data.to(device), target.to(device)
        optimizer.zero_grad()
        output = model(data)
        loss = F.nll_loss(output, target)
        loss.backward()
        optimizer.step()
        if (batch_idx) % 30 == 0:
            print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format(
                epoch, batch_idx * len(data), len(train_loader.dataset),
                       100. * batch_idx / len(train_loader), loss.item()))

測試模型

def test(model, device, test_loader):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += F.nll_loss(output, target, reduction='sum').item() # 將一批的損失相加
            pred = output.max(1, keepdim=True)[1] # 找到概率最大的下標
            correct += pred.eq(target.view_as(pred)).sum().item()
 
    test_loss /= len(test_loader.dataset)
    print('\nTest set: Average loss: {:.4f}, Accuracy: {}/{} ({:.0f}%)\n'.format(
        test_loss, correct, len(test_loader.dataset),
        100. * correct / len(test_loader.dataset)))

將訓練次數(shù)進行循環(huán)

if __name__ == '__main__':
    model = linearNet()
    optimizer = optim.Adam(model.parameters())
 
    for epoch in range(1, EPOCHS + 1):
        train(model, device, train_loader, optimizer, epoch)
        test(model, device, test_loader)

保存訓練模型

torch.save(model, 'MNIST.pth')

運行結果展示:

分享人:蘇云云

到此這篇關于Python實戰(zhàn)小項目之Mnist手寫數(shù)字識別的文章就介紹到這了,更多相關Python Mnist手寫數(shù)字識別內(nèi)容請搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關文章希望大家以后多多支持腳本之家!

相關文章

  • Python如何使用內(nèi)置庫matplotlib繪制折線圖

    Python如何使用內(nèi)置庫matplotlib繪制折線圖

    這篇文章主要介紹了Python如何使用內(nèi)置庫matplotlib繪制折線圖,文中通過示例代碼介紹的非常詳細,對大家的學習或者工作具有一定的參考學習價值,需要的朋友可以參考下
    2020-02-02
  • Python超簡單分析評論提取關鍵詞制作精美詞云流程

    Python超簡單分析評論提取關鍵詞制作精美詞云流程

    這篇文章主要介紹了使用Python來分析評論并且提取其中的關鍵詞,用于制作精美詞云的方法,感興趣的朋友來看看吧
    2022-03-03
  • Python標準庫之os模塊詳解

    Python標準庫之os模塊詳解

    Python的os模塊是用于與操作系統(tǒng)進行交互的模塊,它提供了許多函數(shù)和方法來執(zhí)行文件和目錄操作、進程管理、環(huán)境變量訪問等,本文詳細介紹了Python標準庫中os模塊,感興趣的同學跟著小編一起來看看吧
    2023-08-08
  • PyTorch的Optimizer訓練工具的實現(xiàn)

    PyTorch的Optimizer訓練工具的實現(xiàn)

    這篇文章主要介紹了PyTorch的Optimizer訓練工具的實現(xiàn),文中通過示例代碼介紹的非常詳細,對大家的學習或者工作具有一定的參考學習價值,需要的朋友們下面隨著小編來一起學習學習吧
    2019-08-08
  • Pandas?Query方法使用深度總結

    Pandas?Query方法使用深度總結

    大多數(shù)Pandas用戶都熟悉iloc[]和loc[]索引器方法,用于檢索行和列。但是隨著檢索數(shù)據(jù)的規(guī)則變得越來越復雜,這些方法也隨之變得更加復雜而臃腫。本文將展示如何使用?query()?方法對數(shù)據(jù)框執(zhí)行查詢,感興趣的可以了解一下
    2022-07-07
  • Python實現(xiàn)插入排序和選擇排序的方法

    Python實現(xiàn)插入排序和選擇排序的方法

    這篇文章主要介紹了Python實現(xiàn)插入排序和選擇排序的方法,非常不錯,具有一定的參考借鑒價值,需要的朋友可以參考下
    2019-05-05
  • 解決python的空格和tab混淆而報錯的問題

    解決python的空格和tab混淆而報錯的問題

    這篇文章主要介紹了解決python的空格和tab混淆而報錯的問題,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧
    2021-02-02
  • python:HDF和CSV存儲優(yōu)劣對比分析

    python:HDF和CSV存儲優(yōu)劣對比分析

    這篇文章主要介紹了python:HDF和CSV存儲優(yōu)劣對比分析,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧
    2020-06-06
  • Python面向?qū)ο罂偨Y及類與正則表達式詳解

    Python面向?qū)ο罂偨Y及類與正則表達式詳解

    Python中的類提供了面向?qū)ο缶幊痰乃谢竟δ埽侯惖睦^承機制允許多個基類,派生類可以覆蓋基類中的任何方法,方法中可以調(diào)用基類中的同名方法。這篇文章主要介紹了Python面向?qū)ο罂偨Y及類與正則表達式 ,需要的朋友可以參考下
    2019-04-04
  • Python cv2 圖像自適應灰度直方圖均衡化處理方法

    Python cv2 圖像自適應灰度直方圖均衡化處理方法

    今天小編就為大家分享一篇Python cv2 圖像自適應灰度直方圖均衡化處理方法,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧
    2018-12-12

最新評論