Pytorch復(fù)現(xiàn)擴散模型的示例詳解
開發(fā)環(huán)境
集成開發(fā)工具:jupyter notebook 6.5.2
集成開發(fā)環(huán)境:Python 3.10.6
第三方庫:torch、matplotlib、sklearn、numpy
1 加載相關(guān)第三方庫
# 使得在 notebook 中顯示繪圖,而不是在外部窗口中顯示 %matplotlib inline import matplotlib.pyplot as plt import numpy as np from sklearn.datasets import make_s_curve import torch import torch.nn as nn import io from PIL import Image
2 加載數(shù)據(jù)集
這里選擇S 形曲線數(shù)據(jù)集作為本次復(fù)現(xiàn)擴散模型所用數(shù)據(jù)集。
s_curve, _ = make_s_curve(10 ** 4, noise=0.1)
將數(shù)據(jù)集中的特征縮放到一個相對較小的范圍內(nèi),以便于模型的訓練和收斂。這樣做可以避免數(shù)據(jù)的特征值之間差異過大,導(dǎo)致某些特征對模型的影響過大,而其他特征的影響被忽略的情況。同時,將數(shù)據(jù)的特征縮放到一個相對較小的范圍內(nèi),也有助于提高模型的泛化能力,使其能夠更好地適應(yīng)新的未知數(shù)據(jù)。
s_curve = s_curve[:, [0, 2]] / 10. print(F"shape of Moons:{np.shape(s_curve)}")
將數(shù)據(jù)集從原來的 (10000, 2) 轉(zhuǎn)換為 (2, 10000),即每一列對應(yīng)一個樣本的所有特征值,這樣的形狀更適合一些深度學習框架的輸入格式。同時還可以保持數(shù)據(jù)的連續(xù)性:對數(shù)據(jù)進行轉(zhuǎn)置操作可以保持數(shù)據(jù)之間的連續(xù)性。在某些機器學習算法或深度學習框架中,連續(xù)的數(shù)據(jù)在內(nèi)存中存儲更加緊湊,可以更快地讀取和處理數(shù)據(jù),從而提高模型的訓練和預(yù)測效率。
data = s_curve.T # 繪制 S 形曲線數(shù)據(jù)集 fig, ax = plt.subplots() ax.scatter(*data, color='red', edgecolor='white') ax.axis('off')
因為在深度學習中,通常使用 PyTorch 等深度學習框架來實現(xiàn)模型的訓練和預(yù)測。而 PyTorch 中的數(shù)據(jù)處理對象是張量(Tensor),因此我們需要將原始數(shù)據(jù)集轉(zhuǎn)換為張量對象才能進行后續(xù)的深度學習模型的訓練和預(yù)測。另外,由于深度學習模型通常需要浮點數(shù)類型的數(shù)據(jù)作為輸入,因此我們需要使用 float() 將張量的數(shù)據(jù)類型設(shè)置為浮點型。這樣做可以保證輸入數(shù)據(jù)類型的一致性,避免數(shù)據(jù)類型不匹配導(dǎo)致的錯誤。
dataset = torch.Tensor(s_curve).float()
3 確定超參數(shù)的值
首先,指定步數(shù)(num_step),這個步數(shù)可以根據(jù) beta、分布的均值和標準差來共同確定。num_step 指定了擴散模型的最終狀態(tài)的計算次數(shù),每一次計算對應(yīng)一個 beta 值。
接著,使用 torch.linspace() 函數(shù)生成一個等間隔的 num_step 個 beta 值。然后,通過對這些 beta 值執(zhí)行 sigmoid 激活函數(shù)以及線性變換,將它們轉(zhuǎn)換為介于 1e-5 到 0.5e-2 之間的浮點數(shù)。這些 beta 值將在后續(xù)計算中用于計算擴散模型的每一步的參數(shù)。
接下來,計算一些中間變量,包括 alphas、alphas_prod、alphas_prod_p、alphas_bar_sqrt、one_minus_alphas_bar_log 和 one_minus_alphas_bar_sqrt。其中,alphas 表示每一步的 alpha 值,alphas_prod 表示前 t 步的 alpha 值的累積乘積,alphas_prod_p 表示前 t-1 步的 alpha 值的累積乘積,alphas_bar_sqrt 表示前 t 步的 alpha 值的累積乘積的平方根,one_minus_alphas_bar_log 表示前 t 步的 alpha 值的累積乘積的對數(shù)的負值,one_minus_alphas_bar_sqrt 表示前 t 步的 alpha 值的累積乘積的差值的平方根。
最后,使用 assert 命令檢查計算的所有變量的形狀是否相同,并打印出 betas 變量的形狀。
num_step = 100 # 一開始可以由beta、分布的均值和標準差來共同確定 # 指定每一步的beta betas = torch.linspace(-6, 6, num_step) betas = torch.sigmoid(betas) * (0.5e-2 - 1e-5) + 1e-5 # 計算alpha、alpha_prod、alpha_prod_previous、alpha_bar_sqrt等變量的值 alphas = 1 - betas alphas_prod = torch.cumprod(alphas, dim=0) alphas_prod_p = torch.cat([torch.tensor([1]).float(), alphas_prod[:-1]], 0) # p表示previous alphas_bar_sqrt = torch.sqrt(alphas_prod) one_minus_alphas_bar_log = torch.log(1 - alphas_prod) one_minus_alphas_bar_sqrt = torch.sqrt(1 - alphas_prod) assert alphas.shape == alphas_prod.shape == alphas_prod_p.shape == alphas_bar_sqrt.shape == one_minus_alphas_bar_log.shape == one_minus_alphas_bar_sqrt.shape print(f"all the same shape:{betas.shape}")
4 確定擴散過程任意時刻的采樣值
首先從正態(tài)分布中生成隨機噪聲。然后,根據(jù)參數(shù)重整化技巧,使用預(yù)先計算好的 alpha_bar_sqrt 和 one_minus_alphas_bar_sqrt,將初始值 x_0 進行變換,得到時刻 t 的采樣值。最后,將噪聲加入到采樣值中,得到最終的采樣值。
# 計算任意時刻的x的采樣值,基于x_0和參數(shù)重整化技巧 def q_x(x_0, t): """可以基于x[0]得到任意時刻t的x[t]""" noise = torch.randn_like(x_0) # noise是從正態(tài)分布中生成的隨機噪聲 alphas_t = alphas_bar_sqrt[t] alphas_l_m_t = one_minus_alphas_bar_sqrt[t] return (alphas_t * x_0 + alphas_l_m_t * noise) # 在x[0]的基礎(chǔ)上添加噪聲
5 演示原始數(shù)據(jù)分布加噪100步后的效果
生成樣本點隨時間變化的演化過程圖。生成一個大小為2x10的子圖網(wǎng)格,每個子圖顯示了原始S曲線數(shù)據(jù)集在經(jīng)過噪聲添加和擴散操作后在某個時間點t時的圖像。其中,num_shows變量指定了要顯示的時間步數(shù),這里為20,因此總共會顯示20張子圖。在每個子圖中,使用q_x函數(shù)對原始數(shù)據(jù)集進行噪聲添加和擴散操作,得到對應(yīng)時間點t時的新數(shù)據(jù)集,然后在子圖中以紅色散點圖的形式繪制出來。每個子圖的標題顯示了該子圖所對應(yīng)的時間步t。
num_shows = 20 fig, axs = plt.subplots(2, 10, figsize=(28, 7)) plt.rc('text', color='blue') # 共有10000個點,每個點包含兩個坐標 # 生成100步以內(nèi)每隔5步加噪聲后的圖像 for i in range(num_shows): j = i // 10 k = i % 10 q_i = q_x(dataset, torch.tensor([i * num_step // num_shows])) # 生成t時刻的采樣數(shù)據(jù) axs[j, k].scatter(q_i[:, 0], q_i[:, 1], color='red', edgecolor='white') axs[j, k].set_axis_off() axs[j, k].set_title('$q(\mathbf{x}_{'+ str(i * num_step // num_shows) + '})$')
6 編寫擬合逆擴散過程高斯分布的模型
在輸入的基礎(chǔ)上添加一個時間步長 t,并對此進行嵌入。具體來說,它使用了 3 個 nn.Embedding
層,分別對應(yīng)于嵌入 t 的 3 個維度。
模型的 forward
方法接受一個輸入 x 和一個時間步長 t,并返回輸出 y。在 forward
方法中,輸入 x 會經(jīng)過一系列的全連接層(使用 nn.Linear
實現(xiàn)),其中每兩個全連接層之間都有一個 ReLU 激活函數(shù)。在這些全連接層之前和之后,模型都會使用 nn.Embedding
層將 t 嵌入到向量中。最終的輸出 y 是一個 2 維向量。
class MLPDiffusion(nn.Module): def __init__(self, n_steps, num_units=128): super(MLPDiffusion, self).__init__() self.linears = nn.ModuleList( [ nn.Linear(2, num_units), nn.ReLU(), nn.Linear(num_units, num_units), nn.ReLU(), nn.Linear(num_units, num_units), nn.ReLU(), nn.Linear(num_units, 2), ]) self.step_embeddings = nn.ModuleList( [ nn.Embedding(n_steps, num_units), nn.Embedding(n_steps, num_units), nn.Embedding(n_steps, num_units), ]) def forward(self, x, t): for idx, embedding_layer in enumerate(self.step_embeddings): t_embedding = embedding_layer(t) x = self.linears[2 * idx](x) x += t_embedding x = self.linears[2 * idx + 1](x) x = self.linears[-1](x) return x
7 編寫訓練的誤差函數(shù)
def diffusion_loss_fn(model, x_0, alphas_bar_sqrt, one_minus_alphas_bar_sqrt, n_steps): """ 對任意時刻t進行采樣計算loss param: model:模型 x_0:初始狀態(tài) alphas_bar_sqrt、one_minus_alphas_bar_sqrt: 參數(shù) n_steps:時間步數(shù) return:損失值 """ batch_size = x_0.shape[0] # 隨機采樣一個時刻t,為了提高訓練效率,這里確保t不重復(fù) # 對一個batchsize樣本生成隨機的時刻t,覆蓋到更多不同的t t = torch.randint(0, n_steps, size=(batch_size // 2,)) t = torch.cat([t, n_steps - 1 - t], dim=0) # [batch] t = t.unsqueeze(-1) # [batch, 1] # x0的系數(shù) a = alphas_bar_sqrt[t] # eps的系數(shù) aml = one_minus_alphas_bar_sqrt[t] # 生成隨機噪聲eps e = torch.randn_like(x_0) # 構(gòu)造模型的輸入 x = x_0 * a + e * aml # 送入模型,得到t時刻的隨機噪聲預(yù)測值 output = model(x, t.squeeze(-1)) # 與真實噪聲一起計算誤差,求平均值 return (e - output).square().mean()
8 編寫逆擴散采樣函數(shù)(inference過程)
進行擴散模型的采樣。具體來說,p_sample_loop函數(shù)是從x[T]恢復(fù)x[T-1]、x[T-2]、...、x[0]的過程,其中x[T]是輸入的初始值。在這個函數(shù)里,使用了一個for循環(huán),從最后一個時刻T開始往前推,依次對每個時刻進行采樣。在每個時刻,調(diào)用p_sample函數(shù)進行采樣。
p_sample函數(shù)的主要作用是從x[T]采樣t時刻的重構(gòu)值,其中x[T]是輸入的初始值,t表示當前時刻。具體來說,首先通過模型預(yù)測出eps_theta,然后通過一些計算,得到該時刻的重構(gòu)值sample。其中,mean表示重構(gòu)值的均值,z是服從標準正態(tài)分布的噪聲,sigma_t是該時刻的標準差。最后,將sample作為當前時刻的重構(gòu)值返回。
def p_sample_loop(model, shape, n_step, betas, one_minus_alphas_bar_sqrt): """從x[T]恢復(fù)x[T - 1]、x[T - 2]、...、x[0]""" cur_x = torch.randn(shape) x_seq = [cur_x] for i in reversed(range(n_step)): cur_x = p_sample(model, cur_x, i, betas, one_minus_alphas_bar_sqrt) x_seq.append(cur_x) return x_seq def p_sample(model, x, t, betas, one_minus_alphas_bar_sqrt): """從x[T]采樣t時刻的重構(gòu)值""" t = torch.tensor([t]) coeff = betas[t] / one_minus_alphas_bar_sqrt[t] eps_theta = model(x, t) mean = (1 / (1 - betas[t]).sqrt()) * (x - (coeff * eps_theta)) z = torch.randn_like(x) sigma_t = betas[t].sqrt() sample = mean + sigma_t * z return (sample)
9 開始訓練模型,并打印loss及中間重構(gòu)效果
這段代碼定義了一個EMA(Exponential Moving Average,指數(shù)平滑移動平均)類,它用于對模型的參數(shù)進行平滑處理。構(gòu)造函數(shù)中的 mu 參數(shù)控制平滑程度,shadow 是一個字典,用于存儲參數(shù)的平滑后的值。
register 方法將參數(shù) val 注冊到 shadow 字典中,__call__方法對指定名稱的參數(shù) name 進行平滑處理。其中,x 是當前時刻參數(shù)的值。計算完成后,將結(jié)果存儲在 shadow 字典中,并返回平滑后的值。
seed = 1234 # 確保程序在每次運行時生成的隨機數(shù)序列都是一樣的 class EMA(): """構(gòu)建一個參數(shù)平滑器,以便更好地泛化模型并減少過擬合""" def __init__(self, mu=0.01): self.mu = mu self.shadow = {} def register(self, name, val): self.shadow[name] = val.clone() def __call__(self, name, x): assert name in self.shadow new_average = self.mu * x + (1.0 - self.mu) * self.shadow[name] return new_average print('Training model.....') batch_size = 512 # 批訓練大小 dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True) num_epoch = 4000 # 定義迭代4000次 plt.rc('text', color='blue') model = MLPDiffusion(num_step) # 輸出維度是2,輸入是x和step optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) for t in range(num_epoch): for idx, batch_x in enumerate(dataloader): loss = diffusion_loss_fn(model, batch_x, alphas_bar_sqrt, one_minus_alphas_bar_sqrt, num_step) optimizer.zero_grad() # 對梯度進行清零,防止網(wǎng)絡(luò)權(quán)重更新過于迅速或不穩(wěn)定,無法得到正確的收斂結(jié)果 loss.backward() nn.utils.clip_grad_norm_(model.parameters(), 1.) # 對梯度進行裁剪,避免出現(xiàn)梯度爆炸 optimizer.step() if (t % 100 == 0): print(loss) x_seq = p_sample_loop(model, dataset.shape, num_step, betas, one_minus_alphas_bar_sqrt) # 共有100個元素 fig, axs = plt.subplots(1, 5, figsize=(28, 7)) for i in range(1, 6): cur_x = x_seq[i * 20].detach() axs[i - 1].scatter(cur_x[:, 0], cur_x[:, 1], color='red', edgecolor='white') axs[i - 1].set_axis_off() axs[i - 1].set_title('$q(\mathbf{x}_{'+str(i * 20)+'})$')
部分效果圖:
10 動畫演示擴散過程和逆擴散過程
# 生成前向過程,也就是逐步加噪聲 imgs = [] for i in range(100): plt.clf() torch_i = q_x(dataset, torch.tensor([i])) plt.scatter(q_i[:, 0], q_i[:, 1], color='red', edgecolor='white', s=5) plt.axis('off') img_buf = io.BytesIO() plt.savefig(img_buf, format='png') img = Image.open(img_buf) imgs.append(img)
# 生成逆過程,也就是逐步復(fù)原 reverse = [] for i in range(100): plt.clf() cur_x = x_seq[i].detach() # 拿到訓練末尾階段生成的x_seq plt.scatter(cur_x[:, 0], cur_x[:, 1], color='red', edgecolor='white', s=5) plt.axis('off') img_buf = io.BytesIO() plt.savefig(img_buf, format='png') img = Image.open(img_buf) reverse.append(img)
imgs = imgs + reverse imgs[0].save("diffusion.gif", format='GIF', append_images=imgs, save_all=True, duration=100, loop=0)
動畫效果圖:
以上就是Pytorch復(fù)現(xiàn)擴散模型的示例詳解的詳細內(nèi)容,更多關(guān)于Pytorch擴散模型的資料請關(guān)注腳本之家其它相關(guān)文章!
相關(guān)文章
Python minidom模塊用法示例【DOM寫入和解析XML】
這篇文章主要介紹了Python minidom模塊用法,結(jié)合實例形式分析了Python DOM創(chuàng)建、寫入和解析XML文件相關(guān)操作技巧,需要的朋友可以參考下2019-03-03pytorch 如何實現(xiàn)HWC轉(zhuǎn)CHW
這篇文章主要介紹了pytorch HWC轉(zhuǎn)CHW的實現(xiàn)方式,具有很好的參考價值,希望對大家有所幫助。如有錯誤或未考慮完全的地方,望不吝賜教2021-05-05計算pytorch標準化(Normalize)所需要數(shù)據(jù)集的均值和方差實例
今天小編就為大家分享一篇計算pytorch標準化(Normalize)所需要數(shù)據(jù)集的均值和方差實例,具有很好的參考價值,希望對大家有所幫助。一起跟隨小編過來看看吧2020-01-01Python模塊psycopg2連接postgresql的實現(xiàn)
本文主要介紹了Python模塊psycopg2連接postgresql的實現(xiàn),文中通過示例代碼介紹的非常詳細,對大家的學習或者工作具有一定的參考學習價值,需要的朋友們下面隨著小編來一起學習學習吧2023-07-07