-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathode.py
More file actions
52 lines (48 loc) · 1.78 KB
/
ode.py
File metadata and controls
52 lines (48 loc) · 1.78 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import torch
import torch.nn as nn
from torchdiffeq import odeint, odeint_adjoint
class ODEBlock(nn.Module):
"""
Neural ODE Block that integrates the GDE over time.
"""
def __init__(self, func, method='rk4', atol=1e-3, rtol=1e-4, adjoint=False):
"""
Args:
func: GDEFunc instance defining the dynamics
method: ODE solver method ('euler', 'rk4', 'dopri5', etc.)
atol: Absolute tolerance for adaptive solvers
rtol: Relative tolerance for adaptive solvers
adjoint: Whether to use adjoint method for backprop (memory efficient)
"""
super(ODEBlock, self).__init__()
self.func = func
self.method = method
self.atol = atol
self.rtol = rtol
self.adjoint = adjoint
# Fixed-grid solvers don't use atol/rtol
self.fixed_grid_solvers = ['euler', 'midpoint', 'rk4', 'explicit_adams', 'implicit_adams']
if method in self.fixed_grid_solvers:
self.options = {}
else:
self.options = {'atol': atol, 'rtol': rtol}
def forward(self, x, t_eval):
"""
Args:
x: Initial state, shape (num_nodes, features)
t_eval: Integration time points, shape (num_steps,)
Returns:
Trajectory at all time points, shape (num_steps, num_nodes, features)
"""
integrator = odeint_adjoint if self.adjoint else odeint
out = integrator(
self.func,
x,
t_eval,
method=self.method,
**self.options
)
return out # Return full trajectory
def __repr__(self):
return (f'ODEBlock(method={self.method}, atol={self.atol}, '
f'rtol={self.rtol}, adjoint={self.adjoint})')