pytorch 實(shí)現(xiàn)變分自動(dòng)編碼器的操作
本來以為自動(dòng)編碼器是很簡單的東西,但是也是看了好多資料仍然不太懂它的原理。先把代碼記錄下來,有時(shí)間好好研究。
這個(gè)例子是用MNIST數(shù)據(jù)集生成為例子
# -*- coding: utf-8 -*- """ Created on Fri Oct 12 11:42:19 2018 @author: www """ import os import torch from torch.autograd import Variable import torch.nn.functional as F from torch import nn from torch.utils.data import DataLoader from torchvision.datasets import MNIST from torchvision import transforms as tfs from torchvision.utils import save_image im_tfs = tfs.Compose([ tfs.ToTensor(), tfs.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) # 標(biāo)準(zhǔn)化 ]) train_set = MNIST('E:\data', transform=im_tfs) train_data = DataLoader(train_set, batch_size=128, shuffle=True) class VAE(nn.Module): def __init__(self): super(VAE, self).__init__() self.fc1 = nn.Linear(784, 400) self.fc21 = nn.Linear(400, 20) # mean self.fc22 = nn.Linear(400, 20) # var self.fc3 = nn.Linear(20, 400) self.fc4 = nn.Linear(400, 784) def encode(self, x): h1 = F.relu(self.fc1(x)) return self.fc21(h1), self.fc22(h1) def reparametrize(self, mu, logvar): std = logvar.mul(0.5).exp_() eps = torch.FloatTensor(std.size()).normal_() if torch.cuda.is_available(): eps = Variable(eps.cuda()) else: eps = Variable(eps) return eps.mul(std).add_(mu) def decode(self, z): h3 = F.relu(self.fc3(z)) return F.tanh(self.fc4(h3)) def forward(self, x): mu, logvar = self.encode(x) # 編碼 z = self.reparametrize(mu, logvar) # 重新參數(shù)化成正態(tài)分布 return self.decode(z), mu, logvar # 解碼,同時(shí)輸出均值方差 net = VAE() # 實(shí)例化網(wǎng)絡(luò) if torch.cuda.is_available(): net = net.cuda() x, _ = train_set[0] x = x.view(x.shape[0], -1) if torch.cuda.is_available(): x = x.cuda() x = Variable(x) _, mu, var = net(x) print(mu) #可以看到,對(duì)于輸入,網(wǎng)絡(luò)可以輸出隱含變量的均值和方差,這里的均值方差還沒有訓(xùn)練 #下面開始訓(xùn)練 reconstruction_function = nn.MSELoss(size_average=False) def loss_function(recon_x, x, mu, logvar): """ recon_x: generating images x: origin images mu: latent mean logvar: latent log variance """ MSE = reconstruction_function(recon_x, x) # loss = 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2) KLD_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar) KLD = torch.sum(KLD_element).mul_(-0.5) # KL divergence return MSE + KLD optimizer = torch.optim.Adam(net.parameters(), lr=1e-3) def to_img(x): ''' 定義一個(gè)函數(shù)將最后的結(jié)果轉(zhuǎn)換回圖片 ''' x = 0.5 * (x + 1.) x = x.clamp(0, 1) x = x.view(x.shape[0], 1, 28, 28) return x for e in range(100): for im, _ in train_data: im = im.view(im.shape[0], -1) im = Variable(im) if torch.cuda.is_available(): im = im.cuda() recon_im, mu, logvar = net(im) loss = loss_function(recon_im, im, mu, logvar) / im.shape[0] # 將 loss 平均 optimizer.zero_grad() loss.backward() optimizer.step() if (e + 1) % 20 == 0: print('epoch: {}, Loss: {:.4f}'.format(e + 1, loss.item())) save = to_img(recon_im.cpu().data) if not os.path.exists('./vae_img'): os.mkdir('./vae_img') save_image(save, './vae_img/image_{}.png'.format(e + 1))
補(bǔ)充:PyTorch 深度學(xué)習(xí)快速入門——變分自動(dòng)編碼器
變分編碼器是自動(dòng)編碼器的升級(jí)版本,其結(jié)構(gòu)跟自動(dòng)編碼器是類似的,也由編碼器和解碼器構(gòu)成。
回憶一下,自動(dòng)編碼器有個(gè)問題,就是并不能任意生成圖片,因?yàn)槲覀儧]有辦法自己去構(gòu)造隱藏向量,需要通過一張圖片輸入編碼我們才知道得到的隱含向量是什么,這時(shí)我們就可以通過變分自動(dòng)編碼器來解決這個(gè)問題。
其實(shí)原理特別簡單,只需要在編碼過程給它增加一些限制,迫使其生成的隱含向量能夠粗略的遵循一個(gè)標(biāo)準(zhǔn)正態(tài)分布,這就是其與一般的自動(dòng)編碼器最大的不同。
這樣我們生成一張新圖片就很簡單了,我們只需要給它一個(gè)標(biāo)準(zhǔn)正態(tài)分布的隨機(jī)隱含向量,這樣通過解碼器就能夠生成我們想要的圖片,而不需要給它一張?jiān)紙D片先編碼。
一般來講,我們通過 encoder 得到的隱含向量并不是一個(gè)標(biāo)準(zhǔn)的正態(tài)分布,為了衡量兩種分布的相似程度,我們使用 KL divergence,利用其來表示隱含向量與標(biāo)準(zhǔn)正態(tài)分布之間差異的 loss,另外一個(gè) loss 仍然使用生成圖片與原圖片的均方誤差來表示。
KL divergence 的公式如下
重參數(shù) 為了避免計(jì)算 KL divergence 中的積分,我們使用重參數(shù)的技巧,不是每次產(chǎn)生一個(gè)隱含向量,而是生成兩個(gè)向量,一個(gè)表示均值,一個(gè)表示標(biāo)準(zhǔn)差,這里我們默認(rèn)編碼之后的隱含向量服從一個(gè)正態(tài)分布的之后,就可以用一個(gè)標(biāo)準(zhǔn)正態(tài)分布先乘上標(biāo)準(zhǔn)差再加上均值來合成這個(gè)正態(tài)分布,最后 loss 就是希望這個(gè)生成的正態(tài)分布能夠符合一個(gè)標(biāo)準(zhǔn)正態(tài)分布,也就是希望均值為 0,方差為 1
所以最后我們可以將我們的 loss 定義為下面的函數(shù),由均方誤差和 KL divergence 求和得到一個(gè)總的 loss
def loss_function(recon_x, x, mu, logvar): """ recon_x: generating images x: origin images mu: latent mean logvar: latent log variance """ MSE = reconstruction_function(recon_x, x) # loss = 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2) KLD_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar) KLD = torch.sum(KLD_element).mul_(-0.5) # KL divergence return MSE + KLD
用 mnist 數(shù)據(jù)集來簡單說明一下變分自動(dòng)編碼器
import os import torch from torch.autograd import Variable import torch.nn.functional as F from torch import nn from torch.utils.data import DataLoader from torchvision.datasets import MNIST from torchvision import transforms as tfs from torchvision.utils import save_image im_tfs = tfs.Compose([ tfs.ToTensor(), tfs.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]) # 標(biāo)準(zhǔn)化 ]) train_set = MNIST('./mnist', transform=im_tfs) train_data = DataLoader(train_set, batch_size=128, shuffle=True) class VAE(nn.Module): def __init__(self): super(VAE, self).__init__() self.fc1 = nn.Linear(784, 400) self.fc21 = nn.Linear(400, 20) # mean self.fc22 = nn.Linear(400, 20) # var self.fc3 = nn.Linear(20, 400) self.fc4 = nn.Linear(400, 784) def encode(self, x): h1 = F.relu(self.fc1(x)) return self.fc21(h1), self.fc22(h1) def reparametrize(self, mu, logvar): std = logvar.mul(0.5).exp_() eps = torch.FloatTensor(std.size()).normal_() if torch.cuda.is_available(): eps = Variable(eps.cuda()) else: eps = Variable(eps) return eps.mul(std).add_(mu) def decode(self, z): h3 = F.relu(self.fc3(z)) return F.tanh(self.fc4(h3)) def forward(self, x): mu, logvar = self.encode(x) # 編碼 z = self.reparametrize(mu, logvar) # 重新參數(shù)化成正態(tài)分布 return self.decode(z), mu, logvar # 解碼,同時(shí)輸出均值方差 net = VAE() # 實(shí)例化網(wǎng)絡(luò) if torch.cuda.is_available(): net = net.cuda() x, _ = train_set[0] x = x.view(x.shape[0], -1) if torch.cuda.is_available(): x = x.cuda() x = Variable(x) _, mu, var = net(x) print(mu) Variable containing: Columns 0 to 9 -0.0307 -0.1439 -0.0435 0.3472 0.0368 -0.0339 0.0274 -0.5608 0.0280 0.2742 Columns 10 to 19 -0.6221 -0.0894 -0.0933 0.4241 0.1611 0.3267 0.5755 -0.0237 0.2714 -0.2806 [torch.cuda.FloatTensor of size 1x20 (GPU 0)]
可以看到,對(duì)于輸入,網(wǎng)絡(luò)可以輸出隱含變量的均值和方差,這里的均值方差還沒有訓(xùn)練 下面開始訓(xùn)練
reconstruction_function = nn.MSELoss(size_average=False) def loss_function(recon_x, x, mu, logvar): """ recon_x: generating images x: origin images mu: latent mean logvar: latent log variance """ MSE = reconstruction_function(recon_x, x) # loss = 0.5 * sum(1 + log(sigma^2) - mu^2 - sigma^2) KLD_element = mu.pow(2).add_(logvar.exp()).mul_(-1).add_(1).add_(logvar) KLD = torch.sum(KLD_element).mul_(-0.5) # KL divergence return MSE + KLD optimizer = torch.optim.Adam(net.parameters(), lr=1e-3) def to_img(x): ''' 定義一個(gè)函數(shù)將最后的結(jié)果轉(zhuǎn)換回圖片 ''' x = 0.5 * (x + 1.) x = x.clamp(0, 1) x = x.view(x.shape[0], 1, 28, 28) return x for e in range(100): for im, _ in train_data: im = im.view(im.shape[0], -1) im = Variable(im) if torch.cuda.is_available(): im = im.cuda() recon_im, mu, logvar = net(im) loss = loss_function(recon_im, im, mu, logvar) / im.shape[0] # 將 loss 平均 optimizer.zero_grad() loss.backward() optimizer.step() if (e + 1) % 20 == 0: print('epoch: {}, Loss: {:.4f}'.format(e + 1, loss.data[0])) save = to_img(recon_im.cpu().data) if not os.path.exists('./vae_img'): os.mkdir('./vae_img') save_image(save, './vae_img/image_{}.png'.format(e + 1)) epoch: 20, Loss: 61.5803 epoch: 40, Loss: 62.9573 epoch: 60, Loss: 63.4285 epoch: 80, Loss: 64.7138 epoch: 100, Loss: 63.3343
變分自動(dòng)編碼器雖然比一般的自動(dòng)編碼器效果要好,而且也限制了其輸出的編碼 (code) 的概率分布,但是它仍然是通過直接計(jì)算生成圖片和原始圖片的均方誤差來生成 loss,這個(gè)方式并不好,生成對(duì)抗網(wǎng)絡(luò)中,我們會(huì)講一講這種方式計(jì)算 loss 的局限性,然后會(huì)介紹一種新的訓(xùn)練辦法,就是通過生成對(duì)抗的訓(xùn)練方式來訓(xùn)練網(wǎng)絡(luò)而不是直接比較兩張圖片的每個(gè)像素點(diǎn)的均方誤差
以上為個(gè)人經(jīng)驗(yàn),希望能給大家一個(gè)參考,也希望大家多多支持腳本之家。
相關(guān)文章
python 中的9個(gè)實(shí)用技巧,助你提高開發(fā)效率
這篇文章主要介紹了python 中的9個(gè)實(shí)用技巧,幫助大家提高python開發(fā)時(shí)的效率,感興趣的朋友可以了解下2020-08-08numpy中幾種隨機(jī)數(shù)生成函數(shù)的用法
numpy是Python中常用的科學(xué)計(jì)算庫,其中也包含了一些隨機(jī)數(shù)生成函數(shù),本文主要介紹了numpy中幾種隨機(jī)數(shù)生成函數(shù)的用法,具有一定的參考價(jià)值,感興趣的可以了解一下2023-11-11python刪除文件、清空目錄的實(shí)現(xiàn)方法
這篇文章主要介紹了python刪除文件、清空目錄的實(shí)現(xiàn)方法,文中通過示例代碼介紹的非常詳細(xì),對(duì)大家的學(xué)習(xí)或者工作具有一定的參考學(xué)習(xí)價(jià)值,需要的朋友們下面隨著小編來一起學(xué)習(xí)學(xué)習(xí)吧2020-09-09Pandas執(zhí)行SQL操作的實(shí)現(xiàn)
使用SQL語句能夠完成對(duì)table的增刪改查操作,Pandas同樣也可以實(shí)現(xiàn)SQL語句的基本功能,本文就來介紹一下,具有一檔的參考價(jià)值,感興趣的可以了解一下2024-07-07Python算法練習(xí)之二分查找算法的實(shí)現(xiàn)
二分查找也稱折半查找(Binary Search),它是一種效率較高的查找方法。本文將介紹python如何實(shí)現(xiàn)二分查找算法,幫助大家更好的理解和使用python,感興趣的朋友可以了解下2022-06-06pytorch實(shí)現(xiàn)下載加載mnist數(shù)據(jù)集
這篇文章主要介紹了pytorch實(shí)現(xiàn)下載加載mnist數(shù)據(jù)集方式,具有很好的參考價(jià)值,希望對(duì)大家有所幫助,如有錯(cuò)誤或未考慮完全的地方,望不吝賜教2024-06-06