亚洲乱码中文字幕综合,中国熟女仑乱hd,亚洲精品乱拍国产一区二区三区,一本大道卡一卡二卡三乱码全集资源,又粗又黄又硬又爽的免费视频

PyTorch使用Torchdyn實(shí)現(xiàn)連續(xù)時(shí)間神經(jīng)網(wǎng)絡(luò)的代碼示例

 更新時(shí)間:2025年02月05日 09:35:36   作者:deephub  
神經(jīng)常微分方程(Neural ODEs)是深度學(xué)習(xí)領(lǐng)域的創(chuàng)新性模型架構(gòu),它將神經(jīng)網(wǎng)絡(luò)的離散變換擴(kuò)展為連續(xù)時(shí)間動(dòng)力系統(tǒng),本文將基于Torchdyn(一個(gè)專門用于連續(xù)深度學(xué)習(xí)和平衡模型的PyTorch擴(kuò)展庫(kù))介紹Neural ODE的實(shí)現(xiàn)與訓(xùn)練方法,需要的朋友可以參考下

Torchdyn概述

Torchdyn是基于PyTorch構(gòu)建的專業(yè)庫(kù),專注于連續(xù)深度學(xué)習(xí)和隱式神經(jīng)網(wǎng)絡(luò)模型(如Neural ODEs)的開(kāi)發(fā)。該庫(kù)具有以下核心特性:

  • 支持深度不變性和深度可變性的ODE模型
  • 提供多種數(shù)值求解算法(如Runge-Kutta法,Dormand-Prince法)
  • 與PyTorch Lightning框架的無(wú)縫集成,便于訓(xùn)練流程管理

本教程將以經(jīng)典的moons數(shù)據(jù)集為例,展示Neural ODEs在分類問(wèn)題中的應(yīng)用。

數(shù)據(jù)集構(gòu)建

首先,我們使用Torchdyn內(nèi)置的數(shù)據(jù)集生成工具創(chuàng)建實(shí)驗(yàn)數(shù)據(jù):

 from torchdyn.datasets import ToyDataset  
 import matplotlib.pyplot as plt  
   
 # 生成示例數(shù)據(jù)
 d = ToyDataset()  
 X, yn = d.generate(n_samples=512, noise=1e-1, dataset_type='moons')  
 # 可視化數(shù)據(jù)集
 colors = ['orange', 'blue']  
 fig, ax = plt.subplots(figsize=(3, 3))  
 for i in range(len(X)):  
     ax.scatter(X[i, 0], X[i, 1], s=1, color=colors[yn[i].int()])  
 plt.show()

數(shù)據(jù)預(yù)處理

將生成的數(shù)據(jù)轉(zhuǎn)換為PyTorch張量格式,并構(gòu)建訓(xùn)練數(shù)據(jù)加載器。Torchdyn支持CPU和GPU計(jì)算,可根據(jù)硬件環(huán)境靈活選擇:

 import torch  
 import torch.utils.data as data  
   
 device = torch.device("cpu")  # 如果使用GPU則改為'cuda'
 X_train = torch.Tensor(X).to(device)  
 y_train = torch.LongTensor(yn.long()).to(device)  
 train = data.TensorDataset(X_train, y_train)  
 trainloader = data.DataLoader(train, batch_size=len(X), shuffle=True)

Neural ODE模型構(gòu)建

Neural ODEs的核心組件是向量場(chǎng)(vector field),它通過(guò)神經(jīng)網(wǎng)絡(luò)定義了數(shù)據(jù)在連續(xù)深度域中的演化規(guī)律。以下代碼展示了向量場(chǎng)的基本實(shí)現(xiàn):

 import torch.nn as nn  
   
 # 定義向量場(chǎng)f
 f = nn.Sequential(  
     nn.Linear(2, 16),  
     nn.Tanh(),  
     nn.Linear(16, 2)  
 )

接下來(lái),我們使用Torchdyn的

NeuralODE

類定義Neural ODE模型。這個(gè)類接收向量場(chǎng)和求解器設(shè)置作為輸入。

 from torchdyn.core import NeuralODE  
   
 t_span = torch.linspace(0, 1, 5)  # 時(shí)間跨度
 model = NeuralODE(f, sensitivity='adjoint', solver='dopri5').to(device)

類來(lái)管理訓(xùn)練過(guò)程:

 import pytorch_lightning as pl  
   
 class Learner(pl.LightningModule):  
     def __init__(self, t_span: torch.Tensor, model: nn.Module):  
         super().__init__()  
         self.model, self.t_span = model, t_span  
     def forward(self, x):  
         return self.model(x)  
     def training_step(self, batch, batch_idx):  
         x, y = batch  
         t_eval, y_hat = self.model(x, self.t_span)  
         y_hat = y_hat[-1]  # 選擇軌跡的最后一個(gè)點(diǎn)
         loss = nn.CrossEntropyLoss()(y_hat, y)  
         return {'loss': loss}  
     def configure_optimizers(self):  
         return torch.optim.Adam(self.model.parameters(), lr=0.01)  
     def train_dataloader(self):  
         return trainloader

最后訓(xùn)練模型:

 learn = Learner(t_span, model)  
 trainer = pl.Trainer(max_epochs=200)  
 trainer.fit(learn)

實(shí)驗(yàn)結(jié)果可視化

深度域軌跡分析

訓(xùn)練完成后,我們可以觀察數(shù)據(jù)樣本在深度域(即ODE的時(shí)間維度)中的演化軌跡:

 t_eval, trajectory = model(X_train, t_span)  
 trajectory = trajectory.detach().cpu()  
   
 fig, (ax0, ax1) = plt.subplots(1, 2, figsize=(10, 2))  
 for i in range(500):  
     ax0.plot(t_span, trajectory[:, i, 0], alpha=0.1, color=colors[int(yn[i])])  
     ax1.plot(t_span, trajectory[:, i, 1], alpha=0.1, color=colors[int(yn[i])])  
 ax0.set_title("維度 0")  
 ax1.set_title("維度 1")  
 plt.show()

向量場(chǎng)可視化

通過(guò)可視化學(xué)習(xí)得到的向量場(chǎng),我們可以直觀理解模型的動(dòng)力學(xué)特性:

 x = torch.linspace(trajectory[:, :, 0].min(), trajectory[:, :, 0].max(), 50)  
 y = torch.linspace(trajectory[:, :, 1].min(), trajectory[:, :, 1].max(), 50)  
 X, Y = torch.meshgrid(x, y)  
 z = torch.cat([X.reshape(-1, 1), Y.reshape(-1, 1)], 1)  
 f_eval = model.vf(0, z.to(device)).cpu().detach()  
   
 fx, fy = f_eval[:, 0], f_eval[:, 1]  
 fx, fy = fx.reshape(50, 50), fy.reshape(50, 50)  
 fig, ax = plt.subplots(figsize=(4, 4))  
 ax.streamplot(X.numpy(), Y.numpy(), fx.numpy(), fy.numpy(), color='black')  
 plt.show()

Torchdyn進(jìn)階特性

Torchdyn框架的功能遠(yuǎn)不限于基礎(chǔ)的Neural ODEs實(shí)現(xiàn)。它提供了豐富的高級(jí)特性,包括:

  • 高精度數(shù)值求解器
  • 平衡模型支持
  • 自定義微分方程系統(tǒng)

無(wú)論是物理模型的數(shù)值模擬,還是連續(xù)深度學(xué)習(xí)模型的開(kāi)發(fā),Torchdyn都提供了完整的工具鏈支持。

以上就是PyTorch使用Torchdyn實(shí)現(xiàn)連續(xù)時(shí)間神經(jīng)網(wǎng)絡(luò)的代碼示例的詳細(xì)內(nèi)容,更多關(guān)于PyTorch Torchdyn連續(xù)時(shí)間神經(jīng)網(wǎng)絡(luò)的資料請(qǐng)關(guān)注腳本之家其它相關(guān)文章!

相關(guān)文章

  • Python3 sort和sorted用法+cmp_to_key()函數(shù)詳解

    Python3 sort和sorted用法+cmp_to_key()函數(shù)詳解

    這篇文章主要介紹了Python3 sort和sorted用法+cmp_to_key()函數(shù)詳解,具有很好的參考價(jià)值,希望對(duì)大家有所幫助。如有錯(cuò)誤或未考慮完全的地方,望不吝賜教
    2023-07-07
  • python 動(dòng)態(tài)獲取當(dāng)前運(yùn)行的類名和函數(shù)名的方法

    python 動(dòng)態(tài)獲取當(dāng)前運(yùn)行的類名和函數(shù)名的方法

    這篇文章主要介紹了python 動(dòng)態(tài)獲取當(dāng)前運(yùn)行的類名和函數(shù)名的方法,分別介紹使用內(nèi)置方法、sys模塊、修飾器、inspect模塊等方法,需要的朋友可以參考下
    2014-04-04
  • 使用Python提取PDF表格到Excel文件的操作步驟

    使用Python提取PDF表格到Excel文件的操作步驟

    在對(duì)PDF中的表格進(jìn)行再利用時(shí),除了直接將PDF文檔轉(zhuǎn)換為Excel文件,我們還可以提取PDF文檔中的表格數(shù)據(jù)并寫入Excel工作表,本文將介紹如何使用Python提取PDF文檔中的表格并寫入Excel文件中,需要的朋友可以參考下
    2024-09-09
  • 在windows下Python打印彩色字體的方法

    在windows下Python打印彩色字體的方法

    這篇文章主要介紹了Python在windows下打印彩色字體的方法;具有很好的參考價(jià)值,希望對(duì)大家有所幫助,一起跟隨小編過(guò)來(lái)看看吧
    2018-05-05
  • python使用IP歸屬地查詢API追蹤網(wǎng)絡(luò)活動(dòng)

    python使用IP歸屬地查詢API追蹤網(wǎng)絡(luò)活動(dòng)

    這篇文章主要為大家介紹了python使用IP歸屬地查詢API追蹤網(wǎng)絡(luò)活動(dòng)實(shí)現(xiàn)詳解,有需要的朋友可以借鑒參考下,希望能夠有所幫助,祝大家多多進(jìn)步,早日升職加薪
    2023-09-09
  • Python入門之布爾值詳解

    Python入門之布爾值詳解

    Python中布爾值(Booleans)表示以下兩個(gè)值之一:True或False。本文主要介紹布爾值(Booleans)的使用,和使用時(shí)需要注意的地方,需要的可以參考一下
    2023-02-02
  • Win10里python3創(chuàng)建虛擬環(huán)境的步驟

    Win10里python3創(chuàng)建虛擬環(huán)境的步驟

    在本篇文章里小編給大家整理的是一篇關(guān)于Win10里python3創(chuàng)建虛擬環(huán)境的步驟內(nèi)容,需要的朋友們可以學(xué)習(xí)參考下。
    2020-01-01
  • Python Matplotlib條形圖之垂直條形圖和水平條形圖詳解

    Python Matplotlib條形圖之垂直條形圖和水平條形圖詳解

    這篇文章主要為大家詳細(xì)介紹了Python Matplotlib條形圖之垂直條形圖和水平條形圖,使用數(shù)據(jù)庫(kù),文中示例代碼介紹的非常詳細(xì),具有一定的參考價(jià)值,感興趣的小伙伴們可以參考一下
    2022-03-03
  • python中的break、continue、exit()、pass全面解析

    python中的break、continue、exit()、pass全面解析

    下面小編就為大家?guī)?lái)一篇python中的break、continue、exit()、pass全面解析。小編覺(jué)得挺不錯(cuò)的,現(xiàn)在就分享給大家,也給大家做個(gè)參考。一起跟隨小編過(guò)來(lái)看看吧
    2017-08-08
  • python plt如何保存為emf圖像

    python plt如何保存為emf圖像

    這篇文章主要介紹了python plt如何保存為emf圖像問(wèn)題,具有很好的參考價(jià)值,希望對(duì)大家有所幫助,如有錯(cuò)誤或未考慮完全的地方,望不吝賜教
    2023-09-09

最新評(píng)論