PyTorch詳解經(jīng)典網(wǎng)絡(luò)ResNet實(shí)現(xiàn)流程
簡述
GoogleNet 和 VGG 等網(wǎng)絡(luò)證明了,更深度的網(wǎng)絡(luò)可以抽象出表達(dá)能力更強(qiáng)的特征,進(jìn)而獲得更強(qiáng)的分類能力。在深度網(wǎng)絡(luò)中,隨之網(wǎng)絡(luò)深度的增加,每層輸出的特征圖分辨率主要是高和寬越來越小,而深度逐漸增加。
深度的增加理論上能夠提升網(wǎng)絡(luò)的表達(dá)能力,但是對于優(yōu)化來說就會產(chǎn)生梯度消失的問題。在深度網(wǎng)絡(luò)中,反向傳播時(shí),梯度從輸出端向數(shù)據(jù)端逐層傳播,傳播過程中,梯度的累乘使得近數(shù)據(jù)段接近0值,使得網(wǎng)絡(luò)的訓(xùn)練失效。
為了解決梯度消失問題,可以在網(wǎng)絡(luò)中加入BatchNorm,激活函數(shù)換成ReLU,一定程度緩解了梯度消失問題。
深度增加的另一個(gè)問題就是網(wǎng)絡(luò)的退化(Degradation of deep network)問題。即,在現(xiàn)有網(wǎng)絡(luò)的基礎(chǔ)上,增加網(wǎng)絡(luò)的深度,理論上,只有訓(xùn)練到最佳情況,新網(wǎng)絡(luò)的性能應(yīng)該不會低于淺層的網(wǎng)絡(luò)。因?yàn)椋灰獙⑿略黾拥膶訉W(xué)習(xí)成恒等映射(identity mapping)就可以。換句話說,淺網(wǎng)絡(luò)的解空間是深的網(wǎng)絡(luò)的解空間的子集。但是由于Degradation問題,更深的網(wǎng)絡(luò)并不一定好于淺層網(wǎng)絡(luò)。
Residual模塊的想法就是認(rèn)為的讓網(wǎng)絡(luò)實(shí)現(xiàn)這種恒等映射。如圖,殘差結(jié)構(gòu)在兩層卷積的基礎(chǔ)上,并行添加了一個(gè)分支,將輸入直接加到最后的ReLU激活函數(shù)之前,如果兩層卷積改變大量輸入的分辨率和通道數(shù),為了能夠相加,可以在添加的分支上使用1x1卷積來匹配尺寸。
殘差結(jié)構(gòu)
ResNet網(wǎng)絡(luò)有兩種殘差塊,一種是兩個(gè)3x3卷積,一種是1x1,3x3,1x1三個(gè)卷積網(wǎng)絡(luò)串聯(lián)成殘差模塊。
PyTorch 實(shí)現(xiàn):
class Residual_1(nn.Module): r""" 18-layer, 34-layer 殘差塊 1. 使用了類似VGG的3×3卷積層設(shè)計(jì); 2. 首先使用兩個(gè)相同輸出通道數(shù)的3×3卷積層,后接一個(gè)批量規(guī)范化和ReLU激活函數(shù); 3. 加入跨過卷積層的通路,加到最后的ReLU激活函數(shù)前; 4. 如果要匹配卷積后的輸出的尺寸和通道數(shù),可以在加入的跨通路上使用1×1卷積; """ def __init__(self, input_channels, num_channels, use_1x1conv=False, strides=1): r""" parameters: input_channels: 輸入的通道上數(shù) num_channels: 輸出的通道數(shù) use_1x1conv: 是否需要使用1x1卷積控制尺寸 stride: 第一個(gè)卷積的步長 """ super().__init__() # 3×3卷積,strides控制分辨率是否縮小 self.conv1 = nn.Conv2d(input_channels, num_channels, kernel_size=3, padding=1, stride=strides) # 3×3卷積,不改變分辨率 self.conv2 = nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1) # 使用 1x1 卷積變換輸入的分辨率和通道 if use_1x1conv: self.conv3 = nn.Conv2d(input_channels, num_channels, kernel_size=1, stride=strides) else: self.conv3 = None # 批量規(guī)范化層 self.bn1 = nn.BatchNorm2d(num_channels) self.bn2 = nn.BatchNorm2d(num_channels) def forward(self, X): Y = F.relu(self.bn1(self.conv1(X))) Y = self.bn2(self.conv2(Y)) if self.conv3: X = self.conv3(X) # print(X.shape) Y += X return F.relu(Y)
class Residual_2(nn.Module): r""" 50-layer, 101-layer, 152-layer 殘差塊 1. 首先使用1x1卷積,ReLU激活函數(shù); 2. 然后用3×3卷積層,在接一個(gè)批量規(guī)范化,ReLU激活函數(shù); 3. 再接1x1卷積層; 4. 加入跨過卷積層的通路,加到最后的ReLU激活函數(shù)前; 5. 如果要匹配卷積后的輸出的尺寸和通道數(shù),可以在加入的跨通路上使用1×1卷積; """ def __init__(self, input_channels, num_channels, use_1x1conv=False, strides=1): r""" parameters: input_channels: 輸入的通道上數(shù) num_channels: 輸出的通道數(shù) use_1x1conv: 是否需要使用1x1卷積控制尺寸 stride: 第一個(gè)卷積的步長 """ super().__init__() # 1×1卷積,strides控制分辨率是否縮小 self.conv1 = nn.Conv2d(input_channels, num_channels, kernel_size=1, padding=1, stride=strides) # 3×3卷積,不改變分辨率 self.conv2 = nn.Conv2d(num_channels, num_channels, kernel_size=3, padding=1) # 1×1卷積,strides控制分辨率是否縮小 self.conv3 = nn.Conv2d(input_channels, num_channels, kernel_size=1, padding=1) # 使用 1x1 卷積變換輸入的分辨率和通道 if use_1x1conv: self.conv3 = nn.Conv2d(input_channels, num_channels, kernel_size=1, stride=strides) else: self.conv3 = None # 批量規(guī)范化層 self.bn1 = nn.BatchNorm2d(num_channels) self.bn2 = nn.BatchNorm2d(num_channels) def forward(self, X): Y = F.relu(self.bn1(self.conv1(X))) Y = F.relu(self.bn2(self.conv2(Y))) Y = self.conv3(Y) if self.conv3: X = self.conv3(X) # print(X.shape) Y += X return F.relu(Y)
ResNet有不同的網(wǎng)絡(luò)層數(shù),比較常用的是50-layer,101-layer,152-layer。他們都是由上述的殘差模塊堆疊在一起實(shí)現(xiàn)的。
以18-layer為例,層數(shù)是指:首先,conv_1 的一層7x7卷積,然后conv_2~conv_5四個(gè)模塊,每個(gè)模塊兩個(gè)殘差塊,每個(gè)殘差塊有兩層的3x3卷積組成,共4×2×2=16層,最后是一層分類層(fc),加總一起共1+16+1=18層。
18-layer 實(shí)現(xiàn)
首先定義由殘差結(jié)構(gòu)組成的模塊:
# ResNet模塊 def resnet_block(input_channels, num_channels, num_residuals, first_block=False): r"""殘差塊組成的模塊""" blk = [] for i in range(num_residuals): if i == 0 and not first_block: blk.append(Residual_1(input_channels, num_channels, use_1x1conv=True, strides=2)) else: blk.append(Residual_1(num_channels, num_channels)) return blk
定義18-layer的最開始的層:
# ResNet的前兩層: # 1. 輸出通道數(shù)64, 步幅為2的7x7卷積層 # 2. 步幅為2的3x3最大匯聚層 conv_1 = nn.Sequential(nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3), nn.BatchNorm2d(64), nn.ReLU(), nn.MaxPool2d(kernel_size=3, stride=2, padding=1))
定義殘差組模塊:
# ResNet模塊 conv_2 = nn.Sequential(*resnet_block(64, 64, 2, first_block=True)) conv_3 = nn.Sequential(*resnet_block(64, 128, 2)) conv_4 = nn.Sequential(*resnet_block(128, 256, 2)) conv_5 = nn.Sequential(*resnet_block(256, 512, 2))
ResNet 18-layer模型:
net = nn.Sequential(conv_1, conv_2, conv_3, conv_4, conv_5, nn.AdaptiveAvgPool2d((1, 1)), nn.Flatten(), nn.Linear(512, 10)) # 觀察模型各層的輸出尺寸 X = torch.rand(size=(1, 1, 224, 224)) for layer in net: X = layer(X) print(layer.__class__.__name__,'output shape:\t', X.shape)
輸出:
Sequential output shape: torch.Size([1, 64, 56, 56])
Sequential output shape: torch.Size([1, 64, 56, 56])
Sequential output shape: torch.Size([1, 128, 28, 28])
Sequential output shape: torch.Size([1, 256, 14, 14])
Sequential output shape: torch.Size([1, 512, 7, 7])
AdaptiveAvgPool2d output shape: torch.Size([1, 512, 1, 1])
Flatten output shape: torch.Size([1, 512])
Linear output shape: torch.Size([1, 10])
在數(shù)據(jù)集訓(xùn)練
def load_datasets_Cifar10(batch_size, resize=None): trans = [transforms.ToTensor()] if resize: transform = trans.insert(0, transforms.Resize(resize)) trans = transforms.Compose(trans) train_data = torchvision.datasets.CIFAR10(root="../data", train=True, transform=trans, download=True) test_data = torchvision.datasets.CIFAR10(root="../data", train=False, transform=trans, download=True) print("Cifar10 下載完成...") return (torch.utils.data.DataLoader(train_data, batch_size, shuffle=True), torch.utils.data.DataLoader(test_data, batch_size, shuffle=False)) def load_datasets_FashionMNIST(batch_size, resize=None): trans = [transforms.ToTensor()] if resize: transform = trans.insert(0, transforms.Resize(resize)) trans = transforms.Compose(trans) train_data = torchvision.datasets.FashionMNIST(root="../data", train=True, transform=trans, download=True) test_data = torchvision.datasets.FashionMNIST(root="../data", train=False, transform=trans, download=True) print("FashionMNIST 下載完成...") return (torch.utils.data.DataLoader(train_data, batch_size, shuffle=True), torch.utils.data.DataLoader(test_data, batch_size, shuffle=False)) def load_datasets(dataset, batch_size, resize): if dataset == "Cifar10": return load_datasets_Cifar10(batch_size, resize=resize) else: return load_datasets_FashionMNIST(batch_size, resize=resize) train_iter, test_iter = load_datasets("", 128, 224) # Cifar10
到此這篇關(guān)于PyTorch詳解經(jīng)典網(wǎng)絡(luò)ResNet實(shí)現(xiàn)流程的文章就介紹到這了,更多相關(guān)PyTorch ResNet內(nèi)容請搜索腳本之家以前的文章或繼續(xù)瀏覽下面的相關(guān)文章希望大家以后多多支持腳本之家!
相關(guān)文章
Python調(diào)用Prometheus監(jiān)控?cái)?shù)據(jù)并計(jì)算
Prometheus是一套開源監(jiān)控系統(tǒng)和告警為一體,由go語言(golang)開發(fā),是監(jiān)控+報(bào)警+時(shí)間序列數(shù)據(jù)庫的組合。本文將介紹Python如何調(diào)用Prometheus實(shí)現(xiàn)數(shù)據(jù)的監(jiān)控與計(jì)算,需要的可以參考一下2021-12-12Python如何通過Flask-Mail發(fā)送電子郵件
這篇文章主要介紹了Python如何通過Flask-Mail發(fā)送電子郵件,文中通過示例代碼介紹的非常詳細(xì),對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友可以參考下2020-01-01pycharm 更改創(chuàng)建文件默認(rèn)路徑的操作
今天小編就為大家分享一篇pycharm 更改創(chuàng)建文件默認(rèn)路徑的操作,具有很好的參考價(jià)值,希望對大家有所幫助。一起跟隨小編過來看看吧2020-02-02使用python裝飾器計(jì)算函數(shù)運(yùn)行時(shí)間的實(shí)例
下面小編就為大家分享一篇使用python裝飾器計(jì)算函數(shù)運(yùn)行時(shí)間的實(shí)例,具有很好的參考價(jià)值,希望對大家有所幫助。一起跟隨小編過來看看吧2018-04-04使用httplib模塊來制作Python下HTTP客戶端的方法
這篇文章主要介紹了使用httplib模塊來制作Python下HTTP客戶端的方法,文中列舉了一些httplib下常用的HTTP方法,需要的朋友可以參考下2015-06-06Python數(shù)據(jù)結(jié)構(gòu)與算法之跳表詳解
跳表是帶有附加指針的鏈表,使用這些附加指針可以跳過一些中間結(jié)點(diǎn),用以快速完成查找、插入和刪除等操作。本節(jié)將詳細(xì)介紹跳表的相關(guān)概念及其具體實(shí)現(xiàn),需要的可以參考一下2022-02-02Python cookbook(數(shù)據(jù)結(jié)構(gòu)與算法)根據(jù)字段將記錄分組操作示例
這篇文章主要介紹了Python cookbook(數(shù)據(jù)結(jié)構(gòu)與算法)根據(jù)字段將記錄分組操作,結(jié)合實(shí)例形式分析了itertools.groupby()函數(shù)針對字典進(jìn)行分組操作的相關(guān)實(shí)現(xiàn)技巧,需要的朋友可以參考下2018-03-03在Apache服務(wù)器上同時(shí)運(yùn)行多個(gè)Django程序的方法
這篇文章主要介紹了在Apache服務(wù)器上同時(shí)運(yùn)行多個(gè)Django程序的方法,Django是Python各色高人氣web框架中最為著名的一個(gè),需要的朋友可以參考下2015-07-07Python+Matplotlib繪制發(fā)散條形圖的示例代碼
發(fā)散條形圖(Diverging Bar)是一種用于顯示數(shù)據(jù)分布的圖表,可以幫助我們比較不同類別或分組的數(shù)據(jù)的差異和相對性,本文介紹了Matplotlib繪制發(fā)散條形圖的函數(shù)源碼,需要的可以參考一下2023-06-06