Python機器學(xué)習(xí)從ResNet到DenseNet示例詳解
從ResNet到DenseNet
上圖中,左邊是ResNet,右邊是DenseNet,它們在跨層上的主要區(qū)別是:使用相加和使用連結(jié)。
最后,將這些展開式結(jié)合到多層感知機中,再次減少特征的數(shù)量。實現(xiàn)起來非常簡單:我們不需要添加術(shù)語,而是將它們連接起來。DenseNet這個名字由變量之間的“稠密連接”而得來,最后一層與之前的所有層緊密相連。稠密連接如下圖所示:
稠密網(wǎng)絡(luò)主要由2部分構(gòu)成:稠密塊(dense block)和過渡層(trainsition block)。
前者定義如何連接輸入和輸出,而后者則控制通道數(shù)量,使其不會太復(fù)雜。
稠密塊體
DenseNet使用了ResNet改良版的“批量歸一化、激活和卷積”結(jié)構(gòu)。我們首先實現(xiàn)下這個結(jié)構(gòu)。
import torch from torch import nn from d2l import torch as d2l def conv_block(input_channels, num_channels): return nn.Sequential( nn.BatchNorm2d(input_channels), nn.ReLU(), nn.Conv2d(input_channels, num_channels, kernel_size=3, padding=1) )
一個稠密塊由多個卷積塊組成,每個卷積塊使用相同矢量的輸出通道。然而,在前向傳播中,我們將每個卷積塊的輸入和輸出在通道維上連結(jié)。
class DenseBlock(nn.Module): def __init__(self, num_convs, input_channels, num_channels): super(Denseblock, self).__init__() layer = [] for i in range(num_convs): layer.append(conv_block(num_channels * i + input_channels, num_channels)) self.net = nn.Sequential(*layer) def forward(self, X): for blk in self.net: Y = blk(X) # 連結(jié)通道維度上的每個塊的輸入和輸出 X = torch.cat((X, Y), dim=1) return X
在下面的例子中,我們定義一個有2個輸出通道數(shù)為10的DenseBlock。使用通道數(shù)為3的輸入時,我們會得到通道數(shù)為 3 + 2 × 10 = 23 3+2\times10=23 3+2×10=23的輸出。卷積塊的通道數(shù)控制了輸出通道數(shù)相對于輸入通道數(shù)的增長,因此也被稱為增長率(growth rate)。
blk = DenseBlock(2, 3, 10) X = torch.randn(4, 3, 8, 8) Y = blk(X) Y.shape
torch.Size([4, 23, 8, 8])
過渡層
由于每個稠密快都會帶來通道數(shù)的增加,使用過多則會過于復(fù)雜化模型。而過渡層可以用來控制模型復(fù)雜度。它通過 1×1卷積層來減小通道數(shù),并使用步幅為2的平均匯聚層減半高和寬,從而進一步降低模型復(fù)雜度。
def transition_block(input_channels, num_channels): return nn.Sequential( nn.BatchNorm2d(input_channels), nn.ReLU(), nn.Conv2d(input_channels, num_channels, kernel_size=1) nn.AvgPool2d(kernel_size=2, stride=2) )
對上一個例子中稠密塊的輸出使用通道數(shù)為10的過渡層。此時輸出的通道數(shù)減為10,高和寬均減半。
blk = transition_block(23, 10) blk(Y).shape
torch.Size([4, 10, 4, 4])
DenseNet模型
我們來構(gòu)造DenseNet模型。DenseNet首先使用同ResNet一樣的單卷積層和最大聚集層。
b1 = nn.Sequential( nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3), nn.BatchNorm2d(64), nn.ReLU(), nn.MaxPool2d(kernel_size=3, stride=2, padding=1) )
接下來,類似于ResNet使用的4個殘差塊,DenseNet使用的是4個稠密塊。與ResNet類似,我們可以設(shè)置每個稠密塊使用多少個卷積層。這里我們設(shè)成4,從而與之前的ResNet-18保持一致。稠密塊里的卷積層通道數(shù)(即增長率)設(shè)置為32,所以每個稠密塊將增加128個通道。
在每個模塊之間,ResNet通過步幅為2的殘差塊減小高和寬,而DenseNet則使用過渡層來減半高和寬,并減半通道數(shù)。
# 'num_channels'為當(dāng)前通道數(shù) num_channels, growth_rate = 64, 32 num_convs_in_dense_blocks = [4, 4, 4, 4] blks = [] for i, num_convs in enumerate(num_convs_in_dense_blocks): blks.append(DenseBlock(num_convs, num_channels, growth_rate)) # 上一個稠密塊的輸出通道數(shù) num_channels += num_convs * growth_rate # 在稠密塊之間添加一個轉(zhuǎn)換層,使通道數(shù)量減半 if i != len(num_convs_in_dense_blocks) - 1: blks.append(transition_block(num_channels, num_channels // 2)) num_channels = num_channels // 2
與ResNet類似,最后接上全局匯聚層和全連接層來輸出結(jié)果。
net = nn.Sequential( b1, *blks, nn.BatchNorm2d(num_channels), nn.ReLU(), nn.AdaptiveMaxPool2d((1, 1)), nn.Flatten(), nn.Linear(num_channels, 10) )
訓(xùn)練模型
由于這里使用了比較深的網(wǎng)絡(luò),本節(jié)里我們將輸入高和寬從224降到96來簡化計算。
lr, num_epochs, batch_size = 0.1, 10, 256 train_iter, test_iter = d2l.load_data_fashion_mnist(batch_size, resize=96) d2l.train_ch6(net, train_iter, test_iter, num_epochs, lr, d2l.try_gpu())
loss 0.154, train acc 0.943, test acc 0.880 5506.9 examples/sec on cuda:0
以上就是Python機器學(xué)習(xí)從ResNet到DenseNet示例詳解的詳細(xì)內(nèi)容,更多關(guān)于Python機器學(xué)習(xí)ResNet到DenseNet的資料請關(guān)注腳本之家其它相關(guān)文章!
相關(guān)文章
利用Python實現(xiàn)自動生成圖文并茂的數(shù)據(jù)分析
這篇文章主要介紹了利用Python實現(xiàn)自動生成圖文并茂的數(shù)據(jù)分析,文章圍繞主題展開詳細(xì)的內(nèi)容介紹,具有一定的參考價值,需要的朋友可以參考一下2022-08-08pandas.DataFrame的for循環(huán)迭代的實現(xiàn)
本文主要介紹了pandas.DataFrame的for循環(huán)迭代的實現(xiàn),文中通過示例代碼介紹的非常詳細(xì),對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2023-02-02Python高級應(yīng)用探索之元編程和并發(fā)編程詳解
Python作為一種簡單易用且功能強大的編程語言,廣泛應(yīng)用于各個領(lǐng)域,本文主要來和大家一起探索一下Python中的優(yōu)化技巧、元編程和并發(fā)編程,希望對大家有所幫助2023-11-11Pygame游戲開發(fā)之太空射擊實戰(zhàn)子彈與碰撞處理篇
相信大多數(shù)8090后都玩過太空射擊游戲,在過去游戲不多的年代太空射擊自然屬于經(jīng)典好玩的一款了,今天我們來自己動手實現(xiàn)它,在編寫學(xué)習(xí)中回顧過往展望未來,下面開始講解子彈與碰撞處理,在本課中,我們將添加玩家與敵人之間的碰撞,以及添加供玩家射擊的子彈2022-08-08