Pytorch搭建SRGAN平臺(tái)提升圖片超分辨率
網(wǎng)絡(luò)構(gòu)建
一、什么是SRGAN
SRGAN出自論文Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network。
如果將SRGAN看作一個(gè)黑匣子,其主要的功能就是輸入一張低分辨率圖片,生成高分辨率圖片。
該文章提到,普通的超分辨率模型訓(xùn)練網(wǎng)絡(luò)時(shí)只用到了均方差作為損失函數(shù),雖然能夠獲得很高的峰值信噪比,但是恢復(fù)出來(lái)的圖像通常會(huì)丟失高頻細(xì)節(jié)。
SRGAN利用感知損失(perceptual loss)和對(duì)抗損失(adversarial loss)來(lái)提升恢復(fù)出的圖片的真實(shí)感。
二、生成網(wǎng)絡(luò)的構(gòu)建
生成網(wǎng)絡(luò)的構(gòu)成如上圖所示,生成網(wǎng)絡(luò)的作用是輸入一張低分辨率圖片,生成高分辨率圖片。:
SRGAN的生成網(wǎng)絡(luò)由三個(gè)部分組成。
1、低分辨率圖像進(jìn)入后會(huì)經(jīng)過(guò)一個(gè)卷積+RELU函數(shù)。
2、然后經(jīng)過(guò)B個(gè)殘差網(wǎng)絡(luò)結(jié)構(gòu),每個(gè)殘差結(jié)構(gòu)都包含兩個(gè)卷積+標(biāo)準(zhǔn)化+RELU,還有一個(gè)殘差邊。
3、然后進(jìn)入上采樣部分,在經(jīng)過(guò)兩次上采樣后,原圖的高寬變?yōu)樵瓉?lái)的4倍,實(shí)現(xiàn)分辨率的提升。
前兩個(gè)部分用于特征提取,第三部分用于提高分辨率。
import math import torch from torch import nn class ResidualBlock(nn.Module): def __init__(self, channels): super(ResidualBlock, self).__init__() self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1) self.bn1 = nn.BatchNorm2d(channels) self.prelu = nn.PReLU(channels) self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1) self.bn2 = nn.BatchNorm2d(channels) def forward(self, x): short_cut = x x = self.conv1(x) x = self.bn1(x) x = self.prelu(x) x = self.conv2(x) x = self.bn2(x) return x + short_cut class UpsampleBLock(nn.Module): def __init__(self, in_channels, up_scale): super(UpsampleBLock, self).__init__() self.conv = nn.Conv2d(in_channels, in_channels * up_scale ** 2, kernel_size=3, padding=1) self.pixel_shuffle = nn.PixelShuffle(up_scale) self.prelu = nn.PReLU(in_channels) def forward(self, x): x = self.conv(x) x = self.pixel_shuffle(x) x = self.prelu(x) return x class Generator(nn.Module): def __init__(self, scale_factor, num_residual=16): upsample_block_num = int(math.log(scale_factor, 2)) super(Generator, self).__init__() self.block_in = nn.Sequential( nn.Conv2d(3, 64, kernel_size=9, padding=4), nn.PReLU(64) ) self.blocks = [] for _ in range(num_residual): self.blocks.append(ResidualBlock(64)) self.blocks = nn.Sequential(*self.blocks) self.block_out = nn.Sequential( nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64) ) self.upsample = [UpsampleBLock(64, 2) for _ in range(upsample_block_num)] self.upsample.append(nn.Conv2d(64, 3, kernel_size=9, padding=4)) self.upsample = nn.Sequential(*self.upsample) def forward(self, x): x = self.block_in(x) short_cut = x x = self.blocks(x) x = self.block_out(x) upsample = self.upsample(x + short_cut) return torch.tanh(upsample)
三、判別網(wǎng)絡(luò)的構(gòu)建
判別網(wǎng)絡(luò)的構(gòu)成如上圖所示:
SRGAN的判別網(wǎng)絡(luò)由不斷重復(fù)的 卷積+LeakyRELU和標(biāo)準(zhǔn)化 組成。
對(duì)于判斷網(wǎng)絡(luò)來(lái)講,它的目的是判斷輸入圖片的真假,它的輸入是圖片,輸出是判斷結(jié)果。
判斷結(jié)果處于0-1之間,利用接近1代表判斷為真圖片,接近0代表判斷為假圖片。
判斷網(wǎng)絡(luò)的構(gòu)建和普通卷積網(wǎng)絡(luò)差距不大,都是不斷的卷積對(duì)圖片進(jìn)行下采用,在多次卷積后,最終接一次全連接判斷結(jié)果。
實(shí)現(xiàn)代碼如下:
class Discriminator(nn.Module): def __init__(self): super(Discriminator, self).__init__() self.net = nn.Sequential( nn.Conv2d(3, 64, kernel_size=3, padding=1), nn.LeakyReLU(0.2), nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1), nn.BatchNorm2d(64), nn.LeakyReLU(0.2), nn.Conv2d(64, 128, kernel_size=3, padding=1), nn.BatchNorm2d(128), nn.LeakyReLU(0.2), nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1), nn.BatchNorm2d(128), nn.LeakyReLU(0.2), nn.Conv2d(128, 256, kernel_size=3, padding=1), nn.BatchNorm2d(256), nn.LeakyReLU(0.2), nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1), nn.BatchNorm2d(256), nn.LeakyReLU(0.2), nn.Conv2d(256, 512, kernel_size=3, padding=1), nn.BatchNorm2d(512), nn.LeakyReLU(0.2), nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1), nn.BatchNorm2d(512), nn.LeakyReLU(0.2), nn.AdaptiveAvgPool2d(1), nn.Conv2d(512, 1024, kernel_size=1), nn.LeakyReLU(0.2), nn.Conv2d(1024, 1, kernel_size=1) ) def forward(self, x): batch_size = x.size(0) return torch.sigmoid(self.net(x).view(batch_size))
訓(xùn)練思路
SRGAN的訓(xùn)練可以分為生成器訓(xùn)練和判別器訓(xùn)練:
每一個(gè)step中一般先訓(xùn)練判別器,然后訓(xùn)練生成器。
一、判別器的訓(xùn)練
在訓(xùn)練判別器的時(shí)候我們希望判別器可以判斷輸入圖片的真?zhèn)?,因此我們的輸入就是真圖片、假圖片和它們對(duì)應(yīng)的標(biāo)簽。
因此判別器的訓(xùn)練步驟如下:
1、隨機(jī)選取batch_size個(gè)真實(shí)高分辨率圖片。
2、利用resize后的低分辨率圖片,傳入到Generator中生成batch_size個(gè)虛假高分辨率圖片。
3、真實(shí)圖片的label為1,虛假圖片的label為0,將真實(shí)圖片和虛假圖片當(dāng)作訓(xùn)練集傳入到Discriminator中進(jìn)行訓(xùn)練。
二、生成器的訓(xùn)練
在訓(xùn)練生成器的時(shí)候我們希望生成器可以生成極為真實(shí)的假圖片。因此我們?cè)谟?xùn)練生成器需要知道判別器認(rèn)為什么圖片是真圖片。
因此生成器的訓(xùn)練步驟如下:
1、將低分辨率圖像傳入生成模型,得到虛假高分辨率圖像,將虛假高分辨率圖像獲得判別結(jié)果與1進(jìn)行對(duì)比得到loss。(與1對(duì)比的意思是,讓生成器根據(jù)判別器判別的結(jié)果進(jìn)行訓(xùn)練)。
2、將真實(shí)高分辨率圖像和虛假高分辨率圖像傳入VGG網(wǎng)絡(luò),獲得兩個(gè)圖像的特征,通過(guò)這兩個(gè)圖像的特征進(jìn)行比較獲得loss
利用SRGAN生成圖片
SRGAN的庫(kù)整體結(jié)構(gòu)如下:
一、數(shù)據(jù)集的準(zhǔn)備
在訓(xùn)練前需要準(zhǔn)備好數(shù)據(jù)集,數(shù)據(jù)集保存在datasets文件夾里面。
二、數(shù)據(jù)集的處理
打開txt_annotation.py,默認(rèn)指向根目錄下的datasets。運(yùn)行txt_annotation.py。
此時(shí)生成根目錄下面的train_lines.txt。
三、模型訓(xùn)練
在完成數(shù)據(jù)集處理后,運(yùn)行train.py即可開始訓(xùn)練。
訓(xùn)練過(guò)程中,可在results文件夾內(nèi)查看訓(xùn)練效果:
以上就是Pytorch搭建SRGAN平臺(tái)提升圖片超分辨率的詳細(xì)內(nèi)容,更多關(guān)于Pytorch搭建SRGAN圖片超分辨率的資料請(qǐng)關(guān)注腳本之家其它相關(guān)文章!
相關(guān)文章
利用PyQt5模擬實(shí)現(xiàn)網(wǎng)頁(yè)鼠標(biāo)移動(dòng)特效
不知道大家有沒(méi)有發(fā)現(xiàn),博客園有些博客左側(cè)會(huì)有鼠標(biāo)移動(dòng)特效。通過(guò)移動(dòng)鼠標(biāo),會(huì)形成類似蜘蛛網(wǎng)的特效,本文將用PyQt5實(shí)現(xiàn)這一特效,需要的可以參考一下2022-03-03如何輕松實(shí)現(xiàn)Python數(shù)組降維?
歡迎來(lái)到Python數(shù)組降維實(shí)現(xiàn)方法的指南!這里,你將探索一種神秘又強(qiáng)大的編程技術(shù),想要提升你的Python編程技巧嗎?別猶豫,跟我一起深入探索吧!2024-01-01Python字符串對(duì)象實(shí)現(xiàn)原理詳解
這篇文章主要介紹了Python字符串對(duì)象實(shí)現(xiàn)原理詳解,在Python世界中將對(duì)象分為兩種:一種是定長(zhǎng)對(duì)象,比如整數(shù),整數(shù)對(duì)象定義的時(shí)候就能確定它所占用的內(nèi)存空間大小,另一種是變長(zhǎng)對(duì)象,在對(duì)象定義時(shí)并不知道是多少,需要的朋友可以參考下2019-07-07Python使用matplotlib填充圖形指定區(qū)域代碼示例
這篇文章主要介紹了Python使用matplotlib填充圖形指定區(qū)域代碼示例,具有一定借鑒價(jià)值,需要的朋友可以參考下2018-01-01Python 創(chuàng)建守護(hù)進(jìn)程的示例
這篇文章主要介紹了Python 創(chuàng)建守護(hù)進(jìn)程的示例,幫助大家更好的理解和使用python,感興趣的朋友可以了解下2020-09-09Python處理字符串的常用函數(shù)實(shí)例總結(jié)
在數(shù)據(jù)分析中,特別是文本分析中,字符處理需要耗費(fèi)極大的精力,因而了解字符處理對(duì)于數(shù)據(jù)分析而言,也是一項(xiàng)很重要的能力,這篇文章主要給大家介紹了關(guān)于Python處理字符串的常用函數(shù),需要的朋友可以參考下2021-11-11