Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 50 additions & 9 deletions ami/jobs/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
from asgiref.sync import async_to_sync
from celery.signals import task_failure, task_postrun, task_prerun
from django.db import transaction
from redis.exceptions import ConnectionError as RedisConnectionError
from redis.exceptions import RedisError

from ami.ml.orchestration.async_job_state import AsyncJobStateManager
from ami.ml.orchestration.nats_queue import TaskQueueManager
Expand Down Expand Up @@ -47,7 +49,17 @@ def run_job(self, job_id: int) -> None:

@celery_app.task(
bind=True,
max_retries=0, # don't retry since we already have retry logic in the NATS queue
# Retry on transient Redis/connection errors so a single connection reset
# doesn't flip the job to FAILURE mid-processing. Backoff is bounded well
# below soft_time_limit so retries never leak past the task deadline.
# Terminal failures (e.g. PipelineResultsError validation) are raised
# from other exception types and are not retried here.
# See RolnickLab/antenna#1219.
autoretry_for=(RedisError, RedisConnectionError, ConnectionError),
retry_backoff=True,
Comment thread
mihow marked this conversation as resolved.
Outdated
retry_backoff_max=30,
retry_jitter=True,
Comment thread
mihow marked this conversation as resolved.
Outdated
max_retries=5,
soft_time_limit=300, # 5 minutes
time_limit=360, # 6 minutes
)
Expand Down Expand Up @@ -84,11 +96,24 @@ def process_nats_pipeline_result(self, job_id: int, result_data: dict, reply_sub

state_manager = AsyncJobStateManager(job_id)

progress_info = state_manager.update_state(processed_image_ids, stage="process", failed_image_ids=failed_image_ids)
try:
progress_info = state_manager.update_state(
processed_image_ids, stage="process", failed_image_ids=failed_image_ids
)
except RedisError as e:
# Transient (connection reset, broker blip, timeout). Celery will retry
# via autoretry_for; NATS hasn't been acked yet so redelivery is not
# a concern. Log so the real cause is visible in task logs rather than
# the misleading "Redis state missing" that users saw in #1219.
Comment thread
mihow marked this conversation as resolved.
Outdated
logger.warning(f"Transient Redis error updating job {job_id} state (stage=process); Celery will retry: {e}")
Comment thread
mihow marked this conversation as resolved.
Outdated
raise

if not progress_info:
# Acknowledge the task to prevent retries, since we don't know the state
# State keys genuinely missing (the total-images key returned None).
# Ack so NATS stops redelivering and fail the job — there's no state
# left to reconcile against.
_ack_task_via_nats(reply_subject, logger)
_fail_job(job_id, "Redis state missing for job")
_fail_job(job_id, "Job state keys not found in Redis (likely cleaned up concurrently)")
return

try:
Expand Down Expand Up @@ -146,13 +171,24 @@ def process_nats_pipeline_result(self, job_id: int, result_data: dict, reply_sub
acked = True
# Update job stage with calculated progress

progress_info = state_manager.update_state(
processed_image_ids,
stage="results",
)
try:
progress_info = state_manager.update_state(
processed_image_ids,
stage="results",
)
except RedisError as e:
# Transient. NATS is acked and results are saved, but those writes
# are idempotent so a Celery retry is safe: save_results dedupes on
# re-run and SREM is a no-op on already-removed ids. Log to the
# job logger so the cause is visible in the UI task log, then let
# autoretry_for pick it up.
job.logger.warning(
f"Transient Redis error updating job {job_id} state (stage=results); Celery will retry: {e}"
)
raise

if not progress_info:
_fail_job(job_id, "Redis state missing for job")
_fail_job(job_id, "Job state keys not found in Redis (likely cleaned up concurrently)")
return

# update complete state based on latest progress info after saving results
Expand All @@ -170,6 +206,11 @@ def process_nats_pipeline_result(self, job_id: int, result_data: dict, reply_sub
captures=captures_count,
)

except RedisError:
# Logged above at the specific update_state call site; re-raise so
# Celery's autoretry_for handles the transient rather than this broad
# except swallowing it.
raise
except Exception as e:
error = f"Error processing pipeline result for job {job_id}: {e}"
if not acked:
Expand Down
62 changes: 62 additions & 0 deletions ami/jobs/tests/test_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,68 @@ def test_process_nats_pipeline_result_concurrent_updates(self, mock_manager_clas
self.assertEqual(progress.total, 3)
self.assertEqual(mock_manager.acknowledge_task.call_count, 2)

@patch("ami.jobs.tasks._fail_job")
@patch("ami.jobs.tasks._ack_task_via_nats")
@patch("ami.jobs.tasks.TaskQueueManager")
def test_transient_redis_error_does_not_fail_job_or_ack(self, mock_manager_class, mock_ack, mock_fail):
"""
#1219: A transient RedisError during update_state must NOT flip the job
to FAILURE and must NOT ack the NATS reply. Celery's autoretry_for is
responsible for retrying; acking or failing prematurely is what caused
the production incident.

We invoke the task body directly (bypassing Celery's retry machinery)
so we can assert the raw behavior: the exception propagates, _fail_job
is not called, and the NATS ack helper is not called.
"""
from redis.exceptions import RedisError

self._setup_mock_nats(mock_manager_class)
error_data = self._create_error_result(image_id=str(self.images[0].pk))

with patch.object(AsyncJobStateManager, "update_state", side_effect=RedisError("reset by peer")):
with self.assertRaises(RedisError):
# Calling the task as a function runs its body once with no retry.
process_nats_pipeline_result(
job_id=self.job.pk,
result_data=error_data,
reply_subject="reply.transient",
)

mock_fail.assert_not_called()
mock_ack.assert_not_called()

@patch("ami.jobs.tasks._fail_job")
@patch("ami.jobs.tasks._ack_task_via_nats")
@patch("ami.jobs.tasks.TaskQueueManager")
def test_genuinely_missing_state_acks_and_fails_job(self, mock_manager_class, mock_ack, mock_fail):
"""
#1219 pairs with the transient case: when the job's total-images key
is actually gone from Redis (cleanup race / expiry), the task should
ack NATS (to stop redelivery) and fail the job — there's no state
to reconcile against. This path is now the ONLY reason _fail_job is
called from process_nats_pipeline_result's first call site.
"""
self._setup_mock_nats(mock_manager_class)
error_data = self._create_error_result(image_id=str(self.images[0].pk))

# Wipe out the state that setUp's initialize_job created. Now
# update_state will see total_raw=None and return None (genuine).
self.state_manager.cleanup()

process_nats_pipeline_result(
job_id=self.job.pk,
result_data=error_data,
reply_subject="reply.missing",
)

mock_ack.assert_called_once()
mock_fail.assert_called_once()
# New, accurate message — no longer the misleading "Redis state missing"
# that users saw in the UI for transient connection drops.
args, _ = mock_fail.call_args
self.assertIn("Job state keys not found in Redis", args[1])

@patch("ami.jobs.tasks.TaskQueueManager")
def test_process_nats_pipeline_result_error_job_not_found(self, mock_manager_class):
"""
Expand Down
42 changes: 23 additions & 19 deletions ami/ml/orchestration/async_job_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,26 +127,30 @@ def update_state(
failed_image_ids: Set of image IDs that failed processing (optional)

Returns:
JobStateProgress snapshot, or None if Redis state is missing
(job expired or not yet initialized).
JobStateProgress snapshot, or None if the job's total-images key is
genuinely missing from Redis (job expired, cleaned up concurrently,
or never initialized).

Raises:
redis.exceptions.RedisError: on transient Redis failures (connection
reset, timeout, etc.). Callers should retry; swallowing this
here would conflate a fixable transient with the terminal
"state genuinely gone" signal expressed by the None return.
See RolnickLab/antenna#1219.
"""
try:
redis = self._get_redis()
pending_key = self._get_pending_key(stage)

with redis.pipeline() as pipe:
if processed_image_ids:
pipe.srem(pending_key, *processed_image_ids)
if failed_image_ids:
pipe.sadd(self._failed_key, *failed_image_ids)
pipe.expire(self._failed_key, self.TIMEOUT)
pipe.scard(pending_key)
pipe.scard(self._failed_key)
pipe.get(self._total_key)
results = pipe.execute()
except RedisError as e:
logger.error(f"Redis error updating job {self.job_id} state: {e}")
return None
redis = self._get_redis()
pending_key = self._get_pending_key(stage)

with redis.pipeline() as pipe:
if processed_image_ids:
pipe.srem(pending_key, *processed_image_ids)
if failed_image_ids:
pipe.sadd(self._failed_key, *failed_image_ids)
pipe.expire(self._failed_key, self.TIMEOUT)
pipe.scard(pending_key)
pipe.scard(self._failed_key)
pipe.get(self._total_key)
results = pipe.execute()

# Last 3 results are always scard(pending), scard(failed), get(total)
# regardless of whether SREM/SADD appear at the front.
Expand Down
38 changes: 38 additions & 0 deletions ami/ml/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -1366,3 +1366,41 @@ def test_cleanup_removes_failed_set(self):
# Verify all state is gone (get_progress returns None when total_key is deleted)
progress = self.manager.get_progress("process")
self.assertIsNone(progress)

def test_update_state_raises_on_redis_error(self):
"""
A transient Redis failure during update_state must propagate, not be
swallowed as None. The None return is reserved for the genuine
"state actually gone" case (see test below). Conflating the two is
the #1219 bug that escalated transient connection resets into fatal
job FAILUREs.
"""
from unittest.mock import MagicMock, patch

from redis.exceptions import RedisError

self._init_and_verify(self.image_ids)

# Replace the pipeline context manager with one whose execute() raises.
# Everything upstream of execute() is safely called (srem/sadd/scard/get
# on a pipeline only queue commands; they don't hit the network until
# execute runs), so we only need to blow up at the execute boundary.
pipe = MagicMock()
pipe.execute.side_effect = RedisError("Connection reset by peer")
fake_redis = MagicMock()
fake_redis.pipeline.return_value.__enter__.return_value = pipe

with patch.object(self.manager, "_get_redis", return_value=fake_redis):
with self.assertRaises(RedisError):
self.manager.update_state({"img1", "img2"}, "process")

def test_update_state_returns_none_when_state_genuinely_missing(self):
"""
When the job's total-images key is actually missing from Redis (job
was never initialized, cleaned up, or TTL expired), update_state
returns None. This is the only case that should trigger the
terminal "state missing" failure path in the caller.
"""
# Do NOT call initialize_job — the total key doesn't exist.
progress = self.manager.update_state({"img1", "img2"}, "process")
self.assertIsNone(progress)
Loading