From c826aea4c7e94f00f2bd881336e39d25070dde35 Mon Sep 17 00:00:00 2001 From: Bob Carpenter Date: Thu, 29 Jan 2026 17:09:42 -0500 Subject: [PATCH 1/3] refactor convolve, gaussian_process --- inst/stan/functions/convolve.stan | 13 ++---- inst/stan/functions/gaussian_process.stan | 52 ++++++++++++++--------- 2 files changed, 34 insertions(+), 31 deletions(-) diff --git a/inst/stan/functions/convolve.stan b/inst/stan/functions/convolve.stan index 2a84df929..7cdeefe18 100644 --- a/inst/stan/functions/convolve.stan +++ b/inst/stan/functions/convolve.stan @@ -62,7 +62,6 @@ array[] int calc_conv_indices_len(int s, int xlen, int ylen) { vector convolve_with_rev_pmf(vector x, vector y, int len) { int xlen = num_elements(x); int ylen = num_elements(y); - vector[len] z; if (xlen + ylen - 1 < len) { reject("convolve_with_rev_pmf: len is longer than x and y convolved"); @@ -72,18 +71,12 @@ vector convolve_with_rev_pmf(vector x, vector y, int len) { reject("convolve_with_rev_pmf: len is shorter than x"); } - for (s in 1:xlen) { + vector[len] z; + int ub = max(len, xlen); + for (s in 1:ub) { array[4] int indices = calc_conv_indices_xlen(s, xlen, ylen); z[s] = dot_product(x[indices[1]:indices[2]], y[indices[3]:indices[4]]); } - - if (len > xlen) { - for (s in (xlen + 1):len) { - array[4] int indices = calc_conv_indices_len(s, xlen, ylen); - z[s] = dot_product(x[indices[1]:indices[2]], y[indices[3]:indices[4]]); - } - } - return z; } diff --git a/inst/stan/functions/gaussian_process.stan b/inst/stan/functions/gaussian_process.stan index 2f067b4bd..aa8a992ce 100644 --- a/inst/stan/functions/gaussian_process.stan +++ b/inst/stan/functions/gaussian_process.stan @@ -23,6 +23,21 @@ vector diagSPD_EQ(real alpha, real rho, real L, int M) { return factor * exp(exponent * square(indices)); } +/** + * Index set for M basis functions of length L for Matern kernel. + * + * The function returns pow(pi() / 2 / L * linspaced_vector(M, 1, M), 2), + * or equivalently, square(pi() / (2 * L) * linspaced_vector(M, 1, M)). + * + * @param L Length of the interval + * @param M Number of basis functions + * @return Linearly spaced M-vector + */ +vector matern_indices(int M, int L) { + real factor = pi() / (2 * L); + return square(linspaced_vector(M, factor, factor * M)); +} + /** * Spectral density for 1/2 Matern (Ornstein-Uhlenbeck) kernel * @@ -35,10 +50,8 @@ vector diagSPD_EQ(real alpha, real rho, real L, int M) { * @ingroup estimates_smoothing */ vector diagSPD_Matern12(real alpha, real rho, real L, int M) { - vector[M] indices = linspaced_vector(M, 1, M); - real factor = 2; - vector[M] denom = rho * ((1 / rho)^2 + pow(pi() / 2 / L * indices, 2)); - return alpha * sqrt(factor * inv(denom)); + vector[M] denom = 1 / rho + rho * matern_indices(M, L); + return alpha * sqrt(2 ./ denom); } /** @@ -53,10 +66,9 @@ vector diagSPD_Matern12(real alpha, real rho, real L, int M) { * @ingroup estimates_smoothing */ vector diagSPD_Matern32(real alpha, real rho, real L, int M) { - vector[M] indices = linspaced_vector(M, 1, M); - real factor = 2 * alpha * pow(sqrt(3) / rho, 1.5); - vector[M] denom = (sqrt(3) / rho)^2 + pow((pi() / 2 / L) * indices, 2); - return factor * inv(denom); + real factor1 = 2 * alpha * (sqrt(3) / rho)^1.5; + vector[M] denom = 3 / square(rho) + matern_indices(M, L); + return factor ./ denom; } /** @@ -73,9 +85,8 @@ vector diagSPD_Matern32(real alpha, real rho, real L, int M) { vector diagSPD_Matern52(real alpha, real rho, real L, int M) { vector[M] indices = linspaced_vector(M, 1, M); real factor = 16 * pow(sqrt(5) / rho, 5); - vector[M] denom = - 3 * pow((sqrt(5) / rho)^2 + pow((pi() / 2 / L) * indices, 2), 3); - return alpha * sqrt(factor * inv(denom)); + vector[M] denom = 3 * pow(5 / square(rho) + matern_indices(M, L), 3); + return alpha * sqrt(factor ./ denom); } /** @@ -92,10 +103,11 @@ vector diagSPD_Periodic(real alpha, real rho, int M) { real a = inv_square(rho); vector[M] indices = linspaced_vector(M, 1, M); vector[M] q = exp( - log(alpha) + 0.5 * - (log(2) - a + to_vector(log_modified_bessel_first_kind(indices, a))) + log(alpha) + + 0.5 * (log2() - a + log_modified_bessel_first_kind(indices, a)) ); return append_row(q, q); + } /** @@ -129,11 +141,11 @@ matrix PHI(int N, int M, real L, vector x) { * * @ingroup estimates_smoothing */ + matrix PHI_periodic(int N, int M, real w0, vector x) { - matrix[N, M] mw0x = diag_post_multiply( - rep_matrix(w0 * x, M), linspaced_vector(M, 1, M) - ); - return append_col(cos(mw0x), sin(mw0x)); + row_vector[M] k = linspaced_row_vector(M, 1, M); + matrix[N, M] w0xk = (w0 * x) * k; + return append_col(cos(w0xk), sin(w0xk)); } /** @@ -153,9 +165,7 @@ matrix PHI_periodic(int N, int M, real w0, vector x) { int setup_noise(int ot_h, int t, int horizon, int estimate_r, int stationary, int future_fixed, int fixed_from) { int noise_time = estimate_r > 0 ? (stationary > 0 ? ot_h : ot_h - 1) : t; - int noise_terms = - future_fixed > 0 ? (noise_time - horizon + fixed_from) : noise_time; - return noise_terms; + return future_fixed > 0 ? (noise_time - horizon + fixed_from) : noise_time; } /** @@ -210,7 +220,7 @@ vector update_gp(matrix PHI, int M, real L, real alpha, } else if (nu == 2.5) { diagSPD = diagSPD_Matern52(alpha, rho, L, M); } else { - reject("nu must be one of 1/2, 3/2 or 5/2; found nu=", nu); + reject("nu must be one of 0.5, 1.5, or 2.5; found nu=", nu); } } return PHI * (diagSPD .* eta); From 6700753e53efe76e546ff00cb03a51b94dd6e231 Mon Sep 17 00:00:00 2001 From: Bob Carpenter Date: Thu, 29 Jan 2026 17:51:20 -0500 Subject: [PATCH 2/3] added NEWS item --- NEWS.md | 1 + 1 file changed, 1 insertion(+) diff --git a/NEWS.md b/NEWS.md index 7264cfa51..eabcbdd4e 100644 --- a/NEWS.md +++ b/NEWS.md @@ -58,6 +58,7 @@ The function interface remains unchanged. ## Model changes +- Refactored Stan code for efficiency (@bob-carpenter, #1273). - MCMC runs are now initialised with parameter values drawn from a distribution that approximates their prior distributions. - Added an option to compute growth rates using an estimator by Parag et al. (2022) based on total infectiousness rather than new infections, see `growth_method` argument in rt_opts(). - Added support for fitting the susceptible population size. From 27eeaa37d0fdaf88921700103d4893f75345dba4 Mon Sep 17 00:00:00 2001 From: Bob Carpenter Date: Thu, 29 Jan 2026 17:55:04 -0500 Subject: [PATCH 3/3] code review patches on loops --- inst/stan/functions/convolve.stan | 10 ++++++++-- inst/stan/functions/gaussian_process.stan | 5 ++--- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/inst/stan/functions/convolve.stan b/inst/stan/functions/convolve.stan index 7cdeefe18..b11a0cf6e 100644 --- a/inst/stan/functions/convolve.stan +++ b/inst/stan/functions/convolve.stan @@ -72,11 +72,17 @@ vector convolve_with_rev_pmf(vector x, vector y, int len) { } vector[len] z; - int ub = max(len, xlen); - for (s in 1:ub) { + + for (s in 1:xlen) { array[4] int indices = calc_conv_indices_xlen(s, xlen, ylen); z[s] = dot_product(x[indices[1]:indices[2]], y[indices[3]:indices[4]]); } + + for (s in (xlen + 1):len) { // zero iterations unless len > xlen + array[4] int indices = calc_conv_indices_len(s, xlen, ylen); + z[s] = dot_product(x[indices[1]:indices[2]], y[indices[3]:indices[4]]); + } + return z; } diff --git a/inst/stan/functions/gaussian_process.stan b/inst/stan/functions/gaussian_process.stan index aa8a992ce..6480dccc1 100644 --- a/inst/stan/functions/gaussian_process.stan +++ b/inst/stan/functions/gaussian_process.stan @@ -33,7 +33,7 @@ vector diagSPD_EQ(real alpha, real rho, real L, int M) { * @param M Number of basis functions * @return Linearly spaced M-vector */ -vector matern_indices(int M, int L) { +vector matern_indices(int M, real L) { real factor = pi() / (2 * L); return square(linspaced_vector(M, factor, factor * M)); } @@ -66,7 +66,7 @@ vector diagSPD_Matern12(real alpha, real rho, real L, int M) { * @ingroup estimates_smoothing */ vector diagSPD_Matern32(real alpha, real rho, real L, int M) { - real factor1 = 2 * alpha * (sqrt(3) / rho)^1.5; + real factor = 2 * alpha * (sqrt(3) / rho)^1.5; vector[M] denom = 3 / square(rho) + matern_indices(M, L); return factor ./ denom; } @@ -83,7 +83,6 @@ vector diagSPD_Matern32(real alpha, real rho, real L, int M) { * @ingroup estimates_smoothing */ vector diagSPD_Matern52(real alpha, real rho, real L, int M) { - vector[M] indices = linspaced_vector(M, 1, M); real factor = 16 * pow(sqrt(5) / rho, 5); vector[M] denom = 3 * pow(5 / square(rho) + matern_indices(M, L), 3); return alpha * sqrt(factor ./ denom);