Skip to content

Commit 3e21285

Browse files
committed
Respect requested draws in notebook mock sampling
Use the requested draw count when provided and fall back to the minimum to avoid indexing errors in notebooks that iterate over many draws.
1 parent 97db952 commit 3e21285

1 file changed

Lines changed: 6 additions & 3 deletions

File tree

scripts/run_notebooks/injected.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,9 +19,12 @@ def mock_sample(*args, **kwargs):
1919
if isinstance(first_arg, pm.Model):
2020
model = first_arg
2121

22-
# Always use MIN_DRAWS to ensure compatibility with code that iterates
23-
# over posterior samples (ignoring the requested draws for speed)
24-
n_draws = MIN_DRAWS
22+
requested_draws = kwargs.get("draws")
23+
if requested_draws is None and len(args) > 1 and isinstance(args[1], int):
24+
requested_draws = args[1]
25+
26+
# Ensure enough draws for notebook code while keeping execution fast.
27+
n_draws = max(MIN_DRAWS, requested_draws or MIN_DRAWS)
2528

2629
idata = pm.sample_prior_predictive(
2730
model=model,

0 commit comments

Comments
 (0)