Skip to content
Open
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 28 additions & 20 deletions ami/main/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@

import ami.utils
from ami import tasks
from ami.jobs.models import Job
from ami.ml.models.project_pipeline_config import ProjectPipelineConfig
from ami.ml.post_processing.admin.actions import make_post_processing_action
from ami.ml.post_processing.admin.small_size_filter_form import SmallSizeFilterActionForm
from ami.ml.post_processing.small_size_filter import SmallSizeFilterTask
from ami.ml.tasks import remove_duplicate_classifications

from .models import (
Expand Down Expand Up @@ -413,6 +415,19 @@ def get_queryset(self, request: HttpRequest) -> QuerySet[Any]:
def detections_count(self, obj) -> int:
return obj.detections_count

# Per-occurrence post-processing trigger. Same factory as the capture-set
# action on SourceImageCollectionAdmin, scoped to one occurrence — the fast
# spot/dev path for iterating on a filter without running a whole collection.
# New per-occurrence tasks add their own action here the same way.
run_small_size_filter = make_post_processing_action(
SmallSizeFilterTask,
SmallSizeFilterActionForm,
scope_resolver=lambda occurrence: {"occurrence_id": occurrence.pk},
name_resolver=lambda task_cls, occurrence: (f"Post-processing: {task_cls.name} on Occurrence {occurrence.pk}"),
)

actions = [run_small_size_filter]

ordering = ("-created_at",)

# Add classifications as inline
Expand Down Expand Up @@ -651,25 +666,18 @@ def populate_collection_async(self, request: HttpRequest, queryset: QuerySet[Sou
f"Populating {len(queued_tasks)} capture set(s) background tasks: {queued_tasks}.",
)

@admin.action(description="Run Small Size Filter post-processing task (async)")
def run_small_size_filter(self, request: HttpRequest, queryset: QuerySet[SourceImageCollection]) -> None:
jobs = []
for collection in queryset:
job = Job.objects.create(
name=f"Post-processing: SmallSizeFilter on Capture Set {collection.pk}",
project=collection.project,
job_type_key="post_processing",
params={
"task": "small_size_filter",
"config": {
"source_image_collection_id": collection.pk,
},
},
)
job.enqueue()
jobs.append(job.pk)

self.message_user(request, f"Queued Small Size Filter for {queryset.count()} capture set(s). Jobs: {jobs}")
# Built from the shared post-processing action factory: renders an intermediate
# confirmation page with the task's knob form, validates each selection against
# SmallSizeFilterConfig, then enqueues one Job per capture set. New post-processing
# tasks declare their own trigger the same way (task class + form + scope_resolver).
run_small_size_filter = make_post_processing_action(
SmallSizeFilterTask,
SmallSizeFilterActionForm,
scope_resolver=lambda collection: {"source_image_collection_id": collection.pk},
name_resolver=lambda task_cls, collection: (
f"Post-processing: {task_cls.name} on Capture Set {collection.pk}"
),
)

actions = [
populate_collection,
Expand Down
Empty file.
287 changes: 287 additions & 0 deletions ami/ml/post_processing/admin/actions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,287 @@
"""Shared admin-action machinery for triggering post-processing tasks.

Every post-processing task surfaces the same admin flow:

1. The operator selects rows and picks the action.
2. An intermediate confirmation page renders the task's knob form.
3. On submit, each row's config is validated against the task's pydantic
``config_schema`` and a Job is enqueued.

``make_post_processing_action`` builds the action callable for that flow so
each task only declares what varies: its task class, its knob form, and how a
selected row maps to a Job (scope + project + name). Tasks whose row→Job
mapping doesn't fit the default one-Job-per-row shape (e.g. partitioning events
across projects) pass their own ``build_jobs`` callable.

Validation lives in one place: the task's ``config_schema``. The knob form only
declares fields (label, help text, widget); it does not re-encode the schema's
rules. Schema errors raised while building Jobs are mapped back onto the form so
the operator sees them inline on the confirmation page.
"""
from __future__ import annotations

import logging
from collections.abc import Callable
from typing import Any, Protocol

import pydantic
from django.contrib import admin, messages
from django.db import transaction
from django.db.models import Model
from django.db.models.query import QuerySet
from django.http import HttpRequest, HttpResponse
from django.template.response import TemplateResponse
from django.urls import reverse

from ami.jobs.models import Job
from ami.ml.post_processing.admin.forms import BasePostProcessingActionForm
from ami.ml.post_processing.base import BasePostProcessingTask

logger = logging.getLogger(__name__)

CONFIRMATION_TEMPLATE = "admin/post_processing/confirmation.html"


class _ModelAdminProto(Protocol):
"""The slice of ``ModelAdmin`` the generic action touches."""

model: type[Model]
admin_site: Any

def message_user(self, request: HttpRequest, message: str, level: Any = ..., **kwargs: Any) -> None:
...


class ConfigValidationErrors(Exception):
"""Raised by ``build_jobs`` when one or more rows produce invalid config.

Carries ``(field_name_or_None, message)`` pairs so the caller can attach
them to the knob form and re-render the confirmation page instead of
creating any Jobs.
"""

def __init__(self, errors: list[tuple[str | None, str]]):
self.errors = errors
super().__init__(f"{len(errors)} invalid config(s)")


def _schema_errors_to_form_fields(
exc: pydantic.ValidationError,
form_field_names: set[str],
) -> list[tuple[str | None, str]]:
"""Map a pydantic ``ValidationError`` onto form field names where possible.

Errors on a field the form renders are attached to that field; everything
else (e.g. an injected scope field) becomes a non-field error.
"""
mapped: list[tuple[str | None, str]] = []
for err in exc.errors():
loc = err.get("loc") or ()
field = str(loc[0]) if loc else None
target = field if field in form_field_names else None
mapped.append((target, err.get("msg", "Invalid value")))
return mapped


def default_build_jobs(
*,
model_admin: _ModelAdminProto,
request: HttpRequest,
config: dict[str, Any],
queryset: QuerySet,
task_cls: type[BasePostProcessingTask],
form_field_names: set[str],
scope_resolver: Callable[[Any], dict[str, Any]],
project_resolver: Callable[[Any], Any],
name_resolver: Callable[[type[BasePostProcessingTask], Any], str],
) -> list[int]:
"""Validate every selected row, then enqueue one Job per row (all-or-nothing).

Each row's full config is ``{**config, **scope_resolver(row)}`` validated
against ``task_cls.config_schema``. If any row fails, nothing is created and
``ConfigValidationErrors`` is raised so the form can re-render with the
errors inline.
"""
validated: list[tuple[Any, pydantic.BaseModel]] = []
errors: list[tuple[str | None, str]] = []

for obj in queryset:
full_config = {**config, **scope_resolver(obj)}
try:
validated.append((obj, task_cls.config_schema(**full_config)))
except pydantic.ValidationError as exc:
errors.extend(_schema_errors_to_form_fields(exc, form_field_names))

if errors:
raise ConfigValidationErrors(errors)

# Create all Jobs in one transaction so the operation stays all-or-nothing even
# if a create fails mid-loop. (Admin requests are already atomic via
# ATOMIC_REQUESTS, but this helper may also be called outside a request — e.g. a
# management command — where there's no ambient transaction.) Job.enqueue() uses
# transaction.on_commit, so enqueues fire only once the block commits.
job_pks: list[int] = []
with transaction.atomic():
for obj, model in validated:
job = Job.objects.create(
name=name_resolver(task_cls, obj),
project=project_resolver(obj),
job_type_key="post_processing",
params={"task": task_cls.key, "config": model.dict()},
)
job.enqueue()
job_pks.append(job.pk)
return job_pks


def render_confirmation(
model_admin: _ModelAdminProto,
request: HttpRequest,
queryset: QuerySet,
*,
task_cls: type[BasePostProcessingTask],
form: BasePostProcessingActionForm,
action_name: str,
title: str,
submit_label: str,
) -> TemplateResponse:
"""Render the shared intermediate confirmation page for ``task_cls``."""
opts = model_admin.model._meta
# Resolve the selection once; count from the materialized list (one query, not two).
selected_pks = [str(pk) for pk in queryset.values_list("pk", flat=True)]
return TemplateResponse(
request,
CONFIRMATION_TEMPLATE,
{
**model_admin.admin_site.each_context(request),
"title": title,
"task_label": task_cls.name,
"form": form,
"selected_count": len(selected_pks),
"selected_pks": selected_pks,
"action_name": action_name,
"submit_label": submit_label,
"changelist_url": reverse(f"admin:{opts.app_label}_{opts.model_name}_changelist"),
"model_meta": opts,
"opts": opts,
"action_checkbox_name": admin.helpers.ACTION_CHECKBOX_NAME,
},
)


def _default_name_resolver(task_cls: type[BasePostProcessingTask], obj: Any) -> str:
return f"Post-processing: {task_cls.name} on {obj._meta.verbose_name} {obj.pk}"


def make_post_processing_action(
task_cls: type[BasePostProcessingTask],
form_class: type[BasePostProcessingActionForm],
*,
scope_resolver: Callable[[Any], dict[str, Any]] | None = None,
project_resolver: Callable[[Any], Any] = lambda obj: obj.project,
name_resolver: Callable[[type[BasePostProcessingTask], Any], str] = _default_name_resolver,
build_jobs: Callable[..., list[int]] | None = None,
description: str | None = None,
title: str | None = None,
submit_label: str | None = None,
) -> Callable[[_ModelAdminProto, HttpRequest, QuerySet], HttpResponse | None]:
"""Build a Django admin action that triggers ``task_cls`` via the shared flow.

Args:
task_cls: the post-processing task. ``key``/``name``/``config_schema``
drive the action name, labels, and config validation.
form_class: the knob form rendered on the confirmation page.
scope_resolver: maps a selected row to the config fields identifying its
scope, e.g. ``lambda c: {"source_image_collection_id": c.pk}``.
Required unless a custom ``build_jobs`` is supplied.
project_resolver: maps a row to the Job's project (default ``obj.project``).
name_resolver: maps ``(task_cls, row)`` to the Job name.
build_jobs: escape hatch for tasks whose row→Job mapping isn't one
Job per row (e.g. partitioning across projects). Receives the same
keyword arguments as ``default_build_jobs`` and returns created Job
pks; raise ``ConfigValidationErrors`` to re-render the form.
description: admin action dropdown label.
title / submit_label: confirmation-page heading and button text.

The returned callable's ``__name__`` is ``run_<task key>`` so Django
registers it under that name and the confirmation page's hidden ``action``
field round-trips correctly.
"""
if build_jobs is None and scope_resolver is None:
raise ValueError("make_post_processing_action requires scope_resolver unless build_jobs is supplied")

action_name = f"run_{task_cls.key}"
resolved_title = title or f"Run {task_cls.name}"
resolved_submit = submit_label or resolved_title
resolved_description = description or f"Run {task_cls.name} post-processing task (async)"

def action(
model_admin: _ModelAdminProto,
request: HttpRequest,
queryset: QuerySet,
) -> HttpResponse | None:
def _render(form: BasePostProcessingActionForm) -> TemplateResponse:
return render_confirmation(
model_admin,
request,
queryset,
task_cls=task_cls,
form=form,
action_name=action_name,
title=resolved_title,
submit_label=resolved_submit,
)

# "Select all across pages" hands us the entire filtered table as the
# queryset and would serialize every pk into hidden inputs on the
# confirmation page (a huge POST body, possibly over request limits).
# This admin trigger is for explicit, bounded selections; refuse the
# across-pages case rather than render an unbounded form.
if request.POST.get("select_across") == "1":
model_admin.message_user(
request,
f'"Select all across pages" is not supported for {task_cls.name}. '
"Select the specific rows you want to process.",
level=messages.WARNING,
)
return None

if not request.POST.get("confirm"):
return _render(form_class())

form = form_class(request.POST)
if not form.is_valid():
return _render(form)

runner = build_jobs or default_build_jobs
kwargs: dict[str, Any] = dict(
model_admin=model_admin,
request=request,
config=form.to_config(),
queryset=queryset,
task_cls=task_cls,
form_field_names=set(form.fields),
project_resolver=project_resolver,
name_resolver=name_resolver,
)
# Only forward scope_resolver when set. A custom build_jobs supplied without
# a scope_resolver should not receive a None it might try to call.
if scope_resolver is not None:
kwargs["scope_resolver"] = scope_resolver
try:
job_pks = runner(**kwargs)
except ConfigValidationErrors as exc:
for field, message in exc.errors:
form.add_error(field, message)
return _render(form)

model_admin.message_user(
request,
f"Queued {task_cls.name} for {len(job_pks)} {model_admin.model._meta.verbose_name}(s). Jobs: {job_pks}",
level=messages.SUCCESS,
)
return None

action.__name__ = action_name
action.__qualname__ = action_name
return admin.action(description=resolved_description)(action)
25 changes: 25 additions & 0 deletions ami/ml/post_processing/admin/forms.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
"""Form base for admin actions that trigger post-processing tasks.

Each post-processing task surfaces its tunable knobs as a Django form. The
form's ``cleaned_data`` becomes the ``config`` payload on the resulting Job
(after validation against the task's pydantic ``config_schema``).

Algorithm scope (which queryset/events/collection the action runs against)
lives outside the form because it varies per admin entry-point.
"""
from __future__ import annotations

from django import forms


class BasePostProcessingActionForm(forms.Form):
"""Marker base for post-processing admin action forms.

Subclasses declare task-specific fields. Override ``to_config()`` if the
1:1 ``cleaned_data → config`` mapping needs adjustment (e.g. drop empty
optional fields, derive computed values, rename keys).
"""

def to_config(self) -> dict:
"""Return ``cleaned_data`` shaped for ``Job.params['config']``."""
return dict(self.cleaned_data)
Loading
Loading