Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions .env.example
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,15 @@ AMI_ANTENNA_API_BASE_URL=http://localhost:8000/api/v2
AMI_ANTENNA_API_AUTH_TOKEN=your_antenna_auth_token_here
AMI_ANTENNA_API_BATCH_SIZE=4
AMI_ANTENNA_SERVICE_NAME="AMI Data Companion"

# DataLoader subprocess hygiene (see ami-data-companion#140, #145)
# multiprocessing context to start DataLoader workers from.
# Recommended: "forkserver" — avoids inheriting the parent process heap
# (which on production GPU workers carries large CUDA / pinned-memory
# state and leaks into shared memory). Allowed: fork | spawn | forkserver
# Set to an empty string to let PyTorch use its default ("fork" on Linux).
AMI_ANTENNA_API_DATALOADER_MP_CONTEXT=forkserver
# Per-batch DataLoader timeout in seconds. Converts a silent
# subprocess hang (the original #140 failure mode) into a RuntimeError
# that the supervisor can restart from. Set to 0 to disable.
AMI_ANTENNA_API_DATALOADER_TIMEOUT_S=300
42 changes: 42 additions & 0 deletions trapdata/antenna/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@

import requests
import torch
import torch.multiprocessing as mp
import torch.utils.data
import torchvision
from PIL import Image
Expand Down Expand Up @@ -417,12 +418,21 @@ def get_rest_dataloader(
DataLoader num_workers > 0 is safe here because Antenna dequeues
tasks atomically — each worker subprocess gets a unique set of tasks.

Subprocess hygiene (see #140, #145): we explicitly set
multiprocessing_context (default ``forkserver``) and a non-zero
``timeout`` to keep the DataLoader from inheriting stale parent state
via ``fork`` and from hanging forever when a subprocess dies
mid-shared-memory-cleanup. Both knobs are env-configurable so a single
bad pipeline can fall back without touching the others.

Args:
job_id: Job ID to fetch tasks for
settings: Settings object. Relevant fields:
- antenna_api_base_url / antenna_api_auth_token
- antenna_api_batch_size (tasks per API call and GPU batch size)
- num_workers (DataLoader subprocesses)
- antenna_api_dataloader_mp_context (fork / spawn / forkserver / "")
- antenna_api_dataloader_timeout_s (per-batch timeout, seconds)
"""
dataset = RESTDataset(
base_url=settings.antenna_api_base_url,
Expand All @@ -431,6 +441,36 @@ def get_rest_dataloader(
batch_size=settings.antenna_api_batch_size,
)

# Resolve multiprocessing context. Empty string / unset means "let
# PyTorch choose" (today: "fork" on Linux), which is the historical
# behavior. Anything else must be one of the multiprocessing start
# methods supported by the host. Only relevant when num_workers > 0
# because num_workers=0 runs the dataset inline in the main process.
# We query mp.get_all_start_methods() (rather than hardcoding the
# ("fork", "spawn", "forkserver") tuple) so that the validation matches
# what the running interpreter / platform actually supports — e.g. macOS
# under Python 3.13+ no longer lists "fork" by default.
mp_context = (
getattr(settings, "antenna_api_dataloader_mp_context", "forkserver") or ""
).strip().lower()
if settings.num_workers > 0 and mp_context:
allowed = set(mp.get_all_start_methods())
if mp_context not in allowed:
raise ValueError(
f"antenna_api_dataloader_mp_context must be one of "
f"{sorted(allowed)!r} or empty; got {mp_context!r}"
)
dataloader_mp_context: object | None = mp.get_context(mp_context)
else:
dataloader_mp_context = None

# Per-batch timeout. PyTorch interprets 0 as "wait forever" (the
# behavior that caused the silent hangs in #140); we treat negative or
# zero as "disable the guard" so an operator can opt out if needed,
# but the default is a 5-minute ceiling.
timeout_s = int(getattr(settings, "antenna_api_dataloader_timeout_s", 300) or 0)
dataloader_timeout = timeout_s if timeout_s > 0 and settings.num_workers > 0 else 0

return torch.utils.data.DataLoader(
dataset,
batch_size=1, # We collate manually in rest_collate_fn, so set batch_size=1 here
Expand All @@ -439,6 +479,8 @@ def get_rest_dataloader(
pin_memory=True,
persistent_workers=settings.num_workers > 0,
prefetch_factor=4 if settings.num_workers > 0 else None,
multiprocessing_context=dataloader_mp_context,
timeout=dataloader_timeout,
)


Expand Down
128 changes: 128 additions & 0 deletions trapdata/antenna/tests/test_dataloader_hygiene.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
"""Regression test for DataLoader subprocess-hygiene settings.

This test guards the fix shipped for ami-data-companion#140 / #145:
``get_rest_dataloader()`` must apply the ``multiprocessing_context`` and
``timeout`` knobs from settings. Without these, the DataLoader inherits the
parent's heap via ``fork`` (leaking CUDA / pinned-memory state into shared
memory) and silently hangs forever when a subprocess dies mid-batch.

The RSS-growth side of the regression is already covered by
``test_memory_leak.py``; this file focuses on the *configuration* surface so
the fix can't be silently regressed by a future refactor of
``get_rest_dataloader``.
"""

from types import SimpleNamespace
from unittest import TestCase

import pytest

from trapdata.antenna.datasets import get_rest_dataloader


def _make_settings(**overrides) -> SimpleNamespace:
"""Build a minimal duck-typed Settings object.

Using SimpleNamespace (not MagicMock) so that attribute access on a
*missing* field raises AttributeError — that lets the test catch a
typo'd setting name on the production side rather than swallowing it.
"""
defaults = dict(
antenna_api_base_url="http://testserver/api/v2",
antenna_api_auth_token="test-token",
antenna_api_batch_size=2,
num_workers=0,
antenna_api_dataloader_mp_context="forkserver",
antenna_api_dataloader_timeout_s=300,
)
defaults.update(overrides)
return SimpleNamespace(**defaults)


class TestDataLoaderHygieneDefaults(TestCase):
"""The default config must apply both new knobs."""

def test_num_workers_zero_does_not_set_mp_context(self):
"""num_workers=0 = no subprocesses, so mp_context must stay None.

Setting multiprocessing_context on a num_workers=0 DataLoader is a
no-op at best and a TypeError in some torch versions.
"""
loader = get_rest_dataloader(job_id=1, settings=_make_settings(num_workers=0))
assert loader.multiprocessing_context is None
assert loader.timeout == 0 # 0 = no timeout, same as PyTorch default

def test_num_workers_positive_applies_forkserver_context_by_default(self):
"""num_workers > 0 = mp_context must be the forkserver context."""
loader = get_rest_dataloader(job_id=1, settings=_make_settings(num_workers=1))
ctx = loader.multiprocessing_context
assert ctx is not None, "DataLoader must have an explicit multiprocessing context"
# multiprocessing.get_context returns one of the context singletons; check
# its start method matches what we configured.
assert ctx.get_start_method() == "forkserver"

def test_num_workers_positive_applies_timeout_by_default(self):
loader = get_rest_dataloader(job_id=1, settings=_make_settings(num_workers=1))
assert loader.timeout == 300


class TestDataLoaderHygieneOverrides(TestCase):
"""Operators can override or disable each knob via settings."""

def test_mp_context_can_be_overridden_to_spawn(self):
loader = get_rest_dataloader(
job_id=1,
settings=_make_settings(num_workers=1, antenna_api_dataloader_mp_context="spawn"),
)
assert loader.multiprocessing_context.get_start_method() == "spawn"

def test_mp_context_empty_string_falls_back_to_pytorch_default(self):
"""Empty string = let PyTorch pick (the historical pre-fix behavior).

Operators who need the old `fork` behavior on a specific host can
set this without a code change.
"""
loader = get_rest_dataloader(
job_id=1,
settings=_make_settings(num_workers=1, antenna_api_dataloader_mp_context=""),
)
assert loader.multiprocessing_context is None

def test_timeout_zero_disables_the_guard(self):
loader = get_rest_dataloader(
job_id=1,
settings=_make_settings(num_workers=1, antenna_api_dataloader_timeout_s=0),
)
assert loader.timeout == 0

def test_invalid_mp_context_raises(self):
"""Typoed values get caught up front instead of producing a confusing
torch error later."""
with pytest.raises(ValueError, match="antenna_api_dataloader_mp_context"):
get_rest_dataloader(
job_id=1,
settings=_make_settings(
num_workers=1, antenna_api_dataloader_mp_context="not-a-real-method"
),
)


class TestDataLoaderHygieneBackwardsCompat(TestCase):
"""A Settings object without the new fields must still work.

Older deploys / test fixtures may not have the new fields yet; we use
getattr() with sensible defaults so the worker can keep running.
"""

def test_missing_fields_use_defaults(self):
bare = SimpleNamespace(
antenna_api_base_url="http://testserver/api/v2",
antenna_api_auth_token="test-token",
antenna_api_batch_size=2,
num_workers=1,
)
loader = get_rest_dataloader(job_id=1, settings=bare)
# Defaults: forkserver context + 300s timeout
assert loader.multiprocessing_context is not None
assert loader.multiprocessing_context.get_start_method() == "forkserver"
assert loader.timeout == 300
21 changes: 21 additions & 0 deletions trapdata/antenna/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

import datetime
import gc
import time
from collections.abc import Callable

Expand Down Expand Up @@ -539,3 +540,23 @@ def _process_job(
finally:
if result_poster:
result_poster.shutdown()
# Explicit DataLoader teardown — defense against DataLoader
# subprocess shared-memory cleanup races (#140) and the per-batch
# pipe-FD accumulation observed in #145. Dropping the loader
# reference here forces PyTorch to join its subprocesses
# immediately instead of relying on Python's garbage collector to
# run __del__ at some indeterminate later time. gc.collect()
# picks up the cycle the loader keeps internally; empty_cache()
# returns CUDA buffers to the allocator pool so RSS / shmem-rss
# can fall back toward baseline between jobs.
try:
del loader
except UnboundLocalError:
pass
try:
del batch_source
except UnboundLocalError:
pass
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
13 changes: 13 additions & 0 deletions trapdata/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,19 @@ class Settings(BaseSettings):
antenna_service_name: str = "AMI Data Companion"
antenna_api_batch_size: int = 24

# DataLoader subprocess hygiene (see RolnickLab/ami-data-companion#140, #145)
# multiprocessing context: "forkserver" (default) avoids the parent
# heap-inheritance behavior of "fork" that lets stale CUDA / pinned-memory
# state leak into DataLoader subprocesses as shared memory. Allowed values
# are "fork", "spawn", "forkserver", or "" (let PyTorch pick its default).
antenna_api_dataloader_mp_context: str = "forkserver"
# Per-batch timeout in seconds: if a DataLoader subprocess wedges on a
# shared-memory cleanup race (#140), the main thread will currently wait
# forever. A non-zero timeout converts that silent hang into a RuntimeError
# that supervisor can restart the worker from. 300s is well above any
# legitimate per-batch wait we have measured.
antenna_api_dataloader_timeout_s: int = 300

@pydantic.field_validator("image_base_path", "user_data_path")
def validate_path(cls, v):
"""
Expand Down
Loading