Swin?Transformer圖像處理深度學(xué)習(xí)模型
Swin Transformer
Swin Transformer是一種用于圖像處理的深度學(xué)習(xí)模型,它可以用于各種計(jì)算機(jī)視覺任務(wù),如圖像分類、目標(biāo)檢測(cè)和語義分割等。它的主要特點(diǎn)是采用了分層的窗口機(jī)制,可以處理比較大的圖像,同時(shí)也減少了模型參數(shù)的數(shù)量,提高了計(jì)算效率。Swin Transformer在圖像處理領(lǐng)域取得了很好的表現(xiàn),成為了最先進(jìn)的模型之一。
Swin Transformer通過從小尺寸的圖像塊(用灰色輪廓線框出)開始,并逐漸合并相鄰塊,構(gòu)建了一個(gè)分層的表示形式,在更深層的Transformer中實(shí)現(xiàn)。
整體架構(gòu)
Swin Transformer 模塊
Swin Transformer模塊是基于Transformer塊中標(biāo)準(zhǔn)的多頭自注意力模塊(MSA)進(jìn)行替換構(gòu)建的,用的是一種基于滑動(dòng)窗口的模塊(在后面細(xì)說),而其他層保持不變。如上圖所示,Swin Transformer模塊由基于滑動(dòng)窗口的多頭注意力模塊組成,后跟一個(gè)2層MLP,在中間使用GELU非線性激活函數(shù)。在每個(gè)MSA模塊和每個(gè)MLP之前都應(yīng)用了LayerNorm(LN)層,并在每個(gè)模塊之后應(yīng)用了殘差連接。
滑動(dòng)窗口機(jī)制
Cyclic Shift
Cyclic Shift是Swin Transformer中一種有效的處理局部特征的方法。在Swin Transformer中,為了處理高分辨率的輸入特征圖,需要將輸入特征圖分割成小塊(一個(gè)patch可能有多個(gè)像素)進(jìn)行處理。然而,這樣會(huì)導(dǎo)致局部特征在不同塊之間被分割開來,影響了局部特征的提取。Cyclic Shift將輸入特征圖沿著寬度和高度方向分別平移一個(gè)固定的距離,使得每個(gè)塊的局部特征可以與相鄰塊的局部特征進(jìn)行交互,從而增強(qiáng)了局部特征的表達(dá)能力。另外,Cyclic Shift還可以通過多次平移來增加塊之間的交互,進(jìn)一步提升了模型的性能。需要注意的是,Cyclic Shift只在訓(xùn)練過程中使用,因?yàn)樗鼤?huì)改變輸入特征圖的分布。在測(cè)試過程中,輸入特征圖的大小和分布與訓(xùn)練時(shí)相同,因此不需要使用Cyclic Shift操作。
Efficient batch computation for shifted configuration
Cyclic Shift會(huì)將輸入特征圖沿著寬度和高度方向進(jìn)行平移操作,以便讓不同塊之間的局部特征進(jìn)行交互。這樣的操作會(huì)導(dǎo)致每個(gè)塊的特征值的位置發(fā)生改變,從而需要在每個(gè)塊上重新計(jì)算注意力機(jī)制。
為了加速計(jì)算過程,Swin Transformer中引入了"Efficient batch computation for shifted configuration"這一技巧。該技巧首先將每個(gè)塊的特征值復(fù)制多次,分別放置在Cyclic Shift平移后的不同位置上,使得每個(gè)塊都可以在平移后的不同的位置上參與到注意力機(jī)制的計(jì)算中。然后,將這些位置不同的塊的特征值進(jìn)行合并拼接,計(jì)算注意力。
需要注意的是,這種技巧只在訓(xùn)練時(shí)使用,因?yàn)樗鼤?huì)增加計(jì)算量,而在測(cè)試時(shí),可以將每個(gè)塊的特征值計(jì)算一次,然后在不同位置上進(jìn)行拼接,以得到最終的輸出。
Relative position bias
在傳統(tǒng)的Transformer模型中,為了考慮單詞之間的位置關(guān)系,通常采用絕對(duì)位置編碼(Absolute Positional Encoding)的方式。這種方法是在每個(gè)單詞的embedding中添加位置編碼向量,以表示該單詞在序列中的絕對(duì)位置。但是,當(dāng)序列長度很長時(shí),絕對(duì)位置編碼會(huì)面臨兩個(gè)問題:
- 編碼向量的大小會(huì)隨著序列長度的增加而增加,導(dǎo)致模型參數(shù)量增大,訓(xùn)練難度加大;
- 當(dāng)序列長度超過一定限制時(shí),模型的性能會(huì)下降。
為了解決這些問題,Swin Transformer采用了Relative Positional Encoding,它通過編碼單詞之間的相對(duì)位置信息來代替絕對(duì)位置編碼。相對(duì)位置編碼是由每個(gè)單詞對(duì)其它單詞的相對(duì)位置關(guān)系計(jì)算得出的。在計(jì)算相對(duì)位置時(shí),Swin Transformer引入了Relative Position Bias,即相對(duì)位置偏置,它是一個(gè)可學(xué)習(xí)的參數(shù)矩陣,用于調(diào)整不同位置之間的相對(duì)位置關(guān)系。這樣做可以有效地減少相對(duì)位置編碼的參數(shù)量,同時(shí)提高模型的性能和效率。相對(duì)位置編碼可以通過以下公式計(jì)算:
最終,相對(duì)位置編碼和相對(duì)位置偏置的結(jié)果會(huì)被加到點(diǎn)積注意力機(jī)制中,用于計(jì)算不同位置之間的相關(guān)性,從而實(shí)現(xiàn)序列的建模。
代碼實(shí)現(xiàn):
下面是一個(gè)用PyTorch實(shí)現(xiàn)Swin B模型的示例代碼,其中包含了相對(duì)位置編碼和相對(duì)位置偏置的實(shí)現(xiàn):
import torch import torch.nn as nn from einops.layers.torch import Rearrange class SwinBlock(nn.Module): def __init__(self, in_channels, out_channels, window_size=7, shift_size=0): super(SwinBlock, self).__init__() self.window_size = window_size self.shift_size = shift_size self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) self.norm1 = nn.BatchNorm2d(out_channels) self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=window_size, stride=1, padding=window_size//2, groups=out_channels) self.norm2 = nn.BatchNorm2d(out_channels) self.conv3 = nn.Conv2d(out_channels, out_channels, kernel_size=1, stride=1, padding=0) self.norm3 = nn.BatchNorm2d(out_channels) if in_channels == out_channels: self.downsample = None else: self.downsample = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) self.norm_downsample = nn.BatchNorm2d(out_channels) def forward(self, x): residual = x out = self.conv1(x) out = self.norm1(out) out = nn.functional.relu(out) out = Rearrange(out, 'b c h w -> b (h w) c') out = self.shift_window(out) out = Rearrange(out, 'b (h w) c -> b c h w', h=int(x.shape[2]), w=int(x.shape[3])) out = self.conv2(out) out = self.norm2(out) out = nn.functional.relu(out) out = self.conv3(out) out = self.norm3(out) if self.downsample is not None: residual = self.downsample(x) residual = self.norm_downsample(residual) out += residual out = nn.functional.relu(out) return out def shift_window(self, x): # x: (B, L, C) B, L, C = x.shape if self.shift_size == 0: shifted_x = torch.zeros_like(x) shifted_x[:, self.window_size//2:L-self.window_size//2, :] = x[:, self.window_size//2:L-self.window_size//2, :] return shifted_x else: # pad feature maps to shift window left_pad = self.window_size // 2 + self.shift_size right_pad = left_pad - self.shift_size x = nn.functional.pad(x, (0, 0, left_pad, right_pad), mode='constant', value=0) # Reshape X to (B, H, W, C) H = W = int(x.shape[1] ** 0.5) x = Rearrange(x, 'b (h w) c -> b c h w', h=H, w=W) # Shift window x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(2, 3)) # Reshape back to (B, L, C) x = Rearrange(x, 'b c h w -> b (h w) c') return x[:, self.window] class SwinTransformer(nn.Module): def __init__(self, in_channels=3, num_classes=1000, num_layers=12, embed_dim=96, window_sizes=(7, 3, 3, 3), shift_sizes=(0, 1, 2, 3)): super(SwinTransformer, self).__init__() self.in_channels = in_channels self.num_classes = num_classes self.num_layers = num_layers self.embed_dim = embed_dim self.window_sizes = window_sizes self.shift_sizes = shift_sizes self.conv1 = nn.Conv2d(in_channels, embed_dim, kernel_size=4, stride=4, padding=0) self.norm1 = nn.BatchNorm2d(embed_dim) self.blocks = nn.ModuleList() for i in range(num_layers): self.blocks.append(SwinBlock(embed_dim * 2**i, embed_dim * 2**(i+1), window_size=window_sizes[i%4], shift_size=shift_sizes[i%4])) self.avgpool = nn.AdaptiveAvgPool2d((1, 1)) self.fc = nn.Linear(embed_dim * 2**num_layers, num_classes) # add relative position bias self.relative_position_bias_table = nn.Parameter( torch.zeros((2 * (2 * window_sizes[-1] - 1), embed_dim // 8, embed_dim // 8)), requires_grad=True) nn.init.kaiming_uniform_(self.relative_position_bias_table, a=1) # add relative position encoding self.pos_embed = nn.Parameter( torch.zeros(1, embed_dim * 2**num_layers, 7, 7), requires_grad=True) nn.init.kaiming_uniform_(self.pos_embed, a=1) def forward(self, x): out = self.conv1(x) out = self.norm1(out) out = nn.functional.relu(out) for block in self.blocks: out = block(out) out = self.avgpool(out) out = Rearrange(out, 'b c h w -> b (c h w)') out = self.fc(out) return out def get_relative_position_bias(self, H, W): # H, W: height and width of feature maps in the last block # output: (2HW-1, 8, 8) relative_position_bias_h = self.relative_position_bias_table[:, :(2 * H - 1), :(2 * W - 1)].transpose(0, 1) relative_position_bias_w = self.relative_position_bias_table[:, (2 * H - 1):, (2 * W - 1):].transpose(0, 1) relative_position_bias = torch.cat([relative_position_bias_h, relative_position_bias_w], dim=0) return relative_position_bias def get_relative_position_encoding(self, H, W): # H, W: height and width of feature maps in the last block # output: (1, HW, C) pos_x, pos_y = torch.meshgrid(torch.arange(H), torch.arange(W)) pos_x, pos_y = pos_x.float(), pos_y.float() pos_x = pos_x / (H-1) * 2 - 1 pos_y = pos_y / (W-1) * 2 - 1 pos_encoding = torch.stack((pos_y, pos_x), dim=-1) pos_encoding = pos_encoding.reshape(1, -1, 2) pos_encoding = pos_encoding.repeat(1, 1, embed_dim // 2) pos_encoding = pos_encoding.transpose(1, 2) return pos_encoding
以上就是Swin Transformer圖像處理深度學(xué)習(xí)模型的詳細(xì)內(nèi)容,更多關(guān)于Swin Transformer深度學(xué)習(xí)的資料請(qǐng)關(guān)注腳本之家其它相關(guān)文章!
相關(guān)文章
Python?Ruby?等語言棄用自增運(yùn)算符原因剖析
這篇文章主要為大家介紹了Python?Ruby?等語言棄用自增運(yùn)算符原因剖析,有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進(jìn)步,早日升職加薪2022-08-08keras model.fit 解決validation_spilt=num 的問題
這篇文章主要介紹了keras model.fit 解決validation_spilt=num 的問題,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過來看看吧2020-06-06python求平均數(shù)、方差、中位數(shù)的例子
今天小編就為大家分享一篇python求平均數(shù)、方差、中位數(shù)的例子,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過來看看吧2019-08-08python字符串str和字節(jié)數(shù)組相互轉(zhuǎn)化方法
下面小編就為大家?guī)硪黄猵ython字符串str和字節(jié)數(shù)組相互轉(zhuǎn)化方法。小編覺得挺不錯(cuò)的,現(xiàn)在就分享給大家,也給大家做個(gè)參考。一起跟隨小編過來看看吧2017-03-03python快速排序的實(shí)現(xiàn)及運(yùn)行時(shí)間比較
這篇文章主要介紹了python快速排序的實(shí)現(xiàn)及運(yùn)行時(shí)間比較,本文通過兩種方法給大家介紹,大家可以根據(jù)自己需要選擇適合自己的方法,對(duì)python實(shí)現(xiàn)快速排序相關(guān)知識(shí)感興趣的朋友一起看看吧2019-11-11Python創(chuàng)建類的方法及成員訪問的相關(guān)知識(shí)總結(jié)
今天給大家?guī)淼氖顷P(guān)于Python基礎(chǔ)的相關(guān)知識(shí),文章圍繞著Python類的方法及成員訪問展開,文中有非常詳細(xì)的介紹及代碼示例,需要的朋友可以參考下2021-06-06python2和python3應(yīng)該學(xué)哪個(gè)(python3.6與python3.7的選擇)
許多剛?cè)腴T Python 的朋友都在糾結(jié)的的問題是:我應(yīng)該選擇學(xué)習(xí) python2 還是 python3,Python 3.7 已經(jīng)發(fā)布了,目前Python的用戶,主要使用的版本 應(yīng)該是 Python3.6 和 Python2.7 ,那么是不是該轉(zhuǎn)到 Python 3.7 呢2019-10-10