diff --git a/.env.example b/.env.example index 12a099f..d95c1e0 100644 --- a/.env.example +++ b/.env.example @@ -15,3 +15,15 @@ AMI_ANTENNA_API_BASE_URL=http://localhost:8000/api/v2 AMI_ANTENNA_API_AUTH_TOKEN=your_antenna_auth_token_here AMI_ANTENNA_API_BATCH_SIZE=4 AMI_ANTENNA_SERVICE_NAME="AMI Data Companion" + +# DataLoader subprocess hygiene (see ami-data-companion#140, #145) +# multiprocessing context to start DataLoader workers from. +# Recommended: "forkserver" — avoids inheriting the parent process heap +# (which on production GPU workers carries large CUDA / pinned-memory +# state and leaks into shared memory). Allowed: fork | spawn | forkserver +# Set to an empty string to let PyTorch use its default ("fork" on Linux). +AMI_ANTENNA_API_DATALOADER_MP_CONTEXT=forkserver +# Per-batch DataLoader timeout in seconds. Converts a silent +# subprocess hang (the original #140 failure mode) into a RuntimeError +# that the supervisor can restart from. Set to 0 to disable. +AMI_ANTENNA_API_DATALOADER_TIMEOUT_S=300 diff --git a/trapdata/antenna/datasets.py b/trapdata/antenna/datasets.py index 7ecc7bd..3154fd0 100644 --- a/trapdata/antenna/datasets.py +++ b/trapdata/antenna/datasets.py @@ -69,6 +69,7 @@ import requests import torch +import torch.multiprocessing as mp import torch.utils.data import torchvision from PIL import Image @@ -417,12 +418,21 @@ def get_rest_dataloader( DataLoader num_workers > 0 is safe here because Antenna dequeues tasks atomically — each worker subprocess gets a unique set of tasks. + Subprocess hygiene (see #140, #145): we explicitly set + multiprocessing_context (default ``forkserver``) and a non-zero + ``timeout`` to keep the DataLoader from inheriting stale parent state + via ``fork`` and from hanging forever when a subprocess dies + mid-shared-memory-cleanup. Both knobs are env-configurable so a single + bad pipeline can fall back without touching the others. + Args: job_id: Job ID to fetch tasks for settings: Settings object. Relevant fields: - antenna_api_base_url / antenna_api_auth_token - antenna_api_batch_size (tasks per API call and GPU batch size) - num_workers (DataLoader subprocesses) + - antenna_api_dataloader_mp_context (fork / spawn / forkserver / "") + - antenna_api_dataloader_timeout_s (per-batch timeout, seconds) """ dataset = RESTDataset( base_url=settings.antenna_api_base_url, @@ -431,6 +441,36 @@ def get_rest_dataloader( batch_size=settings.antenna_api_batch_size, ) + # Resolve multiprocessing context. Empty string / unset means "let + # PyTorch choose" (today: "fork" on Linux), which is the historical + # behavior. Anything else must be one of the multiprocessing start + # methods supported by the host. Only relevant when num_workers > 0 + # because num_workers=0 runs the dataset inline in the main process. + # We query mp.get_all_start_methods() (rather than hardcoding the + # ("fork", "spawn", "forkserver") tuple) so that the validation matches + # what the running interpreter / platform actually supports — e.g. macOS + # under Python 3.13+ no longer lists "fork" by default. + mp_context = ( + getattr(settings, "antenna_api_dataloader_mp_context", "forkserver") or "" + ).strip().lower() + if settings.num_workers > 0 and mp_context: + allowed = set(mp.get_all_start_methods()) + if mp_context not in allowed: + raise ValueError( + f"antenna_api_dataloader_mp_context must be one of " + f"{sorted(allowed)!r} or empty; got {mp_context!r}" + ) + dataloader_mp_context: object | None = mp.get_context(mp_context) + else: + dataloader_mp_context = None + + # Per-batch timeout. PyTorch interprets 0 as "wait forever" (the + # behavior that caused the silent hangs in #140); we treat negative or + # zero as "disable the guard" so an operator can opt out if needed, + # but the default is a 5-minute ceiling. + timeout_s = int(getattr(settings, "antenna_api_dataloader_timeout_s", 300) or 0) + dataloader_timeout = timeout_s if timeout_s > 0 and settings.num_workers > 0 else 0 + return torch.utils.data.DataLoader( dataset, batch_size=1, # We collate manually in rest_collate_fn, so set batch_size=1 here @@ -439,6 +479,8 @@ def get_rest_dataloader( pin_memory=True, persistent_workers=settings.num_workers > 0, prefetch_factor=4 if settings.num_workers > 0 else None, + multiprocessing_context=dataloader_mp_context, + timeout=dataloader_timeout, ) diff --git a/trapdata/antenna/tests/test_dataloader_hygiene.py b/trapdata/antenna/tests/test_dataloader_hygiene.py new file mode 100644 index 0000000..78b631e --- /dev/null +++ b/trapdata/antenna/tests/test_dataloader_hygiene.py @@ -0,0 +1,128 @@ +"""Regression test for DataLoader subprocess-hygiene settings. + +This test guards the fix shipped for ami-data-companion#140 / #145: +``get_rest_dataloader()`` must apply the ``multiprocessing_context`` and +``timeout`` knobs from settings. Without these, the DataLoader inherits the +parent's heap via ``fork`` (leaking CUDA / pinned-memory state into shared +memory) and silently hangs forever when a subprocess dies mid-batch. + +The RSS-growth side of the regression is already covered by +``test_memory_leak.py``; this file focuses on the *configuration* surface so +the fix can't be silently regressed by a future refactor of +``get_rest_dataloader``. +""" + +from types import SimpleNamespace +from unittest import TestCase + +import pytest + +from trapdata.antenna.datasets import get_rest_dataloader + + +def _make_settings(**overrides) -> SimpleNamespace: + """Build a minimal duck-typed Settings object. + + Using SimpleNamespace (not MagicMock) so that attribute access on a + *missing* field raises AttributeError — that lets the test catch a + typo'd setting name on the production side rather than swallowing it. + """ + defaults = dict( + antenna_api_base_url="http://testserver/api/v2", + antenna_api_auth_token="test-token", + antenna_api_batch_size=2, + num_workers=0, + antenna_api_dataloader_mp_context="forkserver", + antenna_api_dataloader_timeout_s=300, + ) + defaults.update(overrides) + return SimpleNamespace(**defaults) + + +class TestDataLoaderHygieneDefaults(TestCase): + """The default config must apply both new knobs.""" + + def test_num_workers_zero_does_not_set_mp_context(self): + """num_workers=0 = no subprocesses, so mp_context must stay None. + + Setting multiprocessing_context on a num_workers=0 DataLoader is a + no-op at best and a TypeError in some torch versions. + """ + loader = get_rest_dataloader(job_id=1, settings=_make_settings(num_workers=0)) + assert loader.multiprocessing_context is None + assert loader.timeout == 0 # 0 = no timeout, same as PyTorch default + + def test_num_workers_positive_applies_forkserver_context_by_default(self): + """num_workers > 0 = mp_context must be the forkserver context.""" + loader = get_rest_dataloader(job_id=1, settings=_make_settings(num_workers=1)) + ctx = loader.multiprocessing_context + assert ctx is not None, "DataLoader must have an explicit multiprocessing context" + # multiprocessing.get_context returns one of the context singletons; check + # its start method matches what we configured. + assert ctx.get_start_method() == "forkserver" + + def test_num_workers_positive_applies_timeout_by_default(self): + loader = get_rest_dataloader(job_id=1, settings=_make_settings(num_workers=1)) + assert loader.timeout == 300 + + +class TestDataLoaderHygieneOverrides(TestCase): + """Operators can override or disable each knob via settings.""" + + def test_mp_context_can_be_overridden_to_spawn(self): + loader = get_rest_dataloader( + job_id=1, + settings=_make_settings(num_workers=1, antenna_api_dataloader_mp_context="spawn"), + ) + assert loader.multiprocessing_context.get_start_method() == "spawn" + + def test_mp_context_empty_string_falls_back_to_pytorch_default(self): + """Empty string = let PyTorch pick (the historical pre-fix behavior). + + Operators who need the old `fork` behavior on a specific host can + set this without a code change. + """ + loader = get_rest_dataloader( + job_id=1, + settings=_make_settings(num_workers=1, antenna_api_dataloader_mp_context=""), + ) + assert loader.multiprocessing_context is None + + def test_timeout_zero_disables_the_guard(self): + loader = get_rest_dataloader( + job_id=1, + settings=_make_settings(num_workers=1, antenna_api_dataloader_timeout_s=0), + ) + assert loader.timeout == 0 + + def test_invalid_mp_context_raises(self): + """Typoed values get caught up front instead of producing a confusing + torch error later.""" + with pytest.raises(ValueError, match="antenna_api_dataloader_mp_context"): + get_rest_dataloader( + job_id=1, + settings=_make_settings( + num_workers=1, antenna_api_dataloader_mp_context="not-a-real-method" + ), + ) + + +class TestDataLoaderHygieneBackwardsCompat(TestCase): + """A Settings object without the new fields must still work. + + Older deploys / test fixtures may not have the new fields yet; we use + getattr() with sensible defaults so the worker can keep running. + """ + + def test_missing_fields_use_defaults(self): + bare = SimpleNamespace( + antenna_api_base_url="http://testserver/api/v2", + antenna_api_auth_token="test-token", + antenna_api_batch_size=2, + num_workers=1, + ) + loader = get_rest_dataloader(job_id=1, settings=bare) + # Defaults: forkserver context + 300s timeout + assert loader.multiprocessing_context is not None + assert loader.multiprocessing_context.get_start_method() == "forkserver" + assert loader.timeout == 300 diff --git a/trapdata/antenna/worker.py b/trapdata/antenna/worker.py index 2b7e1db..c62cb40 100644 --- a/trapdata/antenna/worker.py +++ b/trapdata/antenna/worker.py @@ -3,6 +3,7 @@ from __future__ import annotations import datetime +import gc import time from collections.abc import Callable @@ -539,3 +540,23 @@ def _process_job( finally: if result_poster: result_poster.shutdown() + # Explicit DataLoader teardown — defense against DataLoader + # subprocess shared-memory cleanup races (#140) and the per-batch + # pipe-FD accumulation observed in #145. Dropping the loader + # reference here forces PyTorch to join its subprocesses + # immediately instead of relying on Python's garbage collector to + # run __del__ at some indeterminate later time. gc.collect() + # picks up the cycle the loader keeps internally; empty_cache() + # returns CUDA buffers to the allocator pool so RSS / shmem-rss + # can fall back toward baseline between jobs. + try: + del loader + except UnboundLocalError: + pass + try: + del batch_source + except UnboundLocalError: + pass + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() diff --git a/trapdata/settings.py b/trapdata/settings.py index b07e043..f1a7d2f 100644 --- a/trapdata/settings.py +++ b/trapdata/settings.py @@ -43,6 +43,19 @@ class Settings(BaseSettings): antenna_service_name: str = "AMI Data Companion" antenna_api_batch_size: int = 24 + # DataLoader subprocess hygiene (see RolnickLab/ami-data-companion#140, #145) + # multiprocessing context: "forkserver" (default) avoids the parent + # heap-inheritance behavior of "fork" that lets stale CUDA / pinned-memory + # state leak into DataLoader subprocesses as shared memory. Allowed values + # are "fork", "spawn", "forkserver", or "" (let PyTorch pick its default). + antenna_api_dataloader_mp_context: str = "forkserver" + # Per-batch timeout in seconds: if a DataLoader subprocess wedges on a + # shared-memory cleanup race (#140), the main thread will currently wait + # forever. A non-zero timeout converts that silent hang into a RuntimeError + # that supervisor can restart the worker from. 300s is well above any + # legitimate per-batch wait we have measured. + antenna_api_dataloader_timeout_s: int = 300 + @pydantic.field_validator("image_base_path", "user_data_path") def validate_path(cls, v): """