Skip to content

Add ConvergentFilter#676

Open
jessegrabowski wants to merge 17 commits intomainfrom
convergant-scan
Open

Add ConvergentFilter#676
jessegrabowski wants to merge 17 commits intomainfrom
convergant-scan

Conversation

@jessegrabowski
Copy link
Copy Markdown
Member

@jessegrabowski jessegrabowski commented Apr 23, 2026

Closes #394
Related to #332

For certain classes of models, the latent state covariance matrix estimated by the Kalman Filter converges to a fixed point, after which is is no longer necessary to compute via expensive linear algebra. This PR adds a new filter, ConvergentFIlter, that exploits this property. To take advantage, a model must:

  • Have no missing values, and
  • No time varying matrices

The ConvergentFilter uses a while scan, so this cannot be used with jax. All timings are done with numba.

These break the convergence assumptions. But nothing else is required. If these criteria are met, we get big speedups. Here is a table of models, where n is the number of timesteps, m is the number of hidden states, p is the number of observed states, and r is the number of shocks. The biggest speedups come from long time series, where there is more iterations in the compute-saving post-convergence scan. Here are some timings for the compiled logp_dlogp function for models taken from the example notebooks:

Model Shape (m, p, r) n StandardFilter ConvergentFilter Speedup
VARMAX(2,0) 5-obs (10, 5, 5) 200 3.4 1.3 2.5x
VARMAX(2,0) 5-obs (10, 5, 5) 1000 17.2 5.1 3.4x
VARMAX(2,0) 5-obs (10, 5, 5) 5000 96.0 26.6 3.6x
SARIMAX(1,0,1) (2, 1, 1) 200 1.1 0.6 1.9x
SARIMAX(1,0,1) (2, 1, 1) 1000 5.4 2.1 2.6x
SARIMAX(1,0,1) (2, 1, 1) 5000 27.9 10.1 2.8x
SARIMAX(2,1,2)(2,0,2,12) (28, 1, 1) 200 5.7 3.1 1.8x
SARIMAX(2,1,2)(2,0,2,12) (28, 1, 1) 1000 30.7 14.8 2.1x
SARIMAX(2,1,2)(2,0,2,12) (28, 1, 1) 5000 153.7 63.9 2.4x
ETS(A,Ad,A,12) (15, 1, 1) 200 3.4 1.8 1.9x
ETS(A,Ad,A,12) (15, 1, 1) 1000 17.6 6.4 2.8x
ETS(A,Ad,A,12) (15, 1, 1) 5000 90.7 33.6 2.7x
DFM 1-factor 2-lag 4-obs (10, 4, 5) 200 2.9 1.5 1.9x
DFM 1-factor 2-lag 4-obs (10, 4, 5) 1000 15.6 4.9 3.2x
DFM 1-factor 2-lag 4-obs (10, 4, 5) 5000 78.3 24.8 3.2x
Structural LvlTrend+FreqSeas (airpass) (14, 1, 14) 200 2.5 1.5 1.7x
Structural LvlTrend+FreqSeas (airpass) (14, 1, 14) 1000 12.8 4.6 2.8x
Structural LvlTrend+FreqSeas (airpass) (14, 1, 14) 5000 66.5 23.3 2.9x

Special thanks to @JeanVanDyk, who derived components of the closed forms used in the handoff step between the two scans. I tried using the full analytic forms he worked out end-to-end, but it was slower than autodiff, I believe because of the number of matmuls involved. It might be faster on GPU, so I want to revist that again at some point.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

enhancements New feature or request statespace

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Use Kalman Filter convergence to better handle long time series

3 participants