PyTorch實(shí)現(xiàn)卷積神經(jīng)網(wǎng)絡(luò)的搭建詳解
PyTorch中實(shí)現(xiàn)卷積的重要基礎(chǔ)函數(shù)
1、nn.Conv2d:
nn.Conv2d在pytorch中用于實(shí)現(xiàn)卷積。
nn.Conv2d( in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1, )
1、in_channels為輸入通道數(shù)。
2、out_channels為輸出通道數(shù)。
3、kernel_size為卷積核大小。
4、stride為步數(shù)。
5、padding為padding情況。
6、dilation表示空洞卷積情況。
2、nn.MaxPool2d(kernel_size=2)
nn.MaxPool2d在pytorch中用于實(shí)現(xiàn)最大池化。
具體使用方式如下:
MaxPool2d(kernel_size, stride=None, padding=0, dilation=1, return_indices=False, ceil_mode=False)
1、kernel_size為池化核的大小
2、stride為步長(zhǎng)
3、padding為填充情況
3、nn.ReLU()
nn.ReLU()用來實(shí)現(xiàn)Relu函數(shù),實(shí)現(xiàn)非線性。
4、x.view()
x.view用于reshape特征層的形狀。
全部代碼
這是一個(gè)簡(jiǎn)單的CNN模型,用于預(yù)測(cè)mnist手寫體。
import os import numpy as np import torch import torch.nn as nn import torch.utils.data as Data import torchvision import matplotlib.pyplot as plt # 循環(huán)世代 EPOCH = 20 BATCH_SIZE = 50 # 下載mnist數(shù)據(jù)集 train_data = torchvision.datasets.MNIST(root='./mnist/',train=True,transform=torchvision.transforms.ToTensor(),download=True,) # (60000, 28, 28) print(train_data.train_data.size()) # (60000) print(train_data.train_labels.size()) train_loader = Data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True) # 測(cè)試集 test_data = torchvision.datasets.MNIST(root='./mnist/', train=False) # (2000, 1, 28, 28) # 標(biāo)準(zhǔn)化 test_x = torch.unsqueeze(test_data.test_data, dim=1).type(torch.FloatTensor)[:2000]/255. test_y = test_data.test_labels[:2000] # 建立pytorch神經(jīng)網(wǎng)絡(luò) class CNN(nn.Module): def __init__(self): super(CNN, self).__init__() #----------------------------# # 第一部分卷積 #----------------------------# self.conv1 = nn.Sequential( nn.Conv2d( in_channels=1, out_channels=32, kernel_size=5, stride=1, padding=2, dilation=1 ), nn.ReLU(), nn.MaxPool2d(kernel_size=2), ) #----------------------------# # 第二部分卷積 #----------------------------# self.conv2 = nn.Sequential( nn.Conv2d( in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1, dilation=1 ), nn.ReLU(), nn.MaxPool2d(kernel_size=2), ) #----------------------------# # 全連接+池化+全連接 #----------------------------# self.ful1 = nn.Linear(64 * 7 * 7, 512) self.drop = nn.Dropout(0.5) self.ful2 = nn.Sequential(nn.Linear(512, 10),nn.Softmax()) #----------------------------# # 前向傳播 #----------------------------# def forward(self, x): x = self.conv1(x) x = self.conv2(x) x = x.view(x.size(0), -1) x = self.ful1(x) x = self.drop(x) output = self.ful2(x) return output cnn = CNN() # 指定優(yōu)化器 optimizer = torch.optim.Adam(cnn.parameters(), lr=1e-3) # 指定loss函數(shù) loss_func = nn.CrossEntropyLoss() for epoch in range(EPOCH): for step, (b_x, b_y) in enumerate(train_loader): #----------------------------# # 計(jì)算loss并修正權(quán)值 #----------------------------# output = cnn(b_x) loss = loss_func(output, b_y) optimizer.zero_grad() loss.backward() optimizer.step() #----------------------------# # 打印 #----------------------------# if step % 50 == 0: test_output = cnn(test_x) pred_y = torch.max(test_output, 1)[1].data.numpy() accuracy = float((pred_y == test_y.data.numpy()).astype(int).sum()) / float(test_y.size(0)) print('Epoch: %2d'% epoch, ', loss: %.4f' % loss.data.numpy(), ', accuracy: %.4f' % accuracy)
以上就是PyTorch實(shí)現(xiàn)卷積神經(jīng)網(wǎng)絡(luò)的搭建詳解的詳細(xì)內(nèi)容,更多關(guān)于PyTorch搭建卷積神經(jīng)網(wǎng)絡(luò)的資料請(qǐng)關(guān)注腳本之家其它相關(guān)文章!
- PyTorch中的神經(jīng)網(wǎng)絡(luò) Mnist 分類任務(wù)
- 使用Pytorch構(gòu)建第一個(gè)神經(jīng)網(wǎng)絡(luò)模型?附案例實(shí)戰(zhàn)
- pytorch簡(jiǎn)單實(shí)現(xiàn)神經(jīng)網(wǎng)絡(luò)功能
- pytorch深度神經(jīng)網(wǎng)絡(luò)入門準(zhǔn)備自己的圖片數(shù)據(jù)
- Pytorch卷積神經(jīng)網(wǎng)絡(luò)遷移學(xué)習(xí)的目標(biāo)及好處
- Pytorch深度學(xué)習(xí)經(jīng)典卷積神經(jīng)網(wǎng)絡(luò)resnet模塊訓(xùn)練
- Pytorch卷積神經(jīng)網(wǎng)絡(luò)resent網(wǎng)絡(luò)實(shí)踐
- Pytorch神經(jīng)網(wǎng)絡(luò)參數(shù)管理方法詳細(xì)講解
相關(guān)文章
Python 實(shí)現(xiàn)簡(jiǎn)單的電話本功能
這篇文章主要介紹了Python 實(shí)現(xiàn)簡(jiǎn)單的電話本功能的相關(guān)資料,包括添加聯(lián)系人信息,查找姓名顯示聯(lián)系人,存儲(chǔ)聯(lián)系人到 TXT 文檔等內(nèi)容,十分的細(xì)致,有需要的小伙伴可以參考下2015-08-08Python開啟Http Server的實(shí)現(xiàn)步驟
本文主要介紹了Python開啟Http Server的實(shí)現(xiàn)步驟,文中通過示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2023-07-07在Python中使用matplotlib模塊繪制數(shù)據(jù)圖的示例
這篇文章主要介紹了在Python中使用matplotlib模塊繪制數(shù)據(jù)圖的示例,matplotlib模塊經(jīng)常被用來實(shí)現(xiàn)數(shù)據(jù)的可視化,需要的朋友可以參考下2015-05-05TensorFlow Autodiff自動(dòng)微分詳解
這篇文章主要介紹了TensorFlow Autodiff自動(dòng)微分詳解,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過來看看吧2020-07-07Python 實(shí)現(xiàn)進(jìn)度條的六種方式
這篇文章主要介紹了Python 實(shí)現(xiàn)進(jìn)度條的六種方式,幫助大家更好的理解和使用python,感興趣的朋友可以了解下2021-01-01PyCharm插件開發(fā)實(shí)踐之PyGetterAndSetter詳解
這篇文章主要介紹了PyCharm插件開發(fā)實(shí)踐-PyGetterAndSetter,本文給大家介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或工作具有一定的參考借鑒價(jià)值,需要的朋友可以參考下2021-10-10pyenv虛擬環(huán)境管理python多版本和軟件庫的方法
這篇文章主要介紹了pyenv虛擬環(huán)境管理python多版本和軟件庫,文中通過示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2019-12-12python 將列表中的字符串連接成一個(gè)長(zhǎng)路徑的方法
今天小編就為大家分享一篇python 將列表中的字符串連接成一個(gè)長(zhǎng)路徑的方法,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過來看看吧2018-10-10