Skip to content
Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions ami/main/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,11 @@ class TaxonRank(OrderedEnum):
NULL_DETECTIONS_FILTER = Q(bbox__isnull=True) | Q(bbox=[])


def bbox_is_null(bbox) -> bool:
"""In-memory equivalent of NULL_DETECTIONS_FILTER for already-fetched bbox values."""
return bbox is None or bbox == []


def get_media_url(path: str) -> str:
"""
If path is a full URL, return it as-is.
Expand Down
188 changes: 147 additions & 41 deletions ami/ml/models/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import collections
import dataclasses
import itertools
import logging
import time
import typing
Expand All @@ -23,7 +24,6 @@
from ami.base.models import BaseModel, BaseQuerySet
from ami.base.schemas import ConfigurableStage, default_stages
from ami.main.models import (
NULL_DETECTIONS_FILTER,
Classification,
Deployment,
Detection,
Expand All @@ -34,6 +34,7 @@
TaxaList,
Taxon,
TaxonRank,
bbox_is_null,
update_calculated_fields_for_events,
update_occurrence_determination,
)
Expand All @@ -57,10 +58,23 @@
logger = logging.getLogger(__name__)


FILTER_PROCESSED_BATCH_SIZE = 1000
# Minimum wall-time (seconds) between in-loop Job.progress writes inside
# filter_processed_images. The reaper's "no forward progress" heuristic keys off
# job.updated_at, so we need to tick at least once per ~10min; 5s gives ample
# headroom while still amortising the JSONB UPDATE across many chunks. Caps at
# 0.99 so the caller still owns the final status=SUCCESS, progress=1 flip.
COLLECT_PROGRESS_SAVE_INTERVAL_SECONDS = 5.0
COLLECT_PROGRESS_MAX_FRACTION = 0.99


def filter_processed_images(
images: typing.Iterable[SourceImage],
pipeline: Pipeline,
task_logger: logging.Logger = logger,
batch_size: int = FILTER_PROCESSED_BATCH_SIZE,
job: Job | None = None,
total: int | None = None,
) -> typing.Iterable[SourceImage]:
"""
Return only images that need to be processed by a given pipeline.
Expand All @@ -79,56 +93,131 @@ def filter_processed_images(

Null detections are sentinels that mark an image as "processed, nothing found."
They are excluded from classification checks so they don't trigger reprocessing.

Input images are processed in batches so the work scales as O(image_count /
batch_size) database round-trips instead of O(image_count). This keeps the
Collect stage well under the Celery broker's AMQP heartbeat window for the
large-collection case (see issue #1321).
"""
pipeline_algorithms = pipeline.algorithms.all()
if batch_size <= 0:
raise ValueError(f"batch_size must be > 0, got {batch_size}")

pipeline_algorithms = list(pipeline.algorithms.all())
pipeline_algorithm_ids = [a.id for a in pipeline_algorithms]

detection_type_keys = Algorithm.detection_task_types
detection_algorithms = pipeline_algorithms.filter(task_type__in=detection_type_keys)
if not detection_algorithms.exists():
detection_type_keys = set(Algorithm.detection_task_types)
has_detection_algorithm = any(a.task_type in detection_type_keys for a in pipeline_algorithms)
if not has_detection_algorithm:
task_logger.warning(f"Pipeline {pipeline} has no detection algorithms saved. Will reprocess all images.")
classification_algorithms = pipeline_algorithms.exclude(task_type__in=detection_type_keys)
if not classification_algorithms.exists():
pipeline_classifier_ids = {a.id for a in pipeline_algorithms if a.task_type not in detection_type_keys}
if not pipeline_classifier_ids:
task_logger.warning(f"Pipeline {pipeline} has no classification algorithms saved. Will reprocess all images.")
Comment thread
coderabbitai[bot] marked this conversation as resolved.

for image in images:
existing_detections = image.detections.filter(detection_algorithm__in=pipeline_algorithms)
if not existing_detections.exists():
task_logger.debug(f"Image {image} needs processing: has no existing detections from pipeline's detector")
# If there are no existing detections from this pipeline, send the image
yield image
elif not existing_detections.exclude(NULL_DETECTIONS_FILTER).exists(): # type: ignore
# All detections for this image are null (processed but nothing found) — skip
task_logger.debug(f"Image {image} has only null detections from pipeline {pipeline}, skipping!")
continue
elif existing_detections.exclude(NULL_DETECTIONS_FILTER).filter(classifications__isnull=True).exists():
# Check if any real detections (non-null) have no classifications
task_logger.debug(
f"Image {image} needs processing: has existing detections with no classifications "
"from pipeline {pipeline}"
# set().issubset(anything) is vacuously True, so without this short-circuit
# every image with existing detections gets skipped by the subset check
# below — contradicting the warning above. Yield everything and exit.
yield from images
return

image_iter = iter(images)
# Track how many of the input images we've inspected so far so we can emit
# a fractional `collect` progress to the Job row. Only used when both
# `job` and `total` are passed by the caller; legacy callers stay silent.
processed_count = 0
last_progress_save_monotonic = time.monotonic()
while True:
batch = list(itertools.islice(image_iter, batch_size))
Comment thread
mihow marked this conversation as resolved.
if not batch:
return

batch_ids = [image.pk for image in batch]

images_with_pipeline_detection: set[int] = set()
real_pipeline_detections_per_image: dict[int, set[int]] = collections.defaultdict(set)
detection_rows = (
Detection.objects.filter(
source_image_id__in=batch_ids,
detection_algorithm_id__in=pipeline_algorithm_ids,
)
yield image
else:
# If there are existing detections with classifications,
# Compare their classification algorithms to the current pipeline's algorithms
pipeline_algorithm_ids = set(classification_algorithms.values_list("id", flat=True))
detection_algorithm_ids = set(existing_detections.values_list("classifications__algorithm_id", flat=True))
.order_by()
.values_list("source_image_id", "id", "bbox")
)
for source_image_id, detection_id, bbox in detection_rows:
images_with_pipeline_detection.add(source_image_id)
if not bbox_is_null(bbox):
real_pipeline_detections_per_image[source_image_id].add(detection_id)

classifier_ids_per_detection: dict[int, set[int]] = collections.defaultdict(set)
real_pipeline_detection_ids = [d for dets in real_pipeline_detections_per_image.values() for d in dets]
if real_pipeline_detection_ids:
classification_rows = (
Classification.objects.filter(detection_id__in=real_pipeline_detection_ids)
.order_by()
.values_list("detection_id", "algorithm_id")
)
for detection_id, algorithm_id in classification_rows:
classifier_ids_per_detection[detection_id].add(algorithm_id)

if not pipeline_algorithm_ids.issubset(detection_algorithm_ids):
for image in batch:
image_id = image.pk

if image_id not in images_with_pipeline_detection:
task_logger.debug(
f"Image {image} has existing detections that haven't been classified by the pipeline: {pipeline}:"
f" {detection_algorithm_ids} vs {pipeline_algorithm_ids}"
f"Since we do yet have a mechanism to reclassify detections, processing the image from scratch."
f"Image {image} needs processing: has no existing detections from pipeline's detector"
)
yield image
continue

real_det_ids = real_pipeline_detections_per_image.get(image_id)
if not real_det_ids:
task_logger.debug(f"Image {image} has only null detections from pipeline {pipeline}, skipping!")
continue

if any(detection_id not in classifier_ids_per_detection for detection_id in real_det_ids):
task_logger.debug(
f"Image {image} needs processing: has existing detections with no classifications "
f"from pipeline {pipeline}"
)
yield image
continue

observed_classifier_ids: set[int] = set()
for detection_id in real_det_ids:
observed_classifier_ids.update(classifier_ids_per_detection[detection_id])

if not pipeline_classifier_ids.issubset(observed_classifier_ids):
missing_algos = pipeline_classifier_ids - observed_classifier_ids
task_logger.debug(
f"Image {image} has existing detections that haven't been classified by the pipeline: {pipeline}: "
f"{observed_classifier_ids} vs {pipeline_classifier_ids}. "
f"Since we do not yet have a mechanism to reclassify detections, "
f"processing the image from scratch."
)
Comment thread
mihow marked this conversation as resolved.
# log all algorithms that are in the pipeline but not in the detection
missing_algos = pipeline_algorithm_ids - detection_algorithm_ids
task_logger.debug(f"Image #{image.pk} needs classification by pipeline's algorithms: {missing_algos}")
yield image
else:
# If all detections have been classified by the pipeline, skip the image
task_logger.debug(
f"Image {image} has existing detections classified by the pipeline: {pipeline}, skipping!"
)
continue

# Throttled progress emit. Save only when both `job` and `total` are
# provided, the total is non-zero, and at least
# COLLECT_PROGRESS_SAVE_INTERVAL_SECONDS of wall time have passed since
# the last save. Capped at COLLECT_PROGRESS_MAX_FRACTION so the caller's
# final status=SUCCESS, progress=1 flip still owns the terminal value.
#
# `updated_at` is included in update_fields explicitly: Django only fires
# auto_now's pre_save hook for fields listed in update_fields, so without
# it the reaper's `Job.updated_at < cutoff` heuristic
# (`ami/jobs/tasks.py:929-944`) would not see this heartbeat and could
# still revoke the job mid-Collect.
processed_count += len(batch)
if job is not None and total:
now_monotonic = time.monotonic()
if now_monotonic - last_progress_save_monotonic >= COLLECT_PROGRESS_SAVE_INTERVAL_SECONDS:
fraction = min(processed_count / total, COLLECT_PROGRESS_MAX_FRACTION)
job.progress.update_stage("collect", progress=fraction)
job.save(update_fields=["progress", "updated_at"])
last_progress_save_monotonic = now_monotonic


def collect_images(
Expand All @@ -151,21 +240,38 @@ def collect_images(
else:
job = None

# Set source to first argument that is not None
# Set source to first argument that is not None. The collection- and
# deployment-backed branches prefetch the deployment + data_source joins so
# later calls to image.url() in queue_images_to_nats don't trigger N+1
# lookups (see issue #1321). The source_images branch passes the caller's
# iterable through unchanged — callers handing us a pre-built list are
# responsible for any prefetching they need.
if collection:
images = collection.images.all()
images = collection.images.select_related("deployment__data_source")
elif source_images:
images = source_images
elif deployment:
images = SourceImage.objects.filter(deployment=deployment)
images = SourceImage.objects.filter(deployment=deployment).select_related("deployment__data_source")
else:
raise ValueError("Must specify a collection, deployment or a list of images")

total_images = len(images)
if pipeline and not reprocess_all_images:
msg = f"Filtering images that have already been processed by pipeline {pipeline}"
task_logger.info(msg)
images = list(filter_processed_images(images, pipeline, task_logger=task_logger))
# Pass `job` and `total` so filter_processed_images can emit throttled
# `collect` progress updates as it chews through the input — keeps the
# reaper "no forward progress" heuristic happy on large collections
# (issue #1321 follow-up).
images = list(
filter_processed_images(
images,
pipeline,
task_logger=task_logger,
job=job,
total=total_images,
)
)
else:
msg = "NOT filtering images that have already been processed"
task_logger.info(msg)
Expand Down
60 changes: 44 additions & 16 deletions ami/ml/orchestration/jobs.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import logging

from asgiref.sync import async_to_sync
Expand All @@ -10,6 +11,14 @@

logger = logging.getLogger(__name__)

# Number of concurrent JetStream publishes per fanout chunk. The bottleneck on
# large-collection jobs is the per-message ack round-trip (~1.3ms each on
# Serbia 2026-05-27, sequential), so awaiting publishes one at a time scales
# linearly with image count and pushes >450k-image jobs past the reaper
# threshold. Issuing ~200 publishes per gather() lets NATS pipeline the acks
# back to us; the chunk boundary keeps memory and concurrent-task counts bounded.
NATS_PUBLISH_FANOUT_CHUNK_SIZE = 200


def cleanup_async_job_resources(job_id: int) -> bool:
"""
Expand Down Expand Up @@ -91,7 +100,11 @@ def queue_images_to_nats(job: "Job", images: list[SourceImage]):
skipped_count = 0
for image in images:
image_id = str(image.pk)
image_url = image.url() if hasattr(image, "url") and image.url() else ""
# Call image.url() exactly once per iteration — the implementation
# touches deployment + data_source and the call cost adds up across
# large collections. The upstream queryset in collect_images() also
# prefetches those joins so this stays cheap (see issue #1321).
image_url = image.url()
if not image_url:
job.logger.warning(f"Image {image.pk} has no URL, skipping queuing to NATS for job '{job.pk}'")
skipped_count += 1
Expand Down Expand Up @@ -120,28 +133,43 @@ async def queue_all_images():
# the sync_to_async bridge for JobLogHandler's ORM save lives in one
# place instead of being re-implemented at every call site.
async with TaskQueueManager(job_logger=job.logger) as manager:
for image_pk, task in tasks:
# Warm the stream + consumer caches once so per-publish calls skip
# the cached-noop branches in publish_task -> _ensure_stream /
# _ensure_consumer. Even though those branches are O(1) after the
# first call, each one still runs inside the publish coroutine and
# serialises with the gather below.
try:
await manager.ensure_job_resources(job.pk)
except Exception as e:
await manager.log_async(
logging.ERROR,
f"Failed to set up NATS stream/consumer for job '{job.pk}': {e}",
exc_info=True,
)
return 0, len(tasks)

async def publish_one(image_pk: int, task: PipelineProcessingTask) -> bool:
try:
await manager.log_async(
logging.DEBUG,
f"Queueing image {image_pk} to stream for job '{job.pk}': {task.image_url}",
)
success = await manager.publish_task(
job_id=job.pk,
data=task,
)
return await manager.publish_task(job_id=job.pk, data=task)
except Exception as e:
await manager.log_async(
logging.ERROR,
f"Failed to queue image {image_pk} to stream for job '{job.pk}': {e}",
exc_info=True,
)
success = False

if success:
successful_queues += 1
else:
failed_queues += 1
return False

for chunk_start in range(0, len(tasks), NATS_PUBLISH_FANOUT_CHUNK_SIZE):
chunk = tasks[chunk_start : chunk_start + NATS_PUBLISH_FANOUT_CHUNK_SIZE]
results = await asyncio.gather(
*(publish_one(image_pk, task) for image_pk, task in chunk),
return_exceptions=False,
)
for success in results:
if success:
successful_queues += 1
else:
failed_queues += 1

return successful_queues, failed_queues

Expand Down
14 changes: 14 additions & 0 deletions ami/ml/orchestration/nats_queue.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,6 +290,20 @@ async def _stream_exists(self, stream_name: str) -> bool:
except nats.js.errors.NotFoundError:
return False

async def ensure_job_resources(self, job_id: int) -> None:
"""Ensure both the stream and consumer for the given job exist.

Public wrapper for the two ``_ensure_*`` calls. Callers that want to
pre-warm both before a hot loop (e.g. a publish fanout) should use this
rather than touching the private methods directly — keeps the
TaskQueueManager's stream/consumer setup story in one place.

Subsequent calls in the same manager session skip the NATS round-trip
via the per-instance ``_streams_logged`` / ``_consumers_logged`` sets.
"""
await self._ensure_stream(job_id)
await self._ensure_consumer(job_id)

async def _ensure_stream(self, job_id: int):
"""Ensure stream exists for the given job.

Expand Down
Loading
Loading