Skip to content
Open
Show file tree
Hide file tree
Changes from 24 commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
ad620df
add enzyme integration proof of concept
dionhaefner Mar 25, 2026
b603f3d
Merge branch 'main' into dion/enzyme-demo
dionhaefner Apr 16, 2026
f11425d
add demo notebook
dionhaefner Apr 17, 2026
271eb76
fix enzyme
dionhaefner May 15, 2026
861d6e6
Merge remote-tracking branch 'origin/main' into dion/enzyme-demo
dionhaefner May 15, 2026
9d2f490
add blog post draft
dionhaefner May 15, 2026
74bae21
checkpoint
dionhaefner Jun 1, 2026
2b463df
wip on blog post
dionhaefner Jun 2, 2026
550245a
fix: green up CI for enzyme integration PR
dionhaefner Jun 2, 2026
d8235a7
doc: add enzyme_thermal_2d demo to docs toctree
dionhaefner Jun 2, 2026
b0d955d
doc: make Tesseract-JAX the through-line in the Enzyme blog post
dionhaefner Jun 2, 2026
23046cf
execute notebook
dionhaefner Jun 2, 2026
01617f5
Merge branch 'main' into dion/enzyme-demo
dionhaefner Jun 2, 2026
32d0e12
style(docs): tighten blog code-block and table spacing
dionhaefner Jun 2, 2026
a1808c8
tighten prose + diagrams
dionhaefner Jun 2, 2026
4186d8a
iterate on diagram
dionhaefner Jun 2, 2026
f03f1e4
Merge branch 'dion/enzyme-demo' of github.com:pasteurlabs/tesseract-c…
dionhaefner Jun 2, 2026
243aa2d
Merge branch 'main' into dion/enzyme-demo
dionhaefner Jun 2, 2026
4e869db
fine-tune blog posts
dionhaefner Jun 16, 2026
4906c55
Merge branch 'main' into dion/enzyme-demo
dionhaefner Jun 16, 2026
8a0f3ae
move blog post footer to template
dionhaefner Jun 16, 2026
cf380d3
more copy-edits, add llvm snippets
dionhaefner Jun 17, 2026
1144d7a
shout out to forum showcase
dionhaefner Jun 17, 2026
00d9ee0
Merge branch 'main' into dion/enzyme-demo
dionhaefner Jun 17, 2026
bb5f185
rename demo; copy-editing based on feedback
dionhaefner Jun 29, 2026
5d4d846
add building blocks page; add demo to landing page; copyedits
dionhaefner Jun 29, 2026
01ee624
tighten example page
dionhaefner Jun 29, 2026
f223108
update date on blog post
dionhaefner Jun 29, 2026
55d4d37
Merge remote-tracking branch 'origin/main' into dion/enzyme-demo
dionhaefner Jun 29, 2026
51f480e
fix tests
dionhaefner Jun 29, 2026
9e4bded
be less strict about LLVM GPG keys
dionhaefner Jun 29, 2026
5c7e83d
no ifs and buts
dionhaefner Jun 29, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
304 changes: 304 additions & 0 deletions demo/_showcase/enzyme-optimization.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,304 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Inverse Heat Conduction via Enzyme AD\n",
"\n",
"In this tutorial, you will learn how to:\n",
"\n",
"1. **Build a Tesseract** that wraps a Fortran heat equation solver differentiated by [Enzyme](https://enzyme.mit.edu/) (an LLVM-based automatic differentiation compiler plugin)\n",
"2. **Embed it in a JAX pipeline** via [tesseract-jax](https://github.com/pasteurlabs/tesseract-jax), making it a native JAX-differentiable function\n",
"3. **Solve an inverse problem**: recover the thermal diffusivity of a material from temperature observations, using `jax.grad` to differentiate *through* the compiled Fortran solver\n",
"\n",
"## Why this matters\n",
"\n",
"Scientific computing is full of Fortran and C code that researchers need gradients of -- for optimization, inverse problems, uncertainty quantification, or integration with ML pipelines. The traditional options are painful:\n",
"\n",
"- **Hand-written adjoints**: months of expert effort, error-prone, a maintenance nightmare\n",
"- **Finite differences**: slow ($O(n)$ evaluations per gradient), inaccurate (truncation vs. roundoff tradeoff)\n",
"- **Rewrite in JAX/PyTorch**: impractical for existing codebases\n",
"\n",
"Enzyme offers a fourth option: **automatic differentiation at the LLVM IR level**. It takes compiled code and synthesizes exact derivative functions -- no source modifications, no manual adjoints, no approximation. And because it operates on LLVM IR, it works with any language that compiles to it: Fortran, C, C++, Rust, and more.\n",
"\n",
"In this demo, we differentiate a Fortran heat equation solver using Enzyme, package it as a [Tesseract](https://github.com/pasteurlabs/tesseract), and use the exact gradients to solve an inverse problem entirely within JAX."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Install additional requirements for this notebook\n",
"%pip install tesseract-jax -q"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Step 1: Build and serve the Enzyme AD Tesseract\n",
"\n",
"The `enzyme-ad` Tesseract wraps a Fortran implementation of a single explicit Euler step of the 1D heat equation:\n",
"\n",
"$$\\frac{\\partial T}{\\partial t} = \\alpha \\frac{\\partial^2 T}{\\partial x^2}$$\n",
"\n",
"The Fortran source (`heat_step.f90`) is just ~30 lines -- a simple finite difference stencil:\n",
"\n",
"```fortran\n",
"do i = 2, n - 1\n",
" T_out(i) = T_in(i) + r * (T_in(i-1) - 2.0d0*T_in(i) + T_in(i+1))\n",
"end do\n",
"```\n",
"\n",
"During `tesseract build`, the compilation pipeline runs:\n",
"\n",
"```\n",
"Fortran --> LFortran --> LLVM IR --> Enzyme AD pass --> libheat_ad.so\n",
"```\n",
"\n",
"Enzyme analyzes the LLVM IR and generates both forward-mode (JVP) and reverse-mode (VJP) derivative functions automatically. The resulting shared library has three entry points: `heat_step_forward`, `heat_step_jvp`, and `heat_step_vjp`."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%%bash\n",
"tesseract build ../../examples/enzyme_ad/"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from tesseract_core import Tesseract\n",
"\n",
"heat_tesseract = Tesseract.from_image(\"enzyme-ad\")\n",
"heat_tesseract.serve()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Step 2: Test a forward evaluation\n",
"\n",
"Before optimizing, let's verify the Tesseract works. We set up a temperature profile (a Gaussian bump on a uniform grid) and run one heat equation step."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": "import jax\nimport jax.numpy as jnp\nimport matplotlib.pyplot as plt\nimport numpy as np\nfrom tesseract_jax import apply_tesseract\n\njax.config.update(\"jax_enable_x64\", True)\n\n# Grid setup\nn_points = 50\ndx = 1.0 / (n_points - 1)\nx = jnp.linspace(0.0, 1.0, n_points)\n\n# Initial temperature: Gaussian bump\nT_init = jnp.exp(-((x - 0.5) ** 2) / (2 * 0.05**2))\n# Fix boundary conditions to zero\nT_init = T_init.at[0].set(0.0).at[-1].set(0.0)\n\n# Physical parameters\nalpha_true = 0.02 # true thermal diffusivity\ndt = 0.0001\n\n\ndef heat_step(T_in, alpha):\n \"\"\"Run one heat equation step through the Enzyme-differentiated Fortran solver.\"\"\"\n result = apply_tesseract(\n heat_tesseract,\n {\"T_in\": T_in, \"alpha\": alpha, \"dx\": dx, \"dt\": dt},\n )\n return result[\"T_out\"]\n\n\n# Run one step\nT_after_one = heat_step(T_init, alpha_true)\n\nplt.figure(figsize=(8, 4))\nplt.plot(x, T_init, label=\"Initial\", linewidth=2)\nplt.plot(x, T_after_one, label=\"After 1 step\", linewidth=2)\nplt.xlabel(\"x\")\nplt.ylabel(\"T\")\nplt.legend()\nplt.title(\"Single heat equation step (Enzyme-differentiated Fortran solver)\")\nplt.tight_layout()"
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Step 3: Run multiple steps to generate synthetic observations\n",
"\n",
"For the inverse problem, we need \"observed\" data. We'll run the solver forward for many steps with the true $\\alpha$ to produce a final temperature profile, then pretend we only know this final profile and try to recover $\\alpha$."
]
},
{
"cell_type": "code",
"source": "def simulate(T_init, alpha, n_steps):\n \"\"\"Run the heat equation forward for n_steps.\"\"\"\n T = T_init\n for _ in range(n_steps):\n T = heat_step(T, alpha)\n return T\n\n\n# Generate \"observed\" data with the true alpha\nn_steps = 200\nT_observed = jax.jit(simulate, static_argnums=(2,))(T_init, alpha_true, n_steps)\n\nplt.figure(figsize=(8, 4))\nplt.plot(x, T_init, label=\"Initial condition\", linewidth=2)\nplt.plot(x, T_observed, \"k--\", label=f\"Observed (after {n_steps} steps)\", linewidth=2)\nplt.xlabel(\"x\")\nplt.ylabel(\"T\")\nplt.legend()\nplt.title(\"Forward simulation with true $\\\\alpha$\")\nplt.tight_layout()",
"metadata": {},
"execution_count": null,
"outputs": []
},
{
"cell_type": "markdown",
"metadata": {},
"source": "## Step 4: Solve the inverse problem\n\nNow the fun part. We want to recover $\\alpha$ from the observed final temperature profile. We define the loss as the mean squared error between the simulated and observed profiles:\n\n$$\\mathcal{L}(\\alpha) = \\text{MSE}\\big(T_{\\text{sim}}(\\alpha),\\; T_{\\text{obs}}\\big)$$\n\nBecause the Tesseract exposes Enzyme-generated VJPs, `jax.grad` can differentiate through the entire multi-step simulation -- computing $\\partial \\mathcal{L}/\\partial \\alpha$ by backpropagating through hundreds of Fortran solver calls. We use L-BFGS-B from `scipy.optimize` to find the optimal $\\alpha$."
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def loss_fn(alpha):\n",
" \"\"\"MSE between simulated and observed temperature profiles.\"\"\"\n",
" T_sim = simulate(T_init, alpha, n_steps)\n",
" return jnp.mean((T_sim - T_observed) ** 2)\n",
"\n",
"\n",
"# Verify gradients work\n",
"grad_fn = jax.jit(jax.value_and_grad(loss_fn))\n",
"test_loss, test_grad = grad_fn(jnp.float64(0.05))\n",
"print(f\"Loss at alpha=0.05: {test_loss:.6e}\")\n",
"print(f\"Gradient dL/dalpha: {test_grad:.6e}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from scipy.optimize import minimize as scipy_minimize\n",
"\n",
"# Start from a wrong initial guess: alpha=0.05 (true value is 0.02)\n",
"alpha_init = 0.05\n",
"history = {\"alpha\": [alpha_init], \"loss\": []}\n",
"\n",
"\n",
"def objective(alpha_arr):\n",
" \"\"\"Wrapper for scipy: returns (loss, grad) as plain floats.\"\"\"\n",
" loss, grad = grad_fn(float(alpha_arr[0]))\n",
" history[\"alpha\"].append(float(alpha_arr[0]))\n",
" history[\"loss\"].append(float(loss))\n",
" return float(loss), np.array([float(grad)])\n",
"\n",
"\n",
"result = scipy_minimize(\n",
" objective,\n",
" x0=np.array([alpha_init]),\n",
" method=\"L-BFGS-B\",\n",
" jac=True,\n",
" bounds=[(1e-6, 0.1)],\n",
" options={\"maxiter\": 100},\n",
")\n",
"\n",
"alpha_recovered = result.x[0]\n",
"print(f\"Converged in {len(history['loss'])} iterations\")\n",
"print(f\"Recovered alpha: {alpha_recovered:.8f}\")\n",
"print(f\"True alpha: {alpha_true}\")\n",
"print(f\"Final loss: {result.fun:.2e}\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"fig, axes = plt.subplots(1, 3, figsize=(15, 4))\n",
"\n",
"# Loss curve\n",
"axes[0].semilogy(history[\"loss\"])\n",
"axes[0].set_xlabel(\"Iteration\")\n",
"axes[0].set_ylabel(\"Loss (MSE)\")\n",
"axes[0].set_title(\"Loss convergence\")\n",
"axes[0].grid(True, alpha=0.3)\n",
"\n",
"# Alpha convergence\n",
"axes[1].plot(history[\"alpha\"], label=\"Estimated $\\\\alpha$\")\n",
"axes[1].axhline(y=alpha_true, color=\"r\", linestyle=\"--\", label=\"True $\\\\alpha$\")\n",
"axes[1].set_xlabel(\"Iteration\")\n",
"axes[1].set_ylabel(\"$\\\\alpha$\")\n",
"axes[1].set_title(\"Parameter convergence\")\n",
"axes[1].legend()\n",
"axes[1].grid(True, alpha=0.3)\n",
"\n",
"# Temperature profiles\n",
"T_recovered = simulate(T_init, alpha_recovered, n_steps)\n",
"T_initial_guess = simulate(T_init, alpha_init, n_steps)\n",
"axes[2].plot(x, T_observed, \"k--\", label=\"Observed\", linewidth=2)\n",
"axes[2].plot(\n",
" x, T_recovered, label=f\"Recovered ($\\\\alpha$={alpha_recovered:.4f})\", linewidth=2\n",
")\n",
"axes[2].plot(\n",
" x,\n",
" T_initial_guess,\n",
" \":\",\n",
" label=f\"Initial guess ($\\\\alpha$={alpha_init})\",\n",
" linewidth=2,\n",
" alpha=0.7,\n",
")\n",
"axes[2].set_xlabel(\"x\")\n",
"axes[2].set_ylabel(\"T\")\n",
"axes[2].set_title(\"Temperature profiles\")\n",
"axes[2].legend()\n",
"axes[2].grid(True, alpha=0.3)\n",
"\n",
"plt.tight_layout()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Step 5: Verify gradient accuracy against finite differences\n",
"\n",
"To confirm that Enzyme produces exact gradients, we compare against finite difference approximations at several step sizes. Enzyme's gradients should match to machine precision, while finite differences degrade for both too-large (truncation error) and too-small (roundoff error) step sizes."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Compare Enzyme gradient vs finite differences at a non-optimal alpha\n",
"# (At the optimum the gradient is zero, which makes relative error meaningless)\n",
"alpha_test = 0.03\n",
"\n",
"_, enzyme_grad = grad_fn(alpha_test)\n",
"\n",
"# Finite difference gradients at various step sizes\n",
"epsilons = np.logspace(-2, -12, 20)\n",
"fd_grads = []\n",
"for eps in epsilons:\n",
" loss_plus = loss_fn(alpha_test + eps)\n",
" loss_minus = loss_fn(alpha_test - eps)\n",
" fd_grads.append(float((loss_plus - loss_minus) / (2 * eps)))\n",
"\n",
"fd_grads = np.array(fd_grads)\n",
"rel_errors = np.abs(fd_grads - float(enzyme_grad)) / (\n",
" np.abs(float(enzyme_grad)) + 1e-30\n",
")\n",
"\n",
"plt.figure(figsize=(8, 4))\n",
"plt.loglog(epsilons, rel_errors, \"o-\", label=\"FD vs Enzyme\")\n",
"plt.xlabel(\"Finite difference step size $\\\\epsilon$\")\n",
"plt.ylabel(\"Relative error vs. Enzyme gradient\")\n",
"plt.title(\"Enzyme provides exact gradients; finite differences have a sweet spot\")\n",
"plt.grid(True, alpha=0.3)\n",
"plt.legend()\n",
"plt.tight_layout()\n",
"\n",
"print(f\"Enzyme gradient: {float(enzyme_grad):.10e}\")\n",
"print(f\"Best FD gradient: {fd_grads[np.argmin(rel_errors)]:.10e}\")\n",
"print(f\"Best FD relative error: {rel_errors.min():.2e}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": "## Takeaways\n\n1. **Exact gradients from compiled Fortran, automatically.** Enzyme differentiated the Fortran heat solver at the LLVM IR level -- no adjoint code, no source modifications, no approximation.\n\n2. **Full JAX composability.** By packaging the solver as a Tesseract and using tesseract-jax, we called `jax.grad` through hundreds of Fortran solver invocations. L-BFGS-B converged in just a handful of iterations -- the Fortran solver is just another differentiable function.\n\n3. **Machine-precision accuracy.** The Enzyme gradients match finite differences at their best, and remain exact where finite differences break down.\n\n4. **Language-agnostic.** Enzyme works on LLVM IR, so the same approach applies to C, C++, Rust -- any language with an LLVM frontend. The Fortran example here is just the beginning.\n\n5. **Reproducible and portable.** The entire toolchain (LFortran, LLVM 19, Enzyme) is packaged in the Tesseract container. `tesseract build` produces the same differentiated binary on any machine."
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"heat_tesseract.teardown()"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python",
"version": "3.10.0"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
3 changes: 3 additions & 0 deletions demo/enzyme_thermal_2d/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
# Generated figures (notebook + pipeline-diagram outputs).
# Regenerated on demand; copied into docs/static/blog/ by hand when publishing.
figures/
Loading
Loading