欧美bbbwbbbw肥妇,免费乱码人妻系列日韩,一级黄片

PyTorch實(shí)現(xiàn)卷積神經(jīng)網(wǎng)絡(luò)的搭建詳解

 更新時(shí)間:2022年05月07日 08:56:01   作者:Bubbliiiing  
這篇文章主要為大家介紹了PyTorch實(shí)現(xiàn)卷積神經(jīng)網(wǎng)絡(luò)的搭建詳解,有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進(jìn)步,早日升職加薪

PyTorch中實(shí)現(xiàn)卷積的重要基礎(chǔ)函數(shù)

1、nn.Conv2d:

nn.Conv2d在pytorch中用于實(shí)現(xiàn)卷積。

nn.Conv2d(
    in_channels=32,
    out_channels=64,
    kernel_size=3,
    stride=1,
    padding=1,
)

1、in_channels為輸入通道數(shù)。

2、out_channels為輸出通道數(shù)。

3、kernel_size為卷積核大小。

4、stride為步數(shù)。

5、padding為padding情況。

6、dilation表示空洞卷積情況。

2、nn.MaxPool2d(kernel_size=2)

nn.MaxPool2d在pytorch中用于實(shí)現(xiàn)最大池化。

具體使用方式如下:

MaxPool2d(kernel_size, 
		stride=None, 
		padding=0, 
		dilation=1, 
		return_indices=False, 
		ceil_mode=False)

1、kernel_size為池化核的大小

2、stride為步長(zhǎng)

3、padding為填充情況

3、nn.ReLU()

nn.ReLU()用來實(shí)現(xiàn)Relu函數(shù),實(shí)現(xiàn)非線性。

4、x.view()

x.view用于reshape特征層的形狀。

全部代碼

這是一個(gè)簡(jiǎn)單的CNN模型,用于預(yù)測(cè)mnist手寫體。

import os
import numpy as np
import torch
import torch.nn as nn
import torch.utils.data as Data
import torchvision
import matplotlib.pyplot as plt
# 循環(huán)世代
EPOCH = 20
BATCH_SIZE = 50
# 下載mnist數(shù)據(jù)集
train_data = torchvision.datasets.MNIST(root='./mnist/',train=True,transform=torchvision.transforms.ToTensor(),download=True,)
# (60000, 28, 28)
print(train_data.train_data.size())                 
# (60000)
print(train_data.train_labels.size())               
train_loader = Data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
# 測(cè)試集
test_data = torchvision.datasets.MNIST(root='./mnist/', train=False)
# (2000, 1, 28, 28)
# 標(biāo)準(zhǔn)化
test_x = torch.unsqueeze(test_data.test_data, dim=1).type(torch.FloatTensor)[:2000]/255.
test_y = test_data.test_labels[:2000]
# 建立pytorch神經(jīng)網(wǎng)絡(luò)
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        #----------------------------#
        #   第一部分卷積
        #----------------------------#
        self.conv1 = nn.Sequential(
            nn.Conv2d(
                in_channels=1,
                out_channels=32,
                kernel_size=5,
                stride=1,
                padding=2,
                dilation=1
            ),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),
        )
        #----------------------------#
        #   第二部分卷積
        #----------------------------#
        self.conv2 = nn.Sequential( 
            nn.Conv2d(
                in_channels=32,
                out_channels=64,
                kernel_size=3,
                stride=1,
                padding=1,
                dilation=1
            ),
            nn.ReLU(),
            nn.MaxPool2d(kernel_size=2),
        )
        #----------------------------#
        #   全連接+池化+全連接
        #----------------------------#
        self.ful1 = nn.Linear(64 * 7 * 7, 512)
        self.drop = nn.Dropout(0.5)
        self.ful2 = nn.Sequential(nn.Linear(512, 10),nn.Softmax())
    #----------------------------#
    #   前向傳播
    #----------------------------#   
    def forward(self, x):
        x = self.conv1(x)
        x = self.conv2(x)
        x = x.view(x.size(0), -1)
        x = self.ful1(x)
        x = self.drop(x)
        output = self.ful2(x)
        return output
cnn = CNN()
# 指定優(yōu)化器
optimizer = torch.optim.Adam(cnn.parameters(), lr=1e-3) 
# 指定loss函數(shù)
loss_func = nn.CrossEntropyLoss()
for epoch in range(EPOCH):
    for step, (b_x, b_y) in enumerate(train_loader): 
        #----------------------------#
        #   計(jì)算loss并修正權(quán)值
        #----------------------------#   
        output = cnn(b_x)
        loss = loss_func(output, b_y) 
        optimizer.zero_grad() 
        loss.backward() 
        optimizer.step() 
        #----------------------------#
        #   打印
        #----------------------------#   
        if step % 50 == 0:
            test_output = cnn(test_x)
            pred_y = torch.max(test_output, 1)[1].data.numpy()
            accuracy = float((pred_y == test_y.data.numpy()).astype(int).sum()) / float(test_y.size(0))
            print('Epoch: %2d'% epoch, ', loss: %.4f' % loss.data.numpy(), ', accuracy: %.4f' % accuracy)

以上就是PyTorch實(shí)現(xiàn)卷積神經(jīng)網(wǎng)絡(luò)的搭建詳解的詳細(xì)內(nèi)容,更多關(guān)于PyTorch搭建卷積神經(jīng)網(wǎng)絡(luò)的資料請(qǐng)關(guān)注腳本之家其它相關(guān)文章!

相關(guān)文章

  • Python 實(shí)現(xiàn)簡(jiǎn)單的電話本功能

    Python 實(shí)現(xiàn)簡(jiǎn)單的電話本功能

    這篇文章主要介紹了Python 實(shí)現(xiàn)簡(jiǎn)單的電話本功能的相關(guān)資料,包括添加聯(lián)系人信息,查找姓名顯示聯(lián)系人,存儲(chǔ)聯(lián)系人到 TXT 文檔等內(nèi)容,十分的細(xì)致,有需要的小伙伴可以參考下
    2015-08-08
  • Python開啟Http Server的實(shí)現(xiàn)步驟

    Python開啟Http Server的實(shí)現(xiàn)步驟

    本文主要介紹了Python開啟Http Server的實(shí)現(xiàn)步驟,文中通過示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧
    2023-07-07
  • 在Python中使用matplotlib模塊繪制數(shù)據(jù)圖的示例

    在Python中使用matplotlib模塊繪制數(shù)據(jù)圖的示例

    這篇文章主要介紹了在Python中使用matplotlib模塊繪制數(shù)據(jù)圖的示例,matplotlib模塊經(jīng)常被用來實(shí)現(xiàn)數(shù)據(jù)的可視化,需要的朋友可以參考下
    2015-05-05
  • matplotlib作圖添加表格實(shí)例代碼

    matplotlib作圖添加表格實(shí)例代碼

    這篇文章主要介紹了matplotlib作圖添加表格實(shí)例代碼,實(shí)例繪制了一個(gè)簡(jiǎn)單的折線圖,并且在圖中添加了一個(gè)表格,小編覺得還是挺不錯(cuò)的,具有一定借鑒價(jià)值,需要的朋友可以參考下
    2018-01-01
  • TensorFlow Autodiff自動(dòng)微分詳解

    TensorFlow Autodiff自動(dòng)微分詳解

    這篇文章主要介紹了TensorFlow Autodiff自動(dòng)微分詳解,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過來看看吧
    2020-07-07
  • 解決IDEA 的 plugins 搜不到任何的插件問題

    解決IDEA 的 plugins 搜不到任何的插件問題

    這篇文章主要介紹了解決IDEA 的 plugins 搜不到任何的插件問題,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過來看看吧
    2020-05-05
  • Python 實(shí)現(xiàn)進(jìn)度條的六種方式

    Python 實(shí)現(xiàn)進(jìn)度條的六種方式

    這篇文章主要介紹了Python 實(shí)現(xiàn)進(jìn)度條的六種方式,幫助大家更好的理解和使用python,感興趣的朋友可以了解下
    2021-01-01
  • PyCharm插件開發(fā)實(shí)踐之PyGetterAndSetter詳解

    PyCharm插件開發(fā)實(shí)踐之PyGetterAndSetter詳解

    這篇文章主要介紹了PyCharm插件開發(fā)實(shí)踐-PyGetterAndSetter,本文給大家介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或工作具有一定的參考借鑒價(jià)值,需要的朋友可以參考下
    2021-10-10
  • pyenv虛擬環(huán)境管理python多版本和軟件庫的方法

    pyenv虛擬環(huán)境管理python多版本和軟件庫的方法

    這篇文章主要介紹了pyenv虛擬環(huán)境管理python多版本和軟件庫,文中通過示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧
    2019-12-12
  • python 將列表中的字符串連接成一個(gè)長(zhǎng)路徑的方法

    python 將列表中的字符串連接成一個(gè)長(zhǎng)路徑的方法

    今天小編就為大家分享一篇python 將列表中的字符串連接成一個(gè)長(zhǎng)路徑的方法,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過來看看吧
    2018-10-10

最新評(píng)論