欧美bbbwbbbw肥妇,免费乱码人妻系列日韩,一级黄片

Pytorch使用MNIST數(shù)據(jù)集實現(xiàn)CGAN和生成指定的數(shù)字方式

 更新時間:2020年01月10日 09:58:04   作者:shiheyingzhe  
今天小編就為大家分享一篇Pytorch使用MNIST數(shù)據(jù)集實現(xiàn)CGAN和生成指定的數(shù)字方式,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧

CGAN的全拼是Conditional Generative Adversarial Networks,條件生成對抗網(wǎng)絡,在初始GAN的基礎上增加了圖片的相應信息。

這里用傳統(tǒng)的卷積方式實現(xiàn)CGAN。

import torch
from torch.utils.data import DataLoader
from torchvision.datasets import MNIST
from torchvision import transforms
from torch import optim
import torch.nn as nn
import matplotlib.pyplot as plt
import numpy as np
from torch.autograd import Variable
import pickle
import copy
 
import matplotlib.gridspec as gridspec
import os
 
def save_model(model, filename): #保存為CPU中可以打開的模型
 state = model.state_dict()
 x=state.copy()
 for key in x: 
  x[key] = x[key].clone().cpu()
 torch.save(x, filename)
 
def showimg(images,count):
 images=images.to('cpu')
 images=images.detach().numpy()
 images=images[[6, 12, 18, 24, 30, 36, 42, 48, 54, 60, 66, 72, 78, 84, 90, 96]]
 images=255*(0.5*images+0.5)
 images = images.astype(np.uint8)
 grid_length=int(np.ceil(np.sqrt(images.shape[0])))
 plt.figure(figsize=(4,4))
 width = images.shape[2]
 gs = gridspec.GridSpec(grid_length,grid_length,wspace=0,hspace=0)
 for i, img in enumerate(images):
  ax = plt.subplot(gs[i])
  ax.set_xticklabels([])
  ax.set_yticklabels([])
  ax.set_aspect('equal')
  plt.imshow(img.reshape(width,width),cmap = plt.cm.gray)
  plt.axis('off')
  plt.tight_layout()
#  plt.tight_layout()
 plt.savefig(r'./CGAN/images/%d.png'% count, bbox_inches='tight')
 
def loadMNIST(batch_size): #MNIST圖片的大小是28*28
 trans_img=transforms.Compose([transforms.ToTensor()])
 trainset=MNIST('./data',train=True,transform=trans_img,download=True)
 testset=MNIST('./data',train=False,transform=trans_img,download=True)
 # device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
 trainloader=DataLoader(trainset,batch_size=batch_size,shuffle=True,num_workers=10)
 testloader = DataLoader(testset, batch_size=batch_size, shuffle=False, num_workers=10)
 return trainset,testset,trainloader,testloader
 
class discriminator(nn.Module):
 def __init__(self):
  super(discriminator,self).__init__()
  self.dis=nn.Sequential(
   nn.Conv2d(1,32,5,stride=1,padding=2),
   nn.LeakyReLU(0.2,True),
   nn.MaxPool2d((2,2)),
 
   nn.Conv2d(32,64,5,stride=1,padding=2),
   nn.LeakyReLU(0.2,True),
   nn.MaxPool2d((2,2))
  )
  self.fc=nn.Sequential(
   nn.Linear(7 * 7 * 64, 1024),
   nn.LeakyReLU(0.2, True),
   nn.Linear(1024, 10),
   nn.Sigmoid()
  )
 def forward(self, x):
  x=self.dis(x)
  x=x.view(x.size(0),-1)
  x=self.fc(x)
  return x
 
class generator(nn.Module):
 def __init__(self,input_size,num_feature):
  super(generator,self).__init__()
  self.fc=nn.Linear(input_size,num_feature) #1*56*56
  self.br=nn.Sequential(
   nn.BatchNorm2d(1),
   nn.ReLU(True)
  )
  self.gen=nn.Sequential(
   nn.Conv2d(1,50,3,stride=1,padding=1),
   nn.BatchNorm2d(50),
   nn.ReLU(True),
 
   nn.Conv2d(50,25,3,stride=1,padding=1),
   nn.BatchNorm2d(25),
   nn.ReLU(True),
 
   nn.Conv2d(25,1,2,stride=2),
   nn.Tanh()
  )
 def forward(self, x):
  x=self.fc(x)
  x=x.view(x.size(0),1,56,56)
  x=self.br(x)
  x=self.gen(x)
  return x
 
if __name__=="__main__":
 criterion=nn.BCELoss()
 num_img=100
 z_dimension=110
 D=discriminator()
 G=generator(z_dimension,3136) #1*56*56
 trainset, testset, trainloader, testloader = loadMNIST(num_img) # data
 D=D.cuda()
 G=G.cuda()
 d_optimizer=optim.Adam(D.parameters(),lr=0.0003)
 g_optimizer=optim.Adam(G.parameters(),lr=0.0003)
 '''
 交替訓練的方式訓練網(wǎng)絡
 先訓練判別器網(wǎng)絡D再訓練生成器網(wǎng)絡G
 不同網(wǎng)絡的訓練次數(shù)是超參數(shù)
 也可以兩個網(wǎng)絡訓練相同的次數(shù),
 這樣就可以不用分別訓練兩個網(wǎng)絡
 '''
 count=0
 #鑒別器D的訓練,固定G的參數(shù)
 epoch = 119
 gepoch = 1
 for i in range(epoch):
  for (img, label) in trainloader:
   labels_onehot = np.zeros((num_img,10))
   labels_onehot[np.arange(num_img),label.numpy()]=1
#    img=img.view(num_img,-1)
#    img=np.concatenate((img.numpy(),labels_onehot))
#    img=torch.from_numpy(img)
   img=Variable(img).cuda()
   real_label=Variable(torch.from_numpy(labels_onehot).float()).cuda()#真實label為1
   fake_label=Variable(torch.zeros(num_img,10)).cuda()#假的label為0
 
   #compute loss of real_img
   real_out=D(img) #真實圖片送入判別器D輸出0~1
   d_loss_real=criterion(real_out,real_label)#得到loss
   real_scores=real_out#真實圖片放入判別器輸出越接近1越好
 
   #compute loss of fake_img
   z=Variable(torch.randn(num_img,z_dimension)).cuda()#隨機生成向量
   fake_img=G(z)#將向量放入生成網(wǎng)絡G生成一張圖片
   fake_out=D(fake_img)#判別器判斷假的圖片
   d_loss_fake=criterion(fake_out,fake_label)#假的圖片的loss
   fake_scores=fake_out#假的圖片放入判別器輸出越接近0越好
 
   #D bp and optimize
   d_loss=d_loss_real+d_loss_fake
   d_optimizer.zero_grad() #判別器D的梯度歸零
   d_loss.backward() #反向傳播
   d_optimizer.step() #更新判別器D參數(shù)
 
   #生成器G的訓練compute loss of fake_img
   for j in range(gepoch):
    z =torch.randn(num_img, 100) # 隨機生成向量
    z=np.concatenate((z.numpy(),labels_onehot),axis=1)
    z=Variable(torch.from_numpy(z).float()).cuda()
    fake_img = G(z) # 將向量放入生成網(wǎng)絡G生成一張圖片
    output = D(fake_img) # 經(jīng)過判別器得到結果
    g_loss = criterion(output, real_label)#得到假的圖片與真實標簽的loss
    #bp and optimize
    g_optimizer.zero_grad() #生成器G的梯度歸零
    g_loss.backward() #反向傳播
    g_optimizer.step()#更新生成器G參數(shù)
    temp=real_label
  if (i%10==0) and (i!=0):
   print(i)
   torch.save(G.state_dict(),r'./CGAN/Generator_cuda_%d.pkl'%i)
   torch.save(D.state_dict(), r'./CGAN/Discriminator_cuda_%d.pkl' % i)
   save_model(G, r'./CGAN/Generator_cpu_%d.pkl'%i) #保存為CPU中可以打開的模型
   save_model(D, r'./CGAN/Discriminator_cpu_%d.pkl'%i) #保存為CPU中可以打開的模型
  print('Epoch [{}/{}], d_loss: {:.6f}, g_loss: {:.6f} '
     'D real: {:.6f}, D fake: {:.6f}'.format(
    i, epoch, d_loss.data[0], g_loss.data[0],
    real_scores.data.mean(), fake_scores.data.mean()))
  temp=temp.to('cpu')
  _,x=torch.max(temp,1)
  x=x.numpy()
  print(x[[6, 12, 18, 24, 30, 36, 42, 48, 54, 60, 66, 72, 78, 84, 90, 96]])
  showimg(fake_img,count)
  plt.show()
  count += 1

和基礎GAN Pytorch使用MNIST數(shù)據(jù)集實現(xiàn)基礎GAN 里面的卷積版網(wǎng)絡比較起來,這里修改的主要是這幾個地方:

生成網(wǎng)絡的輸入值增加了真實圖片的類標簽,生成網(wǎng)絡的初始向量z_dimension之前用的是100維,由于MNIST有10類,Onehot以后一張圖片的類標簽是10維,所以將類標簽放在后面z_dimension=100+10=110維;

訓練生成器的時候,由于生成網(wǎng)絡的輸入向量z_dimension=110維,而且是100維隨機向量和10維真實圖片標簽拼接,需要做相應的拼接操作;

z =torch.randn(num_img, 100) # 隨機生成向量
z=np.concatenate((z.numpy(),labels_onehot),axis=1)
z=Variable(torch.from_numpy(z).float()).cuda()

由于計算Loss和生成網(wǎng)絡的輸入向量都需要用到真實圖片的類標簽,需要重新生成real_label,對label進行onehot。其中real_label就是真實圖片的標簽,當num_img=100時,real_label的維度是(100,10);

labels_onehot = np.zeros((num_img,10))
labels_onehot[np.arange(num_img),label.numpy()]=1
img=Variable(img).cuda()
real_label=Variable(torch.from_numpy(labels_onehot).float()).cuda()#真實label為1
fake_label=Variable(torch.zeros(num_img,10)).cuda()#假的label為0

real_label的維度是(100,10),計算Loss的時候也要有對應的維度,判別網(wǎng)絡的輸出也不再是標量,而是要修改為10維;

nn.Linear(1024, 10)

在輸出圖片的同時輸出期望的類標簽。

temp=temp.to('cpu')
_,x=torch.max(temp,1)#返回值有兩個,第一個是按列的最大值,第二個是相應最大值的列標號
x=x.numpy()
print(x[[6, 12, 18, 24, 30, 36, 42, 48, 54, 60, 66, 72, 78, 84, 90, 96]])

epoch等于0、25、50、75、100時訓練的結果:

可以看到訓練到后面圖像反而變模糊可能是訓練過擬合

用模型生成指定的數(shù)字:

在訓練的過程中保存了訓練好的模型,根據(jù)輸出圖片的清晰度,用清晰度較高的模型,使用隨機向量和10維類標簽來指定生成的數(shù)字。

import torch
import torch.nn as nn
import pickle
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
 
num_img=9
class discriminator(nn.Module):
 def __init__(self):
  super(discriminator, self).__init__()
  self.dis = nn.Sequential(
   nn.Conv2d(1, 32, 5, stride=1, padding=2),
   nn.LeakyReLU(0.2, True),
   nn.MaxPool2d((2, 2)),
 
   nn.Conv2d(32, 64, 5, stride=1, padding=2),
   nn.LeakyReLU(0.2, True),
   nn.MaxPool2d((2, 2))
  )
  self.fc = nn.Sequential(
   nn.Linear(7 * 7 * 64, 1024),
   nn.LeakyReLU(0.2, True),
   nn.Linear(1024, 10),
   nn.Sigmoid()
  )
 
 def forward(self, x):
  x = self.dis(x)
  x = x.view(x.size(0), -1)
  x = self.fc(x)
  return x
 
 
class generator(nn.Module):
 def __init__(self, input_size, num_feature):
  super(generator, self).__init__()
  self.fc = nn.Linear(input_size, num_feature) # 1*56*56
  self.br = nn.Sequential(
   nn.BatchNorm2d(1),
   nn.ReLU(True)
  )
  self.gen = nn.Sequential(
   nn.Conv2d(1, 50, 3, stride=1, padding=1),
   nn.BatchNorm2d(50),
   nn.ReLU(True),
 
   nn.Conv2d(50, 25, 3, stride=1, padding=1),
   nn.BatchNorm2d(25),
   nn.ReLU(True),
 
   nn.Conv2d(25, 1, 2, stride=2),
   nn.Tanh()
  )
 
 def forward(self, x):
  x = self.fc(x)
  x = x.view(x.size(0), 1, 56, 56)
  x = self.br(x)
  x = self.gen(x)
  return x
 
 
def show(images):
 images = images.detach().numpy()
 images = 255 * (0.5 * images + 0.5)
 images = images.astype(np.uint8)
 plt.figure(figsize=(4, 4))
 width = images.shape[2]
 gs = gridspec.GridSpec(1, num_img, wspace=0, hspace=0)
 for i, img in enumerate(images):
  ax = plt.subplot(gs[i])
  ax.set_xticklabels([])
  ax.set_yticklabels([])
  ax.set_aspect('equal')
  plt.imshow(img.reshape(width, width), cmap=plt.cm.gray)
  plt.axis('off')
  plt.tight_layout()
 plt.tight_layout()
 # plt.savefig(r'drive/深度學習/DCGAN/images/%d.png' % count, bbox_inches='tight')
 return width
 
def show_all(images_all):
 x=images_all[0]
 for i in range(1,len(images_all),1):
  x=np.concatenate((x,images_all[i]),0)
 print(x.shape)
 x = 255 * (0.5 * x + 0.5)
 x = x.astype(np.uint8)
 plt.figure(figsize=(9, 10))
 width = x.shape[2]
 gs = gridspec.GridSpec(10, num_img, wspace=0, hspace=0)
 for i, img in enumerate(x):
  ax = plt.subplot(gs[i])
  ax.set_xticklabels([])
  ax.set_yticklabels([])
  ax.set_aspect('equal')
  plt.imshow(img.reshape(width, width), cmap=plt.cm.gray)
  plt.axis('off')
  plt.tight_layout()
 
 
 # 導入相應的模型
z_dimension = 110
D = discriminator()
G = generator(z_dimension, 3136) # 1*56*56
D.load_state_dict(torch.load(r'./CGAN/Discriminator.pkl'))
G.load_state_dict(torch.load(r'./CGAN/Generator.pkl'))
# 依次生成0到9
lis=[]
for i in range(10):
 z = torch.randn((num_img, 100)) # 隨機生成向量
 x=np.zeros((num_img,10))
 x[:,i]=1
 z = np.concatenate((z.numpy(), x),1)
 z = torch.from_numpy(z).float()
 fake_img = G(z) # 將向量放入生成網(wǎng)絡G生成一張圖片
 lis.append(fake_img.detach().numpy())
 output = D(fake_img) # 經(jīng)過判別器得到結果
 show(fake_img)
 plt.savefig('./CGAN/generator/%d.png' % i, bbox_inches='tight')
 
show_all(lis)
plt.savefig('./CGAN/generator/all.png', bbox_inches='tight')
plt.show()

生成的結果是:

以上這篇Pytorch使用MNIST數(shù)據(jù)集實現(xiàn)CGAN和生成指定的數(shù)字方式就是小編分享給大家的全部內(nèi)容了,希望能給大家一個參考,也希望大家多多支持腳本之家。

相關文章

  • Python基于百度AI實現(xiàn)抓取表情包

    Python基于百度AI實現(xiàn)抓取表情包

    本文先抓取網(wǎng)絡上的表情圖像,然后利用百度 AI 識別表情包上的說明文字,并利用表情文字重命名文件,感興趣的小伙伴們可以參考一下
    2021-06-06
  • python列表的構造方法list()

    python列表的構造方法list()

    這篇文章主要介紹了python列表的構造方法list(),python中沒有數(shù)組這個概念,與之相應的是列表,本篇文章就來說說列表這個語法,下面文章詳細內(nèi)容,需要的小伙伴可以參考一下
    2022-03-03
  • Python列表list常用內(nèi)建函數(shù)實例小結

    Python列表list常用內(nèi)建函數(shù)實例小結

    這篇文章主要介紹了Python列表list常用內(nèi)建函數(shù),結合實例形式總結分析了Python列表list常見內(nèi)建函數(shù)的功能、使用方法及相關操作注意事項,需要的朋友可以參考下
    2019-10-10
  • python os模塊常用的29種方法使用詳解

    python os模塊常用的29種方法使用詳解

    這篇文章主要介紹了python os模塊常用的29種方法使用詳解,文中通過示例代碼介紹的非常詳細,對大家的學習或者工作具有一定的參考學習價值,需要的朋友們下面隨著小編來一起學習學習吧
    2020-06-06
  • Python利用命名空間解析XML文檔

    Python利用命名空間解析XML文檔

    這篇文章主要介紹了Python利用命名空間解析XML文檔,幫助大家更好的理解和學習Python,感興趣的朋友可以了解下
    2020-08-08
  • Python中TK窗口的創(chuàng)建方式

    Python中TK窗口的創(chuàng)建方式

    這篇文章主要介紹了Python中TK窗口的創(chuàng)建方式,具有很好的參考價值,希望對大家有所幫助。如有錯誤或未考慮完全的地方,望不吝賜教
    2022-11-11
  • 不可錯過的十本Python好書

    不可錯過的十本Python好書

    不可錯過的十本Python好書,分別適合入門、進階到精深三個不同階段的人來閱讀,感興趣的小伙伴們可以參考一下
    2017-07-07
  • 詳解Python IO口多路復用

    詳解Python IO口多路復用

    這篇文章主要介紹了Python IO口多路復用的的相關資料,文中講解的非常細致,幫助大家更好的理解和學習,感興趣的朋友可以參考下
    2020-06-06
  • Python中你應該知道的一些內(nèi)置函數(shù)

    Python中你應該知道的一些內(nèi)置函數(shù)

    python提供了內(nèi)聯(lián)模塊buidin,該模塊定義了一些軟件開發(fā)中常用的函數(shù),這些函數(shù)實現(xiàn)了數(shù)據(jù)類型的轉換,數(shù)據(jù)的計算,序列的處理等功能。下面這篇文章主要給大家介紹了Python中一些大家應該知道的內(nèi)置函數(shù),文中總結的非常詳細,需要的朋友們下面來一起看看吧。
    2017-03-03
  • python-pymysql如何實現(xiàn)更新mysql表中任意字段數(shù)據(jù)

    python-pymysql如何實現(xiàn)更新mysql表中任意字段數(shù)據(jù)

    這篇文章主要介紹了python-pymysql如何實現(xiàn)更新mysql表中任意字段數(shù)據(jù)問題,具有很好的參考價值,希望對大家有所幫助。如有錯誤或未考慮完全的地方,望不吝賜教
    2023-05-05

最新評論