Skip to content
Merged
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions ami/jobs/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,6 +384,36 @@ def test_filter_by_pipeline_slug(self):
self.assertEqual(data["count"], 1)
self.assertEqual(data["results"][0]["id"], job_with_pipeline.pk)

def test_filter_by_pipeline_slug_in(self):
"""Test filtering jobs by pipeline__slug__in (multiple slugs)."""
pipeline_a = self._create_pipeline("Pipeline A", "pipeline-a")
pipeline_b = Pipeline.objects.create(name="Pipeline B", slug="pipeline-b", description="B")
pipeline_b.projects.add(self.project)
pipeline_c = Pipeline.objects.create(name="Pipeline C", slug="pipeline-c", description="C")
pipeline_c.projects.add(self.project)

job_a = self._create_ml_job("Job A", pipeline_a)
job_b = self._create_ml_job("Job B", pipeline_b)
job_c = self._create_ml_job("Job C", pipeline_c)

self.client.force_authenticate(user=self.user)

# Filter for two of the three pipelines
jobs_list_url = reverse_with_params(
"api:job-list",
params={"project_id": self.project.pk, "pipeline__slug__in": "pipeline-a,pipeline-b"},
)
resp = self.client.get(jobs_list_url)

self.assertEqual(resp.status_code, 200)
data = resp.json()
returned_ids = {job["id"] for job in data["results"]}
self.assertIn(job_a.pk, returned_ids)
self.assertIn(job_b.pk, returned_ids)
self.assertNotIn(job_c.pk, returned_ids)
# Original setUp job (no pipeline) should also be excluded
self.assertNotIn(self.job.pk, returned_ids)

def test_search_jobs(self):
"""Test searching jobs by name and pipeline name."""
pipeline = self._create_pipeline("SearchablePipeline", "searchable-pipeline")
Expand Down
19 changes: 10 additions & 9 deletions ami/jobs/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class JobFilterSet(filters.FilterSet):
"""Custom filterset to enable pipeline name filtering."""

pipeline__slug = filters.CharFilter(field_name="pipeline__slug", lookup_expr="exact")
pipeline__slug__in = filters.BaseInFilter(field_name="pipeline__slug", lookup_expr="in")

class Meta:
model = Job
Expand All @@ -55,11 +56,12 @@ def filter_queryset(self, request, queryset, view):
incomplete_only = url_boolean_param(request, "incomplete_only", default=False)
# Filter to incomplete jobs if requested (checks "results" stage status)
if incomplete_only:
# Create filters for each final state to exclude
# Exclude jobs with a terminal top-level status
queryset = queryset.exclude(status__in=JobState.final_states())

Comment thread
mihow marked this conversation as resolved.
# Also exclude jobs where the "results" stage has a final state status
final_states = JobState.final_states()
exclude_conditions = Q()

# Exclude jobs where the "results" stage has a final state status
for state in final_states:
# JSON path query to check if results stage status is in final states
# @TODO move to a QuerySet method on Job model if/when this needs to be reused elsewhere
Expand Down Expand Up @@ -233,6 +235,10 @@ def tasks(self, request, pk=None):
if job.dispatch_mode != JobDispatchMode.ASYNC_API:
raise ValidationError("Only async_api jobs have fetchable tasks")

# Don't fetch tasks from completed/failed/revoked jobs
if job.status in JobState.final_states():
return Response({"tasks": []})
Comment thread
mihow marked this conversation as resolved.

# Validate that the job has a pipeline
if not job.pipeline:
raise ValidationError("This job does not have a pipeline configured")
Expand All @@ -241,13 +247,8 @@ def tasks(self, request, pk=None):
from ami.ml.orchestration.nats_queue import TaskQueueManager

async def get_tasks():
tasks = []
async with TaskQueueManager() as manager:
for _ in range(batch):
task = await manager.reserve_task(job.pk, timeout=0.1)
if task:
tasks.append(task.dict())
return tasks
return [task.dict() for task in await manager.reserve_tasks(job.pk, count=batch)]
Comment thread
mihow marked this conversation as resolved.
Outdated

# Use async_to_sync to properly handle the async call
tasks = async_to_sync(get_tasks)()
Comment thread
mihow marked this conversation as resolved.
Outdated
Expand Down
61 changes: 28 additions & 33 deletions ami/ml/orchestration/nats_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,12 @@


async def get_connection(nats_url: str):
nc = await nats.connect(nats_url)
nc = await nats.connect(
nats_url,
connect_timeout=5,
allow_reconnect=False,
Comment thread
mihow marked this conversation as resolved.
max_reconnect_attempts=0,
)
Comment thread
mihow marked this conversation as resolved.
Outdated
js = nc.jetstream()
return nc, js

Expand All @@ -39,8 +44,8 @@ class TaskQueueManager:
Use as an async context manager:
async with TaskQueueManager() as manager:
await manager.publish_task('job123', {'data': 'value'})
task = await manager.reserve_task('job123')
await manager.acknowledge_task(task['reply_subject'])
tasks = await manager.reserve_tasks('job123', count=64)
await manager.acknowledge_task(tasks[0].reply_subject)
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated
"""

def __init__(self, nats_url: str | None = None):
Expand Down Expand Up @@ -156,62 +161,52 @@ async def publish_task(self, job_id: int, data: PipelineProcessingTask) -> bool:
logger.error(f"Failed to publish task to stream for job '{job_id}': {e}")
return False

async def reserve_task(self, job_id: int, timeout: float | None = None) -> PipelineProcessingTask | None:
async def reserve_tasks(self, job_id: int, count: int, timeout: float = 5) -> list[PipelineProcessingTask]:
Comment thread
mihow marked this conversation as resolved.
"""
Reserve a task from the specified stream.
Reserve up to `count` tasks from the specified stream in a single NATS fetch.

Args:
job_id: The job ID (integer primary key) to pull tasks from
timeout: Timeout in seconds for reservation (default: 5 seconds)
count: Maximum number of tasks to reserve
timeout: Timeout in seconds waiting for messages (default: 5 seconds)

Returns:
PipelineProcessingTask with reply_subject set for acknowledgment, or None if no task available
List of PipelineProcessingTask objects with reply_subject set for acknowledgment.
May return fewer than `count` if the queue has fewer messages available.
"""
if self.js is None:
raise RuntimeError("Connection is not open. Use TaskQueueManager as an async context manager.")

if timeout is None:
timeout = 5

try:
# Ensure stream and consumer exist
await self._ensure_stream(job_id)
await self._ensure_consumer(job_id)

consumer_name = self._get_consumer_name(job_id)
subject = self._get_subject(job_id)

# Create ephemeral subscription for this pull
psub = await self.js.pull_subscribe(subject, consumer_name)

try:
# Fetch a single message
msgs = await psub.fetch(1, timeout=timeout)

if msgs:
msg = msgs[0]
task_data = json.loads(msg.data.decode())
metadata = msg.metadata

# Parse the task data into PipelineProcessingTask
task = PipelineProcessingTask(**task_data)
# Set the reply_subject for acknowledgment
task.reply_subject = msg.reply

logger.debug(f"Reserved task from stream for job '{job_id}', sequence {metadata.sequence.stream}")
return task

msgs = await psub.fetch(count, timeout=timeout)
except nats.errors.TimeoutError:
# No messages available
logger.debug(f"No tasks available in stream for job '{job_id}'")
return None
return []
finally:
# Always unsubscribe
await psub.unsubscribe()

tasks = []
for msg in msgs:
task_data = json.loads(msg.data.decode())
task = PipelineProcessingTask(**task_data)
task.reply_subject = msg.reply
tasks.append(task)

logger.info(f"Reserved {len(tasks)} tasks from stream for job '{job_id}'")
Comment thread
mihow marked this conversation as resolved.
Outdated
return tasks

except Exception as e:
logger.error(f"Failed to reserve task from stream for job '{job_id}': {e}")
return None
logger.error(f"Failed to reserve tasks from stream for job '{job_id}': {e}")
return []

async def acknowledge_task(self, reply_subject: str) -> bool:
"""
Expand Down
63 changes: 45 additions & 18 deletions ami/ml/orchestration/tests/test_nats_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,47 +62,74 @@ async def test_publish_task_creates_stream_and_consumer(self):
self.assertIn("job_456", str(js.add_stream.call_args))
js.add_consumer.assert_called_once()

async def test_reserve_task_success(self):
"""Test successful task reservation."""
async def test_reserve_tasks_success(self):
"""Test successful batch task reservation."""
nc, js = self._create_mock_nats_connection()
sample_task = self._create_sample_task()

# Mock message with task data
mock_msg = MagicMock()
mock_msg.data = sample_task.json().encode()
mock_msg.reply = "reply.subject.123"
mock_msg.metadata = MagicMock(sequence=MagicMock(stream=1))
# Mock messages with task data
mock_msg1 = MagicMock()
mock_msg1.data = sample_task.json().encode()
mock_msg1.reply = "reply.subject.1"

mock_msg2 = MagicMock()
mock_msg2.data = sample_task.json().encode()
mock_msg2.reply = "reply.subject.2"

mock_psub = MagicMock()
mock_psub.fetch = AsyncMock(return_value=[mock_msg])
mock_psub.fetch = AsyncMock(return_value=[mock_msg1, mock_msg2])
mock_psub.unsubscribe = AsyncMock()
js.pull_subscribe = AsyncMock(return_value=mock_psub)

with patch("ami.ml.orchestration.nats_queue.get_connection", AsyncMock(return_value=(nc, js))):
async with TaskQueueManager() as manager:
task = await manager.reserve_task(123)
tasks = await manager.reserve_tasks(123, count=5)

self.assertIsNotNone(task)
self.assertEqual(task.id, sample_task.id)
self.assertEqual(task.reply_subject, "reply.subject.123")
self.assertEqual(len(tasks), 2)
self.assertEqual(tasks[0].id, sample_task.id)
self.assertEqual(tasks[0].reply_subject, "reply.subject.1")
self.assertEqual(tasks[1].reply_subject, "reply.subject.2")
mock_psub.fetch.assert_called_once_with(5, timeout=5)
mock_psub.unsubscribe.assert_called_once()

async def test_reserve_task_no_messages(self):
"""Test reserve_task when no messages are available."""
async def test_reserve_tasks_no_messages(self):
"""Test reserve_tasks when no messages are available (timeout)."""
nc, js = self._create_mock_nats_connection()
import nats.errors

mock_psub = MagicMock()
mock_psub.fetch = AsyncMock(return_value=[])
mock_psub.fetch = AsyncMock(side_effect=nats.errors.TimeoutError)
mock_psub.unsubscribe = AsyncMock()
js.pull_subscribe = AsyncMock(return_value=mock_psub)

with patch("ami.ml.orchestration.nats_queue.get_connection", AsyncMock(return_value=(nc, js))):
async with TaskQueueManager() as manager:
task = await manager.reserve_task(123)
tasks = await manager.reserve_tasks(123, count=5)

self.assertIsNone(task)
self.assertEqual(tasks, [])
mock_psub.unsubscribe.assert_called_once()

async def test_reserve_tasks_single(self):
"""Test reserving a single task."""
nc, js = self._create_mock_nats_connection()
sample_task = self._create_sample_task()

mock_msg = MagicMock()
mock_msg.data = sample_task.json().encode()
mock_msg.reply = "reply.subject.123"

mock_psub = MagicMock()
mock_psub.fetch = AsyncMock(return_value=[mock_msg])
mock_psub.unsubscribe = AsyncMock()
js.pull_subscribe = AsyncMock(return_value=mock_psub)

with patch("ami.ml.orchestration.nats_queue.get_connection", AsyncMock(return_value=(nc, js))):
async with TaskQueueManager() as manager:
tasks = await manager.reserve_tasks(123, count=1)

self.assertEqual(len(tasks), 1)
self.assertEqual(tasks[0].reply_subject, "reply.subject.123")

async def test_acknowledge_task_success(self):
"""Test successful task acknowledgment."""
nc, js = self._create_mock_nats_connection()
Expand Down Expand Up @@ -144,7 +171,7 @@ async def test_operations_without_connection_raise_error(self):
await manager.publish_task(123, sample_task)

with self.assertRaisesRegex(RuntimeError, "Connection is not open"):
await manager.reserve_task(123)
await manager.reserve_tasks(123, count=1)

with self.assertRaisesRegex(RuntimeError, "Connection is not open"):
await manager.delete_stream(123)