Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@

- Delay distribution discretisation now properly accounts for primary event censoring during model fitting, matching the correction already applied on the R side since v1.8.0. This improves accuracy for short delays where the observation window is large relative to the delay.
- Left truncation of delay distributions (e.g. excluding generation times of zero) is now handled analytically rather than by zeroing and renormalising, giving more accurate PMFs near the truncation point.
- Refactored the Gaussian process, convolution, and observation model Stan functions for efficiency and readability, including a shared `matern_indices()` helper for the Matern spectral densities, outer-product basis function construction in `PHI()` and `PHI_periodic()`, and a shared `reporting_phi()` helper for the negative binomial overdispersion.

## Package changes

Expand Down
14 changes: 7 additions & 7 deletions inst/stan/functions/convolve.stan
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ array[] int calc_conv_indices_len(int s, int xlen, int ylen) {
int s_minus_ylen = s - ylen;
int start_x = max(1, s_minus_ylen + 1);
int end_x = xlen;
int start_y = max(1, 1 - s_minus_ylen);;
int start_y = max(1, 1 - s_minus_ylen);
int end_y = ylen + xlen - s;
return {start_x, end_x, start_y, end_y};
}
Expand All @@ -62,7 +62,6 @@ array[] int calc_conv_indices_len(int s, int xlen, int ylen) {
vector convolve_with_rev_pmf(vector x, vector y, int len) {
int xlen = num_elements(x);
int ylen = num_elements(y);
vector[len] z;

if (xlen + ylen - 1 < len) {
reject("convolve_with_rev_pmf: len is longer than x and y convolved");
Expand All @@ -72,16 +71,17 @@ vector convolve_with_rev_pmf(vector x, vector y, int len) {
reject("convolve_with_rev_pmf: len is shorter than x");
}

vector[len] z;

for (s in 1:xlen) {
array[4] int indices = calc_conv_indices_xlen(s, xlen, ylen);
z[s] = dot_product(x[indices[1]:indices[2]], y[indices[3]:indices[4]]);
}

if (len > xlen) {
for (s in (xlen + 1):len) {
array[4] int indices = calc_conv_indices_len(s, xlen, ylen);
z[s] = dot_product(x[indices[1]:indices[2]], y[indices[3]:indices[4]]);
}
// runs zero times unless len > xlen
for (s in (xlen + 1):len) {
array[4] int indices = calc_conv_indices_len(s, xlen, ylen);
z[s] = dot_product(x[indices[1]:indices[2]], y[indices[3]:indices[4]]);
}

return z;
Expand Down
55 changes: 33 additions & 22 deletions inst/stan/functions/gaussian_process.stan
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,26 @@ vector diagSPD_EQ(real alpha, real rho, real L, int M) {
return factor * exp(exponent * square(indices));
}

/**
* Squared spectral indices shared by the Matern kernels
*
* Returns `square(pi() / (2 * L) * linspaced_vector(M, 1, M))`, the term that
* appears in the denominator of every Matern spectral density. The
* `linspaced_vector()` call uses data-only bounds and is scaled afterwards so
* the function compiles under the data-only argument constraint in older Stan
* versions (e.g. the one shipped with rstan).
*
* @param M Number of basis functions
* @param L Length of the interval
* @return A vector of squared spectral indices
*
* @ingroup estimates_smoothing
*/
vector matern_indices(int M, real L) {
vector[M] indices = linspaced_vector(M, 1, M);
return square(pi() / (2 * L) * indices);
}

/**
* Spectral density for 1/2 Matern (Ornstein-Uhlenbeck) kernel
*
Expand All @@ -35,10 +55,8 @@ vector diagSPD_EQ(real alpha, real rho, real L, int M) {
* @ingroup estimates_smoothing
*/
vector diagSPD_Matern12(real alpha, real rho, real L, int M) {
vector[M] indices = linspaced_vector(M, 1, M);
real factor = 2;
vector[M] denom = rho * ((1 / rho)^2 + pow(pi() / 2 / L * indices, 2));
return alpha * sqrt(factor * inv(denom));
vector[M] denom = 1 / rho + rho * matern_indices(M, L);
return alpha * sqrt(2 ./ denom);
}

/**
Expand All @@ -53,10 +71,9 @@ vector diagSPD_Matern12(real alpha, real rho, real L, int M) {
* @ingroup estimates_smoothing
*/
vector diagSPD_Matern32(real alpha, real rho, real L, int M) {
vector[M] indices = linspaced_vector(M, 1, M);
real factor = 2 * alpha * pow(sqrt(3) / rho, 1.5);
vector[M] denom = (sqrt(3) / rho)^2 + pow((pi() / 2 / L) * indices, 2);
return factor * inv(denom);
real factor = 2 * alpha * (sqrt(3) / rho)^1.5;
vector[M] denom = 3 / square(rho) + matern_indices(M, L);
return factor ./ denom;
}

/**
Expand All @@ -71,11 +88,9 @@ vector diagSPD_Matern32(real alpha, real rho, real L, int M) {
* @ingroup estimates_smoothing
*/
vector diagSPD_Matern52(real alpha, real rho, real L, int M) {
vector[M] indices = linspaced_vector(M, 1, M);
real factor = 16 * pow(sqrt(5) / rho, 5);
vector[M] denom =
3 * pow((sqrt(5) / rho)^2 + pow((pi() / 2 / L) * indices, 2), 3);
return alpha * sqrt(factor * inv(denom));
vector[M] denom = 3 * pow(5 / square(rho) + matern_indices(M, L), 3);
return alpha * sqrt(factor ./ denom);
}

/**
Expand Down Expand Up @@ -110,11 +125,8 @@ vector diagSPD_Periodic(real alpha, real rho, int M) {
* @ingroup estimates_smoothing
*/
matrix PHI(int N, int M, real L, vector x) {
matrix[N, M] phi = sin(
diag_post_multiply(
rep_matrix(pi() / (2 * L) * (x + L), M), linspaced_vector(M, 1, M)
)
) / sqrt(L);
row_vector[M] k = linspaced_row_vector(M, 1, M);
matrix[N, M] phi = sin((pi() / (2 * L) * (x + L)) * k) / sqrt(L);
return phi;
}

Expand All @@ -130,10 +142,9 @@ matrix PHI(int N, int M, real L, vector x) {
* @ingroup estimates_smoothing
*/
matrix PHI_periodic(int N, int M, real w0, vector x) {
matrix[N, M] mw0x = diag_post_multiply(
rep_matrix(w0 * x, M), linspaced_vector(M, 1, M)
);
return append_col(cos(mw0x), sin(mw0x));
row_vector[M] k = linspaced_row_vector(M, 1, M);
matrix[N, M] w0xk = (w0 * x) * k;
return append_col(cos(w0xk), sin(w0xk));
}

/**
Expand Down Expand Up @@ -210,7 +221,7 @@ vector update_gp(matrix PHI, int M, real L, real alpha,
} else if (nu == 2.5) {
diagSPD = diagSPD_Matern52(alpha, rho, L, M);
} else {
reject("nu must be one of 1/2, 3/2 or 5/2; found nu=", nu);
reject("nu must be one of 0.5, 1.5, or 2.5; found nu=", nu);
}
}
return PHI * (diagSPD .* eta);
Expand Down
28 changes: 22 additions & 6 deletions inst/stan/functions/observation_model.stan
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,25 @@ void truncation_lp(array[] real truncation_mean, array[] real truncation_sd,
}
}

/**
* Negative binomial overdispersion for the reporting model
*
* Converts the reporting overdispersion parameter into the `phi` of the
* negative binomial. When no overdispersion is modelled a large value is
* returned so the negative binomial behaves like a Poisson.
*
* @param reporting_overdispersion Real value for reporting overdispersion.
* @param model_type Integer indicating the model type (0 for Poisson, >0 for
* Negative Binomial).
*
* @return The negative binomial overdispersion `phi`.
*
* @ingroup observation_model
*/
real reporting_phi(real reporting_overdispersion, int model_type) {
return model_type ? inv_square(reporting_overdispersion) : 1e5;
}

/**
* Update log density for reported cases
*
Expand All @@ -127,7 +146,7 @@ void report_lp(array[] int cases, array[] int case_times, vector reports,
int n = num_elements(case_times); // number of observations
vector[n] obs_reports = reports[case_times]; // reports at observation time
if (model_type) {
real phi = inv_square(reporting_overdispersion);
real phi = reporting_phi(reporting_overdispersion, model_type);
if (weight == 1) {
cases ~ neg_binomial_2(obs_reports, phi);
} else {
Expand Down Expand Up @@ -197,7 +216,7 @@ vector report_log_lik(array[] int cases, vector reports,
log_lik[i] = poisson_lpmf(cases[i] | reports[i]) * weight;
}
} else {
real phi = inv_square(reporting_overdispersion);
real phi = reporting_phi(reporting_overdispersion, model_type);
for (i in 1:t) {
log_lik[i] = neg_binomial_2_lpmf(
cases[i] | reports[i], phi
Expand Down Expand Up @@ -257,10 +276,7 @@ int neg_binomial_2_safe_rng(real mu, real phi) {
array[] int report_rng(vector reports, real reporting_overdispersion, int model_type) {
int t = num_elements(reports);
array[t] int sampled_reports;
real phi = 1e5;
if (model_type) {
phi = inv_square(reporting_overdispersion);
}
real phi = reporting_phi(reporting_overdispersion, model_type);

for (s in 1:t) {
sampled_reports[s] = neg_binomial_2_safe_rng(reports[s], phi);
Expand Down
12 changes: 12 additions & 0 deletions tests/testthat/test-stan-guassian-process.R
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,18 @@ test_that("diagSPD_EQ returns correct dimensions and values", {
expect_equal(result, expected_result, tolerance = 1e-8)
})

test_that("matern_indices returns correct dimensions and values", {
L <- 1.0
M <- 5
result <- matern_indices(M, L)
expect_equal(length(result), M)
expect_true(all(result > 0))
# Check specific values for known inputs
indices <- linspaced_vector(M, 1, M)
expected_result <- (pi / (2 * L) * indices)^2
expect_equal(result, expected_result, tolerance = 1e-8)
})

test_that("diagSPD_Matern functions return correct dimensions and values", {
alpha <- 1.0
rho <- 2.0
Expand Down
11 changes: 11 additions & 0 deletions tests/testthat/test-stan-observation_model.R
Original file line number Diff line number Diff line change
@@ -1,6 +1,17 @@
skip_on_cran()
skip_on_os("windows")

test_that("reporting_phi returns overdispersion or Poisson fallback", {
reporting_overdispersion <- 0.5
# Negative binomial: phi is the inverse square of the overdispersion
expect_equal(
reporting_phi(reporting_overdispersion, 1),
1 / reporting_overdispersion^2
)
# Poisson: large phi so the negative binomial behaves like a Poisson
expect_equal(reporting_phi(reporting_overdispersion, 0), 1e5)
})

test_that("day_of_week_effect applies day of week effect correctly", {
reports <- c(100, 200, 300, 400, 500, 600, 700, 800, 900, 1000)
day_of_week <- c(1, 2, 3, 1, 2, 3, 1, 2, 3, 1)
Expand Down
Loading