欧美bbbwbbbw肥妇,免费乱码人妻系列日韩,一级黄片

PyTorch使用CNN實現(xiàn)圖像分類

 更新時間:2025年03月13日 08:49:47   作者:夢想畫家  
圖像分類是計算機視覺領域的一項基本任務,也是深度學習技術的一個常見應用,近年來,卷積神經(jīng)網(wǎng)絡(cnn)和PyTorch庫的結合由于其易用性和魯棒性已經(jīng)成為執(zhí)行圖像分類的流行選擇,所以本文給大家介紹了PyTorch使用CNN實現(xiàn)圖像分類的示例,需要的朋友可以參考下

理解卷積神經(jīng)網(wǎng)絡(cnn)

卷積神經(jīng)網(wǎng)絡是一類深度神經(jīng)網(wǎng)絡,對分析視覺圖像特別有效。他們利用多層構建一個可以直接從圖像中識別模式的模型。這些模型對于圖像識別和分類等任務特別有用,因為它們不需要手動提取特征。

cnn的關鍵組成部分

  • 卷積層:這些層對輸入應用卷積操作,將結果傳遞給下一層。每個過濾器(或核)可以捕獲不同的特征,如邊緣、角或其他模式。
  • 池化層:這些層減少了表示的空間大小,以減少參數(shù)的數(shù)量并加快計算速度。池化層簡化了后續(xù)層的處理。
  • 完全連接層:在這些層中,神經(jīng)元與前一層的所有激活具有完全連接,就像傳統(tǒng)的神經(jīng)網(wǎng)絡一樣。它們有助于對前一層識別的對象進行分類。

使用PyTorch進行圖像分類

PyTorch是開源的深度學習庫,提供了極大的靈活性和多功能性。研究人員和從業(yè)人員廣泛使用它來輕松有效地實現(xiàn)尖端的機器學習模型。

設置PyTorch

首先,確保在開發(fā)環(huán)境中安裝了PyTorch。你可以通過pip安裝它:

pip install torch torchvision

用PyTorch創(chuàng)建簡單的CNN示例

下面是如何定義簡單的CNN來使用PyTorch對圖像進行分類的示例。

import torch
import torch.nn as nn
import torch.nn.functional as F

# 定義CNN模型(修復了變量引用問題)
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)      # 第一個卷積層:3輸入通道,6輸出通道,5x5卷積核
        self.pool = nn.MaxPool2d(2, 2)        # 最大池化層:2x2窗口,步長2
        self.conv2 = nn.Conv2d(6, 16, 5)     # 第二個卷積層:6輸入通道,16輸出通道,5x5卷積核
        self.fc1 = nn.Linear(16 * 5 * 5, 120)# 全連接層1:400輸入 -> 120輸出
        self.fc2 = nn.Linear(120, 84)      # 全連接層2:120輸入 -> 84輸出
        self.fc3 = nn.Linear(84, 10)       # 輸出層:84輸入 -> 10類 logits

    def forward(self, x):
        # 輸入形狀:[batch_size, 3, 32, 32]
        x = self.pool(F.relu(self.conv1(x)))  # -> [batch, 6, 14, 14](池化后尺寸減半)
        x = self.pool(F.relu(self.conv2(x)))  # -> [batch, 16, 5, 5] 
        x = x.view(-1, 16 * 5 * 5)            # 展平為一維向量:16 * 5 * 5=400
        x = F.relu(self.fc1(x))             # -> [batch, 120]
        x = F.relu(self.fc2(x))             # -> [batch, 84]
        x = self.fc3(x)                     # -> [batch, 10](未應用softmax,配合CrossEntropyLoss使用)
        return x

這個特殊的網(wǎng)絡接受一個輸入圖像,通過兩組卷積和池化層,然后是三個完全連接的層。根據(jù)數(shù)據(jù)集的復雜性和大小調(diào)整網(wǎng)絡的架構和超參數(shù)。

模型定義

  • SimpleCNN 繼承自 nn.Module
  • 使用兩個卷積層提取特征,三個全連接層進行分類
  • 最終輸出未應用 softmax,而是直接輸出 logits(與 CrossEntropyLoss 配合使用)

訓練網(wǎng)絡

對于訓練,你需要一個數(shù)據(jù)集。PyTorch通過torchvision包提供了用于數(shù)據(jù)加載和預處理的實用程序。

import torchvision.transforms as transforms
import torchvision
from torch.utils.data import DataLoader

# 初始化模型、損失函數(shù)和優(yōu)化器
net = SimpleCNN()               # 實例化模型
criterion = nn.CrossEntropyLoss()  # 使用交叉熵損失函數(shù)(自動處理softmax)
optimizer = torch.optim.SGD(net.parameters(), 
                            lr=0.001,      # 學習率
                            momentum=0.9)   # 動量參數(shù)

# 數(shù)據(jù)預處理和加載
transform = transforms.Compose([
    transforms.ToTensor(),          
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  

# 加載CIFAR-10訓練集
trainset = torchvision.datasets.CIFAR10(
    root='./data', train=True,
    download=True,  # 自動下載數(shù)據(jù)集
    transform=transform
)

trainloader = DataLoader(trainset, 
                     batch_size=4,   # 每個batch包含4張圖像
                     shuffle=True)  # 打亂數(shù)據(jù)順序

模型配置

  • 損失函數(shù)CrossEntropyLoss(自動包含 softmax 和 log_softmax)
  • 優(yōu)化器:SGD with momentum,學習率 0.001

數(shù)據(jù)加載

  • 使用 torchvision.datasets.CIFAR10 加載數(shù)據(jù)集

  • batch_size:4(根據(jù) GPU 內(nèi)存調(diào)整,CIFAR-10 建議 batch size ≥ 32)

  • transforms.Compose 定義數(shù)據(jù)預處理流程:

    • ToTensor():將圖像轉(zhuǎn)換為 PyTorch Tensor
    • Normalize():標準化圖像像素值到 [-1, 1]

加載數(shù)據(jù)后,訓練過程包括通過數(shù)據(jù)集進行多次迭代,使用反向傳播和合適的損失函數(shù):

# 訓練循環(huán)
for epoch in range(2):  # 進行2個epoch的訓練
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data
        
        # 前向傳播
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        
        # 反向傳播和優(yōu)化
        optimizer.zero_grad()   # 清空梯度
        loss.backward()         # 計算梯度
        optimizer.step()       # 更新參數(shù)
        
        running_loss += loss.item()
        
        # 每2000個batch打印一次
        if i % 2000 == 1999:
            avg_loss = running_loss / 2000
            print(f'Epoch [{epoch+1}/{2}], Batch [{i+1}/2000], Loss: {avg_loss:.3f}')
            running_loss = 0.0

print("訓練完成!")

訓練循環(huán)

  • epoch:完整遍歷數(shù)據(jù)集一次
  • batch:數(shù)據(jù)加載器中的一個批次
  • 梯度清零:每次反向傳播前需要清空梯度
  • 損失計算outputs 的形狀為 [batch_size, 10]labels 為整數(shù)標簽

完整代碼

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import torchvision
from torch.utils.data import DataLoader

# 定義CNN模型(修復了變量引用問題)
class SimpleCNN(nn.Module):
    def __init__(self):
        super(SimpleCNN, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)      # 第一個卷積層:3輸入通道,6輸出通道,5x5卷積核
        self.pool = nn.MaxPool2d(2, 2)        # 最大池化層:2x2窗口,步長2
        self.conv2 = nn.Conv2d(6, 16, 5)     # 第二個卷積層:6輸入通道,16輸出通道,5x5卷積核
        self.fc1 = nn.Linear(16 * 5 * 5, 120)# 全連接層1:400輸入 -> 120輸出
        self.fc2 = nn.Linear(120, 84)      # 全連接層2:120輸入 -> 84輸出
        self.fc3 = nn.Linear(84, 10)       # 輸出層:84輸入 -> 10類 logits

    def forward(self, x):
        # 輸入形狀:[batch_size, 3, 32, 32]
        x = self.pool(F.relu(self.conv1(x)))  # -> [batch, 6, 14, 14](池化后尺寸減半)
        x = self.pool(F.relu(self.conv2(x)))  # -> [batch, 16, 5, 5] 
        x = x.view(-1, 16 * 5 * 5)            # 展平為一維向量:16 * 5 * 5=400
        x = F.relu(self.fc1(x))             # -> [batch, 120]
        x = F.relu(self.fc2(x))             # -> [batch, 84]
        x = self.fc3(x)                     # -> [batch, 10](未應用softmax,配合CrossEntropyLoss使用)
        return x

# 初始化模型、損失函數(shù)和優(yōu)化器
net = SimpleCNN()               # 實例化模型
criterion = nn.CrossEntropyLoss()  # 使用交叉熵損失函數(shù)(自動處理softmax)
optimizer = torch.optim.SGD(net.parameters(), 
                            lr=0.001,      # 學習率
                            momentum=0.9)   # 動量參數(shù)

# 數(shù)據(jù)預處理和加載
transform = transforms.Compose([
    transforms.ToTensor(),            
    transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))  
])

# 加載CIFAR-10訓練集
trainset = torchvision.datasets.CIFAR10(
    root='./data', train=True,
    download=True,  # 自動下載數(shù)據(jù)集
    transform=transform
)
trainloader = DataLoader(trainset, 
                         batch_size=4,   # 每個batch包含4張圖像
                         shuffle=True)  # 打亂數(shù)據(jù)順序

# 訓練循環(huán)
for epoch in range(2):  # 進行2個epoch的訓練
    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data
        
        # 前向傳播
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        
        # 反向傳播和優(yōu)化
        optimizer.zero_grad()   # 清空梯度
        loss.backward()         # 計算梯度
        optimizer.step()       # 更新參數(shù)
        
        running_loss += loss.item()
        
        # 每2000個batch打印一次
        if i % 2000 == 1999:
            avg_loss = running_loss / 2000
            print(f'Epoch [{epoch+1}/{2}], Batch [{i+1}/2000], Loss: {avg_loss:.3f}')
            running_loss = 0.0

print("訓練完成!")

最后總結

通過PyTorch和卷積神經(jīng)網(wǎng)絡,你可以有效地處理圖像分類任務。借助PyTorch的靈活性,可以根據(jù)特定的數(shù)據(jù)集和應用程序構建、訓練和微調(diào)模型。示例代碼僅為理論過程,實際項目中還有大量優(yōu)化空間。

以上就是PyTorch使用CNN實現(xiàn)圖像分類的詳細內(nèi)容,更多關于PyTorch CNN圖像分類的資料請關注腳本之家其它相關文章!

相關文章

  • matplotlib 曲線圖 和 折線圖 plt.plot()實例

    matplotlib 曲線圖 和 折線圖 plt.plot()實例

    這篇文章主要介紹了matplotlib 曲線圖 和 折線圖 plt.plot()實例,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧
    2020-04-04
  • Python3.x檢查內(nèi)存可用大小的兩種實現(xiàn)

    Python3.x檢查內(nèi)存可用大小的兩種實現(xiàn)

    本文將介紹如何使用Python 3實現(xiàn)檢查Linux服務器內(nèi)存可用大小的方法,包括使用Python標準庫實現(xiàn)和使用Linux命令實現(xiàn)兩種方式,感興趣可以了解一下
    2023-05-05
  • python 調(diào)用Google翻譯接口的方法

    python 調(diào)用Google翻譯接口的方法

    這篇文章主要介紹了python 調(diào)用Google翻譯接口的方法,幫助大家更好的理解和使用python處理url,感興趣的朋友可以了解下
    2020-12-12
  • Python創(chuàng)建普通菜單示例【基于win32ui模塊】

    Python創(chuàng)建普通菜單示例【基于win32ui模塊】

    這篇文章主要介紹了Python創(chuàng)建普通菜單,結合實例形式分析了Python基于win32ui模塊創(chuàng)建普通菜單及添加菜單項的相關操作技巧,并附帶說明了win32ui模塊的安裝命令,需要的朋友可以參考下
    2018-05-05
  • springboot配置文件抽離 git管理統(tǒng) 配置中心詳解

    springboot配置文件抽離 git管理統(tǒng) 配置中心詳解

    在本篇文章里小編給大家整理的是關于springboot配置文件抽離 git管理統(tǒng) 配置中心的相關知識點內(nèi)容,有需要的朋友們可以學習下。
    2019-09-09
  • Python使用MYSQLDB實現(xiàn)從數(shù)據(jù)庫中導出XML文件的方法

    Python使用MYSQLDB實現(xiàn)從數(shù)據(jù)庫中導出XML文件的方法

    這篇文章主要介紹了Python使用MYSQLDB實現(xiàn)從數(shù)據(jù)庫中導出XML文件的方法,涉及Python使用MYSQLDB操作數(shù)據(jù)庫及XML文件的相關技巧,需要的朋友可以參考下
    2015-05-05
  • Python批量合并有合并單元格的Excel文件詳解

    Python批量合并有合并單元格的Excel文件詳解

    經(jīng)常使用Excel的用戶都知道,合并單元格的存在,這篇文章主要給大家介紹了關于利用Python如何批量合并有合并單元格的Excel文件的相關資料,文中通過示例代碼介紹的非常詳細,對大家的學習或者工作具有一定的參考學習價值,需要的朋友們下面來一起看看吧。
    2018-04-04
  • calendar在python3時間中常用函數(shù)舉例詳解

    calendar在python3時間中常用函數(shù)舉例詳解

    這篇文章主要介紹了calendar在python3時間中常用函數(shù)的相關文章,對此知識點有興趣的朋友們可以學習下。
    2020-11-11
  • Python 使用with上下文實現(xiàn)計時功能

    Python 使用with上下文實現(xiàn)計時功能

    with 語句適用于對資源進行訪問的場合,確保不管使用過程中是否發(fā)生異常都會執(zhí)行必要的“清理”操作,釋放資源,比如文件使用后自動關閉、線程中鎖的自動獲取和釋放等。這篇文章主要介紹了Python 使用with上下文實現(xiàn)計時,需要的朋友可以參考下
    2018-03-03
  • 關于Pycharm安裝第三方庫超時 Read time-out的問題

    關于Pycharm安裝第三方庫超時 Read time-out的問題

    這篇文章主要介紹了關于Pycharm安裝第三方庫超時 Read time-out的問題, 找了幾個命令都不是很好用,最后找到解決的步驟,感興趣的朋友跟隨小編一起看看吧
    2021-10-10

最新評論