Skip to content
62 changes: 52 additions & 10 deletions ami/jobs/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -908,10 +908,15 @@ def check_stale_jobs(minutes: int | None = None, dry_run: bool = False) -> list[
healthy. Default cutoff is :attr:`Job.STALLED_JOBS_MAX_MINUTES`.

For each stale job, checks Celery for a terminal task status. REVOKED is
always trusted. For async_api jobs, SUCCESS and FAILURE are only accepted
when job.progress.is_complete() — NATS workers may still be delivering
results after the Celery task finishes. All other cases result in revocation.
Async resources (NATS/Redis) are cleaned up in both branches.
always trusted. For async_api jobs, only SUCCESS is fast-pathed to a
terminal status, and only when AsyncJobStateManager.all_tasks_processed()
reports True (i.e. the Redis pending sets are drained); when Redis state is
unavailable it falls back to job.progress.is_complete(). Celery FAILURE is
not trusted for async_api jobs — update_job_failure() defers the terminal
outcome to the async result handler (FAILURE_THRESHOLD logic), so a stale
FAILED async_api job is revoked rather than forced to FAILURE. All other
cases result in revocation. Async resources (NATS/Redis) are cleaned up in
both branches.

Returns a list of dicts describing what was done to each job.
"""
Expand Down Expand Up @@ -960,20 +965,54 @@ def check_stale_jobs(minutes: int | None = None, dry_run: bool = False) -> list[
)
# Treat as unknown state — job will be revoked below.

# Only trust terminal Celery states. For async_api jobs, SUCCESS and
# FAILURE are only accepted when progress is complete — NATS workers may
# still be delivering results after the Celery task finishes.
# Only trust terminal Celery states. For async_api jobs, a SUCCESS
# Celery state is only accepted when all NATS tasks are processed —
# workers may still be delivering results after the Celery task
# finishes. Consult Redis (source of truth for SREM completeness)
# directly rather than Job.progress.is_complete(), which mirrors a
# JSONB blob racy under concurrent _update_job_progress writes since
# #1261.
#
# Celery FAILURE is deliberately NOT fast-pathed to a terminal
# status here: update_job_failure() defers post-queue run_job
# failures for async_api jobs to the async result handler, which
# decides the terminal outcome from the final processed/failed
# counts against FAILURE_THRESHOLD (a drained-but-failed Celery task
# can still resolve to SUCCESS). Trusting Celery FAILURE here would
# force the job to FAILURE and bypass that threshold logic, so a
# stale async_api job whose Celery task ended FAILURE falls through
# to the revoke branch instead.
is_terminal = celery_state in states.READY_STATES
is_async_api = job.dispatch_mode == JobDispatchMode.ASYNC_API
if is_async_api and celery_state in {states.SUCCESS, states.FAILURE} and not job.progress.is_complete():
if is_async_api and celery_state == states.SUCCESS:
processed = AsyncJobStateManager(job.pk).all_tasks_processed()
if processed is False:
Comment on lines 985 to +989
is_terminal = False
elif processed is None:
logger.warning(
"Reaper for job %s: Redis state unavailable, falling back to " "progress.is_complete()",
job.pk,
)
if not job.progress.is_complete():
is_terminal = False
# processed is True -> trust Celery SUCCESS
elif is_async_api and celery_state == states.FAILURE:
# Don't treat Celery FAILURE as authoritative for async_api jobs
# (see comment above); revoke instead of forcing FAILURE.
is_terminal = False

previous_status = job.status
if is_terminal:
if not dry_run:
job.update_status(celery_state, save=False)
job.finished_at = datetime.datetime.now()
job.save()
# Narrow the write to the fields the reaper actually mutates.
# A full save() would re-write the whole row from the snapshot
# fetched at select_for_update() time, clobbering `logs` and
# `progress.errors` that a concurrent _update_job_progress (no
# row lock since #1261) may commit. update_status() touches
# status + progress.summary.status, so `progress` is included.
job.save(update_fields=["status", "progress", "finished_at", "updated_at"])
else:
# Per-job diagnostic: surface enough state at revoke time that an
# operator can answer "why was this stalled?" without grepping
Expand All @@ -995,7 +1034,10 @@ def check_stale_jobs(minutes: int | None = None, dry_run: bool = False) -> list[
if not dry_run:
job.update_status(JobState.REVOKED, save=False)
job.finished_at = datetime.datetime.now()
job.save()
# See note on the terminal branch above: narrow the write so
# the reaper doesn't clobber a concurrent progress writer's
# `logs` / `progress.errors`.
job.save(update_fields=["status", "progress", "finished_at", "updated_at"])

# Async resource cleanup runs outside the transaction — it makes network
# calls (NATS/Redis) that should not hold the DB row lock.
Expand Down
230 changes: 230 additions & 0 deletions ami/jobs/tests/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1229,3 +1229,233 @@ def test_jobs_health_check_runs_lost_images_before_stale_jobs(self, mock_manager

job.refresh_from_db()
self.assertEqual(job.status, JobState.SUCCESS.value)


class TestCheckStaleJobsReaperGuard(TransactionTestCase):
"""Reaper guard for async_api jobs: a SUCCESS Celery state is only accepted
when AsyncJobStateManager.all_tasks_processed() reports True. The earlier
guard read Job.progress.is_complete() — racy under concurrent
_update_job_progress writes since #1261 dropped select_for_update. A
production job once landed REVOKED with NATS+Redis fully drained because a
slower committer clobbered the SUCCESS write. This class verifies the new
Redis-direct path, the unavailable-state fallback to progress.is_complete()
with a WARNING, that Celery FAILURE is not fast-pathed to a terminal
FAILURE for async_api jobs, and that sync_api jobs are unaffected.
"""

def setUp(self):
cache.clear()
self.project = Project.objects.create(name="Reaper Guard Project")
self.pipeline = Pipeline.objects.create(name="Reaper Pipeline", slug="reaper-pipeline")
self.pipeline.projects.add(self.project)
self.collection = SourceImageCollection.objects.create(name="Reaper Coll", project=self.project)

def tearDown(self):
cache.clear()

def _stale_async_job(self, *, task_id: str = "reaper-task") -> Job:
job = Job.objects.create(
job_type_key=MLJob.key,
project=self.project,
name="reaper async job",
pipeline=self.pipeline,
source_image_collection=self.collection,
dispatch_mode=JobDispatchMode.ASYNC_API,
)
job.task_id = task_id
job.update_status(JobState.STARTED, save=True)
return job

def _mark_stale(self, job: Job) -> None:
"""Push updated_at back past STALLED_JOBS_MAX_MINUTES via raw update so
Job.save() side-effects (auto_now) don't undo it. Call AFTER any helper
that touches the job model."""
Job.objects.filter(pk=job.pk).update(
updated_at=datetime.datetime.now() - datetime.timedelta(minutes=Job.STALLED_JOBS_MAX_MINUTES + 1)
)
job.refresh_from_db()

def _set_progress_clobbered(self, job: Job, total: int, processed: int) -> None:
"""Mimic the 2521 Job.progress shape: process stage at processed/total
with status STARTED, even though Redis was actually fully drained."""
progress = job.progress
collect = progress.get_stage("collect")
collect.progress = 1.0
collect.status = JobState.SUCCESS
progress.update_stage(
"process",
progress=processed / total,
status=JobState.STARTED,
processed=processed,
remaining=total - processed,
failed=0,
)
progress.update_stage(
"results",
progress=processed / total,
status=JobState.STARTED,
captures=processed,
)
job.save()

def _set_progress_complete(self, job: Job) -> None:
progress = job.progress
for key in ("collect", "process", "results"):
stage = progress.get_stage(key)
stage.progress = 1.0
stage.status = JobState.SUCCESS
job.save()

@patch("celery.result.AsyncResult")
def test_async_celery_success_redis_empty_progress_clobbered_lands_success(self, mock_async_result):
"""The 2521 case. Pre-fix, this came back REVOKED because the reaper
consulted progress.is_complete() (False, due to clobber). Post-fix,
Redis says all_tasks_processed() → True, so SUCCESS is honored."""
from ami.jobs.tasks import check_stale_jobs

job = self._stale_async_job()
ids = [str(i) for i in range(10)]
manager = AsyncJobStateManager(job.pk)
manager.initialize_job(ids)
manager.update_state(set(ids), stage="process")
manager.update_state(set(ids), stage="results")
# Clobber: progress shows mid-flight even though Redis is drained.
self._set_progress_clobbered(job, total=10, processed=9)
self._mark_stale(job)

from celery import states as celery_states

mock_async_result.return_value.state = celery_states.SUCCESS

check_stale_jobs()

job.refresh_from_db()
self.assertEqual(
job.status,
JobState.SUCCESS.value,
f"clobbered progress should not block SUCCESS when Redis says drained; got {job.status}",
)

@patch("celery.result.AsyncResult")
def test_async_celery_success_redis_pending_lands_revoked(self, mock_async_result):
"""Redis still has pending ids → genuine in-flight; reaper revokes."""
from ami.jobs.tasks import check_stale_jobs

job = self._stale_async_job()
ids = [str(i) for i in range(10)]
AsyncJobStateManager(job.pk).initialize_job(ids)
# No SREMs — pending sets still full.
self._mark_stale(job)

from celery import states as celery_states

mock_async_result.return_value.state = celery_states.SUCCESS

check_stale_jobs()

job.refresh_from_db()
self.assertEqual(job.status, JobState.REVOKED.value)

@patch("celery.result.AsyncResult")
def test_async_celery_success_redis_absent_progress_complete_lands_success(self, mock_async_result):
"""Redis state is gone (cleaned up / never initialized). Reaper falls
back to progress.is_complete(). Complete progress → SUCCESS + WARNING."""
from ami.jobs.tasks import check_stale_jobs

job = self._stale_async_job()
# Don't initialize Redis state.
self._set_progress_complete(job)
self._mark_stale(job)

from celery import states as celery_states

mock_async_result.return_value.state = celery_states.SUCCESS

with self.assertLogs("ami.jobs.tasks", level="WARNING") as cm:
check_stale_jobs()

job.refresh_from_db()
self.assertEqual(job.status, JobState.SUCCESS.value)
self.assertTrue(any("Redis state unavailable" in m for m in cm.output))

@patch("celery.result.AsyncResult")
def test_async_celery_success_redis_absent_progress_incomplete_lands_revoked(self, mock_async_result):
"""Redis absent + progress incomplete → REVOKED via fallback + WARNING."""
from ami.jobs.tasks import check_stale_jobs

job = self._stale_async_job()
# Don't initialize Redis. Default progress is fresh (not complete).
self._mark_stale(job)

from celery import states as celery_states

mock_async_result.return_value.state = celery_states.SUCCESS

with self.assertLogs("ami.jobs.tasks", level="WARNING") as cm:
check_stale_jobs()

job.refresh_from_db()
self.assertEqual(job.status, JobState.REVOKED.value)
self.assertTrue(any("Redis state unavailable" in m for m in cm.output))

@patch("celery.result.AsyncResult")
def test_sync_api_celery_success_lands_success_without_redis_check(self, mock_async_result):
"""sync_api jobs skip the Redis guard entirely — Celery's terminal
state is authoritative. No Redis state initialized; if the new path
leaked into sync_api this would REVOKE."""
from ami.jobs.tasks import check_stale_jobs

job = Job.objects.create(
job_type_key=MLJob.key,
project=self.project,
name="reaper sync job",
pipeline=self.pipeline,
source_image_collection=self.collection,
dispatch_mode=JobDispatchMode.SYNC_API,
)
job.task_id = "reaper-sync-task"
job.update_status(JobState.STARTED, save=True)
self._mark_stale(job)

from celery import states as celery_states

mock_async_result.return_value.state = celery_states.SUCCESS

check_stale_jobs()

job.refresh_from_db()
self.assertEqual(job.status, JobState.SUCCESS.value)

@patch("celery.result.AsyncResult")
def test_async_celery_failure_redis_drained_lands_revoked_not_failure(self, mock_async_result):
"""A Celery FAILURE on an async_api job is never fast-pathed to a
terminal FAILURE here, even when Redis reports the pending sets drained.

update_job_failure() deliberately defers post-queue run_job failures for
async_api jobs to the async result handler, which decides the terminal
outcome from the final processed/failed counts against FAILURE_THRESHOLD
(a drained-but-failed task can still resolve to SUCCESS). The reaper must
not pre-empt that by forcing FAILURE, so the stale job is REVOKED instead.
"""
from ami.jobs.tasks import check_stale_jobs

job = self._stale_async_job(task_id="reaper-failure-task")
ids = [str(i) for i in range(10)]
manager = AsyncJobStateManager(job.pk)
manager.initialize_job(ids)
manager.update_state(set(ids), stage="process")
manager.update_state(set(ids), stage="results")
self._mark_stale(job)

from celery import states as celery_states

mock_async_result.return_value.state = celery_states.FAILURE

check_stale_jobs()

job.refresh_from_db()
self.assertEqual(
job.status,
JobState.REVOKED.value,
f"Celery FAILURE must not be forced to terminal FAILURE for async_api; got {job.status}",
)
30 changes: 30 additions & 0 deletions ami/ml/orchestration/async_job_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,6 +228,36 @@ def get_pending_image_ids(self) -> set[str]:
return set()
return {m.decode() if isinstance(m, (bytes, bytearray)) else str(m) for m in members}

def all_tasks_processed(self) -> bool | None:
"""Tri-state truth signal for NATS-task SREM completeness across both
process and results pending sets.

True — both pending sets empty AND total > 0 (or total == 0)
False — at least one pending set has members
None — Redis state absent (cleaned up, expired, never initialized,
or transient RedisError)

Scope: tracks NATS task lifecycle only; does not know about `collect`
or any future post-results stages.
"""
try:
redis = self._get_redis()
with redis.pipeline() as pipe:
for stage in self.STAGES:
pipe.scard(self._get_pending_key(stage))
pipe.get(self._total_key)
results = pipe.execute()
except RedisError as e:
logger.warning(f"Redis error reading all_tasks_processed for job {self.job_id}: {e}")
return None
Comment on lines +231 to +252

*pending_counts, total_raw = results
if total_raw is None:
return None
if int(total_raw) == 0:
return True
return all(count == 0 for count in pending_counts)

def cleanup(self) -> None:
"""
Delete all Redis keys associated with this job.
Expand Down
Loading
Loading