欧美bbbwbbbw肥妇,免费乱码人妻系列日韩,一级黄片

Pytorch搭建SRGAN平臺提升圖片超分辨率

 更新時間:2022年04月29日 17:37:32   作者:Bubbliiiing  
這篇文章主要為大家介紹了Pytorch搭建SRGAN平臺提升圖片超分辨率,有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進步,早日升職加薪

源碼下載地址

網(wǎng)絡(luò)構(gòu)建

一、什么是SRGAN

SRGAN出自論文Photo-Realistic Single Image Super-Resolution Using a Generative Adversarial Network。

如果將SRGAN看作一個黑匣子,其主要的功能就是輸入一張低分辨率圖片,生成高分辨率圖片。

該文章提到,普通的超分辨率模型訓練網(wǎng)絡(luò)時只用到了均方差作為損失函數(shù),雖然能夠獲得很高的峰值信噪比,但是恢復出來的圖像通常會丟失高頻細節(jié)。

SRGAN利用感知損失(perceptual loss)和對抗損失(adversarial loss)來提升恢復出的圖片的真實感。

二、生成網(wǎng)絡(luò)的構(gòu)建

生成網(wǎng)絡(luò)的構(gòu)成如上圖所示,生成網(wǎng)絡(luò)的作用是輸入一張低分辨率圖片,生成高分辨率圖片。:

SRGAN的生成網(wǎng)絡(luò)由三個部分組成。

1、低分辨率圖像進入后會經(jīng)過一個卷積+RELU函數(shù)。

2、然后經(jīng)過B個殘差網(wǎng)絡(luò)結(jié)構(gòu),每個殘差結(jié)構(gòu)都包含兩個卷積+標準化+RELU,還有一個殘差邊。

3、然后進入上采樣部分,在經(jīng)過兩次上采樣后,原圖的高寬變?yōu)樵瓉淼?倍,實現(xiàn)分辨率的提升。

前兩個部分用于特征提取,第三部分用于提高分辨率。

import math
import torch
from torch import nn
class ResidualBlock(nn.Module):
    def __init__(self, channels):
        super(ResidualBlock, self).__init__()
        self.conv1 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.bn1 = nn.BatchNorm2d(channels)
        self.prelu = nn.PReLU(channels)
        self.conv2 = nn.Conv2d(channels, channels, kernel_size=3, padding=1)
        self.bn2 = nn.BatchNorm2d(channels)
    def forward(self, x):
        short_cut = x
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.prelu(x)
        x = self.conv2(x)
        x = self.bn2(x)
        return x + short_cut
class UpsampleBLock(nn.Module):
    def __init__(self, in_channels, up_scale):
        super(UpsampleBLock, self).__init__()
        self.conv = nn.Conv2d(in_channels, in_channels * up_scale ** 2, kernel_size=3, padding=1)
        self.pixel_shuffle = nn.PixelShuffle(up_scale)
        self.prelu = nn.PReLU(in_channels)
    def forward(self, x):
        x = self.conv(x)
        x = self.pixel_shuffle(x)
        x = self.prelu(x)
        return x
class Generator(nn.Module):
    def __init__(self, scale_factor, num_residual=16):
        upsample_block_num = int(math.log(scale_factor, 2))
        super(Generator, self).__init__()
        self.block_in = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=9, padding=4),
            nn.PReLU(64)
        )
        self.blocks = []
        for _ in range(num_residual):
            self.blocks.append(ResidualBlock(64))
        self.blocks = nn.Sequential(*self.blocks)
        self.block_out = nn.Sequential(
            nn.Conv2d(64, 64, kernel_size=3, padding=1),
            nn.BatchNorm2d(64)
        )
        self.upsample = [UpsampleBLock(64, 2) for _ in range(upsample_block_num)]
        self.upsample.append(nn.Conv2d(64, 3, kernel_size=9, padding=4))
        self.upsample = nn.Sequential(*self.upsample)
    def forward(self, x):
        x = self.block_in(x)
        short_cut = x
        x = self.blocks(x)
        x = self.block_out(x)
        upsample = self.upsample(x + short_cut)
        return torch.tanh(upsample)

三、判別網(wǎng)絡(luò)的構(gòu)建

判別網(wǎng)絡(luò)的構(gòu)成如上圖所示:

SRGAN的判別網(wǎng)絡(luò)由不斷重復的 卷積+LeakyRELU和標準化 組成。
對于判斷網(wǎng)絡(luò)來講,它的目的是判斷輸入圖片的真假,它的輸入是圖片,輸出是判斷結(jié)果。

判斷結(jié)果處于0-1之間,利用接近1代表判斷為真圖片,接近0代表判斷為假圖片。

判斷網(wǎng)絡(luò)的構(gòu)建和普通卷積網(wǎng)絡(luò)差距不大,都是不斷的卷積對圖片進行下采用,在多次卷積后,最終接一次全連接判斷結(jié)果。

實現(xiàn)代碼如下:

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.net = nn.Sequential(
            nn.Conv2d(3, 64, kernel_size=3, padding=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 64, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(64),
            nn.LeakyReLU(0.2),
            nn.Conv2d(64, 128, kernel_size=3, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.Conv2d(128, 128, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(128),
            nn.LeakyReLU(0.2),
            nn.Conv2d(128, 256, kernel_size=3, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),
            nn.Conv2d(256, 256, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(256),
            nn.LeakyReLU(0.2),
            nn.Conv2d(256, 512, kernel_size=3, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2),
            nn.Conv2d(512, 512, kernel_size=3, stride=2, padding=1),
            nn.BatchNorm2d(512),
            nn.LeakyReLU(0.2),
            nn.AdaptiveAvgPool2d(1),
            nn.Conv2d(512, 1024, kernel_size=1),
            nn.LeakyReLU(0.2),
            nn.Conv2d(1024, 1, kernel_size=1)
        )
    def forward(self, x):
        batch_size = x.size(0)
        return torch.sigmoid(self.net(x).view(batch_size))

訓練思路

SRGAN的訓練可以分為生成器訓練和判別器訓練:
每一個step中一般先訓練判別器,然后訓練生成器。

一、判別器的訓練

在訓練判別器的時候我們希望判別器可以判斷輸入圖片的真?zhèn)危虼宋覀兊妮斎刖褪钦鎴D片、假圖片和它們對應(yīng)的標簽。

因此判別器的訓練步驟如下:

1、隨機選取batch_size個真實高分辨率圖片。

2、利用resize后的低分辨率圖片,傳入到Generator中生成batch_size個虛假高分辨率圖片。

3、真實圖片的label為1,虛假圖片的label為0,將真實圖片和虛假圖片當作訓練集傳入到Discriminator中進行訓練。

二、生成器的訓練

在訓練生成器的時候我們希望生成器可以生成極為真實的假圖片。因此我們在訓練生成器需要知道判別器認為什么圖片是真圖片。

因此生成器的訓練步驟如下:

1、將低分辨率圖像傳入生成模型,得到虛假高分辨率圖像,將虛假高分辨率圖像獲得判別結(jié)果與1進行對比得到loss。(與1對比的意思是,讓生成器根據(jù)判別器判別的結(jié)果進行訓練)。

2、將真實高分辨率圖像和虛假高分辨率圖像傳入VGG網(wǎng)絡(luò),獲得兩個圖像的特征,通過這兩個圖像的特征進行比較獲得loss

利用SRGAN生成圖片

SRGAN的庫整體結(jié)構(gòu)如下:

一、數(shù)據(jù)集的準備

在訓練前需要準備好數(shù)據(jù)集,數(shù)據(jù)集保存在datasets文件夾里面。

二、數(shù)據(jù)集的處理

打開txt_annotation.py,默認指向根目錄下的datasets。運行txt_annotation.py。
此時生成根目錄下面的train_lines.txt。

三、模型訓練

在完成數(shù)據(jù)集處理后,運行train.py即可開始訓練。

訓練過程中,可在results文件夾內(nèi)查看訓練效果:

以上就是Pytorch搭建SRGAN平臺提升圖片超分辨率的詳細內(nèi)容,更多關(guān)于Pytorch搭建SRGAN圖片超分辨率的資料請關(guān)注腳本之家其它相關(guān)文章!

相關(guān)文章

  • 利用PyQt5模擬實現(xiàn)網(wǎng)頁鼠標移動特效

    利用PyQt5模擬實現(xiàn)網(wǎng)頁鼠標移動特效

    不知道大家有沒有發(fā)現(xiàn),博客園有些博客左側(cè)會有鼠標移動特效。通過移動鼠標,會形成類似蜘蛛網(wǎng)的特效,本文將用PyQt5實現(xiàn)這一特效,需要的可以參考一下
    2022-03-03
  • 如何輕松實現(xiàn)Python數(shù)組降維?

    如何輕松實現(xiàn)Python數(shù)組降維?

    歡迎來到Python數(shù)組降維實現(xiàn)方法的指南!這里,你將探索一種神秘又強大的編程技術(shù),想要提升你的Python編程技巧嗎?別猶豫,跟我一起深入探索吧!
    2024-01-01
  • PyQt5 pyqt多線程操作入門

    PyQt5 pyqt多線程操作入門

    本篇文章主要介紹了PyQt5 pyqt多線程操作入門,小編覺得挺不錯的,現(xiàn)在分享給大家,也給大家做個參考。一起跟隨小編過來看看吧
    2018-05-05
  • Python爬蟲Requests庫的使用詳情

    Python爬蟲Requests庫的使用詳情

    這篇文章主要介紹了Python爬蟲Requests庫的使用詳情,文章圍繞主題展開詳細的內(nèi)容介紹,具有一定的參考價值,需要的小伙伴可以參考一下
    2022-08-08
  • Python字符串對象實現(xiàn)原理詳解

    Python字符串對象實現(xiàn)原理詳解

    這篇文章主要介紹了Python字符串對象實現(xiàn)原理詳解,在Python世界中將對象分為兩種:一種是定長對象,比如整數(shù),整數(shù)對象定義的時候就能確定它所占用的內(nèi)存空間大小,另一種是變長對象,在對象定義時并不知道是多少,需要的朋友可以參考下
    2019-07-07
  • Python使用matplotlib填充圖形指定區(qū)域代碼示例

    Python使用matplotlib填充圖形指定區(qū)域代碼示例

    這篇文章主要介紹了Python使用matplotlib填充圖形指定區(qū)域代碼示例,具有一定借鑒價值,需要的朋友可以參考下
    2018-01-01
  • python中安裝模塊包版本沖突問題的解決

    python中安裝模塊包版本沖突問題的解決

    這篇文章主要給大家介紹了在python中安裝模塊包版本沖突問題的解決方法,文中介紹了該問題的原因與解決方法,需要的朋友可以參考借鑒,下面來一起看看吧。
    2017-05-05
  • Python 創(chuàng)建守護進程的示例

    Python 創(chuàng)建守護進程的示例

    這篇文章主要介紹了Python 創(chuàng)建守護進程的示例,幫助大家更好的理解和使用python,感興趣的朋友可以了解下
    2020-09-09
  • 解決python線程卡死的問題

    解決python線程卡死的問題

    今天小編就為大家分享一篇解決python線程卡死的問題,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧
    2019-02-02
  • Python處理字符串的常用函數(shù)實例總結(jié)

    Python處理字符串的常用函數(shù)實例總結(jié)

    在數(shù)據(jù)分析中,特別是文本分析中,字符處理需要耗費極大的精力,因而了解字符處理對于數(shù)據(jù)分析而言,也是一項很重要的能力,這篇文章主要給大家介紹了關(guān)于Python處理字符串的常用函數(shù),需要的朋友可以參考下
    2021-11-11

最新評論