diff --git a/ami/main/models.py b/ami/main/models.py index 3d21c07cd..1de9a01ee 100644 --- a/ami/main/models.py +++ b/ami/main/models.py @@ -98,6 +98,11 @@ class TaxonRank(OrderedEnum): NULL_DETECTIONS_FILTER = Q(bbox__isnull=True) | Q(bbox=[]) +def bbox_is_null(bbox) -> bool: + """In-memory equivalent of NULL_DETECTIONS_FILTER for already-fetched bbox values.""" + return bbox is None or bbox == [] + + def get_media_url(path: str) -> str: """ If path is a full URL, return it as-is. diff --git a/ami/ml/models/pipeline.py b/ami/ml/models/pipeline.py index c259e4aea..7cb54d14f 100644 --- a/ami/ml/models/pipeline.py +++ b/ami/ml/models/pipeline.py @@ -8,6 +8,7 @@ import collections import dataclasses +import itertools import logging import time import typing @@ -23,7 +24,6 @@ from ami.base.models import BaseModel, BaseQuerySet from ami.base.schemas import ConfigurableStage, default_stages from ami.main.models import ( - NULL_DETECTIONS_FILTER, Classification, Deployment, Detection, @@ -34,6 +34,7 @@ TaxaList, Taxon, TaxonRank, + bbox_is_null, update_calculated_fields_for_events, update_occurrence_determination, ) @@ -57,10 +58,23 @@ logger = logging.getLogger(__name__) +FILTER_PROCESSED_BATCH_SIZE = 1000 +# Minimum wall-time (seconds) between in-loop Job.progress writes inside +# filter_processed_images. The reaper's "no forward progress" heuristic keys off +# job.updated_at, so we need to tick at least once per ~10min; 5s gives ample +# headroom while still amortising the JSONB UPDATE across many chunks. Caps at +# 0.99 so the caller still owns the final status=SUCCESS, progress=1 flip. +COLLECT_PROGRESS_SAVE_INTERVAL_SECONDS = 5.0 +COLLECT_PROGRESS_MAX_FRACTION = 0.99 + + def filter_processed_images( images: typing.Iterable[SourceImage], pipeline: Pipeline, task_logger: logging.Logger = logger, + batch_size: int = FILTER_PROCESSED_BATCH_SIZE, + job: Job | None = None, + total: int | None = None, ) -> typing.Iterable[SourceImage]: """ Return only images that need to be processed by a given pipeline. @@ -79,56 +93,131 @@ def filter_processed_images( Null detections are sentinels that mark an image as "processed, nothing found." They are excluded from classification checks so they don't trigger reprocessing. + + Input images are processed in batches so the work scales as O(image_count / + batch_size) database round-trips instead of O(image_count). This keeps the + Collect stage well under the Celery broker's AMQP heartbeat window for the + large-collection case (see issue #1321). """ - pipeline_algorithms = pipeline.algorithms.all() + if batch_size <= 0: + raise ValueError(f"batch_size must be > 0, got {batch_size}") + + pipeline_algorithms = list(pipeline.algorithms.all()) + pipeline_algorithm_ids = [a.id for a in pipeline_algorithms] - detection_type_keys = Algorithm.detection_task_types - detection_algorithms = pipeline_algorithms.filter(task_type__in=detection_type_keys) - if not detection_algorithms.exists(): + detection_type_keys = set(Algorithm.detection_task_types) + has_detection_algorithm = any(a.task_type in detection_type_keys for a in pipeline_algorithms) + if not has_detection_algorithm: task_logger.warning(f"Pipeline {pipeline} has no detection algorithms saved. Will reprocess all images.") - classification_algorithms = pipeline_algorithms.exclude(task_type__in=detection_type_keys) - if not classification_algorithms.exists(): + pipeline_classifier_ids = {a.id for a in pipeline_algorithms if a.task_type not in detection_type_keys} + if not pipeline_classifier_ids: task_logger.warning(f"Pipeline {pipeline} has no classification algorithms saved. Will reprocess all images.") - - for image in images: - existing_detections = image.detections.filter(detection_algorithm__in=pipeline_algorithms) - if not existing_detections.exists(): - task_logger.debug(f"Image {image} needs processing: has no existing detections from pipeline's detector") - # If there are no existing detections from this pipeline, send the image - yield image - elif not existing_detections.exclude(NULL_DETECTIONS_FILTER).exists(): # type: ignore - # All detections for this image are null (processed but nothing found) — skip - task_logger.debug(f"Image {image} has only null detections from pipeline {pipeline}, skipping!") - continue - elif existing_detections.exclude(NULL_DETECTIONS_FILTER).filter(classifications__isnull=True).exists(): - # Check if any real detections (non-null) have no classifications - task_logger.debug( - f"Image {image} needs processing: has existing detections with no classifications " - "from pipeline {pipeline}" + # set().issubset(anything) is vacuously True, so without this short-circuit + # every image with existing detections gets skipped by the subset check + # below — contradicting the warning above. Yield everything and exit. + yield from images + return + + image_iter = iter(images) + # Track how many of the input images we've inspected so far so we can emit + # a fractional `collect` progress to the Job row. Only used when both + # `job` and `total` are passed by the caller; legacy callers stay silent. + processed_count = 0 + last_progress_save_monotonic = time.monotonic() + while True: + batch = list(itertools.islice(image_iter, batch_size)) + if not batch: + return + + batch_ids = [image.pk for image in batch] + + images_with_pipeline_detection: set[int] = set() + real_pipeline_detections_per_image: dict[int, set[int]] = collections.defaultdict(set) + detection_rows = ( + Detection.objects.filter( + source_image_id__in=batch_ids, + detection_algorithm_id__in=pipeline_algorithm_ids, ) - yield image - else: - # If there are existing detections with classifications, - # Compare their classification algorithms to the current pipeline's algorithms - pipeline_algorithm_ids = set(classification_algorithms.values_list("id", flat=True)) - detection_algorithm_ids = set(existing_detections.values_list("classifications__algorithm_id", flat=True)) + .order_by() + .values_list("source_image_id", "id", "bbox") + ) + for source_image_id, detection_id, bbox in detection_rows: + images_with_pipeline_detection.add(source_image_id) + if not bbox_is_null(bbox): + real_pipeline_detections_per_image[source_image_id].add(detection_id) + + classifier_ids_per_detection: dict[int, set[int]] = collections.defaultdict(set) + real_pipeline_detection_ids = [d for dets in real_pipeline_detections_per_image.values() for d in dets] + if real_pipeline_detection_ids: + classification_rows = ( + Classification.objects.filter(detection_id__in=real_pipeline_detection_ids) + .order_by() + .values_list("detection_id", "algorithm_id") + ) + for detection_id, algorithm_id in classification_rows: + classifier_ids_per_detection[detection_id].add(algorithm_id) - if not pipeline_algorithm_ids.issubset(detection_algorithm_ids): + for image in batch: + image_id = image.pk + + if image_id not in images_with_pipeline_detection: task_logger.debug( - f"Image {image} has existing detections that haven't been classified by the pipeline: {pipeline}:" - f" {detection_algorithm_ids} vs {pipeline_algorithm_ids}" - f"Since we do yet have a mechanism to reclassify detections, processing the image from scratch." + f"Image {image} needs processing: has no existing detections from pipeline's detector" + ) + yield image + continue + + real_det_ids = real_pipeline_detections_per_image.get(image_id) + if not real_det_ids: + task_logger.debug(f"Image {image} has only null detections from pipeline {pipeline}, skipping!") + continue + + if any(detection_id not in classifier_ids_per_detection for detection_id in real_det_ids): + task_logger.debug( + f"Image {image} needs processing: has existing detections with no classifications " + f"from pipeline {pipeline}" + ) + yield image + continue + + observed_classifier_ids: set[int] = set() + for detection_id in real_det_ids: + observed_classifier_ids.update(classifier_ids_per_detection[detection_id]) + + if not pipeline_classifier_ids.issubset(observed_classifier_ids): + missing_algos = pipeline_classifier_ids - observed_classifier_ids + task_logger.debug( + f"Image {image} has existing detections that haven't been classified by the pipeline: {pipeline}: " + f"{observed_classifier_ids} vs {pipeline_classifier_ids}. " + f"Since we do not yet have a mechanism to reclassify detections, " + f"processing the image from scratch." ) - # log all algorithms that are in the pipeline but not in the detection - missing_algos = pipeline_algorithm_ids - detection_algorithm_ids task_logger.debug(f"Image #{image.pk} needs classification by pipeline's algorithms: {missing_algos}") yield image else: - # If all detections have been classified by the pipeline, skip the image task_logger.debug( f"Image {image} has existing detections classified by the pipeline: {pipeline}, skipping!" ) - continue + + # Throttled progress emit. Save only when both `job` and `total` are + # provided, the total is non-zero, and at least + # COLLECT_PROGRESS_SAVE_INTERVAL_SECONDS of wall time have passed since + # the last save. Capped at COLLECT_PROGRESS_MAX_FRACTION so the caller's + # final status=SUCCESS, progress=1 flip still owns the terminal value. + # + # `updated_at` is included in update_fields explicitly: Django only fires + # auto_now's pre_save hook for fields listed in update_fields, so without + # it the reaper's `Job.updated_at < cutoff` heuristic + # (`ami/jobs/tasks.py:929-944`) would not see this heartbeat and could + # still revoke the job mid-Collect. + processed_count += len(batch) + if job is not None and total: + now_monotonic = time.monotonic() + if now_monotonic - last_progress_save_monotonic >= COLLECT_PROGRESS_SAVE_INTERVAL_SECONDS: + fraction = min(processed_count / total, COLLECT_PROGRESS_MAX_FRACTION) + job.progress.update_stage("collect", progress=fraction) + job.save(update_fields=["progress", "updated_at"]) + last_progress_save_monotonic = now_monotonic def collect_images( @@ -151,13 +240,18 @@ def collect_images( else: job = None - # Set source to first argument that is not None + # Set source to first argument that is not None. The collection- and + # deployment-backed branches prefetch the deployment + data_source joins so + # later calls to image.url() in queue_images_to_nats don't trigger N+1 + # lookups (see issue #1321). The source_images branch passes the caller's + # iterable through unchanged — callers handing us a pre-built list are + # responsible for any prefetching they need. if collection: - images = collection.images.all() + images = collection.images.select_related("deployment__data_source") elif source_images: images = source_images elif deployment: - images = SourceImage.objects.filter(deployment=deployment) + images = SourceImage.objects.filter(deployment=deployment).select_related("deployment__data_source") else: raise ValueError("Must specify a collection, deployment or a list of images") @@ -165,7 +259,19 @@ def collect_images( if pipeline and not reprocess_all_images: msg = f"Filtering images that have already been processed by pipeline {pipeline}" task_logger.info(msg) - images = list(filter_processed_images(images, pipeline, task_logger=task_logger)) + # Pass `job` and `total` so filter_processed_images can emit throttled + # `collect` progress updates as it chews through the input — keeps the + # reaper "no forward progress" heuristic happy on large collections + # (issue #1321 follow-up). + images = list( + filter_processed_images( + images, + pipeline, + task_logger=task_logger, + job=job, + total=total_images, + ) + ) else: msg = "NOT filtering images that have already been processed" task_logger.info(msg) diff --git a/ami/ml/orchestration/jobs.py b/ami/ml/orchestration/jobs.py index 79f787e28..a965ed7c9 100644 --- a/ami/ml/orchestration/jobs.py +++ b/ami/ml/orchestration/jobs.py @@ -1,3 +1,4 @@ +import asyncio import logging from asgiref.sync import async_to_sync @@ -10,6 +11,14 @@ logger = logging.getLogger(__name__) +# Number of concurrent JetStream publishes per fanout chunk. The bottleneck on +# large-collection jobs is the per-message ack round-trip (~1.3ms each on +# Serbia 2026-05-27, sequential), so awaiting publishes one at a time scales +# linearly with image count and pushes >450k-image jobs past the reaper +# threshold. Issuing ~200 publishes per gather() lets NATS pipeline the acks +# back to us; the chunk boundary keeps memory and concurrent-task counts bounded. +NATS_PUBLISH_FANOUT_CHUNK_SIZE = 200 + def cleanup_async_job_resources(job_id: int) -> bool: """ @@ -91,7 +100,11 @@ def queue_images_to_nats(job: "Job", images: list[SourceImage]): skipped_count = 0 for image in images: image_id = str(image.pk) - image_url = image.url() if hasattr(image, "url") and image.url() else "" + # Call image.url() exactly once per iteration — the implementation + # touches deployment + data_source and the call cost adds up across + # large collections. The upstream queryset in collect_images() also + # prefetches those joins so this stays cheap (see issue #1321). + image_url = image.url() if not image_url: job.logger.warning(f"Image {image.pk} has no URL, skipping queuing to NATS for job '{job.pk}'") skipped_count += 1 @@ -120,28 +133,43 @@ async def queue_all_images(): # the sync_to_async bridge for JobLogHandler's ORM save lives in one # place instead of being re-implemented at every call site. async with TaskQueueManager(job_logger=job.logger) as manager: - for image_pk, task in tasks: + # Warm the stream + consumer caches once so per-publish calls skip + # the cached-noop branches in publish_task -> _ensure_stream / + # _ensure_consumer. Even though those branches are O(1) after the + # first call, each one still runs inside the publish coroutine and + # serialises with the gather below. + try: + await manager.ensure_job_resources(job.pk) + except Exception as e: + await manager.log_async( + logging.ERROR, + f"Failed to set up NATS stream/consumer for job '{job.pk}': {e}", + exc_info=True, + ) + return 0, len(tasks) + + async def publish_one(image_pk: int, task: PipelineProcessingTask) -> bool: try: - await manager.log_async( - logging.DEBUG, - f"Queueing image {image_pk} to stream for job '{job.pk}': {task.image_url}", - ) - success = await manager.publish_task( - job_id=job.pk, - data=task, - ) + return await manager.publish_task(job_id=job.pk, data=task) except Exception as e: await manager.log_async( logging.ERROR, f"Failed to queue image {image_pk} to stream for job '{job.pk}': {e}", exc_info=True, ) - success = False - - if success: - successful_queues += 1 - else: - failed_queues += 1 + return False + + for chunk_start in range(0, len(tasks), NATS_PUBLISH_FANOUT_CHUNK_SIZE): + chunk = tasks[chunk_start : chunk_start + NATS_PUBLISH_FANOUT_CHUNK_SIZE] + results = await asyncio.gather( + *(publish_one(image_pk, task) for image_pk, task in chunk), + return_exceptions=False, + ) + for success in results: + if success: + successful_queues += 1 + else: + failed_queues += 1 return successful_queues, failed_queues diff --git a/ami/ml/orchestration/nats_queue.py b/ami/ml/orchestration/nats_queue.py index 1338e05ff..7f85c1ad3 100644 --- a/ami/ml/orchestration/nats_queue.py +++ b/ami/ml/orchestration/nats_queue.py @@ -290,6 +290,20 @@ async def _stream_exists(self, stream_name: str) -> bool: except nats.js.errors.NotFoundError: return False + async def ensure_job_resources(self, job_id: int) -> None: + """Ensure both the stream and consumer for the given job exist. + + Public wrapper for the two ``_ensure_*`` calls. Callers that want to + pre-warm both before a hot loop (e.g. a publish fanout) should use this + rather than touching the private methods directly — keeps the + TaskQueueManager's stream/consumer setup story in one place. + + Subsequent calls in the same manager session skip the NATS round-trip + via the per-instance ``_streams_logged`` / ``_consumers_logged`` sets. + """ + await self._ensure_stream(job_id) + await self._ensure_consumer(job_id) + async def _ensure_stream(self, job_id: int): """Ensure stream exists for the given job. diff --git a/ami/ml/orchestration/tests/test_jobs.py b/ami/ml/orchestration/tests/test_jobs.py new file mode 100644 index 000000000..04aaf30ed --- /dev/null +++ b/ami/ml/orchestration/tests/test_jobs.py @@ -0,0 +1,127 @@ +from unittest.mock import AsyncMock, patch + +from django.test import TestCase + +from ami.jobs.models import Job, JobDispatchMode, MLJob +from ami.main.models import Deployment, Project, S3StorageSource, SourceImage, SourceImageCollection +from ami.ml.models import Pipeline +from ami.ml.orchestration.jobs import NATS_PUBLISH_FANOUT_CHUNK_SIZE, queue_images_to_nats + + +def _stub_manager(mock_manager_cls, publish_results=None): + """Return an AsyncMock TaskQueueManager instance with publish_task pre-stubbed. + + publish_results is an optional iterable of bools matching publish_task call + order. Defaults to always-True. + """ + instance = mock_manager_cls.return_value + instance.__aenter__ = AsyncMock(return_value=instance) + instance.__aexit__ = AsyncMock(return_value=False) + instance.ensure_job_resources = AsyncMock() + instance._ensure_stream = AsyncMock() + instance._ensure_consumer = AsyncMock() + instance.log_async = AsyncMock() + if publish_results is None: + instance.publish_task = AsyncMock(return_value=True) + else: + instance.publish_task = AsyncMock(side_effect=list(publish_results)) + return instance + + +@patch("ami.ml.orchestration.jobs.AsyncJobStateManager") +@patch("ami.ml.orchestration.jobs.TaskQueueManager") +class TestQueueImagesToNatsFanout(TestCase): + """Unit-test the chunk+gather behaviour of queue_images_to_nats. + + The integration test in test_cleanup.py uses a real NATS server. These + tests instead mock TaskQueueManager so they run without infrastructure and + can assert on call counts. + """ + + def setUp(self): + self.project = Project.objects.create(name="fanout test") + self.data_source = S3StorageSource.objects.create( + name="ds", + bucket="b", + access_key="x", + secret_key="y", # noqa: S106 - fixture value, never used as a real credential + public_base_url="https://example.invalid/", + project=self.project, + ) + self.deployment = Deployment.objects.create(name="d", project=self.project, data_source=self.data_source) + self.pipeline = Pipeline.objects.create(name="p", slug="p", version="1") + self.collection = SourceImageCollection.objects.create(name="c", project=self.project) + + def _make_images(self, n: int) -> list[SourceImage]: + # public_base_url is set so SourceImage.url() returns a truthy URL — + # otherwise queue_images_to_nats skips the image and the inner loop + # we're trying to test never runs. + rows = [ + SourceImage( + path=f"fanout/{i}.jpg", + deployment=self.deployment, + project=self.project, + public_base_url="https://example.invalid/", + ) + for i in range(n) + ] + SourceImage.objects.bulk_create(rows) + return list(SourceImage.objects.filter(project=self.project).order_by("pk")) + + def _make_job(self) -> Job: + return Job.objects.create( + name="fanout", + job_type_key=MLJob.key, + project=self.project, + pipeline=self.pipeline, + source_image_collection=self.collection, + dispatch_mode=JobDispatchMode.ASYNC_API, + ) + + def test_warms_stream_and_consumer_once_before_publishing(self, mock_mgr_cls, _state_cls): + instance = _stub_manager(mock_mgr_cls) + images = self._make_images(5) + job = self._make_job() + + queue_images_to_nats(job, images) + + self.assertEqual(instance.ensure_job_resources.await_count, 1) + self.assertEqual(instance.publish_task.await_count, 5) + + def test_chunks_across_fanout_boundary(self, mock_mgr_cls, _state_cls): + instance = _stub_manager(mock_mgr_cls) + # One full chunk + one short chunk: forces the inner loop to iterate twice. + n = NATS_PUBLISH_FANOUT_CHUNK_SIZE + 7 + images = self._make_images(n) + job = self._make_job() + + queue_images_to_nats(job, images) + + self.assertEqual(instance.publish_task.await_count, n) + # Warm-up still only runs once even when we span multiple chunks. + self.assertEqual(instance.ensure_job_resources.await_count, 1) + + def test_partial_failure_counted_as_failed_not_raised(self, mock_mgr_cls, _state_cls): + # 3 images, middle one fails. Method should return False (failure summary) + # but not raise — and the success count should reflect the two that worked. + instance = _stub_manager(mock_mgr_cls, publish_results=[True, False, True]) + images = self._make_images(3) + job = self._make_job() + + result = queue_images_to_nats(job, images) + + self.assertFalse(result) + self.assertEqual(instance.publish_task.await_count, 3) + + def test_stream_setup_failure_short_circuits_publishing(self, mock_mgr_cls, _state_cls): + # If ensure_job_resources raises, we should not attempt any publish_task + # calls — the whole batch is marked failed in one shot. + instance = _stub_manager(mock_mgr_cls) + instance.ensure_job_resources = AsyncMock(side_effect=RuntimeError("nats down")) + images = self._make_images(4) + job = self._make_job() + + result = queue_images_to_nats(job, images) + + self.assertFalse(result) + self.assertEqual(instance.publish_task.await_count, 0) diff --git a/ami/ml/tests.py b/ami/ml/tests.py index 0eaa0a237..ea375135e 100644 --- a/ami/ml/tests.py +++ b/ami/ml/tests.py @@ -554,6 +554,43 @@ def test_collect_images(self): images = list(collect_images(collection=self.image_collection, pipeline=self.pipeline)) assert len(images) == 2 + def test_collect_images_prefetches_deployment_and_data_source(self): + """ + collect_images() must hand back SourceImage rows with deployment and + deployment.data_source already joined, so the downstream queue_images_to_nats + loop doesn't trigger N+1 FK lookups inside image.url() (see issue #1321). + """ + from ami.main.models import S3StorageSource + + data_source = S3StorageSource.objects.create( + name="ds-prefetch-test", + bucket="prefetch-bucket", + access_key="x", + secret_key="y", # noqa: S106 - fixture value, never used as a real credential + public_base_url="https://example.invalid/", + project=self.project, + ) + deployment = Deployment.objects.create( + name="prefetch-deployment", project=self.project, data_source=data_source + ) + images = [ + SourceImage.objects.create(path=f"prefetch-{i}.jpg", deployment=deployment, project=self.project) + for i in range(3) + ] + collection = SourceImageCollection.objects.create(project=self.project, name="prefetch-collection") + collection.images.set(images) + + collected = list(collect_images(collection=collection, pipeline=self.pipeline)) + self.assertEqual(len(collected), 3) + + # Accessing deployment.data_source on the returned images should + # require zero extra queries because select_related joined both. + with self.assertNumQueries(0): + for image in collected: + self.assertEqual(image.deployment_id, deployment.pk) + self.assertEqual(image.deployment.data_source_id, data_source.pk) + self.assertEqual(image.deployment.data_source.public_base_url, "https://example.invalid/") + def fake_pipeline_results( self, source_images: list[SourceImage], @@ -962,6 +999,190 @@ def test_filter_processed_images_skips_null_and_fully_classified(self): result = list(filter_processed_images([image], self.pipeline)) self.assertEqual(result, [], "Fully classified image with null detection should be skipped") + def test_filter_processed_images_empty_input(self): + """An empty iterable should yield nothing and run no per-image queries.""" + from ami.ml.models.pipeline import filter_processed_images + + with self.assertNumQueries(1): # one query: pipeline.algorithms.all() + result = list(filter_processed_images([], self.pipeline)) + self.assertEqual(result, []) + + def test_filter_processed_images_yields_all_when_pipeline_has_no_classifiers(self): + """ + When a pipeline has no classifier algorithms registered, filter_processed_images + must yield every image (matching the "Will reprocess all images" warning). + Without the short-circuit, the empty `pipeline_classifier_ids` set makes + `set().issubset(observed) == True` and every image with existing detections + is silently skipped — directly contradicting the warning. + """ + from ami.ml.models.pipeline import filter_processed_images + + detector_only_pipeline = Pipeline.objects.create(name="Detector Only Pipeline") + detector_only_pipeline.algorithms.set([self.algorithms["random-detector"]]) + + # Image with a real, fully-processed-looking detection from the detector. + # Pre-short-circuit this would be skipped because the empty pipeline classifier + # set is vacuously a subset of any observed-classifier set. + image_with_detection = SourceImage.objects.create(path="no-classifier-with-det.jpg") + Detection.objects.create( + source_image=image_with_detection, + detection_algorithm=self.algorithms["random-detector"], + bbox=[0.1, 0.2, 0.3, 0.4], + ) + image_unprocessed = SourceImage.objects.create(path="no-classifier-unprocessed.jpg") + + result = list(filter_processed_images([image_with_detection, image_unprocessed], detector_only_pipeline)) + self.assertEqual(result, [image_with_detection, image_unprocessed]) + + def test_filter_processed_images_mixed_batch(self): + """ + A mixed batch of images covering all five branches should yield only + the ones that need processing, in input order. + """ + from ami.ml.models.pipeline import filter_processed_images + + detector = self.algorithms["random-detector"] + binary = self.algorithms["random-binary-classifier"] + species = self.algorithms["random-species-classifier"] + + unprocessed = SourceImage.objects.create(path="unprocessed.jpg") + null_only = SourceImage.objects.create(path="null_only.jpg") + unclassified = SourceImage.objects.create(path="unclassified.jpg") + fully_classified = SourceImage.objects.create(path="fully_classified.jpg") + + Detection.objects.create(source_image=null_only, detection_algorithm=detector, bbox=None) + + Detection.objects.create(source_image=unclassified, detection_algorithm=detector, bbox=[0.1, 0.2, 0.3, 0.4]) + + real_det = Detection.objects.create( + source_image=fully_classified, detection_algorithm=detector, bbox=[0.1, 0.2, 0.3, 0.4] + ) + taxon = Taxon.objects.create(name="Test Mixed Batch Taxon") + Classification.objects.create( + detection=real_det, taxon=taxon, algorithm=binary, score=0.9, timestamp=datetime.datetime.now() + ) + Classification.objects.create( + detection=real_det, taxon=taxon, algorithm=species, score=0.8, timestamp=datetime.datetime.now() + ) + + images = [unprocessed, null_only, unclassified, fully_classified] + result = list(filter_processed_images(images, self.pipeline)) + self.assertEqual(result, [unprocessed, unclassified]) + + def test_filter_processed_images_query_count_is_bounded_per_batch(self): + """ + With N images and batch_size=B, the query count should scale as + O(N / B), not O(N). Locks in the bulk-query rewrite from issue #1321. + + Each batch issues at most: + - 1 Detection bulk select + - 1 Classification bulk select (only when real detections exist) + Plus one initial pipeline.algorithms.all() that's shared across batches. + """ + from ami.ml.models.pipeline import filter_processed_images + + detector = self.algorithms["random-detector"] + binary = self.algorithms["random-binary-classifier"] + species = self.algorithms["random-species-classifier"] + taxon = Taxon.objects.create(name="Bounded Query Test Taxon") + + # 10 images. Batch 1: 5 fully-classified (triggers classification query). + # Batch 2: 5 unprocessed (no detections, classification query skipped). + images = [SourceImage.objects.create(path=f"bulk-{i}.jpg") for i in range(10)] + for image in images[:5]: + real_det = Detection.objects.create( + source_image=image, detection_algorithm=detector, bbox=[0.1, 0.2, 0.3, 0.4] + ) + Classification.objects.create( + detection=real_det, taxon=taxon, algorithm=binary, score=0.9, timestamp=datetime.datetime.now() + ) + Classification.objects.create( + detection=real_det, taxon=taxon, algorithm=species, score=0.8, timestamp=datetime.datetime.now() + ) + + # Expected: 1 (pipeline) + 2 (detection × 2 batches) + 1 (classification, batch 1 only) = 4. + with self.assertNumQueries(4): + result = list(filter_processed_images(images, self.pipeline, batch_size=5)) + + # First 5 fully classified → skipped. Last 5 fresh → yielded. + self.assertEqual(result, images[5:]) + + def test_filter_processed_images_emits_throttled_collect_progress(self): + """ + When `job` and `total` are passed, filter_processed_images should call + job.save(update_fields=["progress"]) at most once per + COLLECT_PROGRESS_SAVE_INTERVAL_SECONDS of wall time, capped at + COLLECT_PROGRESS_MAX_FRACTION. Keeps the reaper's "no forward progress" + heuristic happy on multi-minute Collect stages without hot-saving the + Job row on every chunk (issue #1321 follow-up). + """ + from unittest.mock import patch + + from ami.jobs.models import Job, MLJob + from ami.ml.models.pipeline import COLLECT_PROGRESS_MAX_FRACTION, filter_processed_images + + job = Job.objects.create( + project=self.project, + name="collect progress cadence test", + pipeline=self.pipeline, + job_type_key=MLJob.key, + ) + # First save triggered MLJob.setup → "collect" stage exists. + job.progress.get_stage("collect") + + images = [SourceImage.objects.create(path=f"cadence-{i}.jpg") for i in range(10)] + + # batch_size=3 over 10 images → 4 batches. Each monotonic() call + # advances the clock by 3s. Expected sequence: init=0, batch1=3 + # (gap 3, no save), batch2=6 (gap 6, SAVE → last=6), batch3=9 + # (gap 3, no save), batch4=12 (gap 6, SAVE → last=12). Two saves. + # + # Counter-based stub (vs an iter/next sequence) means extra + # monotonic() calls added to filter_processed_images later will + # advance time faster and the assertion will fail with a clear + # cadence mismatch, not StopIteration. + clock = {"t": 0.0} + + def fake_monotonic(): + t = clock["t"] + clock["t"] += 3.0 + return t + + with patch("ami.ml.models.pipeline.time.monotonic", side_effect=fake_monotonic): + with patch.object(Job, "save", autospec=True) as mock_save: + list(filter_processed_images(images, self.pipeline, batch_size=3, job=job, total=10)) + + self.assertEqual(mock_save.call_count, 2, "Expected throttle to allow exactly 2 saves") + for call in mock_save.call_args_list: + self.assertEqual( + call.kwargs.get("update_fields"), + ["progress", "updated_at"], + "Throttled saves must include `updated_at` so Django's auto_now fires " + "and the reaper's stale-job heuristic sees forward motion.", + ) + + # Final emitted fraction comes from the second save (batch 4): processed=10, + # total=10 → raw 1.0, capped at COLLECT_PROGRESS_MAX_FRACTION. + collect_stage = job.progress.get_stage("collect") + self.assertEqual(collect_stage.progress, COLLECT_PROGRESS_MAX_FRACTION) + + def test_filter_processed_images_skips_progress_emission_without_job(self): + """ + Legacy callers that omit `job` (the only callers before this change) + should see zero job.save() calls — the throttle block is fully gated + on both `job` and `total` being passed. + """ + from unittest.mock import patch + + from ami.jobs.models import Job + from ami.ml.models.pipeline import filter_processed_images + + images = [SourceImage.objects.create(path=f"nojob-{i}.jpg") for i in range(5)] + with patch.object(Job, "save", autospec=True) as mock_save: + list(filter_processed_images(images, self.pipeline, batch_size=2)) + + self.assertEqual(mock_save.call_count, 0) + def test_null_detections_are_algorithm_specific(self): """ Null detections from different pipelines/algorithms should not be shared.