Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
257 changes: 255 additions & 2 deletions ami/exports/format_types.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
import csv
import datetime
import json
import logging
import os
import tempfile

from django.conf import settings
from django.core.serializers.json import DjangoJSONEncoder
from rest_framework import serializers

from ami.exports.base import BaseExporter
from ami.exports.utils import get_data_in_batches
from ami.main.models import Occurrence, SourceImage, get_media_url
from ami.main.models import Occurrence, SourceImage, Taxon, get_media_url
from ami.ml.schemas import BoundingBox

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -190,20 +193,227 @@ def get_best_detection_height(self, obj):
bbox = BoundingBox.from_coords(getattr(obj, "best_detection_bbox", None), raise_on_error=False)
return bbox.height if bbox else None

def get_best_detection_source_image_id(self, obj):
"""Returns the source image id for the best detection."""
return getattr(obj, "best_detection_source_image_id", None)

def get_best_detection_capture_timestamp(self, obj):
"""Returns the capture timestamp for the best detection source image."""
return getattr(obj, "best_detection_capture_timestamp", None)

def get_best_detection_capture_path(self, obj):
"""Returns the capture path for the best detection source image."""
return getattr(obj, "best_detection_capture_path", None)

def get_best_detection_capture_width(self, obj):
"""Returns the capture width for the best detection source image."""
return getattr(obj, "best_detection_capture_width", None)

def get_best_detection_capture_height(self, obj):
"""Returns the capture height for the best detection source image."""
return getattr(obj, "best_detection_capture_height", None)

def get_best_detection_capture_url(self, obj):
"""Returns the public URL to the source capture (original full-frame image).

Built from annotated `path` + `public_base_url` to avoid loading the
capture (SourceImage) row per occurrence; presigned URLs for private
buckets aren't supported here for the same reason.
"""
path = getattr(obj, "best_detection_capture_path", None)
path = self.get_best_detection_capture_path(obj)
base_url = getattr(obj, "best_detection_capture_public_base_url", None)
if path and base_url:
return SourceImage.build_public_url(base_url, path)
return None


class OccurrenceCocoTabularSerializer(OccurrenceTabularSerializer):
"""CSV-shaped occurrence row plus SourceImage metadata for COCO exports."""

# Best detection SourceImage fields
source_image_id = serializers.SerializerMethodField()
capture_timestamp = serializers.SerializerMethodField()
capture_path = serializers.SerializerMethodField()
capture_width = serializers.SerializerMethodField()
capture_height = serializers.SerializerMethodField()
best_machine_prediction_taxon_id = serializers.IntegerField(allow_null=True, default=None)

class Meta(OccurrenceTabularSerializer.Meta):
fields = OccurrenceTabularSerializer.Meta.fields + [
"source_image_id",
"capture_timestamp",
"capture_path",
"capture_width",
"capture_height",
"best_machine_prediction_taxon_id",
]

def get_source_image_id(self, obj):
return self.get_best_detection_source_image_id(obj)

def get_capture_timestamp(self, obj):
ts = self.get_best_detection_capture_timestamp(obj)
if ts is None:
return None
return ts.isoformat() if hasattr(ts, "isoformat") else str(ts)

def get_capture_path(self, obj):
return self.get_best_detection_capture_path(obj)

def get_capture_width(self, obj):
return self.get_best_detection_capture_width(obj)

def get_capture_height(self, obj):
return self.get_best_detection_capture_height(obj)


def corner_bbox_to_coco_bbox(corner: list | None) -> tuple[list[float], float] | None:
"""Convert [x1, y1, x2, y2] to COCO [x, y, width, height] and area."""
bbox = BoundingBox.from_coords(corner, raise_on_error=False)
if bbox is None:
return None
w, h = bbox.width, bbox.height
if w is None or h is None or w <= 0 or h <= 0:
return None
return [float(bbox.x1), float(bbox.y1), float(w), float(h)], float(w * h)


def build_coco_dict_from_occurrence_rows(rows: list[dict], project) -> dict:
"""Build a COCO-style detection dict from serialized occurrence rows (determination categories only)."""
categories_by_id: dict[int, dict] = {}
images_by_id: dict[int, dict] = {}
annotations: list[dict] = []
category_taxon_ids: set[int] = set()

for row in rows:
determination_id = row.get("determination_id")
if determination_id is None:
logger.warning(f"No determination_id found for row: {row}")
continue

coco_result = corner_bbox_to_coco_bbox(row.get("best_detection_bbox"))
if coco_result is None:
logger.warning(f"No coco_bbox found for row: {row}")
continue

coco_bbox, area = coco_result
source_image_id = row.get("source_image_id")
if source_image_id is None:
logger.warning(f"No source_image_id found for row: {row}")
continue

det_name = row.get("determination_name") or ""
if int(determination_id) not in categories_by_id:
categories_by_id[int(determination_id)] = {
"id": int(determination_id),
"name": det_name,
}
category_taxon_ids.add(int(determination_id))
else:
assert (
categories_by_id[int(determination_id)]["name"] == det_name
), f"Determination name mismatch for id: {determination_id}"

if source_image_id not in images_by_id:
cap_path = row.get("capture_path") or ""
images_by_id[int(source_image_id)] = {
"id": int(source_image_id),
"file_name": os.path.basename(cap_path) if cap_path else "",
"width": row.get("capture_width"),
"height": row.get("capture_height"),
"coco_url": row.get("best_detection_capture_url"),
"date_captured": row.get("capture_timestamp"),
}

occ_id = row.get("id")
if occ_id is None:
logger.warning(f"No occ_id found for row: {row}")
continue

ann: dict = {
"id": int(occ_id),
"image_id": int(source_image_id),
"category_id": int(determination_id),
"bbox": coco_bbox,
"area": area,
"iscrowd": 0, # TODO: Could we use this field to indiate crowd of insects?
"determination_score": row.get("determination_score"),
"verification_status": row.get("verification_status"),
"determination_matches_machine_prediction": row.get("determination_matches_machine_prediction"),
"best_machine_prediction_algorithm": row.get("best_machine_prediction_algorithm"),
"best_machine_prediction_score": row.get("best_machine_prediction_score"),
"best_machine_prediction_taxon_id": row.get("best_machine_prediction_taxon_id"),
"best_detection_width": row.get("best_detection_width"),
"best_detection_height": row.get("best_detection_height"),
}
prediction_taxon_id = row.get("best_machine_prediction_taxon_id")
if prediction_taxon_id is not None:
try:
category_taxon_ids.add(int(prediction_taxon_id))
except (TypeError, ValueError):
logger.warning(f"Invalid best_machine_prediction_taxon_id for row: {row}")
annotations.append(ann)

def _serialize_parents_json(parents_json):
if not isinstance(parents_json, list):
return []
serialized = []
for parent in parents_json:
if isinstance(parent, dict):
parent_id = parent.get("id")
parent_name = parent.get("name")
rank = parent.get("rank")
else:
# SchemaField(list[TaxonParent]) may return Pydantic objects rather than dicts.
parent_id = getattr(parent, "id", None)
parent_name = getattr(parent, "name", None)
rank = getattr(parent, "rank", None)

if parent_id is None and parent_name is None and rank is None:
continue

rank_value = rank.value if hasattr(rank, "value") else rank
serialized.append(
{
"id": parent_id,
"name": parent_name,
"rank": str(rank_value) if rank_value is not None else None,
}
)
return serialized

if category_taxon_ids:
taxa = Taxon.objects.filter(id__in=category_taxon_ids).values(
"id", "name", "rank", "parent_id", "parents_json"
)
for taxon in taxa:
taxon_id = int(taxon["id"])
categories_by_id[taxon_id] = {
"id": taxon_id,
"name": taxon.get("name") or categories_by_id.get(taxon_id, {}).get("name", ""),
"rank": taxon.get("rank"),
"parent_id": taxon.get("parent_id"),
"parents": _serialize_parents_json(taxon.get("parents_json")),
}

base = getattr(settings, "EXTERNAL_BASE_URL", "") or ""
info_url = ""
if base.strip():
info_url = f"{base.rstrip('/')}/projects/{project.pk}/summary" # TODO: is there a better way to do this?

payload = {
"info": {
"description": f"{project.name} ({project.pk}) Occurrences",
"url": info_url,
"date_created": datetime.datetime.now(datetime.timezone.utc).isoformat(),
},
"images": list(images_by_id.values()),
"annotations": annotations,
"categories": sorted(categories_by_id.values(), key=lambda c: c["id"]),
}
return payload


class CSVExporter(BaseExporter):
"""Handles CSV export of occurrences."""

Expand Down Expand Up @@ -247,3 +457,46 @@ def export(self):
self.update_job_progress(records_exported)
self.update_export_stats(file_temp_path=temp_file.name)
return temp_file.name # Return the file path


class CocoJSONExporter(BaseExporter):
"""Exports occurrences as a COCO-style detection dataset (same rows as CSV plus capture metadata)."""

file_format = "json"

serializer_class = OccurrenceCocoTabularSerializer

def get_queryset(self):
return (
Occurrence.objects.valid()
.filter(project=self.project)
.select_related(
"determination",
"deployment",
"event",
)
.with_timestamps()
.with_detections_count()
.with_identifications()
.with_best_detection()
.with_best_machine_prediction()
.with_verification_info()
)

def export(self):
"""Serialize occurrences in batches, build COCO JSON, write to a temp file."""
rows: list[dict] = []
records_exported = 0
for batch in get_data_in_batches(self.queryset, self.serializer_class):
rows.extend(batch)
records_exported += len(batch)
self.update_job_progress(records_exported)

coco_payload = build_coco_dict_from_occurrence_rows(rows, self.project)

temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".json", mode="w", encoding="utf-8")
with open(temp_file.name, "w", encoding="utf-8") as f:
json.dump(coco_payload, f, cls=DjangoJSONEncoder)

self.update_export_stats(file_temp_path=temp_file.name)
return temp_file.name
1 change: 1 addition & 0 deletions ami/exports/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,4 @@ def get_supported_formats(cls):

ExportRegistry.register("occurrences_api_json")(format_types.JSONExporter)
ExportRegistry.register("occurrences_simple_csv")(format_types.CSVExporter)
ExportRegistry.register("occurrences_coco_json")(format_types.CocoJSONExporter)
66 changes: 65 additions & 1 deletion ami/exports/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from django.core.files.base import ContentFile
from django.core.files.storage import default_storage
from django.test import TestCase
from django.test import TestCase, override_settings
from rest_framework.test import APIClient

from ami.exports.models import DataExport
Expand Down Expand Up @@ -210,6 +210,70 @@ def test_csv_export_has_detection_fields(self):
# Clean up the exported file after the test
default_storage.delete(file_path)

@override_settings(EXTERNAL_BASE_URL="https://antenna.example.org")
def test_coco_json_export_structure(self):
"""COCO export is a single JSON object with images, annotations, categories; ids are consistent."""
# Create a DataExport instance
data_export = DataExport.objects.create(
user=self.user,
project=self.project,
format="occurrences_coco_json",
filters={"collection_id": self.collection.pk},
job=None,
)

# Run export and get the file URL
file_url = data_export.run_export()

# Ensure the file is generated
self.assertIsNotNone(file_url)
file_path = file_url.replace("/media/", "")
self.assertTrue(default_storage.exists(file_path))

# Read and validate the exported data
with default_storage.open(file_path, "r") as f:
payload = json.load(f)

# Ensure necessary COCO fields are present
self.assertIn("info", payload)
self.assertIn("images", payload)
self.assertIn("annotations", payload)
self.assertIn("categories", payload)

# Ensure the info field points to the project summary page
# TODO: Change this to not be harcoded?
self.assertEqual(
payload["info"]["url"],
f"https://antenna.example.org/projects/{self.project.pk}/summary",
)

# Ensure the images field contains the correct number of images
image_ids = {img["id"] for img in payload["images"]}
category_ids = {c["id"] for c in payload["categories"]}
assert len(image_ids) > 0, "No images found in the export"
assert len(category_ids) > 0, "No categories found in the export"
assert len(payload["annotations"]) > 0, "No annotations found in the export"

for ann in payload["annotations"]:
self.assertIn(ann["image_id"], image_ids)
self.assertIn(ann["category_id"], category_ids)
x, y, w, h = ann["bbox"]
self.assertGreater(w, 0)
self.assertGreater(h, 0)

# Number of annotations should equal the number of occurrences in
# the collection, but excluding occurrences without a determination
occurrences_with_determination = (
Occurrence.objects.valid()
.filter(detections__source_image__collections=self.collection, determination__isnull=False)
.distinct()
.count()
)
self.assertEqual(len(payload["annotations"]), occurrences_with_determination)

# Clean up the exported file after the test
default_storage.delete(file_path)


class DataExportPermissionTest(TestCase):
"""Test data export permissions (create, update, delete)."""
Expand Down
2 changes: 1 addition & 1 deletion ami/main/api/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -702,7 +702,7 @@ class TaxaListTaxonSerializer(TaxonNoParentNestedSerializer):

class CaptureTaxonSerializer(DefaultSerializer):
parent = TaxonNoParentNestedSerializer(read_only=True)
parents = TaxonParentSerializer(many=True, read_only=True)
parents = TaxonParentSerializer(many=True, read_only=True, source="parents_json")

def get_permissions(self, instance, instance_data):
instance_data["user_permissions"] = []
Expand Down
Loading
Loading