Pytorch實現(xiàn)ResNet網(wǎng)絡之Residual Block殘差塊
Residual Block
ResNet中最重要的組件是殘差塊(residual block),也稱為殘差單元(residual unit)。一個標準的殘差塊包含兩層卷積層和一條跳過連接(skip connection),如下
假設輸入x的大小為F×H×W,其中FFF表示通道數(shù),H和W分別表示高度和寬度。那么通過殘差塊后輸出的特征圖的大小仍然是F×H×W。
跳過連接能夠使得該層網(wǎng)絡可以直接通過進行恒等映射(identity mapping)來優(yōu)化模型,并避免反激化迫使網(wǎng)絡退化。即殘差塊應該學習到輸入數(shù)據(jù)和輸出數(shù)據(jù)的差異,而不是完全復制輸入數(shù)據(jù)。
實現(xiàn)一個殘差塊
代碼如下所示:
import torch.nn as nn
class ResidualBlock(nn.Module):
def __init__(self, in_channels, out_channels, stride=1):
super().__init__()
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=False)
self.bn1 = nn.BatchNorm2d(out_channels)
self.relu = nn.ReLU(inplace=True)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=False)
self.bn2 = nn.BatchNorm2d(out_channels)
self.shortcut = nn.Sequential()
if in_channels != out_channels or stride != 1:
self.shortcut = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, bias=False),
nn.BatchNorm2d(out_channels))
def forward(self, x):
residual = x
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
x = self.conv2(x)
x = self.bn2(x)
shortcut = self.shortcut(residual)
x += shortcut
x = self.relu(x)
return x這段代碼定義了一個繼承自nn.Module的殘差塊。在初始化過程中,我們定義了兩個卷積層、兩個批標準化(batch normalization)層以及一個恒等映射短連接(shortcut)。其中第二個卷積層的輸入通道數(shù)必須與輸出通道數(shù)相同。
在forward函數(shù)中,我們首先將輸入數(shù)據(jù)xxx保存到一個變量residual中。然后將xxx通過第一個卷積層、批標準化以及ReLU激活函數(shù),再通過第二個卷積層和批標準化。
默認情況下,跳過連接是一個恒等映射,即僅將輸入數(shù)據(jù)復制并直接加到輸出數(shù)據(jù)上。如果輸入的通道數(shù)與輸出的通道數(shù)不同,或者在卷積操作中改變了特征圖的大?。╯tride > 1),則需要對輸入進行適當?shù)奶幚硪耘c輸出相匹配。我們使用1×1卷積層(又稱為“投影級”)來改變大小和通道數(shù),并將其添加到shortcut`, 確保整個殘差塊拓撲中都能夠正確地實現(xiàn)殘差學習。
以上就是Pytorch實現(xiàn)ResNet網(wǎng)絡之Residual Block殘差塊的詳細內(nèi)容,更多關于Pytorch ResNet殘差塊的資料請關注腳本之家其它相關文章!
相關文章
jupyter notebook tensorflow打印device信息實例
這篇文章主要介紹了jupyter notebook tensorflow打印device信息實例,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2020-04-04
Python?計算機視覺編程進階之OpenCV?圖像銳化及邊緣檢測
計算機視覺這種技術(shù)可以將靜止圖像或視頻數(shù)據(jù)轉(zhuǎn)換為一種決策或新的表示。所有這樣的轉(zhuǎn)換都是為了完成某種特定的目的而進行的,本篇我們來學習下如何對圖像進行銳化處理以及如何進行邊緣檢測2021-11-11
Python+OpenCV圖像處理——實現(xiàn)輪廓發(fā)現(xiàn)
這篇文章主要介紹了Python+OpenCV實現(xiàn)輪廓發(fā)現(xiàn),幫助大家更好的利用python處理圖片,感興趣的朋友可以了解下2020-10-10
詳解在Python程序中解析并修改XML內(nèi)容的方法
這篇文章主要介紹了在Python程序中解析并修改XML內(nèi)容的方法,依賴于解析成樹狀結(jié)構(gòu)后的節(jié)點進行修改,需要的朋友可以參考下2015-11-11
Python cookbook(數(shù)據(jù)結(jié)構(gòu)與算法)通過公共鍵對字典列表排序算法示例
這篇文章主要介紹了Python cookbook(數(shù)據(jù)結(jié)構(gòu)與算法)通過公共鍵對字典列表排序算法,結(jié)合實例形式分析了Python基于operator模塊中的itemgetter()函數(shù)對字典進行排序的相關操作技巧,需要的朋友可以參考下2018-03-03

