Skip to content

Commit 97db952

Browse files
committed
Improve notebook runner robustness
Add nbclient widget output guards, clear notebook outputs before injection, and support optional parallel execution to reduce notebook flakiness in CI.
1 parent 7aeabd3 commit 97db952

3 files changed

Lines changed: 62 additions & 8 deletions

File tree

docs/source/_static/interrogate_badge.svg

Lines changed: 3 additions & 3 deletions
Loading

scripts/run_notebooks/README.md

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@ This script runs Jupyter notebooks from `docs/source/notebooks/` to validate the
66

77
1. **Mocks `pm.sample()`** — Replaces MCMC sampling with prior predictive (1 chain × 100 draws) for speed
88
2. **Uses Papermill** — Executes notebooks programmatically
9-
3. **Discards outputs** — Only checks for errors, doesn't save results
9+
3. **Clears saved outputs** — Avoids widget state issues during execution
10+
4. **Guards widget updates** — Patches nbclient to ignore display_id assertion errors
11+
5. **Discards outputs** — Only checks for errors, doesn't save results
1012

1113
## Usage
1214

@@ -22,6 +24,9 @@ python scripts/run_notebooks/runner.py --pattern "*_skl*.ipynb"
2224

2325
# Exclude PyMC and sklearn notebooks (run others)
2426
python scripts/run_notebooks/runner.py --exclude-pattern _pymc --exclude-pattern _skl
27+
28+
# Run notebooks in parallel (requires joblib)
29+
python scripts/run_notebooks/runner.py --parallel
2530
```
2631

2732
## CI Integration

scripts/run_notebooks/runner.py

Lines changed: 53 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,12 @@
2424
import logging
2525
from pathlib import Path
2626
from tempfile import NamedTemporaryFile
27+
from uuid import uuid4
2728

29+
# Monkey-patch nbclient to handle display_id=None for widget updates.
30+
# This fixes an issue where ipywidgets/tqdm progress bars cause
31+
# "assert display_id is not None" errors in nbclient.
32+
import nbclient.client
2833
import papermill
2934
import yaml
3035
from nbformat.notebooknode import NotebookNode
@@ -40,6 +45,21 @@
4045
SKIP_NOTEBOOKS_FILE = HERE / "skip_notebooks.yml"
4146
SKIP_NOTEBOOKS = set(yaml.safe_load(SKIP_NOTEBOOKS_FILE.read_text()))
4247

48+
_original_output = nbclient.client.NotebookClient.output
49+
50+
51+
def _patched_output(self, outs, msg, display_id, cell_index):
52+
"""Patched output method that catches assertion errors from widget updates."""
53+
try:
54+
return _original_output(self, outs, msg, display_id, cell_index)
55+
except AssertionError:
56+
# Silently skip messages that cause display_id assertion errors
57+
# (typically from ipywidgets/tqdm progress bar updates)
58+
return None
59+
60+
61+
nbclient.client.NotebookClient.output = _patched_output
62+
4363

4464
def setup_logging() -> None:
4565
logging.basicConfig(
@@ -48,13 +68,26 @@ def setup_logging() -> None:
4868
)
4969

5070

71+
def generate_random_id() -> str:
72+
return str(uuid4())
73+
74+
75+
def clear_cell_outputs(cells: list) -> None:
76+
"""Clear all outputs from cells to avoid widget state issues with nbclient."""
77+
for cell in cells:
78+
if cell.get("cell_type") == "code":
79+
cell["outputs"] = []
80+
cell["execution_count"] = None
81+
82+
5183
def inject_mock_code(cells: list) -> None:
5284
"""Inject mock pm.sample code at the start of the notebook."""
85+
clear_cell_outputs(cells)
5386
cells.insert(
5487
0,
5588
NotebookNode(
56-
id="mock-injection",
57-
execution_count=0,
89+
id=f"code-injection-{generate_random_id()}",
90+
execution_count=sum(map(ord, "Mock pm.sample")),
5891
cell_type="code",
5992
metadata={"tags": []},
6093
outputs=[],
@@ -133,6 +166,12 @@ def parse_args() -> argparse.Namespace:
133166
dest="exclude_patterns",
134167
help="Pattern to exclude from notebook names (can be used multiple times)",
135168
)
169+
parser.add_argument(
170+
"--parallel",
171+
action="store_true",
172+
default=False,
173+
help="Run notebooks in parallel when possible.",
174+
)
136175
return parser.parse_args()
137176

138177

@@ -149,7 +188,17 @@ def parse_args() -> argparse.Namespace:
149188
for nb in notebooks:
150189
logging.info(f" - {nb.name}")
151190

152-
for notebook in notebooks:
153-
run_notebook(notebook)
191+
if args.parallel:
192+
try:
193+
from joblib import Parallel, delayed
194+
except ImportError as exc:
195+
raise ImportError(
196+
"Parallel execution requires joblib. Install it or run without --parallel."
197+
) from exc
198+
199+
Parallel(n_jobs=-1)(delayed(run_notebook)(notebook) for notebook in notebooks)
200+
else:
201+
for notebook in notebooks:
202+
run_notebook(notebook)
154203

155204
logging.info("All notebooks completed successfully!")

0 commit comments

Comments
 (0)