Skip to content

Commit 855ab81

Browse files
authored
Add market correlation analysis for Synthetic Control (G8) (#760)
* Add market correlation analysis for Synthetic Control (G8, #759) Add plot_correlations() utility for pre-experiment diagnostics and pre-treatment correlation reporting in SyntheticControl.summary(). Document in geolift notebooks and replace manual heatmap code in the Brexit notebook. Made-with: Cursor * Add donor pool curation with factor model DGP and literature guidance - Rewrite generate_geolift_data() with latent factor model (K=3 GP factors, unit-specific loadings) producing realistic positive and negative correlations across 10 control countries - Add donor pool selection sections to all 5 SC notebooks with correlation-based filtering and citations to Abadie (2010, 2021) and Abadie & L'Hour (2021) - Add "Donor pool selection" glossary entry with {term} links - Add 2 references to references.bib - Regenerate geolift1.csv from the new factor model DGP Made-with: Cursor * Fix notebook output metadata for docs build compatibility Add missing 'name' field to stream outputs in sc_pymc.ipynb and sc_skl.ipynb, and missing 'metadata' field on display_data outputs. These fields are required by nbformat validation and myst-nb rendering. Made-with: Cursor * Fix missing execution_count in notebook execute_result outputs Adds the required execution_count field to execute_result outputs in sc_pymc.ipynb and sc_skl.ipynb to pass nbformat strict validation used by the CI codespell check. Made-with: Cursor
1 parent 6d6e076 commit 855ab81

13 files changed

Lines changed: 9446 additions & 8544 deletions

File tree

causalpy/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,13 +30,14 @@
3030
from .experiments.regression_kink import RegressionKink
3131
from .experiments.staggered_did import StaggeredDifferenceInDifferences
3232
from .experiments.synthetic_control import SyntheticControl
33-
from .utils import extract_lift_for_mmm
33+
from .utils import extract_lift_for_mmm, plot_correlations
3434

3535
__all__ = [
3636
"__version__",
3737
"create_causalpy_compatible_class",
3838
"DifferenceInDifferences",
3939
"extract_lift_for_mmm",
40+
"plot_correlations",
4041
"InstrumentalVariable",
4142
"InterruptedTimeSeries",
4243
"InversePropensityWeighting",

causalpy/data/geolift1.csv

Lines changed: 209 additions & 209 deletions
Large diffs are not rendered by default.

causalpy/data/simulate_data.py

Lines changed: 45 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -333,53 +333,66 @@ def generate_ancova_data(
333333
return df
334334

335335

336-
def generate_geolift_data() -> pd.DataFrame:
337-
"""Generate synthetic data for a geolift example. This will consists of 6 untreated
338-
countries. The treated unit `Denmark` is a weighted combination of the untreated
339-
units. We additionally specify a treatment effect which takes effect after the
340-
`treatment_time`. The timeseries data is observed at weekly resolution and has
341-
annual seasonality, with this seasonality being a drawn from a Gaussian Process with
342-
a periodic kernel."""
336+
def generate_geolift_data(seed: int | None = None) -> pd.DataFrame:
337+
"""Generate synthetic geolift data using a latent factor model.
338+
339+
Each unit's time series is a linear combination of K=3 shared seasonal
340+
factors (GP draws) with unit-specific loadings, plus observation noise.
341+
Most countries share positive loadings and are therefore positively
342+
correlated, while 2 "contrarian" countries carry a negative loading on
343+
one factor, making them negatively correlated with the majority. The
344+
treated unit (Denmark) is a Dirichlet-weighted combination of the
345+
positively-loaded countries only, so it is well-reconstructed by good
346+
donors but poorly correlated with the contrarian ones.
347+
348+
This mirrors the latent factor DGP used to motivate synthetic control
349+
methods in Abadie (2010, 2021).
350+
"""
351+
rng = np.random.default_rng(seed)
343352
n_years = 4
344353
treatment_time = pd.to_datetime("2022-01-01")
345354
causal_impact = 0.2
346-
347355
time = pd.date_range(start="2019-01-01", periods=52 * n_years, freq="W")
356+
n_obs = len(time)
348357

349-
untreated = [
358+
K = 3
359+
factors = np.column_stack(
360+
[create_series(n_years=n_years, intercept=0) for _ in range(K)]
361+
) # (n_obs, K)
362+
363+
similar = [
350364
"Austria",
351365
"Belgium",
352366
"Bulgaria",
353367
"Croatia",
354368
"Cyprus",
355369
"Czech_Republic",
370+
"Estonia",
371+
"Finland",
356372
]
357-
358-
df = (
359-
pd.DataFrame(
360-
{
361-
country: create_series(n_years=n_years, intercept=3)
362-
for country in untreated
363-
}
364-
)
365-
.assign(time=time)
366-
.set_index("time")
367-
)
368-
369-
# create treated unit as a weighted sum of the untreated units
370-
weights = np.random.dirichlet(np.ones(len(untreated)), size=1)[0]
371-
df = df.assign(Denmark=np.dot(df[untreated].values, weights))
372-
373-
# add observation noise
374-
for col in untreated + ["Denmark"]:
375-
df[col] += np.random.normal(size=len(df), scale=0.1)
376-
377-
# add treatment effect
373+
contrarian = ["Greece", "Hungary"]
374+
untreated = similar + contrarian
375+
376+
# Positive loadings for similar countries, one negative loading for contrarians
377+
loadings: dict[str, np.ndarray] = {}
378+
for country in similar:
379+
loadings[country] = rng.uniform(0.3, 1.0, size=K)
380+
loadings["Greece"] = np.array([-0.6, -0.3, 0.8])
381+
loadings["Hungary"] = np.array([0.3, -0.7, -0.5])
382+
383+
df = pd.DataFrame(index=time)
384+
df.index.name = "time"
385+
for country in untreated:
386+
df[country] = factors @ loadings[country] + 3 + rng.normal(0, 0.1, size=n_obs)
387+
388+
# Denmark as a weighted sum of similar countries only
389+
w = rng.dirichlet(np.ones(len(similar)))
390+
df["Denmark"] = df[similar].values @ w + rng.normal(0, 0.1, size=n_obs)
391+
392+
# treatment effect
378393
df["Denmark"] += np.where(df.index < treatment_time, 0, causal_impact)
379394

380-
# ensure we never see any negative sales
381395
df = df.clip(lower=0)
382-
383396
return df
384397

385398

causalpy/experiments/synthetic_control.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,30 @@ def input_validation(
267267
"If data.index is not DatetimeIndex, treatment_time must be pd.Timestamp." # noqa: E501
268268
)
269269

270+
def _pre_treatment_correlations(self) -> dict[str, float]:
271+
"""Compute Pearson correlation between each treated unit and its
272+
synthetic control prediction in the pre-treatment period.
273+
274+
Returns
275+
-------
276+
dict[str, float]
277+
Mapping from treated unit name to correlation coefficient.
278+
"""
279+
correlations: dict[str, float] = {}
280+
for unit in self.treated_units:
281+
observed = self.datapre_treated.sel(treated_units=unit).values.flatten()
282+
if isinstance(self.model, PyMCModel):
283+
predicted = (
284+
self.pre_pred["posterior_predictive"]["mu"]
285+
.sel(treated_units=unit)
286+
.mean(dim=["chain", "draw"])
287+
.values.flatten()
288+
)
289+
else:
290+
predicted = np.asarray(self.pre_pred).flatten()
291+
correlations[unit] = float(np.corrcoef(observed, predicted)[0, 1])
292+
return correlations
293+
270294
def summary(self, round_to: int | None = None) -> None:
271295
"""Print summary of main results and model coefficients.
272296
@@ -280,6 +304,9 @@ def summary(self, round_to: int | None = None) -> None:
280304
else:
281305
print(f"Treated unit: {self.treated_units[0]}")
282306
self.print_coefficients(round_to)
307+
corrs = self._pre_treatment_correlations()
308+
for unit, r in corrs.items():
309+
print(f"Pre-treatment correlation ({unit}): {r:.4f}")
283310

284311
@staticmethod
285312
def _convert_treatment_time_for_axis(

causalpy/tests/test_utils.py

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
Tests for utility functions
1616
"""
1717

18+
import matplotlib
19+
import matplotlib.pyplot as plt
1820
import numpy as np
1921
import pandas as pd
2022
import pytest
@@ -26,6 +28,7 @@
2628
check_convex_hull_violation,
2729
extract_lift_for_mmm,
2830
get_interaction_terms,
31+
plot_correlations,
2932
round_num,
3033
)
3134

@@ -369,3 +372,87 @@ def test_extract_lift_for_mmm_raises_for_ols():
369372
x=0.0,
370373
delta_x=1000,
371374
)
375+
376+
377+
# ============================================================================
378+
# Tests for plot_correlations
379+
# ============================================================================
380+
381+
382+
@pytest.fixture
383+
def panel_data():
384+
"""Simple wide-format panel data for correlation tests."""
385+
rng = np.random.default_rng(0)
386+
n = 50
387+
base = np.sin(np.linspace(0, 4 * np.pi, n))
388+
return pd.DataFrame(
389+
{
390+
"A": base + rng.normal(0, 0.1, n),
391+
"B": base + rng.normal(0, 0.1, n),
392+
"C": -base + rng.normal(0, 0.1, n),
393+
}
394+
)
395+
396+
397+
def test_plot_correlations_returns_matrix_and_axes(panel_data):
398+
corr, ax = plot_correlations(panel_data)
399+
assert isinstance(corr, pd.DataFrame)
400+
assert corr.shape == (3, 3)
401+
assert isinstance(ax, matplotlib.axes.Axes)
402+
plt.close("all")
403+
404+
405+
def test_plot_correlations_diagonal_is_one(panel_data):
406+
corr, _ = plot_correlations(panel_data)
407+
np.testing.assert_allclose(np.diag(corr.values), 1.0)
408+
plt.close("all")
409+
410+
411+
def test_plot_correlations_symmetric(panel_data):
412+
corr, _ = plot_correlations(panel_data)
413+
np.testing.assert_allclose(corr.values, corr.values.T)
414+
plt.close("all")
415+
416+
417+
def test_plot_correlations_column_subset(panel_data):
418+
corr, _ = plot_correlations(panel_data, columns=["A", "B"])
419+
assert corr.shape == (2, 2)
420+
assert list(corr.columns) == ["A", "B"]
421+
plt.close("all")
422+
423+
424+
def test_plot_correlations_custom_ax(panel_data):
425+
fig, provided_ax = plt.subplots()
426+
_, returned_ax = plot_correlations(panel_data, ax=provided_ax)
427+
assert returned_ax is provided_ax
428+
plt.close("all")
429+
430+
431+
def test_plot_correlations_kwargs_forwarded(panel_data):
432+
corr, _ = plot_correlations(panel_data, annot=False, vmin=0)
433+
assert isinstance(corr, pd.DataFrame)
434+
plt.close("all")
435+
436+
437+
# ============================================================================
438+
# Tests for SyntheticControl._pre_treatment_correlations
439+
# ============================================================================
440+
441+
442+
def test_pre_treatment_correlations_single_unit(sc_result_single_unit):
443+
corrs = sc_result_single_unit._pre_treatment_correlations()
444+
assert "actual" in corrs
445+
assert 0 < corrs["actual"] <= 1.0
446+
447+
448+
def test_pre_treatment_correlations_multi_unit(sc_result_multi_unit):
449+
corrs = sc_result_multi_unit._pre_treatment_correlations()
450+
assert set(corrs.keys()) == {"t1", "t2"}
451+
for r in corrs.values():
452+
assert -1 <= r <= 1
453+
454+
455+
def test_summary_prints_correlation(sc_result_single_unit, capsys):
456+
sc_result_single_unit.summary()
457+
captured = capsys.readouterr()
458+
assert "Pre-treatment correlation" in captured.out

causalpy/utils.py

Lines changed: 81 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,12 @@
1818
from __future__ import annotations
1919

2020
import re
21-
from typing import TYPE_CHECKING
21+
from typing import TYPE_CHECKING, Any, Literal
2222

23+
import matplotlib.pyplot as plt
2324
import numpy as np
2425
import pandas as pd
26+
import seaborn as sns
2527
import xarray as xr
2628

2729
if TYPE_CHECKING:
@@ -220,6 +222,84 @@ def check_convex_hull_violation(
220222
}
221223

222224

225+
def plot_correlations(
226+
data: pd.DataFrame,
227+
columns: list[str] | None = None,
228+
method: Literal["pearson", "kendall", "spearman"] = "pearson",
229+
figsize: tuple[float, float] | None = None,
230+
ax: plt.Axes | None = None,
231+
**kwargs: Any,
232+
) -> tuple[pd.DataFrame, plt.Axes]:
233+
"""Plot a pairwise correlation heatmap for panel data columns.
234+
235+
Computes the pairwise correlation matrix between the specified columns
236+
(typically geographic units or time series) and displays it as a
237+
lower-triangle heatmap. This is a pre-experiment diagnostic for
238+
synthetic control analyses: markets that are highly correlated in the
239+
pre-treatment period are more likely to produce reliable counterfactuals.
240+
241+
Parameters
242+
----------
243+
data : pd.DataFrame
244+
Wide-format panel data with time as the index and locations/units
245+
as columns.
246+
columns : list[str], optional
247+
Subset of columns to include. If ``None``, all numeric columns
248+
are used.
249+
method : {"pearson", "kendall", "spearman"}, default "pearson"
250+
Correlation method passed to :meth:`pandas.DataFrame.corr`.
251+
figsize : tuple[float, float], optional
252+
Width and height in inches for the figure. Only used when ``ax``
253+
is not provided. If ``None``, matplotlib's default is used.
254+
ax : matplotlib.axes.Axes, optional
255+
Axes on which to draw the heatmap. If ``None``, a new figure and
256+
axes are created (sized according to ``figsize``).
257+
**kwargs
258+
Additional keyword arguments forwarded to :func:`seaborn.heatmap`
259+
(e.g., ``vmin``, ``vmax``, ``annot``, ``annot_kws``).
260+
261+
Returns
262+
-------
263+
tuple[pd.DataFrame, matplotlib.axes.Axes]
264+
The correlation matrix and the axes containing the heatmap.
265+
266+
Examples
267+
--------
268+
.. code-block:: python
269+
270+
import causalpy as cp
271+
272+
df = cp.load_data("geolift1")
273+
corr, ax = cp.plot_correlations(df)
274+
275+
# Larger figure with smaller annotation text
276+
corr, ax = cp.plot_correlations(df, figsize=(10, 8), annot_kws={"size": 7})
277+
"""
278+
subset = data[columns] if columns is not None else data.select_dtypes("number")
279+
corr = subset.corr(method=method)
280+
mask = np.triu(np.ones_like(corr, dtype=bool))
281+
282+
if ax is None:
283+
_, ax = plt.subplots(figsize=figsize)
284+
285+
defaults: dict[str, Any] = {
286+
"mask": mask,
287+
"cmap": sns.diverging_palette(230, 20, as_cmap=True),
288+
"vmin": -1,
289+
"vmax": 1,
290+
"center": 0,
291+
"square": True,
292+
"linewidths": 0.5,
293+
"cbar_kws": {"shrink": 0.8},
294+
"annot": True,
295+
"fmt": ".2f",
296+
}
297+
defaults.update(kwargs)
298+
299+
sns.heatmap(corr, ax=ax, **defaults)
300+
return corr, ax
301+
302+
223303
def extract_lift_for_mmm(
224304
sc_result: SyntheticControl,
225305
channel: str,

docs/source/knowledgebase/glossary.rst

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,10 @@ Glossary
4545
DiD
4646
Analysis where the treatment effect is estimated as a difference between treatment conditions in the differences between pre-treatment to post treatment observations.
4747

48+
Donor pool
49+
Donor pool selection
50+
In synthetic control methods, the donor pool is the set of untreated units available to construct the synthetic control. Donor pool selection (or curation) is the process of choosing which untreated units to include. Units that are structurally dissimilar to the treated unit -- for example, those with negative pre-treatment correlations -- should be excluded because they can introduce interpolation bias and degrade the synthetic control fit. This is especially important in Bayesian implementations where priors (e.g. Dirichlet) assign non-zero weight to every donor by construction :footcite:p:`abadie2021using,abadie2010synthetic`.
51+
4852
Donut regression discontinuity
4953
Donut RDD
5054
A robustness approach for regression discontinuity designs where observations within a specified distance from the treatment threshold are excluded from model fitting. This technique is used when observations closest to the cutoff may be problematic due to manipulation, sorting, or heaping/rounding of the running variable. By excluding the "donut hole" around the threshold, the analysis relies on observations that are less likely to be affected by such issues. See :footcite:t:`noack2024donut` for formal discussion of donut RDD properties.

0 commit comments

Comments
 (0)