Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 3 additions & 10 deletions inst/stan/functions/convolve.stan
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@ array[] int calc_conv_indices_len(int s, int xlen, int ylen) {
vector convolve_with_rev_pmf(vector x, vector y, int len) {
int xlen = num_elements(x);
int ylen = num_elements(y);
vector[len] z;

if (xlen + ylen - 1 < len) {
reject("convolve_with_rev_pmf: len is longer than x and y convolved");
Expand All @@ -72,18 +71,12 @@ vector convolve_with_rev_pmf(vector x, vector y, int len) {
reject("convolve_with_rev_pmf: len is shorter than x");
}

for (s in 1:xlen) {
vector[len] z;
int ub = max(len, xlen);
for (s in 1:ub) {
array[4] int indices = calc_conv_indices_xlen(s, xlen, ylen);
z[s] = dot_product(x[indices[1]:indices[2]], y[indices[3]:indices[4]]);
}
Comment thread
coderabbitai[bot] marked this conversation as resolved.

if (len > xlen) {
for (s in (xlen + 1):len) {
array[4] int indices = calc_conv_indices_len(s, xlen, ylen);
z[s] = dot_product(x[indices[1]:indices[2]], y[indices[3]:indices[4]]);
}
}

return z;
}

Expand Down
52 changes: 31 additions & 21 deletions inst/stan/functions/gaussian_process.stan
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,21 @@ vector diagSPD_EQ(real alpha, real rho, real L, int M) {
return factor * exp(exponent * square(indices));
}

/**
* Index set for M basis functions of length L for Matern kernel.
*
* The function returns pow(pi() / 2 / L * linspaced_vector(M, 1, M), 2),
* or equivalently, square(pi() / (2 * L) * linspaced_vector(M, 1, M)).
*
* @param L Length of the interval
* @param M Number of basis functions
* @return Linearly spaced M-vector
*/
vector matern_indices(int M, int L) {
real factor = pi() / (2 * L);
return square(linspaced_vector(M, factor, factor * M));
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated
}

/**
* Spectral density for 1/2 Matern (Ornstein-Uhlenbeck) kernel
*
Expand All @@ -35,10 +50,8 @@ vector diagSPD_EQ(real alpha, real rho, real L, int M) {
* @ingroup estimates_smoothing
*/
vector diagSPD_Matern12(real alpha, real rho, real L, int M) {
vector[M] indices = linspaced_vector(M, 1, M);
real factor = 2;
vector[M] denom = rho * ((1 / rho)^2 + pow(pi() / 2 / L * indices, 2));
return alpha * sqrt(factor * inv(denom));
vector[M] denom = 1 / rho + rho * matern_indices(M, L);
return alpha * sqrt(2 ./ denom);
}

/**
Expand All @@ -53,10 +66,9 @@ vector diagSPD_Matern12(real alpha, real rho, real L, int M) {
* @ingroup estimates_smoothing
*/
vector diagSPD_Matern32(real alpha, real rho, real L, int M) {
vector[M] indices = linspaced_vector(M, 1, M);
real factor = 2 * alpha * pow(sqrt(3) / rho, 1.5);
vector[M] denom = (sqrt(3) / rho)^2 + pow((pi() / 2 / L) * indices, 2);
return factor * inv(denom);
real factor1 = 2 * alpha * (sqrt(3) / rho)^1.5;
vector[M] denom = 3 / square(rho) + matern_indices(M, L);
return factor ./ denom;
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated
}

/**
Expand All @@ -73,9 +85,8 @@ vector diagSPD_Matern32(real alpha, real rho, real L, int M) {
vector diagSPD_Matern52(real alpha, real rho, real L, int M) {
vector[M] indices = linspaced_vector(M, 1, M);
real factor = 16 * pow(sqrt(5) / rho, 5);
vector[M] denom =
3 * pow((sqrt(5) / rho)^2 + pow((pi() / 2 / L) * indices, 2), 3);
return alpha * sqrt(factor * inv(denom));
vector[M] denom = 3 * pow(5 / square(rho) + matern_indices(M, L), 3);
return alpha * sqrt(factor ./ denom);
}

/**
Expand All @@ -92,10 +103,11 @@ vector diagSPD_Periodic(real alpha, real rho, int M) {
real a = inv_square(rho);
vector[M] indices = linspaced_vector(M, 1, M);
vector[M] q = exp(
log(alpha) + 0.5 *
(log(2) - a + to_vector(log_modified_bessel_first_kind(indices, a)))
log(alpha) +
0.5 * (log2() - a + log_modified_bessel_first_kind(indices, a))
);
return append_row(q, q);

}

/**
Expand Down Expand Up @@ -129,11 +141,11 @@ matrix PHI(int N, int M, real L, vector x) {
*
* @ingroup estimates_smoothing
*/

matrix PHI_periodic(int N, int M, real w0, vector x) {
matrix[N, M] mw0x = diag_post_multiply(
rep_matrix(w0 * x, M), linspaced_vector(M, 1, M)
);
return append_col(cos(mw0x), sin(mw0x));
row_vector[M] k = linspaced_row_vector(M, 1, M);
matrix[N, M] w0xk = (w0 * x) * k;
return append_col(cos(w0xk), sin(w0xk));
}

/**
Expand All @@ -153,9 +165,7 @@ matrix PHI_periodic(int N, int M, real w0, vector x) {
int setup_noise(int ot_h, int t, int horizon, int estimate_r,
int stationary, int future_fixed, int fixed_from) {
int noise_time = estimate_r > 0 ? (stationary > 0 ? ot_h : ot_h - 1) : t;
int noise_terms =
future_fixed > 0 ? (noise_time - horizon + fixed_from) : noise_time;
return noise_terms;
return future_fixed > 0 ? (noise_time - horizon + fixed_from) : noise_time;
}

/**
Expand Down Expand Up @@ -210,7 +220,7 @@ vector update_gp(matrix PHI, int M, real L, real alpha,
} else if (nu == 2.5) {
diagSPD = diagSPD_Matern52(alpha, rho, L, M);
} else {
reject("nu must be one of 1/2, 3/2 or 5/2; found nu=", nu);
reject("nu must be one of 0.5, 1.5, or 2.5; found nu=", nu);
}
}
return PHI * (diagSPD .* eta);
Expand Down
Loading