Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions ami/main/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,11 @@ class TaxonRank(OrderedEnum):
NULL_DETECTIONS_FILTER = Q(bbox__isnull=True) | Q(bbox=[])


def bbox_is_null(bbox) -> bool:
"""In-memory equivalent of NULL_DETECTIONS_FILTER for already-fetched bbox values."""
return bbox is None or bbox == []


def get_media_url(path: str) -> str:
"""
If path is a full URL, return it as-is.
Expand Down
124 changes: 86 additions & 38 deletions ami/ml/models/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

import collections
import dataclasses
import itertools
import logging
import time
import typing
Expand All @@ -23,7 +24,6 @@
from ami.base.models import BaseModel, BaseQuerySet
from ami.base.schemas import ConfigurableStage, default_stages
from ami.main.models import (
NULL_DETECTIONS_FILTER,
Classification,
Deployment,
Detection,
Expand All @@ -34,6 +34,7 @@
TaxaList,
Taxon,
TaxonRank,
bbox_is_null,
update_calculated_fields_for_events,
update_occurrence_determination,
)
Expand All @@ -57,10 +58,14 @@
logger = logging.getLogger(__name__)


FILTER_PROCESSED_BATCH_SIZE = 1000


def filter_processed_images(
images: typing.Iterable[SourceImage],
pipeline: Pipeline,
task_logger: logging.Logger = logger,
batch_size: int = FILTER_PROCESSED_BATCH_SIZE,
) -> typing.Iterable[SourceImage]:
"""
Return only images that need to be processed by a given pipeline.
Expand All @@ -79,56 +84,97 @@ def filter_processed_images(

Null detections are sentinels that mark an image as "processed, nothing found."
They are excluded from classification checks so they don't trigger reprocessing.

Input images are processed in batches so the work scales as O(image_count /
batch_size) database round-trips instead of O(image_count). This keeps the
Collect stage well under the Celery broker's AMQP heartbeat window for the
large-collection case (see issue #1321).
"""
pipeline_algorithms = pipeline.algorithms.all()
pipeline_algorithms = list(pipeline.algorithms.all())
pipeline_algorithm_ids = [a.id for a in pipeline_algorithms]

detection_type_keys = Algorithm.detection_task_types
detection_algorithms = pipeline_algorithms.filter(task_type__in=detection_type_keys)
if not detection_algorithms.exists():
detection_type_keys = set(Algorithm.detection_task_types)
has_detection_algorithm = any(a.task_type in detection_type_keys for a in pipeline_algorithms)
if not has_detection_algorithm:
task_logger.warning(f"Pipeline {pipeline} has no detection algorithms saved. Will reprocess all images.")
classification_algorithms = pipeline_algorithms.exclude(task_type__in=detection_type_keys)
if not classification_algorithms.exists():
pipeline_classifier_ids = {a.id for a in pipeline_algorithms if a.task_type not in detection_type_keys}
if not pipeline_classifier_ids:
task_logger.warning(f"Pipeline {pipeline} has no classification algorithms saved. Will reprocess all images.")
Comment thread
coderabbitai[bot] marked this conversation as resolved.

for image in images:
existing_detections = image.detections.filter(detection_algorithm__in=pipeline_algorithms)
if not existing_detections.exists():
task_logger.debug(f"Image {image} needs processing: has no existing detections from pipeline's detector")
# If there are no existing detections from this pipeline, send the image
yield image
elif not existing_detections.exclude(NULL_DETECTIONS_FILTER).exists(): # type: ignore
# All detections for this image are null (processed but nothing found) — skip
task_logger.debug(f"Image {image} has only null detections from pipeline {pipeline}, skipping!")
continue
elif existing_detections.exclude(NULL_DETECTIONS_FILTER).filter(classifications__isnull=True).exists():
# Check if any real detections (non-null) have no classifications
task_logger.debug(
f"Image {image} needs processing: has existing detections with no classifications "
"from pipeline {pipeline}"
image_iter = iter(images)
while True:
batch = list(itertools.islice(image_iter, batch_size))
Comment thread
mihow marked this conversation as resolved.
if not batch:
return

batch_ids = [image.pk for image in batch]

images_with_pipeline_detection: set[int] = set()
real_pipeline_detections_per_image: dict[int, set[int]] = collections.defaultdict(set)
detection_rows = (
Detection.objects.filter(
source_image_id__in=batch_ids,
detection_algorithm_id__in=pipeline_algorithm_ids,
)
yield image
else:
# If there are existing detections with classifications,
# Compare their classification algorithms to the current pipeline's algorithms
pipeline_algorithm_ids = set(classification_algorithms.values_list("id", flat=True))
detection_algorithm_ids = set(existing_detections.values_list("classifications__algorithm_id", flat=True))
.order_by()
.values_list("source_image_id", "id", "bbox")
)
for source_image_id, detection_id, bbox in detection_rows:
images_with_pipeline_detection.add(source_image_id)
if not bbox_is_null(bbox):
real_pipeline_detections_per_image[source_image_id].add(detection_id)

classifier_ids_per_detection: dict[int, set[int]] = collections.defaultdict(set)
real_pipeline_detection_ids = [d for dets in real_pipeline_detections_per_image.values() for d in dets]
if real_pipeline_detection_ids:
classification_rows = (
Classification.objects.filter(detection_id__in=real_pipeline_detection_ids)
.order_by()
.values_list("detection_id", "algorithm_id")
)
for detection_id, algorithm_id in classification_rows:
classifier_ids_per_detection[detection_id].add(algorithm_id)

for image in batch:
image_id = image.pk

if image_id not in images_with_pipeline_detection:
task_logger.debug(
f"Image {image} needs processing: has no existing detections from pipeline's detector"
)
yield image
continue

real_det_ids = real_pipeline_detections_per_image.get(image_id)
if not real_det_ids:
task_logger.debug(f"Image {image} has only null detections from pipeline {pipeline}, skipping!")
continue

if any(detection_id not in classifier_ids_per_detection for detection_id in real_det_ids):
task_logger.debug(
f"Image {image} needs processing: has existing detections with no classifications "
f"from pipeline {pipeline}"
)
yield image
continue

observed_classifier_ids: set[int] = set()
for detection_id in real_det_ids:
observed_classifier_ids.update(classifier_ids_per_detection[detection_id])

if not pipeline_algorithm_ids.issubset(detection_algorithm_ids):
if not pipeline_classifier_ids.issubset(observed_classifier_ids):
missing_algos = pipeline_classifier_ids - observed_classifier_ids
task_logger.debug(
f"Image {image} has existing detections that haven't been classified by the pipeline: {pipeline}:"
f" {detection_algorithm_ids} vs {pipeline_algorithm_ids}"
f"Image {image} has existing detections that haven't been classified by the pipeline: {pipeline}: "
f"{observed_classifier_ids} vs {pipeline_classifier_ids}. "
f"Since we do yet have a mechanism to reclassify detections, processing the image from scratch."
)
Comment thread
mihow marked this conversation as resolved.
# log all algorithms that are in the pipeline but not in the detection
missing_algos = pipeline_algorithm_ids - detection_algorithm_ids
task_logger.debug(f"Image #{image.pk} needs classification by pipeline's algorithms: {missing_algos}")
yield image
else:
# If all detections have been classified by the pipeline, skip the image
task_logger.debug(
f"Image {image} has existing detections classified by the pipeline: {pipeline}, skipping!"
)
continue


def collect_images(
Expand All @@ -151,13 +197,15 @@ def collect_images(
else:
job = None

# Set source to first argument that is not None
# Set source to first argument that is not None. Always prefetch the
# deployment + data_source joins so later calls to image.url() in
# queue_images_to_nats don't trigger N+1 lookups (see issue #1321).
Comment thread
mihow marked this conversation as resolved.
Outdated
if collection:
images = collection.images.all()
images = collection.images.select_related("deployment__data_source")
elif source_images:
images = source_images
elif deployment:
images = SourceImage.objects.filter(deployment=deployment)
images = SourceImage.objects.filter(deployment=deployment).select_related("deployment__data_source")
else:
raise ValueError("Must specify a collection, deployment or a list of images")

Expand Down
61 changes: 45 additions & 16 deletions ami/ml/orchestration/jobs.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import logging

from asgiref.sync import async_to_sync
Expand All @@ -10,6 +11,14 @@

logger = logging.getLogger(__name__)

# Number of concurrent JetStream publishes per fanout chunk. The bottleneck on
# large-collection jobs is the per-message ack round-trip (~1.3ms each on
# Serbia 2026-05-27, sequential), so awaiting publishes one at a time scales
# linearly with image count and pushes >450k-image jobs past the reaper
# threshold. Issuing ~200 publishes per gather() lets NATS pipeline the acks
# back to us; the chunk boundary keeps memory and concurrent-task counts bounded.
NATS_PUBLISH_FANOUT_CHUNK_SIZE = 200


def cleanup_async_job_resources(job_id: int) -> bool:
"""
Expand Down Expand Up @@ -91,7 +100,11 @@ def queue_images_to_nats(job: "Job", images: list[SourceImage]):
skipped_count = 0
for image in images:
image_id = str(image.pk)
image_url = image.url() if hasattr(image, "url") and image.url() else ""
# Call image.url() exactly once per iteration — the implementation
# touches deployment + data_source and the call cost adds up across
# large collections. The upstream queryset in collect_images() also
# prefetches those joins so this stays cheap (see issue #1321).
image_url = image.url() if hasattr(image, "url") else None
if not image_url:
job.logger.warning(f"Image {image.pk} has no URL, skipping queuing to NATS for job '{job.pk}'")
skipped_count += 1
Expand Down Expand Up @@ -120,28 +133,44 @@ async def queue_all_images():
# the sync_to_async bridge for JobLogHandler's ORM save lives in one
# place instead of being re-implemented at every call site.
async with TaskQueueManager(job_logger=job.logger) as manager:
for image_pk, task in tasks:
# Warm the stream + consumer caches once so per-publish calls skip
# the cached-noop branches in publish_task -> _ensure_stream /
# _ensure_consumer. Even though those branches are O(1) after the
# first call, each one still runs inside the publish coroutine and
# serialises with the gather below.
try:
await manager._ensure_stream(job.pk)
await manager._ensure_consumer(job.pk)
Comment thread
mihow marked this conversation as resolved.
Outdated
except Exception as e:
await manager.log_async(
logging.ERROR,
f"Failed to set up NATS stream/consumer for job '{job.pk}': {e}",
exc_info=True,
)
return 0, len(tasks)

async def publish_one(image_pk: int, task: PipelineProcessingTask) -> bool:
try:
await manager.log_async(
logging.DEBUG,
f"Queueing image {image_pk} to stream for job '{job.pk}': {task.image_url}",
)
success = await manager.publish_task(
job_id=job.pk,
data=task,
)
return await manager.publish_task(job_id=job.pk, data=task)
except Exception as e:
await manager.log_async(
logging.ERROR,
f"Failed to queue image {image_pk} to stream for job '{job.pk}': {e}",
exc_info=True,
)
success = False

if success:
successful_queues += 1
else:
failed_queues += 1
return False

for chunk_start in range(0, len(tasks), NATS_PUBLISH_FANOUT_CHUNK_SIZE):
chunk = tasks[chunk_start : chunk_start + NATS_PUBLISH_FANOUT_CHUNK_SIZE]
results = await asyncio.gather(
*(publish_one(image_pk, task) for image_pk, task in chunk),
return_exceptions=False,
)
for success in results:
if success:
successful_queues += 1
else:
failed_queues += 1

return successful_queues, failed_queues

Expand Down
Loading
Loading