diff --git a/ami/jobs/tasks.py b/ami/jobs/tasks.py index dbf57fa38..90e1aaae4 100644 --- a/ami/jobs/tasks.py +++ b/ami/jobs/tasks.py @@ -146,7 +146,15 @@ def run_job(self, job_id: int) -> None: raise e # self.retry(exc=e, countdown=1, max_retries=1) else: - job.logger.info(f"Running job {job}") + # Log the Redis DB index at task start so cross-host DB-index drift — the misconfig + # that surfaces as silent process_nats_pipeline_result FAILUREs on whichever worker + # reads state from the wrong DB — is visible per job. Every host's "default" Redis + # connection (the one the job state manager uses) must select the same DB number, or + # initialize_job and update_state operate on different DBs. The public job log carries + # only the DB index; the full host:port goes to the server log, since the host names + # internal infrastructure. + job.logger.info(f"Running job {job} on Redis db {_redis_db_index()}") + logger.info("Job %s Redis target: %s", job.pk, _describe_redis_target()) try: job.run() except Exception as e: @@ -281,10 +289,18 @@ def process_nats_pipeline_result(self, job_id: int, result_data: dict, reply_sub if not progress_info: # State keys genuinely missing (the total-images key returned None). # Ack so NATS stops redelivering and fail the job — there's no state - # left to reconcile against. + # left to reconcile against. The reason string is built from a live + # Redis snapshot (DB index, keys present under job:{id}:*) so the + # FAILURE log and the UI progress.errors entry name the actual cause + # instead of the previous hardcoded "likely cleaned up concurrently" + # guess — which conflated DB-index misconfig, eviction, and genuine + # concurrent cleanup into a single misleading string. _log_missing_state_context(job_id, "process") _ack_task_via_nats(reply_subject, logger) - _fail_job(job_id, "Job state keys not found in Redis (likely cleaned up concurrently)") + _fail_job( + job_id, + f"Job state missing from Redis (stage=process): {state_manager.diagnose_missing_state()}", + ) return try: @@ -367,7 +383,10 @@ def process_nats_pipeline_result(self, job_id: int, result_data: dict, reply_sub # then fail the job. Mirrors the stage=process missing-state path. _log_missing_state_context(job_id, "results") _ack_task_via_nats(reply_subject, job.logger) - _fail_job(job_id, "Job state keys not found in Redis (likely cleaned up concurrently)") + _fail_job( + job_id, + f"Job state missing from Redis (stage=results): {state_manager.diagnose_missing_state()}", + ) return # update complete state based on latest progress info after saving results @@ -416,6 +435,40 @@ def process_nats_pipeline_result(self, job_id: int, result_data: dict, reply_sub job.logger.error(error) +def _redis_db_index() -> str: + """Return just the DB index of the "default" Redis connection (e.g. ``"1"``). + + Safe for the public job log: the DB index is the load-bearing signal for diagnosing + cross-host DB drift, without exposing the internal Redis host. Pair with + :func:`_describe_redis_target` (host:port, server logs only) when an operator needs the host. + """ + try: + from django_redis import get_redis_connection + + redis = get_redis_connection("default") + kwargs = getattr(redis.connection_pool, "connection_kwargs", {}) or {} + return str(kwargs.get("db", "?")) + except Exception as e: + return f"(unavailable: {e})" + + +def _describe_redis_target() -> str: + """Return a ``redis=host:port/dbN`` string for the "default" Redis connection. + + For server-side logs only — it names the internal Redis host, so it must not reach the + public job log (use :func:`_redis_db_index` there). Logged server-side at ``run_job`` start + so an operator can resolve a DB-index drift to a specific host. + """ + try: + from django_redis import get_redis_connection + + redis = get_redis_connection("default") + kwargs = getattr(redis.connection_pool, "connection_kwargs", {}) or {} + return f"redis={kwargs.get('host', '?')}:{kwargs.get('port', '?')}/db{kwargs.get('db', '?')}" + except Exception as e: + return f"redis=(unavailable: {e})" + + def _fail_job(job_id: int, reason: str) -> None: from ami.jobs.models import Job, JobState from ami.ml.orchestration.jobs import cleanup_async_job_resources @@ -425,6 +478,16 @@ def _fail_job(job_id: int, reason: str) -> None: job = Job.objects.select_for_update().get(pk=job_id) if job.status in (JobState.CANCELING, *JobState.final_states()): return + # Mirror the reason into progress.errors so the UI surfaces it + # alongside the FAILURE status. Operators can see the cause in the + # job detail view without digging through Celery worker logs. + try: + job.progress.errors.append(reason) + except Exception as e: + # Don't let a diagnostic-write failure mask the original FAILURE, but record + # that the reason could not be attached so the swallow is observable — otherwise + # the UI silently loses the cause this PR exists to surface. + logger.warning("Job %s: could not append failure reason to progress.errors: %s", job_id, e) job.update_status(JobState.FAILURE, save=False) job.finished_at = datetime.datetime.now() job.save(update_fields=["status", "progress", "finished_at"]) diff --git a/ami/jobs/tests/test_tasks.py b/ami/jobs/tests/test_tasks.py index 8dc99da29..fce7a4f46 100644 --- a/ami/jobs/tests/test_tasks.py +++ b/ami/jobs/tests/test_tasks.py @@ -12,7 +12,7 @@ from unittest.mock import AsyncMock, MagicMock, patch from django.core.cache import cache -from django.test import TransactionTestCase +from django.test import SimpleTestCase, TransactionTestCase from rest_framework.test import APITestCase from ami.base.serializers import reverse_with_params @@ -348,10 +348,69 @@ def test_genuinely_missing_state_acks_and_fails_job(self, mock_manager_class, mo mock_ack.assert_called_once() mock_fail.assert_called_once() - # New, accurate message — no longer the misleading "Redis state missing" - # that users saw in the UI for transient connection drops. + # The reason string passed to _fail_job identifies the stage and embeds + # a live Redis snapshot (from diagnose_missing_state) so the FAILURE + # log and UI progress.errors distinguish DB-index drift, eviction, and + # never-initialized state rather than collapsing them into one message. args, _ = mock_fail.call_args - self.assertIn("Job state keys not found in Redis", args[1]) + self.assertIn("Job state missing from Redis", args[1]) + self.assertIn("stage=process", args[1]) + + @patch("ami.jobs.tasks._fail_job") + @patch("ami.jobs.tasks._ack_task_via_nats") + @patch("ami.jobs.tasks.TaskQueueManager") + def test_genuinely_missing_state_results_stage_acks_and_fails_job(self, mock_manager_class, mock_ack, mock_fail): + """ + Mirror of test_genuinely_missing_state_acks_and_fails_job for the + stage=results path (tasks.py lines 378-388). When the total-images key + is gone at the results-stage update_state call, the task must ack NATS + to stop redelivery and fail the job — there is no state to reconcile. + The reason string must identify stage=results. + """ + self._setup_mock_nats(mock_manager_class) + + # save_results requires at least one algorithm on the pipeline. + detection_algorithm = Algorithm.objects.create( + name="results-missing-state-detector", + key="results-missing-state-detector", + task_type=AlgorithmTaskType.LOCALIZATION, + ) + self.pipeline.algorithms.add(detection_algorithm) + + # Use a success result so the process-stage path succeeds and + # save_results runs before the results-stage update_state is reached. + success_data = PipelineResultsResponse( + pipeline="test-pipeline", + algorithms={}, + total_time=1.0, + source_images=[SourceImageResponse(id=str(self.images[0].pk), url="http://example.com/test_image_0.jpg")], + detections=[], + errors=None, + ).dict() + + real_update_state = AsyncJobStateManager.update_state + + def none_on_results_stage(self_inner, processed_image_ids, stage, failed_image_ids=None): + if stage == "results": + return None + return real_update_state(self_inner, processed_image_ids, stage, failed_image_ids) + + with patch.object(AsyncJobStateManager, "update_state", none_on_results_stage): + process_nats_pipeline_result( + job_id=self.job.pk, + result_data=success_data, + reply_subject="reply.missing-results", + ) + + mock_ack.assert_called_once() + mock_fail.assert_called_once() + # The reason string passed to _fail_job identifies the stage and embeds + # a live Redis snapshot (from diagnose_missing_state) so the FAILURE + # log and UI progress.errors distinguish DB-index drift, eviction, and + # never-initialized state rather than collapsing them into one message. + args, _ = mock_fail.call_args + self.assertIn("Job state missing from Redis", args[1]) + self.assertIn("stage=results", args[1]) @patch("ami.jobs.tasks._ack_task_via_nats") @patch("ami.jobs.tasks.TaskQueueManager") @@ -627,6 +686,95 @@ def test_task_failure_marks_sync_api_job_failure_and_cleans_up(self, mock_cleanu mock_cleanup.assert_called_once() +class TestFailJob(TransactionTestCase): + """ + Regression tests for ``_fail_job`` — specifically for the reason-string + mirroring into ``progress.errors`` that this PR adds. + + The FAILURE log line alone is not enough for operators; the UI reads + ``progress.errors``, and prior to this PR that list stayed empty on the + missing-Redis-state path. Any regression that stops appending the reason + (e.g. silently dropping it via the defensive ``try/except``) would put + operators back in the position of digging through Celery worker logs to + find out why a job died. + """ + + def setUp(self): + cache.clear() + self.project = Project.objects.create(name="FailJob Test Project") + self.pipeline = Pipeline.objects.create(name="FailJob Pipeline", slug="fail-job-pipeline") + self.pipeline.projects.add(self.project) + self.collection = SourceImageCollection.objects.create(name="FailJob Collection", project=self.project) + + def tearDown(self): + cache.clear() + + def _make_job(self, dispatch_mode: JobDispatchMode = JobDispatchMode.ASYNC_API) -> Job: + job = Job.objects.create( + job_type_key=MLJob.key, + project=self.project, + name=f"{dispatch_mode} fail-job test", + pipeline=self.pipeline, + source_image_collection=self.collection, + dispatch_mode=dispatch_mode, + ) + job.update_status(JobState.STARTED, save=True) + return job + + @patch("ami.ml.orchestration.jobs.cleanup_async_job_resources") + def test_fail_job_appends_reason_to_progress_errors(self, mock_cleanup): + """ + Reason string must end up in ``job.progress.errors`` (persisted) so the + UI shows the cause of the FAILURE alongside the status change. Before + this PR the reason lived only in ``job.logger`` and the UI showed + ``errors=[]``. A silent regression here would not be caught by the + ``_fail_job`` call-site tests in ``TestProcessNatsPipelineResultError`` + (they mock ``_fail_job`` entirely). + """ + from ami.jobs.tasks import _fail_job + + job = self._make_job() + reason = "Job state missing from Redis (stage=process): redis=host:6379/db1 keys_for_job=" + + _fail_job(job.pk, reason) + + job.refresh_from_db() + self.assertEqual(job.status, JobState.FAILURE) + self.assertIn( + reason, + job.progress.errors, + f"expected reason in progress.errors, got: {job.progress.errors!r}", + ) + # Sanity: the fix also propagates to the DB-persisted copy (i.e. the + # update_fields tuple on job.save includes 'progress'). Re-read from a + # fresh Job instance to prove the append wasn't only visible on the + # in-memory object returned by select_for_update. + reloaded = Job.objects.get(pk=job.pk) + self.assertIn(reason, reloaded.progress.errors) + mock_cleanup.assert_called_once_with(job.pk) + + @patch("ami.ml.orchestration.jobs.cleanup_async_job_resources") + def test_fail_job_is_noop_on_already_final_job(self, mock_cleanup): + """ + If the job is already in a final state (e.g. concurrent cleanup + beat us), ``_fail_job`` must return early without touching status + or progress. This protects against double-failing a job that has + already been reconciled to SUCCESS by the reconciler path. + """ + from ami.jobs.tasks import _fail_job + + job = self._make_job() + job.update_status(JobState.SUCCESS, save=True) + errors_before = list(job.progress.errors) + + _fail_job(job.pk, "should be ignored") + + job.refresh_from_db() + self.assertEqual(job.status, JobState.SUCCESS) + self.assertEqual(job.progress.errors, errors_before) + mock_cleanup.assert_not_called() + + class TestResultEndpointWithError(APITestCase): """Integration test for the result API endpoint with error results.""" @@ -1469,3 +1617,28 @@ def worker(): JobState.REVOKED.value, f"late results batch resurrected a REVOKED job to {self.job.status!r} (lost-update race)", ) + + +class TestRedisTargetLogging(SimpleTestCase): + """The per-job Redis-target log must not leak the internal Redis host. + + run_job logs the Redis DB index to the public job log (via _redis_db_index) and the full + host:port only to the server log (via _describe_redis_target). The host names internal + infrastructure, so it must never reach the public job log / progress.errors. This pins the + split that a prior leak (host:port in the public job log) slipped through without. + """ + + def test_public_db_index_omits_host_and_port(self): + from ami.jobs.tasks import _redis_db_index + + out = _redis_db_index() + self.assertNotIn(":", out, "DB index must not contain a host:port") + self.assertNotIn("redis=", out) + + def test_server_target_includes_host_and_port(self): + from ami.jobs.tasks import _describe_redis_target + + out = _describe_redis_target() + self.assertTrue(out.startswith("redis=")) + self.assertIn(":", out) + self.assertIn("/db", out) diff --git a/ami/ml/orchestration/async_job_state.py b/ami/ml/orchestration/async_job_state.py index 26e7c3024..3d3d0ebdc 100644 --- a/ami/ml/orchestration/async_job_state.py +++ b/ami/ml/orchestration/async_job_state.py @@ -163,6 +163,18 @@ def update_state( newly_removed = results[0] if processed_image_ids else 0 if total_raw is None: + # Loud diagnostic before the silent None return. The caller will mark + # the job FAILURE based on this result, so the operator needs to see + # *why* the total key is gone. Distinguishes three causes that map to + # the same symptom: DB-index mismatch across hosts, key eviction, and + # never-initialized state. + logger.warning( + "Job %s state missing in Redis (stage=%s, target=%s): %s", + self.job_id, + stage, + self.connection_target(), + self.diagnose_missing_state(), + ) return None total = int(total_raw) @@ -182,6 +194,72 @@ def update_state( newly_removed=newly_removed, ) + def diagnose_missing_state(self) -> str: + """ + One-line, log-safe snapshot of what Redis holds for this job. + + Reports the Redis DB index and the per-job keys (with set cardinalities) so a + missing-state FAILURE distinguishes its three common causes — DB-index mismatch + across processes, key eviction, and never-initialized state — instead of a single + hardcoded "likely cleaned up concurrently" guess that all three collapse to. + + The Redis host/port is deliberately omitted here: this string is surfaced in the + job's public ``progress.errors`` (via the reason passed to ``_fail_job``), and the + host identifies internal infrastructure. The DB index is the load-bearing signal for + the mismatch case and is safe to expose. Operators who need the host see it in the + server-side warning ``update_state`` logs via ``connection_target()``. + + Called from the missing-state path in ``update_state`` (the loud log) and from the + result handler in ``process_nats_pipeline_result``. + + Cost: only ever runs on the missing-state failure path (at most twice per + job-lifetime FAILURE — once for the log, once for the reason string). ``SCAN`` is + O(keyspace) regardless of ``MATCH`` (MATCH filters the returned keys, not the keys + scanned), but on the rare failure path that one extra full cursor walk is negligible + next to the FAILURE it helps diagnose. + + Intentionally defensive: any failure to collect diagnostics is swallowed, because the + caller is already about to fail the job and an exception from diagnostics would mask + the original cause. + """ + try: + redis = self._get_redis() + kwargs = getattr(redis.connection_pool, "connection_kwargs", {}) or {} + db = kwargs.get("db", "?") + keys = sorted(k.decode() if isinstance(k, bytes) else k for k in redis.scan_iter(match=self._pattern())) + sizes: list[str] = [] + for key in keys: + if key == self._total_key: + sizes.append(f"{key}=") + continue + try: + sizes.append(f"{key}=SCARD:{redis.scard(key)}") + except RedisError: + sizes.append(f"{key}=") + keys_summary = ", ".join(sizes) if sizes else "" + return f"redis db{db}: keys_for_job={keys_summary}" + except Exception as e: + return f"(diagnostics failed: {e})" + + def connection_target(self) -> str: + """ + Redis target as ``host:port/dbN``, for server-side operator logs only. + + Names the internal Redis host, so this must NOT go into job logs or + ``progress.errors`` — use :meth:`diagnose_missing_state` for anything surfaced to + the user. Kept separate so the host stays in operator-facing logs (where the + cross-host DB-drift diagnosis needs it) without leaking to the public job view. + """ + try: + redis = self._get_redis() + kwargs = getattr(redis.connection_pool, "connection_kwargs", {}) or {} + return f"{kwargs.get('host', '?')}:{kwargs.get('port', '?')}/db{kwargs.get('db', '?')}" + except Exception as e: + return f"(unavailable: {e})" + + def _pattern(self) -> str: + return f"job:{self.job_id}:*" + def get_progress(self, stage: str) -> "JobStateProgress | None": """Read-only progress snapshot for the given stage.""" try: diff --git a/ami/ml/tests.py b/ami/ml/tests.py index ea375135e..79ba5e253 100644 --- a/ami/ml/tests.py +++ b/ami/ml/tests.py @@ -1630,6 +1630,55 @@ def test_update_state_returns_none_when_state_genuinely_missing(self): progress = self.manager.update_state({"img1", "img2"}, "process") self.assertIsNone(progress) + def test_diagnose_missing_state_when_never_initialized(self): + """ + Diagnostic string for the "never initialized" case: no keys are + present under ``job:{id}:*``. The public string names the Redis DB + index (so cross-process DB drift is distinguishable from eviction and + truly-never-initialized state) but must NOT name the Redis host: it is + surfaced in the job's public progress.errors. + """ + # initialize_job has NOT been called; nothing under job:123:*. + diagnosis = self.manager.diagnose_missing_state() + self.assertIn("db", diagnosis) + self.assertIn("keys_for_job=", diagnosis) + + def test_diagnose_missing_state_omits_host_but_connection_target_keeps_it(self): + """ + The public diagnostic must not leak the internal Redis host/port, but the + operator-only connection_target() must still carry host:port/db for the + server-side cross-host DB-drift diagnosis. + """ + kwargs = self.manager._get_redis().connection_pool.connection_kwargs + host = str(kwargs.get("host", "")) + + public = self.manager.diagnose_missing_state() + target = self.manager.connection_target() + + # The leak is the host:port connection detail, so assert the host:port pattern is + # absent from the public string (not the bare host substring — the public string + # legitimately contains the word "redis"). + if host: + self.assertNotIn(f"{host}:", public) + self.assertIn(f"{host}:", target) + self.assertIn("/db", target) + + def test_diagnose_missing_state_lists_present_keys(self): + """ + Diagnostic string for the partial-cleanup / eviction case: some keys + remain under ``job:{id}:*`` and their SCARDs should appear so the + operator can tell "total key evicted but pending sets still present" + from "nothing here, this DB never saw the job". + """ + self.manager.initialize_job(self.image_ids) + # Drop the total key to simulate eviction while pending sets survive. + redis = self.manager._get_redis() + redis.delete(self.manager._total_key) + + diagnosis = self.manager.diagnose_missing_state() + self.assertIn(f"job:{self.job_id}:pending_images:process=SCARD:", diagnosis) + self.assertNotIn(self.manager._total_key, diagnosis) + class TestSaveResultsRefreshesDeploymentCounts(TestCase): """save_results must refresh Deployment cached counts, not just Event counts.