Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
198 changes: 198 additions & 0 deletions demo/bayesian-inference/_test_e2e.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
#!/usr/bin/env python
"""End-to-end test of the Bayesian inference demo notebook logic."""

import time

import jax
import jax.numpy as jnp
import numpy as np
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, NUTS, SA
from tesseract_jax import apply_tesseract

from tesseract_core import Tesseract

# Reduce sample counts for a faster test
NUM_WARMUP = 100
NUM_SAMPLES = 200

# ── Step 1: Serve the JAX Lorenz Tesseract ──────────────────────────────
print("=== Step 1: Serving JAX Lorenz Tesseract ===")
lorenz = Tesseract.from_image("lorenz-bayesian")
lorenz.serve()
print(f"Available endpoints: {lorenz.available_endpoints}")

# ── Step 2: Generate synthetic observations ─────────────────────────────
print("\n=== Step 2: Generating synthetic observations ===")
data = np.load("lorenz96_two_scale_F_18_sample_0_small.npz")
X_states = data["X_states"]
true_trajectory = X_states[500:]
X0 = true_trajectory[0]
x0_jax = jnp.array(X0, dtype=jnp.float32)

OBS_GAP = 10
N_OBS = 3
STD_OBS = 0.5
TRUE_F = 18.0
N_STEPS = OBS_GAP * N_OBS

# Generate self-consistent observations from the model itself
true_result = apply_tesseract(
lorenz,
{"state": x0_jax, "F": jnp.float32(TRUE_F), "dt": 0.005, "n_steps": N_STEPS},
)
true_traj = true_result["result"]
obs_indices = jnp.arange(OBS_GAP - 1, N_STEPS, OBS_GAP)
true_obs = true_traj[obs_indices]

key = jax.random.PRNGKey(42)
observations = true_obs + STD_OBS * jax.random.normal(key, true_obs.shape)
print(f"Observations shape: {observations.shape}")

# ── Step 3: Test apply_tesseract works ──────────────────────────────────
print("\n=== Step 3: Testing apply_tesseract ===")
test_result = apply_tesseract(
lorenz,
{"state": x0_jax, "F": jnp.float32(TRUE_F), "dt": 0.005, "n_steps": N_STEPS},
)
print(f"apply_tesseract output keys: {list(test_result.keys())}")
print(f"result shape: {test_result['result'].shape}")

# ── Step 4: Test jax.grad flows through ─────────────────────────────────
print("\n=== Step 4: Testing jax.grad through Tesseract ===")


def loss_fn(F_val):
result = apply_tesseract(
lorenz,
{"state": x0_jax, "F": F_val, "dt": 0.005, "n_steps": OBS_GAP},
)
return jnp.sum(result["result"] ** 2)


grad_F = jax.grad(loss_fn)(jnp.float32(18.0))
print(f"grad w.r.t. F: {grad_F}")
assert not jnp.isnan(grad_F), "Gradient is NaN!"
print("Gradient OK")

# ── Step 5: Define NumPyro model and run NUTS ───────────────────────────
print("\n=== Step 5: Running NUTS ===")


def bayesian_lorenz_model(observations, x0, obs_gap, n_obs, std_obs):
F = numpyro.sample("F", dist.Normal(15.0, 5.0))
result = apply_tesseract(
lorenz,
{"state": x0, "F": F, "dt": 0.005, "n_steps": obs_gap * n_obs},
)
trajectory = result["result"]
obs_idx = jnp.arange(obs_gap - 1, obs_gap * n_obs, obs_gap)
predicted_obs = trajectory[obs_idx]
numpyro.sample("obs", dist.Normal(predicted_obs, std_obs), obs=observations)


nuts_kernel = NUTS(bayesian_lorenz_model)
mcmc_nuts = MCMC(
nuts_kernel, num_warmup=NUM_WARMUP, num_samples=NUM_SAMPLES, num_chains=1
)

start = time.time()
mcmc_nuts.run(
jax.random.PRNGKey(0),
observations=observations,
x0=x0_jax,
obs_gap=OBS_GAP,
n_obs=N_OBS,
std_obs=STD_OBS,
)
nuts_time = time.time() - start
mcmc_nuts.print_summary()

nuts_samples = mcmc_nuts.get_samples()
F_mean = float(nuts_samples["F"].mean())
F_std = float(nuts_samples["F"].std())
print(
f"\nNUTS: F = {F_mean:.2f} ± {F_std:.2f} (true: {TRUE_F}), time: {nuts_time:.1f}s"
)
assert abs(F_mean - TRUE_F) < 5.0, f"NUTS posterior mean too far from truth: {F_mean}"
print("NUTS posterior check PASSED")

# ── Step 6: Test finite-diff Tesseract ──────────────────────────────────
print("\n=== Step 6: Serving finite-diff Lorenz Tesseract ===")
lorenz_fd = Tesseract.from_image("lorenz-finitediff")
lorenz_fd.serve()
print(f"Available endpoints: {lorenz_fd.available_endpoints}")


def bayesian_lorenz_model_fd(observations, x0, obs_gap, n_obs, std_obs):
F = numpyro.sample("F", dist.Normal(15.0, 5.0))
result = apply_tesseract(
lorenz_fd,
{"state": x0, "F": F, "dt": 0.005, "n_steps": obs_gap * n_obs},
)
trajectory = result["result"]
obs_idx = jnp.arange(obs_gap - 1, obs_gap * n_obs, obs_gap)
predicted_obs = trajectory[obs_idx]
numpyro.sample("obs", dist.Normal(predicted_obs, std_obs), obs=observations)


nuts_fd_kernel = NUTS(bayesian_lorenz_model_fd)
mcmc_nuts_fd = MCMC(
nuts_fd_kernel, num_warmup=NUM_WARMUP, num_samples=NUM_SAMPLES, num_chains=1
)

start = time.time()
mcmc_nuts_fd.run(
jax.random.PRNGKey(0),
observations=observations,
x0=x0_jax,
obs_gap=OBS_GAP,
n_obs=N_OBS,
std_obs=STD_OBS,
)
nuts_fd_time = time.time() - start
mcmc_nuts_fd.print_summary()

fd_samples = mcmc_nuts_fd.get_samples()
F_mean_fd = float(fd_samples["F"].mean())
print(
f"\nNUTS (FD): F = {F_mean_fd:.2f} ± {float(fd_samples['F'].std()):.2f} (true: {TRUE_F}), time: {nuts_fd_time:.1f}s"
)
assert abs(F_mean_fd - TRUE_F) < 5.0, (
f"NUTS-FD posterior mean too far from truth: {F_mean_fd}"
)
print("NUTS-FD posterior check PASSED")

# ── Step 7: Gradient-free baseline ──────────────────────────────────────
print("\n=== Step 7: Running SA (gradient-free) ===")
sa_kernel = SA(bayesian_lorenz_model)
mcmc_sa = MCMC(
sa_kernel, num_warmup=NUM_WARMUP * 5, num_samples=NUM_SAMPLES, num_chains=1
)

start = time.time()
mcmc_sa.run(
jax.random.PRNGKey(0),
observations=observations,
x0=x0_jax,
obs_gap=OBS_GAP,
n_obs=N_OBS,
std_obs=STD_OBS,
)
sa_time = time.time() - start
mcmc_sa.print_summary()

sa_samples = mcmc_sa.get_samples()
F_mean_sa = float(sa_samples["F"].mean())
print(
f"\nSA: F = {F_mean_sa:.2f} ± {float(sa_samples['F'].std()):.2f} (true: {TRUE_F}), time: {sa_time:.1f}s"
)
# SA is expected to perform poorly — no assertion on accuracy

# ── Cleanup ─────────────────────────────────────────────────────────────
print("\n=== Cleanup ===")
lorenz.teardown()
lorenz_fd.teardown()

print("\n=== ALL STEPS PASSED ===")
758 changes: 758 additions & 0 deletions demo/bayesian-inference/demo.ipynb

Large diffs are not rendered by default.

Binary file not shown.
186 changes: 186 additions & 0 deletions demo/bayesian-inference/lorenz_tesseract/tesseract_api.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
# Copyright 2025 Pasteur Labs. All Rights Reserved.
# SPDX-License-Identifier: Apache-2.0

"""Lorenz 96 Tesseract with differentiable forcing parameter F.

This is a variant of the Lorenz 96 Tesseract from the 4D-Var demo, modified
to mark the forcing parameter F as Differentiable. This allows JAX autodiff
(and thus NumPyro's NUTS) to compute gradients w.r.t. F for Bayesian
parameter inference.
"""

from typing import Any

import equinox as eqx
import jax
import jax.numpy as jnp
from pydantic import BaseModel, Field

from tesseract_core.runtime import Array, Differentiable, Float32
from tesseract_core.runtime.tree_transforms import filter_func, flatten_with_paths


class InputSchema(BaseModel):
"""Input schema for forecasting of Lorenz 96 system."""

state: Differentiable[Array[(None,), Float32]] = Field(
description="A state vector for the Lorenz 96 system"
)
F: Differentiable[Array[(), Float32]] = Field(
description="Forcing parameter for Lorenz 96", default=8.0
)
dt: float = Field(description="Time step for integration", default=0.05)
n_steps: int = Field(description="Number of integration steps", default=1)


class OutputSchema(BaseModel):
"""Output schema for forecasting of Lorenz 96 system."""

result: Differentiable[Array[(None, None), Float32]] = Field(
description="A trajectory of predictions after integration"
)


def lorenz96_step(state: jnp.ndarray, F: float, dt: float) -> jnp.ndarray:
"""Perform one step of RK4 integration for the Lorenz 96 system."""

def lorenz96_derivatives(x: jnp.ndarray) -> jnp.ndarray:
N = x.shape[0]
ip1 = (jnp.arange(N) + 1) % N
im1 = (jnp.arange(N) - 1) % N
im2 = (jnp.arange(N) - 2) % N
d = (x[ip1] - x[im2]) * x[im1] - x + F
return d

k1 = lorenz96_derivatives(state)
k2 = lorenz96_derivatives(state + dt * k1 / 2)
k3 = lorenz96_derivatives(state + dt * k2 / 2)
k4 = lorenz96_derivatives(state + dt * k3)
return state + dt * (k1 + 2 * k2 + 2 * k3 + k4) / 6


def lorenz96_multi_step(
state: jnp.ndarray, F: float, dt: float, n_steps: int
) -> jnp.ndarray:
"""Perform multiple steps of Lorenz 96 integration using scan."""

def step_fn(state: jnp.ndarray, _: Any) -> tuple:
return lorenz96_step(state, F, dt), state

_, trajectory = jax.lax.scan(step_fn, state, None, length=n_steps)
return trajectory


@eqx.filter_jit
def apply_jit(inputs: dict) -> dict:
trajectory = lorenz96_multi_step(**inputs)
return dict(result=trajectory)


def apply(inputs: InputSchema) -> OutputSchema:
return apply_jit(inputs.model_dump())


def abstract_eval(abstract_inputs: Any) -> Any:
"""Calculate output shape of apply from the shape of its inputs."""
is_shapedtype_dict = lambda x: type(x) is dict and (x.keys() == {"shape", "dtype"})
is_shapedtype_struct = lambda x: isinstance(x, jax.ShapeDtypeStruct)

jaxified_inputs = jax.tree.map(
lambda x: jax.ShapeDtypeStruct(**x) if is_shapedtype_dict(x) else x,
abstract_inputs.model_dump(),
is_leaf=is_shapedtype_dict,
)
dynamic_inputs, static_inputs = eqx.partition(
jaxified_inputs, filter_spec=is_shapedtype_struct
)

def wrapped_apply(dynamic_inputs: Any) -> Any:
inputs = eqx.combine(static_inputs, dynamic_inputs)
return apply_jit(inputs)

jax_shapes = jax.eval_shape(wrapped_apply, dynamic_inputs)
return jax.tree.map(
lambda x: (
{"shape": x.shape, "dtype": str(x.dtype)} if is_shapedtype_struct(x) else x
),
jax_shapes,
is_leaf=is_shapedtype_struct,
)


@eqx.filter_jit
def jac_jit(
inputs: dict,
jac_inputs: tuple[str],
jac_outputs: tuple[str],
) -> dict:
filtered_apply = filter_func(apply_jit, inputs, jac_outputs)
return jax.jacrev(filtered_apply)(
flatten_with_paths(inputs, include_paths=jac_inputs)
)


def jacobian(
inputs: InputSchema,
jac_inputs: set[str],
jac_outputs: set[str],
) -> Any:
return jac_jit(inputs.model_dump(), tuple(jac_inputs), tuple(jac_outputs))


@eqx.filter_jit
def jvp_jit(
inputs: dict,
jvp_inputs: tuple[str],
jvp_outputs: tuple[str],
tangent_vector: dict,
) -> Any:
filtered_apply = filter_func(apply_jit, inputs, jvp_outputs)
return jax.jvp(
filtered_apply,
[flatten_with_paths(inputs, include_paths=jvp_inputs)],
[tangent_vector],
)[1]


def jacobian_vector_product(
inputs: InputSchema,
jvp_inputs: set[str],
jvp_outputs: set[str],
tangent_vector: dict[str, Any],
) -> Any:
return jvp_jit(
inputs.model_dump(),
tuple(jvp_inputs),
tuple(jvp_outputs),
tangent_vector,
)


@eqx.filter_jit
def vjp_jit(
inputs: dict,
vjp_inputs: tuple[str],
vjp_outputs: tuple[str],
cotangent_vector: dict,
) -> Any:
filtered_apply = filter_func(apply_jit, inputs, vjp_outputs)
_, vjp_func = jax.vjp(
filtered_apply, flatten_with_paths(inputs, include_paths=vjp_inputs)
)
return vjp_func(cotangent_vector)[0]


def vector_jacobian_product(
inputs: InputSchema,
vjp_inputs: set[str],
vjp_outputs: set[str],
cotangent_vector: dict[str, Any],
) -> Any:
return vjp_jit(
inputs.model_dump(),
tuple(vjp_inputs),
tuple(vjp_outputs),
cotangent_vector,
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
# Tesseract configuration file

name: "lorenz-bayesian"
version: "0.1.0"
description: "Lorenz 96 Tesseract with differentiable forcing parameter for Bayesian inference"

build_config:
target_platform: "native"
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Tesseract requirements file
numpy==1.26.0
jax[cpu]==0.5.2
equinox
Loading
Loading