欧美bbbwbbbw肥妇,免费乱码人妻系列日韩,一级黄片

使用PyTorch實(shí)現(xiàn)手寫數(shù)字識(shí)別功能

 更新時(shí)間:2025年03月24日 09:35:51   作者:不惑_  
在人工智能的世界里,計(jì)算機(jī)視覺是最具魅力的領(lǐng)域之一,通過(guò)PyTorch這一強(qiáng)大的深度學(xué)習(xí)框架,我們將在經(jīng)典的MNIST數(shù)據(jù)集上,見證一個(gè)神經(jīng)網(wǎng)絡(luò)從零開始學(xué)會(huì)識(shí)別數(shù)字的全過(guò)程,本文給大家介紹了如何使用PyTorch實(shí)現(xiàn)手寫數(shù)字識(shí)別,需要的朋友可以參考下

當(dāng)計(jì)算機(jī)學(xué)會(huì)“看”數(shù)字

在人工智能的世界里,計(jì)算機(jī)視覺是最具魅力的領(lǐng)域之一。通過(guò)PyTorch這一強(qiáng)大的深度學(xué)習(xí)框架,我們將在經(jīng)典的MNIST數(shù)據(jù)集上,見證一個(gè)神經(jīng)網(wǎng)絡(luò)從零開始學(xué)會(huì)識(shí)別數(shù)字的全過(guò)程。本文將以通俗易懂的方式,帶你走進(jìn)這個(gè)看似神秘實(shí)則充滿邏輯的美妙世界。

搭建開發(fā)環(huán)境

在開始訓(xùn)練之前,我們需要準(zhǔn)備好三個(gè)基礎(chǔ)要素:騰訊云HAI,騰訊云HAI,騰訊云HAI。導(dǎo)入必要的工具庫(kù):

import torch  # 深度學(xué)習(xí)框架核心
import torch.nn as nn  # 神經(jīng)網(wǎng)絡(luò)模塊
from torchvision import datasets, transforms  # 數(shù)據(jù)處理利器

MNIST數(shù)據(jù)集解析

1. 認(rèn)識(shí)手寫數(shù)字?jǐn)?shù)據(jù)庫(kù)

MNIST數(shù)據(jù)集包含6萬(wàn)張訓(xùn)練圖片和1萬(wàn)張測(cè)試圖片,每張都是28x28像素的灰度圖。這些數(shù)字由美國(guó)高中生和人口普查局員工書寫,構(gòu)成了計(jì)算機(jī)視覺領(lǐng)域的"Hello World"。

2. 數(shù)據(jù)預(yù)處理的藝術(shù)

原始圖片需要經(jīng)過(guò)精心處理才能被模型理解:

transform = transforms.Compose([
    transforms.ToTensor(),  # 將圖像轉(zhuǎn)換為數(shù)值矩陣
    transforms.Normalize((0.1307,), (0.3081,))  # 標(biāo)準(zhǔn)化處理
])

3. 可視化的重要性

通過(guò)Matplotlib展示樣本圖片,我們能直觀感受數(shù)據(jù)的特征:

plt.imshow(images[0].squeeze(), cmap='gray')
plt.title(f'Label: {labels[0]}')

神經(jīng)網(wǎng)絡(luò)設(shè)計(jì)

1. 網(wǎng)絡(luò)結(jié)構(gòu)藍(lán)圖

我們?cè)O(shè)計(jì)一個(gè)全連接網(wǎng)絡(luò)(FCN),其結(jié)構(gòu)如同人類神經(jīng)系統(tǒng)的簡(jiǎn)化版:

class Net(nn.Module):
    def __init__(self):
        super().__init__()
        self.flatten = nn.Flatten()  # 將圖片展開為向量
        self.fc1 = nn.Linear(28*28, 128)  # 第一隱藏層
        self.fc2 = nn.Linear(128, 64)    # 第二隱藏層
        self.fc3 = nn.Linear(64, 10)     # 輸出層
        self.dropout = nn.Dropout(0.5)   # 正則化裝置
  • 神經(jīng)元數(shù)量的選擇需要平衡學(xué)習(xí)能力與過(guò)擬合風(fēng)險(xiǎn)
  • Dropout層像隨機(jī)關(guān)閉部分神經(jīng)元,防止模型"死記硬背"

2. 信息傳遞機(jī)制

前向傳播模擬人腦的信息處理過(guò)程:ReLU激活函數(shù)如同神經(jīng)元的開關(guān),決定是否傳遞信號(hào)。

def forward(self, x):
    x = self.flatten(x)  # 展平操作:將圖片變?yōu)?84維向量
    x = torch.relu(self.fc1(x))  # 通過(guò)第一個(gè)全連接層
    x = self.dropout(x)         # 隨機(jī)屏蔽部分神經(jīng)元
    x = torch.relu(self.fc2(x))  # 第二個(gè)全連接層
    return self.fc3(x)          # 最終輸出10個(gè)數(shù)字的概率

讓模型學(xué)會(huì)思考

1. 配置學(xué)習(xí)參數(shù)

  • 損失函數(shù):交叉熵?fù)p失(CrossEntropyLoss),衡量預(yù)測(cè)與真實(shí)的差距
  • 優(yōu)化器:Adam優(yōu)化器,智能調(diào)節(jié)學(xué)習(xí)步伐的導(dǎo)航員
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

2. 訓(xùn)練循環(huán)解析

每個(gè)epoch都是一次完整的學(xué)習(xí)輪回:

def train(epoch):
    model.train()  # 切換至訓(xùn)練模式
    for batch_idx, (data, target) in enumerate(train_loader):
        optimizer.zero_grad()  # 清空之前的梯度
        output = model(data)    # 前向傳播
        loss = criterion(output, target)  # 計(jì)算損失值
        loss.backward()         # 反向傳播求梯度
        optimizer.step()        # 更新網(wǎng)絡(luò)參數(shù)
  • 梯度清零避免不同批次數(shù)據(jù)的干擾
  • 反向傳播就像糾錯(cuò)老師,沿著計(jì)算鏈修正參數(shù)

完整代碼示例

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision

from torchvision import transforms
import matplotlib.pyplot as plt

# 2. 數(shù)據(jù)準(zhǔn)備
# 定義數(shù)據(jù)預(yù)處理:轉(zhuǎn)換為Tensor并標(biāo)準(zhǔn)化(MNIST的均值和標(biāo)準(zhǔn)差)
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))
])

# 加載訓(xùn)練集和測(cè)試集
train_dataset = torchvision.datasets.MNIST(
    root='./data', 
    train=True,
    download=True,
    transform=transform
)

test_dataset = torchvision.datasets.MNIST(
    root='./data',
    train=False,
    transform=transform
)

# 創(chuàng)建數(shù)據(jù)加載器
train_loader = torch.utils.data.DataLoader(
    train_dataset,

    batch_size=64,
    shuffle=True
)

test_loader = torch.utils.data.DataLoader(
    test_dataset,
    batch_size=1000,
    shuffle=False
)

# 查看數(shù)據(jù)集信息
print(f'Train samples: {len(train_dataset)}')
print(f'Test samples: {len(test_dataset)}')

# 可視化樣本
images, labels = next(iter(train_loader))
plt.imshow(images[0].squeeze(), cmap='gray')
plt.title(f'Label: {labels[0]}')
plt.show()

# 3. 定義神經(jīng)網(wǎng)絡(luò)模型
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.flatten = nn.Flatten()
        self.fc1 = nn.Linear(28*28, 128)
        self.fc2 = nn.Linear(128, 64)
        self.fc3 = nn.Linear(64, 10)
        self.dropout = nn.Dropout(0.5)
        
    def forward(self, x):
        x = self.flatten(x)
        x = torch.relu(self.fc1(x))
        x = self.dropout(x)
        x = torch.relu(self.fc2(x))
        x = self.fc3(x)
        return x

model = Net()
print(model)

# 4. 定義損失函數(shù)和優(yōu)化器
criterion = nn.CrossEntropyLoss()

optimizer = optim.Adam(model.parameters(), lr=0.001)


# 5. 訓(xùn)練模型
def train(epoch):

    model.train()
    running_loss = 0.0
    for batch_idx, (data, target) in enumerate(train_loader):
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        

        running_loss += loss.item()
        if batch_idx % 100 == 99:

            print(f'Epoch: {epoch+1}, Batch: {batch_idx+1}, Loss: {running_loss/100:.3f}')
            running_loss = 0.0

# 6. 測(cè)試模型

def test():
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for data, target in test_loader:
            outputs = model(data)
            _, predicted = torch.max(outputs.data, 1)
            total += target.size(0)
            correct += (predicted == target).sum().item()
    
    accuracy = 100 * correct / total
    print(f'Test Accuracy: {accuracy:.2f}%')
    return accuracy

# 7. 執(zhí)行訓(xùn)練和測(cè)試

epochs = 5
for epoch in range(epochs):
    train(epoch)
    test()

# 8. 保存模型
torch.save(model.state_dict(), 'mnist_model.pth')

針對(duì)1-9數(shù)字的測(cè)試

# 擴(kuò)展測(cè)試函數(shù),增加按數(shù)字統(tǒng)計(jì)的功能

def detailed_test():
    model.eval()
    class_correct = [0] * 10  # 存儲(chǔ)每個(gè)數(shù)字的正確計(jì)數(shù)
    class_total = [0] * 10    # 存儲(chǔ)每個(gè)數(shù)字的總樣本數(shù)
    
    with torch.no_grad():
        for images, labels in test_loader:
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            
            # 遍歷每個(gè)預(yù)測(cè)結(jié)果
            for label, prediction in zip(labels, predicted):
                class_total[label] += 1
                if label == prediction:
                    class_correct[label] += 1


    # 打印每個(gè)數(shù)字的準(zhǔn)確率
    print("{:^10} | {:^10} | {:^10}".format("數(shù)字", "正確數(shù)", "準(zhǔn)確率"))
    print("-"*33)

    for i in range(10):
        acc = 100 * class_correct[i] / class_total[i]
        print("{:^10} | {:^10} | {:^10.2f}%".format(i, class_correct[i], acc))
    
    # 可視化錯(cuò)誤案例
    wrong_examples = []

    for images, labels in test_loader:
        outputs = model(images)
        _, predicted = torch.max(outputs, 1)
        mask = predicted != labels
        wrong_examples.extend(zip(images[mask], labels[mask], predicted[mask]))
    
    # 隨機(jī)展示3個(gè)錯(cuò)誤樣本

    fig, axes = plt.subplots(1, 3, figsize=(12,4))

    for ax, (img, true, pred) in zip(axes, wrong_examples[:3]):
        ax.imshow(img.squeeze(), cmap='gray')

        ax.set_title(f'True: {true}\nPred: {pred}')
        ax.axis('off')
    plt.show()

# 執(zhí)行詳細(xì)測(cè)試
detailed_test()

PyTorch vs TensorFlow 深度對(duì)比

1. 核心架構(gòu)差異

特性PyTorchTensorFlow
計(jì)算圖動(dòng)態(tài)圖(即時(shí)執(zhí)行)靜態(tài)圖(需預(yù)先定義)
調(diào)試便利性支持標(biāo)準(zhǔn)Python調(diào)試工具需要特殊工具(tfdbg)
API設(shè)計(jì)更接近Python原生語(yǔ)法自成體系的API風(fēng)格
移動(dòng)端部署支持但生態(tài)較弱通過(guò)TF Lite有成熟解決方案

2. 相同功能的代碼對(duì)比

以定義全連接層為例:

# PyTorch版
import torch.nn as nn
layer = nn.Linear(in_features=784, out_features=128)

# TensorFlow版
from tensorflow.keras.layers import Dense
layer = Dense(units=128, input_dim=784)

3. 訓(xùn)練流程對(duì)比

PyTorch訓(xùn)練循環(huán)

for epoch in range(epochs):
    for data, labels in train_loader:
        optimizer.zero_grad()
        outputs = model(data)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

TensorFlow訓(xùn)練流程

model.compile(optimizer='adam', loss='sparse_categorical_crossentropy')
model.fit(train_dataset, epochs=epochs)  # 自動(dòng)完成訓(xùn)練循環(huán)

4. 性能對(duì)比(MNIST示例)

指標(biāo)PyTorch(CPU)TensorFlow(CPU)
訓(xùn)練時(shí)間/epoch~45秒~50秒
內(nèi)存占用~800MB~1GB
測(cè)試準(zhǔn)確率97.8-98.2%97.5-98.0%

工具的本質(zhì)

PyTorch與TensorFlow的差異,本質(zhì)上是靈活性規(guī)范性的不同追求。就像畫家選擇畫筆,PyTorch提供的是自由揮灑的水彩,TensorFlow則是精準(zhǔn)可控的鋼筆。理解它們的特性差異,根據(jù)項(xiàng)目需求選擇合適的工具,才是提升開發(fā)效率的關(guān)鍵。無(wú)論是哪個(gè)框架,最終目標(biāo)都是將數(shù)學(xué)公式轉(zhuǎn)化為智能的力量。

以上就是使用PyTorch實(shí)現(xiàn)手寫數(shù)字識(shí)別功能的詳細(xì)內(nèi)容,更多關(guān)于PyTorch手寫數(shù)字識(shí)別的資料請(qǐng)關(guān)注腳本之家其它相關(guān)文章!

相關(guān)文章

  • Python+OpenCV感興趣區(qū)域ROI提取方法

    Python+OpenCV感興趣區(qū)域ROI提取方法

    今天小編就為大家分享一篇Python+OpenCV感興趣區(qū)域ROI提取方法,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧
    2019-01-01
  • 時(shí)間序列重采樣和pandas的resample方法示例解析

    時(shí)間序列重采樣和pandas的resample方法示例解析

    這篇文章主要為大家介紹了時(shí)間序列重采樣和pandas的resample方法示例解析,有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進(jìn)步,早日升職加薪
    2023-09-09
  • Python 批量合并多個(gè)txt文件的實(shí)例講解

    Python 批量合并多個(gè)txt文件的實(shí)例講解

    今天小編就為大家分享一篇Python 批量合并多個(gè)txt文件的實(shí)例講解,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧
    2018-05-05
  • python thrift搭建服務(wù)端和客戶端測(cè)試程序

    python thrift搭建服務(wù)端和客戶端測(cè)試程序

    這篇文章主要為大家詳細(xì)介紹了python thrift搭建服務(wù)端和客戶端測(cè)試程序,具有一定的參考價(jià)值,感興趣的小伙伴們可以參考一下
    2018-01-01
  • Python的Django框架可適配的各種數(shù)據(jù)庫(kù)介紹

    Python的Django框架可適配的各種數(shù)據(jù)庫(kù)介紹

    這篇文章主要介紹了Python的Django框架可適配的各種數(shù)據(jù)庫(kù),簡(jiǎn)單總結(jié)為就是流行的幾種數(shù)據(jù)庫(kù)Python基本上全部能用XD 需要的朋友可以參考下
    2015-07-07
  • python3使用tkinter實(shí)現(xiàn)ui界面簡(jiǎn)單實(shí)例

    python3使用tkinter實(shí)現(xiàn)ui界面簡(jiǎn)單實(shí)例

    使用tkinter創(chuàng)建一個(gè)小窗口,布置2個(gè)按鈕,一個(gè)btn關(guān)閉窗口,另一個(gè)btn用于切換執(zhí)行傳入的2個(gè)函數(shù),簡(jiǎn)單的小代碼,大家參考使用吧
    2014-01-01
  • Django實(shí)現(xiàn)翻頁(yè)的示例代碼

    Django實(shí)現(xiàn)翻頁(yè)的示例代碼

    翻頁(yè)是經(jīng)常使用的功能,Django提供了翻頁(yè)器。用Django的Paginator類實(shí)現(xiàn),有需要了解Paginator類用法的朋友可參考。希望此文章對(duì)各位有所幫助
    2021-05-05
  • 基于OpenCV實(shí)現(xiàn)小型的圖像數(shù)據(jù)庫(kù)檢索功能

    基于OpenCV實(shí)現(xiàn)小型的圖像數(shù)據(jù)庫(kù)檢索功能

    下面就使用VLAD表示圖像,實(shí)現(xiàn)一個(gè)小型的圖像數(shù)據(jù)庫(kù)的檢索程序。下面實(shí)現(xiàn)需要的功能模塊,分步驟給大家介紹的非常詳細(xì),對(duì)OpenCV圖像數(shù)據(jù)庫(kù)檢索功能感興趣的朋友跟隨小編一起看看吧
    2021-12-12
  • PySpark和RDD對(duì)象最新詳解

    PySpark和RDD對(duì)象最新詳解

    Spark是一款分布式的計(jì)算框架,用于調(diào)度成百上千的服務(wù)器集群,計(jì)算TB、PB乃至EB級(jí)別的海量數(shù)據(jù),PySpark是由Spark官方開發(fā)的Python語(yǔ)言第三方庫(kù),本文重點(diǎn)介紹PySpark和RDD對(duì)象,感興趣的朋友一起看看吧
    2023-01-01
  • Python編寫登陸接口的方法

    Python編寫登陸接口的方法

    這篇文章主要為大家詳細(xì)介紹了Python編寫登陸接口的方法,具有一定的參考價(jià)值,感興趣的小伙伴們可以參考一下
    2017-07-07

最新評(píng)論