diff --git a/script/split_tests.py b/script/split_tests.py index 6888372d947b2f..5b770fc913a31d 100755 --- a/script/split_tests.py +++ b/script/split_tests.py @@ -2,13 +2,19 @@ """Helper script to split test into n buckets.""" import argparse +from concurrent.futures import ProcessPoolExecutor from dataclasses import dataclass, field from math import ceil +import os from pathlib import Path import subprocess import sys from typing import Final +# tests/components has ~1000 sub-directories, which makes it the natural +# place to subdivide to keep each pytest invocation roughly equal in size. +_FAN_OUT_DIRS: Final = frozenset({"components"}) + class Bucket: """Class to hold bucket.""" @@ -164,33 +170,86 @@ def get_all_flatten(self) -> list[TestFolder | TestFile]: return result -def collect_tests(path: Path) -> TestFolder: - """Collect all tests.""" +def _collect_batch(paths: list[Path]) -> tuple[str, str, int]: + """Run pytest --collect-only on a batch of paths.""" result = subprocess.run( - ["pytest", "--collect-only", "-qq", "-p", "no:warnings", path], + ["pytest", "--collect-only", "-qq", "-p", "no:warnings", *map(str, paths)], check=False, capture_output=True, text=True, ) + return result.stdout, result.stderr, result.returncode - if result.returncode != 0: - print("Failed to collect tests:") - print(result.stderr) - print(result.stdout) - sys.exit(1) - folder = TestFolder(path) +def _iter_eligible_children(path: Path) -> list[Path]: + """Return immediate children of ``path`` that pytest should collect. - for line in result.stdout.splitlines(): - if not line.strip(): + Filters out hidden/dunder entries, non-``test_*.py`` files (so helper + modules like ``conftest.py`` and ``common.py`` are not passed as + explicit collection targets), and pycache-style directories. + """ + children: list[Path] = [] + for entry in sorted(path.iterdir()): + if entry.name.startswith((".", "_")): continue - file_path, _, total_tests = line.partition(": ") - if not path or not total_tests: - print(f"Unexpected line: {line}") + if entry.is_dir() or (entry.suffix == ".py" and entry.name.startswith("test_")): + children.append(entry) + return children + + +def _enumerate_batch_paths(path: Path) -> list[Path]: + """Return the child paths to run pytest --collect-only over. + + Files are returned as-is. Directories are expanded one level deep, with + a second level of expansion for entries named in ``_FAN_OUT_DIRS`` so the + enormous ``tests/components`` tree fans out into per-integration paths. + """ + if path.is_file(): + return [path] + + paths: list[Path] = [] + for entry in _iter_eligible_children(path): + if entry.is_dir() and entry.name in _FAN_OUT_DIRS: + paths.extend(_iter_eligible_children(entry)) + else: + paths.append(entry) + return paths + + +def collect_tests(path: Path) -> TestFolder: + """Collect all tests.""" + batch_paths = _enumerate_batch_paths(path) + if not batch_paths: + print(f"No eligible test paths found under {path}") + sys.exit(1) + workers = min(len(batch_paths), os.cpu_count() or 1) or 1 + # Round-robin chunking keeps batches roughly balanced when path + # ordering correlates with test size. + batches = [batch_paths[i::workers] for i in range(workers)] + + if workers == 1: + results = [_collect_batch(batches[0])] + else: + with ProcessPoolExecutor(max_workers=workers) as executor: + results = list(executor.map(_collect_batch, batches)) + + folder = TestFolder(path) + for stdout, stderr, returncode in results: + if returncode != 0: + print("Failed to collect tests:") + print(stderr) + print(stdout) sys.exit(1) + for line in stdout.splitlines(): + if not line.strip(): + continue + file_path, _, total_tests = line.partition(": ") + if not file_path or not total_tests: + print(f"Unexpected line: {line}") + sys.exit(1) - file = TestFile(int(total_tests), Path(file_path)) - folder.add_test_file(file) + file = TestFile(int(total_tests), Path(file_path)) + folder.add_test_file(file) return folder