欧美bbbwbbbw肥妇,免费乱码人妻系列日韩,一级黄片

pytorch:實(shí)現(xiàn)簡(jiǎn)單的GAN示例(MNIST數(shù)據(jù)集)

 更新時(shí)間:2020年01月10日 09:17:37   作者:xckkcxxck  
今天小編就為大家分享一篇pytorch:實(shí)現(xiàn)簡(jiǎn)單的GAN示例(MNIST數(shù)據(jù)集),具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧

我就廢話不多說(shuō)了,直接上代碼吧!

# -*- coding: utf-8 -*-
"""
Created on Sat Oct 13 10:22:45 2018
@author: www
"""
 
import torch
from torch import nn
from torch.autograd import Variable
 
import torchvision.transforms as tfs
from torch.utils.data import DataLoader, sampler
from torchvision.datasets import MNIST
 
import numpy as np
 
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
 
plt.rcParams['figure.figsize'] = (10.0, 8.0) # 設(shè)置畫圖的尺寸
plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams['image.cmap'] = 'gray'
 
def show_images(images): # 定義畫圖工具
  images = np.reshape(images, [images.shape[0], -1])
  sqrtn = int(np.ceil(np.sqrt(images.shape[0])))
  sqrtimg = int(np.ceil(np.sqrt(images.shape[1])))
 
  fig = plt.figure(figsize=(sqrtn, sqrtn))
  gs = gridspec.GridSpec(sqrtn, sqrtn)
  gs.update(wspace=0.05, hspace=0.05)
 
  for i, img in enumerate(images):
    ax = plt.subplot(gs[i])
    plt.axis('off')
    ax.set_xticklabels([])
    ax.set_yticklabels([])
    ax.set_aspect('equal')
    plt.imshow(img.reshape([sqrtimg,sqrtimg]))
  return 
  
def preprocess_img(x):
  x = tfs.ToTensor()(x)
  return (x - 0.5) / 0.5
 
def deprocess_img(x):
  return (x + 1.0) / 2.0
 
class ChunkSampler(sampler.Sampler): # 定義一個(gè)取樣的函數(shù)
  """Samples elements sequentially from some offset. 
  Arguments:
    num_samples: # of desired datapoints
    start: offset where we should start selecting from
  """
  def __init__(self, num_samples, start=0):
    self.num_samples = num_samples
    self.start = start
 
  def __iter__(self):
    return iter(range(self.start, self.start + self.num_samples))
 
  def __len__(self):
    return self.num_samples
    
NUM_TRAIN = 50000
NUM_VAL = 5000
 
NOISE_DIM = 96
batch_size = 128
 
train_set = MNIST('E:/data', train=True, transform=preprocess_img)
 
train_data = DataLoader(train_set, batch_size=batch_size, sampler=ChunkSampler(NUM_TRAIN, 0))
 
val_set = MNIST('E:/data', train=True, transform=preprocess_img)
 
val_data = DataLoader(val_set, batch_size=batch_size, sampler=ChunkSampler(NUM_VAL, NUM_TRAIN))
 
imgs = deprocess_img(train_data.__iter__().next()[0].view(batch_size, 784)).numpy().squeeze() # 可視化圖片效果
show_images(imgs)
 
#判別網(wǎng)絡(luò)
def discriminator():
  net = nn.Sequential(    
      nn.Linear(784, 256),
      nn.LeakyReLU(0.2),
      nn.Linear(256, 256),
      nn.LeakyReLU(0.2),
      nn.Linear(256, 1)
    )
  return net
  
#生成網(wǎng)絡(luò)
def generator(noise_dim=NOISE_DIM):  
  net = nn.Sequential(
    nn.Linear(noise_dim, 1024),
    nn.ReLU(True),
    nn.Linear(1024, 1024),
    nn.ReLU(True),
    nn.Linear(1024, 784),
    nn.Tanh()
  )
  return net
  
#判別器的 loss 就是將真實(shí)數(shù)據(jù)的得分判斷為 1,假的數(shù)據(jù)的得分判斷為 0,而生成器的 loss 就是將假的數(shù)據(jù)判斷為 1
 
bce_loss = nn.BCEWithLogitsLoss()#交叉熵?fù)p失函數(shù)
 
def discriminator_loss(logits_real, logits_fake): # 判別器的 loss
  size = logits_real.shape[0]
  true_labels = Variable(torch.ones(size, 1)).float()
  false_labels = Variable(torch.zeros(size, 1)).float()
  loss = bce_loss(logits_real, true_labels) + bce_loss(logits_fake, false_labels)
  return loss
  
def generator_loss(logits_fake): # 生成器的 loss 
  size = logits_fake.shape[0]
  true_labels = Variable(torch.ones(size, 1)).float()
  loss = bce_loss(logits_fake, true_labels)
  return loss
  
# 使用 adam 來(lái)進(jìn)行訓(xùn)練,學(xué)習(xí)率是 3e-4, beta1 是 0.5, beta2 是 0.999
def get_optimizer(net):
  optimizer = torch.optim.Adam(net.parameters(), lr=3e-4, betas=(0.5, 0.999))
  return optimizer
  
def train_a_gan(D_net, G_net, D_optimizer, G_optimizer, discriminator_loss, generator_loss, show_every=250, 
        noise_size=96, num_epochs=10):
  iter_count = 0
  for epoch in range(num_epochs):
    for x, _ in train_data:
      bs = x.shape[0]
      # 判別網(wǎng)絡(luò)
      real_data = Variable(x).view(bs, -1) # 真實(shí)數(shù)據(jù)
      logits_real = D_net(real_data) # 判別網(wǎng)絡(luò)得分
      
      sample_noise = (torch.rand(bs, noise_size) - 0.5) / 0.5 # -1 ~ 1 的均勻分布
      g_fake_seed = Variable(sample_noise)
      fake_images = G_net(g_fake_seed) # 生成的假的數(shù)據(jù)
      logits_fake = D_net(fake_images) # 判別網(wǎng)絡(luò)得分
 
      d_total_error = discriminator_loss(logits_real, logits_fake) # 判別器的 loss
      D_optimizer.zero_grad()
      d_total_error.backward()
      D_optimizer.step() # 優(yōu)化判別網(wǎng)絡(luò)
      
      # 生成網(wǎng)絡(luò)
      g_fake_seed = Variable(sample_noise)
      fake_images = G_net(g_fake_seed) # 生成的假的數(shù)據(jù)
 
      gen_logits_fake = D_net(fake_images)
      g_error = generator_loss(gen_logits_fake) # 生成網(wǎng)絡(luò)的 loss
      G_optimizer.zero_grad()
      g_error.backward()
      G_optimizer.step() # 優(yōu)化生成網(wǎng)絡(luò)
 
      if (iter_count % show_every == 0):
        print('Iter: {}, D: {:.4}, G:{:.4}'.format(iter_count, d_total_error.item(), g_error.item()))
        imgs_numpy = deprocess_img(fake_images.data.cpu().numpy())
        show_images(imgs_numpy[0:16])
        plt.show()
        print()
      iter_count += 1
 
D = discriminator()
G = generator()
 
D_optim = get_optimizer(D)
G_optim = get_optimizer(G)
 
train_a_gan(D, G, D_optim, G_optim, discriminator_loss, generator_loss)      

以上這篇pytorch:實(shí)現(xiàn)簡(jiǎn)單的GAN示例(MNIST數(shù)據(jù)集)就是小編分享給大家的全部?jī)?nèi)容了,希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。

相關(guān)文章

  • Python PyQt5標(biāo)準(zhǔn)對(duì)話框用法示例

    Python PyQt5標(biāo)準(zhǔn)對(duì)話框用法示例

    這篇文章主要介紹了Python PyQt5標(biāo)準(zhǔn)對(duì)話框用法,結(jié)合實(shí)例形式分析了PyQt5常用的標(biāo)準(zhǔn)對(duì)話框及相關(guān)使用技巧,需要的朋友可以參考下
    2017-08-08
  • python之線程通過(guò)信號(hào)pyqtSignal刷新ui的方法

    python之線程通過(guò)信號(hào)pyqtSignal刷新ui的方法

    今天小編就為大家分享一篇python之線程通過(guò)信號(hào)pyqtSignal刷新ui的方法,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧
    2019-01-01
  • python創(chuàng)建模板文件及使用教程示例

    python創(chuàng)建模板文件及使用教程示例

    這篇文章主要介紹了python創(chuàng)建模板文件及使用教程示例
    2021-10-10
  • 淺談對(duì)pytroch中torch.autograd.backward的思考

    淺談對(duì)pytroch中torch.autograd.backward的思考

    這篇文章主要介紹了對(duì)pytroch中torch.autograd.backward的思考,文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來(lái)一起學(xué)習(xí)學(xué)習(xí)吧
    2019-12-12
  • python中import和from-import的區(qū)別解析

    python中import和from-import的區(qū)別解析

    這篇文章主要介紹了python中import和from-import的區(qū)別解析,本文通過(guò)實(shí)例代碼給大家講解的非常詳細(xì),對(duì)大家的學(xué)習(xí)或工作具有一定的參考借鑒價(jià)值,需要的朋友可以參考下
    2022-12-12
  • Python多路復(fù)用selector模塊的基本使用

    Python多路復(fù)用selector模塊的基本使用

    Python提供了selector模塊來(lái)實(shí)現(xiàn)IO多路復(fù)用,這篇文章給大家介紹了Python多路復(fù)用selector模塊的基本使用,感興趣的朋友一起看看吧
    2021-11-11
  • Python使用虛擬環(huán)境(安裝下載更新卸載)命令

    Python使用虛擬環(huán)境(安裝下載更新卸載)命令

    這篇文章主要為大家介紹了Python使用虛擬環(huán)境(安裝下載更新卸載)命令,有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進(jìn)步,早日升職加薪
    2023-11-11
  • Python安裝spark的詳細(xì)過(guò)程

    Python安裝spark的詳細(xì)過(guò)程

    這篇文章主要介紹了Python安裝spark的詳細(xì)過(guò)程,本文通過(guò)圖文實(shí)例代碼相結(jié)合給大家介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或工作具有一定的參考借鑒價(jià)值,需要的朋友可以參考下
    2021-10-10
  • Python實(shí)現(xiàn)拉格朗日插值法的示例詳解

    Python實(shí)現(xiàn)拉格朗日插值法的示例詳解

    插值法是一種數(shù)學(xué)方法,用于在已知數(shù)據(jù)點(diǎn)(離散數(shù)據(jù))之間插入數(shù)據(jù),以生成連續(xù)的函數(shù)曲線,而格朗日插值法是一種多項(xiàng)式插值法。本文就來(lái)用Python實(shí)現(xiàn)拉格朗日插值法,希望對(duì)大家有所幫助
    2023-02-02
  • python應(yīng)用之如何使用Python發(fā)送通知到微信

    python應(yīng)用之如何使用Python發(fā)送通知到微信

    現(xiàn)在通過(guò)發(fā)微信信息來(lái)做消息通知和告警已經(jīng)很普遍了,下面這篇文章主要給大家介紹了關(guān)于python應(yīng)用之如何使用Python發(fā)送通知到微信的相關(guān)資料,文中通過(guò)實(shí)例代碼介紹的非常詳細(xì),需要的朋友可以參考下
    2022-03-03

最新評(píng)論