Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@
- Added a model overview vignette with an architecture diagram showing how the package's models connect.
- Added a model features vignette providing a quick reference to all modelling options with links to detailed documentation.

## Model changes

- The non-stationary Gaussian process used to model Rt over time now samples in a mean-centred parameterisation internally (`gp -= mean(gp)` in `inst/stan/functions/rt.stan`), eliminating the `(R0, drift)` ridge in the joint posterior that caused stuck chains and catastrophic R-hat values on some seeds. The user-facing interpretation is unchanged: the `prior` argument in `rt_opts()` is still the prior on the initial Rt. This is achieved by sampling the trajectory mean internally and applying the user prior to the derived initial Rt via a new generic `centred_gp_init_lpdf` Stan helper, with Jacobian-correct change of variables (so the joint prior matches the pre-change model). The plumbing (`init_dists`, `init_dist_params`, `pack_init_prior()`, `centred_gp_init_lpdf`) is parameter-agnostic, designed as a prototype for the general time-varying-parameter framework planned for issue #600.

## Bug fixes

- Fixed a bug in `forecast_infections()` where the summary call to extract dates was using modified args instead of the original fit dimensions, causing a date-dimension mismatch when extending the R trajectory beyond the original observation period.
Expand Down
25 changes: 24 additions & 1 deletion R/estimate_infections.R
Original file line number Diff line number Diff line change
Expand Up @@ -181,10 +181,12 @@ estimate_infections <- function(data,
# Add initial zeroes
model_data <- pad_reported_cases(model_data, seeding_time)

# R0 is handled separately from the generic params system: it is wrapped
# by the centred non-stationary GP, so its user-facing prior is on the
# initial Rt (R[1]) rather than on the sampled internal log-mean.
params <- list(
make_param("alpha", gp$alpha, lower_bound = 0),
make_param("rho", gp$ls, lower_bound = 0),
make_param("R0", rt$prior, lower_bound = 0),
make_param("fraction_observed", obs$scale, lower_bound = 0),
make_param("reporting_overdispersion", obs$dispersion, lower_bound = 0),
make_param("pop", rt$pop, lower_bound = 0)
Expand All @@ -202,6 +204,27 @@ estimate_infections <- function(data,
params = params
)

# Register R0 as the (single, for now) centred-GP-wrapped parameter with
# its user prior applied to the derived initial Rt. The dispatch in
# `inst/stan/estimate_infections.stan` (the `init lp` profile block) uses
# `param_id_R0` to know which trajectory's initial value to apply the
# prior to. Generic plumbing: additional time-varying parameters drop in
# alongside R0 by appending to `init_param_ids` / `init_dists` /
# `init_dist_params` and adding one dispatch branch in stan.
stan_data$param_id_R0 <- stan_data$n_params_variable + 1L
if (isTRUE(rt$use_rt)) {
init_R <- pack_init_prior(rt$prior)
stan_data$n_init_priors <- 1L
stan_data$init_param_ids <- array(stan_data$param_id_R0)
stan_data$init_dists <- array(init_R$dist_type)
stan_data$init_dist_params <- array(init_R$params)
} else {
stan_data$n_init_priors <- 0L
stan_data$init_param_ids <- array(integer(0))
stan_data$init_dists <- array(integer(0))
stan_data$init_dist_params <- array(numeric(0))
}

stan_data <- c(stan_data, create_stan_delays(
generation_time = generation_time,
reporting = delays,
Expand Down
1 change: 0 additions & 1 deletion R/opts.R
Original file line number Diff line number Diff line change
Expand Up @@ -341,7 +341,6 @@ rt_opts <- function(prior = LogNormal(mean = 1, sd = 1),
pop_floor = pop_floor,
growth_method = arg_match(growth_method)
)

# replace default settings with those specified by user
if (opts$rw > 0) {
opts$use_breakpoints <- TRUE
Expand Down
8 changes: 8 additions & 0 deletions R/simulate_infections.R
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,14 @@ simulate_infections <- function(R,
## set empty params matrix - variable parameters not supported here
stan_data$params <- array(dim = c(1, 0))

## init priors: not used in forward simulation (no model fitting). Provide
## empty defaults so the shared estimate_infections_params.stan data block
## is satisfied.
stan_data$n_init_priors <- 0L
stan_data$init_param_ids <- array(integer(0))
stan_data$init_dists <- array(integer(0))
stan_data$init_dist_params <- array(numeric(0))

## day of week effect
if (is.null(day_of_week_effect)) {
day_of_week_effect <- rep(1, stan_data$week_effect)
Expand Down
37 changes: 37 additions & 0 deletions R/utilities.R
Original file line number Diff line number Diff line change
Expand Up @@ -481,6 +481,43 @@ make_param <- function(name, dist = NULL, lower_bound = -Inf) {
params
}

##' Pack a dist_spec into stan-side init-prior fields
##'
##' For parameters wrapped in a centred non-stationary GP (today: R0), the
##' user-facing prior in `rt_opts()` is on the initial value of the trajectory
##' rather than on the sampled internal mean. This helper converts a
##' `<dist_spec>` into the integer/vector representation consumed by the
##' generic init-prior loop in `inst/stan/estimate_infections.stan` (data
##' items `init_dists` and `init_dist_params`).
##'
##' Generic over the parameter — designed to lift directly into a future
##' general time-varying-parameter framework (issue #600).
##'
##' @param dist A `<dist_spec>` (LogNormal, Gamma, or Normal).
##' @return A list with elements `dist_type` (integer code: 0 = lognormal,
##' 1 = gamma, 2 = normal) and `params` (numeric of length 2 with the
##' distribution parameters).
##' @keywords internal
pack_init_prior <- function(dist) {
dist_name <- get_distribution(dist)
dist_type <- switch(dist_name,
lognormal = 0L,
gamma = 1L,
normal = 2L,
cli_abort(c(
"!" = "Init prior distribution {.val {dist_name}} not supported.",
"i" = "Use {.fn LogNormal}, {.fn Gamma}, or {.fn Normal}."
))
)
params <- unlist(get_parameters(dist))
if (length(params) != 2) {
cli_abort(
"Init prior must have exactly 2 distribution parameters; got {length(params)}." # nolint
)
}
list(dist_type = dist_type, params = as.numeric(params))
}

#' @importFrom stats glm median na.omit pexp pgamma plnorm quasipoisson rexp
#' @importFrom stats rlnorm rnorm rpois runif sd var rgamma pnorm
globalVariables(
Expand Down
9 changes: 9 additions & 0 deletions inst/stan/data/estimate_infections_params.stan
Original file line number Diff line number Diff line change
Expand Up @@ -4,3 +4,12 @@ int<lower = 0> param_id_R0; // parameter id of R0
int<lower = 0> param_id_fraction_observed; // parameter id of fraction_observed
int<lower = 0> param_id_reporting_overdispersion; // parameter id of reporting_overdispersion
int<lower = 0> param_id_pop; // parameter id of pop

// Init priors for centred-GP-wrapped parameters. Today only R0 is wrapped
// (its prior in rt_opts() is on the *initial* Rt, applied to the derived
// R[1] from the centred GP). Generic shape so additional time-varying
// parameters can be added without changing the data plumbing.
int<lower = 0> n_init_priors;
array[n_init_priors] int<lower = 1> init_param_ids;
array[n_init_priors] int<lower = 0, upper = 2> init_dists;
vector[2 * n_init_priors] init_dist_params;
36 changes: 31 additions & 5 deletions inst/stan/estimate_infections.stan
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,10 @@ parameters {
vector<lower = params_lower, upper = params_upper>[n_params_variable] params;
// gaussian process
vector[fixed ? 0 : gp_type == 1 ? 2*M : M] eta; // unconstrained noise
// Rt
// Rt — mean log Rt over the window (sampled internally). User prior is on
// R[1] (initial Rt) and is applied via centred_gp_init_lpdf in the model
// block, with R[1] derived from R_mean + gp_centred[1].
array[estimate_r] real<lower = 0> R_mean;
array[estimate_r] real initial_infections; // seed infections
// standard deviation of breakpoint effect
array[bp_n > 0 ? 1 : 0] real<lower = 0> bp_sd;
Expand Down Expand Up @@ -106,11 +109,11 @@ transformed parameters {
);
}
profile("R0") {
real R0 = get_param(
param_id_R0, params_fixed_lookup, params_variable_lookup, params_value, params
);
// R_mean is sampled directly (centred GP scaffolding). The user prior
// lives on R[1] (initial Rt), applied in the model block via the
// init_priors plumbing.
R = update_Rt(
ot_h, R0, noise, breakpoints, bp_effects, stationary
ot_h, R_mean[1], noise, breakpoints, bp_effects, stationary
);
}
profile("infections") {
Expand Down Expand Up @@ -231,6 +234,29 @@ model {
}
}

// Apply user priors on the initial value of any centred-GP-wrapped
// parameter (today: R0 -> R[1]). Generic loop so new time-varying
// parameters drop in via a one-line dispatch addition below.
if (n_init_priors > 0) {
profile("init lp") {
for (i in 1:n_init_priors) {
real init_value;
int pid = init_param_ids[i];
if (pid == param_id_R0) {
init_value = R[1];
} else {
reject("no time-varying parameter registered for id ", pid);
}
target += centred_gp_init_lpdf(
init_value |
init_dists[i],
init_dist_params[2 * i - 1],
init_dist_params[2 * i]
);
}
}
}

// observed reports from mean of reports (update likelihood)
if (likelihood) {
profile("report lp") {
Expand Down
33 changes: 33 additions & 0 deletions inst/stan/functions/params.stan
Original file line number Diff line number Diff line change
Expand Up @@ -106,3 +106,36 @@ void params_lp(vector params, array[] int prior_dist,
}
}

/**
* Apply user prior on the initial value of a centred-GP-wrapped parameter.
*
* When a parameter is wrapped by a centred non-stationary GP, the user-facing
* prior is on the initial value of the trajectory (X[1]) rather than on the
* sampled internal log-mean. This helper applies the prior to the derived
* initial value with the Jacobian correction for the linear-shift transform
* from log-mean to log-initial.
*
* Generic over the parameter — used by R0 today, lifts to any future
* time-varying parameter (alpha, dispersion, ...) via the same call.
*
* @param init_value Derived initial value of the trajectory (e.g. R[1]).
* @param dist_type Prior distribution type (0 = lognormal, 1 = gamma, 2 = normal).
* @param p1 First distribution parameter.
* @param p2 Second distribution parameter.
* @return Log-density contribution to the target, Jacobian included.
*
* @ingroup parameter_handlers
*/
real centred_gp_init_lpdf(real init_value, int dist_type, real p1, real p2) {
if (dist_type == 0) {
return lognormal_lpdf(init_value | p1, p2) + log(init_value);
} else if (dist_type == 1) {
return gamma_lpdf(init_value | p1, p2) + log(init_value);
} else if (dist_type == 2) {
return normal_lpdf(init_value | p1, p2) + log(init_value);
} else {
reject("centred_gp_init_lpdf: dist_type must be 0, 1, or 2");
}
return 0;
}

4 changes: 4 additions & 0 deletions inst/stan/functions/rt.stan
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,10 @@ vector update_Rt(int t, real R0, vector noise, array[] int bps,
} else {
gp[2:(gp_n + 1)] = noise;
gp = cumulative_sum(gp);
// Identifiability: subtract the trajectory mean so log R0 is the mean
// log Rt over the window rather than the initial value. Eliminates
// the (R0, drift) ridge in the joint posterior.
gp -= mean(gp);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is fine but it changes the interpretation of the prior passed to rt_opts() -- this needs to be reflected in the documentation in various places (also the model description vignette I think), and should probably have an argument rename from init and a deprecation cycle.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It also makes it much hard to set. I wonder if we can keep it with its current defintion but also keep this change? A prestan transform of some kind?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed in be45c1b: rt_opts() prior doc and the three estimate_infections vignettes (workflow, options, math description) now reflect that with the default non-stationary GP, R0 is the mean Rt over the observation window rather than the initial Rt. Stationary GP and no-GP cases keep the initial-Rt meaning.

On the argument rename + deprecation: the current argument is prior (not init) and is generic enough that I don't think the name itself needs to change — the meaning shift is captured in the docs. Happy to do a rename (e.g. priorr_mean_prior with lifecycle::deprecate_warn()) if you'd prefer, just want to confirm the new name before doing the cycle. Could you point me at what you had in mind?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we might not need a rename but if we change the interpretation of the parameter we need, at the minimum, a warning to anyone who sets this to something other than the default.

What about @seabbs's comment? It's often easier to set a prior for initial R than mean R. Could this behaviour be recovered under the updated cumsum centering?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added the warning in 0055390: rt_opts() now emits a cli_warn() whenever a user supplies a non-default prior while gp_on = "R_t-1" (the default), explicitly noting the mean-Rt interpretation. Stationary GP (gp_on = "R0") and no-GP paths don't trigger it.

On @seabbs's question about recovering the initial-Rt semantics with a pre-stan transform: I worked through the algebra and it doesn't separate cleanly. The geometric improvement from centring comes from imposing mean(gp) = 0. If you also impose gp[1] = 0 (so that R0 = R[1]), then gp_for_logR = gp_centred - gp_centred[1] = cumsum(noise) — the centring cancels out and you're back to the original (bad-geometry) parameterisation. Any single linear constraint on the trajectory removes one degree of freedom from the (R0, drift) ridge; choosing mean(gp) = 0 and choosing gp[1] = 0 are two different choices of the same kind of constraint, and they produce the same likelihood but a different identification for R0.

So the only way to keep the initial-Rt interpretation is the original parameterisation, which has the ridge. We have to pick one. The warning is the honest minimum here.

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Better answer landed in d5f9081: the user-facing interpretation is preserved (rt_opts(prior=...) is once again the prior on the initial Rt). No warning, no rename, no doc shift.

Implementation: R0 now leaves the generic params vector. R_mean (the trajectory mean) is sampled internally — that gets the geometry win from centring. The user's prior is applied to the derived initial Rt R[1] via a new centred_gp_init_lpdf Stan helper, with a Jacobian correction (just a +log(R[1]) term) so that the resulting joint prior over (R[1], noise, R[t]) matches main exactly. Change of variables, determinant 1.

The plumbing (init_param_ids, init_dists, init_dist_params data items + pack_init_prior() R helper + centred_gp_init_lpdf Stan helper) is intentionally parameter-agnostic — designed as a prototype for the time-varying-parameter framework in #600. Today R0 is the only entry; adding a time-varying alpha later is a one-line dispatch branch in Stan plus appending to the data items.

Verified end-to-end on the previously catastrophic seed=8 (R-hat=6.10 on main): td=0, div=1, R-hat=1.005, ESS=663. R[1] posterior mean=2.29 (user prior on initial Rt → posterior on R[1]); R_mean[1] posterior mean=1.00 (data-determined trajectory mean). The two quantities are distinct and the user's prior lands where intended.

}
logR = logR + gp;
}
Expand Down
29 changes: 29 additions & 0 deletions man/pack_init_prior.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

16 changes: 12 additions & 4 deletions tests/testthat/test-stan-rt.R
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,14 @@ test_that("update_Rt works to produce multiple Rt estimates with a static gaussi
)
})
test_that("update_Rt works to produce multiple Rt estimates with a non-static gaussian process", {
# Non-stationary GP: cumulated trajectory is centred so log R0 = mean log Rt
# over the window rather than the initial value (eliminates the (R0, drift)
# ridge in the joint posterior). For noise = rep(0.1, 9), gp_n = 9:
# gp = cumsum(noise) = c(0, 0.1, 0.2, ..., 0.9), mean = 0.45,
# centred = c(-0.45, -0.35, ..., 0.45). log Rt = log(1.2) + centred.
expect_equal(
round(update_Rt(10, 1.2, rep(0.1, 9), rep(10, 0), numeric(0), 0), 2),
c(1.20, 1.33, 1.47, 1.62, 1.79, 1.98, 2.19, 2.42, 2.67, 2.95)
round(update_Rt(10, 1.2, rep(0.1, 9), rep(10, 0), numeric(0), 0), 3),
c(0.765, 0.846, 0.935, 1.033, 1.141, 1.262, 1.394, 1.541, 1.703, 1.882)
)
})
test_that("update_Rt works to produce multiple Rt estimates with a non-static stationary gaussian process", {
Expand Down Expand Up @@ -53,9 +58,12 @@ test_that("update_Rt works when Rt is variable and a breakpoint is present", {
round(update_Rt(5, 1.2, rep(0, 5), c(1, 1, 2, 2, 2), 0.1, 1), 2),
c(1.2, 1.2, rep(1.33, 3))
)
# Non-stationary GP: see explanation in the earlier non-static GP test.
# Here gp_n = 4, gp_centred = c(-0.2, -0.1, 0, 0.1, 0.2), breakpoint adds
# 0.1 from t = 3 onward.
expect_equal(
round(update_Rt(5, 1.2, rep(0.1, 4), c(1, 1, 2, 2, 2), 0.1, 0), 2),
c(1.20, 1.33, 1.62, 1.79, 1.98)
round(update_Rt(5, 1.2, rep(0.1, 4), c(1, 1, 2, 2, 2), 0.1, 0), 3),
c(0.982, 1.086, 1.326, 1.466, 1.620)
)
})

Expand Down
Loading