diff --git a/ami/exports/format_types.py b/ami/exports/format_types.py index 9c57b5c25..bbe2e5c6c 100644 --- a/ami/exports/format_types.py +++ b/ami/exports/format_types.py @@ -1,14 +1,17 @@ import csv +import datetime import json import logging +import os import tempfile +from django.conf import settings from django.core.serializers.json import DjangoJSONEncoder from rest_framework import serializers from ami.exports.base import BaseExporter from ami.exports.utils import get_data_in_batches -from ami.main.models import Occurrence, SourceImage, get_media_url +from ami.main.models import Occurrence, SourceImage, Taxon, get_media_url from ami.ml.schemas import BoundingBox logger = logging.getLogger(__name__) @@ -190,6 +193,26 @@ def get_best_detection_height(self, obj): bbox = BoundingBox.from_coords(getattr(obj, "best_detection_bbox", None), raise_on_error=False) return bbox.height if bbox else None + def get_best_detection_source_image_id(self, obj): + """Returns the source image id for the best detection.""" + return getattr(obj, "best_detection_source_image_id", None) + + def get_best_detection_capture_timestamp(self, obj): + """Returns the capture timestamp for the best detection source image.""" + return getattr(obj, "best_detection_capture_timestamp", None) + + def get_best_detection_capture_path(self, obj): + """Returns the capture path for the best detection source image.""" + return getattr(obj, "best_detection_capture_path", None) + + def get_best_detection_capture_width(self, obj): + """Returns the capture width for the best detection source image.""" + return getattr(obj, "best_detection_capture_width", None) + + def get_best_detection_capture_height(self, obj): + """Returns the capture height for the best detection source image.""" + return getattr(obj, "best_detection_capture_height", None) + def get_best_detection_capture_url(self, obj): """Returns the public URL to the source capture (original full-frame image). @@ -197,13 +220,200 @@ def get_best_detection_capture_url(self, obj): capture (SourceImage) row per occurrence; presigned URLs for private buckets aren't supported here for the same reason. """ - path = getattr(obj, "best_detection_capture_path", None) + path = self.get_best_detection_capture_path(obj) base_url = getattr(obj, "best_detection_capture_public_base_url", None) if path and base_url: return SourceImage.build_public_url(base_url, path) return None +class OccurrenceCocoTabularSerializer(OccurrenceTabularSerializer): + """CSV-shaped occurrence row plus SourceImage metadata for COCO exports.""" + + # Best detection SourceImage fields + source_image_id = serializers.SerializerMethodField() + capture_timestamp = serializers.SerializerMethodField() + capture_path = serializers.SerializerMethodField() + capture_width = serializers.SerializerMethodField() + capture_height = serializers.SerializerMethodField() + best_machine_prediction_taxon_id = serializers.IntegerField(allow_null=True, default=None) + + class Meta(OccurrenceTabularSerializer.Meta): + fields = OccurrenceTabularSerializer.Meta.fields + [ + "source_image_id", + "capture_timestamp", + "capture_path", + "capture_width", + "capture_height", + "best_machine_prediction_taxon_id", + ] + + def get_source_image_id(self, obj): + return self.get_best_detection_source_image_id(obj) + + def get_capture_timestamp(self, obj): + ts = self.get_best_detection_capture_timestamp(obj) + if ts is None: + return None + return ts.isoformat() if hasattr(ts, "isoformat") else str(ts) + + def get_capture_path(self, obj): + return self.get_best_detection_capture_path(obj) + + def get_capture_width(self, obj): + return self.get_best_detection_capture_width(obj) + + def get_capture_height(self, obj): + return self.get_best_detection_capture_height(obj) + + +def corner_bbox_to_coco_bbox(corner: list | None) -> tuple[list[float], float] | None: + """Convert [x1, y1, x2, y2] to COCO [x, y, width, height] and area.""" + bbox = BoundingBox.from_coords(corner, raise_on_error=False) + if bbox is None: + return None + w, h = bbox.width, bbox.height + if w is None or h is None or w <= 0 or h <= 0: + return None + return [float(bbox.x1), float(bbox.y1), float(w), float(h)], float(w * h) + + +def build_coco_dict_from_occurrence_rows(rows: list[dict], project) -> dict: + """Build a COCO-style detection dict from serialized occurrence rows (determination categories only).""" + categories_by_id: dict[int, dict] = {} + images_by_id: dict[int, dict] = {} + annotations: list[dict] = [] + category_taxon_ids: set[int] = set() + + for row in rows: + determination_id = row.get("determination_id") + if determination_id is None: + logger.warning(f"No determination_id found for row: {row}") + continue + + coco_result = corner_bbox_to_coco_bbox(row.get("best_detection_bbox")) + if coco_result is None: + logger.warning(f"No coco_bbox found for row: {row}") + continue + + coco_bbox, area = coco_result + source_image_id = row.get("source_image_id") + if source_image_id is None: + logger.warning(f"No source_image_id found for row: {row}") + continue + + det_name = row.get("determination_name") or "" + if int(determination_id) not in categories_by_id: + categories_by_id[int(determination_id)] = { + "id": int(determination_id), + "name": det_name, + } + category_taxon_ids.add(int(determination_id)) + else: + assert ( + categories_by_id[int(determination_id)]["name"] == det_name + ), f"Determination name mismatch for id: {determination_id}" + + if source_image_id not in images_by_id: + cap_path = row.get("capture_path") or "" + images_by_id[int(source_image_id)] = { + "id": int(source_image_id), + "file_name": os.path.basename(cap_path) if cap_path else "", + "width": row.get("capture_width"), + "height": row.get("capture_height"), + "coco_url": row.get("best_detection_capture_url"), + "date_captured": row.get("capture_timestamp"), + } + + occ_id = row.get("id") + if occ_id is None: + logger.warning(f"No occ_id found for row: {row}") + continue + + ann: dict = { + "id": int(occ_id), + "image_id": int(source_image_id), + "category_id": int(determination_id), + "bbox": coco_bbox, + "area": area, + "iscrowd": 0, # TODO: Could we use this field to indiate crowd of insects? + "determination_score": row.get("determination_score"), + "verification_status": row.get("verification_status"), + "determination_matches_machine_prediction": row.get("determination_matches_machine_prediction"), + "best_machine_prediction_algorithm": row.get("best_machine_prediction_algorithm"), + "best_machine_prediction_score": row.get("best_machine_prediction_score"), + "best_machine_prediction_taxon_id": row.get("best_machine_prediction_taxon_id"), + "best_detection_width": row.get("best_detection_width"), + "best_detection_height": row.get("best_detection_height"), + } + prediction_taxon_id = row.get("best_machine_prediction_taxon_id") + if prediction_taxon_id is not None: + try: + category_taxon_ids.add(int(prediction_taxon_id)) + except (TypeError, ValueError): + logger.warning(f"Invalid best_machine_prediction_taxon_id for row: {row}") + annotations.append(ann) + + def _serialize_parents_json(parents_json): + if not isinstance(parents_json, list): + return [] + serialized = [] + for parent in parents_json: + if isinstance(parent, dict): + parent_id = parent.get("id") + parent_name = parent.get("name") + rank = parent.get("rank") + else: + # SchemaField(list[TaxonParent]) may return Pydantic objects rather than dicts. + parent_id = getattr(parent, "id", None) + parent_name = getattr(parent, "name", None) + rank = getattr(parent, "rank", None) + + if parent_id is None and parent_name is None and rank is None: + continue + + rank_value = rank.value if hasattr(rank, "value") else rank + serialized.append( + { + "id": parent_id, + "name": parent_name, + "rank": str(rank_value) if rank_value is not None else None, + } + ) + return serialized + + if category_taxon_ids: + taxa = Taxon.objects.filter(id__in=category_taxon_ids).values( + "id", "name", "rank", "parent_id", "parents_json" + ) + for taxon in taxa: + taxon_id = int(taxon["id"]) + categories_by_id[taxon_id] = { + "id": taxon_id, + "name": taxon.get("name") or categories_by_id.get(taxon_id, {}).get("name", ""), + "rank": taxon.get("rank"), + "parent_id": taxon.get("parent_id"), + "parents": _serialize_parents_json(taxon.get("parents_json")), + } + + base = getattr(settings, "EXTERNAL_BASE_URL", "") or "" + info_url = "" + if base.strip(): + info_url = f"{base.rstrip('/')}/projects/{project.pk}/summary" # TODO: is there a better way to do this? + + payload = { + "info": { + "description": f"{project.name} ({project.pk}) Occurrences", + "url": info_url, + "date_created": datetime.datetime.now(datetime.timezone.utc).isoformat(), + }, + "images": list(images_by_id.values()), + "annotations": annotations, + "categories": sorted(categories_by_id.values(), key=lambda c: c["id"]), + } + return payload + + class CSVExporter(BaseExporter): """Handles CSV export of occurrences.""" @@ -247,3 +457,46 @@ def export(self): self.update_job_progress(records_exported) self.update_export_stats(file_temp_path=temp_file.name) return temp_file.name # Return the file path + + +class CocoJSONExporter(BaseExporter): + """Exports occurrences as a COCO-style detection dataset (same rows as CSV plus capture metadata).""" + + file_format = "json" + + serializer_class = OccurrenceCocoTabularSerializer + + def get_queryset(self): + return ( + Occurrence.objects.valid() + .filter(project=self.project) + .select_related( + "determination", + "deployment", + "event", + ) + .with_timestamps() + .with_detections_count() + .with_identifications() + .with_best_detection() + .with_best_machine_prediction() + .with_verification_info() + ) + + def export(self): + """Serialize occurrences in batches, build COCO JSON, write to a temp file.""" + rows: list[dict] = [] + records_exported = 0 + for batch in get_data_in_batches(self.queryset, self.serializer_class): + rows.extend(batch) + records_exported += len(batch) + self.update_job_progress(records_exported) + + coco_payload = build_coco_dict_from_occurrence_rows(rows, self.project) + + temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".json", mode="w", encoding="utf-8") + with open(temp_file.name, "w", encoding="utf-8") as f: + json.dump(coco_payload, f, cls=DjangoJSONEncoder) + + self.update_export_stats(file_temp_path=temp_file.name) + return temp_file.name diff --git a/ami/exports/registry.py b/ami/exports/registry.py index 29a4cc0e7..291c1500f 100644 --- a/ami/exports/registry.py +++ b/ami/exports/registry.py @@ -27,3 +27,4 @@ def get_supported_formats(cls): ExportRegistry.register("occurrences_api_json")(format_types.JSONExporter) ExportRegistry.register("occurrences_simple_csv")(format_types.CSVExporter) +ExportRegistry.register("occurrences_coco_json")(format_types.CocoJSONExporter) diff --git a/ami/exports/tests.py b/ami/exports/tests.py index 866b1af61..db371fb9d 100644 --- a/ami/exports/tests.py +++ b/ami/exports/tests.py @@ -4,7 +4,7 @@ from django.core.files.base import ContentFile from django.core.files.storage import default_storage -from django.test import TestCase +from django.test import TestCase, override_settings from rest_framework.test import APIClient from ami.exports.models import DataExport @@ -210,6 +210,70 @@ def test_csv_export_has_detection_fields(self): # Clean up the exported file after the test default_storage.delete(file_path) + @override_settings(EXTERNAL_BASE_URL="https://antenna.example.org") + def test_coco_json_export_structure(self): + """COCO export is a single JSON object with images, annotations, categories; ids are consistent.""" + # Create a DataExport instance + data_export = DataExport.objects.create( + user=self.user, + project=self.project, + format="occurrences_coco_json", + filters={"collection_id": self.collection.pk}, + job=None, + ) + + # Run export and get the file URL + file_url = data_export.run_export() + + # Ensure the file is generated + self.assertIsNotNone(file_url) + file_path = file_url.replace("/media/", "") + self.assertTrue(default_storage.exists(file_path)) + + # Read and validate the exported data + with default_storage.open(file_path, "r") as f: + payload = json.load(f) + + # Ensure necessary COCO fields are present + self.assertIn("info", payload) + self.assertIn("images", payload) + self.assertIn("annotations", payload) + self.assertIn("categories", payload) + + # Ensure the info field points to the project summary page + # TODO: Change this to not be harcoded? + self.assertEqual( + payload["info"]["url"], + f"https://antenna.example.org/projects/{self.project.pk}/summary", + ) + + # Ensure the images field contains the correct number of images + image_ids = {img["id"] for img in payload["images"]} + category_ids = {c["id"] for c in payload["categories"]} + assert len(image_ids) > 0, "No images found in the export" + assert len(category_ids) > 0, "No categories found in the export" + assert len(payload["annotations"]) > 0, "No annotations found in the export" + + for ann in payload["annotations"]: + self.assertIn(ann["image_id"], image_ids) + self.assertIn(ann["category_id"], category_ids) + x, y, w, h = ann["bbox"] + self.assertGreater(w, 0) + self.assertGreater(h, 0) + + # Number of annotations should equal the number of occurrences in + # the collection, but excluding occurrences without a determination + occurrences_with_determination = ( + Occurrence.objects.valid() + .filter(detections__source_image__collections=self.collection, determination__isnull=False) + .distinct() + .count() + ) + self.assertEqual(len(payload["annotations"]), occurrences_with_determination) + + # Clean up the exported file after the test + default_storage.delete(file_path) + class DataExportPermissionTest(TestCase): """Test data export permissions (create, update, delete).""" diff --git a/ami/main/api/serializers.py b/ami/main/api/serializers.py index ec79603e7..6dfeb1b2c 100644 --- a/ami/main/api/serializers.py +++ b/ami/main/api/serializers.py @@ -702,7 +702,7 @@ class TaxaListTaxonSerializer(TaxonNoParentNestedSerializer): class CaptureTaxonSerializer(DefaultSerializer): parent = TaxonNoParentNestedSerializer(read_only=True) - parents = TaxonParentSerializer(many=True, read_only=True) + parents = TaxonParentSerializer(many=True, read_only=True, source="parents_json") def get_permissions(self, instance, instance_data): instance_data["user_permissions"] = [] diff --git a/ami/main/models.py b/ami/main/models.py index c604c6319..40f2e4165 100644 --- a/ami/main/models.py +++ b/ami/main/models.py @@ -2881,6 +2881,9 @@ def with_best_detection(self): - best_detection_bbox: The bounding box of the detection as a list [x1, y1, x2, y2] - best_detection_capture_path: The path of the source capture image - best_detection_capture_public_base_url: The public base URL of the source capture image + - best_detection_source_image_id: Primary key of the SourceImage for that capture + - best_detection_capture_timestamp: Timestamp from SourceImage (capture time) + - best_detection_capture_width / best_detection_capture_height: Dimensions from SourceImage """ # Subquery to get the path of the best detection # Use id as secondary sort to ensure deterministic results @@ -2910,11 +2913,36 @@ def with_best_detection(self): .values("source_image__public_base_url")[:1] ) + best_detection_source_image_id_subquery = ( + Detection.objects.filter(occurrence=OuterRef("pk")) + .order_by("-classifications__score", "id") + .values("source_image_id")[:1] + ) + best_detection_capture_timestamp_subquery = ( + Detection.objects.filter(occurrence=OuterRef("pk")) + .order_by("-classifications__score", "id") + .values("source_image__timestamp")[:1] + ) + best_detection_capture_width_subquery = ( + Detection.objects.filter(occurrence=OuterRef("pk")) + .order_by("-classifications__score", "id") + .values("source_image__width")[:1] + ) + best_detection_capture_height_subquery = ( + Detection.objects.filter(occurrence=OuterRef("pk")) + .order_by("-classifications__score", "id") + .values("source_image__height")[:1] + ) + return self.annotate( best_detection_path=models.Subquery(best_detection_path_subquery), best_detection_bbox=models.Subquery(best_detection_bbox_subquery), best_detection_capture_path=models.Subquery(best_detection_capture_path_subquery), best_detection_capture_public_base_url=models.Subquery(best_detection_capture_public_base_url_subquery), + best_detection_source_image_id=models.Subquery(best_detection_source_image_id_subquery), + best_detection_capture_timestamp=models.Subquery(best_detection_capture_timestamp_subquery), + best_detection_capture_width=models.Subquery(best_detection_capture_width_subquery), + best_detection_capture_height=models.Subquery(best_detection_capture_height_subquery), ) def with_best_machine_prediction(self): diff --git a/ui/src/data-services/models/export.ts b/ui/src/data-services/models/export.ts index d96ee2e39..ec63dfc1f 100644 --- a/ui/src/data-services/models/export.ts +++ b/ui/src/data-services/models/export.ts @@ -6,6 +6,7 @@ import { JobDetails } from './job-details' export const SERVER_EXPORT_TYPES = [ 'occurrences_simple_csv', 'occurrences_api_json', + 'occurrences_coco_json', ] as const export type ServerExportType = (typeof SERVER_EXPORT_TYPES)[number] @@ -27,6 +28,7 @@ export class Export extends Entity { const label = { occurrences_simple_csv: 'Occurrences (simple CSV)', occurrences_api_json: 'Occurrences (API JSON)', + occurrences_coco_json: 'Occurrences (COCO JSON)', }[key] return { diff --git a/ui/src/utils/language.ts b/ui/src/utils/language.ts index c9fc33df0..f417043b9 100644 --- a/ui/src/utils/language.ts +++ b/ui/src/utils/language.ts @@ -527,7 +527,7 @@ const ENGLISH_STRINGS: { [key in STRING]: string } = { 'Are you sure you want to delete this {{type}}?', [STRING.MESSAGE_DRAFTS]: 'Drafts are private and limited to one user.', [STRING.MESSAGE_EXPORT_TIP]: - 'We support two export formats: one compact and easy to use, and one that includes all raw data. To include all data in the export, skip "Capture set".', + 'We support three export formats: simple CSV, API JSON (includes all raw data), and COCO JSON. COCO JSON follows the simple CSV occurrence fields with additional image metadata. To include all data in the export, skip "Capture set".', [STRING.MESSAGE_HAS_ACCOUNT]: 'Already have an account?', [STRING.MESSAGE_IMAGE_FORMAT]: 'Valid formats are PNG, GIF and JPEG.', [STRING.MESSAGE_IMAGE_SIZE]: