diff --git a/pyproject.toml b/pyproject.toml index a76e2de..bbf9251 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,6 +16,7 @@ authors = [ {name = "sgkit Developers", email = "project@sgkit.dev"}, ] dependencies = [ + "anyio>=4", "numpy>=2", "zarr>=3.1", "click>=8.2.0", @@ -139,6 +140,7 @@ unfixable = [] [tool.ruff.lint.isort] known-third-party = [ + "anyio", "bio2zarr", "click", "cyvcf2", diff --git a/tests/test_cli.py b/tests/test_cli.py index 7d2d62d..df650ec 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -951,9 +951,13 @@ def test_size_param_rejects_invalid(self): with pytest.raises(click.UsageError): cli.SIZE.convert("not-a-size", None, None) - def test_make_reader_forwards_workers(self, fx_vcz_path): - with cli.make_reader(fx_vcz_path, readahead_workers=4) as reader: - assert reader._readahead_workers == 4 + def test_make_reader_forwards_io_concurrency(self, fx_vcz_path): + with cli.make_reader(fx_vcz_path, io_concurrency=4) as reader: + assert reader._io_concurrency == 4 + + def test_make_reader_forwards_decode_threads(self, fx_vcz_path): + with cli.make_reader(fx_vcz_path, decode_threads=2) as reader: + assert reader._decode_threads == 2 def test_make_reader_forwards_bytes(self, fx_vcz_path): with cli.make_reader(fx_vcz_path, readahead_bytes=1024) as reader: @@ -983,14 +987,15 @@ def test_view_forwards_flags(self, monkeypatch, tmp_path, fx_vcz_path): runner = ct.CliRunner() result = runner.invoke( cli.vcztools_main, - f"view --no-version --readahead-workers 4 " + f"view --no-version --io-concurrency 4 --decode-threads 3 " f"--readahead-buffer-size 100M {fx_vcz_path} " f"-o {output_path.as_posix()}", catch_exceptions=False, ) assert result.exit_code == 0 assert captured == { - "readahead_workers": 4, + "io_concurrency": 4, + "decode_threads": 3, "readahead_bytes": 100 * 1024 * 1024, } @@ -1000,13 +1005,17 @@ def test_query_forwards_flags(self, monkeypatch, tmp_path, fx_vcz_path): runner = ct.CliRunner() result = runner.invoke( cli.vcztools_main, - f"query -f '%POS\n' --readahead-workers 2 " + f"query -f '%POS\n' --io-concurrency 2 --decode-threads 1 " f"--readahead-buffer-size 1024 {fx_vcz_path} " f"-o {output_path.as_posix()}", catch_exceptions=False, ) assert result.exit_code == 0 - assert captured == {"readahead_workers": 2, "readahead_bytes": 1024} + assert captured == { + "io_concurrency": 2, + "decode_threads": 1, + "readahead_bytes": 1024, + } def test_view_plink_forwards_flags(self, monkeypatch, tmp_path, fx_vcz_path): captured = self._spy_vcz_reader_init(monkeypatch) @@ -1015,13 +1024,14 @@ def test_view_plink_forwards_flags(self, monkeypatch, tmp_path, fx_vcz_path): result = runner.invoke( cli.vcztools_main, f"view-plink --max-alleles 2 -e 'CHROM==\"X\"' " - f"--readahead-workers 8 --readahead-buffer-size 2M " + f"--io-concurrency 8 --decode-threads 5 --readahead-buffer-size 2M " f"{fx_vcz_path} --out {out.as_posix()}", catch_exceptions=False, ) assert result.exit_code == 0 assert captured == { - "readahead_workers": 8, + "io_concurrency": 8, + "decode_threads": 5, "readahead_bytes": 2 * 1024 * 1024, } diff --git a/tests/test_retrieval.py b/tests/test_retrieval.py index 47e9911..0b00344 100644 --- a/tests/test_retrieval.py +++ b/tests/test_retrieval.py @@ -1,9 +1,5 @@ import concurrent.futures as cf -import contextlib -import gc import logging -import threading -import time import numpy as np import numpy.testing as nt @@ -15,7 +11,7 @@ from vcztools import regions as regions_mod from vcztools import retrieval as retrieval_mod from vcztools import samples as samples_mod -from vcztools import utils +from vcztools import utils, zarr_direct from vcztools.bcftools_filter import BcftoolsFilter from vcztools.retrieval import CachedVariantChunk, VczReader @@ -258,67 +254,6 @@ def test_string_field(self, readahead_bytes): ) -def _make_pipeline( - root, - *, - readahead_bytes=10**9, - read_fields=None, - n_chunks=None, - executor=None, -): - """Construct a ``ReadaheadPipeline`` directly against ``root``, - matching the wiring ``VczReader.variant_chunks`` does (default - sample-chunk plan over non-null samples; one ``ChunkRead`` per - variant chunk; no view-mode column remap). - - ``executor`` is the thread pool the pipeline submits reads to. The - caller is responsible for shutting it down (e.g. via a ``with`` - block); ``None`` means "build a small pool here", which is the - common case for tests that only need the pipeline to run once and - don't share the executor with other pipelines. - """ - if read_fields is None: - read_fields = ["variant_position"] - if executor is None: - executor = cf.ThreadPoolExecutor(max_workers=2) - samples_chunk_size = int(root["sample_id"].chunks[0]) - raw_sample_ids = root["sample_id"][:] - samples_selection = np.flatnonzero(raw_sample_ids != "") - sample_chunk_plan = samples_mod.build_chunk_plan( - samples_selection, samples_chunk_size=samples_chunk_size - ) - if n_chunks is None: - n_chunks = int(root["variant_position"].cdata_shape[0]) - variants_chunk_size = int(root["variant_position"].chunks[0]) - num_variants = int(root["variant_position"].shape[0]) - plan_length = min(n_chunks * variants_chunk_size, num_variants) - variant_chunk_plan = utils.ChunkRead.simple_plan(plan_length, variants_chunk_size) - return retrieval_mod.ReadaheadPipeline( - root, - variant_chunk_plan, - sample_chunk_plan, - None, - read_fields, - readahead_bytes=readahead_bytes, - executor=executor, - ) - - -class _DepthTrackingPipeline(retrieval_mod.ReadaheadPipeline): - """Pipeline subclass that records ``len(_in_flight)`` after each - ``_refill`` call. Used to assert depth-control behaviour under - different ``readahead_bytes`` values without observing the executor - directly.""" - - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.depths = [] - - def _refill(self): - super()._refill() - self.depths.append(len(self._in_flight)) - - def _vcz_for_template_tests(): """Small VCZ exposing all four field shapes in the templates: @@ -359,29 +294,47 @@ def _sample_plan(self, root): non_null, samples_chunk_size=samples_chunk_size ) + def _get_block_reader(self, root): + cache = {} + + def get(name): + cached = cache.get(name) + if cached is not None: + return cached + reader = zarr_direct.BlockReader(root[name]) + cache[name] = reader + return reader + + return get + def test_static_field_rejected(self): root = _vcz_for_template_tests() with pytest.raises(AssertionError, match="non-variants-axis"): retrieval_mod.create_chunk_read_list( - root, self._sample_plan(root), ["sample_id"] + root, + self._sample_plan(root), + ["sample_id"], + get_block_reader=self._get_block_reader(root), ) def test_variant_axis_1d_field(self): root = _vcz_for_template_tests() + get = self._get_block_reader(root) templates = retrieval_mod.create_chunk_read_list( - root, self._sample_plan(root), ["variant_position"] + root, self._sample_plan(root), ["variant_position"], get_block_reader=get ) assert len(templates) == 1 t = templates[0] assert t.key == ("variant_position",) - assert t.arr == root["variant_position"] + assert t.block_reader is get("variant_position") # 1-D variants axis → no extra dims after the variant chunk slot. assert t.block_index_suffix == () def test_variant_axis_2d_field(self): root = _vcz_for_template_tests() + get = self._get_block_reader(root) templates = retrieval_mod.create_chunk_read_list( - root, self._sample_plan(root), ["variant_allele"] + root, self._sample_plan(root), ["variant_allele"], get_block_reader=get ) assert len(templates) == 1 t = templates[0] @@ -391,22 +344,28 @@ def test_variant_axis_2d_field(self): def test_call_field_2d_fans_out_per_sample_chunk(self): root = _vcz_for_template_tests() + get = self._get_block_reader(root) plan = self._sample_plan(root) # 4 samples, samples_chunk_size=2 → 2 sample chunks. assert len(plan.chunk_reads) == 2 - templates = retrieval_mod.create_chunk_read_list(root, plan, ["call_DP"]) + templates = retrieval_mod.create_chunk_read_list( + root, plan, ["call_DP"], get_block_reader=get + ) assert len(templates) == 2 assert [t.key for t in templates] == [("call_DP", 0), ("call_DP", 1)] - call_dp = root["call_DP"] + # All templates for call_DP share the same BlockReader instance. + assert templates[0].block_reader is templates[1].block_reader for t, cr in zip(templates, plan.chunk_reads): - assert t.arr == call_dp # 2-D (variants, samples) → suffix is (sci,), no trailing slices. assert t.block_index_suffix == (cr.index,) def test_call_field_3d_keeps_trailing_slice(self): root = _vcz_for_template_tests() + get = self._get_block_reader(root) plan = self._sample_plan(root) - templates = retrieval_mod.create_chunk_read_list(root, plan, ["call_genotype"]) + templates = retrieval_mod.create_chunk_read_list( + root, plan, ["call_genotype"], get_block_reader=get + ) assert len(templates) == len(plan.chunk_reads) for t, cr in zip(templates, plan.chunk_reads): assert t.key == ("call_genotype", cr.index) @@ -418,9 +377,13 @@ def test_multiple_fields_in_input_order(self): # templates; ordering is "fields in input order, then sample # chunks within a call_*". root = _vcz_for_template_tests() + get = self._get_block_reader(root) plan = self._sample_plan(root) templates = retrieval_mod.create_chunk_read_list( - root, plan, ["variant_position", "call_DP", "variant_allele"] + root, + plan, + ["variant_position", "call_DP", "variant_allele"], + get_block_reader=get, ) assert [t.key for t in templates] == [ ("variant_position",), @@ -436,6 +399,19 @@ class TestUpdateChunkReadList: list is reusable across every chunk in a query. """ + def _get_block_reader(self, root): + cache = {} + + def get(name): + cached = cache.get(name) + if cached is not None: + return cached + reader = zarr_direct.BlockReader(root[name]) + cache[name] = reader + return reader + + return get + def test_variant_non_call_template_prepends_variant_chunk_index(self): root = _vcz_for_template_tests() plan = samples_mod.SampleChunkPlan( @@ -443,7 +419,10 @@ def test_variant_non_call_template_prepends_variant_chunk_index(self): permutation=None, ) templates = retrieval_mod.create_chunk_read_list( - root, plan, ["variant_position", "variant_allele"] + root, + plan, + ["variant_position", "variant_allele"], + get_block_reader=self._get_block_reader(root), ) reads = retrieval_mod.update_chunk_read_list(templates, 3) keys = [r[0] for r in reads] @@ -456,7 +435,9 @@ def test_call_template_keeps_sample_chunk_index_after_variant(self): plan = samples_mod.build_chunk_plan( np.array([0, 1, 2, 3], dtype=np.int64), samples_chunk_size=2 ) - templates = retrieval_mod.create_chunk_read_list(root, plan, ["call_DP"]) + templates = retrieval_mod.create_chunk_read_list( + root, plan, ["call_DP"], get_block_reader=self._get_block_reader(root) + ) reads = retrieval_mod.update_chunk_read_list(templates, 5) assert [r[0] for r in reads] == [("call_DP", 0), ("call_DP", 1)] assert [r[2] for r in reads] == [(5, 0), (5, 1)] @@ -471,7 +452,10 @@ def test_two_calls_yield_independent_lists(self): permutation=None, ) templates = retrieval_mod.create_chunk_read_list( - root, plan, ["variant_position"] + root, + plan, + ["variant_position"], + get_block_reader=self._get_block_reader(root), ) reads_a = retrieval_mod.update_chunk_read_list(templates, 0) reads_b = retrieval_mod.update_chunk_read_list(templates, 4) @@ -480,175 +464,6 @@ def test_two_calls_yield_independent_lists(self): assert reads_b[0][2] == (4,) -class TestReadaheadPipeline: - """Direct unit tests for ``retrieval.ReadaheadPipeline``. - - The end-to-end suites cover correctness; this class targets the - pipeline's own state machine — bootstrap, budget-driven scheduling - depth, executor cleanup, and behaviour at the edges (empty plan, - empty read columns). - """ - - @staticmethod - def _vcz(num_variants=12, variants_chunk_size=3, num_samples=2): - return vcz_builder.make_vcz( - variant_contig=[0] * num_variants, - variant_position=list(range(100, 100 + num_variants)), - alleles=[("A", "T")] * num_variants, - num_samples=num_samples, - variants_chunk_size=variants_chunk_size, - ) - - def test_yields_one_chunk_per_plan_entry_in_order(self): - root = self._vcz() - pipeline = _make_pipeline(root) - indexes = [chunk.variant_chunk.index for chunk in pipeline] - assert indexes == [0, 1, 2, 3] - - def test_empty_plan_yields_nothing(self): - root = self._vcz() - pipeline = _make_pipeline(root, n_chunks=0) - assert list(pipeline) == [] - # No pending futures left after a clean drain. - assert pipeline._in_flight == [] - - def test_single_chunk_plan(self): - root = self._vcz(num_variants=3, variants_chunk_size=3) - pipeline = _make_pipeline(root) - chunks = list(pipeline) - assert len(chunks) == 1 - assert chunks[0].variant_chunk.index == 0 - - def test_bootstrap_runs_first_chunk_solo(self): - # Until the first chunk's prefetch lands the pipeline can't - # measure per-chunk bytes, so _refill must schedule exactly one - # chunk on the bootstrap path. - root = self._vcz() - pipeline = _make_pipeline(root, readahead_bytes=10**9) - gen = iter(pipeline) - # Generator hasn't run yet — no scheduling, no measurement. - assert pipeline._per_chunk_bytes is None - chunk = next(gen) - # After bootstrap the measurement is recorded and matches the - # prefetched blocks' content. - assert isinstance(pipeline._per_chunk_bytes, int) - assert pipeline._per_chunk_bytes > 0 - expected = sum(utils.array_memory_bytes(v) for v in chunk._blocks.values()) - assert pipeline._per_chunk_bytes == expected - gen.close() - - def test_readahead_bytes_zero_keeps_depth_one(self): - # Budget=0 → after every refill, exactly one chunk is queued - # ahead of the consumer (and zero on the final, plan-exhausted - # refill). - root = self._vcz(num_variants=12, variants_chunk_size=3) - with cf.ThreadPoolExecutor(max_workers=2) as executor: - pipeline = _DepthTrackingPipeline( - root, - [utils.ChunkRead(index=i, num_selected=3) for i in range(4)], - samples_mod.build_chunk_plan( - np.arange(2, dtype=np.int64), samples_chunk_size=2 - ), - None, - ["variant_position"], - readahead_bytes=0, - executor=executor, - ) - list(pipeline) - # 4 chunks → 5 refills (one per consume + the post-final empty refill). - assert pipeline.depths == [1, 1, 1, 1, 0] - - def test_large_readahead_schedules_all_remaining_after_bootstrap(self): - # Budget of 10**9 dwarfs the per-chunk cost, so the second - # refill fills with every remaining chunk in one go. - root = self._vcz(num_variants=12, variants_chunk_size=3) - with cf.ThreadPoolExecutor(max_workers=2) as executor: - pipeline = _DepthTrackingPipeline( - root, - [utils.ChunkRead(index=i, num_selected=3) for i in range(4)], - samples_mod.build_chunk_plan( - np.arange(2, dtype=np.int64), samples_chunk_size=2 - ), - None, - ["variant_position"], - readahead_bytes=10**9, - executor=executor, - ) - list(pipeline) - # Bootstrap depth=1, then post-yield-1 schedules the remaining - # 3, then drains. - assert pipeline.depths == [1, 3, 2, 1, 0] - - def test_max_in_flight_tracks_peak_depth(self): - # Budget large enough to fit every remaining chunk after the - # bootstrap; peak depth should be (plan length - 1) reached at - # the post-yield refill that schedules every remaining chunk. - root = self._vcz(num_variants=12, variants_chunk_size=3) - pipeline = _make_pipeline(root, readahead_bytes=10**9) - assert pipeline.max_in_flight == 0 - list(pipeline) - assert pipeline.max_in_flight == 3 - - def test_max_in_flight_pinned_at_one_with_zero_budget(self): - root = self._vcz(num_variants=12, variants_chunk_size=3) - pipeline = _make_pipeline(root, readahead_bytes=0) - list(pipeline) - assert pipeline.max_in_flight == 1 - - def test_empty_read_fields_does_not_infinite_loop(self): - # With no fields to prefetch the bootstrap measurement is 0 - # bytes; without the ``max(1, per_chunk_bytes)`` guard the - # budget loop would never exit. List materialises the full - # sequence. - root = self._vcz(num_variants=6, variants_chunk_size=3) - pipeline = _make_pipeline(root, read_fields=[], readahead_bytes=10**9) - chunks = list(pipeline) - assert len(chunks) == 2 - for chunk in chunks: - assert chunk._blocks == {} - - def test_chunks_have_prefetched_blocks(self): - # Every (key, future) submitted lands in chunk._blocks before - # the consumer receives the chunk. - root = self._vcz(num_variants=6, variants_chunk_size=3, num_samples=2) - pipeline = _make_pipeline( - root, - read_fields=["variant_position", "variant_contig"], - readahead_bytes=0, - ) - for chunk in pipeline: - assert ("variant_position",) in chunk._blocks - assert ("variant_contig",) in chunk._blocks - - def test_executor_outlives_full_iteration(self): - # The pipeline does not own the executor; full drain leaves - # the pool alive and ready to serve another pipeline. - root = self._vcz() - with cf.ThreadPoolExecutor(max_workers=2) as executor: - pipeline = _make_pipeline(root, executor=executor) - list(pipeline) - assert executor._shutdown is False - # Pool is still usable for a second pipeline. - second = _make_pipeline(root, executor=executor) - assert len(list(second)) > 0 - - def test_pending_futures_cancelled_on_early_break(self): - # Abandoning iteration cancels still-pending futures (those - # that hadn't started); the executor itself stays alive. - root = self._vcz(num_variants=24, variants_chunk_size=3) - with cf.ThreadPoolExecutor(max_workers=2) as executor: - pipeline = _make_pipeline(root, executor=executor, readahead_bytes=10**9) - gen = iter(pipeline) - next(gen) - in_flight_snapshot = [ - fut for _, futures in pipeline._in_flight for _, fut in futures - ] - gen.close() - assert executor._shutdown is False - for fut in in_flight_snapshot: - assert fut.cancelled() or fut.done() - - class TestVczReaderBackendsEndToEnd: """All four storage backends read the same local-directory VCZ identically. @@ -1017,9 +832,24 @@ def _make_cached_chunk( num_selected=variant_num_selected, selection=variant_selection, ) - templates = retrieval_mod.create_chunk_read_list(root, sample_chunk_plan, fields) + block_readers: dict[str, zarr_direct.BlockReader] = {} + + def get_block_reader(name): + cached = block_readers.get(name) + if cached is not None: + return cached + reader = zarr_direct.BlockReader(root[name]) + block_readers[name] = reader + return reader + + templates = retrieval_mod.create_chunk_read_list( + root, sample_chunk_plan, fields, get_block_reader=get_block_reader + ) reads = retrieval_mod.update_chunk_read_list(templates, variant_chunk.index) - blocks = {key: retrieval_mod._read_block(arr, idx) for key, arr, idx in reads} + # Build expected blocks via Zarr's high-level path; the + # CachedVariantChunk under test consumes the dict shape, not the + # source of the bytes. + blocks = {key: root[key[0]].blocks[idx] for key, _reader, idx in reads} return CachedVariantChunk( root, variant_chunk, @@ -2395,23 +2225,23 @@ def test_variant_chunks_with_static_query_field(self): assert chunk["filter_id"] is first_filter_id def test_filter_referenced_static_field_not_in_pipeline(self, monkeypatch): - # When a FILTER expression references filter_id the readahead - # pipeline must NOT submit a (filter_id,) read — the value - # comes from the reader's static cache. + # When a FILTER expression references filter_id the producer + # must NOT submit a (filter_id,) fetch — the value comes from + # the reader's static cache. seen_fields: list[str] = [] - original = retrieval_mod._read_block + original = retrieval_mod._read_block_async - def capturing_read_block(arr, block_index): - seen_fields.append(arr.path.rsplit("/", 1)[-1]) - return original(arr, block_index) + async def capturing_read_block(reader, block_index, io_lim, decode_lim): + seen_fields.append(reader._path.rsplit("/", 1)[-1]) + return await original(reader, block_index, io_lim, decode_lim) - monkeypatch.setattr(retrieval_mod, "_read_block", capturing_read_block) + monkeypatch.setattr(retrieval_mod, "_read_block_async", capturing_read_block) root = _make_filter_vcz(num_variants=9, variants_chunk_size=3) reader = make_reader(root, include='FILTER="PASS"') list(reader.variant_chunks(fields=["variant_position"])) # filter_id is referenced by the FILTER expression but is read - # from the reader cache, never via _read_block. + # from the reader cache, never via _read_block_async. assert "filter_id" not in seen_fields assert "variant_position" in seen_fields assert "variant_filter" in seen_fields @@ -2524,9 +2354,8 @@ def test_debug_per_chunk_lines(self, fx_sample_vcz, caplog): reader = VczReader(fx_sample_vcz.group) with caplog.at_level(logging.DEBUG, logger="vcztools.retrieval"): list(reader.variant_chunks(fields=["variant_position"])) - assert "ReadaheadPipeline init:" in caplog.text assert "read complete in" in caplog.text - assert "yielded" in caplog.text + assert "assembled" in caplog.text def test_trace_schedule_chunk(self, fx_sample_vcz, caplog): # schedule chunk lines fire once per chunk and are too noisy @@ -2608,287 +2437,78 @@ def test_warn_single_chunk_bound_budget(self, fx_sample_vcz, caplog): assert "Readahead budget is single-chunk-bound" in caplog.text -class TestVariantChunksPrefetch: - """``variant_chunks`` returns a one-deep prefetch iterator that - drives the inner generator in a background thread so the - consumer's per-chunk work overlaps with the producer's - per-chunk assembly. These cases lock in the wrapper's contract: - eager validation, empty-iterator short-circuit, exception - propagation, and clean worker-thread teardown.""" - - def _prefetch_threads(self): - return [t for t in threading.enumerate() if "vcztools-prefetch" in t.name] +class TestVariantChunksIterator: + """``variant_chunks()`` returns an :class:`_AsyncBackedIterator` + over a memory channel produced by the reader's anyio portal. + These cases lock in the wrapper's contract: eager validation, + empty-iterator short-circuit, exception propagation, and clean + teardown on close(). + """ def test_eager_negative_start_validation(self, fx_sample_vcz): - # Was previously raised on first next() (lazy generator); - # the wrapper validates eagerly on the call itself. reader = VczReader(fx_sample_vcz.group) with pytest.raises(ValueError, match="start must be >= 0"): reader.variant_chunks(start=-1) - def test_empty_fields_starts_no_worker(self, fx_sample_vcz): - # fields=[] short-circuits to iter(()) without spinning up - # the prefetch worker — exhausting the iterator must not - # leave a vcztools-prefetch thread alive. + def test_empty_fields_does_not_start_portal(self, fx_sample_vcz): + # fields=[] short-circuits to iter(()) before _ensure_portal(). reader = VczReader(fx_sample_vcz.group) - before = len(self._prefetch_threads()) result = list(reader.variant_chunks(fields=[])) assert result == [] - assert len(self._prefetch_threads()) == before + assert reader._portal is None - def test_exception_in_inner_gen_surfaces_to_consumer( + def test_exception_in_producer_surfaces_to_consumer( self, fx_sample_vcz, monkeypatch ): - # The wrapper retrieves each item from the worker future via - # result(); an exception raised by the inner generator must - # re-raise on the consumer's next() call rather than be - # swallowed. - sentinel = RuntimeError("boom from inner generator") - - def faulty_gen(self, *, fields=None, start=0): - yield {"variant_position": np.array([0])} - raise sentinel - - monkeypatch.setattr(VczReader, "_variant_chunks_gen", faulty_gen) - reader = VczReader(fx_sample_vcz.group) - it = reader.variant_chunks(fields=["variant_position"]) - # First chunk arrives normally. - next(it) - with pytest.raises(RuntimeError, match="boom from inner generator"): - next(it) - it.close() + # A producer that raises mid-iteration must surface its exception + # on the consumer's next() call once the channel is drained. Use + # a synthetic producer (rather than wrapping the real one) so the + # test isolates the iterator's exception-surfacing behaviour from + # the byte-budget and refill machinery. + sentinel = RuntimeError("boom from producer") + + async def faulty_producer(send_channel, ctx, telemetry): + async with send_channel: + await send_channel.send({"variant_position": np.array([42])}) + raise sentinel + + monkeypatch.setattr(retrieval_mod, "_produce_variant_chunks", faulty_producer) + with VczReader(fx_sample_vcz.group) as reader: + it = reader.variant_chunks(fields=["variant_position"]) + chunk = next(it) + assert list(chunk["variant_position"]) == [42] + with pytest.raises(RuntimeError, match="boom from producer"): + next(it) + it.close() - def test_close_terminates_worker_thread(self, fx_sample_vcz): - # After close(), no vcztools-prefetch thread should remain - # running — confirms the wrapper joins its worker pool. + def test_close_cancels_in_progress_iteration(self, fx_sample_vcz): reader = VczReader(fx_sample_vcz.group) - before = self._prefetch_threads() it = reader.variant_chunks(fields=["variant_position"]) - # Pull one chunk so the worker is definitely live. next(it) it.close() - # Pools shut down asynchronously; allow the worker a brief - # window to exit before asserting absence. - deadline = time.time() + 1.0 - while time.time() < deadline: - after = self._prefetch_threads() - if len(after) <= len(before): - break - time.sleep(0.01) - assert len(self._prefetch_threads()) <= len(before) - - -def _prefetch_threads(): - return [t for t in threading.enumerate() if "vcztools-prefetch" in t.name] - - -def _wait_for_thread_count(target, timeout=1.0): - """Block briefly while the executor's worker threads exit.""" - deadline = time.time() + timeout - while time.time() < deadline: - if len(_prefetch_threads()) <= target: - return - time.sleep(0.01) - - -class TestPrefetchIteratorDirect: - """Direct unit tests for ``_PrefetchIterator`` decoupled from - :class:`VczReader`. - - Locks in the wrapper's contract using plain Python iterables and - a controllable iterator: ordered yields, ``StopIteration`` - handling, exception propagation, ``close()`` /``__del__`` - cleanup, source-iterator close-through, and the actual - background-overlap behaviour the prefetch is meant to provide. - """ - - def test_yields_items_in_order(self): - with contextlib.closing(retrieval_mod._PrefetchIterator(iter([1, 2, 3]))) as it: - assert list(it) == [1, 2, 3] - - def test_iter_returns_self(self): - it = retrieval_mod._PrefetchIterator(iter([1])) - try: - assert iter(it) is it - finally: - it.close() - - def test_empty_source(self): - with contextlib.closing(retrieval_mod._PrefetchIterator(iter([]))) as it: - with pytest.raises(StopIteration): - next(it) - - def test_single_item_then_stopiteration(self): - with contextlib.closing(retrieval_mod._PrefetchIterator(iter([42]))) as it: - assert next(it) == 42 - with pytest.raises(StopIteration): - next(it) - - def test_repeated_stopiteration_after_exhaustion(self): - # Exhausted iterators are expected to keep raising; the - # _closed flag must not turn StopIteration into something - # else on a second pull. - with contextlib.closing(retrieval_mod._PrefetchIterator(iter([1]))) as it: - assert next(it) == 1 - for _ in range(3): - with pytest.raises(StopIteration): - next(it) - - def test_exception_on_first_item(self): - sentinel = RuntimeError("boom on first") - - def gen(): - raise sentinel - yield # pragma: no cover - - with contextlib.closing(retrieval_mod._PrefetchIterator(gen())) as it: - with pytest.raises(RuntimeError, match="boom on first"): - next(it) - - def test_exception_mid_iteration(self): - sentinel = RuntimeError("boom mid") - - def gen(): - yield 1 - yield 2 - raise sentinel - - with contextlib.closing(retrieval_mod._PrefetchIterator(gen())) as it: - assert next(it) == 1 - assert next(it) == 2 - with pytest.raises(RuntimeError, match="boom mid"): - next(it) - - def test_close_idempotent(self): - it = retrieval_mod._PrefetchIterator(iter([1, 2, 3])) - it.close() - it.close() # second call must not raise - - def test_next_after_close_raises_stopiteration(self): - it = retrieval_mod._PrefetchIterator(iter([1, 2, 3])) - it.close() + # After close, subsequent next() returns StopIteration cleanly. with pytest.raises(StopIteration): next(it) - def test_source_close_called_on_wrapper_close(self): - # The wrapper must close the underlying iterator (mirroring - # the generator-finalisation contract that variant_chunks - # callers rely on for the ``iteration done`` log). - events = [] - - class TrackingIter: - def __iter__(self): - return self - - def __next__(self): - return 1 - - def close(self): - events.append("closed") - - it = retrieval_mod._PrefetchIterator(TrackingIter()) - next(it) # ensure the worker is live - it.close() - assert events == ["closed"] - - def test_close_drains_pending_exception(self): - # If the in-flight prefetch was about to raise, close() - # must drain it without re-raising (the user has explicitly - # given up on the iterator). - def gen(): - yield 1 - raise RuntimeError("would-be uncaught") - - it = retrieval_mod._PrefetchIterator(gen()) - # Pull the first item so the worker is now computing the - # second (which will raise). - assert next(it) == 1 - # Give the worker a moment to evaluate the second next(). - deadline = time.time() + 0.5 - while time.time() < deadline and not it._next_future.done(): - time.sleep(0.005) - # close() must not propagate the worker's exception. - it.close() - - def test_close_terminates_worker_thread(self): - # Block the source long enough to confirm the worker is alive, - # then unblock and assert close() joins it. - gate = threading.Event() - - def gen(): - gate.wait(timeout=2.0) - yield 1 + def test_max_in_flight_is_one_with_zero_budget(self, fx_sample_vcz): + reader = VczReader(fx_sample_vcz.group, readahead_bytes=0) + it = reader.variant_chunks(fields=["variant_position"]) + list(it) + assert it.max_in_flight == 1 - before = len(_prefetch_threads()) - it = retrieval_mod._PrefetchIterator(gen()) - # Worker is now blocked inside _fetch waiting on gate. - assert len(_prefetch_threads()) >= before + 1 - gate.set() - it.close() - _wait_for_thread_count(before) - assert len(_prefetch_threads()) <= before - - def test_del_without_close_cleans_up_worker(self): - # __del__ must defensively close the iterator even if the - # caller never called close(). Locks in the no-thread-leak - # contract under garbage collection. - before = len(_prefetch_threads()) - it = retrieval_mod._PrefetchIterator(iter([1, 2, 3])) - next(it) - del it - gc.collect() - _wait_for_thread_count(before) - assert len(_prefetch_threads()) <= before - - def test_prefetch_runs_in_background(self): - # While the consumer holds item N, the worker should already - # have produced item N+1. Verified by recording each - # production timestamp and asserting item 1 is produced - # before the consumer pulls it. - produced_at: list[float] = [] - - def gen(): - for i in range(3): - produced_at.append(time.perf_counter()) - yield i - - with contextlib.closing(retrieval_mod._PrefetchIterator(gen())) as it: - t0 = time.perf_counter() - assert next(it) == 0 - # By now the worker has been told to compute item 1; it - # should land before the consumer next()s for it. - deadline = time.time() + 0.5 - while time.time() < deadline and len(produced_at) < 2: - time.sleep(0.001) - assert len(produced_at) >= 2 - # The second item must have been produced *after* the - # iterator started but before we ask for it. - t_pull_1 = time.perf_counter() - assert next(it) == 1 - assert produced_at[1] >= t0 - assert produced_at[1] <= t_pull_1 + 1e-3 - - def test_source_runs_on_worker_thread_not_caller(self): - # Sanity: confirm the source iterator is being driven by the - # prefetch worker, not the calling thread. - seen_threads: list[str] = [] - - def gen(): - for i in range(3): - seen_threads.append(threading.current_thread().name) - yield i - - with contextlib.closing(retrieval_mod._PrefetchIterator(gen())) as it: - list(it) - assert all("vcztools-prefetch" in name for name in seen_threads) - - def test_close_called_multiple_times_after_exhaustion(self): - # After the iterator is naturally exhausted, the wrapper - # has already shut its executor. Subsequent close() calls - # must remain no-ops. - it = retrieval_mod._PrefetchIterator(iter([1, 2])) - list(it) # exhaust - it.close() - it.close() - with pytest.raises(StopIteration): - next(it) + def test_max_in_flight_grows_with_large_budget(self): + # Synthetic VCZ with many chunks; a generous budget should let + # the producer schedule several chunks ahead of the consumer. + root = vcz_builder.make_vcz( + variant_contig=[0] * 12, + variant_position=list(range(100, 112)), + alleles=[("A", "T")] * 12, + num_samples=2, + variants_chunk_size=3, + ) + reader = VczReader(root, readahead_bytes=10**9) + it = reader.variant_chunks(fields=["variant_position"]) + list(it) + # 4 chunks; budget is huge, so by the post-bootstrap refill + # every remaining chunk gets scheduled immediately. + assert it.max_in_flight >= 2 diff --git a/tests/test_zarr_direct.py b/tests/test_zarr_direct.py new file mode 100644 index 0000000..7296e70 --- /dev/null +++ b/tests/test_zarr_direct.py @@ -0,0 +1,229 @@ +"""Tests for vcztools.zarr_direct.BlockReader. + +The contract: ``BlockReader.read_chunk(coords)`` returns the same +ndarray as ``zarr.Array.blocks[coords]`` for every chunk position, +across both Zarr v2 and v3 metadata, every codec we encounter in +real fixtures, missing chunks (fill values), and boundary chunks. +Sharded arrays are refused at construction. +""" + +import itertools +import warnings + +import anyio +import numpy as np +import numpy.testing as nt +import pytest +import zarr +import zarr.codecs +import zarr.storage + +from vcztools.zarr_direct import BlockReader + + +def _all_chunk_coords(arr: zarr.Array): + """Iterate every chunk index tuple for ``arr``.""" + return itertools.product(*[range(n) for n in arr.cdata_shape]) + + +def _assert_array_parity(arr: zarr.Array): + """Every chunk read via BlockReader equals arr.blocks[coords].""" + reader = BlockReader(arr) + for coords in _all_chunk_coords(arr): + expected = arr.blocks[coords] + got = anyio.run(reader.read_chunk, coords) + nt.assert_array_equal(got, expected, err_msg=f"mismatch at {coords}") + assert got.shape == expected.shape, f"shape mismatch at {coords}" + assert got.dtype == expected.dtype, f"dtype mismatch at {coords}" + + +@pytest.fixture +def synthetic_v3_group(): + """Tiny v3 group with one int array, default codecs (Bytes + Zstd).""" + store = zarr.storage.MemoryStore() + g = zarr.group(store=store, zarr_format=3) + arr = g.create_array(name="ints", shape=(7, 4), chunks=(3, 2), dtype=np.int32) + arr[:] = np.arange(28, dtype=np.int32).reshape(7, 4) + return g + + +@pytest.fixture +def synthetic_v2_group(): + """v2 group with default Blosc compressor.""" + store = zarr.storage.MemoryStore() + g = zarr.group(store=store, zarr_format=2) + arr = g.create_array(name="ints", shape=(7, 4), chunks=(3, 2), dtype=np.int32) + arr[:] = np.arange(28, dtype=np.int32).reshape(7, 4) + return g + + +class TestBlockReaderSynthetic: + """Parity + boundary + fill-value tests on tiny in-memory arrays.""" + + def test_v3_default_codecs(self, synthetic_v3_group): + _assert_array_parity(synthetic_v3_group["ints"]) + + def test_v2_default_codecs(self, synthetic_v2_group): + _assert_array_parity(synthetic_v2_group["ints"]) + + def test_uncompressed_v3(self): + store = zarr.storage.MemoryStore() + g = zarr.group(store=store, zarr_format=3) + arr = g.create_array( + name="x", + shape=(7,), + chunks=(3,), + dtype=np.int32, + compressors=None, + filters=None, + ) + arr[:] = np.arange(7, dtype=np.int32) + _assert_array_parity(arr) + + def test_vlen_strings_v3(self): + store = zarr.storage.MemoryStore() + g = zarr.group(store=store, zarr_format=3) + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + data = np.array( + [["hello", "world"], ["foo", ""], ["", "bar"]], dtype="/"`` prefix.""" + store = zarr.storage.MemoryStore() + # Create a v3 array directly under the store root (no group). + arr = zarr.create_array( + store=store, shape=(5,), chunks=(2,), dtype=np.int32, zarr_format=3 + ) + arr[:] = np.arange(5, dtype=np.int32) + reader = BlockReader(arr) + assert reader.chunk_key((0,)) == "c/0" + nt.assert_array_equal(anyio.run(reader.read_chunk, (0,)), [0, 1]) + + def test_decode_limiter_is_honoured(self): + """A CapacityLimiter with capacity 1 must let read_chunk complete; + we don't assert serialisation here, just that the path doesn't + deadlock or error when a limiter is supplied.""" + store = zarr.storage.MemoryStore() + g = zarr.group(store=store, zarr_format=3) + arr = g.create_array(name="x", shape=(5,), chunks=(2,), dtype=np.int32) + arr[:] = np.arange(5, dtype=np.int32) + reader = BlockReader(arr) + + async def run(): + limiter = anyio.CapacityLimiter(1) + return await reader.read_chunk((1,), decode_limiter=limiter) + + got = anyio.run(run) + nt.assert_array_equal(got, [2, 3]) + + +class TestBlockReaderShardingRefusal: + def test_sharded_array_refused(self): + store = zarr.storage.MemoryStore() + g = zarr.group(store=store, zarr_format=3) + arr = g.create_array( + name="x", + shape=(8, 8), + chunks=(8, 8), + shards=(8, 8), + dtype=np.int32, + ) + arr[:] = np.arange(64, dtype=np.int32).reshape(8, 8) + with pytest.raises(NotImplementedError, match="ShardingCodec"): + BlockReader(arr) + + +class TestBlockReaderRealFixtures: + """Parity against committed VCZ fixtures — exercises the codecs and + metadata produced by bio2zarr in real workloads (vlen UTF-8, blosc, + zstd, transpose, NaN-bearing floats, etc.).""" + + @pytest.fixture(autouse=True) + def _silence_zarr_warnings(self): + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + yield + + def _check_full_group_parity(self, root: zarr.Group): + """Read every chunk of every array in the group through both paths + and assert they match. Uses ``equal_nan=True`` since real fixtures + contain NaN-encoded missing values.""" + for name in sorted(root.array_keys()): + arr = root[name] + reader = BlockReader(arr) + for coords in _all_chunk_coords(arr): + expected = arr.blocks[coords] + got = anyio.run(reader.read_chunk, coords) + assert got.shape == expected.shape, ( + f"{name}{coords}: shape {got.shape} vs {expected.shape}" + ) + assert got.dtype == expected.dtype, ( + f"{name}{coords}: dtype {got.dtype} vs {expected.dtype}" + ) + if np.issubdtype(expected.dtype, np.floating): + nt.assert_array_equal( + np.where(np.isnan(expected), 0, expected), + np.where(np.isnan(got), 0, got), + err_msg=f"{name}{coords}", + ) + nt.assert_array_equal( + np.isnan(expected), + np.isnan(got), + err_msg=f"{name}{coords} NaN mask", + ) + else: + nt.assert_array_equal(got, expected, err_msg=f"{name}{coords}") + + def test_sample_v2(self, fx_sample_vcz): + self._check_full_group_parity(fx_sample_vcz.group) + + def test_sample_v3(self, fx_sample_vcz3): + self._check_full_group_parity(fx_sample_vcz3.group) + + def test_field_type_combos(self, fx_field_type_combos_vcz): + self._check_full_group_parity(fx_field_type_combos_vcz.group) diff --git a/uv.lock b/uv.lock index f9bad8d..039208d 100644 --- a/uv.lock +++ b/uv.lock @@ -158,6 +158,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/32/34/d4e1c02d3bee589efb5dfa17f88ea08bdb3e3eac12bc475462aec52ed223/alabaster-0.7.16-py3-none-any.whl", hash = "sha256:b46733c07dce03ae4e150330b975c75737fa60f0a7c591b6c8bf4928a28e2c92", size = 13511, upload-time = "2024-01-10T00:56:08.388Z" }, ] +[[package]] +name = "anyio" +version = "4.13.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "idna" }, + { name = "typing-extensions", marker = "python_full_version < '3.13'" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/19/14/2c5dd9f512b66549ae92767a9c7b330ae88e1932ca57876909410251fe13/anyio-4.13.0.tar.gz", hash = "sha256:334b70e641fd2221c1505b3890c69882fe4a2df910cba14d97019b90b24439dc", size = 231622, upload-time = "2026-03-24T12:59:09.671Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/da/42/e921fccf5015463e32a3cf6ee7f980a6ed0f395ceeaa45060b61d86486c2/anyio-4.13.0-py3-none-any.whl", hash = "sha256:08b310f9e24a9594186fd75b4f73f4a4152069e3853f1ed8bfbf58369f4ad708", size = 114353, upload-time = "2026-03-24T12:59:08.246Z" }, +] + [[package]] name = "appnope" version = "0.1.4" @@ -3589,6 +3602,7 @@ all = [ name = "vcztools" source = { editable = "." } dependencies = [ + { name = "anyio" }, { name = "click" }, { name = "humanfriendly" }, { name = "numpy" }, @@ -3682,6 +3696,7 @@ wheels = [ [package.metadata] requires-dist = [ + { name = "anyio", specifier = ">=4" }, { name = "click", specifier = ">=8.2.0" }, { name = "humanfriendly" }, { name = "icechunk", marker = "extra == 'icechunk'" }, diff --git a/vcztools/cli.py b/vcztools/cli.py index d10eab4..abe4b04 100644 --- a/vcztools/cli.py +++ b/vcztools/cli.py @@ -180,11 +180,23 @@ def convert(self, value, param, ctx): SIZE = _SizeParam() -readahead_workers = click.option( - "--readahead-workers", +io_concurrency = click.option( + "--io-concurrency", type=int, default=None, - help=("Worker threads servicing the cross-chunk readahead pool. Default: 32."), + help=( + "Cap on concurrent store.get calls per iteration. Default: 32. " + "Stores are async-native, so this is a coroutine count." + ), +) +decode_threads = click.option( + "--decode-threads", + type=int, + default=None, + help=( + "Size of the decode thread pool that runs codec decode (zstd, " + "blosc, etc.). Default: os.cpu_count()." + ), ) readahead_buffer_size = click.option( "--readahead-buffer-size", @@ -367,7 +379,8 @@ def make_reader( drop_genotypes=False, backend_storage=None, storage_options=None, - readahead_workers=None, + io_concurrency=None, + decode_threads=None, readahead_bytes=None, ): """Resolve file arguments and create a VczReader.""" @@ -419,7 +432,8 @@ def make_reader( ) reader = retrieval.VczReader( root, - readahead_workers=readahead_workers, + io_concurrency=io_concurrency, + decode_threads=decode_threads, readahead_bytes=readahead_bytes, ) @@ -511,7 +525,8 @@ class ViewPlinkOptions: max_alleles: int | None = None backend_storage: str | None = None storage_options: dict | None = None - readahead_workers: int | None = None + io_concurrency: int | None = None + decode_threads: int | None = None readahead_bytes: int | None = None @classmethod @@ -551,7 +566,8 @@ def view_plink_options(f): max_alleles_opt, backend_storage, storage_option, - readahead_workers, + io_concurrency, + decode_threads, readahead_buffer_size, ] for d in reversed(decorators): @@ -680,7 +696,8 @@ def index(path, nrecords, stats, backend_storage, storage_options, log_level, lo ) @backend_storage @storage_option -@readahead_workers +@io_concurrency +@decode_threads @readahead_buffer_size @log_level @log_file @@ -702,7 +719,8 @@ def query( disable_automatic_newline, backend_storage, storage_options, - readahead_workers, + io_concurrency, + decode_threads, readahead_bytes, log_level, log_file, @@ -747,7 +765,8 @@ def query( force_samples=force_samples, backend_storage=backend_storage, storage_options=parsed_storage_options, - readahead_workers=readahead_workers, + io_concurrency=io_concurrency, + decode_threads=decode_threads, readahead_bytes=readahead_bytes, ) with reader, handle_broken_pipe(output): @@ -806,7 +825,8 @@ def query( @max_alleles_opt @backend_storage @storage_option -@readahead_workers +@io_concurrency +@decode_threads @readahead_buffer_size @log_level @log_file @@ -834,7 +854,8 @@ def view( max_alleles, backend_storage, storage_options, - readahead_workers, + io_concurrency, + decode_threads, readahead_bytes, log_level, log_file, @@ -879,7 +900,8 @@ def view( drop_genotypes=drop_genotypes, backend_storage=backend_storage, storage_options=_parse_storage_options(storage_options), - readahead_workers=readahead_workers, + io_concurrency=io_concurrency, + decode_threads=decode_threads, readahead_bytes=readahead_bytes, ) subsetting_samples = ( diff --git a/vcztools/retrieval.py b/vcztools/retrieval.py index b7a3f8f..b926165 100644 --- a/vcztools/retrieval.py +++ b/vcztools/retrieval.py @@ -1,14 +1,20 @@ import concurrent.futures as cf import dataclasses import functools +import itertools import logging +import os +import threading import time +import weakref +import anyio +import anyio.from_thread import numpy as np from vcztools import regions as regions_mod from vcztools import samples as samples_mod -from vcztools import utils +from vcztools import utils, zarr_direct from vcztools import variant_filter as variant_filter_mod from vcztools.utils import ( _as_fixed_length_string, @@ -44,89 +50,63 @@ def _one_line_repr(obj) -> str: return " ".join(repr(obj).split()) -class _PrefetchIterator: - """One-deep prefetch wrapper around an iterator. - - On every ``__next__`` returns the previously prefetched item and - submits the next ``__next__`` call on the underlying iterator to - a dedicated single-worker pool. While the consumer's per-item - work runs, the producer's next item is being computed in the - background. Exceptions raised by the underlying iterator surface - on the consumer's ``__next__`` call. - - Lifetime: the worker pool is created in ``__init__`` and shut - down by ``close()`` (also called from ``__del__`` defensively to - prevent thread leaks if a caller forgets to close). +DEFAULT_READAHEAD_BYTES = 256 * 1024 * 1024 +# Default cap on concurrent ``store.get`` calls per variant_chunks() +# iteration. Stores are async-native so this is a coroutine count, not +# a thread count. +DEFAULT_IO_CONCURRENCY = 32 +# Default size of the per-reader decode thread pool. CPU-bound work +# (zstd, blosc, etc. — all GIL-releasing C) runs in this pool inside +# the reader's anyio portal. +DEFAULT_DECODE_THREADS = os.cpu_count() or 1 + + +async def _read_block_async( + reader: zarr_direct.BlockReader, + block_index: tuple, + io_limiter: anyio.CapacityLimiter, + decode_limiter: anyio.CapacityLimiter, +) -> np.ndarray: + """Fetch the chunks covered by ``block_index`` and assemble them + into a single ndarray, matching the result of ``arr.blocks[idx]``. + + ``block_index`` may mix integer chunk indices with ``slice`` + objects (typically ``slice(None)`` over a non-variants axis). + Slices are resolved to their concrete chunk-coord ranges via + :attr:`BlockReader.cdata_shape`, every chunk is fetched + concurrently inside an :func:`anyio.create_task_group`, and the + results are assembled with :func:`numpy.block`. """ + coord_ranges = [] + for d, idx in enumerate(block_index): + if isinstance(idx, slice): + n_chunks = reader.cdata_shape[d] + coord_ranges.append(list(range(*idx.indices(n_chunks)))) + else: + coord_ranges.append([idx]) + coords_list = list(itertools.product(*coord_ranges)) - _SENTINEL = object() + async def fetch_one(coords): + async with io_limiter: + return await reader.read_chunk(coords, decode_limiter) - def __init__(self, source): - self._source = source - self._executor = cf.ThreadPoolExecutor( - max_workers=1, thread_name_prefix="vcztools-prefetch" - ) - self._next_future = self._executor.submit(self._fetch) - self._closed = False + if len(coords_list) == 1: + return await fetch_one(coords_list[0]) + fetched: dict[tuple, np.ndarray] = {} + async with anyio.create_task_group() as tg: + for coords in coords_list: - def _fetch(self): - try: - return next(self._source) - except StopIteration: - return self._SENTINEL + async def one(c=coords): + fetched[c] = await fetch_one(c) - def __iter__(self): - return self + tg.start_soon(one) - def __next__(self): - if self._closed: - raise StopIteration - result = self._next_future.result() - if result is self._SENTINEL: - self._closed = True - self._executor.shutdown(wait=False) - raise StopIteration - self._next_future = self._executor.submit(self._fetch) - return result + def build(axis, prefix): + if axis == len(coord_ranges): + return fetched[tuple(prefix)] + return [build(axis + 1, prefix + [c]) for c in coord_ranges[axis]] - def close(self): - if self._closed: - return - self._closed = True - # Drain the in-flight fetch so the worker isn't left producing - # into the void; the result (next item, sentinel, or - # exception) is no longer needed. - try: - self._next_future.result() - except BaseException: - pass - # Plain iterators (e.g. list_iterator) have no close(); only - # generators and similar resource-holding iterators do. - source_close = getattr(self._source, "close", None) - if source_close is not None: - source_close() - self._executor.shutdown(wait=True) - - def __del__(self): - # Defensive: prevent thread leaks if a caller forgets close(). - # Mirrors generator finalisation semantics. - try: - self.close() - except Exception: - pass - - -DEFAULT_READAHEAD_BYTES = 256 * 1024 * 1024 -# Fixed by design: these threads dispatch I/O to the Zarr backend -# (which already handles its own async/decompression parallelism), -# so usable parallelism is dispatch-bound and GIL-capped rather than -# scaling with cpu_count. -DEFAULT_READAHEAD_WORKERS = 32 - - -def _read_block(arr, block_index: tuple) -> np.ndarray: - """Fetch one Zarr block by block-index tuple.""" - return arr.blocks[block_index] + return np.block(build(0, [])) @dataclasses.dataclass(frozen=True) @@ -188,7 +168,7 @@ class BlockReadTemplate: """ key: tuple - arr: object + block_reader: zarr_direct.BlockReader block_index_suffix: tuple @@ -196,16 +176,19 @@ def create_chunk_read_list( root, sample_chunk_plan: "samples_mod.SampleChunkPlan", fields, + *, + get_block_reader, ) -> list[BlockReadTemplate]: """Resolve ``fields`` to a list of :class:`BlockReadTemplate` once per query, before any variant chunk is visited. Each template carries the variant-chunk-independent parts of one - block read — the cache key, the resolved Zarr array, and the - suffix of ``block_index`` that follows the variant chunk index - slot. :func:`update_chunk_read_list` substitutes a specific - variant chunk index to produce executor-ready - ``(key, arr, block_index)`` tuples. + block read — the cache key, a :class:`vcztools.zarr_direct.BlockReader` + bound to the field's array, and the suffix of ``block_index`` + that follows the variant chunk index slot. + :func:`update_chunk_read_list` substitutes a specific variant + chunk index to produce executor-ready + ``(key, block_reader, block_index)`` tuples. For a ``call_*`` field the template list fans out one entry per sample chunk in ``sample_chunk_plan.chunk_reads``; for any other @@ -213,22 +196,31 @@ def create_chunk_read_list( Every field must be variant-axis. Static (no-variants-axis) fields are handled by the reader's static-field cache, not the pipeline. + + ``get_block_reader`` is a callable ``str -> BlockReader`` that + typically routes through :meth:`VczReader._get_block_reader` so + BlockReader instances are cached for the reader's lifetime. """ templates = [] for field in fields: arr = root[field] assert _has_variants_axis(arr), f"non-variants-axis field in pipeline: {field}" + reader = get_block_reader(field) if not field.startswith("call_"): suffix = (slice(None),) * (arr.ndim - 1) templates.append( - BlockReadTemplate(key=(field,), arr=arr, block_index_suffix=suffix) + BlockReadTemplate( + key=(field,), block_reader=reader, block_index_suffix=suffix + ) ) else: for cr in sample_chunk_plan.chunk_reads: suffix = (cr.index,) + (slice(None),) * (arr.ndim - 2) templates.append( BlockReadTemplate( - key=(field, cr.index), arr=arr, block_index_suffix=suffix + key=(field, cr.index), + block_reader=reader, + block_index_suffix=suffix, ) ) return templates @@ -239,215 +231,369 @@ def update_chunk_read_list( variant_chunk_index: int, ) -> list[tuple]: """Substitute ``variant_chunk_index`` into each template, returning - the ``[(key, arr, block_index), ...]`` list that - :class:`ReadaheadPipeline` submits to the thread pool. The + the ``[(key, block_reader, block_index), ...]`` list that + :func:`_produce_variant_chunks` issues fetches against. The template list itself is unchanged. """ reads = [] for t in templates: block_index = (variant_chunk_index,) + t.block_index_suffix - reads.append((t.key, t.arr, block_index)) + reads.append((t.key, t.block_reader, block_index)) return reads -class ReadaheadPipeline: - """Cross-chunk readahead controller for ``VczReader.variant_chunks``. - - Resolves the per-field read pattern once at init via - :func:`create_chunk_read_list`, then for each entry in - ``variant_chunk_plan``: substitute the variant chunk index via - :func:`update_chunk_read_list`, submit the resulting block reads - to the reader-owned thread pool, collect results into a - ``blocks`` dict, then construct a :class:`CachedVariantChunk` - over those prefetched blocks and yield it. Cross-chunk readahead - overlaps later chunks' reads with the current chunk's processing - in the consumer. - - The executor is supplied by the caller (typically - :class:`VczReader`) and lives across pipelines. Multiple pipelines - on the same reader — for example the BedEncoder shared-reader - fanout — submit to a single shared pool. When iteration is - abandoned mid-stream (consumer breaks early, generator closed, - exception propagates), the pipeline cancels its own pending - futures only; the executor itself outlives the pipeline. - - The window is sized by a byte budget rather than a chunk count: - one variant-chunk prefetch can vary from a few MB (single - sample-chunk read for a partial subset) to >1 GB (every - sample chunk for a wide call_* field), so a count-based depth - would either starve fan-out or blow RSS. - - Per-chunk byte cost is *measured*, not predicted: the first chunk - is scheduled solo, and once its prefetched blocks land we sum - their :func:`vcztools.utils.array_memory_bytes` and use that as - the window-sizing estimate for every later chunk. The estimate is - approximate — - - - The bootstrap chunk runs even when its prefetch alone exceeds - ``readahead_bytes`` (the alternative is to never make progress). - - Chunks can drift in content size across the iteration, especially - when variable-length string fields are in the prefetch set, so - later chunks may over- or under-shoot the budget. - - ``readahead_bytes=0`` pins pipeline depth at 1: the consumer's - current chunk plus exactly one prefetched ahead. The pipeline - never goes below depth 1 (the consumer would have to wait for - every chunk's I/O on the request thread), so this is the - smallest readahead the caller can ask for. +@dataclasses.dataclass(frozen=True, slots=True) +class _VariantChunksContext: + """Frozen snapshot of every per-iteration parameter for the + variant_chunks producer. Built on the calling thread; consumed by + the producer task running on the reader's anyio portal so that + mid-iteration mutations of reader state do not affect the + in-flight iteration. """ - def __init__( - self, - root, - variant_chunk_plan: list[utils.ChunkRead], - sample_chunk_plan: "samples_mod.SampleChunkPlan", - output_columns: np.ndarray | None, - read_fields, - *, - readahead_bytes: int, - executor: cf.ThreadPoolExecutor, - ): - self.root = root - self._variant_chunk_plan_iter = iter(variant_chunk_plan) - self._sample_chunk_plan = sample_chunk_plan - self._output_columns = output_columns - self._read_templates = create_chunk_read_list( - root, sample_chunk_plan, read_fields - ) - self._readahead_bytes = readahead_bytes - # Set on the first chunk's completion in __iter__. - self._per_chunk_bytes: int | None = None - # Wall-clock seconds spent on the most recent chunk's block reads; - # consumed by VczReader.variant_chunks to attribute per-chunk time - # into "read" vs. "assemble". - self.last_chunk_read_seconds: float | None = None - # Sum of utils.array_memory_bytes() over the most recent chunk's - # decompressed blocks; consumed by VczReader.variant_chunks to - # accumulate retrieval-side throughput stats. - self.last_chunk_bytes: int | None = None - self._executor = executor - # in_flight: list of (variant_chunk, [(blocks_key, Future), ...]). - # The futures list is empty when the chunk needs no prefetch. - self._in_flight: list = [] - # Peak ``len(_in_flight)`` observed across the iteration; the - # consumer reads it after iteration to assess how effective - # the readahead window was at staying ahead of demand. - self.max_in_flight = 0 - logger.debug( - f"ReadaheadPipeline init: {len(read_fields)} read_fields, " - f"{len(self._read_templates)} templates, " - f"readahead_bytes={_fmt_bytes(readahead_bytes)}" - ) + root: object + variant_chunk_plan: list + sample_chunk_plan: "samples_mod.SampleChunkPlan" + output_columns: np.ndarray | None + read_fields: tuple + query_fields: tuple + filter_fields: frozenset + referenced_static_fields: dict + variant_filter: variant_filter_mod.VariantFilter | None + variants_chunk_size: int + io_limiter: anyio.CapacityLimiter + decode_limiter: anyio.CapacityLimiter + get_block_reader: object + readahead_bytes: int + + +async def _create_memory_channel(buffer_size: int): + """Return an anyio (send, recv) pair sized to ``buffer_size``.""" + return anyio.create_memory_object_stream[dict](max_buffer_size=buffer_size) + + +def _close_portal_cm(portal_cm) -> None: + """Exit the BlockingPortal context manager, swallowing any + teardown error so a misbehaving portal can't wedge GC. Used both + by :meth:`VczReader.close` and by the weakref finalizer that + arms close on garbage collection. + """ + try: + portal_cm.__exit__(None, None, None) + except Exception: + pass + + +async def _produce_variant_chunks(send_channel, ctx, telemetry): + """Async producer for :meth:`VczReader.variant_chunks`. + + Iterates ``ctx.variant_chunk_plan`` in order, fetching every + chunk's blocks concurrently inside an inner task group, applying + the variant filter, materialising the output dict, and sending it + through ``send_channel``. Backpressure is byte-budget controlled: + the in-flight window expands until measured per-chunk bytes times + in-flight count would exceed ``ctx.readahead_bytes``. + + ``telemetry`` is a shared dict that the iterator reads after + iteration completes. Updated keys: ``max_in_flight``, + ``last_chunk_bytes``, ``chunks_visited``, ``chunks_yielded``, + ``variants_yielded``, ``bytes_yielded``, ``producer_assemble_total``, + ``producer_read_total``. + """ + templates = create_chunk_read_list( + ctx.root, + ctx.sample_chunk_plan, + ctx.read_fields, + get_block_reader=ctx.get_block_reader, + ) + plan_iter = iter(ctx.variant_chunk_plan) + in_flight: list[dict] = [] + per_chunk_bytes: int | None = None + iter_start = time.perf_counter() - def _schedule_one(self) -> bool: - """Plan the next variant chunk's reads and submit them to the - thread pool. Returns False once the plan is exhausted.""" - try: - variant_chunk = next(self._variant_chunk_plan_iter) - except StopIteration: - return False - reads = update_chunk_read_list(self._read_templates, variant_chunk.index) - futures = [ - (key, self._executor.submit(_read_block, arr, block_index)) - for key, arr, block_index in reads - ] - self._in_flight.append((variant_chunk, futures)) - if len(self._in_flight) > self.max_in_flight: - self.max_in_flight = len(self._in_flight) - logger.log( - TRACE, - f"schedule chunk {variant_chunk.index}: {len(futures)} blocks submitted", - ) - return True - - def _refill(self) -> None: - # Until the first chunk has been measured we can't size the - # window — schedule exactly one chunk and wait for its reads - # to land. Subsequent refills fall through to the budget loop. - if self._per_chunk_bytes is None: - if len(self._in_flight) == 0: - self._schedule_one() - return - # Always keep at least one chunk in flight; otherwise honour the - # byte budget (use an effective per-chunk cost of at least 1 to - # avoid an infinite loop when read_fields is empty). - per_chunk = max(1, self._per_chunk_bytes) - while len(self._in_flight) == 0 or ( - len(self._in_flight) * per_chunk < self._readahead_bytes - ): - if not self._schedule_one(): - return + try: + async with send_channel, anyio.create_task_group() as tg: - def __iter__(self): - try: - self._refill() - while len(self._in_flight) > 0: - variant_chunk, futures = self._in_flight.pop(0) - future_to_key = {fut: key for key, fut in futures} + async def fetch_one(slot): + vc = slot["vc"] + t0 = time.perf_counter() blocks: dict[tuple, np.ndarray] = {} - read_start = time.perf_counter() - for fut in cf.as_completed(future_to_key): - blocks[future_to_key[fut]] = fut.result() - read_seconds = time.perf_counter() - read_start - self.last_chunk_read_seconds = read_seconds - chunk_bytes = sum(utils.array_memory_bytes(v) for v in blocks.values()) - self.last_chunk_bytes = chunk_bytes - if self._per_chunk_bytes is None: - self._per_chunk_bytes = chunk_bytes - if self._readahead_bytes > 0 and chunk_bytes > 0: + async with anyio.create_task_group() as inner: + for key, reader, block_index in update_chunk_read_list( + templates, vc.index + ): + + async def one(k=key, r=reader, bi=block_index): + blocks[k] = await _read_block_async( + r, bi, ctx.io_limiter, ctx.decode_limiter + ) + + inner.start_soon(one) + slot["blocks"] = blocks + slot["read_seconds"] = time.perf_counter() - t0 + slot["bytes"] = sum( + utils.array_memory_bytes(v) for v in blocks.values() + ) + slot["done"].set() + + def schedule_one() -> bool: + try: + vc = next(plan_iter) + except StopIteration: + return False + slot = {"vc": vc, "done": anyio.Event()} + in_flight.append(slot) + if len(in_flight) > telemetry["max_in_flight"]: + telemetry["max_in_flight"] = len(in_flight) + tg.start_soon(fetch_one, slot) + logger.log( + TRACE, + f"schedule chunk {vc.index}: {len(templates)} blocks submitted", + ) + return True + + def refill(): + # Until the first chunk has been measured we can't size the + # window — schedule exactly one chunk and wait for its reads + # to land. Subsequent refills fall through to the budget loop. + if per_chunk_bytes is None: + if not in_flight: + schedule_one() + return + # Always keep at least one chunk in flight; otherwise honour + # the byte budget (use an effective per-chunk cost of at + # least 1 to avoid an infinite loop when read_fields is + # empty). + pcb = max(1, per_chunk_bytes) + while not in_flight or (len(in_flight) * pcb < ctx.readahead_bytes): + if not schedule_one(): + return + + refill() + while in_flight: + slot = in_flight.pop(0) + await slot["done"].wait() + vc = slot["vc"] + blocks = slot["blocks"] + chunk_bytes = slot["bytes"] + read_seconds = slot["read_seconds"] + telemetry["last_chunk_bytes"] = chunk_bytes + telemetry["bytes_yielded"] += chunk_bytes + telemetry["producer_read_total"] += read_seconds + telemetry["chunks_visited"] += 1 + if per_chunk_bytes is None: + per_chunk_bytes = chunk_bytes + if ctx.readahead_bytes > 0 and chunk_bytes > 0: window_chunks = max( - 1, self._readahead_bytes // max(1, chunk_bytes) + 1, ctx.readahead_bytes // max(1, chunk_bytes) ) else: window_chunks = 1 logger.info( f"Per-chunk read size: {_fmt_bytes(chunk_bytes)} " - f"(chunk {variant_chunk.index}); window will hold " + f"(chunk {vc.index}); window will hold " f"~{window_chunks} chunks under budget " - f"{_fmt_bytes(self._readahead_bytes)}" + f"{_fmt_bytes(ctx.readahead_bytes)}" ) if ( - self._readahead_bytes > 0 - and chunk_bytes > self._readahead_bytes / 2 + ctx.readahead_bytes > 0 + and chunk_bytes > ctx.readahead_bytes / 2 ): logger.warning( f"Readahead budget is single-chunk-bound: " f"per-chunk {_fmt_bytes(chunk_bytes)} > " - f"half of {_fmt_bytes(self._readahead_bytes)}; " + f"half of {_fmt_bytes(ctx.readahead_bytes)}; " f"the prefetch window will be capped at ~1 " f"in flight regardless of worker count. " f"Increase readahead_bytes to widen the window." ) logger.debug( - f"chunk {variant_chunk.index} read complete in " - f"{read_seconds:.2f}s ({len(blocks)} blocks, " - f"{_fmt_bytes(chunk_bytes)})" + f"chunk {vc.index} read complete in {read_seconds:.2f}s " + f"({len(blocks)} blocks, {_fmt_bytes(chunk_bytes)})" ) - yield CachedVariantChunk( - self.root, - variant_chunk, - sample_chunk_plan=self._sample_chunk_plan, - output_columns=self._output_columns, + + assemble_start = time.perf_counter() + cached = CachedVariantChunk( + ctx.root, + vc, + sample_chunk_plan=ctx.sample_chunk_plan, + output_columns=ctx.output_columns, blocks=blocks, ) - # After the consumer drops the previous chunk reference, - # top the pipeline back up. - self._refill() - finally: - cancelled = 0 - for _variant_chunk, futures in self._in_flight: - for _key, fut in futures: - if fut.cancel(): - cancelled += 1 - if cancelled > 0: - logger.debug(f"cancelled {cancelled} pending futures") + + variants_selection = None + sample_filter_pass = None + if ctx.variant_filter is not None: + filter_data = { + f: ctx.referenced_static_fields[f] + if f in ctx.referenced_static_fields + else cached.filter_view(f) + for f in ctx.filter_fields + } + filter_result = ctx.variant_filter.evaluate(filter_data) + if filter_result.ndim == 1: + variants_selection = filter_result + logger.debug( + f"chunk {vc.index}: filter pass " + f"{int(filter_result.sum())}/{filter_result.size} " + f"variants" + ) + else: + variants_selection = filter_result.any(axis=1) + sample_filter_pass = filter_result[variants_selection] + if ctx.output_columns is not None: + sample_filter_pass = sample_filter_pass[ + :, ctx.output_columns + ] + logger.debug( + f"chunk {vc.index}: filter pass " + f"{int(variants_selection.sum())}/" + f"{variants_selection.size} variants, " + f"{int(filter_result.sum())}/{filter_result.size} " + f"sample cells" + ) + + if variants_selection is not None and not variants_selection.any(): + telemetry["producer_assemble_total"] += ( + time.perf_counter() - assemble_start + ) + refill() + continue + + chunk_data: dict[str, np.ndarray] = {} + for field in ctx.query_fields: + if field in ctx.referenced_static_fields: + chunk_data[field] = ctx.referenced_static_fields[field] + continue + if field == "variant_index": + value = _absolute_variant_indexes(vc, ctx.variants_chunk_size) + else: + value = cached.output_view(field) + if variants_selection is not None: + value = value[variants_selection] + chunk_data[field] = value + if sample_filter_pass is not None: + chunk_data["sample_filter_pass"] = sample_filter_pass + + non_static_query = [ + f for f in ctx.query_fields if f not in ctx.referenced_static_fields + ] + if len(non_static_query) > 0: + chunk_variants = len(chunk_data[non_static_query[0]]) + elif variants_selection is not None: + chunk_variants = int(variants_selection.sum()) + else: + chunk_variants = 0 + telemetry["variants_yielded"] += chunk_variants + telemetry["chunks_yielded"] += 1 + + assemble_seconds = time.perf_counter() - assemble_start + telemetry["producer_assemble_total"] += assemble_seconds + logger.debug( + f"chunk {vc.index}: assembled {chunk_variants} variants in " + f"{assemble_seconds:.2f}s" + ) + + await send_channel.send(chunk_data) + refill() + finally: + elapsed = time.perf_counter() - iter_start + mib = telemetry["bytes_yielded"] / (1024 * 1024) + rate = mib / elapsed if elapsed > 0 else 0.0 + logger.info( + f"variant_chunks: iteration done in {elapsed:.2f}s " + f"({telemetry['chunks_visited']} chunks visited, " + f"{telemetry['chunks_yielded']} yielded, " + f"{telemetry['variants_yielded']} variants, " + f"{mib:.1f} MiB retrieved, {rate:.1f} MiB/s, " + f"max readahead depth {telemetry['max_in_flight']}); " + f"producer_assemble={telemetry['producer_assemble_total']:.2f}s, " + f"producer_read_wait={telemetry['producer_read_total']:.2f}s" + ) + + +class _AsyncBackedIterator: + """Sync iterator over an anyio memory channel populated by + :func:`_produce_variant_chunks` running on the reader's portal. + + Bridges async → sync via ``portal.call(recv.receive)``. ``close()`` + cancels the producer and shuts the channel; ``__del__`` closes + defensively. Producer exceptions surface on ``__next__`` once the + channel is drained. + + ``max_in_flight`` and ``last_chunk_bytes`` expose telemetry the + producer updates as it runs; they are intended for diagnostics + and tests, not for the user-facing chunk dicts themselves. + """ + + def __init__(self, portal, recv, fut, telemetry): + self._portal = portal + self._recv = recv + self._fut = fut + self._telemetry = telemetry + self._closed = False + + @property + def max_in_flight(self) -> int: + return self._telemetry["max_in_flight"] + + @property + def last_chunk_bytes(self) -> int | None: + return self._telemetry["last_chunk_bytes"] + + def __iter__(self): + return self + + def __next__(self): + if self._closed: + raise StopIteration + try: + return self._portal.call(self._recv.receive) + except anyio.EndOfStream: + self._closed = True + try: + self._fut.result() + except BaseExceptionGroup as eg: + # anyio wraps every producer-side error in an + # ExceptionGroup via the task group's __aexit__. Unwrap + # a single-exception group so callers see the original + # error type (handle_exception in cli.py matches against + # ValueError, not BaseExceptionGroup). + if len(eg.exceptions) == 1: + raise eg.exceptions[0] from None + raise + raise StopIteration from None + + def close(self): + if self._closed: + return + self._closed = True + # Cancel the producer; close the receive end so a producer that's + # blocked on send.send() unblocks via BrokenResourceError; wait + # (bounded) for the task to actually finish so __del__ can't wedge + # interpreter shutdown. + self._fut.cancel() + try: + self._portal.call(self._recv.aclose) + except Exception: + pass + try: + self._fut.result(timeout=5) + except (cf.CancelledError, cf.TimeoutError, BaseException): + pass + + def __del__(self): + try: + self.close() + except Exception: + pass class CachedVariantChunk: """View assembler over prefetched blocks for one variant chunk visit. - Constructed by :class:`ReadaheadPipeline` once its block reads have - completed; performs no I/O itself. The constructor takes: + Constructed by :func:`_produce_variant_chunks` once its block reads + have completed; performs no I/O itself. The constructor takes: - ``blocks`` — ``{key: ndarray}`` of prefetched Zarr blocks keyed by ``(field,)`` for variants-axis non-``call_*`` reads and @@ -646,14 +792,14 @@ class VczReader: resolved by reading the corresponding property during default iteration. - The reader owns a single :class:`concurrent.futures.ThreadPoolExecutor` - that every :class:`ReadaheadPipeline` it spawns submits work to. - Use as a context manager (``with VczReader(root) as reader:``) so - the pool is torn down deterministically on exit. Multiple - pipelines (e.g. several :class:`vcztools.plink.BedEncoder` - instances driven concurrently against the same reader, or - repeated ``variant_chunks()`` calls) share the pool — submission - is thread-safe at the executor level. + The reader owns a single anyio :class:`BlockingPortal` (started + lazily on first ``variant_chunks()``) plus two + :class:`anyio.CapacityLimiter` knobs sized by ``io_concurrency`` + and ``decode_threads``. Use as a context manager + (``with VczReader(root) as reader:``) so the portal is shut down + deterministically on exit. Multiple concurrent + ``variant_chunks()`` callers share the portal — startup is + guarded by an internal lock. Parameters ---------- @@ -662,37 +808,50 @@ class VczReader: dataset. Use :func:`vcztools.open_zarr` to open a path (local, remote, or zip) with the desired backend before constructing the reader. - readahead_workers - Worker count for the readahead thread pool. ``None`` - (default) uses :data:`DEFAULT_READAHEAD_WORKERS` (``32``). - The pool is created at construction; this parameter has no - post-init knob. + io_concurrency + Cap on concurrent ``store.get`` calls per iteration. ``None`` + (default) uses :data:`DEFAULT_IO_CONCURRENCY` (``32``). Each + store fetch is async-native, so this is a coroutine count, + not a thread count. readahead_bytes Cap, in bytes, on the cross-chunk readahead window. ``None`` (default) uses :data:`DEFAULT_READAHEAD_BYTES` (256 MiB). - ``0`` pins pipeline depth at 1 (one chunk prefetched ahead of - the consumer); the pipeline cannot go lower. + ``0`` pins window depth at 1 (one chunk prefetched ahead of + the consumer). + decode_threads + Size of the decode thread pool that runs codec ``_decode_sync`` + calls inside the reader's portal. ``None`` (default) uses + :data:`DEFAULT_DECODE_THREADS` (``os.cpu_count()``). """ def __init__( self, root, *, - readahead_workers: int | None = None, + io_concurrency: int | None = None, readahead_bytes: int | None = None, + decode_threads: int | None = None, ): self.root = root self.readahead_bytes = readahead_bytes - workers = ( - readahead_workers - if readahead_workers is not None - else DEFAULT_READAHEAD_WORKERS + self._io_concurrency = ( + io_concurrency if io_concurrency is not None else DEFAULT_IO_CONCURRENCY ) - self._executor = cf.ThreadPoolExecutor( - max_workers=workers, - thread_name_prefix="vcztools-readahead", + self._decode_threads = ( + decode_threads if decode_threads is not None else DEFAULT_DECODE_THREADS ) - self._readahead_workers = workers + # Portal + limiters are started lazily on first variant_chunks() + # call — readers that only consume static metadata never pay + # the asyncio-thread cost. Concurrent variant_chunks() callers + # racing to start the portal would otherwise leak event-loop + # threads, so the bring-up is guarded by ``_portal_lock``. + self._portal_lock = threading.Lock() + self._portal_cm = None + self._portal: anyio.from_thread.BlockingPortal | None = None + self._io_limiter: anyio.CapacityLimiter | None = None + self._decode_limiter: anyio.CapacityLimiter | None = None + self._finalizer: weakref.finalize | None = None + self._block_readers: dict[str, zarr_direct.BlockReader] = {} self._sample_chunk_plan = None self._variant_chunk_plan = None self._samples_selection = None @@ -708,15 +867,64 @@ def __init__( f"num_variants={self.num_variants}, num_samples={self.num_samples}, " f"variants_chunk_size={self.variants_chunk_size}, " f"samples_chunk_size={self.samples_chunk_size}, " - f"readahead_workers={workers}, " + f"io_concurrency={self._io_concurrency}, " + f"decode_threads={self._decode_threads}, " f"readahead_bytes={readahead_bytes}" ) + def _ensure_portal(self): + """Start the anyio portal and limiter pair on first access; + return ``(portal, io_limiter, decode_limiter)``. Idempotent and + thread-safe — concurrent first callers see the same portal. + + A :func:`weakref.finalize` arms ``close`` so the portal thread + is shut down even if the caller doesn't use the reader as a + context manager — without it, the portal's daemon thread stays + alive past interpreter shutdown and joins on the asyncio + default executor's non-daemon workers can wedge process exit. + """ + with self._portal_lock: + if self._portal is None: + self._portal_cm = anyio.from_thread.start_blocking_portal( + backend="asyncio", name="vcztools-portal" + ) + self._portal = self._portal_cm.__enter__() + self._io_limiter = anyio.CapacityLimiter(self._io_concurrency) + self._decode_limiter = anyio.CapacityLimiter(self._decode_threads) + if self._finalizer is None: + self._finalizer = weakref.finalize( + self, + _close_portal_cm, + self._portal_cm, + ) + return self._portal, self._io_limiter, self._decode_limiter + + def _get_block_reader(self, name: str) -> zarr_direct.BlockReader: + """Lazily construct and cache one :class:`BlockReader` per field.""" + cached = self._block_readers.get(name) + if cached is not None: + return cached + reader = zarr_direct.BlockReader(self.root[name]) + self._block_readers[name] = reader + return reader + + def close(self) -> None: + """Tear down owned resources: the anyio portal (if started).""" + if self._finalizer is not None: + self._finalizer.detach() + self._finalizer = None + if self._portal_cm is not None: + _close_portal_cm(self._portal_cm) + self._portal_cm = None + self._portal = None + self._io_limiter = None + self._decode_limiter = None + def __enter__(self): return self def __exit__(self, exc_type, exc_value, traceback): - self._executor.shutdown(wait=True) + self.close() return False def _load_static_field(self, name: str) -> np.ndarray: @@ -1151,21 +1359,8 @@ def variant_chunks( if fields is not None and len(fields) == 0: return iter(()) - return _PrefetchIterator(self._variant_chunks_gen(fields=fields, start=start)) - - def _variant_chunks_gen( - self, - *, - fields: list[str] | None = None, - start: int = 0, - ): - """Inner generator backing :meth:`variant_chunks`. The public - entry point validates arguments eagerly and wraps this - generator in a one-deep prefetch iterator; tests that need - the raw single-threaded behaviour (e.g. for deterministic - in-thread state inspection) can drive this directly.""" # Snapshot the filter so a mid-iteration set_variant_filter - # can't change behaviour for this generator. + # can't change behaviour for this iteration. variant_filter = self.variant_filter query_fields = self._resolve_query_fields(fields) filter_fields = frozenset( @@ -1220,142 +1415,42 @@ def _variant_chunks_gen( f"{len(variant_chunk_plan)} variant chunks, " f"{len(sample_chunk_plan.chunk_reads)} sample chunks, " f"readahead_bytes={_fmt_bytes(readahead_bytes)}, " - f"workers={self._readahead_workers}); " + f"io_concurrency={self._io_concurrency}, " + f"decode_threads={self._decode_threads}); " f"query_fields={list(query_fields)}, " f"read_fields={read_fields}" ) - pipeline = ReadaheadPipeline( - self.root, - variant_chunk_plan, - sample_chunk_plan, - output_columns, - read_fields, + portal, io_limiter, decode_limiter = self._ensure_portal() + ctx = _VariantChunksContext( + root=self.root, + variant_chunk_plan=variant_chunk_plan, + sample_chunk_plan=sample_chunk_plan, + output_columns=output_columns, + read_fields=tuple(read_fields), + query_fields=tuple(query_fields), + filter_fields=filter_fields, + referenced_static_fields=referenced_static_fields, + variant_filter=variant_filter, + variants_chunk_size=self.variants_chunk_size, + io_limiter=io_limiter, + decode_limiter=decode_limiter, + get_block_reader=self._get_block_reader, readahead_bytes=readahead_bytes, - executor=self._executor, ) - chunks_visited = 0 - chunks_yielded = 0 - variants_yielded = 0 - bytes_yielded = 0 - # Per-iteration time accounting. consumer_wait isolates the gap - # between yielding chunk N and the consumer pulling chunk N+1 - # (minus the producer's own read wait for N+1), exposing the - # downstream encoder/writer cost which the iterator otherwise - # can't see. - producer_assemble_total = 0.0 - producer_read_total = 0.0 - consumer_wait_total = 0.0 - last_yield_t: float | None = None - iter_start = time.perf_counter() - try: - for chunk in pipeline: - chunks_visited += 1 - chunk_start = time.perf_counter() - read_seconds = pipeline.last_chunk_read_seconds or 0.0 - producer_read_total += read_seconds - if last_yield_t is not None: - gap = chunk_start - last_yield_t - consumer_wait_total += max(0.0, gap - read_seconds) - bytes_yielded += pipeline.last_chunk_bytes or 0 - # variants_selection: 1-D bool over the chunk's variant axis, or - # None meaning "no filter, keep every variant". - # sample_filter_pass: 2-D bool over (surviving variants, output - # samples) for sample-scope filters only; published so query.py - # can emit only matching samples in FORMAT-loop queries. - variants_selection = None - sample_filter_pass = None - if variant_filter is not None: - filter_data = { - f: referenced_static_fields[f] - if f in referenced_static_fields - else chunk.filter_view(f) - for f in filter_fields - } - filter_result = variant_filter.evaluate(filter_data) - if filter_result.ndim == 1: - # Variant-scope filter: one bool per variant. - variants_selection = filter_result - logger.debug( - f"chunk {chunk.variant_chunk.index}: filter pass " - f"{int(filter_result.sum())}/{filter_result.size} variants" - ) - else: - # Sample-scope filter: a variant survives if at least one - # sample matched; the surviving rows are kept so the query - # layer can emit only matching samples. - variants_selection = filter_result.any(axis=1) - sample_filter_pass = filter_result[variants_selection] - if output_columns is not None: - # Filter ran on the real-sample axis but output is - # the user's subset axis; reindex columns to match. - sample_filter_pass = sample_filter_pass[:, output_columns] - logger.debug( - f"chunk {chunk.variant_chunk.index}: filter pass " - f"{int(variants_selection.sum())}/" - f"{variants_selection.size} variants, " - f"{int(filter_result.sum())}/{filter_result.size} " - f"sample cells" - ) - - if variants_selection is not None and not variants_selection.any(): - continue - - chunk_data = {} - for field in query_fields: - if field in referenced_static_fields: - chunk_data[field] = referenced_static_fields[field] - continue - if field == "variant_index": - value = _absolute_variant_indexes( - chunk.variant_chunk, self.variants_chunk_size - ) - else: - value = chunk.output_view(field) - if variants_selection is not None: - value = value[variants_selection] - chunk_data[field] = value - if sample_filter_pass is not None: - chunk_data["sample_filter_pass"] = sample_filter_pass - - # Count surviving variants from a dynamic (variants-axis) - # query field if there is one; static fields have the - # store-wide axis length, not the per-chunk variant count. - non_static_query = [ - f for f in query_fields if f not in referenced_static_fields - ] - if len(non_static_query) > 0: - chunk_variants = len(chunk_data[non_static_query[0]]) - elif variants_selection is not None: - chunk_variants = int(variants_selection.sum()) - else: - chunk_variants = 0 - variants_yielded += chunk_variants - chunks_yielded += 1 - total_seconds = time.perf_counter() - chunk_start + read_seconds - assemble_seconds = max(0.0, total_seconds - read_seconds) - producer_assemble_total += assemble_seconds - logger.debug( - f"chunk {chunk.variant_chunk.index}: yielded {chunk_variants} " - f"variants in {total_seconds:.2f}s " - f"(read {read_seconds:.2f}s, assemble {assemble_seconds:.2f}s)" - ) - last_yield_t = time.perf_counter() - yield chunk_data - finally: - elapsed = time.perf_counter() - iter_start - mib = bytes_yielded / (1024 * 1024) - rate = mib / elapsed if elapsed > 0 else 0.0 - logger.info( - f"variant_chunks: iteration done in {elapsed:.2f}s " - f"({chunks_visited} chunks visited, {chunks_yielded} yielded, " - f"{variants_yielded} variants, " - f"{mib:.1f} MiB retrieved, {rate:.1f} MiB/s, " - f"max readahead depth {pipeline.max_in_flight}); " - f"producer_assemble={producer_assemble_total:.2f}s, " - f"producer_read_wait={producer_read_total:.2f}s, " - f"consumer_wait={consumer_wait_total:.2f}s" - ) + telemetry: dict = { + "max_in_flight": 0, + "last_chunk_bytes": None, + "chunks_visited": 0, + "chunks_yielded": 0, + "variants_yielded": 0, + "bytes_yielded": 0, + "producer_assemble_total": 0.0, + "producer_read_total": 0.0, + } + send, recv = portal.call(_create_memory_channel, 1) + fut = portal.start_task_soon(_produce_variant_chunks, send, ctx, telemetry) + return _AsyncBackedIterator(portal, recv, fut, telemetry) def _resolve_query_fields(self, fields): if fields is not None: diff --git a/vcztools/zarr_direct.py b/vcztools/zarr_direct.py new file mode 100644 index 0000000..f671ebd --- /dev/null +++ b/vcztools/zarr_direct.py @@ -0,0 +1,136 @@ +"""Direct chunk fetch + decode without Zarr's high-level Array layer. + +Owns the path from ``(zarr.Array, chunk_coords)`` to ``np.ndarray``: + + raw bytes from store.get → codec pipeline.decode → np.ndarray + +Lets the caller drive concurrency with anyio (async store gets) and a +bounded decode pool (``anyio.CapacityLimiter``) instead of going through +``arr.blocks[idx]``, which spins up a fresh asyncio loop per call. + +Sharded arrays are explicitly unsupported. ``BlockReader`` rejects them +at construction with a clear error. +""" + +import logging + +import anyio +import numpy as np +import zarr +from zarr.abc.buffer import Buffer +from zarr.codecs.sharding import ShardingCodec +from zarr.core.array_spec import ArraySpec +from zarr.core.buffer import default_buffer_prototype + +logger = logging.getLogger(__name__) + + +class BlockReader: + """Resolves a single Zarr array to its store, key encoding, codec + pipeline, and chunk-spec factory. Constructed once per field and + reused across chunk reads. + + The output of :meth:`read_chunk` matches ``array.blocks[coords]`` + byte-for-byte: boundary chunks are clamped to the actual data + shape, missing chunks materialise as ``fill_value``. + """ + + def __init__(self, array: zarr.Array): + async_array = array.async_array + codec_pipeline = async_array.codec_pipeline + if isinstance(codec_pipeline.array_bytes_codec, ShardingCodec): + raise NotImplementedError( + f"Array {array.path!r} uses ShardingCodec, which is not " + "supported by vcztools. Re-encode the VCZ without sharding " + "(bio2zarr writes non-sharded VCZ by default)." + ) + + self._store = array.store_path.store + self._path = array.path + self._metadata = array.metadata + self._codec_pipeline = codec_pipeline + self._array_config = async_array.config + self._prototype = default_buffer_prototype() + self._shape = array.shape + self._chunk_shape = array.chunks + self._cdata_shape = tuple( + (s + c - 1) // c for s, c in zip(array.shape, array.chunks) + ) + self._dtype = array.dtype + # Zarr metadata may record fill_value=None (especially in v2); + # fall back to the dtype's default scalar (0 for ints, "" for + # strings, etc.) so missing-chunk decode matches arr.blocks[]. + fill = array.metadata.fill_value + if fill is None: + fill = array.metadata.dtype.default_scalar() + self._fill_value = fill + + @property + def cdata_shape(self) -> tuple[int, ...]: + """Number of chunks per axis (the size of the chunk grid).""" + return self._cdata_shape + + def chunk_key(self, coords: tuple[int, ...]) -> str: + """Store key for the chunk at ``coords``.""" + suffix = self._metadata.encode_chunk_key(coords) + if self._path == "": + return suffix + return f"{self._path}/{suffix}" + + def chunk_spec(self, coords: tuple[int, ...]) -> ArraySpec: + return self._metadata.get_chunk_spec( + coords, self._array_config, self._prototype + ) + + def actual_chunk_shape(self, coords: tuple[int, ...]) -> tuple[int, ...]: + """Boundary-clamped chunk shape — what ``arr.blocks[coords]`` returns. + + Boundary chunks are stored at the nominal chunk shape (padded with + fill values); ``arr.blocks[]`` slices them back to the actual + data extent. We mirror that here. + """ + result = [] + for d, coord in enumerate(coords): + stored = self._chunk_shape[d] + remaining = self._shape[d] - coord * self._chunk_shape[d] + result.append(max(0, min(stored, remaining))) + return tuple(result) + + async def fetch_chunk_bytes(self, coords: tuple[int, ...]) -> Buffer | None: + """Fetch raw chunk bytes from the store; ``None`` if absent.""" + key = self.chunk_key(coords) + return await self._store.get(key, prototype=self._prototype) + + async def decode_chunk( + self, raw: Buffer | None, coords: tuple[int, ...] + ) -> np.ndarray: + """Decode raw chunk bytes to a numpy array. + + Returns the boundary-clamped shape. ``None`` raw materialises + fill values at the actual shape. + """ + actual_shape = self.actual_chunk_shape(coords) + if raw is None: + return np.full(actual_shape, self._fill_value, dtype=self._dtype) + spec = self.chunk_spec(coords) + decoded = list(await self._codec_pipeline.decode([(raw, spec)])) + nd = decoded[0].as_numpy_array() + if nd.shape != actual_shape: + nd = nd[tuple(slice(0, s) for s in actual_shape)] + return nd + + async def read_chunk( + self, + coords: tuple[int, ...], + decode_limiter: anyio.CapacityLimiter | None = None, + ) -> np.ndarray: + """Fetch and decode a single chunk. + + ``decode_limiter`` bounds concurrent decodes when many readers + share a thread budget; pass ``None`` for unbounded. + """ + raw = await self.fetch_chunk_bytes(coords) + if decode_limiter is None: + return await self.decode_chunk(raw, coords) + async with decode_limiter: + return await self.decode_chunk(raw, coords)