Skip to content
Draft
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@
- Added a model overview vignette with an architecture diagram showing how the package's models connect.
- Added a model features vignette providing a quick reference to all modelling options with links to detailed documentation.

## Model changes

- The cumulated non-stationary Gaussian process used to model Rt over time is now mean-centred (`gp -= mean(gp)` in `inst/stan/functions/rt.stan`), so `R0` represents the mean Rt over the trajectory rather than the initial value. This eliminates the `(R0, drift)` ridge in the joint posterior that was responsible for stuck chains and catastrophic R-hat values on some seeds. No API change and no change to the `alpha` prior — verified across previously stuck seeds (R-hat goes from up to 6.10 down to <1.01, treedepth hits from hundreds down to zero).

## Bug fixes

- Fixed a bug in `forecast_infections()` where the summary call to extract dates was using modified args instead of the original fit dimensions, causing a date-dimension mismatch when extending the R trajectory beyond the original observation period.
Expand Down
4 changes: 4 additions & 0 deletions inst/stan/functions/rt.stan
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ vector update_Rt(int t, real R0, vector noise, array[] int bps,
} else {
gp[2:(gp_n + 1)] = noise;
gp = cumulative_sum(gp);
// Identifiability: subtract the trajectory mean so log R0 is the mean
// log Rt over the window rather than the initial value. Eliminates
// the (R0, drift) ridge in the joint posterior.
gp -= mean(gp);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is fine but it changes the interpretation of the prior passed to rt_opts() -- this needs to be reflected in the documentation in various places (also the model description vignette I think), and should probably have an argument rename from init and a deprecation cycle.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It also makes it much hard to set. I wonder if we can keep it with its current defintion but also keep this change? A prestan transform of some kind?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed in be45c1b: rt_opts() prior doc and the three estimate_infections vignettes (workflow, options, math description) now reflect that with the default non-stationary GP, R0 is the mean Rt over the observation window rather than the initial Rt. Stationary GP and no-GP cases keep the initial-Rt meaning.

On the argument rename + deprecation: the current argument is prior (not init) and is generic enough that I don't think the name itself needs to change — the meaning shift is captured in the docs. Happy to do a rename (e.g. priorr_mean_prior with lifecycle::deprecate_warn()) if you'd prefer, just want to confirm the new name before doing the cycle. Could you point me at what you had in mind?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we might not need a rename but if we change the interpretation of the parameter we need, at the minimum, a warning to anyone who sets this to something other than the default.

What about @seabbs's comment? It's often easier to set a prior for initial R than mean R. Could this behaviour be recovered under the updated cumsum centering?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added the warning in 0055390: rt_opts() now emits a cli_warn() whenever a user supplies a non-default prior while gp_on = "R_t-1" (the default), explicitly noting the mean-Rt interpretation. Stationary GP (gp_on = "R0") and no-GP paths don't trigger it.

On @seabbs's question about recovering the initial-Rt semantics with a pre-stan transform: I worked through the algebra and it doesn't separate cleanly. The geometric improvement from centring comes from imposing mean(gp) = 0. If you also impose gp[1] = 0 (so that R0 = R[1]), then gp_for_logR = gp_centred - gp_centred[1] = cumsum(noise) — the centring cancels out and you're back to the original (bad-geometry) parameterisation. Any single linear constraint on the trajectory removes one degree of freedom from the (R0, drift) ridge; choosing mean(gp) = 0 and choosing gp[1] = 0 are two different choices of the same kind of constraint, and they produce the same likelihood but a different identification for R0.

So the only way to keep the initial-Rt interpretation is the original parameterisation, which has the ridge. We have to pick one. The warning is the honest minimum here.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better answer landed in d5f9081: the user-facing interpretation is preserved (rt_opts(prior=...) is once again the prior on the initial Rt). No warning, no rename, no doc shift.

Implementation: R0 now leaves the generic params vector. R_mean (the trajectory mean) is sampled internally — that gets the geometry win from centring. The user's prior is applied to the derived initial Rt R[1] via a new centred_gp_init_lpdf Stan helper, with a Jacobian correction (just a +log(R[1]) term) so that the resulting joint prior over (R[1], noise, R[t]) matches main exactly. Change of variables, determinant 1.

The plumbing (init_param_ids, init_dists, init_dist_params data items + pack_init_prior() R helper + centred_gp_init_lpdf Stan helper) is intentionally parameter-agnostic — designed as a prototype for the time-varying-parameter framework in #600. Today R0 is the only entry; adding a time-varying alpha later is a one-line dispatch branch in Stan plus appending to the data items.

Verified end-to-end on the previously catastrophic seed=8 (R-hat=6.10 on main): td=0, div=1, R-hat=1.005, ESS=663. R[1] posterior mean=2.29 (user prior on initial Rt → posterior on R[1]); R_mean[1] posterior mean=1.00 (data-determined trajectory mean). The two quantities are distinct and the user's prior lands where intended.

}
logR = logR + gp;
}
Expand Down
16 changes: 12 additions & 4 deletions tests/testthat/test-stan-rt.R
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,14 @@ test_that("update_Rt works to produce multiple Rt estimates with a static gaussi
)
})
test_that("update_Rt works to produce multiple Rt estimates with a non-static gaussian process", {
# Non-stationary GP: cumulated trajectory is centred so log R0 = mean log Rt
# over the window rather than the initial value (eliminates the (R0, drift)
# ridge in the joint posterior). For noise = rep(0.1, 9), gp_n = 9:
# gp = cumsum(noise) = c(0, 0.1, 0.2, ..., 0.9), mean = 0.45,
# centred = c(-0.45, -0.35, ..., 0.45). log Rt = log(1.2) + centred.
expect_equal(
round(update_Rt(10, 1.2, rep(0.1, 9), rep(10, 0), numeric(0), 0), 2),
c(1.20, 1.33, 1.47, 1.62, 1.79, 1.98, 2.19, 2.42, 2.67, 2.95)
round(update_Rt(10, 1.2, rep(0.1, 9), rep(10, 0), numeric(0), 0), 3),
c(0.765, 0.846, 0.935, 1.033, 1.141, 1.262, 1.394, 1.541, 1.703, 1.882)
)
})
test_that("update_Rt works to produce multiple Rt estimates with a non-static stationary gaussian process", {
Expand Down Expand Up @@ -53,9 +58,12 @@ test_that("update_Rt works when Rt is variable and a breakpoint is present", {
round(update_Rt(5, 1.2, rep(0, 5), c(1, 1, 2, 2, 2), 0.1, 1), 2),
c(1.2, 1.2, rep(1.33, 3))
)
# Non-stationary GP: see explanation in the earlier non-static GP test.
# Here gp_n = 4, gp_centred = c(-0.2, -0.1, 0, 0.1, 0.2), breakpoint adds
# 0.1 from t = 3 onward.
expect_equal(
round(update_Rt(5, 1.2, rep(0.1, 4), c(1, 1, 2, 2, 2), 0.1, 0), 2),
c(1.20, 1.33, 1.62, 1.79, 1.98)
round(update_Rt(5, 1.2, rep(0.1, 4), c(1, 1, 2, 2, 2), 0.1, 0), 3),
c(0.982, 1.086, 1.326, 1.466, 1.620)
)
})

Expand Down
Loading