Skip to content

Commit 2c48962

Browse files
Add Donut Regression Discontinuity functionality (#610)
* Add donut RDD support to RegressionDiscontinuity Introduces a donut_hole parameter to RegressionDiscontinuity, allowing exclusion of observations within a specified distance from the treatment threshold for robustness against manipulation or heaping. Updates plotting, summary, and input validation to support donut RDD, adds comprehensive tests for donut_hole behavior and validation, expands glossary with donut RDD concepts, and provides a new notebook demonstrating donut RDD usage. References on donut RDD, manipulation, and heaping are added to the bibliography. * update cell tags * Increase sample size and strengthen manipulation in donut RDD notebook Raised the number of generated observations from 500 to 1000 and increased manipulation parameters to better demonstrate the donut RDD approach. Updated output and data table examples to reflect the new data generation settings. * edit cell tag * add glossary terms * Expand explanation of heaping and manipulation in Donut RDD * Add regression discontinuity edge case tests Added tests for warnings when bandwidth or donut_hole filters leave few datapoints, for unrecognized model types, and for donut hole boundary lines in OLS and Bayesian plots. Also updated interrogate badge coverage from 96.3% to 96.4%. * Fix plot, legend, warning, and validation issues in donut RDD - Conditional two-layer scatter: only show excluded/fit data distinction when data is actually excluded; default case shows single "data" layer - Fix Bayesian plot legend to include all labeled artists (donut boundaries, threshold, scatter labels) instead of only posterior mean - Fix malformed warning when filter_desc is empty on small datasets - Use ValueError instead of DataException for donut_hole param validation - Update tests to expect ValueError for donut_hole validation Co-authored-by: Cursor <cursoragent@cursor.com> --------- Co-authored-by: Cursor <cursoragent@cursor.com>
1 parent d7f3dc7 commit 2c48962

6 files changed

Lines changed: 1931 additions & 35 deletions

File tree

causalpy/experiments/regression_discontinuity.py

Lines changed: 125 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,12 @@ class RegressionDiscontinuity(BaseExperiment):
6060
:param bandwidth:
6161
Data outside of the bandwidth (relative to the discontinuity) is not used to fit
6262
the model.
63+
:param donut_hole:
64+
Observations within this distance from the treatment threshold are excluded from
65+
model fitting. Used as a robustness check when observations closest to the
66+
threshold may be problematic (e.g., due to manipulation or heaping). Defaults
67+
to 0.0 (no exclusion). Must be non-negative and less than bandwidth if bandwidth
68+
is finite.
6369
6470
Example
6571
--------
@@ -94,6 +100,7 @@ def __init__(
94100
running_variable_name: str = "x",
95101
epsilon: float = 0.001,
96102
bandwidth: float = np.inf,
103+
donut_hole: float = 0.0,
97104
**kwargs: dict,
98105
) -> None:
99106
super().__init__(model=model)
@@ -104,28 +111,43 @@ def __init__(
104111
self.treatment_threshold = treatment_threshold
105112
self.epsilon = epsilon
106113
self.bandwidth = bandwidth
114+
self.donut_hole = donut_hole
107115
self.input_validation()
108116
self._build_design_matrices()
109117
self._prepare_data()
110118
self.algorithm()
111119

112120
def _build_design_matrices(self) -> None:
113-
"""Build design matrices from formula and data, applying bandwidth filtering."""
121+
"""Build design matrices from formula and data, applying bandwidth and donut hole filtering."""
122+
x_vals = self.data[self.running_variable_name]
123+
c = self.treatment_threshold
124+
mask = pd.Series(True, index=self.data.index)
125+
114126
if self.bandwidth is not np.inf:
115-
fmin = self.treatment_threshold - self.bandwidth
116-
fmax = self.treatment_threshold + self.bandwidth
117-
filtered_data = self.data.query(
118-
f"{fmin} <= {self.running_variable_name} <= {fmax}"
119-
)
120-
if len(filtered_data) <= 10:
121-
warnings.warn(
122-
f"Choice of bandwidth parameter has lead to only {len(filtered_data)} remaining datapoints. Consider increasing the bandwidth parameter.", # noqa: E501
123-
UserWarning,
124-
stacklevel=2,
127+
mask &= np.abs(x_vals - c) <= self.bandwidth
128+
129+
if self.donut_hole > 0:
130+
mask &= np.abs(x_vals - c) >= self.donut_hole
131+
132+
self.fit_data = self.data.loc[mask]
133+
134+
if len(self.fit_data) <= 10:
135+
filter_desc = []
136+
if self.bandwidth is not np.inf:
137+
filter_desc.append(f"bandwidth={self.bandwidth}")
138+
if self.donut_hole > 0:
139+
filter_desc.append(f"donut_hole={self.donut_hole}")
140+
if filter_desc:
141+
msg = (
142+
f"Choice of {' and '.join(filter_desc)} parameters has led to only "
143+
f"{len(self.fit_data)} remaining datapoints. "
144+
f"Consider adjusting these parameters."
125145
)
126-
y, X = dmatrices(self.formula, filtered_data)
127-
else:
128-
y, X = dmatrices(self.formula, self.data)
146+
else:
147+
msg = f"Only {len(self.fit_data)} datapoints in the dataset."
148+
warnings.warn(msg, UserWarning, stacklevel=2)
149+
150+
y, X = dmatrices(self.formula, self.fit_data)
129151

130152
self._y_design_info = y.design_info
131153
self._x_design_info = X.design_info
@@ -227,6 +249,16 @@ def input_validation(self) -> None:
227249
"""The treated variable should be dummy coded. Consisting of 0's and 1's only.""" # noqa: E501
228250
)
229251

252+
# Validate donut_hole parameter
253+
if self.donut_hole < 0:
254+
raise ValueError("donut_hole must be non-negative.")
255+
256+
if self.bandwidth is not np.inf and self.donut_hole >= self.bandwidth:
257+
raise ValueError(
258+
f"donut_hole ({self.donut_hole}) must be less than bandwidth "
259+
f"({self.bandwidth}) when bandwidth is finite."
260+
)
261+
230262
# Convert integer treated variable to boolean if needed
231263
if self.data["treated"].dtype in ["int64", "int32"]:
232264
# Make a copy to avoid SettingWithCopyWarning
@@ -249,10 +281,13 @@ def summary(self, round_to: int | None = None) -> None:
249281
:param round_to:
250282
Number of decimals used to round results. Defaults to 2. Use "None" to return raw numbers.
251283
"""
252-
print("Difference in Differences experiment")
284+
print("Regression Discontinuity experiment")
253285
print(f"Formula: {self.formula}")
254286
print(f"Running variable: {self.running_variable_name}")
255287
print(f"Threshold on running variable: {self.treatment_threshold}")
288+
print(f"Bandwidth: {self.bandwidth}")
289+
print(f"Donut hole: {self.donut_hole}")
290+
print(f"Observations used for fit: {len(self.fit_data)}")
256291
print("\nResults:")
257292
print(
258293
f"Discontinuity at threshold = {convert_to_string(self.discontinuity_at_threshold)}"
@@ -265,28 +300,39 @@ def _bayesian_plot(
265300
) -> tuple[plt.Figure, plt.Axes]:
266301
"""Generate plot for regression discontinuity designs."""
267302
fig, ax = plt.subplots()
268-
# Plot raw data
303+
304+
# Plot data: use two layers only when there are excluded observations
305+
has_exclusion = len(self.fit_data) < len(self.data)
306+
if has_exclusion:
307+
sns.scatterplot(
308+
self.data,
309+
x=self.running_variable_name,
310+
y=self.outcome_variable_name,
311+
color="lightgray",
312+
ax=ax,
313+
label="excluded data",
314+
)
269315
sns.scatterplot(
270-
self.data,
316+
self.fit_data,
271317
x=self.running_variable_name,
272318
y=self.outcome_variable_name,
273-
c="k",
319+
color="k",
274320
ax=ax,
321+
label="fit data" if has_exclusion else "data",
275322
)
276323

277324
# Plot model fit to data
278-
h_line, h_patch = plot_xY(
325+
plot_xY(
279326
self.x_pred[self.running_variable_name],
280327
self.pred["posterior_predictive"].mu.isel(treated_units=0),
281328
ax=ax,
282329
plot_hdi_kwargs={"color": "C1"},
330+
label="Posterior mean",
283331
)
284-
handles = [(h_line, h_patch)]
285-
labels = ["Posterior mean"]
286332

287333
# create strings to compose title
288334
title_info = f"{round_num(self.score['unit_0_r2'], round_to)} (std = {round_num(self.score['unit_0_r2_std'], round_to)})"
289-
r2 = f"Bayesian $R^2$ on all data = {title_info}"
335+
r2 = f"Bayesian $R^2$ on fit data = {title_info}"
290336
percentiles = self.discontinuity_at_threshold.quantile([0.03, 1 - 0.03]).values
291337
ci = (
292338
r"$CI_{94\%}$"
@@ -296,34 +342,61 @@ def _bayesian_plot(
296342
Discontinuity at threshold = {round_num(self.discontinuity_at_threshold.mean(), round_to)},
297343
"""
298344
ax.set(title=r2 + "\n" + discon + ci)
299-
# Intervention line
345+
346+
# Treatment threshold line
300347
ax.axvline(
301348
x=self.treatment_threshold,
302349
ls="-",
303350
lw=3,
304351
color="r",
305352
label="treatment threshold",
306353
)
307-
ax.legend(
308-
handles=(h_tuple for h_tuple in handles),
309-
labels=labels,
310-
fontsize=LEGEND_FONT_SIZE,
311-
)
354+
355+
# Add donut hole boundary lines if donut_hole > 0
356+
if self.donut_hole > 0:
357+
ax.axvline(
358+
x=self.treatment_threshold - self.donut_hole,
359+
ls="--",
360+
lw=2,
361+
color="orange",
362+
label="donut boundary",
363+
)
364+
ax.axvline(
365+
x=self.treatment_threshold + self.donut_hole,
366+
ls="--",
367+
lw=2,
368+
color="orange",
369+
)
370+
371+
ax.legend(fontsize=LEGEND_FONT_SIZE)
312372
return (fig, ax)
313373

314374
def _ols_plot(
315375
self, round_to: int | None = None, **kwargs: dict
316376
) -> tuple[plt.Figure, plt.Axes]:
317377
"""Generate plot for regression discontinuity designs."""
318378
fig, ax = plt.subplots()
319-
# Plot raw data
379+
380+
# Plot data: use two layers only when there are excluded observations
381+
has_exclusion = len(self.fit_data) < len(self.data)
382+
if has_exclusion:
383+
sns.scatterplot(
384+
self.data,
385+
x=self.running_variable_name,
386+
y=self.outcome_variable_name,
387+
color="lightgray",
388+
ax=ax,
389+
label="excluded data",
390+
)
320391
sns.scatterplot(
321-
self.data,
392+
self.fit_data,
322393
x=self.running_variable_name,
323394
y=self.outcome_variable_name,
324-
c="k", # hue="treated",
395+
color="k",
325396
ax=ax,
397+
label="fit data" if has_exclusion else "data",
326398
)
399+
327400
# Plot model fit to data
328401
ax.plot(
329402
self.x_pred[self.running_variable_name],
@@ -332,18 +405,37 @@ def _ols_plot(
332405
markersize=10,
333406
label="model fit",
334407
)
408+
335409
# create strings to compose title
336-
r2 = f"$R^2$ on all data = {round_num(float(self.score), round_to)}"
410+
r2 = f"$R^2$ on fit data = {round_num(float(self.score), round_to)}"
337411
discon = f"Discontinuity at threshold = {round_num(self.discontinuity_at_threshold, round_to)}"
338412
ax.set(title=r2 + "\n" + discon)
339-
# Intervention line
413+
414+
# Treatment threshold line
340415
ax.axvline(
341416
x=self.treatment_threshold,
342417
ls="-",
343418
lw=3,
344419
color="r",
345420
label="treatment threshold",
346421
)
422+
423+
# Add donut hole boundary lines if donut_hole > 0
424+
if self.donut_hole > 0:
425+
ax.axvline(
426+
x=self.treatment_threshold - self.donut_hole,
427+
ls="--",
428+
lw=2,
429+
color="orange",
430+
label="donut boundary",
431+
)
432+
ax.axvline(
433+
x=self.treatment_threshold + self.donut_hole,
434+
ls="--",
435+
lw=2,
436+
color="orange",
437+
)
438+
347439
ax.legend(fontsize=LEGEND_FONT_SIZE)
348440
return (fig, ax)
349441

0 commit comments

Comments
 (0)