Skip to content

Commit 44e13ec

Browse files
committed
Centralize deprecated design-alias boilerplate and fix internal callers
Move _build_design_dataset helper and __getattr__ deprecation forwarding into BaseExperiment, replacing per-class @Property blocks with a declarative _deprecated_design_aliases dict. Fix convex_hull.py and maketables_adapters.py to use the new API directly. Add parametrized backward-compatibility tests covering all deprecated aliases. Made-with: Cursor
1 parent aafb28c commit 44e13ec

13 files changed

Lines changed: 324 additions & 306 deletions

causalpy/checks/convex_hull.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -50,8 +50,8 @@ def run(
5050
) -> CheckResult:
5151
"""Run the convex hull violation check on pre-treatment data."""
5252
sc = experiment
53-
datapre_control = sc.datapre_control # type: ignore[attr-defined]
54-
datapre_treated = sc.datapre_treated # type: ignore[attr-defined]
53+
datapre_control = sc.pre_design["control"] # type: ignore[attr-defined]
54+
datapre_treated = sc.pre_design["treated"] # type: ignore[attr-defined]
5555

5656
all_results = []
5757
total_violations = 0

causalpy/experiments/base.py

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,13 +18,16 @@
1818
from __future__ import annotations
1919

2020
import contextlib
21+
import warnings
2122
from abc import ABC, abstractmethod
2223
from pathlib import Path
2324
from typing import Any, Literal
2425

2526
import arviz as az
2627
import matplotlib.pyplot as plt
28+
import numpy as np
2729
import pandas as pd
30+
import xarray as xr
2831
from sklearn.base import RegressorMixin
2932

3033
from causalpy.maketables_adapters import get_maketables_adapter
@@ -114,6 +117,67 @@ class BaseExperiment(ABC):
114117

115118
_default_model_class: type[PyMCModel] | None = None
116119

120+
_deprecated_design_aliases: dict[str, tuple[str, str]] = {}
121+
"""Mapping of ``old_attr -> (dataset_attr, key)`` for deprecated design
122+
matrix accessors. Subclasses populate this so that
123+
``__getattr__`` can forward accesses with a deprecation warning."""
124+
125+
def __getattr__(self, name: str) -> Any:
126+
aliases = type(self)._deprecated_design_aliases
127+
if name in aliases:
128+
dataset_attr, key = aliases[name]
129+
warnings.warn(
130+
f"{name} is deprecated, use {dataset_attr}['{key}']",
131+
DeprecationWarning,
132+
stacklevel=2,
133+
)
134+
return getattr(self, dataset_attr)[key]
135+
raise AttributeError(
136+
f"'{type(self).__name__}' object has no attribute '{name}'"
137+
)
138+
139+
@staticmethod
140+
def _build_design_dataset(
141+
X_raw: np.ndarray,
142+
y_raw: np.ndarray,
143+
*,
144+
obs_ind: np.ndarray | pd.Index,
145+
coeffs: list[str],
146+
treated_units: list[str] | None = None,
147+
) -> xr.Dataset:
148+
"""Build a standard ``xr.Dataset`` from raw design matrices.
149+
150+
Parameters
151+
----------
152+
X_raw : np.ndarray
153+
Predictor matrix, shape ``(n_obs, n_coeffs)``.
154+
y_raw : np.ndarray
155+
Outcome matrix, shape ``(n_obs, n_units)``.
156+
obs_ind : array-like
157+
Observation index coordinates.
158+
coeffs : list[str]
159+
Coefficient / column names for ``X_raw``.
160+
treated_units : list[str], optional
161+
Names for the treated-unit dimension of ``y_raw``.
162+
Defaults to ``["unit_0"]``.
163+
"""
164+
if treated_units is None:
165+
treated_units = ["unit_0"]
166+
return xr.Dataset(
167+
{
168+
"X": xr.DataArray(
169+
X_raw,
170+
dims=["obs_ind", "coeffs"],
171+
coords={"obs_ind": obs_ind, "coeffs": coeffs},
172+
),
173+
"y": xr.DataArray(
174+
y_raw,
175+
dims=["obs_ind", "treated_units"],
176+
coords={"obs_ind": obs_ind, "treated_units": treated_units},
177+
),
178+
}
179+
)
180+
117181
def __init__(self, model: PyMCModel | RegressorMixin | None = None) -> None:
118182
# Ensure we've made any provided Scikit Learn model (as identified as being type
119183
# RegressorMixin) compatible with CausalPy by appending our custom methods.

causalpy/experiments/diff_in_diff.py

Lines changed: 6 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
Difference in differences
1616
"""
1717

18-
import warnings
1918
from typing import Any, Literal
2019

2120
import arviz as az
@@ -98,6 +97,7 @@ class DifferenceInDifferences(BaseExperiment):
9897
supports_ols = True
9998
supports_bayes = True
10099
_default_model_class = LinearRegression
100+
_deprecated_design_aliases = {"X": ("design", "X"), "y": ("design", "y")}
101101

102102
def __init__(
103103
self,
@@ -136,38 +136,14 @@ def _build_design_matrices(self) -> None:
136136
def _prepare_data(self) -> None:
137137
"""Bundle design matrices into an ``xr.Dataset``."""
138138
n = self._X_raw.shape[0]
139-
self.design = xr.Dataset(
140-
{
141-
"X": xr.DataArray(
142-
self._X_raw,
143-
dims=["obs_ind", "coeffs"],
144-
coords={"obs_ind": np.arange(n), "coeffs": self.labels},
145-
),
146-
"y": xr.DataArray(
147-
self._y_raw,
148-
dims=["obs_ind", "treated_units"],
149-
coords={"obs_ind": np.arange(n), "treated_units": ["unit_0"]},
150-
),
151-
}
139+
self.design = self._build_design_dataset(
140+
self._X_raw,
141+
self._y_raw,
142+
obs_ind=np.arange(n),
143+
coeffs=self.labels,
152144
)
153145
del self._X_raw, self._y_raw
154146

155-
@property
156-
def X(self) -> xr.DataArray:
157-
""".. deprecated:: Use ``self.design['X']`` instead."""
158-
warnings.warn(
159-
"X is deprecated, use design['X']", DeprecationWarning, stacklevel=2
160-
)
161-
return self.design["X"]
162-
163-
@property
164-
def y(self) -> xr.DataArray:
165-
""".. deprecated:: Use ``self.design['y']`` instead."""
166-
warnings.warn(
167-
"y is deprecated, use design['y']", DeprecationWarning, stacklevel=2
168-
)
169-
return self.design["y"]
170-
171147
def algorithm(self) -> None:
172148
"""Run the experiment algorithm: fit model, predict, and calculate causal impact."""
173149
X = self.design["X"]

causalpy/experiments/interrupted_time_series.py

Lines changed: 16 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
Interrupted Time Series Analysis
1616
"""
1717

18-
import warnings
1918
from typing import Any, Literal
2019

2120
import arviz as az
@@ -129,6 +128,12 @@ class InterruptedTimeSeries(BaseExperiment):
129128
supports_ols = True
130129
supports_bayes = True
131130
_default_model_class = LinearRegression
131+
_deprecated_design_aliases = {
132+
"pre_X": ("pre_design", "X"),
133+
"pre_y": ("pre_design", "y"),
134+
"post_X": ("post_design", "X"),
135+
"post_y": ("post_design", "y"),
136+
}
132137

133138
def __init__(
134139
self,
@@ -169,39 +174,17 @@ def _build_design_matrices(self) -> None:
169174

170175
def _prepare_data(self) -> None:
171176
"""Bundle design matrices into ``xr.Dataset`` objects for pre and post periods."""
172-
self.pre_design = xr.Dataset(
173-
{
174-
"X": xr.DataArray(
175-
self._pre_X_raw,
176-
dims=["obs_ind", "coeffs"],
177-
coords={"obs_ind": self.datapre.index, "coeffs": self.labels},
178-
),
179-
"y": xr.DataArray(
180-
self._pre_y_raw,
181-
dims=["obs_ind", "treated_units"],
182-
coords={
183-
"obs_ind": self.datapre.index,
184-
"treated_units": ["unit_0"],
185-
},
186-
),
187-
}
177+
self.pre_design = self._build_design_dataset(
178+
self._pre_X_raw,
179+
self._pre_y_raw,
180+
obs_ind=self.datapre.index,
181+
coeffs=self.labels,
188182
)
189-
self.post_design = xr.Dataset(
190-
{
191-
"X": xr.DataArray(
192-
self._post_X_raw,
193-
dims=["obs_ind", "coeffs"],
194-
coords={"obs_ind": self.datapost.index, "coeffs": self.labels},
195-
),
196-
"y": xr.DataArray(
197-
self._post_y_raw,
198-
dims=["obs_ind", "treated_units"],
199-
coords={
200-
"obs_ind": self.datapost.index,
201-
"treated_units": ["unit_0"],
202-
},
203-
),
204-
}
183+
self.post_design = self._build_design_dataset(
184+
self._post_X_raw,
185+
self._post_y_raw,
186+
obs_ind=self.datapost.index,
187+
coeffs=self.labels,
205188
)
206189
del self._pre_X_raw, self._pre_y_raw, self._post_X_raw, self._post_y_raw
207190

@@ -320,46 +303,6 @@ def datapost(self) -> pd.DataFrame:
320303
"""
321304
return self.data[self.data.index >= self.treatment_time]
322305

323-
@property
324-
def pre_X(self) -> xr.DataArray:
325-
""".. deprecated:: Use ``self.pre_design['X']`` instead."""
326-
warnings.warn(
327-
"pre_X is deprecated, use pre_design['X']",
328-
DeprecationWarning,
329-
stacklevel=2,
330-
)
331-
return self.pre_design["X"]
332-
333-
@property
334-
def pre_y(self) -> xr.DataArray:
335-
""".. deprecated:: Use ``self.pre_design['y']`` instead."""
336-
warnings.warn(
337-
"pre_y is deprecated, use pre_design['y']",
338-
DeprecationWarning,
339-
stacklevel=2,
340-
)
341-
return self.pre_design["y"]
342-
343-
@property
344-
def post_X(self) -> xr.DataArray:
345-
""".. deprecated:: Use ``self.post_design['X']`` instead."""
346-
warnings.warn(
347-
"post_X is deprecated, use post_design['X']",
348-
DeprecationWarning,
349-
stacklevel=2,
350-
)
351-
return self.post_design["X"]
352-
353-
@property
354-
def post_y(self) -> xr.DataArray:
355-
""".. deprecated:: Use ``self.post_design['y']`` instead."""
356-
warnings.warn(
357-
"post_y is deprecated, use post_design['y']",
358-
DeprecationWarning,
359-
stacklevel=2,
360-
)
361-
return self.post_design["y"]
362-
363306
def _split_post_period(self) -> None:
364307
"""Split post period into intervention and post-intervention periods.
365308

causalpy/experiments/panel_regression.py

Lines changed: 6 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,12 @@
1515
Panel Regression with Fixed Effects
1616
"""
1717

18-
import warnings
1918
from typing import Any, Literal
2019

2120
import arviz as az
2221
import matplotlib.pyplot as plt
2322
import numpy as np
2423
import pandas as pd
25-
import xarray as xr
2624
from matplotlib.gridspec import GridSpec
2725
from patsy import dmatrices
2826
from scipy import stats
@@ -179,6 +177,7 @@ class PanelRegression(BaseExperiment):
179177

180178
supports_ols = True
181179
supports_bayes = True
180+
_deprecated_design_aliases = {"X": ("design", "X"), "y": ("design", "y")}
182181

183182
def __init__(
184183
self,
@@ -283,38 +282,14 @@ def _build_design_matrices(self) -> None:
283282
def _prepare_data(self) -> None:
284283
"""Bundle design matrices into an ``xr.Dataset``."""
285284
n = self._X_raw.shape[0]
286-
self.design = xr.Dataset(
287-
{
288-
"X": xr.DataArray(
289-
self._X_raw,
290-
dims=["obs_ind", "coeffs"],
291-
coords={"obs_ind": np.arange(n), "coeffs": self.labels},
292-
),
293-
"y": xr.DataArray(
294-
self._y_raw,
295-
dims=["obs_ind", "treated_units"],
296-
coords={"obs_ind": np.arange(n), "treated_units": ["unit_0"]},
297-
),
298-
}
285+
self.design = self._build_design_dataset(
286+
self._X_raw,
287+
self._y_raw,
288+
obs_ind=np.arange(n),
289+
coeffs=self.labels,
299290
)
300291
del self._X_raw, self._y_raw
301292

302-
@property
303-
def X(self) -> xr.DataArray:
304-
""".. deprecated:: Use ``self.design['X']`` instead."""
305-
warnings.warn(
306-
"X is deprecated, use design['X']", DeprecationWarning, stacklevel=2
307-
)
308-
return self.design["X"]
309-
310-
@property
311-
def y(self) -> xr.DataArray:
312-
""".. deprecated:: Use ``self.design['y']`` instead."""
313-
warnings.warn(
314-
"y is deprecated, use design['y']", DeprecationWarning, stacklevel=2
315-
)
316-
return self.design["y"]
317-
318293
def algorithm(self) -> None:
319294
"""Run the experiment algorithm: fit the model."""
320295
X = self.design["X"]

0 commit comments

Comments
 (0)