diff --git a/ami/jobs/tasks.py b/ami/jobs/tasks.py index f3c996f86..f8f199045 100644 --- a/ami/jobs/tasks.py +++ b/ami/jobs/tasks.py @@ -908,10 +908,15 @@ def check_stale_jobs(minutes: int | None = None, dry_run: bool = False) -> list[ healthy. Default cutoff is :attr:`Job.STALLED_JOBS_MAX_MINUTES`. For each stale job, checks Celery for a terminal task status. REVOKED is - always trusted. For async_api jobs, SUCCESS and FAILURE are only accepted - when job.progress.is_complete() — NATS workers may still be delivering - results after the Celery task finishes. All other cases result in revocation. - Async resources (NATS/Redis) are cleaned up in both branches. + always trusted. For async_api jobs, only SUCCESS is fast-pathed to a + terminal status, and only when AsyncJobStateManager.all_tasks_processed() + reports True (i.e. the Redis pending sets are drained); when Redis state is + unavailable it falls back to job.progress.is_complete(). Celery FAILURE is + not trusted for async_api jobs — update_job_failure() defers the terminal + outcome to the async result handler (FAILURE_THRESHOLD logic), so a stale + FAILED async_api job is revoked rather than forced to FAILURE. All other + cases result in revocation. Async resources (NATS/Redis) are cleaned up in + both branches. Returns a list of dicts describing what was done to each job. """ @@ -960,12 +965,40 @@ def check_stale_jobs(minutes: int | None = None, dry_run: bool = False) -> list[ ) # Treat as unknown state — job will be revoked below. - # Only trust terminal Celery states. For async_api jobs, SUCCESS and - # FAILURE are only accepted when progress is complete — NATS workers may - # still be delivering results after the Celery task finishes. + # Only trust terminal Celery states. For async_api jobs, a SUCCESS + # Celery state is only accepted when all NATS tasks are processed — + # workers may still be delivering results after the Celery task + # finishes. Consult Redis (source of truth for SREM completeness) + # directly rather than Job.progress.is_complete(), which mirrors a + # JSONB blob racy under concurrent _update_job_progress writes since + # #1261. + # + # Celery FAILURE is deliberately NOT fast-pathed to a terminal + # status here: update_job_failure() defers post-queue run_job + # failures for async_api jobs to the async result handler, which + # decides the terminal outcome from the final processed/failed + # counts against FAILURE_THRESHOLD (a drained-but-failed Celery task + # can still resolve to SUCCESS). Trusting Celery FAILURE here would + # force the job to FAILURE and bypass that threshold logic, so a + # stale async_api job whose Celery task ended FAILURE falls through + # to the revoke branch instead. is_terminal = celery_state in states.READY_STATES is_async_api = job.dispatch_mode == JobDispatchMode.ASYNC_API - if is_async_api and celery_state in {states.SUCCESS, states.FAILURE} and not job.progress.is_complete(): + if is_async_api and celery_state == states.SUCCESS: + processed = AsyncJobStateManager(job.pk).all_tasks_processed() + if processed is False: + is_terminal = False + elif processed is None: + logger.warning( + "Reaper for job %s: Redis state unavailable, falling back to " "progress.is_complete()", + job.pk, + ) + if not job.progress.is_complete(): + is_terminal = False + # processed is True -> trust Celery SUCCESS + elif is_async_api and celery_state == states.FAILURE: + # Don't treat Celery FAILURE as authoritative for async_api jobs + # (see comment above); revoke instead of forcing FAILURE. is_terminal = False previous_status = job.status @@ -973,7 +1006,13 @@ def check_stale_jobs(minutes: int | None = None, dry_run: bool = False) -> list[ if not dry_run: job.update_status(celery_state, save=False) job.finished_at = datetime.datetime.now() - job.save() + # Narrow the write to the fields the reaper actually mutates. + # A full save() would re-write the whole row from the snapshot + # fetched at select_for_update() time, clobbering `logs` and + # `progress.errors` that a concurrent _update_job_progress (no + # row lock since #1261) may commit. update_status() touches + # status + progress.summary.status, so `progress` is included. + job.save(update_fields=["status", "progress", "finished_at", "updated_at"]) else: # Per-job diagnostic: surface enough state at revoke time that an # operator can answer "why was this stalled?" without grepping @@ -995,7 +1034,10 @@ def check_stale_jobs(minutes: int | None = None, dry_run: bool = False) -> list[ if not dry_run: job.update_status(JobState.REVOKED, save=False) job.finished_at = datetime.datetime.now() - job.save() + # See note on the terminal branch above: narrow the write so + # the reaper doesn't clobber a concurrent progress writer's + # `logs` / `progress.errors`. + job.save(update_fields=["status", "progress", "finished_at", "updated_at"]) # Async resource cleanup runs outside the transaction — it makes network # calls (NATS/Redis) that should not hold the DB row lock. diff --git a/ami/jobs/tests/test_tasks.py b/ami/jobs/tests/test_tasks.py index 89587f3f1..343d54d8a 100644 --- a/ami/jobs/tests/test_tasks.py +++ b/ami/jobs/tests/test_tasks.py @@ -1229,3 +1229,233 @@ def test_jobs_health_check_runs_lost_images_before_stale_jobs(self, mock_manager job.refresh_from_db() self.assertEqual(job.status, JobState.SUCCESS.value) + + +class TestCheckStaleJobsReaperGuard(TransactionTestCase): + """Reaper guard for async_api jobs: a SUCCESS Celery state is only accepted + when AsyncJobStateManager.all_tasks_processed() reports True. The earlier + guard read Job.progress.is_complete() — racy under concurrent + _update_job_progress writes since #1261 dropped select_for_update. A + production job once landed REVOKED with NATS+Redis fully drained because a + slower committer clobbered the SUCCESS write. This class verifies the new + Redis-direct path, the unavailable-state fallback to progress.is_complete() + with a WARNING, that Celery FAILURE is not fast-pathed to a terminal + FAILURE for async_api jobs, and that sync_api jobs are unaffected. + """ + + def setUp(self): + cache.clear() + self.project = Project.objects.create(name="Reaper Guard Project") + self.pipeline = Pipeline.objects.create(name="Reaper Pipeline", slug="reaper-pipeline") + self.pipeline.projects.add(self.project) + self.collection = SourceImageCollection.objects.create(name="Reaper Coll", project=self.project) + + def tearDown(self): + cache.clear() + + def _stale_async_job(self, *, task_id: str = "reaper-task") -> Job: + job = Job.objects.create( + job_type_key=MLJob.key, + project=self.project, + name="reaper async job", + pipeline=self.pipeline, + source_image_collection=self.collection, + dispatch_mode=JobDispatchMode.ASYNC_API, + ) + job.task_id = task_id + job.update_status(JobState.STARTED, save=True) + return job + + def _mark_stale(self, job: Job) -> None: + """Push updated_at back past STALLED_JOBS_MAX_MINUTES via raw update so + Job.save() side-effects (auto_now) don't undo it. Call AFTER any helper + that touches the job model.""" + Job.objects.filter(pk=job.pk).update( + updated_at=datetime.datetime.now() - datetime.timedelta(minutes=Job.STALLED_JOBS_MAX_MINUTES + 1) + ) + job.refresh_from_db() + + def _set_progress_clobbered(self, job: Job, total: int, processed: int) -> None: + """Mimic the 2521 Job.progress shape: process stage at processed/total + with status STARTED, even though Redis was actually fully drained.""" + progress = job.progress + collect = progress.get_stage("collect") + collect.progress = 1.0 + collect.status = JobState.SUCCESS + progress.update_stage( + "process", + progress=processed / total, + status=JobState.STARTED, + processed=processed, + remaining=total - processed, + failed=0, + ) + progress.update_stage( + "results", + progress=processed / total, + status=JobState.STARTED, + captures=processed, + ) + job.save() + + def _set_progress_complete(self, job: Job) -> None: + progress = job.progress + for key in ("collect", "process", "results"): + stage = progress.get_stage(key) + stage.progress = 1.0 + stage.status = JobState.SUCCESS + job.save() + + @patch("celery.result.AsyncResult") + def test_async_celery_success_redis_empty_progress_clobbered_lands_success(self, mock_async_result): + """The 2521 case. Pre-fix, this came back REVOKED because the reaper + consulted progress.is_complete() (False, due to clobber). Post-fix, + Redis says all_tasks_processed() → True, so SUCCESS is honored.""" + from ami.jobs.tasks import check_stale_jobs + + job = self._stale_async_job() + ids = [str(i) for i in range(10)] + manager = AsyncJobStateManager(job.pk) + manager.initialize_job(ids) + manager.update_state(set(ids), stage="process") + manager.update_state(set(ids), stage="results") + # Clobber: progress shows mid-flight even though Redis is drained. + self._set_progress_clobbered(job, total=10, processed=9) + self._mark_stale(job) + + from celery import states as celery_states + + mock_async_result.return_value.state = celery_states.SUCCESS + + check_stale_jobs() + + job.refresh_from_db() + self.assertEqual( + job.status, + JobState.SUCCESS.value, + f"clobbered progress should not block SUCCESS when Redis says drained; got {job.status}", + ) + + @patch("celery.result.AsyncResult") + def test_async_celery_success_redis_pending_lands_revoked(self, mock_async_result): + """Redis still has pending ids → genuine in-flight; reaper revokes.""" + from ami.jobs.tasks import check_stale_jobs + + job = self._stale_async_job() + ids = [str(i) for i in range(10)] + AsyncJobStateManager(job.pk).initialize_job(ids) + # No SREMs — pending sets still full. + self._mark_stale(job) + + from celery import states as celery_states + + mock_async_result.return_value.state = celery_states.SUCCESS + + check_stale_jobs() + + job.refresh_from_db() + self.assertEqual(job.status, JobState.REVOKED.value) + + @patch("celery.result.AsyncResult") + def test_async_celery_success_redis_absent_progress_complete_lands_success(self, mock_async_result): + """Redis state is gone (cleaned up / never initialized). Reaper falls + back to progress.is_complete(). Complete progress → SUCCESS + WARNING.""" + from ami.jobs.tasks import check_stale_jobs + + job = self._stale_async_job() + # Don't initialize Redis state. + self._set_progress_complete(job) + self._mark_stale(job) + + from celery import states as celery_states + + mock_async_result.return_value.state = celery_states.SUCCESS + + with self.assertLogs("ami.jobs.tasks", level="WARNING") as cm: + check_stale_jobs() + + job.refresh_from_db() + self.assertEqual(job.status, JobState.SUCCESS.value) + self.assertTrue(any("Redis state unavailable" in m for m in cm.output)) + + @patch("celery.result.AsyncResult") + def test_async_celery_success_redis_absent_progress_incomplete_lands_revoked(self, mock_async_result): + """Redis absent + progress incomplete → REVOKED via fallback + WARNING.""" + from ami.jobs.tasks import check_stale_jobs + + job = self._stale_async_job() + # Don't initialize Redis. Default progress is fresh (not complete). + self._mark_stale(job) + + from celery import states as celery_states + + mock_async_result.return_value.state = celery_states.SUCCESS + + with self.assertLogs("ami.jobs.tasks", level="WARNING") as cm: + check_stale_jobs() + + job.refresh_from_db() + self.assertEqual(job.status, JobState.REVOKED.value) + self.assertTrue(any("Redis state unavailable" in m for m in cm.output)) + + @patch("celery.result.AsyncResult") + def test_sync_api_celery_success_lands_success_without_redis_check(self, mock_async_result): + """sync_api jobs skip the Redis guard entirely — Celery's terminal + state is authoritative. No Redis state initialized; if the new path + leaked into sync_api this would REVOKE.""" + from ami.jobs.tasks import check_stale_jobs + + job = Job.objects.create( + job_type_key=MLJob.key, + project=self.project, + name="reaper sync job", + pipeline=self.pipeline, + source_image_collection=self.collection, + dispatch_mode=JobDispatchMode.SYNC_API, + ) + job.task_id = "reaper-sync-task" + job.update_status(JobState.STARTED, save=True) + self._mark_stale(job) + + from celery import states as celery_states + + mock_async_result.return_value.state = celery_states.SUCCESS + + check_stale_jobs() + + job.refresh_from_db() + self.assertEqual(job.status, JobState.SUCCESS.value) + + @patch("celery.result.AsyncResult") + def test_async_celery_failure_redis_drained_lands_revoked_not_failure(self, mock_async_result): + """A Celery FAILURE on an async_api job is never fast-pathed to a + terminal FAILURE here, even when Redis reports the pending sets drained. + + update_job_failure() deliberately defers post-queue run_job failures for + async_api jobs to the async result handler, which decides the terminal + outcome from the final processed/failed counts against FAILURE_THRESHOLD + (a drained-but-failed task can still resolve to SUCCESS). The reaper must + not pre-empt that by forcing FAILURE, so the stale job is REVOKED instead. + """ + from ami.jobs.tasks import check_stale_jobs + + job = self._stale_async_job(task_id="reaper-failure-task") + ids = [str(i) for i in range(10)] + manager = AsyncJobStateManager(job.pk) + manager.initialize_job(ids) + manager.update_state(set(ids), stage="process") + manager.update_state(set(ids), stage="results") + self._mark_stale(job) + + from celery import states as celery_states + + mock_async_result.return_value.state = celery_states.FAILURE + + check_stale_jobs() + + job.refresh_from_db() + self.assertEqual( + job.status, + JobState.REVOKED.value, + f"Celery FAILURE must not be forced to terminal FAILURE for async_api; got {job.status}", + ) diff --git a/ami/ml/orchestration/async_job_state.py b/ami/ml/orchestration/async_job_state.py index 26e7c3024..4c9502a60 100644 --- a/ami/ml/orchestration/async_job_state.py +++ b/ami/ml/orchestration/async_job_state.py @@ -228,6 +228,36 @@ def get_pending_image_ids(self) -> set[str]: return set() return {m.decode() if isinstance(m, (bytes, bytearray)) else str(m) for m in members} + def all_tasks_processed(self) -> bool | None: + """Tri-state truth signal for NATS-task SREM completeness across both + process and results pending sets. + + True — both pending sets empty AND total > 0 (or total == 0) + False — at least one pending set has members + None — Redis state absent (cleaned up, expired, never initialized, + or transient RedisError) + + Scope: tracks NATS task lifecycle only; does not know about `collect` + or any future post-results stages. + """ + try: + redis = self._get_redis() + with redis.pipeline() as pipe: + for stage in self.STAGES: + pipe.scard(self._get_pending_key(stage)) + pipe.get(self._total_key) + results = pipe.execute() + except RedisError as e: + logger.warning(f"Redis error reading all_tasks_processed for job {self.job_id}: {e}") + return None + + *pending_counts, total_raw = results + if total_raw is None: + return None + if int(total_raw) == 0: + return True + return all(count == 0 for count in pending_counts) + def cleanup(self) -> None: """ Delete all Redis keys associated with this job. diff --git a/ami/ml/tests.py b/ami/ml/tests.py index ea375135e..5de78aaff 100644 --- a/ami/ml/tests.py +++ b/ami/ml/tests.py @@ -1630,6 +1630,60 @@ def test_update_state_returns_none_when_state_genuinely_missing(self): progress = self.manager.update_state({"img1", "img2"}, "process") self.assertIsNone(progress) + def test_all_tasks_processed_fresh_init_returns_false(self): + """Just-initialized job has all images pending in both stages.""" + self._init_and_verify(self.image_ids) + self.assertFalse(self.manager.all_tasks_processed()) + + def test_all_tasks_processed_after_drain_returns_true(self): + """SREM-ing every id from both stages → True.""" + self._init_and_verify(self.image_ids) + self.manager.update_state(set(self.image_ids), "process") + self.manager.update_state(set(self.image_ids), "results") + self.assertTrue(self.manager.all_tasks_processed()) + + def test_all_tasks_processed_partial_stage_returns_false(self): + """Process drained but results still pending → False.""" + self._init_and_verify(self.image_ids) + self.manager.update_state(set(self.image_ids), "process") + # results stage untouched + self.assertFalse(self.manager.all_tasks_processed()) + + def test_all_tasks_processed_zero_total_returns_true(self): + """A zero-images job is trivially complete.""" + self.manager.initialize_job([]) + self.assertTrue(self.manager.all_tasks_processed()) + + def test_all_tasks_processed_never_initialized_returns_none(self): + """No init → total key absent → None (caller falls back).""" + # Do NOT call initialize_job + self.assertIsNone(self.manager.all_tasks_processed()) + + def test_all_tasks_processed_after_cleanup_returns_none(self): + """After cleanup() the total key is gone → None.""" + self._init_and_verify(self.image_ids) + self.manager.cleanup() + self.assertIsNone(self.manager.all_tasks_processed()) + + def test_all_tasks_processed_returns_none_on_redis_error(self): + """Transient RedisError → WARNING + None (caller falls back).""" + from unittest.mock import MagicMock, patch + + from redis.exceptions import RedisError + + self._init_and_verify(self.image_ids) + + pipe = MagicMock() + pipe.execute.side_effect = RedisError("Connection reset by peer") + fake_redis = MagicMock() + fake_redis.pipeline.return_value.__enter__.return_value = pipe + + with patch.object(self.manager, "_get_redis", return_value=fake_redis): + with self.assertLogs("ami.ml.orchestration.async_job_state", level="WARNING") as cm: + result = self.manager.all_tasks_processed() + self.assertIsNone(result) + self.assertTrue(any("Redis error reading all_tasks_processed" in m for m in cm.output)) + class TestSaveResultsRefreshesDeploymentCounts(TestCase): """save_results must refresh Deployment cached counts, not just Event counts.