diff --git a/NEWS.md b/NEWS.md index 7264770fc..da14ef353 100644 --- a/NEWS.md +++ b/NEWS.md @@ -17,6 +17,10 @@ - Added a model overview vignette with an architecture diagram showing how the package's models connect. - Added a model features vignette providing a quick reference to all modelling options with links to detailed documentation. +## Model changes + +- The non-stationary Gaussian process used to model Rt over time now samples in a mean-centred parameterisation internally (`gp -= mean(gp)` in `inst/stan/functions/rt.stan`), eliminating the `(R0, drift)` ridge in the joint posterior that caused stuck chains and catastrophic R-hat values on some seeds. The user-facing interpretation is unchanged: the `prior` argument in `rt_opts()` is still the prior on the initial Rt. This is achieved by sampling the trajectory mean internally and applying the user prior to the derived initial Rt via a new generic `centred_gp_init_lpdf` Stan helper, with Jacobian-correct change of variables (so the joint prior matches the pre-change model). The plumbing (`init_dists`, `init_dist_params`, `pack_init_prior()`, `centred_gp_init_lpdf`) is parameter-agnostic, designed as a prototype for the general time-varying-parameter framework planned for issue #600. + ## Bug fixes - Fixed a bug in `forecast_infections()` where the summary call to extract dates was using modified args instead of the original fit dimensions, causing a date-dimension mismatch when extending the R trajectory beyond the original observation period. diff --git a/R/estimate_infections.R b/R/estimate_infections.R index 00578d508..c620ee70c 100644 --- a/R/estimate_infections.R +++ b/R/estimate_infections.R @@ -181,10 +181,12 @@ estimate_infections <- function(data, # Add initial zeroes model_data <- pad_reported_cases(model_data, seeding_time) + # R0 is handled separately from the generic params system: it is wrapped + # by the centred non-stationary GP, so its user-facing prior is on the + # initial Rt (R[1]) rather than on the sampled internal log-mean. params <- list( make_param("alpha", gp$alpha, lower_bound = 0), make_param("rho", gp$ls, lower_bound = 0), - make_param("R0", rt$prior, lower_bound = 0), make_param("fraction_observed", obs$scale, lower_bound = 0), make_param("reporting_overdispersion", obs$dispersion, lower_bound = 0), make_param("pop", rt$pop, lower_bound = 0) @@ -202,6 +204,27 @@ estimate_infections <- function(data, params = params ) + # Register R0 as the (single, for now) centred-GP-wrapped parameter with + # its user prior applied to the derived initial Rt. The dispatch in + # `inst/stan/estimate_infections.stan` (the `init lp` profile block) uses + # `param_id_R0` to know which trajectory's initial value to apply the + # prior to. Generic plumbing: additional time-varying parameters drop in + # alongside R0 by appending to `init_param_ids` / `init_dists` / + # `init_dist_params` and adding one dispatch branch in stan. + stan_data$param_id_R0 <- stan_data$n_params_variable + 1L + if (isTRUE(rt$use_rt)) { + init_R <- pack_init_prior(rt$prior) + stan_data$n_init_priors <- 1L + stan_data$init_param_ids <- array(stan_data$param_id_R0) + stan_data$init_dists <- array(init_R$dist_type) + stan_data$init_dist_params <- array(init_R$params) + } else { + stan_data$n_init_priors <- 0L + stan_data$init_param_ids <- array(integer(0)) + stan_data$init_dists <- array(integer(0)) + stan_data$init_dist_params <- array(numeric(0)) + } + stan_data <- c(stan_data, create_stan_delays( generation_time = generation_time, reporting = delays, diff --git a/R/opts.R b/R/opts.R index ccc236700..64ec18757 100644 --- a/R/opts.R +++ b/R/opts.R @@ -341,7 +341,6 @@ rt_opts <- function(prior = LogNormal(mean = 1, sd = 1), pop_floor = pop_floor, growth_method = arg_match(growth_method) ) - # replace default settings with those specified by user if (opts$rw > 0) { opts$use_breakpoints <- TRUE diff --git a/R/simulate_infections.R b/R/simulate_infections.R index af35ee884..0288d3dc4 100644 --- a/R/simulate_infections.R +++ b/R/simulate_infections.R @@ -198,6 +198,14 @@ simulate_infections <- function(R, ## set empty params matrix - variable parameters not supported here stan_data$params <- array(dim = c(1, 0)) + ## init priors: not used in forward simulation (no model fitting). Provide + ## empty defaults so the shared estimate_infections_params.stan data block + ## is satisfied. + stan_data$n_init_priors <- 0L + stan_data$init_param_ids <- array(integer(0)) + stan_data$init_dists <- array(integer(0)) + stan_data$init_dist_params <- array(numeric(0)) + ## day of week effect if (is.null(day_of_week_effect)) { day_of_week_effect <- rep(1, stan_data$week_effect) diff --git a/R/utilities.R b/R/utilities.R index 276a72d4c..bef5b1a81 100644 --- a/R/utilities.R +++ b/R/utilities.R @@ -481,6 +481,43 @@ make_param <- function(name, dist = NULL, lower_bound = -Inf) { params } +##' Pack a dist_spec into stan-side init-prior fields +##' +##' For parameters wrapped in a centred non-stationary GP (today: R0), the +##' user-facing prior in `rt_opts()` is on the initial value of the trajectory +##' rather than on the sampled internal mean. This helper converts a +##' `` into the integer/vector representation consumed by the +##' generic init-prior loop in `inst/stan/estimate_infections.stan` (data +##' items `init_dists` and `init_dist_params`). +##' +##' Generic over the parameter — designed to lift directly into a future +##' general time-varying-parameter framework (issue #600). +##' +##' @param dist A `` (LogNormal, Gamma, or Normal). +##' @return A list with elements `dist_type` (integer code: 0 = lognormal, +##' 1 = gamma, 2 = normal) and `params` (numeric of length 2 with the +##' distribution parameters). +##' @keywords internal +pack_init_prior <- function(dist) { + dist_name <- get_distribution(dist) + dist_type <- switch(dist_name, + lognormal = 0L, + gamma = 1L, + normal = 2L, + cli_abort(c( + "!" = "Init prior distribution {.val {dist_name}} not supported.", + "i" = "Use {.fn LogNormal}, {.fn Gamma}, or {.fn Normal}." + )) + ) + params <- unlist(get_parameters(dist)) + if (length(params) != 2) { + cli_abort( + "Init prior must have exactly 2 distribution parameters; got {length(params)}." # nolint + ) + } + list(dist_type = dist_type, params = as.numeric(params)) +} + #' @importFrom stats glm median na.omit pexp pgamma plnorm quasipoisson rexp #' @importFrom stats rlnorm rnorm rpois runif sd var rgamma pnorm globalVariables( diff --git a/inst/stan/data/estimate_infections_params.stan b/inst/stan/data/estimate_infections_params.stan index 8b48f1975..edea29b46 100644 --- a/inst/stan/data/estimate_infections_params.stan +++ b/inst/stan/data/estimate_infections_params.stan @@ -4,3 +4,12 @@ int param_id_R0; // parameter id of R0 int param_id_fraction_observed; // parameter id of fraction_observed int param_id_reporting_overdispersion; // parameter id of reporting_overdispersion int param_id_pop; // parameter id of pop + +// Init priors for centred-GP-wrapped parameters. Today only R0 is wrapped +// (its prior in rt_opts() is on the *initial* Rt, applied to the derived +// R[1] from the centred GP). Generic shape so additional time-varying +// parameters can be added without changing the data plumbing. +int n_init_priors; +array[n_init_priors] int init_param_ids; +array[n_init_priors] int init_dists; +vector[2 * n_init_priors] init_dist_params; diff --git a/inst/stan/estimate_infections.stan b/inst/stan/estimate_infections.stan index d535d016d..5b0753fb4 100644 --- a/inst/stan/estimate_infections.stan +++ b/inst/stan/estimate_infections.stan @@ -58,7 +58,10 @@ parameters { vector[n_params_variable] params; // gaussian process vector[fixed ? 0 : gp_type == 1 ? 2*M : M] eta; // unconstrained noise - // Rt + // Rt — mean log Rt over the window (sampled internally). User prior is on + // R[1] (initial Rt) and is applied via centred_gp_init_lpdf in the model + // block, with R[1] derived from R_mean + gp_centred[1]. + array[estimate_r] real R_mean; array[estimate_r] real initial_infections; // seed infections // standard deviation of breakpoint effect array[bp_n > 0 ? 1 : 0] real bp_sd; @@ -106,11 +109,11 @@ transformed parameters { ); } profile("R0") { - real R0 = get_param( - param_id_R0, params_fixed_lookup, params_variable_lookup, params_value, params - ); + // R_mean is sampled directly (centred GP scaffolding). The user prior + // lives on R[1] (initial Rt), applied in the model block via the + // init_priors plumbing. R = update_Rt( - ot_h, R0, noise, breakpoints, bp_effects, stationary + ot_h, R_mean[1], noise, breakpoints, bp_effects, stationary ); } profile("infections") { @@ -231,6 +234,29 @@ model { } } + // Apply user priors on the initial value of any centred-GP-wrapped + // parameter (today: R0 -> R[1]). Generic loop so new time-varying + // parameters drop in via a one-line dispatch addition below. + if (n_init_priors > 0) { + profile("init lp") { + for (i in 1:n_init_priors) { + real init_value; + int pid = init_param_ids[i]; + if (pid == param_id_R0) { + init_value = R[1]; + } else { + reject("no time-varying parameter registered for id ", pid); + } + target += centred_gp_init_lpdf( + init_value | + init_dists[i], + init_dist_params[2 * i - 1], + init_dist_params[2 * i] + ); + } + } + } + // observed reports from mean of reports (update likelihood) if (likelihood) { profile("report lp") { diff --git a/inst/stan/functions/params.stan b/inst/stan/functions/params.stan index 523b12025..b758558a5 100644 --- a/inst/stan/functions/params.stan +++ b/inst/stan/functions/params.stan @@ -106,3 +106,36 @@ void params_lp(vector params, array[] int prior_dist, } } +/** + * Apply user prior on the initial value of a centred-GP-wrapped parameter. + * + * When a parameter is wrapped by a centred non-stationary GP, the user-facing + * prior is on the initial value of the trajectory (X[1]) rather than on the + * sampled internal log-mean. This helper applies the prior to the derived + * initial value with the Jacobian correction for the linear-shift transform + * from log-mean to log-initial. + * + * Generic over the parameter — used by R0 today, lifts to any future + * time-varying parameter (alpha, dispersion, ...) via the same call. + * + * @param init_value Derived initial value of the trajectory (e.g. R[1]). + * @param dist_type Prior distribution type (0 = lognormal, 1 = gamma, 2 = normal). + * @param p1 First distribution parameter. + * @param p2 Second distribution parameter. + * @return Log-density contribution to the target, Jacobian included. + * + * @ingroup parameter_handlers + */ +real centred_gp_init_lpdf(real init_value, int dist_type, real p1, real p2) { + if (dist_type == 0) { + return lognormal_lpdf(init_value | p1, p2) + log(init_value); + } else if (dist_type == 1) { + return gamma_lpdf(init_value | p1, p2) + log(init_value); + } else if (dist_type == 2) { + return normal_lpdf(init_value | p1, p2) + log(init_value); + } else { + reject("centred_gp_init_lpdf: dist_type must be 0, 1, or 2"); + } + return 0; +} + diff --git a/inst/stan/functions/rt.stan b/inst/stan/functions/rt.stan index 0c1a85ce2..3210ed8b7 100644 --- a/inst/stan/functions/rt.stan +++ b/inst/stan/functions/rt.stan @@ -48,6 +48,10 @@ vector update_Rt(int t, real R0, vector noise, array[] int bps, } else { gp[2:(gp_n + 1)] = noise; gp = cumulative_sum(gp); + // Identifiability: subtract the trajectory mean so log R0 is the mean + // log Rt over the window rather than the initial value. Eliminates + // the (R0, drift) ridge in the joint posterior. + gp -= mean(gp); } logR = logR + gp; } diff --git a/man/pack_init_prior.Rd b/man/pack_init_prior.Rd new file mode 100644 index 000000000..a44643f36 --- /dev/null +++ b/man/pack_init_prior.Rd @@ -0,0 +1,29 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/utilities.R +\name{pack_init_prior} +\alias{pack_init_prior} +\title{Pack a dist_spec into stan-side init-prior fields} +\usage{ +pack_init_prior(dist) +} +\arguments{ +\item{dist}{A \verb{} (LogNormal, Gamma, or Normal).} +} +\value{ +A list with elements \code{dist_type} (integer code: 0 = lognormal, +1 = gamma, 2 = normal) and \code{params} (numeric of length 2 with the +distribution parameters). +} +\description{ +For parameters wrapped in a centred non-stationary GP (today: R0), the +user-facing prior in \code{rt_opts()} is on the initial value of the trajectory +rather than on the sampled internal mean. This helper converts a +\verb{} into the integer/vector representation consumed by the +generic init-prior loop in \code{inst/stan/estimate_infections.stan} (data +items \code{init_dists} and \code{init_dist_params}). +} +\details{ +Generic over the parameter — designed to lift directly into a future +general time-varying-parameter framework (issue #600). +} +\keyword{internal} diff --git a/tests/testthat/test-stan-rt.R b/tests/testthat/test-stan-rt.R index bffc9c8b8..13cc93e39 100644 --- a/tests/testthat/test-stan-rt.R +++ b/tests/testthat/test-stan-rt.R @@ -9,9 +9,14 @@ test_that("update_Rt works to produce multiple Rt estimates with a static gaussi ) }) test_that("update_Rt works to produce multiple Rt estimates with a non-static gaussian process", { + # Non-stationary GP: cumulated trajectory is centred so log R0 = mean log Rt + # over the window rather than the initial value (eliminates the (R0, drift) + # ridge in the joint posterior). For noise = rep(0.1, 9), gp_n = 9: + # gp = cumsum(noise) = c(0, 0.1, 0.2, ..., 0.9), mean = 0.45, + # centred = c(-0.45, -0.35, ..., 0.45). log Rt = log(1.2) + centred. expect_equal( - round(update_Rt(10, 1.2, rep(0.1, 9), rep(10, 0), numeric(0), 0), 2), - c(1.20, 1.33, 1.47, 1.62, 1.79, 1.98, 2.19, 2.42, 2.67, 2.95) + round(update_Rt(10, 1.2, rep(0.1, 9), rep(10, 0), numeric(0), 0), 3), + c(0.765, 0.846, 0.935, 1.033, 1.141, 1.262, 1.394, 1.541, 1.703, 1.882) ) }) test_that("update_Rt works to produce multiple Rt estimates with a non-static stationary gaussian process", { @@ -53,9 +58,12 @@ test_that("update_Rt works when Rt is variable and a breakpoint is present", { round(update_Rt(5, 1.2, rep(0, 5), c(1, 1, 2, 2, 2), 0.1, 1), 2), c(1.2, 1.2, rep(1.33, 3)) ) + # Non-stationary GP: see explanation in the earlier non-static GP test. + # Here gp_n = 4, gp_centred = c(-0.2, -0.1, 0, 0.1, 0.2), breakpoint adds + # 0.1 from t = 3 onward. expect_equal( - round(update_Rt(5, 1.2, rep(0.1, 4), c(1, 1, 2, 2, 2), 0.1, 0), 2), - c(1.20, 1.33, 1.62, 1.79, 1.98) + round(update_Rt(5, 1.2, rep(0.1, 4), c(1, 1, 2, 2, 2), 0.1, 0), 3), + c(0.982, 1.086, 1.326, 1.466, 1.620) ) })