diff --git a/NEWS.md b/NEWS.md index 9ea7dc325..816f91704 100644 --- a/NEWS.md +++ b/NEWS.md @@ -53,6 +53,7 @@ - Delay distribution discretisation now properly accounts for primary event censoring during model fitting, matching the correction already applied on the R side since v1.8.0. This improves accuracy for short delays where the observation window is large relative to the delay. - Left truncation of delay distributions (e.g. excluding generation times of zero) is now handled analytically rather than by zeroing and renormalising, giving more accurate PMFs near the truncation point. +- Refactored the Gaussian process, convolution, and observation model Stan functions for efficiency and readability, including a shared `matern_indices()` helper for the Matern spectral densities, outer-product basis function construction in `PHI()` and `PHI_periodic()`, and a shared `reporting_phi()` helper for the negative binomial overdispersion. ## Package changes diff --git a/inst/stan/functions/convolve.stan b/inst/stan/functions/convolve.stan index 2a84df929..76e20cbdc 100644 --- a/inst/stan/functions/convolve.stan +++ b/inst/stan/functions/convolve.stan @@ -40,7 +40,7 @@ array[] int calc_conv_indices_len(int s, int xlen, int ylen) { int s_minus_ylen = s - ylen; int start_x = max(1, s_minus_ylen + 1); int end_x = xlen; - int start_y = max(1, 1 - s_minus_ylen);; + int start_y = max(1, 1 - s_minus_ylen); int end_y = ylen + xlen - s; return {start_x, end_x, start_y, end_y}; } @@ -62,7 +62,6 @@ array[] int calc_conv_indices_len(int s, int xlen, int ylen) { vector convolve_with_rev_pmf(vector x, vector y, int len) { int xlen = num_elements(x); int ylen = num_elements(y); - vector[len] z; if (xlen + ylen - 1 < len) { reject("convolve_with_rev_pmf: len is longer than x and y convolved"); @@ -72,16 +71,17 @@ vector convolve_with_rev_pmf(vector x, vector y, int len) { reject("convolve_with_rev_pmf: len is shorter than x"); } + vector[len] z; + for (s in 1:xlen) { array[4] int indices = calc_conv_indices_xlen(s, xlen, ylen); z[s] = dot_product(x[indices[1]:indices[2]], y[indices[3]:indices[4]]); } - if (len > xlen) { - for (s in (xlen + 1):len) { - array[4] int indices = calc_conv_indices_len(s, xlen, ylen); - z[s] = dot_product(x[indices[1]:indices[2]], y[indices[3]:indices[4]]); - } + // runs zero times unless len > xlen + for (s in (xlen + 1):len) { + array[4] int indices = calc_conv_indices_len(s, xlen, ylen); + z[s] = dot_product(x[indices[1]:indices[2]], y[indices[3]:indices[4]]); } return z; diff --git a/inst/stan/functions/gaussian_process.stan b/inst/stan/functions/gaussian_process.stan index 2f067b4bd..ca1f80a21 100644 --- a/inst/stan/functions/gaussian_process.stan +++ b/inst/stan/functions/gaussian_process.stan @@ -23,6 +23,26 @@ vector diagSPD_EQ(real alpha, real rho, real L, int M) { return factor * exp(exponent * square(indices)); } +/** + * Squared spectral indices shared by the Matern kernels + * + * Returns `square(pi() / (2 * L) * linspaced_vector(M, 1, M))`, the term that + * appears in the denominator of every Matern spectral density. The + * `linspaced_vector()` call uses data-only bounds and is scaled afterwards so + * the function compiles under the data-only argument constraint in older Stan + * versions (e.g. the one shipped with rstan). + * + * @param M Number of basis functions + * @param L Length of the interval + * @return A vector of squared spectral indices + * + * @ingroup estimates_smoothing + */ +vector matern_indices(int M, real L) { + vector[M] indices = linspaced_vector(M, 1, M); + return square(pi() / (2 * L) * indices); +} + /** * Spectral density for 1/2 Matern (Ornstein-Uhlenbeck) kernel * @@ -35,10 +55,8 @@ vector diagSPD_EQ(real alpha, real rho, real L, int M) { * @ingroup estimates_smoothing */ vector diagSPD_Matern12(real alpha, real rho, real L, int M) { - vector[M] indices = linspaced_vector(M, 1, M); - real factor = 2; - vector[M] denom = rho * ((1 / rho)^2 + pow(pi() / 2 / L * indices, 2)); - return alpha * sqrt(factor * inv(denom)); + vector[M] denom = 1 / rho + rho * matern_indices(M, L); + return alpha * sqrt(2 ./ denom); } /** @@ -53,10 +71,9 @@ vector diagSPD_Matern12(real alpha, real rho, real L, int M) { * @ingroup estimates_smoothing */ vector diagSPD_Matern32(real alpha, real rho, real L, int M) { - vector[M] indices = linspaced_vector(M, 1, M); - real factor = 2 * alpha * pow(sqrt(3) / rho, 1.5); - vector[M] denom = (sqrt(3) / rho)^2 + pow((pi() / 2 / L) * indices, 2); - return factor * inv(denom); + real factor = 2 * alpha * (sqrt(3) / rho)^1.5; + vector[M] denom = 3 / square(rho) + matern_indices(M, L); + return factor ./ denom; } /** @@ -71,11 +88,9 @@ vector diagSPD_Matern32(real alpha, real rho, real L, int M) { * @ingroup estimates_smoothing */ vector diagSPD_Matern52(real alpha, real rho, real L, int M) { - vector[M] indices = linspaced_vector(M, 1, M); real factor = 16 * pow(sqrt(5) / rho, 5); - vector[M] denom = - 3 * pow((sqrt(5) / rho)^2 + pow((pi() / 2 / L) * indices, 2), 3); - return alpha * sqrt(factor * inv(denom)); + vector[M] denom = 3 * pow(5 / square(rho) + matern_indices(M, L), 3); + return alpha * sqrt(factor ./ denom); } /** @@ -110,11 +125,8 @@ vector diagSPD_Periodic(real alpha, real rho, int M) { * @ingroup estimates_smoothing */ matrix PHI(int N, int M, real L, vector x) { - matrix[N, M] phi = sin( - diag_post_multiply( - rep_matrix(pi() / (2 * L) * (x + L), M), linspaced_vector(M, 1, M) - ) - ) / sqrt(L); + row_vector[M] k = linspaced_row_vector(M, 1, M); + matrix[N, M] phi = sin((pi() / (2 * L) * (x + L)) * k) / sqrt(L); return phi; } @@ -130,10 +142,9 @@ matrix PHI(int N, int M, real L, vector x) { * @ingroup estimates_smoothing */ matrix PHI_periodic(int N, int M, real w0, vector x) { - matrix[N, M] mw0x = diag_post_multiply( - rep_matrix(w0 * x, M), linspaced_vector(M, 1, M) - ); - return append_col(cos(mw0x), sin(mw0x)); + row_vector[M] k = linspaced_row_vector(M, 1, M); + matrix[N, M] w0xk = (w0 * x) * k; + return append_col(cos(w0xk), sin(w0xk)); } /** @@ -210,7 +221,7 @@ vector update_gp(matrix PHI, int M, real L, real alpha, } else if (nu == 2.5) { diagSPD = diagSPD_Matern52(alpha, rho, L, M); } else { - reject("nu must be one of 1/2, 3/2 or 5/2; found nu=", nu); + reject("nu must be one of 0.5, 1.5, or 2.5; found nu=", nu); } } return PHI * (diagSPD .* eta); diff --git a/inst/stan/functions/observation_model.stan b/inst/stan/functions/observation_model.stan index d903a04a6..e149be104 100644 --- a/inst/stan/functions/observation_model.stan +++ b/inst/stan/functions/observation_model.stan @@ -106,6 +106,25 @@ void truncation_lp(array[] real truncation_mean, array[] real truncation_sd, } } +/** + * Negative binomial overdispersion for the reporting model + * + * Converts the reporting overdispersion parameter into the `phi` of the + * negative binomial. When no overdispersion is modelled a large value is + * returned so the negative binomial behaves like a Poisson. + * + * @param reporting_overdispersion Real value for reporting overdispersion. + * @param model_type Integer indicating the model type (0 for Poisson, >0 for + * Negative Binomial). + * + * @return The negative binomial overdispersion `phi`. + * + * @ingroup observation_model + */ +real reporting_phi(real reporting_overdispersion, int model_type) { + return model_type ? inv_square(reporting_overdispersion) : 1e5; +} + /** * Update log density for reported cases * @@ -127,7 +146,7 @@ void report_lp(array[] int cases, array[] int case_times, vector reports, int n = num_elements(case_times); // number of observations vector[n] obs_reports = reports[case_times]; // reports at observation time if (model_type) { - real phi = inv_square(reporting_overdispersion); + real phi = reporting_phi(reporting_overdispersion, model_type); if (weight == 1) { cases ~ neg_binomial_2(obs_reports, phi); } else { @@ -197,7 +216,7 @@ vector report_log_lik(array[] int cases, vector reports, log_lik[i] = poisson_lpmf(cases[i] | reports[i]) * weight; } } else { - real phi = inv_square(reporting_overdispersion); + real phi = reporting_phi(reporting_overdispersion, model_type); for (i in 1:t) { log_lik[i] = neg_binomial_2_lpmf( cases[i] | reports[i], phi @@ -257,10 +276,7 @@ int neg_binomial_2_safe_rng(real mu, real phi) { array[] int report_rng(vector reports, real reporting_overdispersion, int model_type) { int t = num_elements(reports); array[t] int sampled_reports; - real phi = 1e5; - if (model_type) { - phi = inv_square(reporting_overdispersion); - } + real phi = reporting_phi(reporting_overdispersion, model_type); for (s in 1:t) { sampled_reports[s] = neg_binomial_2_safe_rng(reports[s], phi); diff --git a/tests/testthat/test-stan-guassian-process.R b/tests/testthat/test-stan-guassian-process.R index 085285e67..08ed026b2 100644 --- a/tests/testthat/test-stan-guassian-process.R +++ b/tests/testthat/test-stan-guassian-process.R @@ -24,6 +24,18 @@ test_that("diagSPD_EQ returns correct dimensions and values", { expect_equal(result, expected_result, tolerance = 1e-8) }) +test_that("matern_indices returns correct dimensions and values", { + L <- 1.0 + M <- 5 + result <- matern_indices(M, L) + expect_equal(length(result), M) + expect_true(all(result > 0)) + # Check specific values for known inputs + indices <- linspaced_vector(M, 1, M) + expected_result <- (pi / (2 * L) * indices)^2 + expect_equal(result, expected_result, tolerance = 1e-8) +}) + test_that("diagSPD_Matern functions return correct dimensions and values", { alpha <- 1.0 rho <- 2.0 diff --git a/tests/testthat/test-stan-observation_model.R b/tests/testthat/test-stan-observation_model.R index a871d1f58..d0ada777e 100644 --- a/tests/testthat/test-stan-observation_model.R +++ b/tests/testthat/test-stan-observation_model.R @@ -1,6 +1,17 @@ skip_on_cran() skip_on_os("windows") +test_that("reporting_phi returns overdispersion or Poisson fallback", { + reporting_overdispersion <- 0.5 + # Negative binomial: phi is the inverse square of the overdispersion + expect_equal( + reporting_phi(reporting_overdispersion, 1), + 1 / reporting_overdispersion^2 + ) + # Poisson: large phi so the negative binomial behaves like a Poisson + expect_equal(reporting_phi(reporting_overdispersion, 0), 1e5) +}) + test_that("day_of_week_effect applies day of week effect correctly", { reports <- c(100, 200, 300, 400, 500, 600, 700, 800, 900, 1000) day_of_week <- c(1, 2, 3, 1, 2, 3, 1, 2, 3, 1)