Skip to content

Commit 711970c

Browse files
louismagowanclaudedrbenvincent
authored
Clean up experiments/: typos, annotations, and constants (#796)
* Fix typos in diff_in_diff and prepostnegd - diff_in_diff.py: "treament" -> "treatment" in plot label - prepostnegd.py: "trestment" -> "treatment" in comment Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * Fix **kwargs: dict -> **kwargs: Any in experiment __init__ and plot methods `**kwargs: dict` is incorrect — it annotates each value as a dict, not the collection. Every experiment except inverse_propensity_weighting (which already used Any) is fixed. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * Replace `is False` with `not` for dummy-coded checks Use idiomatic `if not func(...)` instead of `if func(...) is False`. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * Move misplaced docstring in DifferenceInDifferences.input_validation The docstring was after the first executable line; move it to the correct position immediately after the method signature. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * Standardize expt_type to instance attribute in ITS and PiecewiseITS - interrupted_time_series.py: remove dead class attribute (overwritten by self.expt_type in __init__) - piecewise_its.py: move class attribute to self.expt_type in __init__ - Update test to check expt_type as instance attribute Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * Extract LEGEND_FONT_SIZE to experiments/constants.py Create a shared constants module and replace the per-file `LEGEND_FONT_SIZE = 12` definitions in all 8 experiment files with an import from `causalpy.experiments.constants`. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * Extract HDI_PROB constant to utils.py and use throughout Add `HDI_PROB: float = 0.94` to `causalpy/utils.py` and replace all hardcoded `0.94` parameter defaults and `0.03` / `1 - 0.03` quantile bounds with expressions derived from HDI_PROB. Updated files: utils.py, plot_utils.py, pymc_models.py, and 6 experiment modules (prepostnegd, regression_discontinuity, regression_kink, interrupted_time_series, piecewise_its, staggered_did, synthetic_control). Test files and docstrings are intentionally left unchanged. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * Fix ruff linting: import ordering, HDI_PROB placement, line length, and UP038 - Move HDI_PROB after all imports in utils.py to avoid E402 - Sort `from causalpy.experiments.constants` alphabetically with other causalpy imports - Wrap long quantile lines for ruff format compliance - Fix pre-existing UP038: use `X | Y` instead of `(X, Y)` in isinstance - Restore `from typing import Any, Literal` in regression_discontinuity.py Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * Move HDI_PROB and LEGEND_FONT_SIZE to causalpy/constants.py and derive CI labels from HDI_PROB - Create causalpy/constants.py as single home for shared constants - Delete causalpy/experiments/constants.py (now unused) - Update all imports across 10 files to use causalpy.constants directly - Fix 5 hardcoded "94%" label strings to derive from HDI_PROB (pymc_models, utils, regression_kink, regression_discontinuity, prepostnegd); also fix missing backslash before % in prepostnegd CI string Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * Fix ruff formatting: add spaces around * in HDI_PROB f-strings Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com> * Remove hardcoded '94%' from docstrings and comments These strings should not reference a specific percentage since the actual interval width is derived from HDI_PROB at runtime. Made-with: Cursor --------- Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com> Co-authored-by: Benjamin T. Vincent <inferencelab@gmail.com>
1 parent 57efa72 commit 711970c

14 files changed

Lines changed: 104 additions & 78 deletions

causalpy/constants.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
# Copyright 2022 - 2026 The PyMC Labs Developers
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
"""
15+
Shared constants for the CausalPy package.
16+
"""
17+
18+
HDI_PROB: float = 0.94
19+
LEGEND_FONT_SIZE: int = 12

causalpy/experiments/diff_in_diff.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from patsy import build_design_matrices, dmatrices
2727
from sklearn.base import RegressorMixin
2828

29+
from causalpy.constants import LEGEND_FONT_SIZE
2930
from causalpy.custom_exceptions import (
3031
DataException,
3132
FormulaException,
@@ -48,8 +49,6 @@
4849

4950
from .base import BaseExperiment
5051

51-
LEGEND_FONT_SIZE = 12
52-
5352

5453
class DifferenceInDifferences(BaseExperiment):
5554
"""A class to analyse data from Difference in Difference settings.
@@ -107,7 +106,7 @@ def __init__(
107106
group_variable_name: str,
108107
post_treatment_variable_name: str = "post_treatment",
109108
model: PyMCModel | RegressorMixin | None = None,
110-
**kwargs: dict,
109+
**kwargs: Any,
111110
) -> None:
112111
super().__init__(model=model)
113112
self.causal_impact: xr.DataArray | float | None
@@ -261,10 +260,9 @@ def algorithm(self) -> None:
261260
raise ValueError("Model type not recognized")
262261

263262
def input_validation(self) -> None:
263+
"""Validate the input data and model formula for correctness"""
264264
# Validate formula structure and interaction interaction terms
265265
self._validate_formula_interaction_terms()
266-
267-
"""Validate the input data and model formula for correctness"""
268266
# Check if post_treatment_variable_name is in formula
269267
if self.post_treatment_variable_name not in self.formula:
270268
raise FormulaException(
@@ -282,7 +280,7 @@ def input_validation(self) -> None:
282280
"Require a `unit` column to label unique units. This is used for plotting purposes" # noqa: E501
283281
)
284282

285-
if _is_variable_dummy_coded(self.data[self.group_variable_name]) is False:
283+
if not _is_variable_dummy_coded(self.data[self.group_variable_name]):
286284
raise DataException(
287285
f"""The grouping variable {self.group_variable_name} should be dummy
288286
coded. Consisting of 0's and 1's only."""
@@ -331,11 +329,11 @@ def summary(self, round_to: int | None = 2) -> None:
331329
self.print_coefficients(round_to)
332330

333331
def _causal_impact_summary_stat(self, round_to: int | None = None) -> str:
334-
"""Computes the mean and 94% credible interval bounds for the causal impact."""
332+
"""Computes the mean and credible interval bounds for the causal impact."""
335333
return f"Causal impact = {convert_to_string(self.causal_impact, round_to=round_to)}"
336334

337335
def _bayesian_plot(
338-
self, round_to: int | None = None, **kwargs: dict
336+
self, round_to: int | None = None, **kwargs: Any
339337
) -> tuple[plt.Figure, plt.Axes]:
340338
"""
341339
Plot the results
@@ -485,7 +483,7 @@ def _plot_causal_impact_arrow(results, ax):
485483
return fig, ax
486484

487485
def _ols_plot(
488-
self, round_to: int | None = 2, **kwargs: dict
486+
self, round_to: int | None = 2, **kwargs: Any
489487
) -> tuple[plt.Figure, plt.Axes]:
490488
"""Generate plot for difference-in-differences"""
491489
fig, ax = plt.subplots()
@@ -517,7 +515,7 @@ def _ols_plot(
517515
"o",
518516
c="C1",
519517
markersize=10,
520-
label="model fit (treament group)",
518+
label="model fit (treatment group)",
521519
)
522520
# Plot counterfactual - post-test for treatment group IF no treatment
523521
# had occurred.

causalpy/experiments/instrumental_variable.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ def __init__(
126126
vs_prior_type=None,
127127
vs_hyperparams=None,
128128
binary_treatment=False,
129-
**kwargs: dict,
129+
**kwargs: Any,
130130
) -> None:
131131
super().__init__(model=model)
132132
self.expt_type = "Instrumental Variable Regression"

causalpy/experiments/interrupted_time_series.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from patsy import build_design_matrices, dmatrices
2626
from sklearn.base import RegressorMixin
2727

28+
from causalpy.constants import HDI_PROB, LEGEND_FONT_SIZE
2829
from causalpy.custom_exceptions import BadIndexException
2930
from causalpy.date_utils import _combine_datetime_indices, format_date_axes
3031
from causalpy.plot_utils import get_hdi_to_df, plot_xY
@@ -34,8 +35,6 @@
3435

3536
from .base import BaseExperiment
3637

37-
LEGEND_FONT_SIZE = 12
38-
3938

4039
class InterruptedTimeSeries(BaseExperiment):
4140
"""
@@ -126,7 +125,6 @@ class InterruptedTimeSeries(BaseExperiment):
126125
after the intervention ends.
127126
"""
128127

129-
expt_type = "Interrupted Time Series"
130128
supports_ols = True
131129
supports_bayes = True
132130
_default_model_class = LinearRegression
@@ -138,7 +136,7 @@ def __init__(
138136
formula: str,
139137
model: PyMCModel | RegressorMixin | None = None,
140138
treatment_end_time: int | float | pd.Timestamp | None = None,
141-
**kwargs: dict,
139+
**kwargs: Any,
142140
) -> None:
143141
super().__init__(model=model)
144142
self.pre_y: xr.DataArray
@@ -227,7 +225,7 @@ def algorithm(self) -> None:
227225
)
228226

229227
# get the model predictions of the observed (pre-intervention) data
230-
if isinstance(self.model, (PyMCModel, RegressorMixin)):
228+
if isinstance(self.model, PyMCModel | RegressorMixin):
231229
self.pre_pred = self.model.predict(X=self.pre_X)
232230

233231
# calculate the counterfactual (post period)
@@ -601,7 +599,7 @@ def summary(self, round_to: int | None = None) -> None:
601599
self.print_coefficients(round_to)
602600

603601
def _bayesian_plot(
604-
self, round_to: int | None = 2, **kwargs: dict
602+
self, round_to: int | None = 2, **kwargs: Any
605603
) -> tuple[plt.Figure, list[plt.Axes]]:
606604
"""
607605
Plot the results
@@ -797,7 +795,7 @@ def _bayesian_plot(
797795
return fig, ax
798796

799797
def _ols_plot(
800-
self, round_to: int | None = 2, **kwargs: dict
798+
self, round_to: int | None = 2, **kwargs: Any
801799
) -> tuple[plt.Figure, list[plt.Axes]]:
802800
"""
803801
Plot the results
@@ -887,7 +885,7 @@ def _ols_plot(
887885

888886
return (fig, ax)
889887

890-
def get_plot_data_bayesian(self, hdi_prob: float = 0.94) -> pd.DataFrame:
888+
def get_plot_data_bayesian(self, hdi_prob: float = HDI_PROB) -> pd.DataFrame:
891889
"""
892890
Recover the data of the experiment along with the prediction and causal impact information.
893891

causalpy/experiments/piecewise_its.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from patsy import dmatrices
2727
from sklearn.base import RegressorMixin
2828

29+
from causalpy.constants import HDI_PROB, LEGEND_FONT_SIZE
2930
from causalpy.custom_exceptions import FormulaException
3031
from causalpy.plot_utils import plot_xY
3132
from causalpy.pymc_models import LinearRegression, PyMCModel
@@ -35,8 +36,6 @@
3536

3637
from .base import BaseExperiment
3738

38-
LEGEND_FONT_SIZE = 12
39-
4039

4140
class PiecewiseITS(BaseExperiment):
4241
"""
@@ -144,7 +143,6 @@ class PiecewiseITS(BaseExperiment):
144143
the evaluation of public health interventions: a tutorial. Int J Epidemiol.
145144
"""
146145

147-
expt_type = "Piecewise Interrupted Time Series"
148146
supports_ols = True
149147
supports_bayes = True
150148
_default_model_class = LinearRegression
@@ -154,11 +152,12 @@ def __init__(
154152
data: pd.DataFrame,
155153
formula: str,
156154
model: PyMCModel | RegressorMixin | None = None,
157-
**kwargs: dict[str, Any],
155+
**kwargs: Any,
158156
) -> None:
159157
super().__init__(model=model)
160158

161159
# Store configuration
160+
self.expt_type = "Piecewise Interrupted Time Series"
162161
self.formula = formula
163162
self.data = data.copy()
164163

@@ -446,7 +445,7 @@ def summary(self, round_to: int | None = None) -> None:
446445
self.print_coefficients(round_to)
447446

448447
def _bayesian_plot(
449-
self, round_to: int | None = 2, **kwargs: dict[str, Any]
448+
self, round_to: int | None = 2, **kwargs: Any
450449
) -> tuple[plt.Figure, list[plt.Axes]]:
451450
"""
452451
Plot the results for Bayesian models.
@@ -563,7 +562,7 @@ def _bayesian_plot(
563562
return fig, ax
564563

565564
def _ols_plot(
566-
self, round_to: int | None = 2, **kwargs: dict[str, Any]
565+
self, round_to: int | None = 2, **kwargs: Any
567566
) -> tuple[plt.Figure, list[plt.Axes]]:
568567
"""
569568
Plot the results for OLS models.
@@ -626,7 +625,7 @@ def _ols_plot(
626625
plt.tight_layout()
627626
return fig, ax
628627

629-
def get_plot_data_bayesian(self, hdi_prob: float = 0.94) -> pd.DataFrame:
628+
def get_plot_data_bayesian(self, hdi_prob: float = HDI_PROB) -> pd.DataFrame:
630629
"""
631630
Recover the data of the experiment along with prediction and effect information.
632631

causalpy/experiments/prepostnegd.py

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from patsy import build_design_matrices, dmatrices
2727
from sklearn.base import RegressorMixin
2828

29+
from causalpy.constants import HDI_PROB, LEGEND_FONT_SIZE
2930
from causalpy.custom_exceptions import (
3031
DataException,
3132
)
@@ -36,8 +37,6 @@
3637

3738
from .base import BaseExperiment
3839

39-
LEGEND_FONT_SIZE = 12
40-
4140

4241
class PrePostNEGD(BaseExperiment):
4342
"""
@@ -97,7 +96,7 @@ def __init__(
9796
group_variable_name: str,
9897
pretreatment_variable_name: str,
9998
model: PyMCModel | None = None,
100-
**kwargs: dict,
99+
**kwargs: Any,
101100
) -> None:
102101
super().__init__(model=model)
103102
self.causal_impact: xr.DataArray
@@ -184,7 +183,7 @@ def algorithm(self) -> None:
184183
(new_x_treated,) = build_design_matrices([self._x_design_info], x_pred_treated)
185184
self.pred_treated = self.model.predict(X=np.asarray(new_x_treated))
186185

187-
# Evaluate causal impact as equal to the trestment effect
186+
# Evaluate causal impact as equal to the treatment effect
188187
self.causal_impact = self.model.idata.posterior["beta"].sel(
189188
{"coeffs": self._get_treatment_effect_coeff()}
190189
)
@@ -213,10 +212,12 @@ def _get_treatment_effect_coeff(self) -> str:
213212
raise NameError("Unable to find coefficient name for the treatment effect")
214213

215214
def _causal_impact_summary_stat(self, round_to: int | None = 2) -> str:
216-
"""Computes the mean and 94% credible interval bounds for the causal impact."""
217-
percentiles = self.causal_impact.quantile([0.03, 1 - 0.03]).values
215+
"""Computes the mean and credible interval bounds for the causal impact."""
216+
percentiles = self.causal_impact.quantile(
217+
[(1 - HDI_PROB) / 2, 1 - (1 - HDI_PROB) / 2]
218+
).values
218219
ci = (
219-
r"$CI_{94%}$"
220+
rf"$CI_{{{HDI_PROB * 100:.0f}\%}}$"
220221
+ f"[{round_num(percentiles[0], round_to)}, {round_num(percentiles[1], round_to)}]"
221222
)
222223
causal_impact = f"{round_num(self.causal_impact.mean(), round_to)}, "
@@ -235,7 +236,7 @@ def summary(self, round_to: int | None = None) -> None:
235236
self.print_coefficients(round_to)
236237

237238
def _bayesian_plot(
238-
self, round_to: int | None = None, **kwargs: dict
239+
self, round_to: int | None = None, **kwargs: Any
239240
) -> tuple[plt.Figure, list[plt.Axes]]:
240241
"""Generate plot for ANOVA-like experiments with non-equivalent group designs."""
241242
fig, ax = plt.subplots(

causalpy/experiments/regression_discontinuity.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
"""
1717

1818
import warnings # noqa: I001
19-
19+
from typing import Any, Literal
2020

2121
import numpy as np
2222
import pandas as pd
@@ -29,15 +29,17 @@
2929
DataException,
3030
FormulaException,
3131
)
32+
from causalpy.constants import HDI_PROB, LEGEND_FONT_SIZE
3233
from causalpy.plot_utils import plot_xY
3334
from causalpy.pymc_models import LinearRegression, PyMCModel
34-
from causalpy.utils import _is_variable_dummy_coded, convert_to_string, round_num
35-
36-
from .base import BaseExperiment
3735
from causalpy.reporting import EffectSummary, _effect_summary_rd
38-
from typing import Any, Literal
36+
from causalpy.utils import (
37+
_is_variable_dummy_coded,
38+
convert_to_string,
39+
round_num,
40+
)
3941

40-
LEGEND_FONT_SIZE = 12
42+
from .base import BaseExperiment
4143

4244

4345
class RegressionDiscontinuity(BaseExperiment):
@@ -101,7 +103,7 @@ def __init__(
101103
epsilon: float = 0.001,
102104
bandwidth: float = np.inf,
103105
donut_hole: float = 0.0,
104-
**kwargs: dict,
106+
**kwargs: Any,
105107
) -> None:
106108
super().__init__(model=model)
107109
self.expt_type = "Regression Discontinuity"
@@ -244,7 +246,7 @@ def input_validation(self) -> None:
244246
"A predictor called `treated` should be in the formula"
245247
)
246248

247-
if _is_variable_dummy_coded(self.data["treated"]) is False:
249+
if not _is_variable_dummy_coded(self.data["treated"]):
248250
raise DataException(
249251
"""The treated variable should be dummy coded. Consisting of 0's and 1's only.""" # noqa: E501
250252
)
@@ -296,7 +298,7 @@ def summary(self, round_to: int | None = None) -> None:
296298
self.print_coefficients(round_to)
297299

298300
def _bayesian_plot(
299-
self, round_to: int | None = 2, **kwargs: dict
301+
self, round_to: int | None = 2, **kwargs: Any
300302
) -> tuple[plt.Figure, plt.Axes]:
301303
"""Generate plot for regression discontinuity designs."""
302304
fig, ax = plt.subplots()
@@ -333,9 +335,11 @@ def _bayesian_plot(
333335
# create strings to compose title
334336
title_info = f"{round_num(self.score['unit_0_r2'], round_to)} (std = {round_num(self.score['unit_0_r2_std'], round_to)})"
335337
r2 = f"Bayesian $R^2$ on fit data = {title_info}"
336-
percentiles = self.discontinuity_at_threshold.quantile([0.03, 1 - 0.03]).values
338+
percentiles = self.discontinuity_at_threshold.quantile(
339+
[(1 - HDI_PROB) / 2, 1 - (1 - HDI_PROB) / 2]
340+
).values
337341
ci = (
338-
r"$CI_{94\%}$"
342+
rf"$CI_{{{HDI_PROB * 100:.0f}\%}}$"
339343
+ f"[{round_num(percentiles[0], round_to)}, {round_num(percentiles[1], round_to)}]"
340344
)
341345
discon = f"""
@@ -372,7 +376,7 @@ def _bayesian_plot(
372376
return (fig, ax)
373377

374378
def _ols_plot(
375-
self, round_to: int | None = None, **kwargs: dict
379+
self, round_to: int | None = None, **kwargs: Any
376380
) -> tuple[plt.Figure, plt.Axes]:
377381
"""Generate plot for regression discontinuity designs."""
378382
fig, ax = plt.subplots()

0 commit comments

Comments
 (0)