YOLOv5中SPP/SPPF結(jié)構(gòu)源碼詳析(內(nèi)含注釋分析)
一、SPP的應(yīng)用的背景
在卷積神經(jīng)網(wǎng)絡(luò)中我們經(jīng)常看到固定輸入的設(shè)計(jì),但是如果我們輸入的不能是固定尺寸的該怎么辦呢?
通常來(lái)說(shuō),我們有以下幾種方法:
(1)對(duì)輸入進(jìn)行resize操作,讓他們統(tǒng)統(tǒng)變成你設(shè)計(jì)的層的輸入規(guī)格那樣。但是這樣過(guò)于暴力直接,可能會(huì)丟失很多信息或者多出很多不該有的信息(圖片變形等),影響最終的結(jié)果。
(2)替換網(wǎng)絡(luò)中的全連接層,對(duì)最后的卷積層使用global average pooling,全局平均池化只和通道數(shù)有關(guān),而與特征圖大小沒(méi)有關(guān)系
(3)最后一個(gè)當(dāng)然是我們要講的SPP結(jié)構(gòu)啦~
二、SPP結(jié)構(gòu)分析
SPP結(jié)構(gòu)又被稱為空間金字塔池化,能將任意大小的特征圖轉(zhuǎn)換成固定大小的特征向量。
接下來(lái)我們來(lái)詳述一下SPP是怎么處理滴~
輸入層:首先我們現(xiàn)在有一張任意大小的圖片,其大小為w * h。
輸出層:21個(gè)神經(jīng)元 -- 即我們待會(huì)希望提取到21個(gè)特征。
分析如下圖所示:分別對(duì)1 * 1分塊,2 * 2分塊和4 * 4子圖里分別取每一個(gè)框內(nèi)的max值(即取藍(lán)框框內(nèi)的最大值),這一步就是作最大池化,這樣最后提取出來(lái)的特征值(即取出來(lái)的最大值)一共有1 * 1 + 2 * 2 + 4 * 4 = 21個(gè)。得出的特征再concat在一起。
而在YOLOv5中SPP的結(jié)構(gòu)圖如下圖所示:
其中,前后各多加一個(gè)CBL,中間的kernel size分別為1 * 1,5 * 5,9 * 9和13 * 13。
三、SPPF結(jié)構(gòu)分析
(x,y1這些是啥請(qǐng)看下面的代碼)
四、YOLOv5中SPP/SPPF結(jié)構(gòu)源碼解析(內(nèi)含注釋分析)
代碼注釋與上圖的SPP結(jié)構(gòu)相對(duì)應(yīng)。
class SPP(nn.Module): def __init__(self, c1, c2, k=(5, 9, 13)):#這里5,9,13,就是初始化的kernel size super().__init__() c_ = c1 // 2 # hidden channels self.cv1 = Conv(c1, c_, 1, 1)#這里對(duì)應(yīng)第一個(gè)CBL self.cv2 = Conv(c_ * (len(k) + 1), c2, 1, 1)#這里對(duì)應(yīng)SPP操作里的最后一個(gè)CBL self.m = nn.ModuleList([nn.MaxPool2d(kernel_size=x, stride=1, padding=x // 2) for x in k]) #這里對(duì)應(yīng)SPP核心操作,對(duì)5 * 5分塊,9 * 9分塊和13 * 13子圖分別取最大池化 def forward(self, x): x = self.cv1(x) with warnings.catch_warnings(): warnings.simplefilter('ignore') # suppress torch 1.9.0 max_pool2d() warning忽略警告 return self.cv2(torch.cat([x] + [m(x) for m in self.m], 1)) #torch.cat對(duì)應(yīng)concat
SPPF結(jié)構(gòu)
class SPPF(nn.Module): # Spatial Pyramid Pooling - Fast (SPPF) layer for YOLOv5 by Glenn Jocher def __init__(self, c1, c2, k=5): # equivalent to SPP(k=(5, 9, 13)) super().__init__() c_ = c1 // 2 # hidden channels self.cv1 = Conv(c1, c_, 1, 1) self.cv2 = Conv(c_ * 4, c2, 1, 1) self.m = nn.MaxPool2d(kernel_size=k, stride=1, padding=k // 2) def forward(self, x): x = self.cv1(x)#先通過(guò)CBL進(jìn)行通道數(shù)的減半 with warnings.catch_warnings(): warnings.simplefilter('ignore') # suppress torch 1.9.0 max_pool2d() warning y1 = self.m(x) y2 = self.m(y1) #上述兩次最大池化 return self.cv2(torch.cat([x, y1, y2, self.m(y2)], 1)) #將原來(lái)的x,一次池化后的y1,兩次池化后的y2,3次池化的self.m(y2)先進(jìn)行拼接,然后再CBL
總結(jié)
到此這篇關(guān)于YOLOv5中SPP/SPPF結(jié)構(gòu)源碼詳析的文章就介紹到這了,更多相關(guān)YOLOv5 SPP/SPPF結(jié)構(gòu)內(nèi)容請(qǐng)搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
相關(guān)文章

numpy數(shù)組坐標(biāo)軸問(wèn)題解決

python定時(shí)復(fù)制遠(yuǎn)程文件夾中所有文件

numpy中數(shù)組拼接、數(shù)組合并方法總結(jié)(append(),?concatenate,?hstack,?vstack

使用OpenCV-python3實(shí)現(xiàn)滑動(dòng)條更新圖像的Canny邊緣檢測(cè)功能

python實(shí)現(xiàn)批量轉(zhuǎn)換圖片為黑白