PyTorch使用Torchdyn實現(xiàn)連續(xù)時間神經(jīng)網(wǎng)絡的代碼示例
Torchdyn概述
Torchdyn是基于PyTorch構(gòu)建的專業(yè)庫,專注于連續(xù)深度學習和隱式神經(jīng)網(wǎng)絡模型(如Neural ODEs)的開發(fā)。該庫具有以下核心特性:
- 支持深度不變性和深度可變性的ODE模型
- 提供多種數(shù)值求解算法(如Runge-Kutta法,Dormand-Prince法)
- 與PyTorch Lightning框架的無縫集成,便于訓練流程管理
本教程將以經(jīng)典的moons數(shù)據(jù)集為例,展示Neural ODEs在分類問題中的應用。
數(shù)據(jù)集構(gòu)建
首先,我們使用Torchdyn內(nèi)置的數(shù)據(jù)集生成工具創(chuàng)建實驗數(shù)據(jù):
from torchdyn.datasets import ToyDataset import matplotlib.pyplot as plt # 生成示例數(shù)據(jù) d = ToyDataset() X, yn = d.generate(n_samples=512, noise=1e-1, dataset_type='moons') # 可視化數(shù)據(jù)集 colors = ['orange', 'blue'] fig, ax = plt.subplots(figsize=(3, 3)) for i in range(len(X)): ax.scatter(X[i, 0], X[i, 1], s=1, color=colors[yn[i].int()]) plt.show()
數(shù)據(jù)預處理
將生成的數(shù)據(jù)轉(zhuǎn)換為PyTorch張量格式,并構(gòu)建訓練數(shù)據(jù)加載器。Torchdyn支持CPU和GPU計算,可根據(jù)硬件環(huán)境靈活選擇:
import torch import torch.utils.data as data device = torch.device("cpu") # 如果使用GPU則改為'cuda' X_train = torch.Tensor(X).to(device) y_train = torch.LongTensor(yn.long()).to(device) train = data.TensorDataset(X_train, y_train) trainloader = data.DataLoader(train, batch_size=len(X), shuffle=True)
Neural ODE模型構(gòu)建
Neural ODEs的核心組件是向量場(vector field),它通過神經(jīng)網(wǎng)絡定義了數(shù)據(jù)在連續(xù)深度域中的演化規(guī)律。以下代碼展示了向量場的基本實現(xiàn):
import torch.nn as nn # 定義向量場f f = nn.Sequential( nn.Linear(2, 16), nn.Tanh(), nn.Linear(16, 2) )
接下來,我們使用Torchdyn的
NeuralODE
類定義Neural ODE模型。這個類接收向量場和求解器設置作為輸入。
from torchdyn.core import NeuralODE t_span = torch.linspace(0, 1, 5) # 時間跨度 model = NeuralODE(f, sensitivity='adjoint', solver='dopri5').to(device)
類來管理訓練過程:
import pytorch_lightning as pl class Learner(pl.LightningModule): def __init__(self, t_span: torch.Tensor, model: nn.Module): super().__init__() self.model, self.t_span = model, t_span def forward(self, x): return self.model(x) def training_step(self, batch, batch_idx): x, y = batch t_eval, y_hat = self.model(x, self.t_span) y_hat = y_hat[-1] # 選擇軌跡的最后一個點 loss = nn.CrossEntropyLoss()(y_hat, y) return {'loss': loss} def configure_optimizers(self): return torch.optim.Adam(self.model.parameters(), lr=0.01) def train_dataloader(self): return trainloader
最后訓練模型:
learn = Learner(t_span, model) trainer = pl.Trainer(max_epochs=200) trainer.fit(learn)
實驗結(jié)果可視化
深度域軌跡分析
訓練完成后,我們可以觀察數(shù)據(jù)樣本在深度域(即ODE的時間維度)中的演化軌跡:
t_eval, trajectory = model(X_train, t_span) trajectory = trajectory.detach().cpu() fig, (ax0, ax1) = plt.subplots(1, 2, figsize=(10, 2)) for i in range(500): ax0.plot(t_span, trajectory[:, i, 0], alpha=0.1, color=colors[int(yn[i])]) ax1.plot(t_span, trajectory[:, i, 1], alpha=0.1, color=colors[int(yn[i])]) ax0.set_title("維度 0") ax1.set_title("維度 1") plt.show()
向量場可視化
通過可視化學習得到的向量場,我們可以直觀理解模型的動力學特性:
x = torch.linspace(trajectory[:, :, 0].min(), trajectory[:, :, 0].max(), 50) y = torch.linspace(trajectory[:, :, 1].min(), trajectory[:, :, 1].max(), 50) X, Y = torch.meshgrid(x, y) z = torch.cat([X.reshape(-1, 1), Y.reshape(-1, 1)], 1) f_eval = model.vf(0, z.to(device)).cpu().detach() fx, fy = f_eval[:, 0], f_eval[:, 1] fx, fy = fx.reshape(50, 50), fy.reshape(50, 50) fig, ax = plt.subplots(figsize=(4, 4)) ax.streamplot(X.numpy(), Y.numpy(), fx.numpy(), fy.numpy(), color='black') plt.show()
Torchdyn進階特性
Torchdyn框架的功能遠不限于基礎的Neural ODEs實現(xiàn)。它提供了豐富的高級特性,包括:
- 高精度數(shù)值求解器
- 平衡模型支持
- 自定義微分方程系統(tǒng)
無論是物理模型的數(shù)值模擬,還是連續(xù)深度學習模型的開發(fā),Torchdyn都提供了完整的工具鏈支持。
以上就是PyTorch使用Torchdyn實現(xiàn)連續(xù)時間神經(jīng)網(wǎng)絡的代碼示例的詳細內(nèi)容,更多關于PyTorch Torchdyn連續(xù)時間神經(jīng)網(wǎng)絡的資料請關注腳本之家其它相關文章!
相關文章
Python3 sort和sorted用法+cmp_to_key()函數(shù)詳解
這篇文章主要介紹了Python3 sort和sorted用法+cmp_to_key()函數(shù)詳解,具有很好的參考價值,希望對大家有所幫助。如有錯誤或未考慮完全的地方,望不吝賜教2023-07-07python 動態(tài)獲取當前運行的類名和函數(shù)名的方法
這篇文章主要介紹了python 動態(tài)獲取當前運行的類名和函數(shù)名的方法,分別介紹使用內(nèi)置方法、sys模塊、修飾器、inspect模塊等方法,需要的朋友可以參考下2014-04-04python使用IP歸屬地查詢API追蹤網(wǎng)絡活動
這篇文章主要為大家介紹了python使用IP歸屬地查詢API追蹤網(wǎng)絡活動實現(xiàn)詳解,有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進步,早日升職加薪2023-09-09Win10里python3創(chuàng)建虛擬環(huán)境的步驟
在本篇文章里小編給大家整理的是一篇關于Win10里python3創(chuàng)建虛擬環(huán)境的步驟內(nèi)容,需要的朋友們可以學習參考下。2020-01-01Python Matplotlib條形圖之垂直條形圖和水平條形圖詳解
這篇文章主要為大家詳細介紹了Python Matplotlib條形圖之垂直條形圖和水平條形圖,使用數(shù)據(jù)庫,文中示例代碼介紹的非常詳細,具有一定的參考價值,感興趣的小伙伴們可以參考一下2022-03-03python中的break、continue、exit()、pass全面解析
下面小編就為大家?guī)硪黄猵ython中的break、continue、exit()、pass全面解析。小編覺得挺不錯的,現(xiàn)在就分享給大家,也給大家做個參考。一起跟隨小編過來看看吧2017-08-08