Skip to content

Commit 2727b3d

Browse files
bugfixes and touchup
1 parent 40bf29b commit 2727b3d

4 files changed

Lines changed: 122 additions & 28 deletions

File tree

pymc_extras/statespace/core/statespace.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2347,7 +2347,7 @@ def _build_forecast_model(
23472347
cov_dims = ["data_time", ALL_STATE_DIM, ALL_STATE_AUX_DIM]
23482348

23492349
with pm.Model(coords=temp_coords) as forecast_model:
2350-
(_, _, *matrices), grouped_outputs = self._kalman_filter_outputs_from_dummy_graph(
2350+
_, grouped_outputs = self._kalman_filter_outputs_from_dummy_graph(
23512351
data_dims=["data_time", OBS_STATE_DIM],
23522352
)
23532353

@@ -2374,9 +2374,20 @@ def _build_forecast_model(
23742374
"P0_slice", cov_frozen[t0_idx], dims=cov_dims[1:] if cov_dims is not None else None
23752375
)
23762376

2377-
# Get matrices with n_timesteps set to forecast length for time-varying models
2378-
# Note: matrices already has x0, P0 skipped from _kalman_filter_outputs_from_dummy_graph
2379-
forecast_matrices = self._insert_constant_timestep(matrices, len(forecast_index))
2377+
# Get fresh matrices with n_timesteps placeholder still intact.
2378+
# Build for the full timeline (training + forecast) so that time-varying matrices
2379+
# continue at the correct phase, then slice to keep only the forecast portion.
2380+
n_train = len(time_index)
2381+
n_total = n_train + len(forecast_index)
2382+
2383+
full_matrices = self._insert_constant_timestep(self.unpack_statespace(), n_total)
2384+
_, _, *forecast_matrices = full_matrices
2385+
2386+
forecast_names = MATRIX_NAMES[2:] # c, d, T, Z, R, H, Q
2387+
forecast_matrices = [
2388+
m[n_train:] if m.ndim == (2 if name in VECTOR_VALUED else 3) else m
2389+
for m, name in zip(forecast_matrices, forecast_names)
2390+
]
23802391

23812392
_ = LinearGaussianStateSpace(
23822393
"forecast",
@@ -2640,6 +2651,12 @@ def impulse_response_function(
26402651
-------
26412652
pm.InferenceData
26422653
An Arviz InferenceData object containing impulse response function in a variable named "irf".
2654+
2655+
Notes
2656+
-----
2657+
For models with time-varying transition matrices, the IRF is computed starting at phase 0 of the
2658+
time-varying cycle. This means the response represents the effect of a shock occurring at the first
2659+
modeled state, T(0).
26432660
"""
26442661
options = [shock_size, shock_cov, shock_trajectory]
26452662
n_options = sum(x is not None for x in options)

pymc_extras/statespace/models/structural/components/seasonality.py

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -452,14 +452,15 @@ def _build_time_varying_transition(self) -> TensorVariable:
452452

453453
T_hold = pt.eye(n)
454454

455-
time_idx = pt.arange(self.n_timesteps)
456-
is_rotation_step = pt.eq(time_idx % self.duration, self.duration - 1)
455+
# Build one complete cycle: [I, I, ..., I, T_rotate] of length `duration`
456+
# Then tile to cover n_timesteps
457+
cycle_matrices = [T_hold for _ in range(self.duration - 1)] + [T_rotate]
458+
T_cycle = pt.stack(cycle_matrices) # (duration, n, n)
457459

458-
return pt.where(
459-
is_rotation_step[:, None, None],
460-
pt.broadcast_to(T_rotate, (self.n_timesteps, n, n)),
461-
pt.broadcast_to(T_hold, (self.n_timesteps, n, n)),
462-
)
460+
n_cycles = (self.n_timesteps + self.duration - 1) // self.duration # ceiling division
461+
T_tiled = pt.tile(T_cycle, (n_cycles, 1, 1))
462+
463+
return T_tiled[: self.n_timesteps]
463464

464465
def _build_transition_matrix(self) -> TensorVariable:
465466
"""Build the full transition matrix, handling multivariate via block_diag."""
@@ -522,15 +523,18 @@ def _build_initial_state_simple(self, initial_params: TensorVariable) -> TensorV
522523
if use_tv:
523524
return initial_params
524525
else:
525-
return pt.extra_ops.repeat(initial_params, self.duration, axis=0)
526+
# Static mode: state is d blocks of s elements each
527+
# Tile the full season vector d times
528+
return pt.tile(initial_params, self.duration)
526529
else:
527530
if use_tv:
528531
return initial_params.ravel()
529532
else:
530-
return pt.extra_ops.repeat(initial_params, self.duration, axis=1).ravel()
533+
# Static mode: tile each row (endog) d times, then ravel
534+
return pt.tile(initial_params, (1, self.duration)).ravel()
531535

532536
def _apply_start_state_shift(
533-
self, initial_state: TensorVariable, T: TensorVariable
537+
self, initial_state: TensorVariable, T: TensorVariable | None
534538
) -> TensorVariable:
535539
"""Shift initial state to account for start_state offset."""
536540
if self.start_idx == 0:
@@ -569,7 +573,7 @@ def _build_selection_and_state_cov(self) -> None:
569573
k_posdef = self._k_posdef_per_endog()
570574

571575
R = pt.zeros((k_states, k_posdef))[0, 0].set(1.0)
572-
self.ssm["selection", :, :] = pt.join(0, *[R for _ in range(k_endog_effective)])
576+
self.ssm["selection", :, :] = pt.linalg.block_diag(*[R for _ in range(k_endog_effective)])
573577

574578
sigma = self.make_and_register_variable(
575579
f"sigma_{self.name}",
@@ -606,8 +610,10 @@ def make_symbolic_graph(self) -> None:
606610
else:
607611
initial_state = self._build_initial_state_simple(initial_params)
608612

609-
# Apply start_state shift (handles time-varying vs static internally)
610-
T_for_shift = self._build_static_transition()
613+
# Apply start_state shift
614+
T_for_shift = (
615+
None if self._uses_time_varying_transition else self._build_static_transition()
616+
)
611617
initial_state = self._apply_start_state_shift(initial_state, T_for_shift)
612618

613619
self.ssm["initial_state", :] = initial_state

tests/statespace/core/test_statespace.py

Lines changed: 23 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1358,12 +1358,32 @@ def test_sample_unconditional_posterior(self, ss_mod_time_varying, idata_time_va
13581358

13591359
@pytest.mark.filterwarnings("ignore:No time index found on the supplied data.")
13601360
@pytest.mark.filterwarnings("ignore:No start date provided")
1361-
def test_forecast(self, ss_mod_time_varying, idata_time_varying):
1362-
result = ss_mod_time_varying.forecast(idata_time_varying, periods=10)
1361+
@pytest.mark.parametrize(
1362+
"periods", [10, 50], ids=["shorter_than_training", "longer_than_training"]
1363+
)
1364+
def test_forecast(self, ss_mod_time_varying, idata_time_varying, periods):
1365+
n_obs = 40 # must match pymc_mod_time_varying fixture
1366+
result = ss_mod_time_varying.forecast(idata_time_varying, periods=periods)
1367+
13631368
assert "forecast_latent" in result
13641369
assert "forecast_observed" in result
1365-
assert result["forecast_latent"].shape[2] == 10
1370+
assert result["forecast_latent"].dims == ("chain", "draw", "time", "state")
1371+
assert result["forecast_observed"].dims == ("chain", "draw", "time", "observed_state")
1372+
assert result["forecast_latent"].shape[2] == periods
13661373
assert not np.any(np.isnan(result["forecast_latent"].values))
1374+
assert not np.any(np.isnan(result["forecast_observed"].values))
1375+
1376+
# Value check: the model has y_t = d_t + Z @ x_t with Z=[[1]] and H=0 (no obs noise),
1377+
# so forecast_observed - forecast_latent = d_t = slope * t.
1378+
# Forecast matrices are phase-aligned: they continue from n_obs, not from 0.
1379+
latent = result["forecast_latent"].values # (chain, draw, time, state)
1380+
observed = result["forecast_observed"].values # (chain, draw, time, obs)
1381+
slope = idata_time_varying.posterior["slope"].values # (chain, draw)
1382+
1383+
intercepts = observed[..., 0] - latent[..., 0] # (chain, draw, time)
1384+
expected = slope[..., None] * np.arange(n_obs, n_obs + periods)[None, None, :]
1385+
1386+
assert_allclose(intercepts, expected, atol=1e-5, rtol=1e-5)
13671387

13681388
@pytest.mark.filterwarnings("ignore:No time index found on the supplied data.")
13691389
def test_impulse_response_function(self, ss_mod_time_varying, idata_time_varying):

tests/statespace/models/structural/components/test_seasonality.py

Lines changed: 59 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,9 @@ def test_time_seasonality(s, d, innovations, remove_first_state, rng):
5555

5656
@pytest.mark.parametrize("d", [1, 3])
5757
@pytest.mark.parametrize("start_state", [0, 2, "state_2"])
58-
def test_time_seasonality_start_state(d, start_state, rng):
58+
@pytest.mark.parametrize("use_time_varying", [True, False], ids=["time_varying", "static"])
59+
@pytest.mark.parametrize("remove_first_state", [True, False])
60+
def test_time_seasonality_start_state(d, start_state, use_time_varying, remove_first_state, rng):
5961
s = 4
6062
state_names = [f"state_{i}" for i in range(s)]
6163

@@ -65,13 +67,18 @@ def test_time_seasonality_start_state(d, start_state, rng):
6567
innovations=False,
6668
name="season",
6769
state_names=state_names,
68-
remove_first_state=True,
70+
remove_first_state=remove_first_state,
6971
start_state=start_state,
72+
use_time_varying=use_time_varying,
7073
)
7174

72-
params = np.array([1.0, 2.0, 3.0], dtype=config.floatX)
73-
implied_gamma0 = -params.sum()
74-
default_seasons = [implied_gamma0, params[0], params[1], params[2]]
75+
if remove_first_state:
76+
params = np.array([1.0, 2.0, 3.0], dtype=config.floatX)
77+
implied_gamma0 = -params.sum()
78+
default_seasons = [implied_gamma0, params[0], params[1], params[2]]
79+
else:
80+
params = np.array([1.0, 2.0, 3.0, 4.0], dtype=config.floatX)
81+
default_seasons = [params[0], params[1], params[2], params[3]]
7582

7683
start_idx = state_names.index(start_state) if isinstance(start_state, str) else start_state
7784
expected_seasons = default_seasons[start_idx:] + default_seasons[:start_idx]
@@ -175,6 +182,12 @@ def test_time_seasonality_multiple_observed(rng, d, remove_first_state):
175182
np.testing.assert_allclose(T_v, expected_T, atol=ATOL, rtol=RTOL)
176183
np.testing.assert_allclose(Q_v, np.array([[0.1**2, 0.0], [0.0, 0.8**2]]), atol=ATOL, rtol=RTOL)
177184

185+
k_states_per_endog = d * (s - 1) if remove_first_state else d * s
186+
Z0 = np.zeros((1, k_states_per_endog))
187+
Z0[0, 0] = 1.0
188+
expected_Z = np.block([[Z0, np.zeros_like(Z0)], [np.zeros_like(Z0), Z0]])
189+
np.testing.assert_allclose(Z_v, expected_Z, atol=ATOL, rtol=RTOL)
190+
178191

179192
@pytest.mark.parametrize("d", [2, 3])
180193
@pytest.mark.parametrize("remove_first_state", [True, False])
@@ -204,6 +217,13 @@ def test_time_seasonality_multiple_observed_time_varying(rng, d, remove_first_st
204217
assert_pattern_repeats(y[:, 0], s * d, atol=ATOL, rtol=RTOL)
205218
assert_pattern_repeats(y[:, 1], s * d, atol=ATOL, rtol=RTOL)
206219

220+
k_states_per_endog = (s - 1) if remove_first_state else s
221+
Z0 = np.zeros((1, k_states_per_endog))
222+
Z0[0, 0] = 1.0
223+
expected_Z = np.block([[Z0, np.zeros_like(Z0)], [np.zeros_like(Z0), Z0]])
224+
Z_v = mod.ssm["design"].eval()
225+
np.testing.assert_allclose(Z_v, expected_Z, atol=ATOL, rtol=RTOL)
226+
207227

208228
def test_time_seasonality_shared_states():
209229
mod = st.TimeSeasonality(
@@ -231,6 +251,34 @@ def test_time_seasonality_shared_states():
231251
np.testing.assert_allclose(R, np.array([[1.0], [0.0], [0.0]]))
232252

233253

254+
@pytest.mark.parametrize("use_time_varying", [True, False], ids=["time_varying", "static"])
255+
def test_time_seasonality_shared_states_with_duration(rng, use_time_varying):
256+
s, d = 4, 3
257+
mod = st.TimeSeasonality(
258+
season_length=s,
259+
duration=d,
260+
innovations=False,
261+
name="season",
262+
observed_state_names=["data_1", "data_2"],
263+
remove_first_state=True,
264+
share_states=True,
265+
use_time_varying=use_time_varying,
266+
)
267+
268+
assert mod.k_endog == 2
269+
expected_k_states = (s - 1) if use_time_varying else d * (s - 1)
270+
assert mod.k_states == expected_k_states
271+
assert mod.k_posdef == 0
272+
273+
x0 = np.array([1.0, 2.0, 3.0], dtype=config.floatX)
274+
params = {"params_season": x0}
275+
_, y = simulate_from_numpy_model(mod, rng, params, steps=s * d * 4)
276+
277+
assert_pattern_repeats(y[:, 0], s * d, atol=ATOL, rtol=RTOL)
278+
assert_pattern_repeats(y[:, 1], s * d, atol=ATOL, rtol=RTOL)
279+
np.testing.assert_allclose(y[:, 0], y[:, 1], atol=ATOL, rtol=RTOL)
280+
281+
234282
def test_add_mixed_shared_not_shared_time_seasonality():
235283
shared = st.TimeSeasonality(
236284
season_length=3,
@@ -394,8 +442,10 @@ def test_add_time_varying_and_static_seasonality(rng):
394442
x, y = simulate_from_numpy_model(mod, rng, params, steps=steps)
395443

396444
# Combined output should have period = LCM of individual periods
397-
# For testing, verify the model runs without error and output shape is correct
445+
# LCM(s1*d1, s2*d2) = LCM(12, 10) = 60
398446
assert y.shape == (steps,)
447+
assert not np.any(np.isnan(y))
448+
assert_pattern_repeats(y, 60, atol=ATOL, rtol=RTOL)
399449

400450

401451
@pytest.mark.parametrize("n", [1, 2, 3, None])
@@ -409,8 +459,9 @@ def test_frequency_seasonality(n, s, rng):
409459
x0 = rng.normal(size=mod.n_coefs).astype(config.floatX)
410460
params = {"params_season": x0, "sigma_season": 0.0}
411461

412-
decimal = s_str.split(".") if "." in (s_str := str(s)) else "0"
413-
T = int(s * 10 ** len(decimal))
462+
s_str = str(s)
463+
n_decimal = len(s_str.split(".")[1]) if "." in s_str else 0
464+
T = int(s * 10**n_decimal)
414465

415466
x, y = simulate_from_numpy_model(mod, rng, params, steps=2 * T)
416467
assert_pattern_repeats(y, T, atol=ATOL, rtol=RTOL)

0 commit comments

Comments
 (0)