Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions AGENTS.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,4 +40,5 @@ Each of these is a separate repository/Python package.

- **Tesseract Core** is the main codebase that defines the Tesseract specification, the Python SDK for defining and building Tesseracts, and the runtime for executing Tesseracts in containers.
- **Tesseract-JAX** is a mature package that supports full integration of Tesseract calls into JAX programs, including JIT compilation and automatic differentiation of code that mixes Tesseract calls and JAX operations.
- **Tesseract-Torch** is the PyTorch counterpart to Tesseract-JAX: it embeds Tesseract calls as PyTorch operators so that `torch.autograd` flows through code that mixes Tesseract calls and PyTorch operations.
- **Tesseract-Streamlit** provides tools to auto-generate Streamlit apps from (externally running / locally built) Tesseracts. It can be used to quickly create interactive demos for Tesseracts and custom visualization without writing any Streamlit code, but is limited to forward application (`apply`).
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ with Tesseract.from_image("my-tesseract") as t:

- **[Tesseract Core](https://github.com/pasteurlabs/tesseract-core)** — CLI, Python SDK, and runtime (this repo).
- **[Tesseract-JAX](https://github.com/pasteurlabs/tesseract-jax)** — Embed Tesseracts as JAX primitives into end-to-end differentiable JAX programs.
- **[Tesseract-Torch](https://github.com/pasteurlabs/tesseract-torch)** — Embed Tesseracts as PyTorch operators into end-to-end differentiable PyTorch programs.
- **[Tesseract-Streamlit](https://github.com/pasteurlabs/tesseract-streamlit)** — Auto-generate interactive web apps from Tesseracts.

## Learn more
Expand Down
157 changes: 157 additions & 0 deletions demo/learned-closure/burgers_solver/tesseract_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
# Copyright 2025 Pasteur Labs. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0

"""Single-timestep Burgers' equation solver Tesseract (PyTorch).

Solves one explicit Euler step of the 1D viscous Burgers' equation:

u^{n+1} = u^n + dt * (-u * du/dx + nu * d²u/dx²)

The viscosity field nu is provided as an input — the solver does not compute it.
This clean interface (state + material field → next state) is the same contract
that a Fortran solver with an adjoint could implement. The outer time-stepping
loop and closure evaluation live in the caller, enabling per-timestep closure
calls and end-to-end gradient flow through both solver and closure.
"""

from typing import Any

import numpy as np
import torch
from pydantic import BaseModel, Field
from torch.utils._pytree import tree_map

from tesseract_core.runtime import Array, Differentiable, Float64
from tesseract_core.runtime.tree_transforms import filter_func, flatten_with_paths

# Default grid size
N = 128

# --- Grid setup (fixed for this Tesseract) ---
DX = 1.0 / (N - 1)

to_tensor = lambda x: (
torch.tensor(x, dtype=torch.float64)
if isinstance(x, np.generic | np.ndarray)
else x
)


class InputSchema(BaseModel):
u: Differentiable[Array[(N,), Float64]] = Field(
description="Current velocity field on the grid"
)
nu: Differentiable[Array[(N,), Float64]] = Field(
description="Viscosity field at each grid point (must be positive)"
)
dt: float = Field(description="Time step size", default=1e-4)


class OutputSchema(BaseModel):
u_next: Differentiable[Array[(N,), Float64]] = Field(
description="Velocity field after one time step"
)


def evaluate(inputs: dict) -> dict:
"""Core differentiable computation — pure torch operations."""
u = inputs["u"]
nu = inputs["nu"]
dt = inputs["dt"]

# Spatial derivatives via central differences
dudx = torch.zeros_like(u)
dudx[1:-1] = (u[2:] - u[:-2]) / (2 * DX)
Comment on lines +62 to +64

@jpbrodrick89 jpbrodrick89 Jun 25, 2026

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

you might want a comment here about why you're not upwinding, using conservative form, ETDRK methods or anything else more fancy (e.g. low Reynolds, no shocks, role of nu_max etc.)


d2udx2 = torch.zeros_like(u)
d2udx2[1:-1] = (u[2:] - 2 * u[1:-1] + u[:-2]) / (DX**2)

# Burgers' equation: du/dt = -u * du/dx + nu * d²u/dx²
dudt = -u * dudx + nu * d2udx2

# Forward Euler step
u_next = u + dt * dudt

# Enforce boundary conditions (Dirichlet: hold boundary values)
u_next = torch.cat([u[:1], u_next[1:-1], u[-1:]])

return {"u_next": u_next}


def apply(inputs: InputSchema) -> OutputSchema:
tensor_inputs = tree_map(to_tensor, inputs.model_dump())
return evaluate(tensor_inputs)


def abstract_eval(abstract_inputs: Any) -> Any:
return {"u_next": {"shape": [N], "dtype": "float64"}}


def jacobian_vector_product(
inputs: InputSchema,
jvp_inputs: set[str],
jvp_outputs: set[str],
tangent_vector: dict[str, Any],
):
jvp_inputs = tuple(jvp_inputs)
tangent_vector = {key: tangent_vector[key] for key in jvp_inputs}

tensor_inputs = tree_map(to_tensor, inputs.model_dump())
pos_tangent = tree_map(to_tensor, tangent_vector).values()
pos_inputs = flatten_with_paths(tensor_inputs, jvp_inputs).values()

filtered_pos_eval = filter_func(
evaluate, tensor_inputs, jvp_outputs, input_paths=jvp_inputs
)

return torch.func.jvp(filtered_pos_eval, tuple(pos_inputs), tuple(pos_tangent))[1]


def vector_jacobian_product(
inputs: InputSchema,
vjp_inputs: set[str],
vjp_outputs: set[str],
cotangent_vector: dict[str, Any],
):
vjp_inputs = tuple(vjp_inputs)
cotangent_vector = {key: cotangent_vector[key] for key in vjp_outputs}

tensor_inputs = tree_map(to_tensor, inputs.model_dump())
tensor_cotangent = tree_map(to_tensor, cotangent_vector)
pos_inputs = flatten_with_paths(tensor_inputs, vjp_inputs).values()

filtered_pos_func = filter_func(
evaluate, tensor_inputs, vjp_outputs, input_paths=vjp_inputs
)

_, vjp_func = torch.func.vjp(filtered_pos_func, *pos_inputs)
vjp_vals = vjp_func(tensor_cotangent)
return dict(zip(vjp_inputs, vjp_vals, strict=True))


def jacobian(
inputs: InputSchema,
jac_inputs: set[str],
jac_outputs: set[str],
):
jac_inputs = tuple(jac_inputs)
tensor_inputs = tree_map(to_tensor, inputs.model_dump())
pos_inputs = flatten_with_paths(tensor_inputs, jac_inputs).values()

filtered_pos_eval = filter_func(
evaluate, tensor_inputs, jac_outputs, input_paths=jac_inputs
)

def filtered_pos_eval_flat(*args):
res = filtered_pos_eval(*args)
return tuple(res[k] for k in jac_outputs)

jac = torch.autograd.functional.jacobian(filtered_pos_eval_flat, tuple(pos_inputs))

jac_dict = {}
for dy, dys in zip(jac_outputs, jac, strict=True):
jac_dict[dy] = {}
for dx, dxs in zip(jac_inputs, dys, strict=True):
jac_dict[dy][dx] = dxs

return jac_dict
6 changes: 6 additions & 0 deletions demo/learned-closure/burgers_solver/tesseract_config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
name: "burgers-solver"
version: "0.1.0"
description: "1D Burgers equation solver with pluggable neural viscosity closure (PyTorch)"

build_config:
target_platform: "native"
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
torch
tesseract-core
842 changes: 842 additions & 0 deletions demo/learned-closure/demo.ipynb

Large diffs are not rendered by default.

4 changes: 4 additions & 0 deletions demo/learned-closure/requirements.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
torch
matplotlib
tesseract-core
tesseract-torch
201 changes: 201 additions & 0 deletions demo/learned-closure/test_solvers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
"""Smoke tests for the learned closure demo (PyTorch version).

Tests the composition pattern: a plain torch.nn closure predicts a viscosity
field, then the solver Tesseract steps the Burgers' equation forward as a
differentiable layer. Gradients flow end-to-end from the loss, through the
solver's VJP (via apply_tesseract / torch.autograd), into the network weights.

These tests load the solver via ``Tesseract.from_tesseract_api`` (in-process, no
Docker) so they run fast as a local smoke check. The demo notebook itself uses
``Tesseract.from_image`` to serve the solver in a container over HTTP — the same
``apply_tesseract`` call path works either way. This is also the same pattern
that would work with a Fortran solver Tesseract backed by Enzyme or a
hand-written adjoint: the solver just needs apply + VJP with the interface
(u, nu_field, dt) -> u_next. The closure stays ordinary PyTorch.
Comment on lines +8 to +14

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't mind it (especially if this is just a private test file not part of the demo itself), but just wanted to note there's a strong AI smell about this paragraph.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do mind, thanks for catching it.

"""

import sys

sys.path.insert(0, "burgers_solver")

import burgers_solver.tesseract_api as solver_api
import numpy as np
import torch
import torch.nn as nn
from tesseract_torch import apply_tesseract

from tesseract_core import Tesseract

torch.set_default_dtype(torch.float64)

SOLVER_API_PATH = "burgers_solver/tesseract_api.py"

N = 128
DX = 1.0 / (N - 1)
X_GRID = torch.linspace(0.0, 1.0, N)


class ViscosityNet(nn.Module):
"""MLP closure: local flow features (u, du/dx, x) -> viscosity nu."""

def __init__(self, hidden_dim=32, nu_max=0.05):
super().__init__()
self.nu_max = nu_max
self.net = nn.Sequential(
nn.Linear(3, hidden_dim),
nn.Tanh(),
nn.Linear(hidden_dim, hidden_dim),
nn.Tanh(),
nn.Linear(hidden_dim, 1),
)

def forward(self, u, dudx, x):
features = torch.stack([u, dudx, x], dim=-1)
out = self.net(features)[:, 0]
return self.nu_max * torch.sigmoid(out)


def _make_initial_condition():
"""Smooth initial condition: a sine wave."""
return torch.sin(2 * np.pi * X_GRID)


def test_closure_forward():
print("=== Neural viscosity closure forward pass ===")
torch.manual_seed(0)
closure = ViscosityNet()
u0 = _make_initial_condition()
dudx = torch.gradient(u0, spacing=(DX,))[0]

with torch.no_grad():
nu = closure(u0, dudx, X_GRID)

print(f" Shape: {nu.shape}, range: [{float(nu.min()):.4f}, {float(nu.max()):.4f}]")
assert nu.shape == (N,)
assert torch.all(nu > 0), "Viscosity must be positive"
print(" PASSED")


def test_solver_single_step():
print("\n=== Solver single timestep ===")
u0 = _make_initial_condition()
nu = torch.full((N,), 0.01)
dt = 1e-4

inputs = solver_api.InputSchema(u=u0, nu=nu, dt=dt)
out = solver_api.apply(inputs)
u_next = out["u_next"]

print(f" Shape: {u_next.shape}")
print(f" Max change: {float(torch.max(torch.abs(u_next - u0))):.6e}")
assert u_next.shape == (N,)
assert torch.all(torch.isfinite(u_next)), "Solution contains NaN or Inf"
# Boundary values should be preserved
assert float(u_next[0]) == float(u0[0]), "Left BC violated"
assert float(u_next[-1]) == float(u0[-1]), "Right BC violated"
print(" PASSED")


def test_solver_gradient():
print("\n=== Solver gradient (VJP w.r.t. nu field) ===")
u0 = _make_initial_condition()
nu = torch.full((N,), 0.01, requires_grad=True)
dt = 1e-4

tensor_inputs = {
"u": u0.clone(),
"nu": nu,
"dt": torch.tensor(dt),
}
out = solver_api.evaluate(tensor_inputs)
loss = torch.mean(out["u_next"] ** 2)
loss.backward()

grad_nu = nu.grad
print(
f" Gradient shape: {grad_nu.shape}, norm: {float(torch.linalg.norm(grad_nu)):.6e}"
)
assert grad_nu.shape == (N,)
assert torch.all(torch.isfinite(grad_nu))
print(" PASSED")


def _solve_with_closure(u0, closure, solver_tess, dt, n_steps):
u = u0
for _step in range(n_steps):
dudx = torch.zeros_like(u)
dudx[1:-1] = (u[2:] - u[:-2]) / (2 * DX)
nu = closure(u, dudx, X_GRID)
solver_out = apply_tesseract(solver_tess, {"u": u, "nu": nu, "dt": dt})
u = solver_out["u_next"]
return u


def test_composition_forward():
"""Outer loop: plain torch closure + solver Tesseract via apply_tesseract."""
print("\n=== Composed forward pass (closure + solver Tesseract) ===")
solver_tess = Tesseract.from_tesseract_api(SOLVER_API_PATH)

torch.manual_seed(42)
closure = ViscosityNet()
u0 = _make_initial_condition()

with torch.no_grad():
u = _solve_with_closure(u0, closure, solver_tess, dt=1e-4, n_steps=50)

print(f" Shape: {u.shape}")
print(f" Range: [{float(u.min()):.4f}, {float(u.max()):.4f}]")
assert u.shape == (N,)
assert torch.all(torch.isfinite(u)), "Solution contains NaN or Inf"
print(" PASSED")


def test_composition_gradient():
"""End-to-end gradient: loss -> solver VJP -> network weights."""
print("\n=== End-to-end gradient (closure + solver Tesseract) ===")
solver_tess = Tesseract.from_tesseract_api(SOLVER_API_PATH)

torch.manual_seed(42)
closure = ViscosityNet()
u0 = _make_initial_condition()
target = 0.9 * u0
n_steps = 20

def run_forward():
u = _solve_with_closure(
u0.clone(), closure, solver_tess, dt=1e-4, n_steps=n_steps
)
return torch.mean((u - target) ** 2)

# AD gradient on one weight element of the first layer
closure.zero_grad()
loss = run_forward()
loss.backward()
w = closure.net[0].weight
idx = (0, 0)
ad_val = float(w.grad[idx])

# Finite difference check on the same element
eps = 1e-5
with torch.no_grad():
orig = w[idx].item()
w[idx] = orig + eps
l_plus = float(run_forward())
w[idx] = orig - eps
l_minus = float(run_forward())
w[idx] = orig
fd = (l_plus - l_minus) / (2 * eps)

rel_err = abs(ad_val - fd) / (abs(fd) + 1e-30)
print(f" AD: {ad_val:.6e}, FD: {fd:.6e}, Rel error: {rel_err:.2e}")
assert rel_err < 1e-2, f"Gradient error too large: {rel_err}"
print(" PASSED")


if __name__ == "__main__":
test_closure_forward()
test_solver_single_step()
test_solver_gradient()
test_composition_forward()
test_composition_gradient()
print("\nAll smoke tests passed.")
Loading
Loading