Pytorch 卷積中的 Input Shape用法
先看Pytorch中的卷積
class torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True)
二維卷積層, 輸入的尺度是(N, C_in,H,W),輸出尺度(N,C_out,H_out,W_out)的計算方式
這里比較奇怪的是這個卷積層居然沒有定義input shape,輸入尺寸明明是:(N, C_in, H,W),但是定義中卻只需要輸入in_channel的size,就能完成卷積,那是不是說這樣任意size的image都可以進行卷積呢?
然后我進行了下面這樣的實驗:
import torch import torch.nn as nn import torch.nn.functional as F class Net(nn.Module): def __init__(self): super(Net, self).__init__() # 輸入圖像channel:1;輸出channel:6;5x5卷積核 self.conv1 = nn.Conv2d(1, 6, 5) self.conv2 = nn.Conv2d(6, 16, 5) # an affine operation: y = Wx + b self.fc1 = nn.Linear(16 * 5 * 5, 120) self.fc2 = nn.Linear(120, 84) self.fc3 = nn.Linear(84, 10) def forward(self, x): # 2x2 Max pooling x = F.max_pool2d(F.relu(self.conv1(x)), (2, 2)) # If the size is a square you can only specify a single number x = F.max_pool2d(F.relu(self.conv2(x)), 2) x = x.view(-1, self.num_flat_features(x)) x = F.relu(self.fc1(x)) x = F.relu(self.fc2(x)) x = self.fc3(x) return x def num_flat_features(self, x): size = x.size()[1:] # 除去批大小維度的其余維度 num_features = 1 for s in size: num_features *= s return num_features net = Net() print(net)
輸出
Net( (conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1)) (conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1)) (fc1): Linear(in_features=400, out_features=120, bias=True) (fc2): Linear(in_features=120, out_features=84, bias=True) (fc3): Linear(in_features=84, out_features=10, bias=True) )
官網(wǎng)Tutorial 說:這個網(wǎng)絡(luò)(LeNet)的期待輸入是32x32,我就比較奇怪他又沒有設(shè)置Input shape或者Tensorflow里的Input層,怎么就知道(H,W) =(32, 32)。
輸入:
input = torch.randn(1, 1, 32, 32)
output = Net(input)
沒問題,但是
input = torch.randn(1, 1, 64, 64)
output = Net(input)
出現(xiàn):mismatch Error
我們看一下卷積模型部分。
input:(1, 1, 32, 32) --> conv1(1, 6, 5) --> (1, 6, 28, 28) --> max_pool1(2, 2) --> (1, 6, 14, 14) --> conv2(6, 16, 5) -->(1, 16, 10, 10) --> max_pool2(2, 2) --> (1, 16, 5, 5)
然后是將其作為一個全連接網(wǎng)絡(luò)的輸入。Linear相當于tensorflow 中的Dense。所以當你的輸入尺寸不為(32, 32)時,卷積得到最終feature map shape就不是(None, 16, 5, 5),而我們的第一個Linear層的輸入為(None, 16 * 5 * 5),故會出現(xiàn)mismatch Error。
之所以會有這樣一個問題還是因為keras model 必須提定義Input shape,而pytorch更像是一個流程化操作,具體看官網(wǎng)吧。
補充知識:pytorch 卷積 分組卷積 及其深度卷積
先來看看pytorch二維卷積的操作API
現(xiàn)在繼續(xù)講講幾個卷積是如何操作的。
一. 普通卷積
torch.nn.Conv2d(in_channels, out_channels, kernel_size, stride=1, padding=0, dilation=1, groups=1, bias=True)
普通卷積時group默認為1 dilation=1(這里先暫時不討論dilation)其余都正常的話,比如輸入為Nx in_channel x high x width
輸出為N x out_channel x high xwidth .還是來具體的數(shù)字吧,輸入為64通道的特征圖,輸出為32通道的特征圖,要想得到32通道的特征圖就必須得有32種不同的卷積核。
也就是上面?zhèn)魅氲膮?shù)out_channel。繼續(xù)說說具體是怎么的得到的,對于每一種卷積核會和64種不同的特征圖依次進行卷積,比如卷積核大小是3x3大小的,由于卷積權(quán)值共享,所以對于輸入的一張?zhí)卣鲌D卷積時,只有3x3個參數(shù),同一張?zhí)卣鲌D上進行滑窗卷積操作時參數(shù)是一樣的,剛才說的第一種卷積核和輸入的第一個特征圖卷積完后再繼續(xù)和后面的第2,3,........64個不同的特征圖依次卷積(一種卷積核對于輸入特征圖來說,同一特征圖上面卷積,參數(shù)一樣,對于不同的特征圖上卷積不一樣),最后的參數(shù)是3x3x64。
此時輸出才為一個特征圖,因為現(xiàn)在才只使用了一種卷積核。一種核分別在局部小窗口里面和64個特征圖卷積應(yīng)該得到64個數(shù),最后將64個數(shù)相加就可以得到一個數(shù)了,也就是輸出一個特征圖上對應(yīng)于那個窗口的值,依次滑窗就可以得到完整的特征圖了。
前面將了這么多才使用一種卷積核,那么現(xiàn)在依次類推使用32種不同的卷積核就可以得到輸出的32通道的特征圖。最終參數(shù)為64x3x3x32.
二.分組卷積
參數(shù)group=1時,就是和普通的卷積一樣?,F(xiàn)在假如group=4,前提是輸入特征圖和輸出特征圖必須是4的倍數(shù)?,F(xiàn)在來看看是如何操作的。in_channel64分成4組,out_inchannel(也就是32種核)也分成4組,依次對應(yīng)上面的普通卷方式,最終將每組輸出的8個特征圖依次concat起來,就是結(jié)果的out_channel
三. 深度卷積depthwise
此時group=in_channle,也就是對每一個輸入的特征圖分別用不同的卷積核卷積。out_channel必須是in_channel 的整數(shù)倍。
3.1 當k=1時,out_channel=in_channel ,每一個卷積核分別和每一個輸入的通道進行卷積,最后在concat起來。參數(shù)總量為3x3x64。如果此時卷積完之后接著一個64個1x1大小的卷積核。就是谷歌公司于2017年的CVPR中在論文”Xception: deep learning with depthwise separable convolutions”中提出的結(jié)構(gòu)。如下圖
上圖是將1x1放在depthwise前面,其實原理都一樣。最終參數(shù)的個數(shù)是64x1x1+64x3x3。參數(shù)要小于普通的卷積方法64x3x3x64
3.2 當k是大于1的整數(shù)時,比如k=2
此時每一個輸入的特征圖對應(yīng)k個卷積核,生成k特征圖,最終生成的特征圖個數(shù)就是k×in_channel .
以上這篇Pytorch 卷積中的 Input Shape用法就是小編分享給大家的全部內(nèi)容了,希望能給大家一個參考,也希望大家多多支持腳本之家。
相關(guān)文章
每個 Python 開發(fā)者都應(yīng)該知道的7種好用工具(效率翻倍)
Python 從一種小的開源語言開始,到現(xiàn)在,它已經(jīng)成為開發(fā)者很受歡迎的編程語言之一。這篇文章主要介紹了每個 Python 開發(fā)者都應(yīng)該知道的7種好用工具(效率翻倍),需要的朋友可以參考下2021-03-03django配置連接數(shù)據(jù)庫及原生sql語句的使用方法
這篇文章主要給大家介紹了關(guān)于django配置連接數(shù)據(jù)庫,以及原生sql語句的使用方法,文中通過示例代碼介紹的非常詳細,對大家的學習或者工作具有一定的參考學習價值,需要的朋友們下面來一起學習學習吧2019-03-03python QT界面關(guān)閉線程池的線程跟隨退出完美解決方案
這篇文章主要介紹了python QT界面關(guān)閉,線程池的線程跟隨退出解決思路方法,本文給大家分享兩種方法結(jié)合實例代碼給大家介紹的非常詳細,需要的朋友可以參考下2022-11-11封裝?Python?時間處理庫創(chuàng)建自己的TimeUtil類示例
這篇文章主要為大家介紹了封裝?Python?時間處理庫創(chuàng)建自己的TimeUtil類示例詳解,有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進步早日升職加薪2023-05-05Python flask框架定時任務(wù)apscheduler應(yīng)用介紹
Flask是Python社區(qū)非常流行的一個Web開發(fā)框架,本文將嘗試將介紹APScheduler應(yīng)用于Flask之中實現(xiàn)定時任務(wù),文中通過示例代碼介紹的非常詳細,對大家的學習或者工作具有一定的參考學習價值,需要的朋友們下面隨著小編來一起學習吧2022-10-10使用python中的in ,not in來檢查元素是不是在列表中的方法
今天小編就為大家分享一篇使用python中的in ,not in來檢查元素是不是在列表中的方法,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2018-07-07