Skip to content
Open
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
4f98c71
Run pytest --collect-only in parallel batches in split_tests
bdraco May 21, 2026
8dadaa2
Filter fan-out children and fail fast on empty batch list
bdraco May 21, 2026
9ed16b6
Cache per-file test counts in split_tests
bdraco May 21, 2026
5975f4b
Skip cache walking when --cache is not passed
bdraco May 21, 2026
a8bc244
DNM: test cache, touch cloud manifest only
bdraco May 21, 2026
7c18b67
DNM: test cache bust, touch cloud conftest
bdraco May 21, 2026
81e0653
Revert "DNM: test cache bust, touch cloud conftest"
bdraco May 21, 2026
22fb68b
Revert "DNM: test cache, touch cloud manifest only"
bdraco May 21, 2026
1009ce4
Merge branch 'dev' into cache-split-tests
bdraco May 22, 2026
4a6c5b5
cleanups
bdraco May 22, 2026
7c137b5
cleanup
bdraco May 22, 2026
add8a5f
Merge branch 'dev' into cache-split-tests
bdraco May 22, 2026
4033a8b
Apply suggestions from code review
bdraco May 22, 2026
584b32c
address copilot, cleanups
bdraco May 22, 2026
0ec0ea3
single pass
bdraco May 22, 2026
b2257ca
touch ups
bdraco May 22, 2026
1b6e9f5
trim
bdraco May 22, 2026
944fb1e
Merge branch 'dev' into cache-split-tests
bdraco May 22, 2026
3e289da
drop bad copilot suggest
bdraco May 22, 2026
5771b0c
Merge remote-tracking branch 'refs/remotes/upstream/cache-split-tests…
bdraco May 22, 2026
ecc8e52
make bot happy
bdraco May 22, 2026
d942262
handle bot review comments
bdraco May 22, 2026
305b5d6
preen
bdraco May 22, 2026
69efa8e
Merge branch 'dev' into cache-split-tests
bdraco May 22, 2026
7534c43
fix cache bust
bdraco May 22, 2026
9dc37a2
Merge remote-tracking branch 'upstream/cache-split-tests' into cache-…
bdraco May 22, 2026
7835a49
simplify
bdraco May 22, 2026
277a2d8
more cleanups
bdraco May 22, 2026
cab7c41
more cleanups
bdraco May 22, 2026
e9a58cd
restore
bdraco May 22, 2026
e589017
preen
bdraco May 22, 2026
d7bf7df
preen
bdraco May 22, 2026
878761c
preen
bdraco May 22, 2026
11903ac
dry
bdraco May 22, 2026
77bc932
will copilot ever end
bdraco May 22, 2026
8301add
bot comments
bdraco May 22, 2026
ecac38a
Merge branch 'dev' into cache-split-tests
bdraco May 22, 2026
1cc91cd
another round of copilot
bdraco May 22, 2026
ff1177d
Merge branch 'dev' into cache-split-tests
bdraco May 22, 2026
2c25c5a
another round of copilot comments
bdraco May 22, 2026
bf48806
Merge remote-tracking branch 'upstream/cache-split-tests' into cache-…
bdraco May 22, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 12 additions & 1 deletion .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -917,12 +917,23 @@ jobs:
key: >-
${{ runner.os }}-${{ runner.arch }}-${{ steps.python.outputs.python-version }}-${{
needs.info.outputs.python_cache_key }}
- name: Restore pytest test counts cache
uses: actions/cache@27d5ce7f107fe9357f9df03efb73ab90386fccae # v5.0.5
with:
path: pytest_test_counts.json
key: >-
pytest-counts-${{ runner.os }}-${{ runner.arch }}-${{
steps.python.outputs.python-version }}-${{ github.sha }}
restore-keys: |
pytest-counts-${{ runner.os }}-${{ runner.arch }}-${{ steps.python.outputs.python-version }}-
Comment thread
bdraco marked this conversation as resolved.
Outdated
- name: Run split_tests.py
env:
TEST_GROUP_COUNT: ${{ needs.info.outputs.test_group_count }}
run: |
. venv/bin/activate
python -m script.split_tests ${TEST_GROUP_COUNT} tests
python -m script.split_tests \
--cache pytest_test_counts.json \
${TEST_GROUP_COUNT} tests
- name: Upload pytest_buckets
uses: actions/upload-artifact@043fb46d1a93c77aae656e7c1c64a875d1fc6a0a # v7.0.1
with:
Expand Down
309 changes: 279 additions & 30 deletions script/split_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import argparse
from concurrent.futures import ProcessPoolExecutor
from dataclasses import dataclass, field
import hashlib
import json
from math import ceil
import os
from pathlib import Path
Expand All @@ -15,13 +17,15 @@
# place to subdivide to keep each pytest invocation roughly equal in size.
_FAN_OUT_DIRS: Final = frozenset({"components"})

# Cache file format version; bump on any incompatible schema change so old
# caches are ignored rather than misread.
_CACHE_VERSION: Final = 1


class Bucket:
"""Class to hold bucket."""

def __init__(
self,
):
def __init__(self) -> None:
"""Initialize bucket."""
self.total_tests = 0
self._paths: list[str] = []
Expand Down Expand Up @@ -83,7 +87,7 @@ def split_tests(self, test_folder: TestFolder) -> None:

def create_ouput_file(self) -> None:
"""Create output file."""
with Path("pytest_buckets.txt").open("w") as file:
with Path("pytest_buckets.txt").open("w", encoding="utf-8") as file:
for idx, bucket in enumerate(self._buckets):
print(f"Bucket {idx + 1} has {bucket.total_tests} tests")
Comment thread
bdraco marked this conversation as resolved.
Outdated
file.write(bucket.get_paths_line())
Expand Down Expand Up @@ -216,44 +220,283 @@ def _enumerate_batch_paths(path: Path) -> list[Path]:
return paths


def collect_tests(path: Path) -> TestFolder:
"""Collect all tests."""
batch_paths = _enumerate_batch_paths(path)
if not batch_paths:
print(f"No eligible test paths found under {path}")
sys.exit(1)
workers = min(len(batch_paths), os.cpu_count() or 1) or 1
# Round-robin chunking keeps batches roughly balanced when path
# ordering correlates with test size.
batches = [batch_paths[i::workers] for i in range(workers)]
def _hash_file(path: Path) -> str:
"""Return a short content hash for ``path``."""
return hashlib.sha256(path.read_bytes()).hexdigest()[:16]


def _walk_test_tree(root: Path) -> tuple[list[Path], list[Path]]:
"""Walk ``root`` once and return (test files, conftest files).

Uses ``os.walk`` rather than ``Path.rglob`` because it's ~2x faster on
a 5000-file tree, and we prune hidden/dunder subdirectories instead of
visiting them. Doing both walks in one pass keeps total tree I/O down.
"""
if root.is_file():
if root.name.startswith("test_") and root.suffix == ".py":
return [root], []
return [], []
Comment thread
bdraco marked this conversation as resolved.
Outdated

test_files: list[Path] = []
conftests: list[Path] = []
for dirpath, dirnames, filenames in os.walk(root):
dirnames[:] = [d for d in dirnames if not d.startswith((".", "_"))]
base = Path(dirpath)
Comment thread
bdraco marked this conversation as resolved.
for name in filenames:
if name == "conftest.py":
conftests.append(base / name)
elif name.startswith("test_") and name.endswith(".py"):
test_files.append(base / name)
test_files.sort()
conftests.sort()
return test_files, conftests


def _compute_conftest_hash(root: Path, conftests: list[Path]) -> str:
"""Return a hash that changes whenever any conftest.py under ``root`` changes.

Any change to a conftest invalidates the entire test-count cache. This is
coarse but safe: conftests can change fixture parametrization in ways the
cache cannot otherwise detect, so we just re-collect everything.
"""
digest = hashlib.sha256()
for conftest in conftests:
digest.update(str(conftest.relative_to(root)).encode())
digest.update(b"\0")
digest.update(conftest.read_bytes())
digest.update(b"\0")
return digest.hexdigest()


@dataclass
class _CacheEntry:
"""Cached test count for a single file."""

hash: str
count: int


@dataclass
class _Cache:
"""Mapping of test file path → cached entry, plus invalidation key."""

conftest_hash: str
entries: dict[str, _CacheEntry]

@classmethod
def empty(cls, conftest_hash: str = "") -> _Cache:
"""Return a new empty cache."""
return cls(conftest_hash=conftest_hash, entries={})

@classmethod
def load(cls, path: Path, current_conftest_hash: str) -> _Cache:
"""Load cache from ``path`` and invalidate it on schema/conftest drift.

Any failure (missing file, bad JSON, version drift, conftest drift)
returns an empty cache so the script just falls back to a full
collection. This is the self-healing path.
"""
try:
raw = json.loads(path.read_bytes())
except OSError, ValueError:
return cls.empty(current_conftest_hash)
if not isinstance(raw, dict) or raw.get("version") != _CACHE_VERSION:
return cls.empty(current_conftest_hash)
if raw.get("conftest_hash") != current_conftest_hash:
return cls.empty(current_conftest_hash)
files = raw.get("files")
if not isinstance(files, dict):
return cls.empty(current_conftest_hash)
entries: dict[str, _CacheEntry] = {}
for key, value in files.items():
if (
not isinstance(value, dict)
or not isinstance(value.get("hash"), str)
or not isinstance(value.get("count"), int)
):
# Skip malformed entries instead of discarding the whole cache.
continue
entries[key] = _CacheEntry(hash=value["hash"], count=value["count"])
Comment thread
bdraco marked this conversation as resolved.
Outdated
return cls(conftest_hash=current_conftest_hash, entries=entries)

def save(self, path: Path) -> None:
"""Write the cache to ``path``."""
Comment thread
bdraco marked this conversation as resolved.
Outdated
path.write_text(
json.dumps(
{
"version": _CACHE_VERSION,
"conftest_hash": self.conftest_hash,
"files": {
key: {"hash": entry.hash, "count": entry.count}
for key, entry in sorted(self.entries.items())
},
},
indent=2,
ensure_ascii=False,
)
+ "\n",
encoding="utf-8",
)


def _resolve_from_cache(
test_files: list[Path],
cache: _Cache,
root: Path,
) -> tuple[dict[Path, int], list[Path]]:
"""Split ``test_files`` into ``(cached_counts, needs_collection)``.

A file is served from cache when its content hash matches what we
previously stored; otherwise it is queued for re-collection.
"""
cached: dict[Path, int] = {}
misses: list[Path] = []
for file in test_files:
key = str(file.relative_to(root))
entry = cache.entries.get(key)
if entry is None:
misses.append(file)
continue
if entry.hash != _hash_file(file):
misses.append(file)
continue
cached[file] = entry.count
return cached, misses


def _run_collect_batches(paths: list[Path]) -> list[tuple[str, str, int]]:
"""Run pytest --collect-only across ``paths`` using a process pool."""
workers = min(len(paths), os.cpu_count() or 1) or 1
batches = [paths[i::workers] for i in range(workers)]
if workers == 1:
results = [_collect_batch(batches[0])]
else:
with ProcessPoolExecutor(max_workers=workers) as executor:
results = list(executor.map(_collect_batch, batches))
return [_collect_batch(batches[0])]
with ProcessPoolExecutor(max_workers=workers) as executor:
return list(executor.map(_collect_batch, batches))

folder = TestFolder(path)
for stdout, stderr, returncode in results:

def _parse_collect_output(stdout: str) -> dict[Path, int]:
"""Parse ``pytest --collect-only -qq`` output into ``{path: count}``."""
counts: dict[Path, int] = {}
for line in stdout.splitlines():
if not line.strip():
continue
file_path, _, total_tests = line.partition(": ")
if not file_path or not total_tests:
raise ValueError(f"Unexpected line: {line}")
counts[Path(file_path)] = int(total_tests)
return counts


def _run_pytest_collect(paths: list[Path]) -> dict[Path, int]:
"""Run pytest --collect-only across ``paths`` and parse the output."""
counts: dict[Path, int] = {}
for stdout, stderr, returncode in _run_collect_batches(paths):
if returncode != 0:
print("Failed to collect tests:")
print(stderr)
print(stdout)
sys.exit(1)
for line in stdout.splitlines():
if not line.strip():
continue
file_path, _, total_tests = line.partition(": ")
if not file_path or not total_tests:
print(f"Unexpected line: {line}")
sys.exit(1)
try:
counts.update(_parse_collect_output(stdout))
except ValueError as err:
print(err)
sys.exit(1)
return counts


file = TestFile(int(total_tests), Path(file_path))
folder.add_test_file(file)
def _collect_tests_uncached(path: Path) -> TestFolder:
"""Collect tests by handing pytest the top-level directories.

Skips the tree walk and per-file hashing; used when no cache file is
requested so the script behaves like the pre-cache implementation.
"""
batch_paths = _enumerate_batch_paths(path)
if not batch_paths:
print(f"No eligible test paths found under {path}")
sys.exit(1)

folder = TestFolder(path)
for file_path, total_tests in _run_pytest_collect(batch_paths).items():
folder.add_test_file(TestFile(total_tests, file_path))
return folder


def _collect_tests_cached(path: Path, cache_path: Path) -> TestFolder:
"""Collect tests using an on-disk cache for incremental updates."""
all_test_files, conftests = _walk_test_tree(path)
if not all_test_files:
print(f"No eligible test paths found under {path}")
sys.exit(1)

conftest_hash = _compute_conftest_hash(path, conftests)
cache = _Cache.load(cache_path, conftest_hash)

cached_counts, missing = _resolve_from_cache(all_test_files, cache, path)
print(
f"Cache: {len(cached_counts)} hits / {len(missing)} misses"
f" / {len(all_test_files)} total"
)

new_counts: dict[Path, int] = {}
if missing:
# On a full cold-cache run, hand pytest the top-level directories
# instead of 5000+ individual file paths: pytest walks dirs much
# faster than it resolves each file argument. Once any cache hits
# exist, use file-level collection so we only re-collect the diff.
if not cached_counts:
collect_paths = _enumerate_batch_paths(path)
else:
collect_paths = missing
new_counts = _run_pytest_collect(collect_paths)

counts: dict[Path, int] = {**cached_counts, **new_counts}

folder = TestFolder(path)
for file_path, total_tests in counts.items():
if total_tests == 0:
# Files with no collected tests (eg helper modules named
# test_init.py with no test functions) shouldn't enter
# bucketing, but we still cache them below as count=0 so
# they don't get re-collected next run.
continue
folder.add_test_file(TestFile(total_tests, file_path))

# Rebuild the cache from scratch on every run so deleted files are
# dropped and re-collected files get a refreshed hash.
missing_set = set(missing)
updated_entries: dict[str, _CacheEntry] = {}
for file in all_test_files:
if file in counts:
count = counts[file]
elif file in missing_set:
# We asked pytest about this file and got no count back,
# so it has no collectible tests; cache it as 0 to avoid
# repeating the work next run.
count = 0
else:
continue
updated_entries[str(file.relative_to(path))] = _CacheEntry(
hash=_hash_file(file), count=count
)
Comment thread
bdraco marked this conversation as resolved.
Outdated
_Cache(conftest_hash=conftest_hash, entries=updated_entries).save(cache_path)

return folder


def collect_tests(path: Path, cache_path: Path | None = None) -> TestFolder:
"""Collect all tests, using an on-disk cache when ``cache_path`` is set."""
if cache_path is None:
return _collect_tests_uncached(path)
if path.is_file():
# The cache keys on conftest_hash, but a single file root has no
# ancestor conftests to walk and the hash would always be empty,
# which would let stale counts survive conftest edits. Skip the
# cache for the file-root case rather than silently mis-caching.
print(f"--cache ignored: {path} is a single file")
return _collect_tests_uncached(path)
return _collect_tests_cached(path, cache_path)


def main() -> None:
"""Execute script."""
parser = argparse.ArgumentParser(description="Split tests into n buckets.")
Expand All @@ -276,11 +519,17 @@ def check_greater_0(value: str) -> int:
help="Path to the test files to split into buckets",
type=Path,
)
parser.add_argument(
"--cache",
help="Path to a JSON file used to cache per-file test counts",
type=Path,
default=None,
)

arguments = parser.parse_args()

print("Collecting tests...")
tests = collect_tests(arguments.path)
tests = collect_tests(arguments.path, arguments.cache)
tests_per_bucket = ceil(tests.total_tests / arguments.bucket_count)

bucket_holder = BucketHolder(tests_per_bucket, arguments.bucket_count)
Expand Down
Loading