Skip to content

Commit 57efa72

Browse files
louismagowanclaudedrbenvincent
authored
Fix simulate_data.py reproducibility + remove synthetic CSVs (#794)
* fix: make all data generators reproducible via seed parameter - Add `seed` parameter to all generator functions that lacked it - Replace scipy.stats RNG calls (norm().rvs, uniform.rvs, dirichlet().rvs) with numpy Generator methods (rng.normal, rng.uniform, rng.dirichlet) - Replace np.random.* global state calls with local rng instances - Fix create_series() bug: hardcoded length_scale=2 now uses parameter - Rename create_series, generate_seasonality, periodic_kernel to private (_create_series, _generate_seasonality, _periodic_kernel) since they are internal helpers that require an rng instance from their caller - Fix deprecated pandas freq="M" → freq="ME" - Remove unused scipy.stats imports (norm, uniform, dirichlet) - Remove module-level global rng; keep RANDOM_SEED constant for use by load_data() and test fixtures - Standardize generate_staggered_did_data to use same rng pattern Closes #545 Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * refactor: generate synthetic datasets on-the-fly in load_data() Split DATASETS dict into SYNTHETIC_DATASETS (generated via seeded functions) and REAL_WORLD_DATASETS (loaded from CSV). This removes the dependency on synthetic CSV files while keeping real-world data as shipped CSVs. - Replace _get_data_home() with simple _DATA_DIR = Path(__file__).parent - Remove circular `import causalpy as cp` dependency - All 8 synthetic datasets now call their generator with RANDOM_SEED Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * test: add session-scoped fixtures for all synthetic datasets Eight fixtures (did_data, its_data, its_simple_data, rd_data, sc_data, anova1_data, geolift1_data, geolift_multi_cell_data) generate data once per test session using RANDOM_SEED, avoiding redundant calls to load_data() or generators in individual tests. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * test: replace cp.load_data() with fixtures for synthetic datasets Update 8 test files to use session-scoped fixtures (did_data, its_data, rd_data, sc_data, anova1_data, geolift1_data) instead of calling cp.load_data() for synthetic datasets. Real-world dataset loading (banks, brexit, drinking, risk, nhefs) remains unchanged. Also rewrite test_data_loading.py to parametrize over all datasets (synthetic + real-world) and add reproducibility + unknown-key tests. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * chore: delete synthetic CSVs and unused gt_social_media_data.csv Remove 8 synthetic CSV files (~172K) now generated programmatically: did.csv, regression_discontinuity.csv, synthetic_control.csv, its.csv, its_simple.csv, ancova_generated.csv, geolift1.csv, geolift_multi_cell.csv Also remove gt_social_media_data.csv which was never referenced in the DATASETS dict or any code. Real-world CSVs (12 files, ~3.2MB) remain in the repo. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * fix(datasets): use Callable type hint instead of callable for mypy Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> * refactor(simulate_data): remove unused generate_time_series_data function Superseded by generate_time_series_data_seasonal and generate_time_series_data_simple. Not imported or called anywhere. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com> --------- Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com> Co-authored-by: Benjamin T. Vincent <inferencelab@gmail.com>
1 parent eaa37f8 commit 57efa72

20 files changed

Lines changed: 442 additions & 1703 deletions

causalpy/data/ancova_generated.csv

Lines changed: 0 additions & 201 deletions
This file was deleted.

causalpy/data/datasets.py

Lines changed: 57 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -15,39 +15,62 @@
1515
Functions to load example datasets
1616
"""
1717

18-
import pathlib
18+
from collections.abc import Callable
19+
from pathlib import Path
1920

2021
import pandas as pd
2122

22-
import causalpy as cp
23-
24-
DATASETS = {
25-
"banks": {"filename": "banks.csv"},
26-
"brexit": {"filename": "GDP_in_dollars_billions.csv"},
27-
"covid": {"filename": "deaths_and_temps_england_wales.csv"},
28-
"did": {"filename": "did.csv"},
29-
"drinking": {"filename": "drinking.csv"},
30-
"its": {"filename": "its.csv"},
31-
"its simple": {"filename": "its_simple.csv"},
32-
"rd": {"filename": "regression_discontinuity.csv"},
33-
"sc": {"filename": "synthetic_control.csv"},
34-
"anova1": {"filename": "ancova_generated.csv"},
35-
"geolift1": {"filename": "geolift1.csv"},
36-
"geolift_multi_cell": {"filename": "geolift_multi_cell.csv"},
37-
"risk": {"filename": "AJR2001.csv"},
38-
"nhefs": {"filename": "nhefs.csv"},
39-
"schoolReturns": {"filename": "schoolingReturns.csv"},
40-
"pisa18": {"filename": "PISA18sampleScale.csv"},
41-
"nets": {"filename": "nets_df.csv"},
42-
"lalonde": {"filename": "lalonde.csv"},
43-
"zipcodes": {"filename": "zipcodes_data.csv"},
44-
"nevo": {"filename": "data_nevo.csv"},
23+
from .simulate_data import (
24+
RANDOM_SEED,
25+
generate_ancova_data,
26+
generate_did,
27+
generate_geolift_data,
28+
generate_multicell_geolift_data,
29+
generate_regression_discontinuity_data,
30+
generate_synthetic_control_data,
31+
generate_time_series_data_seasonal,
32+
generate_time_series_data_simple,
33+
)
34+
35+
_DATA_DIR = Path(__file__).parent
36+
37+
# Synthetic datasets are generated programmatically for reproducibility.
38+
# .reset_index() on ITS functions because generators set date as the index,
39+
# but the old CSV-based load_data returned date as a column.
40+
SYNTHETIC_DATASETS: dict[str, Callable[[], pd.DataFrame]] = {
41+
"did": lambda: generate_did(seed=RANDOM_SEED),
42+
"rd": lambda: generate_regression_discontinuity_data(
43+
true_treatment_threshold=0.5, seed=RANDOM_SEED
44+
),
45+
"sc": lambda: generate_synthetic_control_data(seed=RANDOM_SEED)[0],
46+
"its": lambda: generate_time_series_data_seasonal(
47+
treatment_time=pd.to_datetime("2017-01-01"), seed=RANDOM_SEED
48+
).reset_index(),
49+
"its simple": lambda: generate_time_series_data_simple(
50+
treatment_time=pd.to_datetime("2015-01-01"), seed=RANDOM_SEED
51+
).reset_index(),
52+
"anova1": lambda: generate_ancova_data(seed=RANDOM_SEED),
53+
"geolift1": lambda: generate_geolift_data(seed=RANDOM_SEED).reset_index(),
54+
"geolift_multi_cell": lambda: generate_multicell_geolift_data(
55+
seed=RANDOM_SEED
56+
).reset_index(),
4557
}
4658

47-
48-
def _get_data_home() -> pathlib.Path:
49-
"""Return the path of the data directory"""
50-
return pathlib.Path(cp.__file__).parents[1] / "causalpy" / "data"
59+
# Real-world datasets remain as CSV files shipped with the package.
60+
REAL_WORLD_DATASETS: dict[str, str] = {
61+
"banks": "banks.csv",
62+
"brexit": "GDP_in_dollars_billions.csv",
63+
"covid": "deaths_and_temps_england_wales.csv",
64+
"drinking": "drinking.csv",
65+
"risk": "AJR2001.csv",
66+
"nhefs": "nhefs.csv",
67+
"schoolReturns": "schoolingReturns.csv",
68+
"pisa18": "PISA18sampleScale.csv",
69+
"nets": "nets_df.csv",
70+
"lalonde": "lalonde.csv",
71+
"zipcodes": "zipcodes_data.csv",
72+
"nevo": "data_nevo.csv",
73+
}
5174

5275

5376
def load_data(dataset: str) -> pd.DataFrame:
@@ -84,6 +107,7 @@ def load_data(dataset: str) -> pd.DataFrame:
84107
- ``"zipcodes"`` - Geo-experimentation zipcode data for comparative interrupted time
85108
series analysis. Based on synthetic data from Juan Orduz's blog post on
86109
`time-based regression for geo-experiments <https://juanitorduz.github.io/time_based_regression_pymc/>`_.
110+
- ``"nevo"`` - Berry, Levinsohn, and Pakes (1995) cereal data for BLP estimation
87111
88112
Returns
89113
-------
@@ -106,11 +130,9 @@ def load_data(dataset: str) -> pd.DataFrame:
106130
107131
>>> df = cp.load_data("rd")
108132
"""
109-
110-
if dataset in DATASETS:
111-
data_dir = _get_data_home()
112-
datafile = DATASETS[dataset]
113-
file_path = data_dir / datafile["filename"]
114-
return pd.read_csv(file_path)
133+
if dataset in SYNTHETIC_DATASETS:
134+
return SYNTHETIC_DATASETS[dataset]()
135+
elif dataset in REAL_WORLD_DATASETS:
136+
return pd.read_csv(_DATA_DIR / REAL_WORLD_DATASETS[dataset])
115137
else:
116-
raise ValueError(f"Dataset {dataset} not found!")
138+
raise ValueError(f"Dataset {dataset!r} not found!")

causalpy/data/did.csv

Lines changed: 0 additions & 41 deletions
This file was deleted.

0 commit comments

Comments
 (0)