pytorch:實現(xiàn)簡單的GAN示例(MNIST數(shù)據(jù)集)
更新時間:2020年01月10日 09:17:37 作者:xckkcxxck
今天小編就為大家分享一篇pytorch:實現(xiàn)簡單的GAN示例(MNIST數(shù)據(jù)集),具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧
我就廢話不多說了,直接上代碼吧!
# -*- coding: utf-8 -*-
"""
Created on Sat Oct 13 10:22:45 2018
@author: www
"""
import torch
from torch import nn
from torch.autograd import Variable
import torchvision.transforms as tfs
from torch.utils.data import DataLoader, sampler
from torchvision.datasets import MNIST
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
plt.rcParams['figure.figsize'] = (10.0, 8.0) # 設(shè)置畫圖的尺寸
plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams['image.cmap'] = 'gray'
def show_images(images): # 定義畫圖工具
images = np.reshape(images, [images.shape[0], -1])
sqrtn = int(np.ceil(np.sqrt(images.shape[0])))
sqrtimg = int(np.ceil(np.sqrt(images.shape[1])))
fig = plt.figure(figsize=(sqrtn, sqrtn))
gs = gridspec.GridSpec(sqrtn, sqrtn)
gs.update(wspace=0.05, hspace=0.05)
for i, img in enumerate(images):
ax = plt.subplot(gs[i])
plt.axis('off')
ax.set_xticklabels([])
ax.set_yticklabels([])
ax.set_aspect('equal')
plt.imshow(img.reshape([sqrtimg,sqrtimg]))
return
def preprocess_img(x):
x = tfs.ToTensor()(x)
return (x - 0.5) / 0.5
def deprocess_img(x):
return (x + 1.0) / 2.0
class ChunkSampler(sampler.Sampler): # 定義一個取樣的函數(shù)
"""Samples elements sequentially from some offset.
Arguments:
num_samples: # of desired datapoints
start: offset where we should start selecting from
"""
def __init__(self, num_samples, start=0):
self.num_samples = num_samples
self.start = start
def __iter__(self):
return iter(range(self.start, self.start + self.num_samples))
def __len__(self):
return self.num_samples
NUM_TRAIN = 50000
NUM_VAL = 5000
NOISE_DIM = 96
batch_size = 128
train_set = MNIST('E:/data', train=True, transform=preprocess_img)
train_data = DataLoader(train_set, batch_size=batch_size, sampler=ChunkSampler(NUM_TRAIN, 0))
val_set = MNIST('E:/data', train=True, transform=preprocess_img)
val_data = DataLoader(val_set, batch_size=batch_size, sampler=ChunkSampler(NUM_VAL, NUM_TRAIN))
imgs = deprocess_img(train_data.__iter__().next()[0].view(batch_size, 784)).numpy().squeeze() # 可視化圖片效果
show_images(imgs)
#判別網(wǎng)絡(luò)
def discriminator():
net = nn.Sequential(
nn.Linear(784, 256),
nn.LeakyReLU(0.2),
nn.Linear(256, 256),
nn.LeakyReLU(0.2),
nn.Linear(256, 1)
)
return net
#生成網(wǎng)絡(luò)
def generator(noise_dim=NOISE_DIM):
net = nn.Sequential(
nn.Linear(noise_dim, 1024),
nn.ReLU(True),
nn.Linear(1024, 1024),
nn.ReLU(True),
nn.Linear(1024, 784),
nn.Tanh()
)
return net
#判別器的 loss 就是將真實數(shù)據(jù)的得分判斷為 1,假的數(shù)據(jù)的得分判斷為 0,而生成器的 loss 就是將假的數(shù)據(jù)判斷為 1
bce_loss = nn.BCEWithLogitsLoss()#交叉熵損失函數(shù)
def discriminator_loss(logits_real, logits_fake): # 判別器的 loss
size = logits_real.shape[0]
true_labels = Variable(torch.ones(size, 1)).float()
false_labels = Variable(torch.zeros(size, 1)).float()
loss = bce_loss(logits_real, true_labels) + bce_loss(logits_fake, false_labels)
return loss
def generator_loss(logits_fake): # 生成器的 loss
size = logits_fake.shape[0]
true_labels = Variable(torch.ones(size, 1)).float()
loss = bce_loss(logits_fake, true_labels)
return loss
# 使用 adam 來進行訓(xùn)練,學(xué)習(xí)率是 3e-4, beta1 是 0.5, beta2 是 0.999
def get_optimizer(net):
optimizer = torch.optim.Adam(net.parameters(), lr=3e-4, betas=(0.5, 0.999))
return optimizer
def train_a_gan(D_net, G_net, D_optimizer, G_optimizer, discriminator_loss, generator_loss, show_every=250,
noise_size=96, num_epochs=10):
iter_count = 0
for epoch in range(num_epochs):
for x, _ in train_data:
bs = x.shape[0]
# 判別網(wǎng)絡(luò)
real_data = Variable(x).view(bs, -1) # 真實數(shù)據(jù)
logits_real = D_net(real_data) # 判別網(wǎng)絡(luò)得分
sample_noise = (torch.rand(bs, noise_size) - 0.5) / 0.5 # -1 ~ 1 的均勻分布
g_fake_seed = Variable(sample_noise)
fake_images = G_net(g_fake_seed) # 生成的假的數(shù)據(jù)
logits_fake = D_net(fake_images) # 判別網(wǎng)絡(luò)得分
d_total_error = discriminator_loss(logits_real, logits_fake) # 判別器的 loss
D_optimizer.zero_grad()
d_total_error.backward()
D_optimizer.step() # 優(yōu)化判別網(wǎng)絡(luò)
# 生成網(wǎng)絡(luò)
g_fake_seed = Variable(sample_noise)
fake_images = G_net(g_fake_seed) # 生成的假的數(shù)據(jù)
gen_logits_fake = D_net(fake_images)
g_error = generator_loss(gen_logits_fake) # 生成網(wǎng)絡(luò)的 loss
G_optimizer.zero_grad()
g_error.backward()
G_optimizer.step() # 優(yōu)化生成網(wǎng)絡(luò)
if (iter_count % show_every == 0):
print('Iter: {}, D: {:.4}, G:{:.4}'.format(iter_count, d_total_error.item(), g_error.item()))
imgs_numpy = deprocess_img(fake_images.data.cpu().numpy())
show_images(imgs_numpy[0:16])
plt.show()
print()
iter_count += 1
D = discriminator()
G = generator()
D_optim = get_optimizer(D)
G_optim = get_optimizer(G)
train_a_gan(D, G, D_optim, G_optim, discriminator_loss, generator_loss)
以上這篇pytorch:實現(xiàn)簡單的GAN示例(MNIST數(shù)據(jù)集)就是小編分享給大家的全部內(nèi)容了,希望能給大家一個參考,也希望大家多多支持腳本之家。
您可能感興趣的文章:
- pytorch實現(xiàn)mnist數(shù)據(jù)集的圖像可視化及保存
- 關(guān)于Pytorch的MNIST數(shù)據(jù)集的預(yù)處理詳解
- 使用 PyTorch 實現(xiàn) MLP 并在 MNIST 數(shù)據(jù)集上驗證方式
- 用Pytorch訓(xùn)練CNN(數(shù)據(jù)集MNIST,使用GPU的方法)
- 詳解PyTorch手寫數(shù)字識別(MNIST數(shù)據(jù)集)
- pytorch 把MNIST數(shù)據(jù)集轉(zhuǎn)換成圖片和txt的方法
- Python PyTorch 如何獲取 MNIST 數(shù)據(jù)
相關(guān)文章
python之線程通過信號pyqtSignal刷新ui的方法
今天小編就為大家分享一篇python之線程通過信號pyqtSignal刷新ui的方法,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2019-01-01
淺談對pytroch中torch.autograd.backward的思考
這篇文章主要介紹了對pytroch中torch.autograd.backward的思考,文中通過示例代碼介紹的非常詳細,對大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2019-12-12
python中import和from-import的區(qū)別解析
這篇文章主要介紹了python中import和from-import的區(qū)別解析,本文通過實例代碼給大家講解的非常詳細,對大家的學(xué)習(xí)或工作具有一定的參考借鑒價值,需要的朋友可以參考下2022-12-12
Python使用虛擬環(huán)境(安裝下載更新卸載)命令
這篇文章主要為大家介紹了Python使用虛擬環(huán)境(安裝下載更新卸載)命令,有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進步,早日升職加薪2023-11-11
python應(yīng)用之如何使用Python發(fā)送通知到微信
現(xiàn)在通過發(fā)微信信息來做消息通知和告警已經(jīng)很普遍了,下面這篇文章主要給大家介紹了關(guān)于python應(yīng)用之如何使用Python發(fā)送通知到微信的相關(guān)資料,文中通過實例代碼介紹的非常詳細,需要的朋友可以參考下2022-03-03

