Skip to content

Commit ef450b0

Browse files
committed
Increase prior predictive draws to 500 in notebook runner
Updated the mock for pm.sample to use 500 draws instead of 100 to ensure compatibility with notebook code that iterates over posterior samples, such as plot_ate which defaults to 500 draws. Adjusted documentation and injected.py accordingly.
1 parent 29c1e29 commit ef450b0

2 files changed

Lines changed: 9 additions & 2 deletions

File tree

scripts/run_notebooks/README.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ This script runs Jupyter notebooks from `docs/source/notebooks/` to validate the
44

55
## How It Works
66

7-
1. **Mocks `pm.sample()`** — Replaces MCMC sampling with prior predictive (1 chain × 100 draws) for speed
7+
1. **Mocks `pm.sample()`** — Replaces MCMC sampling with prior predictive (1 chain × 500 draws) for speed
88
2. **Uses Papermill** — Executes notebooks programmatically
99
3. **Discards outputs** — Only checks for errors, doesn't save results
1010

scripts/run_notebooks/injected.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@
44
import pymc as pm
55
import xarray as xr
66

7+
# Minimum draws needed to satisfy notebook code that iterates over posterior samples
8+
# (e.g., plot_ate uses ate_draws=500 by default)
9+
MIN_DRAWS = 500
10+
711

812
def mock_sample(*args, **kwargs):
913
"""Mock pm.sample using prior predictive sampling for speed."""
@@ -15,7 +19,10 @@ def mock_sample(*args, **kwargs):
1519
first_arg = args[0]
1620
if isinstance(first_arg, pm.Model):
1721
model = first_arg
18-
n_draws = 100
22+
23+
# Always use MIN_DRAWS to ensure compatibility with code that iterates
24+
# over posterior samples (ignoring the requested draws for speed)
25+
n_draws = MIN_DRAWS
1926

2027
idata = pm.sample_prior_predictive(
2128
model=model,

0 commit comments

Comments
 (0)