欧美bbbwbbbw肥妇,免费乱码人妻系列日韩,一级黄片

Pytorch使用MNIST數(shù)據(jù)集實現(xiàn)基礎(chǔ)GAN和DCGAN詳解

 更新時間:2020年01月10日 10:06:33   作者:shiheyingzhe  
今天小編就為大家分享一篇Pytorch使用MNIST數(shù)據(jù)集實現(xiàn)基礎(chǔ)GAN和DCGAN詳解,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧

原始生成對抗網(wǎng)絡(luò)Generative Adversarial Networks GAN包含生成器Generator和判別器Discriminator,數(shù)據(jù)有真實數(shù)據(jù)groundtruth,還有需要網(wǎng)絡(luò)生成的“fake”數(shù)據(jù),目的是網(wǎng)絡(luò)生成的fake數(shù)據(jù)可以“騙過”判別器,讓判別器認不出來,就是讓判別器分不清進入的數(shù)據(jù)是真實數(shù)據(jù)還是fake數(shù)據(jù)??偟膩碚f是:判別器區(qū)分真實數(shù)據(jù)和fake數(shù)據(jù)的能力越強越好;生成器生成的數(shù)據(jù)騙過判別器的能力越強越好,這個是矛盾的,所以只能交替訓(xùn)練網(wǎng)絡(luò)。

需要搭建生成器網(wǎng)絡(luò)和判別器網(wǎng)絡(luò),訓(xùn)練的時候交替訓(xùn)練。

首先訓(xùn)練判別器的參數(shù),固定生成器的參數(shù),讓判別器判斷生成器生成的數(shù)據(jù),讓其和0接近,讓判別器判斷真實數(shù)據(jù),讓其和1接近;

接著訓(xùn)練生成器的參數(shù),固定判別器的參數(shù),讓生成器生成的數(shù)據(jù)進入判別器,讓判斷結(jié)果和1接近。生成器生成數(shù)據(jù)需要給定隨機初始值

線性版:

import torch
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision import transforms
from torch import optim
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np
import matplotlib.gridspec as gridspec
 
def showimg(images,count):
 images=images.detach().numpy()[0:16,:]
 images=255*(0.5*images+0.5)
 images = images.astype(np.uint8)
 grid_length=int(np.ceil(np.sqrt(images.shape[0])))
 plt.figure(figsize=(4,4))
 width = int(np.sqrt((images.shape[1])))
 gs = gridspec.GridSpec(grid_length,grid_length,wspace=0,hspace=0)
 # gs.update(wspace=0, hspace=0)
 print('starting...')
 for i, img in enumerate(images):
 ax = plt.subplot(gs[i])
 ax.set_xticklabels([])
 ax.set_yticklabels([])
 ax.set_aspect('equal')
 plt.imshow(img.reshape([width,width]),cmap = plt.cm.gray)
 plt.axis('off')
 plt.tight_layout()
 print('showing...')
 plt.tight_layout()
 plt.savefig('./GAN_Image/%d.png'%count, bbox_inches='tight')
 
def loadMNIST(batch_size): #MNIST圖片的大小是28*28
 trans_img=transforms.Compose([transforms.ToTensor()])
 trainset=MNIST('./data',train=True,transform=trans_img,download=True)
 testset=MNIST('./data',train=False,transform=trans_img,download=True)
 # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
 trainloader=DataLoader(trainset,batch_size=batch_size,shuffle=True,num_workers=10)
 testloader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=10)
 return trainset,testset,trainloader,testloader
 
class discriminator(nn.Module):
 def __init__(self):
 super(discriminator,self).__init__()
 self.dis=nn.Sequential(
  nn.Linear(784,300),
  nn.LeakyReLU(0.2),
  nn.Linear(300,150),
  nn.LeakyReLU(0.2),
  nn.Linear(150,1),
  nn.Sigmoid()
 )
 def forward(self, x):
 x=self.dis(x)
 return x
 
class generator(nn.Module):
 def __init__(self,input_size):
 super(generator,self).__init__()
 self.gen=nn.Sequential(
  nn.Linear(input_size,150),
  nn.ReLU(True),
  nn.Linear(150,300),
  nn.ReLU(True),
  nn.Linear(300,784),
  nn.Tanh()
 )
 def forward(self, x):
 x=self.gen(x)
 return x
 
if __name__=="__main__":
 criterion=nn.BCELoss()
 num_img=100
 z_dimension=100
 D=discriminator()
 G=generator(z_dimension)
 trainset, testset, trainloader, testloader = loadMNIST(num_img) # data
 d_optimizer=optim.Adam(D.parameters(),lr=0.0003)
 g_optimizer=optim.Adam(G.parameters(),lr=0.0003)
 '''
 交替訓(xùn)練的方式訓(xùn)練網(wǎng)絡(luò)
 先訓(xùn)練判別器網(wǎng)絡(luò)D再訓(xùn)練生成器網(wǎng)絡(luò)G
 不同網(wǎng)絡(luò)的訓(xùn)練次數(shù)是超參數(shù)
 也可以兩個網(wǎng)絡(luò)訓(xùn)練相同的次數(shù)
 這樣就可以不用分別訓(xùn)練兩個網(wǎng)絡(luò)
 '''
 count=0
 #鑒別器D的訓(xùn)練,固定G的參數(shù)
 epoch = 100
 gepoch = 1
 for i in range(epoch):
 for (img, label) in trainloader:
  # num_img=img.size()[0]
  real_img=img.view(num_img,-1)#展開為28*28=784
  real_label=torch.ones(num_img)#真實label為1
  fake_label=torch.zeros(num_img)#假的label為0
 
  #compute loss of real_img
  real_out=D(real_img) #真實圖片送入判別器D輸出0~1
  d_loss_real=criterion(real_out,real_label)#得到loss
  real_scores=real_out#真實圖片放入判別器輸出越接近1越好
 
  #compute loss of fake_img
  z=torch.randn(num_img,z_dimension)#隨機生成向量
  fake_img=G(z)#將向量放入生成網(wǎng)絡(luò)G生成一張圖片
  fake_out=D(fake_img)#判別器判斷假的圖片
  d_loss_fake=criterion(fake_out,fake_label)#假的圖片的loss
  fake_scores=fake_out#假的圖片放入判別器輸出越接近0越好
 
  #D bp and optimize
  d_loss=d_loss_real+d_loss_fake
  d_optimizer.zero_grad() #判別器D的梯度歸零
  d_loss.backward() #反向傳播
  d_optimizer.step() #更新判別器D參數(shù)
 
  #生成器G的訓(xùn)練compute loss of fake_img
  for j in range(gepoch):
  fake_label = torch.ones(num_img) # 真實label為1
  z = torch.randn(num_img, z_dimension) # 隨機生成向量
  fake_img = G(z) # 將向量放入生成網(wǎng)絡(luò)G生成一張圖片
  output = D(fake_img) # 經(jīng)過判別器得到結(jié)果
  g_loss = criterion(output, fake_label)#得到假的圖片與真實標(biāo)簽的loss
  #bp and optimize
  g_optimizer.zero_grad() #生成器G的梯度歸零
  g_loss.backward() #反向傳播
  g_optimizer.step()#更新生成器G參數(shù)
 print('Epoch [{}/{}], d_loss: {:.6f}, g_loss: {:.6f} '
   'D real: {:.6f}, D fake: {:.6f}'.format(
  i, epoch, d_loss.data[0], g_loss.data[0],
  real_scores.data.mean(), fake_scores.data.mean()))
 showimg(fake_img,count)
 # plt.show()
 count += 1

這里的圖分別是 epoch為0、50、100、150、190的運行結(jié)果,可以看到圖片中的數(shù)字并不單一

卷積版 Deep Convolutional Generative Adversarial Networks:

import torch
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision import transforms
from torch import optim
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np
from torch.autograd import Variable
 
import matplotlib.gridspec as gridspec
import os
 
def showimg(images,count):
 images=images.to('cpu')
 images=images.detach().numpy()
 images=images[[6, 12, 18, 24, 30, 36, 42, 48, 54, 60, 66, 72, 78, 84, 90, 96]]
 images=255*(0.5*images+0.5)
 images = images.astype(np.uint8)
 grid_length=int(np.ceil(np.sqrt(images.shape[0])))
 plt.figure(figsize=(4,4))
 width = images.shape[2]
 gs = gridspec.GridSpec(grid_length,grid_length,wspace=0,hspace=0)
 print(images.shape)
 for i, img in enumerate(images):
 ax = plt.subplot(gs[i])
 ax.set_xticklabels([])
 ax.set_yticklabels([])
 ax.set_aspect('equal')
 plt.imshow(img.reshape(width,width),cmap = plt.cm.gray)
 plt.axis('off')
 plt.tight_layout()
# print('showing...')
 plt.tight_layout()
# plt.savefig('./GAN_Imaget/%d.png'%count, bbox_inches='tight')
 
def loadMNIST(batch_size): #MNIST圖片的大小是28*28
 trans_img=transforms.Compose([transforms.ToTensor()])
 trainset=MNIST('./data',train=True,transform=trans_img,download=True)
 testset=MNIST('./data',train=False,transform=trans_img,download=True)
 # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
 trainloader=DataLoader(trainset,batch_size=batch_size,shuffle=True,num_workers=10)
 testloader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=10)
 return trainset,testset,trainloader,testloader
 
class discriminator(nn.Module):
 def __init__(self):
 super(discriminator,self).__init__()
 self.dis=nn.Sequential(
  nn.Conv2d(1,32,5,stride=1,padding=2),
  nn.LeakyReLU(0.2,True),
  nn.MaxPool2d((2,2)),
 
  nn.Conv2d(32,64,5,stride=1,padding=2),
  nn.LeakyReLU(0.2,True),
  nn.MaxPool2d((2,2))
 )
 self.fc=nn.Sequential(
  nn.Linear(7 * 7 * 64, 1024),
  nn.LeakyReLU(0.2, True),
  nn.Linear(1024, 1),
  nn.Sigmoid()
 )
 def forward(self, x):
 x=self.dis(x)
 x=x.view(x.size(0),-1)
 x=self.fc(x)
 return x
 
class generator(nn.Module):
 def __init__(self,input_size,num_feature):
 super(generator,self).__init__()
 self.fc=nn.Linear(input_size,num_feature) #1*56*56
 self.br=nn.Sequential(
  nn.BatchNorm2d(1),
  nn.ReLU(True)
 )
 self.gen=nn.Sequential(
  nn.Conv2d(1,50,3,stride=1,padding=1),
  nn.BatchNorm2d(50),
  nn.ReLU(True),
 
  nn.Conv2d(50,25,3,stride=1,padding=1),
  nn.BatchNorm2d(25),
  nn.ReLU(True),
 
  nn.Conv2d(25,1,2,stride=2),
  nn.Tanh()
 )
 def forward(self, x):
 x=self.fc(x)
 x=x.view(x.size(0),1,56,56)
 x=self.br(x)
 x=self.gen(x)
 return x
 
if __name__=="__main__":
 criterion=nn.BCELoss()
 num_img=100
 z_dimension=100
 D=discriminator()
 G=generator(z_dimension,3136) #1*56*56
 trainset, testset, trainloader, testloader = loadMNIST(num_img) # data
 D=D.cuda()
 G=G.cuda()
 d_optimizer=optim.Adam(D.parameters(),lr=0.0003)
 g_optimizer=optim.Adam(G.parameters(),lr=0.0003)
 '''
 交替訓(xùn)練的方式訓(xùn)練網(wǎng)絡(luò)
 先訓(xùn)練判別器網(wǎng)絡(luò)D再訓(xùn)練生成器網(wǎng)絡(luò)G
 不同網(wǎng)絡(luò)的訓(xùn)練次數(shù)是超參數(shù)
 也可以兩個網(wǎng)絡(luò)訓(xùn)練相同的次數(shù),
 這樣就可以不用分別訓(xùn)練兩個網(wǎng)絡(luò)
 '''
 count=0
 #鑒別器D的訓(xùn)練,固定G的參數(shù)
 epoch = 100
 gepoch = 1
 for i in range(epoch):
 for (img, label) in trainloader:
  # num_img=img.size()[0]
  img=Variable(img).cuda()
  real_label=Variable(torch.ones(num_img)).cuda()#真實label為1
  fake_label=Variable(torch.zeros(num_img)).cuda()#假的label為0
 
  #compute loss of real_img
  real_out=D(img) #真實圖片送入判別器D輸出0~1
  d_loss_real=criterion(real_out,real_label)#得到loss
  real_scores=real_out#真實圖片放入判別器輸出越接近1越好
 
  #compute loss of fake_img
  z=Variable(torch.randn(num_img,z_dimension)).cuda()#隨機生成向量
  fake_img=G(z)#將向量放入生成網(wǎng)絡(luò)G生成一張圖片
  fake_out=D(fake_img)#判別器判斷假的圖片
  d_loss_fake=criterion(fake_out,fake_label)#假的圖片的loss
  fake_scores=fake_out#假的圖片放入判別器輸出越接近0越好
 
  #D bp and optimize
  d_loss=d_loss_real+d_loss_fake
  d_optimizer.zero_grad() #判別器D的梯度歸零
  d_loss.backward() #反向傳播
  d_optimizer.step() #更新判別器D參數(shù)
 
  #生成器G的訓(xùn)練compute loss of fake_img
  for j in range(gepoch):
  fake_label = Variable(torch.ones(num_img)).cuda() # 真實label為1
  z = Variable(torch.randn(num_img, z_dimension)).cuda() # 隨機生成向量
  fake_img = G(z) # 將向量放入生成網(wǎng)絡(luò)G生成一張圖片
  output = D(fake_img) # 經(jīng)過判別器得到結(jié)果
  g_loss = criterion(output, fake_label)#得到假的圖片與真實標(biāo)簽的loss
  #bp and optimize
  g_optimizer.zero_grad() #生成器G的梯度歸零
  g_loss.backward() #反向傳播
  g_optimizer.step()#更新生成器G參數(shù)
  # if ((i+1)%1000==0):
  # print("[%d/%d] GLoss: %.5f" % (i + 1, gepoch, g_loss.data[0]))
 print('Epoch [{}/{}], d_loss: {:.6f}, g_loss: {:.6f} '
   'D real: {:.6f}, D fake: {:.6f}'.format(
  i, epoch, d_loss.data[0], g_loss.data[0],
  real_scores.data.mean(), fake_scores.data.mean()))
 showimg(fake_img,count)
 plt.show()
 count += 1

這里的gepoch設(shè)置為1,運行39次的結(jié)果是:

gepoch設(shè)置為2,運行0、25、50、75、100次的結(jié)果是:

gepoch設(shè)置為3,運行25、50、75次的結(jié)果是:

gepoch設(shè)置為4,運行0、10、20、30、35次的結(jié)果是:

gepoch設(shè)置為5,運行0、10、20、25、29次的結(jié)果是:

gepoch設(shè)置為3,z_dimension設(shè)置為190,epoch運行0、10、15、20、25、35的結(jié)果是:

可以看到生成的數(shù)字基本沒有太多的規(guī)律,可能最終都是同個數(shù)字,不能生成指定的數(shù)字,CGAN就很好的解決這個問題,可以生成指定的數(shù)字 Pytorch使用MNIST數(shù)據(jù)集實現(xiàn)CGAN和生成指定的數(shù)字方式

以上這篇Pytorch使用MNIST數(shù)據(jù)集實現(xiàn)基礎(chǔ)GAN和DCGAN詳解就是小編分享給大家的全部內(nèi)容了,希望能給大家一個參考,也希望大家多多支持腳本之家。

相關(guān)文章

  • python中可以聲明變量類型嗎

    python中可以聲明變量類型嗎

    在本篇文章里小編給大家整理了關(guān)于python中聲明變量類型的相關(guān)知識點,需要的朋友們可以學(xué)習(xí)下。
    2020-06-06
  • Python內(nèi)置函數(shù)詳細解析

    Python內(nèi)置函數(shù)詳細解析

    這篇文章主要介紹了Python內(nèi)置函數(shù)詳細解析,Python?自帶了很多的內(nèi)置函數(shù),極大地方便了我們的開發(fā),下文小編總結(jié)了一些內(nèi)置函數(shù)的相關(guān)內(nèi)容,需要的小伙伴可以參考一下
    2022-05-05
  • python局部賦值的規(guī)則

    python局部賦值的規(guī)則

    Python提出如下假設(shè):如果在函數(shù)體內(nèi)的任何地方對變量賦值,則Python將名稱添加到局部命名空間中。
    2013-03-03
  • 全面了解python中的類,對象,方法,屬性

    全面了解python中的類,對象,方法,屬性

    下面小編就為大家?guī)硪黄媪私鈖ython中的類,對象,方法,屬性。小編覺得挺不錯的,現(xiàn)在就分享給大家,也給大家做個參考。一起跟隨小編過來看看吧
    2016-09-09
  • python 實現(xiàn)長數(shù)據(jù)完整打印方案

    python 實現(xiàn)長數(shù)據(jù)完整打印方案

    這篇文章主要介紹了python 實現(xiàn)長數(shù)據(jù)完整打印方案,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧
    2021-03-03
  • python?實時獲取kafka消費隊列信息示例詳解

    python?實時獲取kafka消費隊列信息示例詳解

    這篇文章主要介紹了python實時獲取kafka消費隊列信息,本文通過實例代碼給大家介紹的非常詳細,對大家的學(xué)習(xí)或工作具有一定的參考借鑒價值,需要的朋友可以參考下
    2023-07-07
  • python爬取”頂點小說網(wǎng)“《純陽劍尊》的示例代碼

    python爬取”頂點小說網(wǎng)“《純陽劍尊》的示例代碼

    這篇文章主要介紹了python爬取”頂點小說網(wǎng)“《純陽劍尊》的示例代碼,幫助大家更好的利用python 爬蟲爬取數(shù)據(jù),感興趣的朋友可以了解下
    2020-10-10
  • 一文了解Python3的錯誤和異常

    一文了解Python3的錯誤和異常

    Python 的語法錯誤或者稱之為解析錯,是初學(xué)者經(jīng)常碰到的。即便 Python 程序的語法是正確的,在運行它的時候,也有可能發(fā)生錯誤。運行期檢測到的錯誤被稱為異常。本文就來和大家聊聊Python3的錯誤和異常,感興趣的可以學(xué)習(xí)一下
    2022-09-09
  • python之pyinstaller組件打包命令和異常解析實戰(zhàn)

    python之pyinstaller組件打包命令和異常解析實戰(zhàn)

    前段時間在制作小工具的時候,直接在命令行用pyinstaller工具打包成功后,啟動exe可執(zhí)行文件的時候各種報錯, 今天,我們就分享一下踩坑經(jīng)過,需要的朋友可以參考下
    2021-09-09
  • Python類定義和類繼承詳解

    Python類定義和類繼承詳解

    這篇文章主要介紹了Python類定義和類繼承詳解,本文講解了類的私有屬性、類的方法、私有的類方法、類的專有方法、類的定義、類的單繼承、類的多繼承等內(nèi)容,需要的朋友可以參考下
    2015-05-05

最新評論