Pytorch搭建SRGAN平臺提升圖片超分辨率
網(wǎng)絡(luò)構(gòu)建
一、什么是SRGAN
SRGAN出自論文Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network。
如果將SRGAN看作一個黑匣子,其主要的功能就是輸入一張低分辨率圖片,生成高分辨率圖片。
該文章提到,普通的超分辨率模型訓練網(wǎng)絡(luò)時只用到了均方差作為損失函數(shù),雖然能夠獲得很高的峰值信噪比,但是恢復出來的圖像通常會丟失高頻細節(jié)。
SRGAN利用感知損失(perceptual loss)和對抗損失(adversarial loss)來提升恢復出的圖片的真實感。
二、生成網(wǎng)絡(luò)的構(gòu)建
生成網(wǎng)絡(luò)的構(gòu)成如上圖所示,生成網(wǎng)絡(luò)的作用是輸入一張低分辨率圖片,生成高分辨率圖片。:
SRGAN的生成網(wǎng)絡(luò)由三個部分組成。
1、低分辨率圖像進入后會經(jīng)過一個卷積+RELU函數(shù)。
2、然后經(jīng)過B個殘差網(wǎng)絡(luò)結(jié)構(gòu),每個殘差結(jié)構(gòu)都包含兩個卷積+標準化+RELU,還有一個殘差邊。
3、然后進入上采樣部分,在經(jīng)過兩次上采樣后,原圖的高寬變?yōu)樵瓉淼?倍,實現(xiàn)分辨率的提升。
前兩個部分用于特征提取,第三部分用于提高分辨率。
import math import torch from torch import nn class ResidualBlock(nn.Module): def __init__(self, channels): super(ResidualBlock, self).__init__() self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1) self.bn1 = nn.BatchNorm2d(channels) self.prelu = nn.PReLU(channels) self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1) self.bn2 = nn.BatchNorm2d(channels) def forward(self, x): short_cut = x x = self.conv1(x) x = self.bn1(x) x = self.prelu(x) x = self.conv2(x) x = self.bn2(x) return x + short_cut class UpsampleBLock(nn.Module): def __init__(self, in_channels, up_scale): super(UpsampleBLock, self).__init__() self.conv = nn.Conv2d(in_channels, in_channels * up_scale ** 2, kernel_size=3, padding=1) self.pixel_shuffle = nn.PixelShuffle(up_scale) self.prelu = nn.PReLU(in_channels) def forward(self, x): x = self.conv(x) x = self.pixel_shuffle(x) x = self.prelu(x) return x class Generator(nn.Module): def __init__(self, scale_factor, num_residual=16): upsample_block_num = int(math.log(scale_factor, 2)) super(Generator, self).__init__() self.block_in = nn.Sequential( nn.Conv2d(3, 64, kernel_size=9, padding=4), nn.PReLU(64) ) self.blocks = [] for _ in range(num_residual): self.blocks.append(ResidualBlock(64)) self.blocks = nn.Sequential(*self.blocks) self.block_out = nn.Sequential( nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64) ) self.upsample = [UpsampleBLock(64, 2) for _ in range(upsample_block_num)] self.upsample.append(nn.Conv2d(64, 3, kernel_size=9, padding=4)) self.upsample = nn.Sequential(*self.upsample) def forward(self, x): x = self.block_in(x) short_cut = x x = self.blocks(x) x = self.block_out(x) upsample = self.upsample(x + short_cut) return torch.tanh(upsample)
三、判別網(wǎng)絡(luò)的構(gòu)建
判別網(wǎng)絡(luò)的構(gòu)成如上圖所示:
SRGAN的判別網(wǎng)絡(luò)由不斷重復的 卷積+LeakyRELU和標準化 組成。
對于判斷網(wǎng)絡(luò)來講,它的目的是判斷輸入圖片的真假,它的輸入是圖片,輸出是判斷結(jié)果。
判斷結(jié)果處于0-1之間,利用接近1代表判斷為真圖片,接近0代表判斷為假圖片。
判斷網(wǎng)絡(luò)的構(gòu)建和普通卷積網(wǎng)絡(luò)差距不大,都是不斷的卷積對圖片進行下采用,在多次卷積后,最終接一次全連接判斷結(jié)果。
實現(xiàn)代碼如下:
class Discriminator(nn.Module): def __init__(self): super(Discriminator, self).__init__() self.net = nn.Sequential( nn.Conv2d(3, 64, kernel_size=3, padding=1), nn.LeakyReLU(0.2), nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1), nn.BatchNorm2d(64), nn.LeakyReLU(0.2), nn.Conv2d(64, 128, kernel_size=3, padding=1), nn.BatchNorm2d(128), nn.LeakyReLU(0.2), nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1), nn.BatchNorm2d(128), nn.LeakyReLU(0.2), nn.Conv2d(128, 256, kernel_size=3, padding=1), nn.BatchNorm2d(256), nn.LeakyReLU(0.2), nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1), nn.BatchNorm2d(256), nn.LeakyReLU(0.2), nn.Conv2d(256, 512, kernel_size=3, padding=1), nn.BatchNorm2d(512), nn.LeakyReLU(0.2), nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1), nn.BatchNorm2d(512), nn.LeakyReLU(0.2), nn.AdaptiveAvgPool2d(1), nn.Conv2d(512, 1024, kernel_size=1), nn.LeakyReLU(0.2), nn.Conv2d(1024, 1, kernel_size=1) ) def forward(self, x): batch_size = x.size(0) return torch.sigmoid(self.net(x).view(batch_size))
訓練思路
SRGAN的訓練可以分為生成器訓練和判別器訓練:
每一個step中一般先訓練判別器,然后訓練生成器。
一、判別器的訓練
在訓練判別器的時候我們希望判別器可以判斷輸入圖片的真?zhèn)危虼宋覀兊妮斎刖褪钦鎴D片、假圖片和它們對應(yīng)的標簽。
因此判別器的訓練步驟如下:
1、隨機選取batch_size個真實高分辨率圖片。
2、利用resize后的低分辨率圖片,傳入到Generator中生成batch_size個虛假高分辨率圖片。
3、真實圖片的label為1,虛假圖片的label為0,將真實圖片和虛假圖片當作訓練集傳入到Discriminator中進行訓練。
二、生成器的訓練
在訓練生成器的時候我們希望生成器可以生成極為真實的假圖片。因此我們在訓練生成器需要知道判別器認為什么圖片是真圖片。
因此生成器的訓練步驟如下:
1、將低分辨率圖像傳入生成模型,得到虛假高分辨率圖像,將虛假高分辨率圖像獲得判別結(jié)果與1進行對比得到loss。(與1對比的意思是,讓生成器根據(jù)判別器判別的結(jié)果進行訓練)。
2、將真實高分辨率圖像和虛假高分辨率圖像傳入VGG網(wǎng)絡(luò),獲得兩個圖像的特征,通過這兩個圖像的特征進行比較獲得loss
利用SRGAN生成圖片
SRGAN的庫整體結(jié)構(gòu)如下:
一、數(shù)據(jù)集的準備
在訓練前需要準備好數(shù)據(jù)集,數(shù)據(jù)集保存在datasets文件夾里面。
二、數(shù)據(jù)集的處理
打開txt_annotation.py,默認指向根目錄下的datasets。運行txt_annotation.py。
此時生成根目錄下面的train_lines.txt。
三、模型訓練
在完成數(shù)據(jù)集處理后,運行train.py即可開始訓練。
訓練過程中,可在results文件夾內(nèi)查看訓練效果:
以上就是Pytorch搭建SRGAN平臺提升圖片超分辨率的詳細內(nèi)容,更多關(guān)于Pytorch搭建SRGAN圖片超分辨率的資料請關(guān)注腳本之家其它相關(guān)文章!
相關(guān)文章
利用PyQt5模擬實現(xiàn)網(wǎng)頁鼠標移動特效
不知道大家有沒有發(fā)現(xiàn),博客園有些博客左側(cè)會有鼠標移動特效。通過移動鼠標,會形成類似蜘蛛網(wǎng)的特效,本文將用PyQt5實現(xiàn)這一特效,需要的可以參考一下2022-03-03Python使用matplotlib填充圖形指定區(qū)域代碼示例
這篇文章主要介紹了Python使用matplotlib填充圖形指定區(qū)域代碼示例,具有一定借鑒價值,需要的朋友可以參考下2018-01-01Python處理字符串的常用函數(shù)實例總結(jié)
在數(shù)據(jù)分析中,特別是文本分析中,字符處理需要耗費極大的精力,因而了解字符處理對于數(shù)據(jù)分析而言,也是一項很重要的能力,這篇文章主要給大家介紹了關(guān)于Python處理字符串的常用函數(shù),需要的朋友可以參考下2021-11-11