Skip to content

Commit f36e9db

Browse files
Add automated notebook testing with Papermill (#602)
* Add automated notebook testing with Papermill Introduce CI workflow to validate that docs notebooks execute without errors. - Add `.github/workflows/test_notebook.yml` running notebooks in 3 parallel splits (PyMC, sklearn, other) on Python 3.12 - Add `scripts/run_notebooks/runner.py` for Papermill-based execution with nbclient widget output guards and optional `--parallel` flag - Add `scripts/run_notebooks/injected.py` to mock `pm.sample` with prior predictive draws for fast CI execution - Add `scripts/run_notebooks/skip_notebooks.yml` for notebooks incompatible with the CI environment (JAX-dependent IV notebooks) - Add papermill to test dependencies in pyproject.toml - Fix sampling bug in iv_pymc.ipynb uncertainty plot - Remove watermark cell from inv_prop_latent.ipynb Co-authored-by: Cursor <cursoragent@cursor.com> * Fix zizmor security alerts in test_notebook.yml - Add workflow-level `permissions: {}` and job-level `contents: read` - Pin actions/checkout and actions/setup-python to SHA digests - Set `persist-credentials: false` on checkout Co-authored-by: Cursor <cursoragent@cursor.com> --------- Co-authored-by: Cursor <cursoragent@cursor.com>
1 parent 03caa7d commit f36e9db

8 files changed

Lines changed: 412 additions & 50 deletions

File tree

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
name: Test Notebooks
2+
3+
on:
4+
pull_request:
5+
branches: [main]
6+
paths:
7+
- "pyproject.toml"
8+
- "causalpy/**"
9+
- ".github/workflows/test_notebook.yml"
10+
- "scripts/run_notebooks/**"
11+
- "docs/source/notebooks/**"
12+
push:
13+
branches: [main]
14+
paths:
15+
- "pyproject.toml"
16+
- "causalpy/**"
17+
- ".github/workflows/test_notebook.yml"
18+
- "scripts/run_notebooks/**"
19+
- "docs/source/notebooks/**"
20+
21+
permissions: {}
22+
23+
concurrency:
24+
group: ${{ github.workflow }}-${{ github.event.pull_request.number || github.ref }}
25+
cancel-in-progress: true
26+
27+
jobs:
28+
notebooks:
29+
runs-on: ubuntu-latest
30+
permissions:
31+
contents: read
32+
timeout-minutes: 60
33+
strategy:
34+
matrix:
35+
split:
36+
- "--pattern *_pymc*.ipynb"
37+
- "--pattern *_skl*.ipynb"
38+
- "--exclude-pattern _pymc --exclude-pattern _skl"
39+
steps:
40+
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
41+
with:
42+
persist-credentials: false
43+
44+
- uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # v6.2.0
45+
with:
46+
python-version: "3.12"
47+
48+
- name: Install system dependencies
49+
run: sudo apt-get update && sudo apt-get install -y graphviz
50+
51+
- name: Install dependencies
52+
run: |
53+
pip install --upgrade pip
54+
pip install -e ".[test,docs]"
55+
56+
- name: Run notebooks
57+
run: python scripts/run_notebooks/runner.py ${{ matrix.split }}

docs/source/notebooks/inv_prop_latent.ipynb

Lines changed: 1 addition & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -4751,49 +4751,6 @@
47514751
":filter: docname in docnames\n",
47524752
":::"
47534753
]
4754-
},
4755-
{
4756-
"cell_type": "markdown",
4757-
"metadata": {},
4758-
"source": [
4759-
"### Watermark"
4760-
]
4761-
},
4762-
{
4763-
"cell_type": "code",
4764-
"execution_count": 42,
4765-
"metadata": {},
4766-
"outputs": [
4767-
{
4768-
"name": "stdout",
4769-
"output_type": "stream",
4770-
"text": [
4771-
"Last updated: Tue Jul 29 2025\n",
4772-
"\n",
4773-
"Python implementation: CPython\n",
4774-
"Python version : 3.13.5\n",
4775-
"IPython version : 9.4.0\n",
4776-
"\n",
4777-
"pytensor: 2.31.7\n",
4778-
"xarray : 2025.7.0\n",
4779-
"\n",
4780-
"matplotlib: 3.10.3\n",
4781-
"arviz : 0.21.0\n",
4782-
"pandas : 2.3.1\n",
4783-
"causalpy : 0.4.2\n",
4784-
"patsy : 1.0.1\n",
4785-
"pymc : 5.23.0\n",
4786-
"numpy : 2.3.1\n",
4787-
"\n",
4788-
"Watermark: 2.5.0\n",
4789-
"\n"
4790-
]
4791-
}
4792-
],
4793-
"source": [
4794-
"%load_ext watermark\n",
4795-
"%watermark -n -u -v -iv -w -p pytensor,xarray"
4796-
]
47974754
}
47984755
],
47994756
"metadata": {
@@ -4812,7 +4769,7 @@
48124769
"name": "python",
48134770
"nbconvert_exporter": "python",
48144771
"pygments_lexer": "ipython3",
4815-
"version": "3.13.5"
4772+
"version": "3.14.2"
48164773
}
48174774
},
48184775
"nbformat": 4,

docs/source/notebooks/iv_pymc.ipynb

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
},
1313
{
1414
"cell_type": "code",
15-
"execution_count": 1,
15+
"execution_count": null,
1616
"metadata": {},
1717
"outputs": [],
1818
"source": [
@@ -21,11 +21,11 @@
2121
"import numpy as np\n",
2222
"import pandas as pd\n",
2323
"from matplotlib.lines import Line2D\n",
24+
"from sklearn.linear_model import LinearRegression as sk_lin_reg\n",
2425
"\n",
2526
"import causalpy as cp\n",
2627
"from causalpy import InstrumentalVariable\n",
27-
"from causalpy.pymc_models import InstrumentalVariableRegression\n",
28-
"from causalpy.skl_models import LinearRegression as sk_lin_reg"
28+
"from causalpy.pymc_models import InstrumentalVariableRegression"
2929
]
3030
},
3131
{
@@ -861,7 +861,7 @@
861861
},
862862
{
863863
"cell_type": "code",
864-
"execution_count": 12,
864+
"execution_count": null,
865865
"metadata": {},
866866
"outputs": [
867867
{
@@ -918,7 +918,8 @@
918918
" Line2D([0], [0], color=\"black\", lw=4),\n",
919919
"]\n",
920920
"\n",
921-
"uncertainty.sample(500).T.plot(legend=False, color=\"orange\", alpha=0.4, ax=axs[1])\n",
921+
"n_samples = min(500, len(uncertainty))\n",
922+
"uncertainty.sample(n_samples).T.plot(legend=False, color=\"orange\", alpha=0.4, ax=axs[1])\n",
922923
"axs[1].plot(x, ols, color=\"black\", label=\"OLS fit\")\n",
923924
"axs[1].set_title(\"OLS versus Instrumental Regression Fits\", fontsize=20)\n",
924925
"axs[1].legend(custom_lines, [\"IV fits\", \"OlS fit\"])\n",

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ docs = [
9393
"sphinx-togglebutton",
9494
]
9595
lint = ["interrogate", "pre-commit", "ruff", "mypy"]
96-
test = ["pytest", "pytest-cov", "codespell", "nbformat", "nbconvert"]
96+
test = ["pytest", "pytest-cov", "codespell", "nbformat", "nbconvert", "papermill"]
9797

9898
[project.urls]
9999
Homepage = "https://github.com/pymc-labs/CausalPy"

scripts/run_notebooks/README.md

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
# Notebook Runner
2+
3+
This script runs Jupyter notebooks from `docs/source/notebooks/` to validate they execute without errors.
4+
5+
## How It Works
6+
7+
1. **Mocks `pm.sample()`** — Replaces MCMC sampling with prior predictive (1 chain × 100 draws) for speed
8+
2. **Uses Papermill** — Executes notebooks programmatically
9+
3. **Clears saved outputs** — Avoids widget state issues during execution
10+
4. **Guards widget updates** — Patches nbclient to ignore display_id assertion errors
11+
5. **Discards outputs** — Only checks for errors, doesn't save results
12+
13+
## Dependencies
14+
15+
The notebook runner mirrors the CI setup and expects a full docs/test environment.
16+
17+
1. **Install Python dependencies**
18+
19+
```bash
20+
pip install -e ".[test,docs]"
21+
```
22+
23+
This brings in Papermill, Jupyter, nbclient, and notebook-related dependencies.
24+
25+
2. **Install Graphviz (system dependency)**
26+
27+
- macOS:
28+
```bash
29+
brew install graphviz
30+
```
31+
- Ubuntu/Debian:
32+
```bash
33+
sudo apt-get update && sudo apt-get install -y graphviz
34+
```
35+
36+
3. **Optional: parallel execution**
37+
38+
```bash
39+
pip install joblib
40+
```
41+
42+
## Notes
43+
44+
- The runner executes using the `python3` Jupyter kernel. Ensure your environment
45+
provides that kernel (e.g., from `ipykernel` installed via the docs extras).
46+
- The CI workflow uses Python 3.12 and installs the same extras.
47+
48+
## Usage
49+
50+
```bash
51+
# Run all notebooks
52+
python scripts/run_notebooks/runner.py
53+
54+
# Run only PyMC notebooks
55+
python scripts/run_notebooks/runner.py --pattern "*_pymc*.ipynb"
56+
57+
# Run only sklearn notebooks
58+
python scripts/run_notebooks/runner.py --pattern "*_skl*.ipynb"
59+
60+
# Exclude PyMC and sklearn notebooks (run others)
61+
python scripts/run_notebooks/runner.py --exclude-pattern _pymc --exclude-pattern _skl
62+
63+
# Run notebooks in parallel (requires joblib)
64+
python scripts/run_notebooks/runner.py --parallel
65+
```
66+
67+
## CI Integration
68+
69+
The GitHub Actions workflow (`.github/workflows/test_notebook.yml`) runs this script in parallel:
70+
- Job 1: PyMC notebooks
71+
- Job 2: Sklearn notebooks
72+
- Job 3: Other notebooks
73+
74+
## Files
75+
76+
- `runner.py` — Main script
77+
- `injected.py` — Code injected into notebooks to mock `pm.sample()`
78+
- `skip_notebooks.yml` — List of notebooks to skip (incompatible with mock sampling)

scripts/run_notebooks/injected.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
"""Injected code to mock pm.sample for faster notebook execution."""
2+
3+
import numpy as np
4+
import pymc as pm
5+
import xarray as xr
6+
7+
# Minimum draws needed to satisfy notebook code that iterates over posterior samples
8+
MIN_DRAWS = 100
9+
10+
11+
def mock_sample(*args, **kwargs):
12+
"""Mock pm.sample using prior predictive sampling for speed."""
13+
random_seed = kwargs.get("random_seed")
14+
model = kwargs.get("model")
15+
16+
# If no model is provided via kwargs, try to infer it from positional args
17+
if model is None and args:
18+
first_arg = args[0]
19+
if isinstance(first_arg, pm.Model):
20+
model = first_arg
21+
22+
requested_draws = kwargs.get("draws")
23+
if requested_draws is None and len(args) > 1 and isinstance(args[1], int):
24+
requested_draws = args[1]
25+
26+
# Ensure enough draws for notebook code while keeping execution fast.
27+
n_draws = max(MIN_DRAWS, requested_draws or MIN_DRAWS)
28+
29+
idata = pm.sample_prior_predictive(
30+
model=model,
31+
random_seed=random_seed,
32+
draws=n_draws,
33+
)
34+
idata.add_groups(posterior=idata.prior)
35+
36+
# Create mock sample stats with diverging data
37+
if "sample_stats" not in idata:
38+
n_chains = 1
39+
sample_stats = xr.Dataset(
40+
{
41+
"diverging": xr.DataArray(
42+
np.zeros((n_chains, n_draws), dtype=int),
43+
dims=("chain", "draw"),
44+
)
45+
}
46+
)
47+
idata.add_groups(sample_stats=sample_stats)
48+
49+
del idata.prior
50+
if "prior_predictive" in idata:
51+
del idata.prior_predictive
52+
53+
return idata
54+
55+
56+
pm.sample = mock_sample
57+
pm.HalfFlat = pm.HalfNormal
58+
pm.Flat = pm.Normal

0 commit comments

Comments
 (0)