pytorch:實(shí)現(xiàn)簡(jiǎn)單的GAN示例(MNIST數(shù)據(jù)集)
我就廢話不多說(shuō)了,直接上代碼吧!
# -*- coding: utf-8 -*- """ Created on Sat Oct 13 10:22:45 2018 @author: www """ import torch from torch import nn from torch.autograd import Variable import torchvision.transforms as tfs from torch.utils.data import DataLoader, sampler from torchvision.datasets import MNIST import numpy as np import matplotlib.pyplot as plt import matplotlib.gridspec as gridspec plt.rcParams['figure.figsize'] = (10.0, 8.0) # 設(shè)置畫圖的尺寸 plt.rcParams['image.interpolation'] = 'nearest' plt.rcParams['image.cmap'] = 'gray' def show_images(images): # 定義畫圖工具 images = np.reshape(images, [images.shape[0], -1]) sqrtn = int(np.ceil(np.sqrt(images.shape[0]))) sqrtimg = int(np.ceil(np.sqrt(images.shape[1]))) fig = plt.figure(figsize=(sqrtn, sqrtn)) gs = gridspec.GridSpec(sqrtn, sqrtn) gs.update(wspace=0.05, hspace=0.05) for i, img in enumerate(images): ax = plt.subplot(gs[i]) plt.axis('off') ax.set_xticklabels([]) ax.set_yticklabels([]) ax.set_aspect('equal') plt.imshow(img.reshape([sqrtimg,sqrtimg])) return def preprocess_img(x): x = tfs.ToTensor()(x) return (x - 0.5) / 0.5 def deprocess_img(x): return (x + 1.0) / 2.0 class ChunkSampler(sampler.Sampler): # 定義一個(gè)取樣的函數(shù) """Samples elements sequentially from some offset. Arguments: num_samples: # of desired datapoints start: offset where we should start selecting from """ def __init__(self, num_samples, start=0): self.num_samples = num_samples self.start = start def __iter__(self): return iter(range(self.start, self.start + self.num_samples)) def __len__(self): return self.num_samples NUM_TRAIN = 50000 NUM_VAL = 5000 NOISE_DIM = 96 batch_size = 128 train_set = MNIST('E:/data', train=True, transform=preprocess_img) train_data = DataLoader(train_set, batch_size=batch_size, sampler=ChunkSampler(NUM_TRAIN, 0)) val_set = MNIST('E:/data', train=True, transform=preprocess_img) val_data = DataLoader(val_set, batch_size=batch_size, sampler=ChunkSampler(NUM_VAL, NUM_TRAIN)) imgs = deprocess_img(train_data.__iter__().next()[0].view(batch_size, 784)).numpy().squeeze() # 可視化圖片效果 show_images(imgs) #判別網(wǎng)絡(luò) def discriminator(): net = nn.Sequential( nn.Linear(784, 256), nn.LeakyReLU(0.2), nn.Linear(256, 256), nn.LeakyReLU(0.2), nn.Linear(256, 1) ) return net #生成網(wǎng)絡(luò) def generator(noise_dim=NOISE_DIM): net = nn.Sequential( nn.Linear(noise_dim, 1024), nn.ReLU(True), nn.Linear(1024, 1024), nn.ReLU(True), nn.Linear(1024, 784), nn.Tanh() ) return net #判別器的 loss 就是將真實(shí)數(shù)據(jù)的得分判斷為 1,假的數(shù)據(jù)的得分判斷為 0,而生成器的 loss 就是將假的數(shù)據(jù)判斷為 1 bce_loss = nn.BCEWithLogitsLoss()#交叉熵?fù)p失函數(shù) def discriminator_loss(logits_real, logits_fake): # 判別器的 loss size = logits_real.shape[0] true_labels = Variable(torch.ones(size, 1)).float() false_labels = Variable(torch.zeros(size, 1)).float() loss = bce_loss(logits_real, true_labels) + bce_loss(logits_fake, false_labels) return loss def generator_loss(logits_fake): # 生成器的 loss size = logits_fake.shape[0] true_labels = Variable(torch.ones(size, 1)).float() loss = bce_loss(logits_fake, true_labels) return loss # 使用 adam 來(lái)進(jìn)行訓(xùn)練,學(xué)習(xí)率是 3e-4, beta1 是 0.5, beta2 是 0.999 def get_optimizer(net): optimizer = torch.optim.Adam(net.parameters(), lr=3e-4, betas=(0.5, 0.999)) return optimizer def train_a_gan(D_net, G_net, D_optimizer, G_optimizer, discriminator_loss, generator_loss, show_every=250, noise_size=96, num_epochs=10): iter_count = 0 for epoch in range(num_epochs): for x, _ in train_data: bs = x.shape[0] # 判別網(wǎng)絡(luò) real_data = Variable(x).view(bs, -1) # 真實(shí)數(shù)據(jù) logits_real = D_net(real_data) # 判別網(wǎng)絡(luò)得分 sample_noise = (torch.rand(bs, noise_size) - 0.5) / 0.5 # -1 ~ 1 的均勻分布 g_fake_seed = Variable(sample_noise) fake_images = G_net(g_fake_seed) # 生成的假的數(shù)據(jù) logits_fake = D_net(fake_images) # 判別網(wǎng)絡(luò)得分 d_total_error = discriminator_loss(logits_real, logits_fake) # 判別器的 loss D_optimizer.zero_grad() d_total_error.backward() D_optimizer.step() # 優(yōu)化判別網(wǎng)絡(luò) # 生成網(wǎng)絡(luò) g_fake_seed = Variable(sample_noise) fake_images = G_net(g_fake_seed) # 生成的假的數(shù)據(jù) gen_logits_fake = D_net(fake_images) g_error = generator_loss(gen_logits_fake) # 生成網(wǎng)絡(luò)的 loss G_optimizer.zero_grad() g_error.backward() G_optimizer.step() # 優(yōu)化生成網(wǎng)絡(luò) if (iter_count % show_every == 0): print('Iter: {}, D: {:.4}, G:{:.4}'.format(iter_count, d_total_error.item(), g_error.item())) imgs_numpy = deprocess_img(fake_images.data.cpu().numpy()) show_images(imgs_numpy[0:16]) plt.show() print() iter_count += 1 D = discriminator() G = generator() D_optim = get_optimizer(D) G_optim = get_optimizer(G) train_a_gan(D, G, D_optim, G_optim, discriminator_loss, generator_loss)
以上這篇pytorch:實(shí)現(xiàn)簡(jiǎn)單的GAN示例(MNIST數(shù)據(jù)集)就是小編分享給大家的全部?jī)?nèi)容了,希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
- pytorch實(shí)現(xiàn)mnist數(shù)據(jù)集的圖像可視化及保存
- 關(guān)于Pytorch的MNIST數(shù)據(jù)集的預(yù)處理詳解
- 使用 PyTorch 實(shí)現(xiàn) MLP 并在 MNIST 數(shù)據(jù)集上驗(yàn)證方式
- 用Pytorch訓(xùn)練CNN(數(shù)據(jù)集MNIST,使用GPU的方法)
- 詳解PyTorch手寫數(shù)字識(shí)別(MNIST數(shù)據(jù)集)
- pytorch 把MNIST數(shù)據(jù)集轉(zhuǎn)換成圖片和txt的方法
- Python PyTorch 如何獲取 MNIST 數(shù)據(jù)
相關(guān)文章
Python PyQt5標(biāo)準(zhǔn)對(duì)話框用法示例
這篇文章主要介紹了Python PyQt5標(biāo)準(zhǔn)對(duì)話框用法,結(jié)合實(shí)例形式分析了PyQt5常用的標(biāo)準(zhǔn)對(duì)話框及相關(guān)使用技巧,需要的朋友可以參考下2017-08-08python之線程通過(guò)信號(hào)pyqtSignal刷新ui的方法
今天小編就為大家分享一篇python之線程通過(guò)信號(hào)pyqtSignal刷新ui的方法,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。一起跟隨小編過(guò)來(lái)看看吧2019-01-01淺談對(duì)pytroch中torch.autograd.backward的思考
這篇文章主要介紹了對(duì)pytroch中torch.autograd.backward的思考,文中通過(guò)示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來(lái)一起學(xué)習(xí)學(xué)習(xí)吧2019-12-12python中import和from-import的區(qū)別解析
這篇文章主要介紹了python中import和from-import的區(qū)別解析,本文通過(guò)實(shí)例代碼給大家講解的非常詳細(xì),對(duì)大家的學(xué)習(xí)或工作具有一定的參考借鑒價(jià)值,需要的朋友可以參考下2022-12-12Python使用虛擬環(huán)境(安裝下載更新卸載)命令
這篇文章主要為大家介紹了Python使用虛擬環(huán)境(安裝下載更新卸載)命令,有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進(jìn)步,早日升職加薪2023-11-11Python實(shí)現(xiàn)拉格朗日插值法的示例詳解
插值法是一種數(shù)學(xué)方法,用于在已知數(shù)據(jù)點(diǎn)(離散數(shù)據(jù))之間插入數(shù)據(jù),以生成連續(xù)的函數(shù)曲線,而格朗日插值法是一種多項(xiàng)式插值法。本文就來(lái)用Python實(shí)現(xiàn)拉格朗日插值法,希望對(duì)大家有所幫助2023-02-02python應(yīng)用之如何使用Python發(fā)送通知到微信
現(xiàn)在通過(guò)發(fā)微信信息來(lái)做消息通知和告警已經(jīng)很普遍了,下面這篇文章主要給大家介紹了關(guān)于python應(yīng)用之如何使用Python發(fā)送通知到微信的相關(guān)資料,文中通過(guò)實(shí)例代碼介紹的非常詳細(xì),需要的朋友可以參考下2022-03-03