-
Notifications
You must be signed in to change notification settings - Fork 5
doc: Add learned closure demo, add T-torch to landing page + docs #626
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 4 commits
102197e
fe991f0
838924b
a6cb90b
4c89192
a2aea5f
965ab1f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,157 @@ | ||
| # Copyright 2025 Pasteur Labs. All Rights Reserved. | ||
| # SPDX-License-Identifier: Apache-2.0 | ||
|
|
||
| """Single-timestep Burgers' equation solver Tesseract (PyTorch). | ||
|
|
||
| Solves one explicit Euler step of the 1D viscous Burgers' equation: | ||
|
|
||
| u^{n+1} = u^n + dt * (-u * du/dx + nu * d²u/dx²) | ||
|
|
||
| The viscosity field nu is provided as an input — the solver does not compute it. | ||
| This clean interface (state + material field → next state) is the same contract | ||
| that a Fortran solver with an adjoint could implement. The outer time-stepping | ||
| loop and closure evaluation live in the caller, enabling per-timestep closure | ||
| calls and end-to-end gradient flow through both solver and closure. | ||
| """ | ||
|
|
||
| from typing import Any | ||
|
|
||
| import numpy as np | ||
| import torch | ||
| from pydantic import BaseModel, Field | ||
| from torch.utils._pytree import tree_map | ||
|
|
||
| from tesseract_core.runtime import Array, Differentiable, Float64 | ||
| from tesseract_core.runtime.tree_transforms import filter_func, flatten_with_paths | ||
|
|
||
| # Default grid size | ||
| N = 128 | ||
|
|
||
| # --- Grid setup (fixed for this Tesseract) --- | ||
| DX = 1.0 / (N - 1) | ||
|
|
||
| to_tensor = lambda x: ( | ||
| torch.tensor(x, dtype=torch.float64) | ||
| if isinstance(x, np.generic | np.ndarray) | ||
| else x | ||
| ) | ||
|
|
||
|
|
||
| class InputSchema(BaseModel): | ||
| u: Differentiable[Array[(N,), Float64]] = Field( | ||
| description="Current velocity field on the grid" | ||
| ) | ||
| nu: Differentiable[Array[(N,), Float64]] = Field( | ||
| description="Viscosity field at each grid point (must be positive)" | ||
| ) | ||
| dt: float = Field(description="Time step size", default=1e-4) | ||
|
|
||
|
|
||
| class OutputSchema(BaseModel): | ||
| u_next: Differentiable[Array[(N,), Float64]] = Field( | ||
| description="Velocity field after one time step" | ||
| ) | ||
|
|
||
|
|
||
| def evaluate(inputs: dict) -> dict: | ||
| """Core differentiable computation — pure torch operations.""" | ||
| u = inputs["u"] | ||
| nu = inputs["nu"] | ||
| dt = inputs["dt"] | ||
|
|
||
| # Spatial derivatives via central differences | ||
| dudx = torch.zeros_like(u) | ||
| dudx[1:-1] = (u[2:] - u[:-2]) / (2 * DX) | ||
|
|
||
| d2udx2 = torch.zeros_like(u) | ||
| d2udx2[1:-1] = (u[2:] - 2 * u[1:-1] + u[:-2]) / (DX**2) | ||
|
|
||
| # Burgers' equation: du/dt = -u * du/dx + nu * d²u/dx² | ||
| dudt = -u * dudx + nu * d2udx2 | ||
|
|
||
| # Forward Euler step | ||
| u_next = u + dt * dudt | ||
|
|
||
| # Enforce boundary conditions (Dirichlet: hold boundary values) | ||
| u_next = torch.cat([u[:1], u_next[1:-1], u[-1:]]) | ||
|
|
||
| return {"u_next": u_next} | ||
|
|
||
|
|
||
| def apply(inputs: InputSchema) -> OutputSchema: | ||
| tensor_inputs = tree_map(to_tensor, inputs.model_dump()) | ||
| return evaluate(tensor_inputs) | ||
|
|
||
|
|
||
| def abstract_eval(abstract_inputs: Any) -> Any: | ||
| return {"u_next": {"shape": [N], "dtype": "float64"}} | ||
|
|
||
|
|
||
| def jacobian_vector_product( | ||
| inputs: InputSchema, | ||
| jvp_inputs: set[str], | ||
| jvp_outputs: set[str], | ||
| tangent_vector: dict[str, Any], | ||
| ): | ||
| jvp_inputs = tuple(jvp_inputs) | ||
| tangent_vector = {key: tangent_vector[key] for key in jvp_inputs} | ||
|
|
||
| tensor_inputs = tree_map(to_tensor, inputs.model_dump()) | ||
| pos_tangent = tree_map(to_tensor, tangent_vector).values() | ||
| pos_inputs = flatten_with_paths(tensor_inputs, jvp_inputs).values() | ||
|
|
||
| filtered_pos_eval = filter_func( | ||
| evaluate, tensor_inputs, jvp_outputs, input_paths=jvp_inputs | ||
| ) | ||
|
|
||
| return torch.func.jvp(filtered_pos_eval, tuple(pos_inputs), tuple(pos_tangent))[1] | ||
|
|
||
|
|
||
| def vector_jacobian_product( | ||
| inputs: InputSchema, | ||
| vjp_inputs: set[str], | ||
| vjp_outputs: set[str], | ||
| cotangent_vector: dict[str, Any], | ||
| ): | ||
| vjp_inputs = tuple(vjp_inputs) | ||
| cotangent_vector = {key: cotangent_vector[key] for key in vjp_outputs} | ||
|
|
||
| tensor_inputs = tree_map(to_tensor, inputs.model_dump()) | ||
| tensor_cotangent = tree_map(to_tensor, cotangent_vector) | ||
| pos_inputs = flatten_with_paths(tensor_inputs, vjp_inputs).values() | ||
|
|
||
| filtered_pos_func = filter_func( | ||
| evaluate, tensor_inputs, vjp_outputs, input_paths=vjp_inputs | ||
| ) | ||
|
|
||
| _, vjp_func = torch.func.vjp(filtered_pos_func, *pos_inputs) | ||
| vjp_vals = vjp_func(tensor_cotangent) | ||
| return dict(zip(vjp_inputs, vjp_vals, strict=True)) | ||
|
|
||
|
|
||
| def jacobian( | ||
| inputs: InputSchema, | ||
| jac_inputs: set[str], | ||
| jac_outputs: set[str], | ||
| ): | ||
| jac_inputs = tuple(jac_inputs) | ||
| tensor_inputs = tree_map(to_tensor, inputs.model_dump()) | ||
| pos_inputs = flatten_with_paths(tensor_inputs, jac_inputs).values() | ||
|
|
||
| filtered_pos_eval = filter_func( | ||
| evaluate, tensor_inputs, jac_outputs, input_paths=jac_inputs | ||
| ) | ||
|
|
||
| def filtered_pos_eval_flat(*args): | ||
| res = filtered_pos_eval(*args) | ||
| return tuple(res[k] for k in jac_outputs) | ||
|
|
||
| jac = torch.autograd.functional.jacobian(filtered_pos_eval_flat, tuple(pos_inputs)) | ||
|
|
||
| jac_dict = {} | ||
| for dy, dys in zip(jac_outputs, jac, strict=True): | ||
| jac_dict[dy] = {} | ||
| for dx, dxs in zip(jac_inputs, dys, strict=True): | ||
| jac_dict[dy][dx] = dxs | ||
|
|
||
| return jac_dict | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,6 @@ | ||
| name: "burgers-solver" | ||
| version: "0.1.0" | ||
| description: "1D Burgers equation solver with pluggable neural viscosity closure (PyTorch)" | ||
|
|
||
| build_config: | ||
| target_platform: "native" |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,2 @@ | ||
| torch | ||
| tesseract-core |
Large diffs are not rendered by default.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,4 @@ | ||
| torch | ||
| matplotlib | ||
| tesseract-core | ||
| tesseract-torch |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,201 @@ | ||
| """Smoke tests for the learned closure demo (PyTorch version). | ||
|
|
||
| Tests the composition pattern: a plain torch.nn closure predicts a viscosity | ||
| field, then the solver Tesseract steps the Burgers' equation forward as a | ||
| differentiable layer. Gradients flow end-to-end from the loss, through the | ||
| solver's VJP (via apply_tesseract / torch.autograd), into the network weights. | ||
|
|
||
| These tests load the solver via ``Tesseract.from_tesseract_api`` (in-process, no | ||
| Docker) so they run fast as a local smoke check. The demo notebook itself uses | ||
| ``Tesseract.from_image`` to serve the solver in a container over HTTP — the same | ||
| ``apply_tesseract`` call path works either way. This is also the same pattern | ||
| that would work with a Fortran solver Tesseract backed by Enzyme or a | ||
| hand-written adjoint: the solver just needs apply + VJP with the interface | ||
| (u, nu_field, dt) -> u_next. The closure stays ordinary PyTorch. | ||
|
Comment on lines
+8
to
+14
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't mind it (especially if this is just a private test file not part of the demo itself), but just wanted to note there's a strong AI smell about this paragraph.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I do mind, thanks for catching it. |
||
| """ | ||
|
|
||
| import sys | ||
|
|
||
| sys.path.insert(0, "burgers_solver") | ||
|
|
||
| import burgers_solver.tesseract_api as solver_api | ||
| import numpy as np | ||
| import torch | ||
| import torch.nn as nn | ||
| from tesseract_torch import apply_tesseract | ||
|
|
||
| from tesseract_core import Tesseract | ||
|
|
||
| torch.set_default_dtype(torch.float64) | ||
|
|
||
| SOLVER_API_PATH = "burgers_solver/tesseract_api.py" | ||
|
|
||
| N = 128 | ||
| DX = 1.0 / (N - 1) | ||
| X_GRID = torch.linspace(0.0, 1.0, N) | ||
|
|
||
|
|
||
| class ViscosityNet(nn.Module): | ||
| """MLP closure: local flow features (u, du/dx, x) -> viscosity nu.""" | ||
|
|
||
| def __init__(self, hidden_dim=32, nu_max=0.05): | ||
| super().__init__() | ||
| self.nu_max = nu_max | ||
| self.net = nn.Sequential( | ||
| nn.Linear(3, hidden_dim), | ||
| nn.Tanh(), | ||
| nn.Linear(hidden_dim, hidden_dim), | ||
| nn.Tanh(), | ||
| nn.Linear(hidden_dim, 1), | ||
| ) | ||
|
|
||
| def forward(self, u, dudx, x): | ||
| features = torch.stack([u, dudx, x], dim=-1) | ||
| out = self.net(features)[:, 0] | ||
| return self.nu_max * torch.sigmoid(out) | ||
|
|
||
|
|
||
| def _make_initial_condition(): | ||
| """Smooth initial condition: a sine wave.""" | ||
| return torch.sin(2 * np.pi * X_GRID) | ||
|
|
||
|
|
||
| def test_closure_forward(): | ||
| print("=== Neural viscosity closure forward pass ===") | ||
| torch.manual_seed(0) | ||
| closure = ViscosityNet() | ||
| u0 = _make_initial_condition() | ||
| dudx = torch.gradient(u0, spacing=(DX,))[0] | ||
|
|
||
| with torch.no_grad(): | ||
| nu = closure(u0, dudx, X_GRID) | ||
|
|
||
| print(f" Shape: {nu.shape}, range: [{float(nu.min()):.4f}, {float(nu.max()):.4f}]") | ||
| assert nu.shape == (N,) | ||
| assert torch.all(nu > 0), "Viscosity must be positive" | ||
| print(" PASSED") | ||
|
|
||
|
|
||
| def test_solver_single_step(): | ||
| print("\n=== Solver single timestep ===") | ||
| u0 = _make_initial_condition() | ||
| nu = torch.full((N,), 0.01) | ||
| dt = 1e-4 | ||
|
|
||
| inputs = solver_api.InputSchema(u=u0, nu=nu, dt=dt) | ||
| out = solver_api.apply(inputs) | ||
| u_next = out["u_next"] | ||
|
|
||
| print(f" Shape: {u_next.shape}") | ||
| print(f" Max change: {float(torch.max(torch.abs(u_next - u0))):.6e}") | ||
| assert u_next.shape == (N,) | ||
| assert torch.all(torch.isfinite(u_next)), "Solution contains NaN or Inf" | ||
| # Boundary values should be preserved | ||
| assert float(u_next[0]) == float(u0[0]), "Left BC violated" | ||
| assert float(u_next[-1]) == float(u0[-1]), "Right BC violated" | ||
| print(" PASSED") | ||
|
|
||
|
|
||
| def test_solver_gradient(): | ||
| print("\n=== Solver gradient (VJP w.r.t. nu field) ===") | ||
| u0 = _make_initial_condition() | ||
| nu = torch.full((N,), 0.01, requires_grad=True) | ||
| dt = 1e-4 | ||
|
|
||
| tensor_inputs = { | ||
| "u": u0.clone(), | ||
| "nu": nu, | ||
| "dt": torch.tensor(dt), | ||
| } | ||
| out = solver_api.evaluate(tensor_inputs) | ||
| loss = torch.mean(out["u_next"] ** 2) | ||
| loss.backward() | ||
|
|
||
| grad_nu = nu.grad | ||
| print( | ||
| f" Gradient shape: {grad_nu.shape}, norm: {float(torch.linalg.norm(grad_nu)):.6e}" | ||
| ) | ||
| assert grad_nu.shape == (N,) | ||
| assert torch.all(torch.isfinite(grad_nu)) | ||
| print(" PASSED") | ||
|
|
||
|
|
||
| def _solve_with_closure(u0, closure, solver_tess, dt, n_steps): | ||
| u = u0 | ||
| for _step in range(n_steps): | ||
| dudx = torch.zeros_like(u) | ||
| dudx[1:-1] = (u[2:] - u[:-2]) / (2 * DX) | ||
| nu = closure(u, dudx, X_GRID) | ||
| solver_out = apply_tesseract(solver_tess, {"u": u, "nu": nu, "dt": dt}) | ||
| u = solver_out["u_next"] | ||
| return u | ||
|
|
||
|
|
||
| def test_composition_forward(): | ||
| """Outer loop: plain torch closure + solver Tesseract via apply_tesseract.""" | ||
| print("\n=== Composed forward pass (closure + solver Tesseract) ===") | ||
| solver_tess = Tesseract.from_tesseract_api(SOLVER_API_PATH) | ||
|
|
||
| torch.manual_seed(42) | ||
| closure = ViscosityNet() | ||
| u0 = _make_initial_condition() | ||
|
|
||
| with torch.no_grad(): | ||
| u = _solve_with_closure(u0, closure, solver_tess, dt=1e-4, n_steps=50) | ||
|
|
||
| print(f" Shape: {u.shape}") | ||
| print(f" Range: [{float(u.min()):.4f}, {float(u.max()):.4f}]") | ||
| assert u.shape == (N,) | ||
| assert torch.all(torch.isfinite(u)), "Solution contains NaN or Inf" | ||
| print(" PASSED") | ||
|
|
||
|
|
||
| def test_composition_gradient(): | ||
| """End-to-end gradient: loss -> solver VJP -> network weights.""" | ||
| print("\n=== End-to-end gradient (closure + solver Tesseract) ===") | ||
| solver_tess = Tesseract.from_tesseract_api(SOLVER_API_PATH) | ||
|
|
||
| torch.manual_seed(42) | ||
| closure = ViscosityNet() | ||
| u0 = _make_initial_condition() | ||
| target = 0.9 * u0 | ||
| n_steps = 20 | ||
|
|
||
| def run_forward(): | ||
| u = _solve_with_closure( | ||
| u0.clone(), closure, solver_tess, dt=1e-4, n_steps=n_steps | ||
| ) | ||
| return torch.mean((u - target) ** 2) | ||
|
|
||
| # AD gradient on one weight element of the first layer | ||
| closure.zero_grad() | ||
| loss = run_forward() | ||
| loss.backward() | ||
| w = closure.net[0].weight | ||
| idx = (0, 0) | ||
| ad_val = float(w.grad[idx]) | ||
|
|
||
| # Finite difference check on the same element | ||
| eps = 1e-5 | ||
| with torch.no_grad(): | ||
| orig = w[idx].item() | ||
| w[idx] = orig + eps | ||
| l_plus = float(run_forward()) | ||
| w[idx] = orig - eps | ||
| l_minus = float(run_forward()) | ||
| w[idx] = orig | ||
| fd = (l_plus - l_minus) / (2 * eps) | ||
|
|
||
| rel_err = abs(ad_val - fd) / (abs(fd) + 1e-30) | ||
| print(f" AD: {ad_val:.6e}, FD: {fd:.6e}, Rel error: {rel_err:.2e}") | ||
| assert rel_err < 1e-2, f"Gradient error too large: {rel_err}" | ||
| print(" PASSED") | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| test_closure_forward() | ||
| test_solver_single_step() | ||
| test_solver_gradient() | ||
| test_composition_forward() | ||
| test_composition_gradient() | ||
| print("\nAll smoke tests passed.") | ||
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
you might want a comment here about why you're not upwinding, using conservative form, ETDRK methods or anything else more fancy (e.g. low Reynolds, no shocks, role of nu_max etc.)