Skip to content
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions AGENTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,4 +40,5 @@ Each of these is a separate repository/Python package.

- **Tesseract Core** is the main codebase that defines the Tesseract specification, the Python SDK for defining and building Tesseracts, and the runtime for executing Tesseracts in containers.
- **Tesseract-JAX** is a mature package that supports full integration of Tesseract calls into JAX programs, including JIT compilation and automatic differentiation of code that mixes Tesseract calls and JAX operations.
- **Tesseract-Torch** is the PyTorch counterpart to Tesseract-JAX: it embeds Tesseract calls as PyTorch operators so that `torch.autograd` flows through code that mixes Tesseract calls and PyTorch operations.
- **Tesseract-Streamlit** provides tools to auto-generate Streamlit apps from (externally running / locally built) Tesseracts. It can be used to quickly create interactive demos for Tesseracts and custom visualization without writing any Streamlit code, but is limited to forward application (`apply`).
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ with Tesseract.from_image("my-tesseract") as t:

- **[Tesseract Core](https://github.com/pasteurlabs/tesseract-core)** — CLI, Python SDK, and runtime (this repo).
- **[Tesseract-JAX](https://github.com/pasteurlabs/tesseract-jax)** — Embed Tesseracts as JAX primitives into end-to-end differentiable JAX programs.
- **[Tesseract-Torch](https://github.com/pasteurlabs/tesseract-torch)** — Embed Tesseracts as PyTorch operators into end-to-end differentiable PyTorch programs.
- **[Tesseract-Streamlit](https://github.com/pasteurlabs/tesseract-streamlit)** — Auto-generate interactive web apps from Tesseracts.

## Learn more
Expand Down
157 changes: 157 additions & 0 deletions demo/learned-closure/burgers_solver/tesseract_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
# Copyright 2025 Pasteur Labs. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0

"""Single-timestep Burgers' equation solver Tesseract (PyTorch).

Solves one explicit Euler step of the 1D viscous Burgers' equation:

u^{n+1} = u^n + dt * (-u * du/dx + nu * d²u/dx²)

The viscosity field nu is provided as an input — the solver does not compute it.
This clean interface (state + material field → next state) is the same contract
that a Fortran solver with an adjoint could implement. The outer time-stepping
loop and closure evaluation live in the caller, enabling per-timestep closure
calls and end-to-end gradient flow through both solver and closure.
"""

from typing import Any

import numpy as np
import torch
from pydantic import BaseModel, Field
from torch.utils._pytree import tree_map

from tesseract_core.runtime import Array, Differentiable, Float64
from tesseract_core.runtime.tree_transforms import filter_func, flatten_with_paths

# Default grid size
N = 128

# --- Grid setup (fixed for this Tesseract) ---
DX = 1.0 / (N - 1)

to_tensor = lambda x: (
torch.tensor(x, dtype=torch.float64)
if isinstance(x, np.generic | np.ndarray)
else x
)


class InputSchema(BaseModel):
u: Differentiable[Array[(N,), Float64]] = Field(
description="Current velocity field on the grid"
)
nu: Differentiable[Array[(N,), Float64]] = Field(
description="Viscosity field at each grid point (must be positive)"
)
dt: float = Field(description="Time step size", default=1e-4)


class OutputSchema(BaseModel):
u_next: Differentiable[Array[(N,), Float64]] = Field(
description="Velocity field after one time step"
)


def evaluate(inputs: dict) -> dict:
"""Core differentiable computation — pure torch operations."""
u = inputs["u"]
nu = inputs["nu"]
dt = inputs["dt"]

# Spatial derivatives via central differences
dudx = torch.zeros_like(u)
dudx[1:-1] = (u[2:] - u[:-2]) / (2 * DX)
Comment on lines +62 to +64

@jpbrodrick89 jpbrodrick89 Jun 25, 2026

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you might want a comment here about why you're not upwinding, using conservative form, ETDRK methods or anything else more fancy (e.g. low Reynolds, no shocks, role of nu_max etc.)


d2udx2 = torch.zeros_like(u)
d2udx2[1:-1] = (u[2:] - 2 * u[1:-1] + u[:-2]) / (DX**2)

# Burgers' equation: du/dt = -u * du/dx + nu * d²u/dx²
dudt = -u * dudx + nu * d2udx2

# Forward Euler step
u_next = u + dt * dudt

# Enforce boundary conditions (Dirichlet: hold boundary values)
u_next = torch.cat([u[:1], u_next[1:-1], u[-1:]])

return {"u_next": u_next}


def apply(inputs: InputSchema) -> OutputSchema:
tensor_inputs = tree_map(to_tensor, inputs.model_dump())
return evaluate(tensor_inputs)


def abstract_eval(abstract_inputs: Any) -> Any:
return {"u_next": {"shape": [N], "dtype": "float64"}}


def jacobian_vector_product(
inputs: InputSchema,
jvp_inputs: set[str],
jvp_outputs: set[str],
tangent_vector: dict[str, Any],
):
jvp_inputs = tuple(jvp_inputs)
tangent_vector = {key: tangent_vector[key] for key in jvp_inputs}

tensor_inputs = tree_map(to_tensor, inputs.model_dump())
pos_tangent = tree_map(to_tensor, tangent_vector).values()
pos_inputs = flatten_with_paths(tensor_inputs, jvp_inputs).values()

filtered_pos_eval = filter_func(
evaluate, tensor_inputs, jvp_outputs, input_paths=jvp_inputs
)

return torch.func.jvp(filtered_pos_eval, tuple(pos_inputs), tuple(pos_tangent))[1]


def vector_jacobian_product(
inputs: InputSchema,
vjp_inputs: set[str],
vjp_outputs: set[str],
cotangent_vector: dict[str, Any],
):
vjp_inputs = tuple(vjp_inputs)
cotangent_vector = {key: cotangent_vector[key] for key in vjp_outputs}

tensor_inputs = tree_map(to_tensor, inputs.model_dump())
tensor_cotangent = tree_map(to_tensor, cotangent_vector)
pos_inputs = flatten_with_paths(tensor_inputs, vjp_inputs).values()

filtered_pos_func = filter_func(
evaluate, tensor_inputs, vjp_outputs, input_paths=vjp_inputs
)

_, vjp_func = torch.func.vjp(filtered_pos_func, *pos_inputs)
vjp_vals = vjp_func(tensor_cotangent)
return dict(zip(vjp_inputs, vjp_vals, strict=True))


def jacobian(
inputs: InputSchema,
jac_inputs: set[str],
jac_outputs: set[str],
):
jac_inputs = tuple(jac_inputs)
tensor_inputs = tree_map(to_tensor, inputs.model_dump())
pos_inputs = flatten_with_paths(tensor_inputs, jac_inputs).values()

filtered_pos_eval = filter_func(
evaluate, tensor_inputs, jac_outputs, input_paths=jac_inputs
)

def filtered_pos_eval_flat(*args):
res = filtered_pos_eval(*args)
return tuple(res[k] for k in jac_outputs)

jac = torch.autograd.functional.jacobian(filtered_pos_eval_flat, tuple(pos_inputs))

jac_dict = {}
for dy, dys in zip(jac_outputs, jac, strict=True):
jac_dict[dy] = {}
for dx, dxs in zip(jac_inputs, dys, strict=True):
jac_dict[dy][dx] = dxs

return jac_dict
6 changes: 6 additions & 0 deletions demo/learned-closure/burgers_solver/tesseract_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
name: "burgers-solver"
version: "0.1.0"
description: "1D Burgers equation solver with pluggable neural viscosity closure (PyTorch)"

build_config:
target_platform: "native"
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
torch
tesseract-core
842 changes: 842 additions & 0 deletions demo/learned-closure/demo.ipynb

Large diffs are not rendered by default.

4 changes: 4 additions & 0 deletions demo/learned-closure/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
torch
matplotlib
tesseract-core
tesseract-torch
201 changes: 201 additions & 0 deletions demo/learned-closure/test_solvers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
"""Smoke tests for the learned closure demo (PyTorch version).

Tests the composition pattern: a plain torch.nn closure predicts a viscosity
field, then the solver Tesseract steps the Burgers' equation forward as a
differentiable layer. Gradients flow end-to-end from the loss, through the
solver's VJP (via apply_tesseract / torch.autograd), into the network weights.

These tests load the solver via ``Tesseract.from_tesseract_api`` (in-process, no
Docker) so they run fast as a local smoke check. The demo notebook itself uses
``Tesseract.from_image`` to serve the solver in a container over HTTP — the same
``apply_tesseract`` call path works either way. This is also the same pattern
that would work with a Fortran solver Tesseract backed by Enzyme or a
hand-written adjoint: the solver just needs apply + VJP with the interface
(u, nu_field, dt) -> u_next. The closure stays ordinary PyTorch.
Comment on lines +8 to +14

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't mind it (especially if this is just a private test file not part of the demo itself), but just wanted to note there's a strong AI smell about this paragraph.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do mind, thanks for catching it.

"""

import sys

sys.path.insert(0, "burgers_solver")

import burgers_solver.tesseract_api as solver_api
import numpy as np
import torch
import torch.nn as nn
from tesseract_torch import apply_tesseract

from tesseract_core import Tesseract

torch.set_default_dtype(torch.float64)

SOLVER_API_PATH = "burgers_solver/tesseract_api.py"

N = 128
DX = 1.0 / (N - 1)
X_GRID = torch.linspace(0.0, 1.0, N)


class ViscosityNet(nn.Module):
"""MLP closure: local flow features (u, du/dx, x) -> viscosity nu."""

def __init__(self, hidden_dim=32, nu_max=0.05):
super().__init__()
self.nu_max = nu_max
self.net = nn.Sequential(
nn.Linear(3, hidden_dim),
nn.Tanh(),
nn.Linear(hidden_dim, hidden_dim),
nn.Tanh(),
nn.Linear(hidden_dim, 1),
)

def forward(self, u, dudx, x):
features = torch.stack([u, dudx, x], dim=-1)
out = self.net(features)[:, 0]
return self.nu_max * torch.sigmoid(out)


def _make_initial_condition():
"""Smooth initial condition: a sine wave."""
return torch.sin(2 * np.pi * X_GRID)


def test_closure_forward():
print("=== Neural viscosity closure forward pass ===")
torch.manual_seed(0)
closure = ViscosityNet()
u0 = _make_initial_condition()
dudx = torch.gradient(u0, spacing=(DX,))[0]

with torch.no_grad():
nu = closure(u0, dudx, X_GRID)

print(f" Shape: {nu.shape}, range: [{float(nu.min()):.4f}, {float(nu.max()):.4f}]")
assert nu.shape == (N,)
assert torch.all(nu > 0), "Viscosity must be positive"
print(" PASSED")


def test_solver_single_step():
print("\n=== Solver single timestep ===")
u0 = _make_initial_condition()
nu = torch.full((N,), 0.01)
dt = 1e-4

inputs = solver_api.InputSchema(u=u0, nu=nu, dt=dt)
out = solver_api.apply(inputs)
u_next = out["u_next"]

print(f" Shape: {u_next.shape}")
print(f" Max change: {float(torch.max(torch.abs(u_next - u0))):.6e}")
assert u_next.shape == (N,)
assert torch.all(torch.isfinite(u_next)), "Solution contains NaN or Inf"
# Boundary values should be preserved
assert float(u_next[0]) == float(u0[0]), "Left BC violated"
assert float(u_next[-1]) == float(u0[-1]), "Right BC violated"
print(" PASSED")


def test_solver_gradient():
print("\n=== Solver gradient (VJP w.r.t. nu field) ===")
u0 = _make_initial_condition()
nu = torch.full((N,), 0.01, requires_grad=True)
dt = 1e-4

tensor_inputs = {
"u": u0.clone(),
"nu": nu,
"dt": torch.tensor(dt),
}
out = solver_api.evaluate(tensor_inputs)
loss = torch.mean(out["u_next"] ** 2)
loss.backward()

grad_nu = nu.grad
print(
f" Gradient shape: {grad_nu.shape}, norm: {float(torch.linalg.norm(grad_nu)):.6e}"
)
assert grad_nu.shape == (N,)
assert torch.all(torch.isfinite(grad_nu))
print(" PASSED")


def _solve_with_closure(u0, closure, solver_tess, dt, n_steps):
u = u0
for _step in range(n_steps):
dudx = torch.zeros_like(u)
dudx[1:-1] = (u[2:] - u[:-2]) / (2 * DX)
nu = closure(u, dudx, X_GRID)
solver_out = apply_tesseract(solver_tess, {"u": u, "nu": nu, "dt": dt})
u = solver_out["u_next"]
return u


def test_composition_forward():
"""Outer loop: plain torch closure + solver Tesseract via apply_tesseract."""
print("\n=== Composed forward pass (closure + solver Tesseract) ===")
solver_tess = Tesseract.from_tesseract_api(SOLVER_API_PATH)

torch.manual_seed(42)
closure = ViscosityNet()
u0 = _make_initial_condition()

with torch.no_grad():
u = _solve_with_closure(u0, closure, solver_tess, dt=1e-4, n_steps=50)

print(f" Shape: {u.shape}")
print(f" Range: [{float(u.min()):.4f}, {float(u.max()):.4f}]")
assert u.shape == (N,)
assert torch.all(torch.isfinite(u)), "Solution contains NaN or Inf"
print(" PASSED")


def test_composition_gradient():
"""End-to-end gradient: loss -> solver VJP -> network weights."""
print("\n=== End-to-end gradient (closure + solver Tesseract) ===")
solver_tess = Tesseract.from_tesseract_api(SOLVER_API_PATH)

torch.manual_seed(42)
closure = ViscosityNet()
u0 = _make_initial_condition()
target = 0.9 * u0
n_steps = 20

def run_forward():
u = _solve_with_closure(
u0.clone(), closure, solver_tess, dt=1e-4, n_steps=n_steps
)
return torch.mean((u - target) ** 2)

# AD gradient on one weight element of the first layer
closure.zero_grad()
loss = run_forward()
loss.backward()
w = closure.net[0].weight
idx = (0, 0)
ad_val = float(w.grad[idx])

# Finite difference check on the same element
eps = 1e-5
with torch.no_grad():
orig = w[idx].item()
w[idx] = orig + eps
l_plus = float(run_forward())
w[idx] = orig - eps
l_minus = float(run_forward())
w[idx] = orig
fd = (l_plus - l_minus) / (2 * eps)

rel_err = abs(ad_val - fd) / (abs(fd) + 1e-30)
print(f" AD: {ad_val:.6e}, FD: {fd:.6e}, Rel error: {rel_err:.2e}")
assert rel_err < 1e-2, f"Gradient error too large: {rel_err}"
print(" PASSED")


if __name__ == "__main__":
test_closure_forward()
test_solver_single_step()
test_solver_gradient()
test_composition_forward()
test_composition_gradient()
print("\nAll smoke tests passed.")
Loading
Loading