PyTorch使用CNN實(shí)現(xiàn)圖像分類
理解卷積神經(jīng)網(wǎng)絡(luò)(cnn)
卷積神經(jīng)網(wǎng)絡(luò)是一類深度神經(jīng)網(wǎng)絡(luò),對分析視覺圖像特別有效。他們利用多層構(gòu)建一個可以直接從圖像中識別模式的模型。這些模型對于圖像識別和分類等任務(wù)特別有用,因?yàn)樗鼈儾恍枰謩犹崛√卣鳌?/p>
cnn的關(guān)鍵組成部分
- 卷積層:這些層對輸入應(yīng)用卷積操作,將結(jié)果傳遞給下一層。每個過濾器(或核)可以捕獲不同的特征,如邊緣、角或其他模式。
- 池化層:這些層減少了表示的空間大小,以減少參數(shù)的數(shù)量并加快計算速度。池化層簡化了后續(xù)層的處理。
- 完全連接層:在這些層中,神經(jīng)元與前一層的所有激活具有完全連接,就像傳統(tǒng)的神經(jīng)網(wǎng)絡(luò)一樣。它們有助于對前一層識別的對象進(jìn)行分類。

使用PyTorch進(jìn)行圖像分類
PyTorch是開源的深度學(xué)習(xí)庫,提供了極大的靈活性和多功能性。研究人員和從業(yè)人員廣泛使用它來輕松有效地實(shí)現(xiàn)尖端的機(jī)器學(xué)習(xí)模型。
設(shè)置PyTorch
首先,確保在開發(fā)環(huán)境中安裝了PyTorch。你可以通過pip安裝它:
pip install torch torchvision
用PyTorch創(chuàng)建簡單的CNN示例
下面是如何定義簡單的CNN來使用PyTorch對圖像進(jìn)行分類的示例。
import torch
import torch.nn as nn
import torch.nn.functional as F
# 定義CNN模型(修復(fù)了變量引用問題)
class SimpleCNN(nn.Module):
def __init__(self):
super(SimpleCNN, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5) # 第一個卷積層:3輸入通道,6輸出通道,5x5卷積核
self.pool = nn.MaxPool2d(2, 2) # 最大池化層:2x2窗口,步長2
self.conv2 = nn.Conv2d(6, 16, 5) # 第二個卷積層:6輸入通道,16輸出通道,5x5卷積核
self.fc1 = nn.Linear(16 * 5 * 5, 120)# 全連接層1:400輸入 -> 120輸出
self.fc2 = nn.Linear(120, 84) # 全連接層2:120輸入 -> 84輸出
self.fc3 = nn.Linear(84, 10) # 輸出層:84輸入 -> 10類 logits
def forward(self, x):
# 輸入形狀:[batch_size, 3, 32, 32]
x = self.pool(F.relu(self.conv1(x))) # -> [batch, 6, 14, 14](池化后尺寸減半)
x = self.pool(F.relu(self.conv2(x))) # -> [batch, 16, 5, 5]
x = x.view(-1, 16 * 5 * 5) # 展平為一維向量:16 * 5 * 5=400
x = F.relu(self.fc1(x)) # -> [batch, 120]
x = F.relu(self.fc2(x)) # -> [batch, 84]
x = self.fc3(x) # -> [batch, 10](未應(yīng)用softmax,配合CrossEntropyLoss使用)
return x
這個特殊的網(wǎng)絡(luò)接受一個輸入圖像,通過兩組卷積和池化層,然后是三個完全連接的層。根據(jù)數(shù)據(jù)集的復(fù)雜性和大小調(diào)整網(wǎng)絡(luò)的架構(gòu)和超參數(shù)。
模型定義:
SimpleCNN繼承自nn.Module- 使用兩個卷積層提取特征,三個全連接層進(jìn)行分類
- 最終輸出未應(yīng)用 softmax,而是直接輸出 logits(與
CrossEntropyLoss配合使用)
訓(xùn)練網(wǎng)絡(luò)
對于訓(xùn)練,你需要一個數(shù)據(jù)集。PyTorch通過torchvision包提供了用于數(shù)據(jù)加載和預(yù)處理的實(shí)用程序。
import torchvision.transforms as transforms
import torchvision
from torch.utils.data import DataLoader
# 初始化模型、損失函數(shù)和優(yōu)化器
net = SimpleCNN() # 實(shí)例化模型
criterion = nn.CrossEntropyLoss() # 使用交叉熵?fù)p失函數(shù)(自動處理softmax)
optimizer = torch.optim.SGD(net.parameters(),
lr=0.001, # 學(xué)習(xí)率
momentum=0.9) # 動量參數(shù)
# 數(shù)據(jù)預(yù)處理和加載
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
# 加載CIFAR-10訓(xùn)練集
trainset = torchvision.datasets.CIFAR10(
root='./data', train=True,
download=True, # 自動下載數(shù)據(jù)集
transform=transform
)
trainloader = DataLoader(trainset,
batch_size=4, # 每個batch包含4張圖像
shuffle=True) # 打亂數(shù)據(jù)順序
模型配置:
- 損失函數(shù):
CrossEntropyLoss(自動包含 softmax 和 log_softmax) - 優(yōu)化器:SGD with momentum,學(xué)習(xí)率 0.001
數(shù)據(jù)加載:
使用
torchvision.datasets.CIFAR10加載數(shù)據(jù)集batch_size:4(根據(jù) GPU 內(nèi)存調(diào)整,CIFAR-10 建議 batch size ≥ 32)
transforms.Compose定義數(shù)據(jù)預(yù)處理流程:ToTensor():將圖像轉(zhuǎn)換為 PyTorch TensorNormalize():標(biāo)準(zhǔn)化圖像像素值到 [-1, 1]
加載數(shù)據(jù)后,訓(xùn)練過程包括通過數(shù)據(jù)集進(jìn)行多次迭代,使用反向傳播和合適的損失函數(shù):
# 訓(xùn)練循環(huán)
for epoch in range(2): # 進(jìn)行2個epoch的訓(xùn)練
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
inputs, labels = data
# 前向傳播
outputs = net(inputs)
loss = criterion(outputs, labels)
# 反向傳播和優(yōu)化
optimizer.zero_grad() # 清空梯度
loss.backward() # 計算梯度
optimizer.step() # 更新參數(shù)
running_loss += loss.item()
# 每2000個batch打印一次
if i % 2000 == 1999:
avg_loss = running_loss / 2000
print(f'Epoch [{epoch+1}/{2}], Batch [{i+1}/2000], Loss: {avg_loss:.3f}')
running_loss = 0.0
print("訓(xùn)練完成!")
訓(xùn)練循環(huán):
- epoch:完整遍歷數(shù)據(jù)集一次
- batch:數(shù)據(jù)加載器中的一個批次
- 梯度清零:每次反向傳播前需要清空梯度
- 損失計算:
outputs的形狀為[batch_size, 10],labels為整數(shù)標(biāo)簽
完整代碼
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.transforms as transforms
import torchvision
from torch.utils.data import DataLoader
# 定義CNN模型(修復(fù)了變量引用問題)
class SimpleCNN(nn.Module):
def __init__(self):
super(SimpleCNN, self).__init__()
self.conv1 = nn.Conv2d(3, 6, 5) # 第一個卷積層:3輸入通道,6輸出通道,5x5卷積核
self.pool = nn.MaxPool2d(2, 2) # 最大池化層:2x2窗口,步長2
self.conv2 = nn.Conv2d(6, 16, 5) # 第二個卷積層:6輸入通道,16輸出通道,5x5卷積核
self.fc1 = nn.Linear(16 * 5 * 5, 120)# 全連接層1:400輸入 -> 120輸出
self.fc2 = nn.Linear(120, 84) # 全連接層2:120輸入 -> 84輸出
self.fc3 = nn.Linear(84, 10) # 輸出層:84輸入 -> 10類 logits
def forward(self, x):
# 輸入形狀:[batch_size, 3, 32, 32]
x = self.pool(F.relu(self.conv1(x))) # -> [batch, 6, 14, 14](池化后尺寸減半)
x = self.pool(F.relu(self.conv2(x))) # -> [batch, 16, 5, 5]
x = x.view(-1, 16 * 5 * 5) # 展平為一維向量:16 * 5 * 5=400
x = F.relu(self.fc1(x)) # -> [batch, 120]
x = F.relu(self.fc2(x)) # -> [batch, 84]
x = self.fc3(x) # -> [batch, 10](未應(yīng)用softmax,配合CrossEntropyLoss使用)
return x
# 初始化模型、損失函數(shù)和優(yōu)化器
net = SimpleCNN() # 實(shí)例化模型
criterion = nn.CrossEntropyLoss() # 使用交叉熵?fù)p失函數(shù)(自動處理softmax)
optimizer = torch.optim.SGD(net.parameters(),
lr=0.001, # 學(xué)習(xí)率
momentum=0.9) # 動量參數(shù)
# 數(shù)據(jù)預(yù)處理和加載
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
# 加載CIFAR-10訓(xùn)練集
trainset = torchvision.datasets.CIFAR10(
root='./data', train=True,
download=True, # 自動下載數(shù)據(jù)集
transform=transform
)
trainloader = DataLoader(trainset,
batch_size=4, # 每個batch包含4張圖像
shuffle=True) # 打亂數(shù)據(jù)順序
# 訓(xùn)練循環(huán)
for epoch in range(2): # 進(jìn)行2個epoch的訓(xùn)練
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
inputs, labels = data
# 前向傳播
outputs = net(inputs)
loss = criterion(outputs, labels)
# 反向傳播和優(yōu)化
optimizer.zero_grad() # 清空梯度
loss.backward() # 計算梯度
optimizer.step() # 更新參數(shù)
running_loss += loss.item()
# 每2000個batch打印一次
if i % 2000 == 1999:
avg_loss = running_loss / 2000
print(f'Epoch [{epoch+1}/{2}], Batch [{i+1}/2000], Loss: {avg_loss:.3f}')
running_loss = 0.0
print("訓(xùn)練完成!")
最后總結(jié)
通過PyTorch和卷積神經(jīng)網(wǎng)絡(luò),你可以有效地處理圖像分類任務(wù)。借助PyTorch的靈活性,可以根據(jù)特定的數(shù)據(jù)集和應(yīng)用程序構(gòu)建、訓(xùn)練和微調(diào)模型。示例代碼僅為理論過程,實(shí)際項(xiàng)目中還有大量優(yōu)化空間。
以上就是PyTorch使用CNN實(shí)現(xiàn)圖像分類的詳細(xì)內(nèi)容,更多關(guān)于PyTorch CNN圖像分類的資料請關(guān)注腳本之家其它相關(guān)文章!
- 使用pytorch進(jìn)行圖像分類的詳細(xì)步驟
- 如何使用Pytorch完成圖像分類任務(wù)詳解
- Pytorch深度學(xué)習(xí)之實(shí)現(xiàn)病蟲害圖像分類
- Python Pytorch深度學(xué)習(xí)之圖像分類器
- Python深度學(xué)習(xí)pytorch實(shí)現(xiàn)圖像分類數(shù)據(jù)集
- 基于PyTorch實(shí)現(xiàn)一個簡單的CNN圖像分類器
- Pytorch 使用CNN圖像分類的實(shí)現(xiàn)
- 使用PyTorch訓(xùn)練一個圖像分類器實(shí)例
- PyTorch中圖像多分類的實(shí)現(xiàn)
相關(guān)文章
使用Python中PDB模塊中的命令來調(diào)試Python代碼的教程
這篇文章主要介紹了使用Python中PDB模塊中的命令來調(diào)試Python代碼的教程,包括設(shè)置斷點(diǎn)來修改代碼等、對于Python團(tuán)隊(duì)項(xiàng)目工作有一定幫助,需要的朋友可以參考下2015-03-03
Python使用DrissionPage實(shí)現(xiàn)網(wǎng)頁自動化采集
DrissionPage 是一個基于 python 的網(wǎng)頁自動化工具,它既能控制瀏覽器,也能收發(fā)數(shù)據(jù)包,還能把兩者合而為一,可兼顧瀏覽器自動化的便利性和 requests 的高效率,本文給大家介紹了Python使用DrissionPage實(shí)現(xiàn)網(wǎng)頁自動化采集,需要的朋友可以參考下2025-03-03
Django 導(dǎo)出項(xiàng)目依賴庫到 requirements.txt過程解析
這篇文章主要介紹了Django 導(dǎo)出項(xiàng)目依賴庫到 requirements.txt過程解析,文中通過示例代碼介紹的非常詳細(xì),對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價值,需要的朋友可以參考下2019-08-08
Gradio機(jī)器學(xué)習(xí)模型快速部署工具quickstart前篇
這篇文章主要為大家介紹了Gradio機(jī)器學(xué)習(xí)模型快速部署工具quickstart準(zhǔn)備原文翻譯,有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進(jìn)步,早日升職加薪2023-04-04
Python利用CNN實(shí)現(xiàn)對時序數(shù)據(jù)進(jìn)行分類
這篇文章主要為大家詳細(xì)介紹了Python如何利用CNN實(shí)現(xiàn)對時序數(shù)據(jù)進(jìn)行分類功能,文中的示例代碼講解詳細(xì),感興趣的小伙伴可以了解一下2023-02-02
Django數(shù)據(jù)庫遷移的實(shí)現(xiàn)步驟
本文主要介紹了Django數(shù)據(jù)庫遷移的實(shí)現(xiàn)步驟,文中通過示例代碼介紹的非常詳細(xì),對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2025-08-08
Python實(shí)現(xiàn)計算文件MD5和SHA1的方法示例
這篇文章主要介紹了Python實(shí)現(xiàn)計算文件MD5和SHA1的方法,結(jié)合具體實(shí)例形式分析了Python針對文件MD5及SHA1的計算方法,需要的朋友可以參考下2019-06-06

