Skip to content

Commit 87a2c2c

Browse files
Fix forecasting logic to handle time-varying matrices
1 parent 2727b3d commit 87a2c2c

1 file changed

Lines changed: 25 additions & 0 deletions

File tree

pymc_extras/statespace/core/statespace.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2383,6 +2383,31 @@ def _build_forecast_model(
23832383
full_matrices = self._insert_constant_timestep(self.unpack_statespace(), n_total)
23842384
_, _, *forecast_matrices = full_matrices
23852385

2386+
# For exogenous-data-driven matrices the time dimension comes from the
2387+
# data shared variable, not from the n_timesteps symbolic. Replace the
2388+
# shared variables with concatenated training + scenario tensors so the
2389+
# [n_train:] slice below yields the correct forecast portion.
2390+
# TODO: Is there a way to handle this in a fully symbolic way, without having to
2391+
# run the full scan on training data to get the system's state at the start date?
2392+
if scenario is not None and self._needs_exog_data:
2393+
exog_replace = {}
2394+
for name in self.data_names:
2395+
if name not in scenario:
2396+
continue
2397+
forecast_data = scenario[name]
2398+
train_val = self._fit_exog_data[name]["value"]
2399+
fc_val = (
2400+
forecast_data.values
2401+
if isinstance(forecast_data, pd.DataFrame)
2402+
else np.asarray(forecast_data)
2403+
)
2404+
combined = np.concatenate([train_val, fc_val], axis=0)
2405+
exog_replace[forecast_model[name]] = pt.as_tensor_variable(combined, name=name)
2406+
if exog_replace:
2407+
forecast_matrices = graph_replace(
2408+
forecast_matrices, replace=exog_replace, strict=False
2409+
)
2410+
23862411
forecast_names = MATRIX_NAMES[2:] # c, d, T, Z, R, H, Q
23872412
forecast_matrices = [
23882413
m[n_train:] if m.ndim == (2 if name in VECTOR_VALUED else 3) else m

0 commit comments

Comments
 (0)