Skip to content

Commit 10a3b0b

Browse files
Support datetime treatment_time in SyntheticControl plots (#731)
* Support datetime treatment_time in SyntheticControl plots. Convert intervention line positions through each axis unit converter so both numeric and datetime indices render correctly, and add an sklearn integration regression test for datetime treatment_time plotting. Co-authored-by: Cursor <cursoragent@cursor.com> * Add fallback coverage test for SyntheticControl axis conversion. Exercise the error-handling path in _convert_treatment_time_for_axis for TypeError and ValueError so the synthetic_control.py diff reaches Codecov patch coverage requirements. Co-authored-by: Cursor <cursoragent@cursor.com> --------- Co-authored-by: Cursor <cursoragent@cursor.com>
1 parent 996a31f commit 10a3b0b

2 files changed

Lines changed: 65 additions & 3 deletions

File tree

causalpy/experiments/synthetic_control.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -281,6 +281,18 @@ def summary(self, round_to: int | None = None) -> None:
281281
print(f"Treated unit: {self.treated_units[0]}")
282282
self.print_coefficients(round_to)
283283

284+
@staticmethod
285+
def _convert_treatment_time_for_axis(
286+
axis: plt.Axes, treatment_time: int | float | pd.Timestamp
287+
) -> int | float | pd.Timestamp:
288+
"""
289+
Convert treatment time into the plotting units expected by a specific axis.
290+
"""
291+
try:
292+
return axis.xaxis.convert_units(treatment_time)
293+
except (TypeError, ValueError):
294+
return treatment_time
295+
284296
def _bayesian_plot(
285297
self,
286298
round_to: int | None = None,
@@ -402,8 +414,11 @@ def _bayesian_plot(
402414

403415
# Intervention line
404416
for i in [0, 1, 2]:
417+
treatment_time = self._convert_treatment_time_for_axis(
418+
ax[i], self.treatment_time
419+
)
405420
ax[i].axvline(
406-
x=self.treatment_time,
421+
x=treatment_time,
407422
ls="-",
408423
lw=3,
409424
color="r",
@@ -528,10 +543,13 @@ def _ols_plot(
528543
label="Causal impact",
529544
)
530545

531-
# Intervention line (see #725 for datetime treatment_time support)
546+
# Intervention line
532547
for i in [0, 1, 2]:
548+
treatment_time = self._convert_treatment_time_for_axis(
549+
ax[i], self.treatment_time
550+
)
533551
ax[i].axvline(
534-
x=self.treatment_time,
552+
x=treatment_time,
535553
ls="-",
536554
lw=3,
537555
color="r",

causalpy/tests/test_integration_skl_examples.py

Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,50 @@ def test_sc():
174174
)
175175

176176

177+
@pytest.mark.integration
178+
def test_sc_datetime_treatment_time_plot():
179+
"""Test SyntheticControl plotting with datetime treatment_time and sklearn model."""
180+
df = (
181+
cp.load_data("geolift1")
182+
.assign(time=lambda x: pd.to_datetime(x["time"]))
183+
.set_index("time")
184+
)
185+
treatment_time = pd.to_datetime("2022-01-01")
186+
187+
result = cp.SyntheticControl(
188+
df,
189+
treatment_time,
190+
control_units=["Austria", "Belgium", "Bulgaria", "Croatia", "Cyprus"],
191+
treated_units=["Denmark"],
192+
model=cp.skl_models.WeightedProportion(),
193+
)
194+
195+
fig, ax = result.plot()
196+
assert isinstance(fig, plt.Figure)
197+
assert isinstance(ax, np.ndarray) and all(
198+
isinstance(item, plt.Axes) for item in ax
199+
), "ax must be a numpy.ndarray of plt.Axes"
200+
201+
202+
@pytest.mark.parametrize("error_type", [TypeError, ValueError])
203+
def test_sc_convert_treatment_time_for_axis_fallback(error_type):
204+
"""Return original treatment_time when axis conversion raises."""
205+
206+
class FailingXAxis:
207+
def convert_units(self, _value):
208+
raise error_type("conversion failed")
209+
210+
class FailingAxis:
211+
xaxis = FailingXAxis()
212+
213+
treatment_time = pd.Timestamp("2022-01-01")
214+
converted = cp.SyntheticControl._convert_treatment_time_for_axis(
215+
FailingAxis(), treatment_time
216+
)
217+
218+
assert converted is treatment_time
219+
220+
177221
@pytest.mark.integration
178222
def test_rd_linear_main_effects():
179223
"""

0 commit comments

Comments
 (0)