Skip to content

Commit aafb28c

Browse files
committed
Consolidate experiment design-matrix attributes into xr.Dataset
Bundle loose xr.DataArray attributes on experiment classes into xr.Dataset objects to reduce attribute sprawl. Pre/post classes (ITS, SC) use pre_design/post_design; formula-based classes use a single design Dataset. Deprecated @Property accessors preserve backward compatibility. Closes #199. Made-with: Cursor
1 parent 08ee2c3 commit aafb28c

11 files changed

Lines changed: 489 additions & 278 deletions

causalpy/experiments/diff_in_diff.py

Lines changed: 39 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
Difference in differences
1616
"""
1717

18+
import warnings
1819
from typing import Any, Literal
1920

2021
import arviz as az
@@ -129,42 +130,60 @@ def _build_design_matrices(self) -> None:
129130
self._y_design_info = y.design_info
130131
self._x_design_info = X.design_info
131132
self.labels = X.design_info.column_names
132-
self.y, self.X = np.asarray(y), np.asarray(X)
133+
self._y_raw, self._X_raw = np.asarray(y), np.asarray(X)
133134
self.outcome_variable_name = y.design_info.column_names[0]
134135

135136
def _prepare_data(self) -> None:
136-
"""Convert design matrices to xarray DataArrays."""
137-
self.X = xr.DataArray(
138-
self.X,
139-
dims=["obs_ind", "coeffs"],
140-
coords={
141-
"obs_ind": np.arange(self.X.shape[0]),
142-
"coeffs": self.labels,
143-
},
137+
"""Bundle design matrices into an ``xr.Dataset``."""
138+
n = self._X_raw.shape[0]
139+
self.design = xr.Dataset(
140+
{
141+
"X": xr.DataArray(
142+
self._X_raw,
143+
dims=["obs_ind", "coeffs"],
144+
coords={"obs_ind": np.arange(n), "coeffs": self.labels},
145+
),
146+
"y": xr.DataArray(
147+
self._y_raw,
148+
dims=["obs_ind", "treated_units"],
149+
coords={"obs_ind": np.arange(n), "treated_units": ["unit_0"]},
150+
),
151+
}
144152
)
145-
self.y = xr.DataArray(
146-
self.y,
147-
dims=["obs_ind", "treated_units"],
148-
coords={"obs_ind": np.arange(self.y.shape[0]), "treated_units": ["unit_0"]},
153+
del self._X_raw, self._y_raw
154+
155+
@property
156+
def X(self) -> xr.DataArray:
157+
""".. deprecated:: Use ``self.design['X']`` instead."""
158+
warnings.warn(
159+
"X is deprecated, use design['X']", DeprecationWarning, stacklevel=2
160+
)
161+
return self.design["X"]
162+
163+
@property
164+
def y(self) -> xr.DataArray:
165+
""".. deprecated:: Use ``self.design['y']`` instead."""
166+
warnings.warn(
167+
"y is deprecated, use design['y']", DeprecationWarning, stacklevel=2
149168
)
169+
return self.design["y"]
150170

151171
def algorithm(self) -> None:
152172
"""Run the experiment algorithm: fit model, predict, and calculate causal impact."""
153-
# fit model
173+
X = self.design["X"]
174+
y = self.design["y"]
175+
154176
if isinstance(self.model, PyMCModel):
155177
COORDS = {
156178
"coeffs": self.labels,
157-
"obs_ind": np.arange(self.X.shape[0]),
179+
"obs_ind": np.arange(X.shape[0]),
158180
"treated_units": ["unit_0"],
159181
}
160-
self.model.fit(X=self.X, y=self.y, coords=COORDS)
182+
self.model.fit(X=X, y=y, coords=COORDS)
161183
elif isinstance(self.model, RegressorMixin):
162-
# Ensure the intercept is part of the coefficients array rather than
163-
# a separate intercept_ attribute. See #664 / PR #693 for
164-
# centralising this in BaseExperiment.
165184
if hasattr(self.model, "fit_intercept"):
166185
self.model.fit_intercept = False
167-
self.model.fit(X=self.X, y=self.y)
186+
self.model.fit(X=X, y=y)
168187
else:
169188
raise ValueError("Model type not recognized")
170189

causalpy/experiments/interrupted_time_series.py

Lines changed: 104 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
Interrupted Time Series Analysis
1616
"""
1717

18+
import warnings
1819
from typing import Any, Literal
1920

2021
import arviz as az
@@ -139,9 +140,8 @@ def __init__(
139140
**kwargs: Any,
140141
) -> None:
141142
super().__init__(model=model)
142-
self.pre_y: xr.DataArray
143-
self.post_y: xr.DataArray
144-
# rename the index to "obs_ind"
143+
self.pre_design: xr.Dataset
144+
self.post_design: xr.Dataset
145145
data.index.name = "obs_ind"
146146
self.data = data
147147
self.input_validation(data, treatment_time, treatment_end_time)
@@ -155,96 +155,98 @@ def __init__(
155155

156156
def _build_design_matrices(self) -> None:
157157
"""Build design matrices for pre and post intervention periods using patsy."""
158-
# set things up with pre-intervention data
159158
y, X = dmatrices(self.formula, self.datapre)
160159
self.outcome_variable_name = y.design_info.column_names[0]
161160
self._y_design_info = y.design_info
162161
self._x_design_info = X.design_info
163162
self.labels = X.design_info.column_names
164-
self.pre_y, self.pre_X = np.asarray(y), np.asarray(X)
165-
# process post-intervention data
163+
self._pre_y_raw, self._pre_X_raw = np.asarray(y), np.asarray(X)
166164
(new_y, new_x) = build_design_matrices(
167165
[self._y_design_info, self._x_design_info], self.datapost
168166
)
169-
self.post_X = np.asarray(new_x)
170-
self.post_y = np.asarray(new_y)
167+
self._post_X_raw = np.asarray(new_x)
168+
self._post_y_raw = np.asarray(new_y)
171169

172170
def _prepare_data(self) -> None:
173-
"""Convert design matrices to xarray DataArrays for pre and post periods."""
174-
self.pre_X = xr.DataArray(
175-
self.pre_X,
176-
dims=["obs_ind", "coeffs"],
177-
coords={
178-
"obs_ind": self.datapre.index,
179-
"coeffs": self.labels,
180-
},
181-
)
182-
self.pre_y = xr.DataArray(
183-
self.pre_y, # Keep 2D shape
184-
dims=["obs_ind", "treated_units"],
185-
coords={"obs_ind": self.datapre.index, "treated_units": ["unit_0"]},
186-
)
187-
self.post_X = xr.DataArray(
188-
self.post_X,
189-
dims=["obs_ind", "coeffs"],
190-
coords={
191-
"obs_ind": self.datapost.index,
192-
"coeffs": self.labels,
193-
},
171+
"""Bundle design matrices into ``xr.Dataset`` objects for pre and post periods."""
172+
self.pre_design = xr.Dataset(
173+
{
174+
"X": xr.DataArray(
175+
self._pre_X_raw,
176+
dims=["obs_ind", "coeffs"],
177+
coords={"obs_ind": self.datapre.index, "coeffs": self.labels},
178+
),
179+
"y": xr.DataArray(
180+
self._pre_y_raw,
181+
dims=["obs_ind", "treated_units"],
182+
coords={
183+
"obs_ind": self.datapre.index,
184+
"treated_units": ["unit_0"],
185+
},
186+
),
187+
}
194188
)
195-
self.post_y = xr.DataArray(
196-
self.post_y, # Keep 2D shape
197-
dims=["obs_ind", "treated_units"],
198-
coords={"obs_ind": self.datapost.index, "treated_units": ["unit_0"]},
189+
self.post_design = xr.Dataset(
190+
{
191+
"X": xr.DataArray(
192+
self._post_X_raw,
193+
dims=["obs_ind", "coeffs"],
194+
coords={"obs_ind": self.datapost.index, "coeffs": self.labels},
195+
),
196+
"y": xr.DataArray(
197+
self._post_y_raw,
198+
dims=["obs_ind", "treated_units"],
199+
coords={
200+
"obs_ind": self.datapost.index,
201+
"treated_units": ["unit_0"],
202+
},
203+
),
204+
}
199205
)
206+
del self._pre_X_raw, self._pre_y_raw, self._post_X_raw, self._post_y_raw
200207

201208
def algorithm(self) -> None:
202209
"""Run the experiment algorithm: fit model, predict, and calculate causal impact."""
203-
# fit the model to the observed (pre-intervention) data
204-
# All PyMC models now accept xr.DataArray with consistent API
210+
pre_X = self.pre_design["X"]
211+
pre_y = self.pre_design["y"]
212+
post_X = self.post_design["X"]
213+
post_y = self.post_design["y"]
214+
205215
if isinstance(self.model, PyMCModel):
206216
COORDS: dict[str, Any] = {
207217
"coeffs": self.labels,
208-
"obs_ind": np.arange(self.pre_X.shape[0]),
218+
"obs_ind": np.arange(pre_X.shape[0]),
209219
"treated_units": ["unit_0"],
210-
"datetime_index": self.datapre.index, # For time series models
220+
"datetime_index": self.datapre.index,
211221
}
212-
self.model.fit(X=self.pre_X, y=self.pre_y, coords=COORDS)
222+
self.model.fit(X=pre_X, y=pre_y, coords=COORDS)
213223
elif isinstance(self.model, RegressorMixin):
214-
# For OLS models, use 1D y data
215-
self.model.fit(X=self.pre_X, y=self.pre_y.isel(treated_units=0))
224+
self.model.fit(X=pre_X, y=pre_y.isel(treated_units=0))
216225
else:
217226
raise ValueError("Model type not recognized")
218227

219-
# score the goodness of fit to the pre-intervention data
220228
if isinstance(self.model, PyMCModel):
221-
self.score = self.model.score(X=self.pre_X, y=self.pre_y)
229+
self.score = self.model.score(X=pre_X, y=pre_y)
222230
elif isinstance(self.model, RegressorMixin):
223-
self.score = self.model.score(
224-
X=self.pre_X, y=self.pre_y.isel(treated_units=0)
225-
)
231+
self.score = self.model.score(X=pre_X, y=pre_y.isel(treated_units=0))
226232

227-
# get the model predictions of the observed (pre-intervention) data
228233
if isinstance(self.model, PyMCModel | RegressorMixin):
229-
self.pre_pred = self.model.predict(X=self.pre_X)
234+
self.pre_pred = self.model.predict(X=pre_X)
230235

231-
# calculate the counterfactual (post period)
232236
if isinstance(self.model, PyMCModel):
233-
self.post_pred = self.model.predict(X=self.post_X, out_of_sample=True)
237+
self.post_pred = self.model.predict(X=post_X, out_of_sample=True)
234238
elif isinstance(self.model, RegressorMixin):
235-
self.post_pred = self.model.predict(X=self.post_X)
239+
self.post_pred = self.model.predict(X=post_X)
236240

237-
# calculate impact - all PyMC models now use 2D data with treated_units
238241
if isinstance(self.model, PyMCModel):
239-
self.pre_impact = self.model.calculate_impact(self.pre_y, self.pre_pred)
240-
self.post_impact = self.model.calculate_impact(self.post_y, self.post_pred)
242+
self.pre_impact = self.model.calculate_impact(pre_y, self.pre_pred)
243+
self.post_impact = self.model.calculate_impact(post_y, self.post_pred)
241244
elif isinstance(self.model, RegressorMixin):
242-
# SKL models work with 1D data
243245
self.pre_impact = self.model.calculate_impact(
244-
self.pre_y.isel(treated_units=0), self.pre_pred
246+
pre_y.isel(treated_units=0), self.pre_pred
245247
)
246248
self.post_impact = self.model.calculate_impact(
247-
self.post_y.isel(treated_units=0), self.post_pred
249+
post_y.isel(treated_units=0), self.post_pred
248250
)
249251

250252
self.post_impact_cumulative = self.model.calculate_cumulative_impact(
@@ -318,6 +320,46 @@ def datapost(self) -> pd.DataFrame:
318320
"""
319321
return self.data[self.data.index >= self.treatment_time]
320322

323+
@property
324+
def pre_X(self) -> xr.DataArray:
325+
""".. deprecated:: Use ``self.pre_design['X']`` instead."""
326+
warnings.warn(
327+
"pre_X is deprecated, use pre_design['X']",
328+
DeprecationWarning,
329+
stacklevel=2,
330+
)
331+
return self.pre_design["X"]
332+
333+
@property
334+
def pre_y(self) -> xr.DataArray:
335+
""".. deprecated:: Use ``self.pre_design['y']`` instead."""
336+
warnings.warn(
337+
"pre_y is deprecated, use pre_design['y']",
338+
DeprecationWarning,
339+
stacklevel=2,
340+
)
341+
return self.pre_design["y"]
342+
343+
@property
344+
def post_X(self) -> xr.DataArray:
345+
""".. deprecated:: Use ``self.post_design['X']`` instead."""
346+
warnings.warn(
347+
"post_X is deprecated, use post_design['X']",
348+
DeprecationWarning,
349+
stacklevel=2,
350+
)
351+
return self.post_design["X"]
352+
353+
@property
354+
def post_y(self) -> xr.DataArray:
355+
""".. deprecated:: Use ``self.post_design['y']`` instead."""
356+
warnings.warn(
357+
"post_y is deprecated, use post_design['y']",
358+
DeprecationWarning,
359+
stacklevel=2,
360+
)
361+
return self.post_design["y"]
362+
321363
def _split_post_period(self) -> None:
322364
"""Split post period into intervention and post-intervention periods.
323365
@@ -627,9 +669,7 @@ def _bayesian_plot(
627669

628670
(h,) = ax[0].plot(
629671
self.datapre.index,
630-
self.pre_y.isel(treated_units=0)
631-
if hasattr(self.pre_y, "isel")
632-
else self.pre_y[:, 0],
672+
self.pre_design["y"].isel(treated_units=0),
633673
"k.",
634674
label="Observations",
635675
)
@@ -654,9 +694,7 @@ def _bayesian_plot(
654694

655695
ax[0].plot(
656696
self.datapost.index,
657-
self.post_y.isel(treated_units=0)
658-
if hasattr(self.post_y, "isel")
659-
else self.post_y[:, 0],
697+
self.post_design["y"].isel(treated_units=0),
660698
"k.",
661699
)
662700
# Shaded causal effect
@@ -669,9 +707,7 @@ def _bayesian_plot(
669707
h = ax[0].fill_between(
670708
self.datapost.index,
671709
y1=post_pred_mu,
672-
y2=self.post_y.isel(treated_units=0)
673-
if hasattr(self.post_y, "isel")
674-
else self.post_y[:, 0],
710+
y2=self.post_design["y"].isel(treated_units=0),
675711
color="C0",
676712
alpha=0.25,
677713
)
@@ -807,10 +843,10 @@ def _ols_plot(
807843

808844
fig, ax = plt.subplots(3, 1, sharex=True, figsize=(7, 8))
809845

810-
ax[0].plot(self.datapre.index, self.pre_y, "k.")
846+
ax[0].plot(self.datapre.index, self.pre_design["y"], "k.")
811847
ax[0].plot(self.datapre.index, self.pre_pred, c="k", label="model fit")
812848

813-
ax[0].plot(self.datapost.index, self.post_y, "k.")
849+
ax[0].plot(self.datapost.index, self.post_design["y"], "k.")
814850
ax[0].plot(
815851
self.datapost.index,
816852
self.post_pred,
@@ -822,7 +858,7 @@ def _ols_plot(
822858
ax[0].fill_between(
823859
self.datapost.index,
824860
y1=np.squeeze(self.post_pred),
825-
y2=np.squeeze(self.post_y),
861+
y2=np.squeeze(self.post_design["y"]),
826862
color="C0",
827863
alpha=0.25,
828864
label="Causal impact",

0 commit comments

Comments
 (0)