1515Interrupted Time Series Analysis
1616"""
1717
18+ import warnings
1819from typing import Any , Literal
1920
2021import arviz as az
@@ -139,9 +140,8 @@ def __init__(
139140 ** kwargs : Any ,
140141 ) -> None :
141142 super ().__init__ (model = model )
142- self .pre_y : xr .DataArray
143- self .post_y : xr .DataArray
144- # rename the index to "obs_ind"
143+ self .pre_design : xr .Dataset
144+ self .post_design : xr .Dataset
145145 data .index .name = "obs_ind"
146146 self .data = data
147147 self .input_validation (data , treatment_time , treatment_end_time )
@@ -155,96 +155,98 @@ def __init__(
155155
156156 def _build_design_matrices (self ) -> None :
157157 """Build design matrices for pre and post intervention periods using patsy."""
158- # set things up with pre-intervention data
159158 y , X = dmatrices (self .formula , self .datapre )
160159 self .outcome_variable_name = y .design_info .column_names [0 ]
161160 self ._y_design_info = y .design_info
162161 self ._x_design_info = X .design_info
163162 self .labels = X .design_info .column_names
164- self .pre_y , self .pre_X = np .asarray (y ), np .asarray (X )
165- # process post-intervention data
163+ self ._pre_y_raw , self ._pre_X_raw = np .asarray (y ), np .asarray (X )
166164 (new_y , new_x ) = build_design_matrices (
167165 [self ._y_design_info , self ._x_design_info ], self .datapost
168166 )
169- self .post_X = np .asarray (new_x )
170- self .post_y = np .asarray (new_y )
167+ self ._post_X_raw = np .asarray (new_x )
168+ self ._post_y_raw = np .asarray (new_y )
171169
172170 def _prepare_data (self ) -> None :
173- """Convert design matrices to xarray DataArrays for pre and post periods."""
174- self .pre_X = xr .DataArray (
175- self .pre_X ,
176- dims = ["obs_ind" , "coeffs" ],
177- coords = {
178- "obs_ind" : self .datapre .index ,
179- "coeffs" : self .labels ,
180- },
181- )
182- self .pre_y = xr .DataArray (
183- self .pre_y , # Keep 2D shape
184- dims = ["obs_ind" , "treated_units" ],
185- coords = {"obs_ind" : self .datapre .index , "treated_units" : ["unit_0" ]},
186- )
187- self .post_X = xr .DataArray (
188- self .post_X ,
189- dims = ["obs_ind" , "coeffs" ],
190- coords = {
191- "obs_ind" : self .datapost .index ,
192- "coeffs" : self .labels ,
193- },
171+ """Bundle design matrices into ``xr.Dataset`` objects for pre and post periods."""
172+ self .pre_design = xr .Dataset (
173+ {
174+ "X" : xr .DataArray (
175+ self ._pre_X_raw ,
176+ dims = ["obs_ind" , "coeffs" ],
177+ coords = {"obs_ind" : self .datapre .index , "coeffs" : self .labels },
178+ ),
179+ "y" : xr .DataArray (
180+ self ._pre_y_raw ,
181+ dims = ["obs_ind" , "treated_units" ],
182+ coords = {
183+ "obs_ind" : self .datapre .index ,
184+ "treated_units" : ["unit_0" ],
185+ },
186+ ),
187+ }
194188 )
195- self .post_y = xr .DataArray (
196- self .post_y , # Keep 2D shape
197- dims = ["obs_ind" , "treated_units" ],
198- coords = {"obs_ind" : self .datapost .index , "treated_units" : ["unit_0" ]},
189+ self .post_design = xr .Dataset (
190+ {
191+ "X" : xr .DataArray (
192+ self ._post_X_raw ,
193+ dims = ["obs_ind" , "coeffs" ],
194+ coords = {"obs_ind" : self .datapost .index , "coeffs" : self .labels },
195+ ),
196+ "y" : xr .DataArray (
197+ self ._post_y_raw ,
198+ dims = ["obs_ind" , "treated_units" ],
199+ coords = {
200+ "obs_ind" : self .datapost .index ,
201+ "treated_units" : ["unit_0" ],
202+ },
203+ ),
204+ }
199205 )
206+ del self ._pre_X_raw , self ._pre_y_raw , self ._post_X_raw , self ._post_y_raw
200207
201208 def algorithm (self ) -> None :
202209 """Run the experiment algorithm: fit model, predict, and calculate causal impact."""
203- # fit the model to the observed (pre-intervention) data
204- # All PyMC models now accept xr.DataArray with consistent API
210+ pre_X = self .pre_design ["X" ]
211+ pre_y = self .pre_design ["y" ]
212+ post_X = self .post_design ["X" ]
213+ post_y = self .post_design ["y" ]
214+
205215 if isinstance (self .model , PyMCModel ):
206216 COORDS : dict [str , Any ] = {
207217 "coeffs" : self .labels ,
208- "obs_ind" : np .arange (self . pre_X .shape [0 ]),
218+ "obs_ind" : np .arange (pre_X .shape [0 ]),
209219 "treated_units" : ["unit_0" ],
210- "datetime_index" : self .datapre .index , # For time series models
220+ "datetime_index" : self .datapre .index ,
211221 }
212- self .model .fit (X = self . pre_X , y = self . pre_y , coords = COORDS )
222+ self .model .fit (X = pre_X , y = pre_y , coords = COORDS )
213223 elif isinstance (self .model , RegressorMixin ):
214- # For OLS models, use 1D y data
215- self .model .fit (X = self .pre_X , y = self .pre_y .isel (treated_units = 0 ))
224+ self .model .fit (X = pre_X , y = pre_y .isel (treated_units = 0 ))
216225 else :
217226 raise ValueError ("Model type not recognized" )
218227
219- # score the goodness of fit to the pre-intervention data
220228 if isinstance (self .model , PyMCModel ):
221- self .score = self .model .score (X = self . pre_X , y = self . pre_y )
229+ self .score = self .model .score (X = pre_X , y = pre_y )
222230 elif isinstance (self .model , RegressorMixin ):
223- self .score = self .model .score (
224- X = self .pre_X , y = self .pre_y .isel (treated_units = 0 )
225- )
231+ self .score = self .model .score (X = pre_X , y = pre_y .isel (treated_units = 0 ))
226232
227- # get the model predictions of the observed (pre-intervention) data
228233 if isinstance (self .model , PyMCModel | RegressorMixin ):
229- self .pre_pred = self .model .predict (X = self . pre_X )
234+ self .pre_pred = self .model .predict (X = pre_X )
230235
231- # calculate the counterfactual (post period)
232236 if isinstance (self .model , PyMCModel ):
233- self .post_pred = self .model .predict (X = self . post_X , out_of_sample = True )
237+ self .post_pred = self .model .predict (X = post_X , out_of_sample = True )
234238 elif isinstance (self .model , RegressorMixin ):
235- self .post_pred = self .model .predict (X = self . post_X )
239+ self .post_pred = self .model .predict (X = post_X )
236240
237- # calculate impact - all PyMC models now use 2D data with treated_units
238241 if isinstance (self .model , PyMCModel ):
239- self .pre_impact = self .model .calculate_impact (self . pre_y , self .pre_pred )
240- self .post_impact = self .model .calculate_impact (self . post_y , self .post_pred )
242+ self .pre_impact = self .model .calculate_impact (pre_y , self .pre_pred )
243+ self .post_impact = self .model .calculate_impact (post_y , self .post_pred )
241244 elif isinstance (self .model , RegressorMixin ):
242- # SKL models work with 1D data
243245 self .pre_impact = self .model .calculate_impact (
244- self . pre_y .isel (treated_units = 0 ), self .pre_pred
246+ pre_y .isel (treated_units = 0 ), self .pre_pred
245247 )
246248 self .post_impact = self .model .calculate_impact (
247- self . post_y .isel (treated_units = 0 ), self .post_pred
249+ post_y .isel (treated_units = 0 ), self .post_pred
248250 )
249251
250252 self .post_impact_cumulative = self .model .calculate_cumulative_impact (
@@ -318,6 +320,46 @@ def datapost(self) -> pd.DataFrame:
318320 """
319321 return self .data [self .data .index >= self .treatment_time ]
320322
323+ @property
324+ def pre_X (self ) -> xr .DataArray :
325+ """.. deprecated:: Use ``self.pre_design['X']`` instead."""
326+ warnings .warn (
327+ "pre_X is deprecated, use pre_design['X']" ,
328+ DeprecationWarning ,
329+ stacklevel = 2 ,
330+ )
331+ return self .pre_design ["X" ]
332+
333+ @property
334+ def pre_y (self ) -> xr .DataArray :
335+ """.. deprecated:: Use ``self.pre_design['y']`` instead."""
336+ warnings .warn (
337+ "pre_y is deprecated, use pre_design['y']" ,
338+ DeprecationWarning ,
339+ stacklevel = 2 ,
340+ )
341+ return self .pre_design ["y" ]
342+
343+ @property
344+ def post_X (self ) -> xr .DataArray :
345+ """.. deprecated:: Use ``self.post_design['X']`` instead."""
346+ warnings .warn (
347+ "post_X is deprecated, use post_design['X']" ,
348+ DeprecationWarning ,
349+ stacklevel = 2 ,
350+ )
351+ return self .post_design ["X" ]
352+
353+ @property
354+ def post_y (self ) -> xr .DataArray :
355+ """.. deprecated:: Use ``self.post_design['y']`` instead."""
356+ warnings .warn (
357+ "post_y is deprecated, use post_design['y']" ,
358+ DeprecationWarning ,
359+ stacklevel = 2 ,
360+ )
361+ return self .post_design ["y" ]
362+
321363 def _split_post_period (self ) -> None :
322364 """Split post period into intervention and post-intervention periods.
323365
@@ -627,9 +669,7 @@ def _bayesian_plot(
627669
628670 (h ,) = ax [0 ].plot (
629671 self .datapre .index ,
630- self .pre_y .isel (treated_units = 0 )
631- if hasattr (self .pre_y , "isel" )
632- else self .pre_y [:, 0 ],
672+ self .pre_design ["y" ].isel (treated_units = 0 ),
633673 "k." ,
634674 label = "Observations" ,
635675 )
@@ -654,9 +694,7 @@ def _bayesian_plot(
654694
655695 ax [0 ].plot (
656696 self .datapost .index ,
657- self .post_y .isel (treated_units = 0 )
658- if hasattr (self .post_y , "isel" )
659- else self .post_y [:, 0 ],
697+ self .post_design ["y" ].isel (treated_units = 0 ),
660698 "k." ,
661699 )
662700 # Shaded causal effect
@@ -669,9 +707,7 @@ def _bayesian_plot(
669707 h = ax [0 ].fill_between (
670708 self .datapost .index ,
671709 y1 = post_pred_mu ,
672- y2 = self .post_y .isel (treated_units = 0 )
673- if hasattr (self .post_y , "isel" )
674- else self .post_y [:, 0 ],
710+ y2 = self .post_design ["y" ].isel (treated_units = 0 ),
675711 color = "C0" ,
676712 alpha = 0.25 ,
677713 )
@@ -807,10 +843,10 @@ def _ols_plot(
807843
808844 fig , ax = plt .subplots (3 , 1 , sharex = True , figsize = (7 , 8 ))
809845
810- ax [0 ].plot (self .datapre .index , self .pre_y , "k." )
846+ ax [0 ].plot (self .datapre .index , self .pre_design [ "y" ] , "k." )
811847 ax [0 ].plot (self .datapre .index , self .pre_pred , c = "k" , label = "model fit" )
812848
813- ax [0 ].plot (self .datapost .index , self .post_y , "k." )
849+ ax [0 ].plot (self .datapost .index , self .post_design [ "y" ] , "k." )
814850 ax [0 ].plot (
815851 self .datapost .index ,
816852 self .post_pred ,
@@ -822,7 +858,7 @@ def _ols_plot(
822858 ax [0 ].fill_between (
823859 self .datapost .index ,
824860 y1 = np .squeeze (self .post_pred ),
825- y2 = np .squeeze (self .post_y ),
861+ y2 = np .squeeze (self .post_design [ "y" ] ),
826862 color = "C0" ,
827863 alpha = 0.25 ,
828864 label = "Causal impact" ,
0 commit comments