@@ -55,7 +55,9 @@ def test_time_seasonality(s, d, innovations, remove_first_state, rng):
5555
5656@pytest .mark .parametrize ("d" , [1 , 3 ])
5757@pytest .mark .parametrize ("start_state" , [0 , 2 , "state_2" ])
58- def test_time_seasonality_start_state (d , start_state , rng ):
58+ @pytest .mark .parametrize ("use_time_varying" , [True , False ], ids = ["time_varying" , "static" ])
59+ @pytest .mark .parametrize ("remove_first_state" , [True , False ])
60+ def test_time_seasonality_start_state (d , start_state , use_time_varying , remove_first_state , rng ):
5961 s = 4
6062 state_names = [f"state_{ i } " for i in range (s )]
6163
@@ -65,13 +67,18 @@ def test_time_seasonality_start_state(d, start_state, rng):
6567 innovations = False ,
6668 name = "season" ,
6769 state_names = state_names ,
68- remove_first_state = True ,
70+ remove_first_state = remove_first_state ,
6971 start_state = start_state ,
72+ use_time_varying = use_time_varying ,
7073 )
7174
72- params = np .array ([1.0 , 2.0 , 3.0 ], dtype = config .floatX )
73- implied_gamma0 = - params .sum ()
74- default_seasons = [implied_gamma0 , params [0 ], params [1 ], params [2 ]]
75+ if remove_first_state :
76+ params = np .array ([1.0 , 2.0 , 3.0 ], dtype = config .floatX )
77+ implied_gamma0 = - params .sum ()
78+ default_seasons = [implied_gamma0 , params [0 ], params [1 ], params [2 ]]
79+ else :
80+ params = np .array ([1.0 , 2.0 , 3.0 , 4.0 ], dtype = config .floatX )
81+ default_seasons = [params [0 ], params [1 ], params [2 ], params [3 ]]
7582
7683 start_idx = state_names .index (start_state ) if isinstance (start_state , str ) else start_state
7784 expected_seasons = default_seasons [start_idx :] + default_seasons [:start_idx ]
@@ -175,6 +182,12 @@ def test_time_seasonality_multiple_observed(rng, d, remove_first_state):
175182 np .testing .assert_allclose (T_v , expected_T , atol = ATOL , rtol = RTOL )
176183 np .testing .assert_allclose (Q_v , np .array ([[0.1 ** 2 , 0.0 ], [0.0 , 0.8 ** 2 ]]), atol = ATOL , rtol = RTOL )
177184
185+ k_states_per_endog = d * (s - 1 ) if remove_first_state else d * s
186+ Z0 = np .zeros ((1 , k_states_per_endog ))
187+ Z0 [0 , 0 ] = 1.0
188+ expected_Z = np .block ([[Z0 , np .zeros_like (Z0 )], [np .zeros_like (Z0 ), Z0 ]])
189+ np .testing .assert_allclose (Z_v , expected_Z , atol = ATOL , rtol = RTOL )
190+
178191
179192@pytest .mark .parametrize ("d" , [2 , 3 ])
180193@pytest .mark .parametrize ("remove_first_state" , [True , False ])
@@ -204,6 +217,13 @@ def test_time_seasonality_multiple_observed_time_varying(rng, d, remove_first_st
204217 assert_pattern_repeats (y [:, 0 ], s * d , atol = ATOL , rtol = RTOL )
205218 assert_pattern_repeats (y [:, 1 ], s * d , atol = ATOL , rtol = RTOL )
206219
220+ k_states_per_endog = (s - 1 ) if remove_first_state else s
221+ Z0 = np .zeros ((1 , k_states_per_endog ))
222+ Z0 [0 , 0 ] = 1.0
223+ expected_Z = np .block ([[Z0 , np .zeros_like (Z0 )], [np .zeros_like (Z0 ), Z0 ]])
224+ Z_v = mod .ssm ["design" ].eval ()
225+ np .testing .assert_allclose (Z_v , expected_Z , atol = ATOL , rtol = RTOL )
226+
207227
208228def test_time_seasonality_shared_states ():
209229 mod = st .TimeSeasonality (
@@ -231,6 +251,34 @@ def test_time_seasonality_shared_states():
231251 np .testing .assert_allclose (R , np .array ([[1.0 ], [0.0 ], [0.0 ]]))
232252
233253
254+ @pytest .mark .parametrize ("use_time_varying" , [True , False ], ids = ["time_varying" , "static" ])
255+ def test_time_seasonality_shared_states_with_duration (rng , use_time_varying ):
256+ s , d = 4 , 3
257+ mod = st .TimeSeasonality (
258+ season_length = s ,
259+ duration = d ,
260+ innovations = False ,
261+ name = "season" ,
262+ observed_state_names = ["data_1" , "data_2" ],
263+ remove_first_state = True ,
264+ share_states = True ,
265+ use_time_varying = use_time_varying ,
266+ )
267+
268+ assert mod .k_endog == 2
269+ expected_k_states = (s - 1 ) if use_time_varying else d * (s - 1 )
270+ assert mod .k_states == expected_k_states
271+ assert mod .k_posdef == 0
272+
273+ x0 = np .array ([1.0 , 2.0 , 3.0 ], dtype = config .floatX )
274+ params = {"params_season" : x0 }
275+ _ , y = simulate_from_numpy_model (mod , rng , params , steps = s * d * 4 )
276+
277+ assert_pattern_repeats (y [:, 0 ], s * d , atol = ATOL , rtol = RTOL )
278+ assert_pattern_repeats (y [:, 1 ], s * d , atol = ATOL , rtol = RTOL )
279+ np .testing .assert_allclose (y [:, 0 ], y [:, 1 ], atol = ATOL , rtol = RTOL )
280+
281+
234282def test_add_mixed_shared_not_shared_time_seasonality ():
235283 shared = st .TimeSeasonality (
236284 season_length = 3 ,
@@ -394,8 +442,10 @@ def test_add_time_varying_and_static_seasonality(rng):
394442 x , y = simulate_from_numpy_model (mod , rng , params , steps = steps )
395443
396444 # Combined output should have period = LCM of individual periods
397- # For testing, verify the model runs without error and output shape is correct
445+ # LCM(s1*d1, s2*d2) = LCM(12, 10) = 60
398446 assert y .shape == (steps ,)
447+ assert not np .any (np .isnan (y ))
448+ assert_pattern_repeats (y , 60 , atol = ATOL , rtol = RTOL )
399449
400450
401451@pytest .mark .parametrize ("n" , [1 , 2 , 3 , None ])
@@ -409,8 +459,9 @@ def test_frequency_seasonality(n, s, rng):
409459 x0 = rng .normal (size = mod .n_coefs ).astype (config .floatX )
410460 params = {"params_season" : x0 , "sigma_season" : 0.0 }
411461
412- decimal = s_str .split ("." ) if "." in (s_str := str (s )) else "0"
413- T = int (s * 10 ** len (decimal ))
462+ s_str = str (s )
463+ n_decimal = len (s_str .split ("." )[1 ]) if "." in s_str else 0
464+ T = int (s * 10 ** n_decimal )
414465
415466 x , y = simulate_from_numpy_model (mod , rng , params , steps = 2 * T )
416467 assert_pattern_repeats (y , T , atol = ATOL , rtol = RTOL )
0 commit comments