欧美bbbwbbbw肥妇,免费乱码人妻系列日韩,一级黄片

PyTorch使用Torchdyn實現(xiàn)連續(xù)時間神經(jīng)網(wǎng)絡的代碼示例

 更新時間:2025年02月05日 09:35:36   作者:deephub  
神經(jīng)常微分方程(Neural ODEs)是深度學習領域的創(chuàng)新性模型架構(gòu),它將神經(jīng)網(wǎng)絡的離散變換擴展為連續(xù)時間動力系統(tǒng),本文將基于Torchdyn(一個專門用于連續(xù)深度學習和平衡模型的PyTorch擴展庫)介紹Neural ODE的實現(xiàn)與訓練方法,需要的朋友可以參考下

Torchdyn概述

Torchdyn是基于PyTorch構(gòu)建的專業(yè)庫,專注于連續(xù)深度學習和隱式神經(jīng)網(wǎng)絡模型(如Neural ODEs)的開發(fā)。該庫具有以下核心特性:

  • 支持深度不變性和深度可變性的ODE模型
  • 提供多種數(shù)值求解算法(如Runge-Kutta法,Dormand-Prince法)
  • 與PyTorch Lightning框架的無縫集成,便于訓練流程管理

本教程將以經(jīng)典的moons數(shù)據(jù)集為例,展示Neural ODEs在分類問題中的應用。

數(shù)據(jù)集構(gòu)建

首先,我們使用Torchdyn內(nèi)置的數(shù)據(jù)集生成工具創(chuàng)建實驗數(shù)據(jù):

 from torchdyn.datasets import ToyDataset  
 import matplotlib.pyplot as plt  
   
 # 生成示例數(shù)據(jù)
 d = ToyDataset()  
 X, yn = d.generate(n_samples=512, noise=1e-1, dataset_type='moons')  
 # 可視化數(shù)據(jù)集
 colors = ['orange', 'blue']  
 fig, ax = plt.subplots(figsize=(3, 3))  
 for i in range(len(X)):  
     ax.scatter(X[i, 0], X[i, 1], s=1, color=colors[yn[i].int()])  
 plt.show()

數(shù)據(jù)預處理

將生成的數(shù)據(jù)轉(zhuǎn)換為PyTorch張量格式,并構(gòu)建訓練數(shù)據(jù)加載器。Torchdyn支持CPU和GPU計算,可根據(jù)硬件環(huán)境靈活選擇:

 import torch  
 import torch.utils.data as data  
   
 device = torch.device("cpu")  # 如果使用GPU則改為'cuda'
 X_train = torch.Tensor(X).to(device)  
 y_train = torch.LongTensor(yn.long()).to(device)  
 train = data.TensorDataset(X_train, y_train)  
 trainloader = data.DataLoader(train, batch_size=len(X), shuffle=True)

Neural ODE模型構(gòu)建

Neural ODEs的核心組件是向量場(vector field),它通過神經(jīng)網(wǎng)絡定義了數(shù)據(jù)在連續(xù)深度域中的演化規(guī)律。以下代碼展示了向量場的基本實現(xiàn):

 import torch.nn as nn  
   
 # 定義向量場f
 f = nn.Sequential(  
     nn.Linear(2, 16),  
     nn.Tanh(),  
     nn.Linear(16, 2)  
 )

接下來,我們使用Torchdyn的

NeuralODE

類定義Neural ODE模型。這個類接收向量場和求解器設置作為輸入。

 from torchdyn.core import NeuralODE  
   
 t_span = torch.linspace(0, 1, 5)  # 時間跨度
 model = NeuralODE(f, sensitivity='adjoint', solver='dopri5').to(device)

類來管理訓練過程:

 import pytorch_lightning as pl  
   
 class Learner(pl.LightningModule):  
     def __init__(self, t_span: torch.Tensor, model: nn.Module):  
         super().__init__()  
         self.model, self.t_span = model, t_span  
     def forward(self, x):  
         return self.model(x)  
     def training_step(self, batch, batch_idx):  
         x, y = batch  
         t_eval, y_hat = self.model(x, self.t_span)  
         y_hat = y_hat[-1]  # 選擇軌跡的最后一個點
         loss = nn.CrossEntropyLoss()(y_hat, y)  
         return {'loss': loss}  
     def configure_optimizers(self):  
         return torch.optim.Adam(self.model.parameters(), lr=0.01)  
     def train_dataloader(self):  
         return trainloader

最后訓練模型:

 learn = Learner(t_span, model)  
 trainer = pl.Trainer(max_epochs=200)  
 trainer.fit(learn)

實驗結(jié)果可視化

深度域軌跡分析

訓練完成后,我們可以觀察數(shù)據(jù)樣本在深度域(即ODE的時間維度)中的演化軌跡:

 t_eval, trajectory = model(X_train, t_span)  
 trajectory = trajectory.detach().cpu()  
   
 fig, (ax0, ax1) = plt.subplots(1, 2, figsize=(10, 2))  
 for i in range(500):  
     ax0.plot(t_span, trajectory[:, i, 0], alpha=0.1, color=colors[int(yn[i])])  
     ax1.plot(t_span, trajectory[:, i, 1], alpha=0.1, color=colors[int(yn[i])])  
 ax0.set_title("維度 0")  
 ax1.set_title("維度 1")  
 plt.show()

向量場可視化

通過可視化學習得到的向量場,我們可以直觀理解模型的動力學特性:

 x = torch.linspace(trajectory[:, :, 0].min(), trajectory[:, :, 0].max(), 50)  
 y = torch.linspace(trajectory[:, :, 1].min(), trajectory[:, :, 1].max(), 50)  
 X, Y = torch.meshgrid(x, y)  
 z = torch.cat([X.reshape(-1, 1), Y.reshape(-1, 1)], 1)  
 f_eval = model.vf(0, z.to(device)).cpu().detach()  
   
 fx, fy = f_eval[:, 0], f_eval[:, 1]  
 fx, fy = fx.reshape(50, 50), fy.reshape(50, 50)  
 fig, ax = plt.subplots(figsize=(4, 4))  
 ax.streamplot(X.numpy(), Y.numpy(), fx.numpy(), fy.numpy(), color='black')  
 plt.show()

Torchdyn進階特性

Torchdyn框架的功能遠不限于基礎的Neural ODEs實現(xiàn)。它提供了豐富的高級特性,包括:

  • 高精度數(shù)值求解器
  • 平衡模型支持
  • 自定義微分方程系統(tǒng)

無論是物理模型的數(shù)值模擬,還是連續(xù)深度學習模型的開發(fā),Torchdyn都提供了完整的工具鏈支持。

以上就是PyTorch使用Torchdyn實現(xiàn)連續(xù)時間神經(jīng)網(wǎng)絡的代碼示例的詳細內(nèi)容,更多關于PyTorch Torchdyn連續(xù)時間神經(jīng)網(wǎng)絡的資料請關注腳本之家其它相關文章!

相關文章

  • Python3 sort和sorted用法+cmp_to_key()函數(shù)詳解

    Python3 sort和sorted用法+cmp_to_key()函數(shù)詳解

    這篇文章主要介紹了Python3 sort和sorted用法+cmp_to_key()函數(shù)詳解,具有很好的參考價值,希望對大家有所幫助。如有錯誤或未考慮完全的地方,望不吝賜教
    2023-07-07
  • python 動態(tài)獲取當前運行的類名和函數(shù)名的方法

    python 動態(tài)獲取當前運行的類名和函數(shù)名的方法

    這篇文章主要介紹了python 動態(tài)獲取當前運行的類名和函數(shù)名的方法,分別介紹使用內(nèi)置方法、sys模塊、修飾器、inspect模塊等方法,需要的朋友可以參考下
    2014-04-04
  • 使用Python提取PDF表格到Excel文件的操作步驟

    使用Python提取PDF表格到Excel文件的操作步驟

    在對PDF中的表格進行再利用時,除了直接將PDF文檔轉(zhuǎn)換為Excel文件,我們還可以提取PDF文檔中的表格數(shù)據(jù)并寫入Excel工作表,本文將介紹如何使用Python提取PDF文檔中的表格并寫入Excel文件中,需要的朋友可以參考下
    2024-09-09
  • 在windows下Python打印彩色字體的方法

    在windows下Python打印彩色字體的方法

    這篇文章主要介紹了Python在windows下打印彩色字體的方法;具有很好的參考價值,希望對大家有所幫助,一起跟隨小編過來看看吧
    2018-05-05
  • python使用IP歸屬地查詢API追蹤網(wǎng)絡活動

    python使用IP歸屬地查詢API追蹤網(wǎng)絡活動

    這篇文章主要為大家介紹了python使用IP歸屬地查詢API追蹤網(wǎng)絡活動實現(xiàn)詳解,有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進步,早日升職加薪
    2023-09-09
  • Python入門之布爾值詳解

    Python入門之布爾值詳解

    Python中布爾值(Booleans)表示以下兩個值之一:True或False。本文主要介紹布爾值(Booleans)的使用,和使用時需要注意的地方,需要的可以參考一下
    2023-02-02
  • Win10里python3創(chuàng)建虛擬環(huán)境的步驟

    Win10里python3創(chuàng)建虛擬環(huán)境的步驟

    在本篇文章里小編給大家整理的是一篇關于Win10里python3創(chuàng)建虛擬環(huán)境的步驟內(nèi)容,需要的朋友們可以學習參考下。
    2020-01-01
  • Python Matplotlib條形圖之垂直條形圖和水平條形圖詳解

    Python Matplotlib條形圖之垂直條形圖和水平條形圖詳解

    這篇文章主要為大家詳細介紹了Python Matplotlib條形圖之垂直條形圖和水平條形圖,使用數(shù)據(jù)庫,文中示例代碼介紹的非常詳細,具有一定的參考價值,感興趣的小伙伴們可以參考一下
    2022-03-03
  • python中的break、continue、exit()、pass全面解析

    python中的break、continue、exit()、pass全面解析

    下面小編就為大家?guī)硪黄猵ython中的break、continue、exit()、pass全面解析。小編覺得挺不錯的,現(xiàn)在就分享給大家,也給大家做個參考。一起跟隨小編過來看看吧
    2017-08-08
  • python plt如何保存為emf圖像

    python plt如何保存為emf圖像

    這篇文章主要介紹了python plt如何保存為emf圖像問題,具有很好的參考價值,希望對大家有所幫助,如有錯誤或未考慮完全的地方,望不吝賜教
    2023-09-09

最新評論