Sequential Monte Carlo and particle filtering in JAX.
smcjax is a JAX implementation of the methods developed in my
master's thesis
on sequential inference for Hidden Markov Models (University of
Arkansas, 2018). It extends Dynamax
and BlackJAX with particle
filters and Bayesian workflow diagnostics. All filters are JIT-compiled via
jax.lax.scan and GPU-ready.
- Bootstrap (SIR) particle filter — Gordon et al. (1993)
- Auxiliary particle filter — Pitt & Shephard (1999)
- Liu-West filter — joint state-parameter estimation via kernel density smoothing (Liu & West, 2001)
- Forward simulation — generate trajectories from state-space models
- Diagnostics — weighted mean/variance/quantiles, parameter summaries, ESS traces, particle diversity, per-step log evidence increments, replicated log-ML, log Bayes factors, CRPS
- 4 resampling schemes (via BlackJAX): systematic, stratified, multinomial, residual
- Conditional resampling with configurable ESS threshold
- All functions are
jit- andvmap-compatible - Type annotations via jaxtyping
- Python 3.10 or later
- uv installed
pip install smcjaxOr from source:
git clone https://github.com/michaelellis003/smcjax.git
cd smcjax
uv syncInstall the pre-commit hooks (one-time setup):
uv run pre-commit install
uv run pre-commit install --hook-type commit-msg
uv run pre-commit install --hook-type pre-pushimport jax.numpy as jnp
import jax.random as jr
import jax.scipy.stats as jstats
from smcjax import bootstrap_filter, weighted_mean, log_ml_increments
# Define a 1-D linear Gaussian state space model
m0, P0 = jnp.array([0.0]), jnp.array([[1.0]])
F, Q = jnp.array([[0.9]]), jnp.array([[0.25]])
H, R = jnp.array([[1.0]]), jnp.array([[1.0]])
chol_P0 = jnp.linalg.cholesky(P0)
chol_Q = jnp.linalg.cholesky(Q)
def initial_sampler(key, n):
return m0 + jr.normal(key, (n, 1)) @ chol_P0.T
def transition_sampler(key, state):
mean = (F @ state[:, None]).squeeze(-1)
return mean + jr.normal(key, (1,)) @ chol_Q.T
def log_observation_fn(emission, state):
mean = (H @ state[:, None]).squeeze(-1)
return jstats.multivariate_normal.logpdf(emission, mean, R)
# Simulate some data
key = jr.PRNGKey(0)
T = 100
emissions = jr.normal(key, (T, 1))
# Run the bootstrap particle filter
posterior = bootstrap_filter(
key=jr.PRNGKey(1),
initial_sampler=initial_sampler,
transition_sampler=transition_sampler,
log_observation_fn=log_observation_fn,
emissions=emissions,
num_particles=1_000,
)
print(f"Log marginal likelihood: {posterior.marginal_loglik:.2f}")
print(f"Particles shape: {posterior.filtered_particles.shape}")
print(f"Mean ESS: {posterior.ess.mean():.1f}")
# Diagnostics
means = weighted_mean(posterior)
increments = log_ml_increments(posterior)A Makefile collects the common development tasks:
make test # lint + pytest
make lint # ruff check, format check, license headers, ty
make format # add license headers, ruff format, ruff fix
make license # add missing license headers
make docs # build documentation
make serve-docs # serve documentation locally
make install # uv sync
make clean # git clean (preserves .venv)Releases are fully automated. When a commit lands on main and CI
passes, python-semantic-release inspects the commit history to
determine whether a version bump is warranted:
fix: ...produces a patch releasefeat: ...produces a minor release- A
BREAKING CHANGEfooter or!suffix produces a major release
Apache-2.0. See LICENSE for the full text.