From a31b0dee13f12bb5cfaca3d8d20b1d488515047b Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Wed, 29 Mar 2023 01:56:26 -0700 Subject: [PATCH 01/53] Quick hacks to serialize data with real image urls --- trapdata/cli/export.py | 2 +- trapdata/db/models/detections.py | 15 +++++++++++++++ 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/trapdata/cli/export.py b/trapdata/cli/export.py index 281c2e12..675c7e72 100644 --- a/trapdata/cli/export.py +++ b/trapdata/cli/export.py @@ -202,7 +202,7 @@ def captures( events = get_monitoring_session_by_date( db_path=settings.database_url, base_directory=settings.image_base_path, - event_dates=[str(date.date())], + event_dates=[date.date()], ) if not len(events): raise Exception(f"No Monitoring Event with date: {date.date()}") diff --git a/trapdata/db/models/detections.py b/trapdata/db/models/detections.py index 63dd3548..a3192296 100644 --- a/trapdata/db/models/detections.py +++ b/trapdata/db/models/detections.py @@ -15,6 +15,21 @@ from trapdata.db import models from trapdata.db.models.images import completely_classified +from pydantic import BaseModel + + +class DetectionListItem(BaseModel): + id: int + cropped_image_path: Optional[str] + bbox: Optional[list[int]] + area_pixels: Optional[int] + last_detected: Optional[datetime.datetime] + label: Optional[str] + score: Optional[int] + model_name: Optional[str] + in_queue: bool + notes: Optional[str] + class DetectionListItem(BaseModel): id: int From 4f7ddfdb7e3039af6ff8f92e5de7427ba78fc7c8 Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Wed, 29 Mar 2023 17:30:44 -0700 Subject: [PATCH 02/53] Use dataclass for Deployment --- trapdata/cli/show.py | 2 +- trapdata/db/models/deployments.py | 12 ++++++++++++ 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/trapdata/cli/show.py b/trapdata/cli/show.py index e87213af..8bffdbf1 100644 --- a/trapdata/cli/show.py +++ b/trapdata/cli/show.py @@ -70,7 +70,7 @@ def deployments(): deployments = list_deployments(session) table = Table( "Image Base Path", - "Events", + "Sessions", "Images", "Detections", ) diff --git a/trapdata/db/models/deployments.py b/trapdata/db/models/deployments.py index 5b75b074..1cec4864 100644 --- a/trapdata/db/models/deployments.py +++ b/trapdata/db/models/deployments.py @@ -14,6 +14,18 @@ from trapdata.common.types import FilePath from trapdata.db import models +from pydantic import BaseModel + + +class DeploymentListItem(BaseModel): + # id: int + name: str + num_events: int + num_source_images: int + num_detections: int + # num_occurrences: int + # num_species: int + class DeploymentListItem(BaseModel): # id: int From 534347150e79ed27d792b81d31c9088691e36eae Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Wed, 5 Apr 2023 20:51:37 -0700 Subject: [PATCH 03/53] Incomplete start of an API implementation --- poetry.lock | 76 +++++++++++++++++++- pyproject.toml | 1 + trapdata/api/__init__.py | 11 +++ trapdata/api/deployments.py | 27 ++++++++ trapdata/api/deps/db.py | 15 ++++ trapdata/api/deps/request_params.py | 46 +++++++++++++ trapdata/api/items.py | 103 ++++++++++++++++++++++++++++ trapdata/api/occurrences.bak.py | 47 +++++++++++++ trapdata/api/occurrences.py | 71 +++++++++++++++++++ trapdata/api/request_params.py | 9 +++ trapdata/api/users.py | 30 ++++++++ trapdata/api/utils.py | 17 +++++ trapdata/db/__init__.py | 4 +- trapdata/db/base.py | 16 +++++ trapdata/db/models/deployments.py | 12 +--- trapdata/db/models/images.py | 1 - 16 files changed, 473 insertions(+), 13 deletions(-) create mode 100644 trapdata/api/__init__.py create mode 100644 trapdata/api/deployments.py create mode 100644 trapdata/api/deps/db.py create mode 100644 trapdata/api/deps/request_params.py create mode 100644 trapdata/api/items.py create mode 100644 trapdata/api/occurrences.bak.py create mode 100644 trapdata/api/occurrences.py create mode 100644 trapdata/api/request_params.py create mode 100644 trapdata/api/users.py create mode 100644 trapdata/api/utils.py diff --git a/poetry.lock b/poetry.lock index d0939916..da8cbc61 100644 --- a/poetry.lock +++ b/poetry.lock @@ -20,6 +20,27 @@ typing-extensions = ">=4" [package.extras] tz = ["python-dateutil"] +[[package]] +name = "anyio" +version = "3.6.2" +description = "High level compatibility layer for multiple asynchronous event loop implementations" +category = "main" +optional = false +python-versions = ">=3.6.2" +files = [ + {file = "anyio-3.6.2-py3-none-any.whl", hash = "sha256:fbbe32bd270d2a2ef3ed1c5d45041250284e31fc0a4df4a5a6071842051a51e3"}, + {file = "anyio-3.6.2.tar.gz", hash = "sha256:25ea0d673ae30af41a0c442f81cf3b38c7e79fdc7b60335a4c14e05eb0947421"}, +] + +[package.dependencies] +idna = ">=2.8" +sniffio = ">=1.1" + +[package.extras] +doc = ["packaging", "sphinx-autodoc-typehints (>=1.2.0)", "sphinx-rtd-theme"] +test = ["contextlib2", "coverage[toml] (>=4.5)", "hypothesis (>=4.0)", "mock (>=4)", "pytest (>=7.0)", "pytest-mock (>=3.6.1)", "trustme", "uvloop (<0.15)", "uvloop (>=0.15)"] +trio = ["trio (>=0.16,<0.22)"] + [[package]] name = "appnope" version = "0.1.3" @@ -406,6 +427,28 @@ files = [ [package.extras] tests = ["asttokens", "littleutils", "pytest", "rich"] +[[package]] +name = "fastapi" +version = "0.95.0" +description = "FastAPI framework, high performance, easy to learn, fast to code, ready for production" +category = "main" +optional = false +python-versions = ">=3.7" +files = [ + {file = "fastapi-0.95.0-py3-none-any.whl", hash = "sha256:daf73bbe844180200be7966f68e8ec9fd8be57079dff1bacb366db32729e6eb5"}, + {file = "fastapi-0.95.0.tar.gz", hash = "sha256:99d4fdb10e9dd9a24027ac1d0bd4b56702652056ca17a6c8721eec4ad2f14e18"}, +] + +[package.dependencies] +pydantic = ">=1.6.2,<1.7 || >1.7,<1.7.1 || >1.7.1,<1.7.2 || >1.7.2,<1.7.3 || >1.7.3,<1.8 || >1.8,<1.8.1 || >1.8.1,<2.0.0" +starlette = ">=0.26.1,<0.27.0" + +[package.extras] +all = ["email-validator (>=1.1.1)", "httpx (>=0.23.0)", "itsdangerous (>=1.1.0)", "jinja2 (>=2.11.2)", "orjson (>=3.2.1)", "python-multipart (>=0.0.5)", "pyyaml (>=5.3.1)", "ujson (>=4.0.1,!=4.0.2,!=4.1.0,!=4.2.0,!=4.3.0,!=5.0.0,!=5.1.0)", "uvicorn[standard] (>=0.12.0)"] +dev = ["pre-commit (>=2.17.0,<3.0.0)", "ruff (==0.0.138)", "uvicorn[standard] (>=0.12.0,<0.21.0)"] +doc = ["mdx-include (>=1.4.1,<2.0.0)", "mkdocs (>=1.1.2,<2.0.0)", "mkdocs-markdownextradata-plugin (>=0.1.7,<0.3.0)", "mkdocs-material (>=8.1.4,<9.0.0)", "pyyaml (>=5.3.1,<7.0.0)", "typer-cli (>=0.0.13,<0.0.14)", "typer[all] (>=0.6.1,<0.8.0)"] +test = ["anyio[trio] (>=3.2.1,<4.0.0)", "black (==23.1.0)", "coverage[toml] (>=6.5.0,<8.0)", "databases[sqlite] (>=0.3.2,<0.7.0)", "email-validator (>=1.1.1,<2.0.0)", "flask (>=1.1.2,<3.0.0)", "httpx (>=0.23.0,<0.24.0)", "isort (>=5.0.6,<6.0.0)", "mypy (==0.982)", "orjson (>=3.2.1,<4.0.0)", "passlib[bcrypt] (>=1.7.2,<2.0.0)", "peewee (>=3.13.3,<4.0.0)", "pytest (>=7.1.3,<8.0.0)", "python-jose[cryptography] (>=3.3.0,<4.0.0)", "python-multipart (>=0.0.5,<0.0.7)", "pyyaml (>=5.3.1,<7.0.0)", "ruff (==0.0.138)", "sqlalchemy (>=1.3.18,<1.4.43)", "types-orjson (==3.6.2)", "types-ujson (==5.7.0.1)", "ujson (>=4.0.1,!=4.0.2,!=4.1.0,!=4.2.0,!=4.3.0,!=5.0.0,!=5.1.0,<6.0.0)"] + [[package]] name = "filelock" version = "3.10.7" @@ -1980,6 +2023,18 @@ files = [ {file = "six-1.16.0.tar.gz", hash = "sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926"}, ] +[[package]] +name = "sniffio" +version = "1.3.0" +description = "Sniff out which async library your code is running under" +category = "main" +optional = false +python-versions = ">=3.7" +files = [ + {file = "sniffio-1.3.0-py3-none-any.whl", hash = "sha256:eecefdce1e5bbfb7ad2eeaabf7c1eeb404d7757c379bd1f7e5cce9d8bf425384"}, + {file = "sniffio-1.3.0.tar.gz", hash = "sha256:e60305c5e5d314f5389259b7f22aaa33d8f7dee49763119234af3755c55b9101"}, +] + [[package]] name = "sqlalchemy" version = "2.0.8" @@ -2107,6 +2162,25 @@ pure-eval = "*" [package.extras] tests = ["cython", "littleutils", "pygments", "pytest", "typeguard"] +[[package]] +name = "starlette" +version = "0.26.1" +description = "The little ASGI library that shines." +category = "main" +optional = false +python-versions = ">=3.7" +files = [ + {file = "starlette-0.26.1-py3-none-any.whl", hash = "sha256:e87fce5d7cbdde34b76f0ac69013fd9d190d581d80681493016666e6f96c6d5e"}, + {file = "starlette-0.26.1.tar.gz", hash = "sha256:41da799057ea8620e4667a3e69a5b1923ebd32b1819c8fa75634bbe8d8bea9bd"}, +] + +[package.dependencies] +anyio = ">=3.4.0,<5" +typing-extensions = {version = ">=3.10.0", markers = "python_version < \"3.10\""} + +[package.extras] +full = ["httpx (>=0.22.0)", "itsdangerous", "jinja2", "python-multipart", "pyyaml"] + [[package]] name = "structlog" version = "22.3.0" @@ -2420,4 +2494,4 @@ test = ["pytest (>=6.0.0)"] [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "be789d9689e5a8a9feeb55857c26f300acf7b99a95ab555fffe45239a14a0086" +content-hash = "235424e527001bb5aabde6b4f5381d9dcf97bfd0b1a9dc7207f4d177e4143e5b" diff --git a/pyproject.toml b/pyproject.toml index 1ddde4ce..e6ad75c2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,6 +45,7 @@ ipython = "^8.11.0" pytest-cov = "^4.0.0" pytest-asyncio = "^0.21.0" pytest = "*" +fastapi = "^0.95.0" [tool.pytest.ini_options] diff --git a/trapdata/api/__init__.py b/trapdata/api/__init__.py new file mode 100644 index 00000000..eec3f78c --- /dev/null +++ b/trapdata/api/__init__.py @@ -0,0 +1,11 @@ +from fastapi import APIRouter + +from trapdata.api import deployments, items, occurrences, users, utils + +api_router = APIRouter() + +api_router.include_router(utils.router, tags=["utils"]) +api_router.include_router(users.router, tags=["users"]) +api_router.include_router(items.router, tags=["items"]) +api_router.include_router(occurrences.router, tags=["occurrences"]) +api_router.include_router(deployments.router, tags=["deployments"]) diff --git a/trapdata/api/deployments.py b/trapdata/api/deployments.py new file mode 100644 index 00000000..4030e460 --- /dev/null +++ b/trapdata/api/deployments.py @@ -0,0 +1,27 @@ +from typing import Any, List, Optional + +from fastapi import APIRouter, Depends, HTTPException +from sqlalchemy import func, orm, select +from starlette.responses import Response + +from trapdata.api.deps.db import get_async_session +from trapdata.api.deps.request_params import parse_react_admin_params +from trapdata.api.request_params import RequestParams +from trapdata.db import Base, get_session +from trapdata.db.models.deployments import DeploymentListItem, list_deployments + +router = APIRouter(prefix="/deployments") + + +@router.get("", response_model=List[DeploymentListItem]) +async def get_deployments( + response: Response, + session: orm.Session = Depends(get_session), + request_params: RequestParams = Depends(parse_react_admin_params(Base)), +) -> Any: + total = await session.scalar(select(func.count(Deployment.id))) + deployments = list_deployments(session) + response.headers[ + "Content-Range" + ] = f"{request_params.skip}-{request_params.skip + len(deployments)}/{total}" + return deployments diff --git a/trapdata/api/deps/db.py b/trapdata/api/deps/db.py new file mode 100644 index 00000000..52a49423 --- /dev/null +++ b/trapdata/api/deps/db.py @@ -0,0 +1,15 @@ +from typing import Generator + +from sqlalchemy import orm + +from trapdata.cli import read_settings +from trapdata.db.base import get_session_class + +settings = read_settings() + + +def get_session() -> Generator[orm.Session, None]: + Session = get_session_class(settings.database_url) + with Session() as session: + yield session + session.close() diff --git a/trapdata/api/deps/request_params.py b/trapdata/api/deps/request_params.py new file mode 100644 index 00000000..113fdbbf --- /dev/null +++ b/trapdata/api/deps/request_params.py @@ -0,0 +1,46 @@ +import json +from typing import Callable, Optional, Type + +from fastapi import HTTPException, Query +from sqlalchemy import UnaryExpression, asc, desc + +from trapdata.api.request_params import RequestParams +from trapdata.db import Base + + +def parse_react_admin_params(model: Type[Base]) -> Callable: + """Parses sort and range parameters coming from a react-admin request""" + + def inner( + sort_: Optional[str] = Query( + None, + alias="sort", + description='Format: `["field_name", "direction"]`', + example='["id", "ASC"]', + ), + range_: Optional[str] = Query( + None, + alias="range", + description="Format: `[start, end]`", + example="[0, 10]", + ), + ) -> RequestParams: + skip, limit = 0, 10 + if range_: + start, end = json.loads(range_) + skip, limit = start, (end - start + 1) + + order_by: UnaryExpression = desc(model.id) + if sort_: + sort_column, sort_order = json.loads(sort_) + if sort_order.lower() == "asc": + direction = asc + elif sort_order.lower() == "desc": + direction = desc + else: + raise HTTPException(400, f"Invalid sort direction {sort_order}") + order_by = direction(model.__table__.c[sort_column]) + + return RequestParams(skip=skip, limit=limit, order_by=order_by) + + return inner diff --git a/trapdata/api/items.py b/trapdata/api/items.py new file mode 100644 index 00000000..e679e857 --- /dev/null +++ b/trapdata/api/items.py @@ -0,0 +1,103 @@ +from typing import Any, List, Optional + +from fastapi import APIRouter, Depends, HTTPException +from sqlalchemy import func, select +from sqlalchemy.ext.asyncio.session import AsyncSession +from starlette.responses import Response + +from app.deps.db import get_async_session +from app.deps.request_params import parse_react_admin_params +from app.deps.users import current_user +from app.models.item import Item +from app.models.user import User +from app.schemas.item import Item as ItemSchema +from app.schemas.item import ItemCreate, ItemUpdate +from app.schemas.request_params import RequestParams + +router = APIRouter(prefix="/items") + + +@router.get("", response_model=List[ItemSchema]) +async def get_items( + response: Response, + session: AsyncSession = Depends(get_async_session), + request_params: RequestParams = Depends(parse_react_admin_params(Item)), + user: User = Depends(current_user), +) -> Any: + total = await session.scalar( + select(func.count(Item.id).filter(Item.user_id == user.id)) + ) + items = ( + ( + await session.execute( + select(Item) + .offset(request_params.skip) + .limit(request_params.limit) + .order_by(request_params.order_by) + .filter(Item.user_id == user.id) + ) + ) + .scalars() + .all() + ) + response.headers[ + "Content-Range" + ] = f"{request_params.skip}-{request_params.skip + len(items)}/{total}" + return items + + +@router.post("", response_model=ItemSchema, status_code=201) +async def create_item( + item_in: ItemCreate, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_user), +) -> Any: + item = Item(**item_in.dict()) + item.user_id = user.id + session.add(item) + await session.commit() + return item + + +@router.put("/{item_id}", response_model=ItemSchema) +async def update_item( + item_id: int, + item_in: ItemUpdate, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_user), +) -> Any: + item: Optional[Item] = await session.get(Item, item_id) + if not item or item.user_id != user.id: + raise HTTPException(404) + update_data = item_in.dict(exclude_unset=True) + for field, value in update_data.items(): + setattr(item, field, value) + session.add(item) + await session.commit() + return item + + +@router.get("/{item_id}", response_model=ItemSchema) +async def get_item( + item_id: int, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_user), +) -> Any: + item: Optional[Item] = await session.get(Item, item_id) + if not item or item.user_id != user.id: + raise HTTPException(404) + return item + + +@router.delete("/{item_id}") +async def delete_item( + item_id: int, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_user), +) -> Any: + item: Optional[Item] = await session.get(Item, item_id) + if not item or item.user_id != user.id: + raise HTTPException(404) + await session.delete(item) + await session.commit() + return {"success": True} diff --git a/trapdata/api/occurrences.bak.py b/trapdata/api/occurrences.bak.py new file mode 100644 index 00000000..d91a252d --- /dev/null +++ b/trapdata/api/occurrences.bak.py @@ -0,0 +1,47 @@ +import random +from typing import Any, List, Optional + +from fastapi import APIRouter, Depends, HTTPException +from sqlalchemy import func, select +from sqlalchemy.ext.asyncio.session import AsyncSession +from starlette.responses import Response + +from app.deps.db import get_async_session +from app.deps.request_params import parse_react_admin_params +from app.deps.users import current_user + +# from app.models.occurrence import Occurrence +from app.models.user import User +from app.schemas.occurrence import test_data +from app.schemas.occurrence import Occurrence +from app.schemas.request_params import RequestParams + +router = APIRouter(prefix="/occurrences") + + +@router.get("", response_model=List[Occurrence]) +async def get_occurrences( + response: Response, + session: AsyncSession = Depends(get_async_session), + request_params: RequestParams = Depends(parse_react_admin_params(Occurrence)), +) -> Any: + occurrences = test_data[ + request_params.skip : request_params.skip + request_params.limit + ] + total = len(occurrences) + response.headers[ + "Content-Range" + ] = f"{request_params.skip}-{request_params.skip + len(occurrences)}/{total}" + return occurrences + + +@router.get("/{occurrence_id}", response_model=Occurrence) +async def get_occurrence( + occurrence_id: int, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_user), +) -> Any: + occurrence: Optional[Occurrence] = await session.get(Occurrence, occurrence_id) + if not occurrence or occurrence.user_id != user.id: + raise HTTPException(404) + return occurrence diff --git a/trapdata/api/occurrences.py b/trapdata/api/occurrences.py new file mode 100644 index 00000000..856013a4 --- /dev/null +++ b/trapdata/api/occurrences.py @@ -0,0 +1,71 @@ +from typing import Any, List, Optional + +from fastapi import APIRouter, Depends, HTTPException +from sqlalchemy import func, select +from sqlalchemy.ext.asyncio.session import AsyncSession +from starlette.responses import Response + +from app.deps.db import get_async_session +from app.deps.request_params import parse_react_admin_params +from app.deps.users import current_user +from app.models.occurrence import Occurrence +from app.models.user import User +from app.schemas.occurrence import Occurrence as OccurrenceSchema +from app.schemas.occurrence import OccurrenceCreate, OccurrenceUpdate +from app.schemas.request_params import RequestParams + +router = APIRouter(prefix="/occurrences") + + +@router.get("", response_model=List[OccurrenceSchema]) +async def get_occurrences( + response: Response, + session: AsyncSession = Depends(get_async_session), + request_params: RequestParams = Depends(parse_react_admin_params(Occurrence)), +) -> Any: + total = await session.scalar( + select( + func.count(Occurrence.id).filter( + Occurrence.deployment_id == request_params.deployment_id + ) + ) + ) + items = ( + ( + await session.execute( + select(Occurrence) + .offset(request_params.skip) + .limit(request_params.limit) + .order_by(request_params.order_by) + .filter(Occurrence.deployment_id == request_params.deployment_id) + ) + ) + .scalars() + .all() + ) + response.headers[ + "Content-Range" + ] = f"{request_params.skip}-{request_params.skip + len(items)}/{total}" + return items + + +@router.get("/{item_id}", response_model=OccurrenceSchema) +async def get_occurrence( + item_id: int, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_user), +) -> Any: + item: Optional[Occurrence] = await session.get(Occurrence, item_id) + return item + + +@router.delete("/{item_id}") +async def delete_occurrence( + item_id: int, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_user), +) -> Any: + item: Optional[Occurrence] = await session.get(Occurrence, item_id) + await session.delete(item) + await session.commit() + return {"success": True} diff --git a/trapdata/api/request_params.py b/trapdata/api/request_params.py new file mode 100644 index 00000000..43d5af5b --- /dev/null +++ b/trapdata/api/request_params.py @@ -0,0 +1,9 @@ +from typing import Any + +from pydantic.main import BaseModel + + +class RequestParams(BaseModel): + skip: int + limit: int + order_by: Any diff --git a/trapdata/api/users.py b/trapdata/api/users.py new file mode 100644 index 00000000..06ca854e --- /dev/null +++ b/trapdata/api/users.py @@ -0,0 +1,30 @@ +from typing import Any, List + +from fastapi.params import Depends +from fastapi.routing import APIRouter +from sqlalchemy import func, select +from sqlalchemy.ext.asyncio.session import AsyncSession +from starlette.responses import Response + +from app.deps.db import get_async_session +from app.deps.users import current_superuser +from app.models.user import User +from app.schemas.user import UserRead + +router = APIRouter() + + +@router.get("/users", response_model=List[UserRead]) +async def get_users( + response: Response, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_superuser), + skip: int = 0, + limit: int = 100, +) -> Any: + total = await session.scalar(select(func.count(User.id))) + users = ( + (await session.execute(select(User).offset(skip).limit(limit))).scalars().all() + ) + response.headers["Content-Range"] = f"{skip}-{skip + len(users)}/{total}" + return users diff --git a/trapdata/api/utils.py b/trapdata/api/utils.py new file mode 100644 index 00000000..e9957285 --- /dev/null +++ b/trapdata/api/utils.py @@ -0,0 +1,17 @@ +from typing import Any + +from fastapi import APIRouter + +from app.schemas.msg import Msg + +router = APIRouter() + + +@router.get( + "/hello-world", + response_model=Msg, + status_code=200, + include_in_schema=False, +) +def test_hello_world() -> Any: + return {"msg": "Hello world!"} diff --git a/trapdata/db/__init__.py b/trapdata/db/__init__.py index a5493f58..5c390e50 100644 --- a/trapdata/db/__init__.py +++ b/trapdata/db/__init__.py @@ -1,3 +1,5 @@ +from typing import Any + import sqlalchemy as sa from sqlalchemy import orm @@ -16,4 +18,4 @@ class Base(orm.DeclarativeBase): - pass + id: Any diff --git a/trapdata/db/base.py b/trapdata/db/base.py index 84d02223..2da8f780 100644 --- a/trapdata/db/base.py +++ b/trapdata/db/base.py @@ -231,3 +231,19 @@ def get_or_create(session, model, defaults=None, **kwargs): return instance, False else: return instance, True + + +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine + + +def get_async_session_class(db_path: str) -> async_sessionmaker[AsyncSession]: + async_engine = create_async_engine(db_path, pool_pre_ping=True) + + async_session_maker = async_sessionmaker( + async_engine, + class_=AsyncSession, + expire_on_commit=False, + autocommit=False, + autoflush=False, + ) + return async_session_maker diff --git a/trapdata/db/models/deployments.py b/trapdata/db/models/deployments.py index 1cec4864..b0f41fe1 100644 --- a/trapdata/db/models/deployments.py +++ b/trapdata/db/models/deployments.py @@ -14,8 +14,6 @@ from trapdata.common.types import FilePath from trapdata.db import models -from pydantic import BaseModel - class DeploymentListItem(BaseModel): # id: int @@ -27,14 +25,8 @@ class DeploymentListItem(BaseModel): # num_species: int -class DeploymentListItem(BaseModel): - # id: int - name: str - num_events: int - num_source_images: int - num_detections: int - # num_occurrences: int - # num_species: int +class DeploymentDetail(DeploymentListItem): + pass def deployment_name(image_base_path: FilePath) -> str: diff --git a/trapdata/db/models/images.py b/trapdata/db/models/images.py index c2707697..166fa798 100644 --- a/trapdata/db/models/images.py +++ b/trapdata/db/models/images.py @@ -27,7 +27,6 @@ class CaptureListItem(BaseModel): class CaptureDetail(CaptureListItem): - id: int event: object notes: Optional[str] detections: list From 65b125a76d363a50853bc538413ee2be3c1a3a6a Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Mon, 10 Apr 2023 01:21:23 -0700 Subject: [PATCH 04/53] Working API example --- poetry.lock | 35 ++++++++- pyproject.toml | 1 + trapdata/api/__init__.py | 11 --- trapdata/api/config.py | 39 ++++++++++ trapdata/api/deployments.py | 27 ------- trapdata/api/deps/db.py | 4 +- trapdata/api/factory.py | 82 +++++++++++++++++++++ trapdata/api/main.py | 20 +++++ trapdata/api/views/__init__.py | 8 ++ trapdata/api/views/deployments.py | 39 ++++++++++ trapdata/api/{ => views}/items.py | 0 trapdata/api/{ => views}/occurrences.bak.py | 0 trapdata/api/{ => views}/occurrences.py | 0 trapdata/api/{utils.py => views/stats.py} | 11 ++- trapdata/api/{ => views}/users.py | 0 trapdata/cli/base.py | 10 +++ trapdata/db/models/deployments.py | 3 +- trapdata/settings.py | 7 +- trapdata/webui/public/index.html | 1 + 19 files changed, 250 insertions(+), 48 deletions(-) delete mode 100644 trapdata/api/__init__.py create mode 100644 trapdata/api/config.py delete mode 100644 trapdata/api/deployments.py create mode 100644 trapdata/api/factory.py create mode 100644 trapdata/api/main.py create mode 100644 trapdata/api/views/__init__.py create mode 100644 trapdata/api/views/deployments.py rename trapdata/api/{ => views}/items.py (100%) rename trapdata/api/{ => views}/occurrences.bak.py (100%) rename trapdata/api/{ => views}/occurrences.py (100%) rename trapdata/api/{utils.py => views/stats.py} (65%) rename trapdata/api/{ => views}/users.py (100%) create mode 100644 trapdata/webui/public/index.html diff --git a/poetry.lock b/poetry.lock index da8cbc61..5894bb32 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry and should not be changed by hand. +# This file is automatically @generated by Poetry 1.4.1 and should not be changed by hand. [[package]] name = "alembic" @@ -556,6 +556,18 @@ files = [ docs = ["Sphinx", "docutils (<0.18)"] test = ["objgraph", "psutil"] +[[package]] +name = "h11" +version = "0.14.0" +description = "A pure-Python, bring-your-own-I/O implementation of HTTP/1.1" +category = "main" +optional = false +python-versions = ">=3.7" +files = [ + {file = "h11-0.14.0-py3-none-any.whl", hash = "sha256:e3fe4ac4b851c468cc8363d500db52c2ead036020723024a109d37346efaa761"}, + {file = "h11-0.14.0.tar.gz", hash = "sha256:8f19fbbe99e72420ff35c00b27a34cb9937e902a8b810e2c88300c6f0a3b699d"}, +] + [[package]] name = "huggingface-hub" version = "0.13.3" @@ -2464,6 +2476,25 @@ brotli = ["brotli (>=1.0.9)", "brotlicffi (>=0.8.0)", "brotlipy (>=0.6.0)"] secure = ["certifi", "cryptography (>=1.3.4)", "idna (>=2.0.0)", "ipaddress", "pyOpenSSL (>=0.14)", "urllib3-secure-extra"] socks = ["PySocks (>=1.5.6,!=1.5.7,<2.0)"] +[[package]] +name = "uvicorn" +version = "0.21.1" +description = "The lightning-fast ASGI server." +category = "main" +optional = false +python-versions = ">=3.7" +files = [ + {file = "uvicorn-0.21.1-py3-none-any.whl", hash = "sha256:e47cac98a6da10cd41e6fd036d472c6f58ede6c5dbee3dbee3ef7a100ed97742"}, + {file = "uvicorn-0.21.1.tar.gz", hash = "sha256:0fac9cb342ba099e0d582966005f3fdba5b0290579fed4a6266dc702ca7bb032"}, +] + +[package.dependencies] +click = ">=7.0" +h11 = ">=0.8" + +[package.extras] +standard = ["colorama (>=0.4)", "httptools (>=0.5.0)", "python-dotenv (>=0.13)", "pyyaml (>=5.1)", "uvloop (>=0.14.0,!=0.15.0,!=0.15.1)", "watchfiles (>=0.13)", "websockets (>=10.4)"] + [[package]] name = "wcwidth" version = "0.2.6" @@ -2494,4 +2525,4 @@ test = ["pytest (>=6.0.0)"] [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "235424e527001bb5aabde6b4f5381d9dcf97bfd0b1a9dc7207f4d177e4143e5b" +content-hash = "453accf6c7a42a5ccebff574cedeef2a4a4797d04a8b155375023c901b40fe37" diff --git a/pyproject.toml b/pyproject.toml index e6ad75c2..1409d445 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,6 +46,7 @@ pytest-cov = "^4.0.0" pytest-asyncio = "^0.21.0" pytest = "*" fastapi = "^0.95.0" +uvicorn = "^0.21.1" [tool.pytest.ini_options] diff --git a/trapdata/api/__init__.py b/trapdata/api/__init__.py deleted file mode 100644 index eec3f78c..00000000 --- a/trapdata/api/__init__.py +++ /dev/null @@ -1,11 +0,0 @@ -from fastapi import APIRouter - -from trapdata.api import deployments, items, occurrences, users, utils - -api_router = APIRouter() - -api_router.include_router(utils.router, tags=["utils"]) -api_router.include_router(users.router, tags=["users"]) -api_router.include_router(items.router, tags=["items"]) -api_router.include_router(occurrences.router, tags=["occurrences"]) -api_router.include_router(deployments.router, tags=["deployments"]) diff --git a/trapdata/api/config.py b/trapdata/api/config.py new file mode 100644 index 00000000..d4810984 --- /dev/null +++ b/trapdata/api/config.py @@ -0,0 +1,39 @@ +import pathlib +from typing import Any, Dict, List, Optional + +from pydantic import BaseSettings, HttpUrl, PostgresDsn, validator +from pydantic.networks import AnyHttpUrl + +from trapdata.cli import read_settings +from trapdata.settings import Settings as BaseSettings + + +class Settings(BaseSettings): + PROJECT_NAME: str = "AMI Data Manager" + + SENTRY_DSN: Optional[HttpUrl] = None + + API_PATH: str = "/api/v1" + + ACCESS_TOKEN_EXPIRE_MINUTES: int = 7 * 24 * 60 # 7 days + + BACKEND_CORS_ORIGINS: List[AnyHttpUrl] = [] + + # The following variables need to be defined in environment + + TEST_DATABASE_URL: Optional[PostgresDsn] + + SECRET_KEY: str + # END: required environment variables + + # STATIC_ROOT: str = "static" + + # @validator("STATIC_ROOT") + # def validate_static_root(cls, v): + # path = cls.user_data_path / v + # path.mkdir(parents=True, exist_ok=True) + # return path + + +# settings = read_settings(SettingsClass=Settings, SECRET_KEY="secret") +settings = Settings(SECRET_KEY="secret") diff --git a/trapdata/api/deployments.py b/trapdata/api/deployments.py deleted file mode 100644 index 4030e460..00000000 --- a/trapdata/api/deployments.py +++ /dev/null @@ -1,27 +0,0 @@ -from typing import Any, List, Optional - -from fastapi import APIRouter, Depends, HTTPException -from sqlalchemy import func, orm, select -from starlette.responses import Response - -from trapdata.api.deps.db import get_async_session -from trapdata.api.deps.request_params import parse_react_admin_params -from trapdata.api.request_params import RequestParams -from trapdata.db import Base, get_session -from trapdata.db.models.deployments import DeploymentListItem, list_deployments - -router = APIRouter(prefix="/deployments") - - -@router.get("", response_model=List[DeploymentListItem]) -async def get_deployments( - response: Response, - session: orm.Session = Depends(get_session), - request_params: RequestParams = Depends(parse_react_admin_params(Base)), -) -> Any: - total = await session.scalar(select(func.count(Deployment.id))) - deployments = list_deployments(session) - response.headers[ - "Content-Range" - ] = f"{request_params.skip}-{request_params.skip + len(deployments)}/{total}" - return deployments diff --git a/trapdata/api/deps/db.py b/trapdata/api/deps/db.py index 52a49423..c800a29a 100644 --- a/trapdata/api/deps/db.py +++ b/trapdata/api/deps/db.py @@ -8,8 +8,8 @@ settings = read_settings() -def get_session() -> Generator[orm.Session, None]: - Session = get_session_class(settings.database_url) +def get_session() -> Generator[orm.Session, None, None]: + Session = get_session_class(db_path=settings.database_url) with Session() as session: yield session session.close() diff --git a/trapdata/api/factory.py b/trapdata/api/factory.py new file mode 100644 index 00000000..b7a8f1ae --- /dev/null +++ b/trapdata/api/factory.py @@ -0,0 +1,82 @@ +from fastapi import FastAPI +from fastapi.routing import APIRoute +from fastapi.staticfiles import StaticFiles +from starlette.middleware.cors import CORSMiddleware +from starlette.requests import Request +from starlette.responses import FileResponse, RedirectResponse + +from trapdata.api.config import settings +from trapdata.api.views import api_router + + +def create_app(): + description = f"{settings.PROJECT_NAME} API" + app = FastAPI( + title=settings.PROJECT_NAME, + openapi_url=f"{settings.API_PATH}/openapi.json", + docs_url="/docs/", + description=description, + redoc_url="/redoc/", + ) + setup_routers(app) + setup_cors_middleware(app) + serve_static_app(app) + return app + + +def setup_routers(app: FastAPI) -> None: + app.include_router(api_router, prefix=settings.API_PATH) + # The following operation needs to be at the end of this function + use_route_names_as_operation_ids(app) + + +def serve_static_app(app): + app.mount( + "/static/crops", + StaticFiles(directory=settings.user_data_path / "crops"), + name="crops", + ) + app.mount( + "/", + StaticFiles(directory="trapdata/webui/public"), + name="static", + ) + + @app.middleware("http") + async def _add_404_middleware(request: Request, call_next): + """Serves static assets on 404""" + response = await call_next(request) + path = request["path"] + if path.startswith(settings.API_PATH) or path.startswith("/docs"): + return response + if response.status_code == 404: + return FileResponse("trapdata/webui/public/index.html") + return response + + +def setup_cors_middleware(app): + if settings.BACKEND_CORS_ORIGINS: + app.add_middleware( + CORSMiddleware, + allow_origins=[str(origin) for origin in settings.BACKEND_CORS_ORIGINS], + allow_credentials=True, + allow_methods=["*"], + expose_headers=["Content-Range", "Range"], + allow_headers=["Authorization", "Range", "Content-Range"], + ) + + +def use_route_names_as_operation_ids(app: FastAPI) -> None: + """ + Simplify operation IDs so that generated API clients have simpler function + names. + + Should be called only after all routes have been added. + """ + route_names = set() + for route in app.routes: + if isinstance(route, APIRoute): + if route.name in route_names: + raise Exception("Route function names should be unique") + route.operation_id = route.name + route_names.add(route.name) diff --git a/trapdata/api/main.py b/trapdata/api/main.py new file mode 100644 index 00000000..1d1cd547 --- /dev/null +++ b/trapdata/api/main.py @@ -0,0 +1,20 @@ +from trapdata import logger +from trapdata.api.factory import create_app + +app = create_app() + + +def run(): + import uvicorn + + logger.info("Starting uvicorn in reload mode") + uvicorn.run( + "main:app", + host="0.0.0.0", + reload=True, + port=int("8000"), + ) + + +if __name__ == "__main__": + run() diff --git a/trapdata/api/views/__init__.py b/trapdata/api/views/__init__.py new file mode 100644 index 00000000..ed7afc4a --- /dev/null +++ b/trapdata/api/views/__init__.py @@ -0,0 +1,8 @@ +from fastapi import APIRouter + +from trapdata.api.views import deployments, stats + +api_router = APIRouter() + +api_router.include_router(stats.router, tags=["stats"]) +api_router.include_router(deployments.router, tags=["deployments"]) diff --git a/trapdata/api/views/deployments.py b/trapdata/api/views/deployments.py new file mode 100644 index 00000000..f3bd15d8 --- /dev/null +++ b/trapdata/api/views/deployments.py @@ -0,0 +1,39 @@ +from typing import Any, List, Optional + +from fastapi import APIRouter, Depends, HTTPException +from sqlalchemy import func, orm, select +from starlette.responses import Response + +from trapdata.api.config import settings +from trapdata.api.deps.db import get_session +from trapdata.api.deps.request_params import parse_react_admin_params +from trapdata.api.request_params import RequestParams +from trapdata.db import Base +from trapdata.db.models.deployments import DeploymentListItem, list_deployments + +router = APIRouter(prefix="/deployments") + + +@router.get("", response_model=List[DeploymentListItem]) +async def get_deployments( + response: Response, + session: orm.Session = Depends(get_session), + # request_params: RequestParams = Depends(parse_react_admin_params(Base)), +) -> Any: + deployments = list_deployments(session) + return deployments + + +@router.post("/process", response_model=List[DeploymentListItem]) +async def process_deployment( + response: Response, + session: orm.Session = Depends(get_session), + # request_params: RequestParams = Depends(parse_react_admin_params(Base)), +) -> Any: + from trapdata.ml.pipeline import start_pipeline + + start_pipeline( + session=session, image_base_path=settings.image_base_path, settings=settings + ) + deployments = list_deployments(session) + return deployments diff --git a/trapdata/api/items.py b/trapdata/api/views/items.py similarity index 100% rename from trapdata/api/items.py rename to trapdata/api/views/items.py diff --git a/trapdata/api/occurrences.bak.py b/trapdata/api/views/occurrences.bak.py similarity index 100% rename from trapdata/api/occurrences.bak.py rename to trapdata/api/views/occurrences.bak.py diff --git a/trapdata/api/occurrences.py b/trapdata/api/views/occurrences.py similarity index 100% rename from trapdata/api/occurrences.py rename to trapdata/api/views/occurrences.py diff --git a/trapdata/api/utils.py b/trapdata/api/views/stats.py similarity index 65% rename from trapdata/api/utils.py rename to trapdata/api/views/stats.py index e9957285..c3e4d43e 100644 --- a/trapdata/api/utils.py +++ b/trapdata/api/views/stats.py @@ -2,13 +2,18 @@ from fastapi import APIRouter -from app.schemas.msg import Msg +router = APIRouter(prefix="/stats") -router = APIRouter() + +from pydantic import BaseModel + + +class Msg(BaseModel): + msg: str @router.get( - "/hello-world", + "/", response_model=Msg, status_code=200, include_in_schema=False, diff --git a/trapdata/api/users.py b/trapdata/api/views/users.py similarity index 100% rename from trapdata/api/users.py rename to trapdata/api/views/users.py diff --git a/trapdata/cli/base.py b/trapdata/cli/base.py index b124b86f..771f5139 100644 --- a/trapdata/cli/base.py +++ b/trapdata/cli/base.py @@ -30,6 +30,16 @@ def gui(): run() +@cli.command() +def api(): + """ + Launch API server + """ + from trapdata.api.main import run as start_api + + start_api() + + @cli.command("import") def import_data(image_base_path: Optional[pathlib.Path] = None, queue: bool = True): """ diff --git a/trapdata/db/models/deployments.py b/trapdata/db/models/deployments.py index b0f41fe1..8e5a37f0 100644 --- a/trapdata/db/models/deployments.py +++ b/trapdata/db/models/deployments.py @@ -6,6 +6,7 @@ is used as the deployment name. """ import pathlib +from typing import Optional import sqlalchemy as sa from pydantic import BaseModel @@ -16,7 +17,7 @@ class DeploymentListItem(BaseModel): - # id: int + id: Optional[int] = None name: str num_events: int num_source_images: int diff --git a/trapdata/settings.py b/trapdata/settings.py index 2d22cc8d..e765cfdb 100644 --- a/trapdata/settings.py +++ b/trapdata/settings.py @@ -6,6 +6,7 @@ import sqlalchemy from pydantic import BaseSettings, Field, ValidationError, validator +from pydantic.main import ModelMetaclass from rich import print as rprint from trapdata import ml @@ -199,9 +200,11 @@ def kivy_settings_source(settings: BaseSettings) -> dict[str, str]: @lru_cache -def read_settings(*args, **kwargs): +def read_settings( + settings_class: ModelMetaclass = Settings, *args, **kwargs +) -> ModelMetaclass: try: - return Settings(*args, **kwargs) + return settings_class(*args, **kwargs) except ValidationError as e: # @TODO the validation errors could be printed in a more helpful way: rprint(cli_help_message) diff --git a/trapdata/webui/public/index.html b/trapdata/webui/public/index.html new file mode 100644 index 00000000..f944b384 --- /dev/null +++ b/trapdata/webui/public/index.html @@ -0,0 +1 @@ +:) From 3df9e3a0424443c5b96afd432f786cb4352e192d Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Wed, 29 Mar 2023 01:56:26 -0700 Subject: [PATCH 05/53] Quick hacks to serialize data with real image urls --- trapdata/cli/export.py | 2 +- trapdata/db/models/detections.py | 15 +++++++++++++++ 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/trapdata/cli/export.py b/trapdata/cli/export.py index fdc07322..98d6dfad 100644 --- a/trapdata/cli/export.py +++ b/trapdata/cli/export.py @@ -202,7 +202,7 @@ def captures( events = get_monitoring_session_by_date( db_path=settings.database_url, base_directory=settings.image_base_path, - event_dates=[str(date.date())], + event_dates=[date.date()], ) if not len(events): raise Exception(f"No Monitoring Event with date: {date.date()}") diff --git a/trapdata/db/models/detections.py b/trapdata/db/models/detections.py index 37c30b77..c584f52f 100644 --- a/trapdata/db/models/detections.py +++ b/trapdata/db/models/detections.py @@ -15,6 +15,21 @@ from trapdata.db import models from trapdata.db.models.images import completely_classified +from pydantic import BaseModel + + +class DetectionListItem(BaseModel): + id: int + cropped_image_path: Optional[str] + bbox: Optional[list[int]] + area_pixels: Optional[int] + last_detected: Optional[datetime.datetime] + label: Optional[str] + score: Optional[int] + model_name: Optional[str] + in_queue: bool + notes: Optional[str] + class DetectionListItem(BaseModel): id: int From b1215b182d5932178a795fd30668d594a4a5b30a Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Wed, 29 Mar 2023 17:30:44 -0700 Subject: [PATCH 06/53] Use dataclass for Deployment --- trapdata/db/models/deployments.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/trapdata/db/models/deployments.py b/trapdata/db/models/deployments.py index 5b75b074..1cec4864 100644 --- a/trapdata/db/models/deployments.py +++ b/trapdata/db/models/deployments.py @@ -14,6 +14,18 @@ from trapdata.common.types import FilePath from trapdata.db import models +from pydantic import BaseModel + + +class DeploymentListItem(BaseModel): + # id: int + name: str + num_events: int + num_source_images: int + num_detections: int + # num_occurrences: int + # num_species: int + class DeploymentListItem(BaseModel): # id: int From fe9d6e2d8dd2392ca81d9745d7194dd42c4e9dc6 Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Wed, 5 Apr 2023 20:51:37 -0700 Subject: [PATCH 07/53] Incomplete start of an API implementation --- poetry.lock | 76 +++++++++++++++++++- pyproject.toml | 1 + trapdata/api/__init__.py | 11 +++ trapdata/api/deployments.py | 27 ++++++++ trapdata/api/deps/db.py | 15 ++++ trapdata/api/deps/request_params.py | 46 +++++++++++++ trapdata/api/items.py | 103 ++++++++++++++++++++++++++++ trapdata/api/occurrences.bak.py | 47 +++++++++++++ trapdata/api/occurrences.py | 71 +++++++++++++++++++ trapdata/api/request_params.py | 9 +++ trapdata/api/users.py | 30 ++++++++ trapdata/api/utils.py | 17 +++++ trapdata/db/__init__.py | 4 +- trapdata/db/base.py | 16 +++++ trapdata/db/models/deployments.py | 12 +--- trapdata/db/models/images.py | 1 - 16 files changed, 473 insertions(+), 13 deletions(-) create mode 100644 trapdata/api/__init__.py create mode 100644 trapdata/api/deployments.py create mode 100644 trapdata/api/deps/db.py create mode 100644 trapdata/api/deps/request_params.py create mode 100644 trapdata/api/items.py create mode 100644 trapdata/api/occurrences.bak.py create mode 100644 trapdata/api/occurrences.py create mode 100644 trapdata/api/request_params.py create mode 100644 trapdata/api/users.py create mode 100644 trapdata/api/utils.py diff --git a/poetry.lock b/poetry.lock index d0939916..da8cbc61 100644 --- a/poetry.lock +++ b/poetry.lock @@ -20,6 +20,27 @@ typing-extensions = ">=4" [package.extras] tz = ["python-dateutil"] +[[package]] +name = "anyio" +version = "3.6.2" +description = "High level compatibility layer for multiple asynchronous event loop implementations" +category = "main" +optional = false +python-versions = ">=3.6.2" +files = [ + {file = "anyio-3.6.2-py3-none-any.whl", hash = "sha256:fbbe32bd270d2a2ef3ed1c5d45041250284e31fc0a4df4a5a6071842051a51e3"}, + {file = "anyio-3.6.2.tar.gz", hash = "sha256:25ea0d673ae30af41a0c442f81cf3b38c7e79fdc7b60335a4c14e05eb0947421"}, +] + +[package.dependencies] +idna = ">=2.8" +sniffio = ">=1.1" + +[package.extras] +doc = ["packaging", "sphinx-autodoc-typehints (>=1.2.0)", "sphinx-rtd-theme"] +test = ["contextlib2", "coverage[toml] (>=4.5)", "hypothesis (>=4.0)", "mock (>=4)", "pytest (>=7.0)", "pytest-mock (>=3.6.1)", "trustme", "uvloop (<0.15)", "uvloop (>=0.15)"] +trio = ["trio (>=0.16,<0.22)"] + [[package]] name = "appnope" version = "0.1.3" @@ -406,6 +427,28 @@ files = [ [package.extras] tests = ["asttokens", "littleutils", "pytest", "rich"] +[[package]] +name = "fastapi" +version = "0.95.0" +description = "FastAPI framework, high performance, easy to learn, fast to code, ready for production" +category = "main" +optional = false +python-versions = ">=3.7" +files = [ + {file = "fastapi-0.95.0-py3-none-any.whl", hash = "sha256:daf73bbe844180200be7966f68e8ec9fd8be57079dff1bacb366db32729e6eb5"}, + {file = "fastapi-0.95.0.tar.gz", hash = "sha256:99d4fdb10e9dd9a24027ac1d0bd4b56702652056ca17a6c8721eec4ad2f14e18"}, +] + +[package.dependencies] +pydantic = ">=1.6.2,<1.7 || >1.7,<1.7.1 || >1.7.1,<1.7.2 || >1.7.2,<1.7.3 || >1.7.3,<1.8 || >1.8,<1.8.1 || >1.8.1,<2.0.0" +starlette = ">=0.26.1,<0.27.0" + +[package.extras] +all = ["email-validator (>=1.1.1)", "httpx (>=0.23.0)", "itsdangerous (>=1.1.0)", "jinja2 (>=2.11.2)", "orjson (>=3.2.1)", "python-multipart (>=0.0.5)", "pyyaml (>=5.3.1)", "ujson (>=4.0.1,!=4.0.2,!=4.1.0,!=4.2.0,!=4.3.0,!=5.0.0,!=5.1.0)", "uvicorn[standard] (>=0.12.0)"] +dev = ["pre-commit (>=2.17.0,<3.0.0)", "ruff (==0.0.138)", "uvicorn[standard] (>=0.12.0,<0.21.0)"] +doc = ["mdx-include (>=1.4.1,<2.0.0)", "mkdocs (>=1.1.2,<2.0.0)", "mkdocs-markdownextradata-plugin (>=0.1.7,<0.3.0)", "mkdocs-material (>=8.1.4,<9.0.0)", "pyyaml (>=5.3.1,<7.0.0)", "typer-cli (>=0.0.13,<0.0.14)", "typer[all] (>=0.6.1,<0.8.0)"] +test = ["anyio[trio] (>=3.2.1,<4.0.0)", "black (==23.1.0)", "coverage[toml] (>=6.5.0,<8.0)", "databases[sqlite] (>=0.3.2,<0.7.0)", "email-validator (>=1.1.1,<2.0.0)", "flask (>=1.1.2,<3.0.0)", "httpx (>=0.23.0,<0.24.0)", "isort (>=5.0.6,<6.0.0)", "mypy (==0.982)", "orjson (>=3.2.1,<4.0.0)", "passlib[bcrypt] (>=1.7.2,<2.0.0)", "peewee (>=3.13.3,<4.0.0)", "pytest (>=7.1.3,<8.0.0)", "python-jose[cryptography] (>=3.3.0,<4.0.0)", "python-multipart (>=0.0.5,<0.0.7)", "pyyaml (>=5.3.1,<7.0.0)", "ruff (==0.0.138)", "sqlalchemy (>=1.3.18,<1.4.43)", "types-orjson (==3.6.2)", "types-ujson (==5.7.0.1)", "ujson (>=4.0.1,!=4.0.2,!=4.1.0,!=4.2.0,!=4.3.0,!=5.0.0,!=5.1.0,<6.0.0)"] + [[package]] name = "filelock" version = "3.10.7" @@ -1980,6 +2023,18 @@ files = [ {file = "six-1.16.0.tar.gz", hash = "sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926"}, ] +[[package]] +name = "sniffio" +version = "1.3.0" +description = "Sniff out which async library your code is running under" +category = "main" +optional = false +python-versions = ">=3.7" +files = [ + {file = "sniffio-1.3.0-py3-none-any.whl", hash = "sha256:eecefdce1e5bbfb7ad2eeaabf7c1eeb404d7757c379bd1f7e5cce9d8bf425384"}, + {file = "sniffio-1.3.0.tar.gz", hash = "sha256:e60305c5e5d314f5389259b7f22aaa33d8f7dee49763119234af3755c55b9101"}, +] + [[package]] name = "sqlalchemy" version = "2.0.8" @@ -2107,6 +2162,25 @@ pure-eval = "*" [package.extras] tests = ["cython", "littleutils", "pygments", "pytest", "typeguard"] +[[package]] +name = "starlette" +version = "0.26.1" +description = "The little ASGI library that shines." +category = "main" +optional = false +python-versions = ">=3.7" +files = [ + {file = "starlette-0.26.1-py3-none-any.whl", hash = "sha256:e87fce5d7cbdde34b76f0ac69013fd9d190d581d80681493016666e6f96c6d5e"}, + {file = "starlette-0.26.1.tar.gz", hash = "sha256:41da799057ea8620e4667a3e69a5b1923ebd32b1819c8fa75634bbe8d8bea9bd"}, +] + +[package.dependencies] +anyio = ">=3.4.0,<5" +typing-extensions = {version = ">=3.10.0", markers = "python_version < \"3.10\""} + +[package.extras] +full = ["httpx (>=0.22.0)", "itsdangerous", "jinja2", "python-multipart", "pyyaml"] + [[package]] name = "structlog" version = "22.3.0" @@ -2420,4 +2494,4 @@ test = ["pytest (>=6.0.0)"] [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "be789d9689e5a8a9feeb55857c26f300acf7b99a95ab555fffe45239a14a0086" +content-hash = "235424e527001bb5aabde6b4f5381d9dcf97bfd0b1a9dc7207f4d177e4143e5b" diff --git a/pyproject.toml b/pyproject.toml index 1ddde4ce..e6ad75c2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -45,6 +45,7 @@ ipython = "^8.11.0" pytest-cov = "^4.0.0" pytest-asyncio = "^0.21.0" pytest = "*" +fastapi = "^0.95.0" [tool.pytest.ini_options] diff --git a/trapdata/api/__init__.py b/trapdata/api/__init__.py new file mode 100644 index 00000000..eec3f78c --- /dev/null +++ b/trapdata/api/__init__.py @@ -0,0 +1,11 @@ +from fastapi import APIRouter + +from trapdata.api import deployments, items, occurrences, users, utils + +api_router = APIRouter() + +api_router.include_router(utils.router, tags=["utils"]) +api_router.include_router(users.router, tags=["users"]) +api_router.include_router(items.router, tags=["items"]) +api_router.include_router(occurrences.router, tags=["occurrences"]) +api_router.include_router(deployments.router, tags=["deployments"]) diff --git a/trapdata/api/deployments.py b/trapdata/api/deployments.py new file mode 100644 index 00000000..4030e460 --- /dev/null +++ b/trapdata/api/deployments.py @@ -0,0 +1,27 @@ +from typing import Any, List, Optional + +from fastapi import APIRouter, Depends, HTTPException +from sqlalchemy import func, orm, select +from starlette.responses import Response + +from trapdata.api.deps.db import get_async_session +from trapdata.api.deps.request_params import parse_react_admin_params +from trapdata.api.request_params import RequestParams +from trapdata.db import Base, get_session +from trapdata.db.models.deployments import DeploymentListItem, list_deployments + +router = APIRouter(prefix="/deployments") + + +@router.get("", response_model=List[DeploymentListItem]) +async def get_deployments( + response: Response, + session: orm.Session = Depends(get_session), + request_params: RequestParams = Depends(parse_react_admin_params(Base)), +) -> Any: + total = await session.scalar(select(func.count(Deployment.id))) + deployments = list_deployments(session) + response.headers[ + "Content-Range" + ] = f"{request_params.skip}-{request_params.skip + len(deployments)}/{total}" + return deployments diff --git a/trapdata/api/deps/db.py b/trapdata/api/deps/db.py new file mode 100644 index 00000000..52a49423 --- /dev/null +++ b/trapdata/api/deps/db.py @@ -0,0 +1,15 @@ +from typing import Generator + +from sqlalchemy import orm + +from trapdata.cli import read_settings +from trapdata.db.base import get_session_class + +settings = read_settings() + + +def get_session() -> Generator[orm.Session, None]: + Session = get_session_class(settings.database_url) + with Session() as session: + yield session + session.close() diff --git a/trapdata/api/deps/request_params.py b/trapdata/api/deps/request_params.py new file mode 100644 index 00000000..113fdbbf --- /dev/null +++ b/trapdata/api/deps/request_params.py @@ -0,0 +1,46 @@ +import json +from typing import Callable, Optional, Type + +from fastapi import HTTPException, Query +from sqlalchemy import UnaryExpression, asc, desc + +from trapdata.api.request_params import RequestParams +from trapdata.db import Base + + +def parse_react_admin_params(model: Type[Base]) -> Callable: + """Parses sort and range parameters coming from a react-admin request""" + + def inner( + sort_: Optional[str] = Query( + None, + alias="sort", + description='Format: `["field_name", "direction"]`', + example='["id", "ASC"]', + ), + range_: Optional[str] = Query( + None, + alias="range", + description="Format: `[start, end]`", + example="[0, 10]", + ), + ) -> RequestParams: + skip, limit = 0, 10 + if range_: + start, end = json.loads(range_) + skip, limit = start, (end - start + 1) + + order_by: UnaryExpression = desc(model.id) + if sort_: + sort_column, sort_order = json.loads(sort_) + if sort_order.lower() == "asc": + direction = asc + elif sort_order.lower() == "desc": + direction = desc + else: + raise HTTPException(400, f"Invalid sort direction {sort_order}") + order_by = direction(model.__table__.c[sort_column]) + + return RequestParams(skip=skip, limit=limit, order_by=order_by) + + return inner diff --git a/trapdata/api/items.py b/trapdata/api/items.py new file mode 100644 index 00000000..e679e857 --- /dev/null +++ b/trapdata/api/items.py @@ -0,0 +1,103 @@ +from typing import Any, List, Optional + +from fastapi import APIRouter, Depends, HTTPException +from sqlalchemy import func, select +from sqlalchemy.ext.asyncio.session import AsyncSession +from starlette.responses import Response + +from app.deps.db import get_async_session +from app.deps.request_params import parse_react_admin_params +from app.deps.users import current_user +from app.models.item import Item +from app.models.user import User +from app.schemas.item import Item as ItemSchema +from app.schemas.item import ItemCreate, ItemUpdate +from app.schemas.request_params import RequestParams + +router = APIRouter(prefix="/items") + + +@router.get("", response_model=List[ItemSchema]) +async def get_items( + response: Response, + session: AsyncSession = Depends(get_async_session), + request_params: RequestParams = Depends(parse_react_admin_params(Item)), + user: User = Depends(current_user), +) -> Any: + total = await session.scalar( + select(func.count(Item.id).filter(Item.user_id == user.id)) + ) + items = ( + ( + await session.execute( + select(Item) + .offset(request_params.skip) + .limit(request_params.limit) + .order_by(request_params.order_by) + .filter(Item.user_id == user.id) + ) + ) + .scalars() + .all() + ) + response.headers[ + "Content-Range" + ] = f"{request_params.skip}-{request_params.skip + len(items)}/{total}" + return items + + +@router.post("", response_model=ItemSchema, status_code=201) +async def create_item( + item_in: ItemCreate, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_user), +) -> Any: + item = Item(**item_in.dict()) + item.user_id = user.id + session.add(item) + await session.commit() + return item + + +@router.put("/{item_id}", response_model=ItemSchema) +async def update_item( + item_id: int, + item_in: ItemUpdate, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_user), +) -> Any: + item: Optional[Item] = await session.get(Item, item_id) + if not item or item.user_id != user.id: + raise HTTPException(404) + update_data = item_in.dict(exclude_unset=True) + for field, value in update_data.items(): + setattr(item, field, value) + session.add(item) + await session.commit() + return item + + +@router.get("/{item_id}", response_model=ItemSchema) +async def get_item( + item_id: int, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_user), +) -> Any: + item: Optional[Item] = await session.get(Item, item_id) + if not item or item.user_id != user.id: + raise HTTPException(404) + return item + + +@router.delete("/{item_id}") +async def delete_item( + item_id: int, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_user), +) -> Any: + item: Optional[Item] = await session.get(Item, item_id) + if not item or item.user_id != user.id: + raise HTTPException(404) + await session.delete(item) + await session.commit() + return {"success": True} diff --git a/trapdata/api/occurrences.bak.py b/trapdata/api/occurrences.bak.py new file mode 100644 index 00000000..d91a252d --- /dev/null +++ b/trapdata/api/occurrences.bak.py @@ -0,0 +1,47 @@ +import random +from typing import Any, List, Optional + +from fastapi import APIRouter, Depends, HTTPException +from sqlalchemy import func, select +from sqlalchemy.ext.asyncio.session import AsyncSession +from starlette.responses import Response + +from app.deps.db import get_async_session +from app.deps.request_params import parse_react_admin_params +from app.deps.users import current_user + +# from app.models.occurrence import Occurrence +from app.models.user import User +from app.schemas.occurrence import test_data +from app.schemas.occurrence import Occurrence +from app.schemas.request_params import RequestParams + +router = APIRouter(prefix="/occurrences") + + +@router.get("", response_model=List[Occurrence]) +async def get_occurrences( + response: Response, + session: AsyncSession = Depends(get_async_session), + request_params: RequestParams = Depends(parse_react_admin_params(Occurrence)), +) -> Any: + occurrences = test_data[ + request_params.skip : request_params.skip + request_params.limit + ] + total = len(occurrences) + response.headers[ + "Content-Range" + ] = f"{request_params.skip}-{request_params.skip + len(occurrences)}/{total}" + return occurrences + + +@router.get("/{occurrence_id}", response_model=Occurrence) +async def get_occurrence( + occurrence_id: int, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_user), +) -> Any: + occurrence: Optional[Occurrence] = await session.get(Occurrence, occurrence_id) + if not occurrence or occurrence.user_id != user.id: + raise HTTPException(404) + return occurrence diff --git a/trapdata/api/occurrences.py b/trapdata/api/occurrences.py new file mode 100644 index 00000000..856013a4 --- /dev/null +++ b/trapdata/api/occurrences.py @@ -0,0 +1,71 @@ +from typing import Any, List, Optional + +from fastapi import APIRouter, Depends, HTTPException +from sqlalchemy import func, select +from sqlalchemy.ext.asyncio.session import AsyncSession +from starlette.responses import Response + +from app.deps.db import get_async_session +from app.deps.request_params import parse_react_admin_params +from app.deps.users import current_user +from app.models.occurrence import Occurrence +from app.models.user import User +from app.schemas.occurrence import Occurrence as OccurrenceSchema +from app.schemas.occurrence import OccurrenceCreate, OccurrenceUpdate +from app.schemas.request_params import RequestParams + +router = APIRouter(prefix="/occurrences") + + +@router.get("", response_model=List[OccurrenceSchema]) +async def get_occurrences( + response: Response, + session: AsyncSession = Depends(get_async_session), + request_params: RequestParams = Depends(parse_react_admin_params(Occurrence)), +) -> Any: + total = await session.scalar( + select( + func.count(Occurrence.id).filter( + Occurrence.deployment_id == request_params.deployment_id + ) + ) + ) + items = ( + ( + await session.execute( + select(Occurrence) + .offset(request_params.skip) + .limit(request_params.limit) + .order_by(request_params.order_by) + .filter(Occurrence.deployment_id == request_params.deployment_id) + ) + ) + .scalars() + .all() + ) + response.headers[ + "Content-Range" + ] = f"{request_params.skip}-{request_params.skip + len(items)}/{total}" + return items + + +@router.get("/{item_id}", response_model=OccurrenceSchema) +async def get_occurrence( + item_id: int, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_user), +) -> Any: + item: Optional[Occurrence] = await session.get(Occurrence, item_id) + return item + + +@router.delete("/{item_id}") +async def delete_occurrence( + item_id: int, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_user), +) -> Any: + item: Optional[Occurrence] = await session.get(Occurrence, item_id) + await session.delete(item) + await session.commit() + return {"success": True} diff --git a/trapdata/api/request_params.py b/trapdata/api/request_params.py new file mode 100644 index 00000000..43d5af5b --- /dev/null +++ b/trapdata/api/request_params.py @@ -0,0 +1,9 @@ +from typing import Any + +from pydantic.main import BaseModel + + +class RequestParams(BaseModel): + skip: int + limit: int + order_by: Any diff --git a/trapdata/api/users.py b/trapdata/api/users.py new file mode 100644 index 00000000..06ca854e --- /dev/null +++ b/trapdata/api/users.py @@ -0,0 +1,30 @@ +from typing import Any, List + +from fastapi.params import Depends +from fastapi.routing import APIRouter +from sqlalchemy import func, select +from sqlalchemy.ext.asyncio.session import AsyncSession +from starlette.responses import Response + +from app.deps.db import get_async_session +from app.deps.users import current_superuser +from app.models.user import User +from app.schemas.user import UserRead + +router = APIRouter() + + +@router.get("/users", response_model=List[UserRead]) +async def get_users( + response: Response, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_superuser), + skip: int = 0, + limit: int = 100, +) -> Any: + total = await session.scalar(select(func.count(User.id))) + users = ( + (await session.execute(select(User).offset(skip).limit(limit))).scalars().all() + ) + response.headers["Content-Range"] = f"{skip}-{skip + len(users)}/{total}" + return users diff --git a/trapdata/api/utils.py b/trapdata/api/utils.py new file mode 100644 index 00000000..e9957285 --- /dev/null +++ b/trapdata/api/utils.py @@ -0,0 +1,17 @@ +from typing import Any + +from fastapi import APIRouter + +from app.schemas.msg import Msg + +router = APIRouter() + + +@router.get( + "/hello-world", + response_model=Msg, + status_code=200, + include_in_schema=False, +) +def test_hello_world() -> Any: + return {"msg": "Hello world!"} diff --git a/trapdata/db/__init__.py b/trapdata/db/__init__.py index a5493f58..5c390e50 100644 --- a/trapdata/db/__init__.py +++ b/trapdata/db/__init__.py @@ -1,3 +1,5 @@ +from typing import Any + import sqlalchemy as sa from sqlalchemy import orm @@ -16,4 +18,4 @@ class Base(orm.DeclarativeBase): - pass + id: Any diff --git a/trapdata/db/base.py b/trapdata/db/base.py index b104d169..e50f4210 100644 --- a/trapdata/db/base.py +++ b/trapdata/db/base.py @@ -232,3 +232,19 @@ def get_or_create(session, model, defaults=None, **kwargs): return instance, False else: return instance, True + + +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine + + +def get_async_session_class(db_path: str) -> async_sessionmaker[AsyncSession]: + async_engine = create_async_engine(db_path, pool_pre_ping=True) + + async_session_maker = async_sessionmaker( + async_engine, + class_=AsyncSession, + expire_on_commit=False, + autocommit=False, + autoflush=False, + ) + return async_session_maker diff --git a/trapdata/db/models/deployments.py b/trapdata/db/models/deployments.py index 1cec4864..b0f41fe1 100644 --- a/trapdata/db/models/deployments.py +++ b/trapdata/db/models/deployments.py @@ -14,8 +14,6 @@ from trapdata.common.types import FilePath from trapdata.db import models -from pydantic import BaseModel - class DeploymentListItem(BaseModel): # id: int @@ -27,14 +25,8 @@ class DeploymentListItem(BaseModel): # num_species: int -class DeploymentListItem(BaseModel): - # id: int - name: str - num_events: int - num_source_images: int - num_detections: int - # num_occurrences: int - # num_species: int +class DeploymentDetail(DeploymentListItem): + pass def deployment_name(image_base_path: FilePath) -> str: diff --git a/trapdata/db/models/images.py b/trapdata/db/models/images.py index 87eb1ac9..d7980f00 100644 --- a/trapdata/db/models/images.py +++ b/trapdata/db/models/images.py @@ -27,7 +27,6 @@ class CaptureListItem(BaseModel): class CaptureDetail(CaptureListItem): - id: int event: object notes: Optional[str] detections: list From 2b7a37277fc224109345db292c8a9846951b7652 Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Mon, 10 Apr 2023 01:21:23 -0700 Subject: [PATCH 08/53] Working API example --- poetry.lock | 35 ++++++++- pyproject.toml | 1 + trapdata/api/__init__.py | 11 --- trapdata/api/config.py | 39 ++++++++++ trapdata/api/deployments.py | 27 ------- trapdata/api/deps/db.py | 4 +- trapdata/api/factory.py | 82 +++++++++++++++++++++ trapdata/api/main.py | 20 +++++ trapdata/api/views/__init__.py | 8 ++ trapdata/api/views/deployments.py | 39 ++++++++++ trapdata/api/{ => views}/items.py | 0 trapdata/api/{ => views}/occurrences.bak.py | 0 trapdata/api/{ => views}/occurrences.py | 0 trapdata/api/{utils.py => views/stats.py} | 11 ++- trapdata/api/{ => views}/users.py | 0 trapdata/cli/base.py | 10 +++ trapdata/db/models/deployments.py | 3 +- trapdata/settings.py | 7 +- trapdata/webui/public/index.html | 1 + 19 files changed, 250 insertions(+), 48 deletions(-) delete mode 100644 trapdata/api/__init__.py create mode 100644 trapdata/api/config.py delete mode 100644 trapdata/api/deployments.py create mode 100644 trapdata/api/factory.py create mode 100644 trapdata/api/main.py create mode 100644 trapdata/api/views/__init__.py create mode 100644 trapdata/api/views/deployments.py rename trapdata/api/{ => views}/items.py (100%) rename trapdata/api/{ => views}/occurrences.bak.py (100%) rename trapdata/api/{ => views}/occurrences.py (100%) rename trapdata/api/{utils.py => views/stats.py} (65%) rename trapdata/api/{ => views}/users.py (100%) create mode 100644 trapdata/webui/public/index.html diff --git a/poetry.lock b/poetry.lock index da8cbc61..5894bb32 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry and should not be changed by hand. +# This file is automatically @generated by Poetry 1.4.1 and should not be changed by hand. [[package]] name = "alembic" @@ -556,6 +556,18 @@ files = [ docs = ["Sphinx", "docutils (<0.18)"] test = ["objgraph", "psutil"] +[[package]] +name = "h11" +version = "0.14.0" +description = "A pure-Python, bring-your-own-I/O implementation of HTTP/1.1" +category = "main" +optional = false +python-versions = ">=3.7" +files = [ + {file = "h11-0.14.0-py3-none-any.whl", hash = "sha256:e3fe4ac4b851c468cc8363d500db52c2ead036020723024a109d37346efaa761"}, + {file = "h11-0.14.0.tar.gz", hash = "sha256:8f19fbbe99e72420ff35c00b27a34cb9937e902a8b810e2c88300c6f0a3b699d"}, +] + [[package]] name = "huggingface-hub" version = "0.13.3" @@ -2464,6 +2476,25 @@ brotli = ["brotli (>=1.0.9)", "brotlicffi (>=0.8.0)", "brotlipy (>=0.6.0)"] secure = ["certifi", "cryptography (>=1.3.4)", "idna (>=2.0.0)", "ipaddress", "pyOpenSSL (>=0.14)", "urllib3-secure-extra"] socks = ["PySocks (>=1.5.6,!=1.5.7,<2.0)"] +[[package]] +name = "uvicorn" +version = "0.21.1" +description = "The lightning-fast ASGI server." +category = "main" +optional = false +python-versions = ">=3.7" +files = [ + {file = "uvicorn-0.21.1-py3-none-any.whl", hash = "sha256:e47cac98a6da10cd41e6fd036d472c6f58ede6c5dbee3dbee3ef7a100ed97742"}, + {file = "uvicorn-0.21.1.tar.gz", hash = "sha256:0fac9cb342ba099e0d582966005f3fdba5b0290579fed4a6266dc702ca7bb032"}, +] + +[package.dependencies] +click = ">=7.0" +h11 = ">=0.8" + +[package.extras] +standard = ["colorama (>=0.4)", "httptools (>=0.5.0)", "python-dotenv (>=0.13)", "pyyaml (>=5.1)", "uvloop (>=0.14.0,!=0.15.0,!=0.15.1)", "watchfiles (>=0.13)", "websockets (>=10.4)"] + [[package]] name = "wcwidth" version = "0.2.6" @@ -2494,4 +2525,4 @@ test = ["pytest (>=6.0.0)"] [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "235424e527001bb5aabde6b4f5381d9dcf97bfd0b1a9dc7207f4d177e4143e5b" +content-hash = "453accf6c7a42a5ccebff574cedeef2a4a4797d04a8b155375023c901b40fe37" diff --git a/pyproject.toml b/pyproject.toml index e6ad75c2..1409d445 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,6 +46,7 @@ pytest-cov = "^4.0.0" pytest-asyncio = "^0.21.0" pytest = "*" fastapi = "^0.95.0" +uvicorn = "^0.21.1" [tool.pytest.ini_options] diff --git a/trapdata/api/__init__.py b/trapdata/api/__init__.py deleted file mode 100644 index eec3f78c..00000000 --- a/trapdata/api/__init__.py +++ /dev/null @@ -1,11 +0,0 @@ -from fastapi import APIRouter - -from trapdata.api import deployments, items, occurrences, users, utils - -api_router = APIRouter() - -api_router.include_router(utils.router, tags=["utils"]) -api_router.include_router(users.router, tags=["users"]) -api_router.include_router(items.router, tags=["items"]) -api_router.include_router(occurrences.router, tags=["occurrences"]) -api_router.include_router(deployments.router, tags=["deployments"]) diff --git a/trapdata/api/config.py b/trapdata/api/config.py new file mode 100644 index 00000000..d4810984 --- /dev/null +++ b/trapdata/api/config.py @@ -0,0 +1,39 @@ +import pathlib +from typing import Any, Dict, List, Optional + +from pydantic import BaseSettings, HttpUrl, PostgresDsn, validator +from pydantic.networks import AnyHttpUrl + +from trapdata.cli import read_settings +from trapdata.settings import Settings as BaseSettings + + +class Settings(BaseSettings): + PROJECT_NAME: str = "AMI Data Manager" + + SENTRY_DSN: Optional[HttpUrl] = None + + API_PATH: str = "/api/v1" + + ACCESS_TOKEN_EXPIRE_MINUTES: int = 7 * 24 * 60 # 7 days + + BACKEND_CORS_ORIGINS: List[AnyHttpUrl] = [] + + # The following variables need to be defined in environment + + TEST_DATABASE_URL: Optional[PostgresDsn] + + SECRET_KEY: str + # END: required environment variables + + # STATIC_ROOT: str = "static" + + # @validator("STATIC_ROOT") + # def validate_static_root(cls, v): + # path = cls.user_data_path / v + # path.mkdir(parents=True, exist_ok=True) + # return path + + +# settings = read_settings(SettingsClass=Settings, SECRET_KEY="secret") +settings = Settings(SECRET_KEY="secret") diff --git a/trapdata/api/deployments.py b/trapdata/api/deployments.py deleted file mode 100644 index 4030e460..00000000 --- a/trapdata/api/deployments.py +++ /dev/null @@ -1,27 +0,0 @@ -from typing import Any, List, Optional - -from fastapi import APIRouter, Depends, HTTPException -from sqlalchemy import func, orm, select -from starlette.responses import Response - -from trapdata.api.deps.db import get_async_session -from trapdata.api.deps.request_params import parse_react_admin_params -from trapdata.api.request_params import RequestParams -from trapdata.db import Base, get_session -from trapdata.db.models.deployments import DeploymentListItem, list_deployments - -router = APIRouter(prefix="/deployments") - - -@router.get("", response_model=List[DeploymentListItem]) -async def get_deployments( - response: Response, - session: orm.Session = Depends(get_session), - request_params: RequestParams = Depends(parse_react_admin_params(Base)), -) -> Any: - total = await session.scalar(select(func.count(Deployment.id))) - deployments = list_deployments(session) - response.headers[ - "Content-Range" - ] = f"{request_params.skip}-{request_params.skip + len(deployments)}/{total}" - return deployments diff --git a/trapdata/api/deps/db.py b/trapdata/api/deps/db.py index 52a49423..c800a29a 100644 --- a/trapdata/api/deps/db.py +++ b/trapdata/api/deps/db.py @@ -8,8 +8,8 @@ settings = read_settings() -def get_session() -> Generator[orm.Session, None]: - Session = get_session_class(settings.database_url) +def get_session() -> Generator[orm.Session, None, None]: + Session = get_session_class(db_path=settings.database_url) with Session() as session: yield session session.close() diff --git a/trapdata/api/factory.py b/trapdata/api/factory.py new file mode 100644 index 00000000..b7a8f1ae --- /dev/null +++ b/trapdata/api/factory.py @@ -0,0 +1,82 @@ +from fastapi import FastAPI +from fastapi.routing import APIRoute +from fastapi.staticfiles import StaticFiles +from starlette.middleware.cors import CORSMiddleware +from starlette.requests import Request +from starlette.responses import FileResponse, RedirectResponse + +from trapdata.api.config import settings +from trapdata.api.views import api_router + + +def create_app(): + description = f"{settings.PROJECT_NAME} API" + app = FastAPI( + title=settings.PROJECT_NAME, + openapi_url=f"{settings.API_PATH}/openapi.json", + docs_url="/docs/", + description=description, + redoc_url="/redoc/", + ) + setup_routers(app) + setup_cors_middleware(app) + serve_static_app(app) + return app + + +def setup_routers(app: FastAPI) -> None: + app.include_router(api_router, prefix=settings.API_PATH) + # The following operation needs to be at the end of this function + use_route_names_as_operation_ids(app) + + +def serve_static_app(app): + app.mount( + "/static/crops", + StaticFiles(directory=settings.user_data_path / "crops"), + name="crops", + ) + app.mount( + "/", + StaticFiles(directory="trapdata/webui/public"), + name="static", + ) + + @app.middleware("http") + async def _add_404_middleware(request: Request, call_next): + """Serves static assets on 404""" + response = await call_next(request) + path = request["path"] + if path.startswith(settings.API_PATH) or path.startswith("/docs"): + return response + if response.status_code == 404: + return FileResponse("trapdata/webui/public/index.html") + return response + + +def setup_cors_middleware(app): + if settings.BACKEND_CORS_ORIGINS: + app.add_middleware( + CORSMiddleware, + allow_origins=[str(origin) for origin in settings.BACKEND_CORS_ORIGINS], + allow_credentials=True, + allow_methods=["*"], + expose_headers=["Content-Range", "Range"], + allow_headers=["Authorization", "Range", "Content-Range"], + ) + + +def use_route_names_as_operation_ids(app: FastAPI) -> None: + """ + Simplify operation IDs so that generated API clients have simpler function + names. + + Should be called only after all routes have been added. + """ + route_names = set() + for route in app.routes: + if isinstance(route, APIRoute): + if route.name in route_names: + raise Exception("Route function names should be unique") + route.operation_id = route.name + route_names.add(route.name) diff --git a/trapdata/api/main.py b/trapdata/api/main.py new file mode 100644 index 00000000..1d1cd547 --- /dev/null +++ b/trapdata/api/main.py @@ -0,0 +1,20 @@ +from trapdata import logger +from trapdata.api.factory import create_app + +app = create_app() + + +def run(): + import uvicorn + + logger.info("Starting uvicorn in reload mode") + uvicorn.run( + "main:app", + host="0.0.0.0", + reload=True, + port=int("8000"), + ) + + +if __name__ == "__main__": + run() diff --git a/trapdata/api/views/__init__.py b/trapdata/api/views/__init__.py new file mode 100644 index 00000000..ed7afc4a --- /dev/null +++ b/trapdata/api/views/__init__.py @@ -0,0 +1,8 @@ +from fastapi import APIRouter + +from trapdata.api.views import deployments, stats + +api_router = APIRouter() + +api_router.include_router(stats.router, tags=["stats"]) +api_router.include_router(deployments.router, tags=["deployments"]) diff --git a/trapdata/api/views/deployments.py b/trapdata/api/views/deployments.py new file mode 100644 index 00000000..f3bd15d8 --- /dev/null +++ b/trapdata/api/views/deployments.py @@ -0,0 +1,39 @@ +from typing import Any, List, Optional + +from fastapi import APIRouter, Depends, HTTPException +from sqlalchemy import func, orm, select +from starlette.responses import Response + +from trapdata.api.config import settings +from trapdata.api.deps.db import get_session +from trapdata.api.deps.request_params import parse_react_admin_params +from trapdata.api.request_params import RequestParams +from trapdata.db import Base +from trapdata.db.models.deployments import DeploymentListItem, list_deployments + +router = APIRouter(prefix="/deployments") + + +@router.get("", response_model=List[DeploymentListItem]) +async def get_deployments( + response: Response, + session: orm.Session = Depends(get_session), + # request_params: RequestParams = Depends(parse_react_admin_params(Base)), +) -> Any: + deployments = list_deployments(session) + return deployments + + +@router.post("/process", response_model=List[DeploymentListItem]) +async def process_deployment( + response: Response, + session: orm.Session = Depends(get_session), + # request_params: RequestParams = Depends(parse_react_admin_params(Base)), +) -> Any: + from trapdata.ml.pipeline import start_pipeline + + start_pipeline( + session=session, image_base_path=settings.image_base_path, settings=settings + ) + deployments = list_deployments(session) + return deployments diff --git a/trapdata/api/items.py b/trapdata/api/views/items.py similarity index 100% rename from trapdata/api/items.py rename to trapdata/api/views/items.py diff --git a/trapdata/api/occurrences.bak.py b/trapdata/api/views/occurrences.bak.py similarity index 100% rename from trapdata/api/occurrences.bak.py rename to trapdata/api/views/occurrences.bak.py diff --git a/trapdata/api/occurrences.py b/trapdata/api/views/occurrences.py similarity index 100% rename from trapdata/api/occurrences.py rename to trapdata/api/views/occurrences.py diff --git a/trapdata/api/utils.py b/trapdata/api/views/stats.py similarity index 65% rename from trapdata/api/utils.py rename to trapdata/api/views/stats.py index e9957285..c3e4d43e 100644 --- a/trapdata/api/utils.py +++ b/trapdata/api/views/stats.py @@ -2,13 +2,18 @@ from fastapi import APIRouter -from app.schemas.msg import Msg +router = APIRouter(prefix="/stats") -router = APIRouter() + +from pydantic import BaseModel + + +class Msg(BaseModel): + msg: str @router.get( - "/hello-world", + "/", response_model=Msg, status_code=200, include_in_schema=False, diff --git a/trapdata/api/users.py b/trapdata/api/views/users.py similarity index 100% rename from trapdata/api/users.py rename to trapdata/api/views/users.py diff --git a/trapdata/cli/base.py b/trapdata/cli/base.py index b124b86f..771f5139 100644 --- a/trapdata/cli/base.py +++ b/trapdata/cli/base.py @@ -30,6 +30,16 @@ def gui(): run() +@cli.command() +def api(): + """ + Launch API server + """ + from trapdata.api.main import run as start_api + + start_api() + + @cli.command("import") def import_data(image_base_path: Optional[pathlib.Path] = None, queue: bool = True): """ diff --git a/trapdata/db/models/deployments.py b/trapdata/db/models/deployments.py index b0f41fe1..8e5a37f0 100644 --- a/trapdata/db/models/deployments.py +++ b/trapdata/db/models/deployments.py @@ -6,6 +6,7 @@ is used as the deployment name. """ import pathlib +from typing import Optional import sqlalchemy as sa from pydantic import BaseModel @@ -16,7 +17,7 @@ class DeploymentListItem(BaseModel): - # id: int + id: Optional[int] = None name: str num_events: int num_source_images: int diff --git a/trapdata/settings.py b/trapdata/settings.py index 2d22cc8d..e765cfdb 100644 --- a/trapdata/settings.py +++ b/trapdata/settings.py @@ -6,6 +6,7 @@ import sqlalchemy from pydantic import BaseSettings, Field, ValidationError, validator +from pydantic.main import ModelMetaclass from rich import print as rprint from trapdata import ml @@ -199,9 +200,11 @@ def kivy_settings_source(settings: BaseSettings) -> dict[str, str]: @lru_cache -def read_settings(*args, **kwargs): +def read_settings( + settings_class: ModelMetaclass = Settings, *args, **kwargs +) -> ModelMetaclass: try: - return Settings(*args, **kwargs) + return settings_class(*args, **kwargs) except ValidationError as e: # @TODO the validation errors could be printed in a more helpful way: rprint(cli_help_message) diff --git a/trapdata/webui/public/index.html b/trapdata/webui/public/index.html new file mode 100644 index 00000000..f944b384 --- /dev/null +++ b/trapdata/webui/public/index.html @@ -0,0 +1 @@ +:) From bb9675c4e08bdb1388b5fec9025555e2c9de24f7 Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Tue, 11 Apr 2023 01:55:43 -0700 Subject: [PATCH 09/53] Add additional endpoints and supporting changes --- trapdata/api/views/__init__.py | 5 +- trapdata/api/views/deployments.py | 26 ++++---- trapdata/api/views/occurrences.bak2.py | 58 +++++++++++++++++ trapdata/api/views/occurrences.py | 86 +++++++++----------------- trapdata/api/views/sessions.py | 39 ++++++++++++ trapdata/api/views/settings.py | 22 +++++++ trapdata/api/views/status.py | 68 ++++++++++++++++++++ trapdata/common/filemanagement.py | 2 +- trapdata/db/models/occurrences.py | 29 ++++++--- trapdata/db/models/queue.py | 26 +++++++- trapdata/settings.py | 14 +++-- 11 files changed, 289 insertions(+), 86 deletions(-) create mode 100644 trapdata/api/views/occurrences.bak2.py create mode 100644 trapdata/api/views/sessions.py create mode 100644 trapdata/api/views/settings.py create mode 100644 trapdata/api/views/status.py diff --git a/trapdata/api/views/__init__.py b/trapdata/api/views/__init__.py index ed7afc4a..58497375 100644 --- a/trapdata/api/views/__init__.py +++ b/trapdata/api/views/__init__.py @@ -1,8 +1,11 @@ from fastapi import APIRouter -from trapdata.api.views import deployments, stats +from trapdata.api.views import deployments, occurrences, settings, stats, status api_router = APIRouter() api_router.include_router(stats.router, tags=["stats"]) +api_router.include_router(status.router, tags=["status"]) api_router.include_router(deployments.router, tags=["deployments"]) +api_router.include_router(occurrences.router, tags=["occurrences"]) +api_router.include_router(settings.router, tags=["settings"]) diff --git a/trapdata/api/views/deployments.py b/trapdata/api/views/deployments.py index f3bd15d8..95557331 100644 --- a/trapdata/api/views/deployments.py +++ b/trapdata/api/views/deployments.py @@ -24,16 +24,16 @@ async def get_deployments( return deployments -@router.post("/process", response_model=List[DeploymentListItem]) -async def process_deployment( - response: Response, - session: orm.Session = Depends(get_session), - # request_params: RequestParams = Depends(parse_react_admin_params(Base)), -) -> Any: - from trapdata.ml.pipeline import start_pipeline - - start_pipeline( - session=session, image_base_path=settings.image_base_path, settings=settings - ) - deployments = list_deployments(session) - return deployments +# @router.post("/process", response_model=List[DeploymentListItem]) +# async def process_deployment( +# response: Response, +# session: orm.Session = Depends(get_session), +# # request_params: RequestParams = Depends(parse_react_admin_params(Base)), +# ) -> Any: +# from trapdata.ml.pipeline import start_pipeline +# +# start_pipeline( +# session=session, image_base_path=settings.image_base_path, settings=settings +# ) +# deployments = list_deployments(session) +# return deployments diff --git a/trapdata/api/views/occurrences.bak2.py b/trapdata/api/views/occurrences.bak2.py new file mode 100644 index 00000000..466c9ef2 --- /dev/null +++ b/trapdata/api/views/occurrences.bak2.py @@ -0,0 +1,58 @@ +from typing import Any, List, Optional + +from app.deps.db import get_async_session +from app.deps.request_params import parse_react_admin_params +from app.deps.users import current_user +from app.models.occurrence import Occurrence +from app.models.user import User +from app.schemas.occurrence import Occurrence as OccurrenceSchema +from app.schemas.occurrence import OccurrenceCreate, OccurrenceUpdate +from app.schemas.request_params import RequestParams +from fastapi import APIRouter, Depends, HTTPException +from sqlalchemy import func, select +from sqlalchemy.ext.asyncio.session import AsyncSession +from starlette.responses import Response + +router = APIRouter(prefix="/occurrences") + + +@router.get("", response_model=List[OccurrenceSchema]) +async def get_occurrences( + response: Response, + session: AsyncSession = Depends(get_async_session), + request_params: RequestParams = Depends(parse_react_admin_params(Occurrence)), +) -> Any: + total = await session.scalar( + select( + func.count(Occurrence.id).filter( + Occurrence.deployment_id == request_params.deployment_id + ) + ) + ) + items = ( + ( + await session.execute( + select(Occurrence) + .offset(request_params.skip) + .limit(request_params.limit) + .order_by(request_params.order_by) + .filter(Occurrence.deployment_id == request_params.deployment_id) + ) + ) + .scalars() + .all() + ) + response.headers[ + "Content-Range" + ] = f"{request_params.skip}-{request_params.skip + len(items)}/{total}" + return items + + +@router.get("/{item_id}", response_model=OccurrenceSchema) +async def get_occurrence( + item_id: int, + session: AsyncSession = Depends(get_async_session), + user: User = Depends(current_user), +) -> Any: + item: Optional[Occurrence] = await session.get(Occurrence, item_id) + return item diff --git a/trapdata/api/views/occurrences.py b/trapdata/api/views/occurrences.py index 856013a4..d7a39ea0 100644 --- a/trapdata/api/views/occurrences.py +++ b/trapdata/api/views/occurrences.py @@ -1,71 +1,41 @@ from typing import Any, List, Optional from fastapi import APIRouter, Depends, HTTPException -from sqlalchemy import func, select -from sqlalchemy.ext.asyncio.session import AsyncSession +from sqlalchemy import func, orm, select from starlette.responses import Response -from app.deps.db import get_async_session -from app.deps.request_params import parse_react_admin_params -from app.deps.users import current_user -from app.models.occurrence import Occurrence -from app.models.user import User -from app.schemas.occurrence import Occurrence as OccurrenceSchema -from app.schemas.occurrence import OccurrenceCreate, OccurrenceUpdate -from app.schemas.request_params import RequestParams +from trapdata.api.config import settings +from trapdata.api.deps.db import get_session +from trapdata.api.deps.request_params import parse_react_admin_params +from trapdata.api.request_params import RequestParams +from trapdata.db import Base +from trapdata.db.models.occurrences import Occurrence, list_occurrences router = APIRouter(prefix="/occurrences") -@router.get("", response_model=List[OccurrenceSchema]) +@router.get("", response_model=List[Occurrence]) async def get_occurrences( response: Response, - session: AsyncSession = Depends(get_async_session), - request_params: RequestParams = Depends(parse_react_admin_params(Occurrence)), + # request_params: RequestParams = Depends(parse_react_admin_params(Base)), ) -> Any: - total = await session.scalar( - select( - func.count(Occurrence.id).filter( - Occurrence.deployment_id == request_params.deployment_id - ) - ) + occurrences = list_occurrences( + settings.database_url, + classification_threshold=settings.classification_threshold, ) - items = ( - ( - await session.execute( - select(Occurrence) - .offset(request_params.skip) - .limit(request_params.limit) - .order_by(request_params.order_by) - .filter(Occurrence.deployment_id == request_params.deployment_id) - ) - ) - .scalars() - .all() - ) - response.headers[ - "Content-Range" - ] = f"{request_params.skip}-{request_params.skip + len(items)}/{total}" - return items - - -@router.get("/{item_id}", response_model=OccurrenceSchema) -async def get_occurrence( - item_id: int, - session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_user), -) -> Any: - item: Optional[Occurrence] = await session.get(Occurrence, item_id) - return item - - -@router.delete("/{item_id}") -async def delete_occurrence( - item_id: int, - session: AsyncSession = Depends(get_async_session), - user: User = Depends(current_user), -) -> Any: - item: Optional[Occurrence] = await session.get(Occurrence, item_id) - await session.delete(item) - await session.commit() - return {"success": True} + return occurrences + + +# @router.post("/process", response_model=List[DeploymentListItem]) +# async def process_deployment( +# response: Response, +# session: orm.Session = Depends(get_session), +# # request_params: RequestParams = Depends(parse_react_admin_params(Base)), +# ) -> Any: +# from trapdata.ml.pipeline import start_pipeline +# +# start_pipeline( +# session=session, image_base_path=settings.image_base_path, settings=settings +# ) +# deployments = list_deployments(session) +# return deployments diff --git a/trapdata/api/views/sessions.py b/trapdata/api/views/sessions.py new file mode 100644 index 00000000..95557331 --- /dev/null +++ b/trapdata/api/views/sessions.py @@ -0,0 +1,39 @@ +from typing import Any, List, Optional + +from fastapi import APIRouter, Depends, HTTPException +from sqlalchemy import func, orm, select +from starlette.responses import Response + +from trapdata.api.config import settings +from trapdata.api.deps.db import get_session +from trapdata.api.deps.request_params import parse_react_admin_params +from trapdata.api.request_params import RequestParams +from trapdata.db import Base +from trapdata.db.models.deployments import DeploymentListItem, list_deployments + +router = APIRouter(prefix="/deployments") + + +@router.get("", response_model=List[DeploymentListItem]) +async def get_deployments( + response: Response, + session: orm.Session = Depends(get_session), + # request_params: RequestParams = Depends(parse_react_admin_params(Base)), +) -> Any: + deployments = list_deployments(session) + return deployments + + +# @router.post("/process", response_model=List[DeploymentListItem]) +# async def process_deployment( +# response: Response, +# session: orm.Session = Depends(get_session), +# # request_params: RequestParams = Depends(parse_react_admin_params(Base)), +# ) -> Any: +# from trapdata.ml.pipeline import start_pipeline +# +# start_pipeline( +# session=session, image_base_path=settings.image_base_path, settings=settings +# ) +# deployments = list_deployments(session) +# return deployments diff --git a/trapdata/api/views/settings.py b/trapdata/api/views/settings.py new file mode 100644 index 00000000..8f02ec86 --- /dev/null +++ b/trapdata/api/views/settings.py @@ -0,0 +1,22 @@ +from typing import Any, List, Optional + +from fastapi import APIRouter, Depends, HTTPException +from sqlalchemy import func, orm, select +from starlette.responses import Response + +from trapdata.api.config import settings +from trapdata.api.deps.db import get_session +from trapdata.api.deps.request_params import parse_react_admin_params +from trapdata.api.request_params import RequestParams +from trapdata.db import Base +from trapdata.db.models.deployments import DeploymentListItem, list_deployments +from trapdata.settings import UserSettings + +router = APIRouter(prefix="/settings") + + +@router.get("", response_model=UserSettings) +async def get_settings( + response: Response, +) -> Any: + return settings diff --git a/trapdata/api/views/status.py b/trapdata/api/views/status.py new file mode 100644 index 00000000..fc90e732 --- /dev/null +++ b/trapdata/api/views/status.py @@ -0,0 +1,68 @@ +import datetime +from typing import Any, List + +import sqlalchemy as sa +from fastapi import APIRouter, Depends +from pydantic import BaseModel +from sqlalchemy import orm +from starlette.responses import Response + +from trapdata.api.config import settings +from trapdata.api.deps.db import get_session +from trapdata.db import models +from trapdata.db.models.queue import QueueListItem, list_queues + +router = APIRouter(prefix="/status") + + +@router.get("/queues", response_model=List[QueueListItem]) +async def get_queues( + response: Response, +) -> Any: + queues = list_queues(settings.database_url, settings.image_base_path) + return queues + + +class NavSummary(BaseModel): + num_deployments: int + num_captures: int + num_sessions: int + num_detections: int + num_occurrences: int + num_species: int + last_updated: datetime.datetime + + +@router.get("/nav_summary", response_model=NavSummary) +async def get_nav_summary( + response: Response, + session: orm.Session = Depends(get_session), +) -> Any: + stmt = ( + sa.select( + sa.func.count(models.MonitoringSession.base_directory.distinct()).label( + "num_deployments" + ), + sa.func.count(models.MonitoringSession.id.distinct()).label("num_sessions"), + sa.func.sum(models.MonitoringSession.num_images).label("num_captures"), + sa.func.sum(models.MonitoringSession.num_detected_objects).label( + "num_detections" + ), + sa.func.count(models.DetectedObject.sequence_id.distinct()).label( + "num_occurrences" + ), # @TODO does not filter based on classification threshold, among other things! + sa.func.count(models.DetectedObject.specific_label.distinct()).label( + "num_species" + ), + ) + .join( + models.DetectedObject, + models.MonitoringSession.id == models.DetectedObject.monitoring_session_id, + ) + .group_by(models.MonitoringSession.base_directory) + ) + summary = session.execute(stmt).first() + if summary: + summary = NavSummary(**summary._mapping, last_updated=datetime.datetime.now()) + + return summary diff --git a/trapdata/common/filemanagement.py b/trapdata/common/filemanagement.py index a94ac9c1..9260d0e5 100644 --- a/trapdata/common/filemanagement.py +++ b/trapdata/common/filemanagement.py @@ -408,7 +408,7 @@ def get_app_dir(app_name: Optional[str] = None) -> pathlib.Path: data_dir = pathlib.Path(os.environ.get("XDG_CONFIG_HOME", "~/.config")) data_dir = data_dir.expanduser().resolve() / app_name if not data_dir.exists(): - data_dir.mkdir() + data_dir.mkdir(parents=True) return data_dir diff --git a/trapdata/db/models/occurrences.py b/trapdata/db/models/occurrences.py index bb09e23a..0b60936c 100644 --- a/trapdata/db/models/occurrences.py +++ b/trapdata/db/models/occurrences.py @@ -43,7 +43,7 @@ class SpeciesSummaryListItem(BaseModel): def list_occurrences( db_path: str, - monitoring_session: models.MonitoringSession, + monitoring_session: Optional[models.MonitoringSession] = None, classification_threshold: float = -1, num_examples: int = 3, limit: Optional[int] = None, @@ -60,8 +60,10 @@ def list_occurrences( ): prepped = {k.split("sequence_", 1)[-1]: v for k, v in item.items()} if prepped["id"]: - prepped["event"] = monitoring_session.day.isoformat() - prepped["deployment"] = monitoring_session.deployment + prepped["event"] = item["monitoring_session_day"].isoformat() + prepped["deployment"] = models.deployments.deployment_name( + item["monitoring_session_base_directory"] + ) occur = Occurrence(**prepped) occurrences.append(occur) return occurrences @@ -115,10 +117,18 @@ def get_unique_species_by_track( Session = db.get_session_class(db_path) session = Session() + filter_args = {} + if monitoring_session: + filter_args["monitoring_session_id"] = monitoring_session.id + # Select all sequences where at least one example is above the score threshold sequences = session.execute( sa.select( models.DetectedObject.sequence_id, + models.MonitoringSession.day.label("monitoring_session_day"), + models.MonitoringSession.base_directory.label( + "monitoring_session_base_directory" + ), sa.func.count(models.DetectedObject.id).label( "sequence_frame_count" ), # frames in track @@ -128,8 +138,12 @@ def get_unique_species_by_track( sa.func.min(models.DetectedObject.timestamp).label("sequence_start_time"), sa.func.max(models.DetectedObject.timestamp).label("sequence_end_time"), ) + .join( + models.MonitoringSession, + models.DetectedObject.monitoring_session_id == models.MonitoringSession.id, + ) .group_by("sequence_id") - .where((models.DetectedObject.monitoring_session_id == monitoring_session.id)) + .filter_by(**filter_args) .having( sa.func.max(models.DetectedObject.specific_label_score) >= classification_threshold, @@ -141,6 +155,7 @@ def get_unique_species_by_track( rows = [] for sequence in sequences: + print(sequence) frames = session.execute( sa.select( models.DetectedObject.id, @@ -152,10 +167,8 @@ def get_unique_species_by_track( models.DetectedObject.sequence_id, models.DetectedObject.timestamp, ) - .where( - (models.DetectedObject.monitoring_session_id == monitoring_session.id) - & (models.DetectedObject.sequence_id == sequence.sequence_id) - ) + .where(models.DetectedObject.sequence_id == sequence.sequence_id) + .filter_by(**filter_args) .join( models.TrapImage, models.TrapImage.id == models.DetectedObject.image_id ) diff --git a/trapdata/db/models/queue.py b/trapdata/db/models/queue.py index 2635d4c7..60be1161 100644 --- a/trapdata/db/models/queue.py +++ b/trapdata/db/models/queue.py @@ -1,10 +1,12 @@ +import pathlib from collections import OrderedDict from typing import Sequence, Union import sqlalchemy as sa +from pydantic import BaseModel from trapdata import constants, logger -from trapdata.common.types import FilePath +from trapdata.common.types import DatabaseURL, FilePath from trapdata.db import get_session from trapdata.db.models.detections import DetectedObject from trapdata.db.models.events import MonitoringSession @@ -655,6 +657,28 @@ def all_queues(db_path, base_directory) -> OrderedDict[str, QueueManager]: ) +class QueueListItem(BaseModel): + name: str + unprocessed_count: int + queue_count: int + done_count: int + + +def list_queues( + db_path: DatabaseURL, image_base_path: pathlib.Path +) -> Sequence[QueueListItem]: + queues = all_queues(db_path, image_base_path) + return [ + QueueListItem( + name=q.name, + unprocessed_count=q.unprocessed_count(), + queue_count=q.queue_count(), + done_count=q.done_count(), + ) + for q in queues.values() + ] + + def add_image_to_queue(db_path, image_id): with get_session(db_path) as sesh: logger.info(f"Adding image id {image_id} to queue") diff --git a/trapdata/settings.py b/trapdata/settings.py index e765cfdb..1412cf21 100644 --- a/trapdata/settings.py +++ b/trapdata/settings.py @@ -14,10 +14,7 @@ from trapdata.common.types import FilePath -class Settings(BaseSettings): - # Can't use PyDantic DSN validator for database_url if sqlite filepath has spaces, see custom validator below - database_url: Union[str, sqlalchemy.engine.URL] = default_database_dsn() - user_data_path: pathlib.Path = get_app_dir() +class UserSettings(BaseSettings): image_base_path: Optional[pathlib.Path] localization_model: ml.models.ObjectDetectorChoice = Field( default=ml.models.DEFAULT_OBJECT_DETECTOR @@ -32,6 +29,15 @@ class Settings(BaseSettings): default=ml.models.DEFAULT_FEATURE_EXTRACTOR ) classification_threshold: float = 0.6 + + class Config: + extra = "ignore" + + +class Settings(UserSettings): + # Can't use PyDantic DSN validator for database_url if sqlite filepath has spaces, see custom validator below + database_url: Union[str, sqlalchemy.engine.URL] = default_database_dsn() + user_data_path: pathlib.Path = get_app_dir() localization_batch_size: int = 2 classification_batch_size: int = 20 num_workers: int = 1 From 826236c39bf7ad538d26e56dffec38ec5525d724 Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Tue, 11 Apr 2023 09:23:36 +0000 Subject: [PATCH 10/53] Fix summary endpoint --- trapdata/api/views/status.py | 36 +++++++++++++++++++----------------- 1 file changed, 19 insertions(+), 17 deletions(-) diff --git a/trapdata/api/views/status.py b/trapdata/api/views/status.py index fc90e732..0eead271 100644 --- a/trapdata/api/views/status.py +++ b/trapdata/api/views/status.py @@ -1,5 +1,5 @@ import datetime -from typing import Any, List +from typing import Any, List, Optional import sqlalchemy as sa from fastapi import APIRouter, Depends @@ -23,18 +23,18 @@ async def get_queues( return queues -class NavSummary(BaseModel): - num_deployments: int - num_captures: int - num_sessions: int - num_detections: int - num_occurrences: int - num_species: int - last_updated: datetime.datetime +class SummaryCounts(BaseModel): + num_deployments: Optional[int] = 0 + num_captures: Optional[int] = 0 + num_sessions: Optional[int] = 0 + num_detections: Optional[int] = 0 + num_occurrences: Optional[int] = 0 + num_species: Optional[int] = 0 + last_updated: Optional[datetime.datetime] = None -@router.get("/nav_summary", response_model=NavSummary) -async def get_nav_summary( +@router.get("/summary", response_model=SummaryCounts) +async def get_summary_counts( response: Response, session: orm.Session = Depends(get_session), ) -> Any: @@ -44,8 +44,8 @@ async def get_nav_summary( "num_deployments" ), sa.func.count(models.MonitoringSession.id.distinct()).label("num_sessions"), - sa.func.sum(models.MonitoringSession.num_images).label("num_captures"), - sa.func.sum(models.MonitoringSession.num_detected_objects).label( + sa.func.count(models.TrapImage.id.distinct()).label("num_captures"), + sa.func.count(models.DetectedObject.id.distinct()).label( "num_detections" ), sa.func.count(models.DetectedObject.sequence_id.distinct()).label( @@ -55,14 +55,16 @@ async def get_nav_summary( "num_species" ), ) + .join( + models.TrapImage, + models.MonitoringSession.id == models.TrapImage.monitoring_session_id, + ) .join( models.DetectedObject, models.MonitoringSession.id == models.DetectedObject.monitoring_session_id, ) - .group_by(models.MonitoringSession.base_directory) ) - summary = session.execute(stmt).first() - if summary: - summary = NavSummary(**summary._mapping, last_updated=datetime.datetime.now()) + summary = session.execute(stmt).one() + summary = SummaryCounts(**summary._mapping, last_updated=datetime.datetime.now()) return summary From e517a8520cc054af02e87c415b75854c37319558 Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Tue, 11 Apr 2023 09:24:17 +0000 Subject: [PATCH 11/53] Persist docker db --- scripts/start_db_container.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/start_db_container.sh b/scripts/start_db_container.sh index b12e1a6f..1d13682c 100755 --- a/scripts/start_db_container.sh +++ b/scripts/start_db_container.sh @@ -8,7 +8,7 @@ HOST_PORT=5432 POSTGRES_VERSION=14 POSTGRES_DB=ami -docker run -d -i --name $CONTAINER_NAME -p $HOST_PORT:5432 -e POSTGRES_HOST_AUTH_METHOD=trust -e POSTGRES_DB=$POSTGRES_DB postgres:$POSTGRES_VERSION +docker run -d -i --name $CONTAINER_NAME -v ./db_data:/var/lib/postgresql/data --restart always -p $HOST_PORT:5432 -e POSTGRES_HOST_AUTH_METHOD=trust -e POSTGRES_DB=$POSTGRES_DB postgres:$POSTGRES_VERSION docker logs ami-db --tail 100 From 9069ae9cc37b139ce8c549250fdea80b7ab4d7de Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Tue, 11 Apr 2023 09:24:39 +0000 Subject: [PATCH 12/53] Debian service for running API server --- trapdata/api/ami.service | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) create mode 100644 trapdata/api/ami.service diff --git a/trapdata/api/ami.service b/trapdata/api/ami.service new file mode 100644 index 00000000..8d861f8a --- /dev/null +++ b/trapdata/api/ami.service @@ -0,0 +1,21 @@ +[Unit] + +Description=AMI Data Manager API + +After=network.target + + +[Service] + +User=debian + +Group=www-data + +WorkingDirectory=/home/debian/ami-data-manager + +ExecStart=/home/debian/miniconda3/bin/gunicorn trapdata.api.main:app --bind 0.0.0.0:8000 --worker-class "uvicorn.workers.UvicornWorker" --log-syslog + + +[Install] + +WantedBy=multi-user.target From 31e1e662510c544c9b9c590565d301093beec599 Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Tue, 11 Apr 2023 18:28:44 -0700 Subject: [PATCH 13/53] Allow monitoring session to be optional --- trapdata/cli/show.py | 35 +++++++---- trapdata/db/models/occurrences.py | 98 ++++++++++++++++++++----------- 2 files changed, 85 insertions(+), 48 deletions(-) diff --git a/trapdata/cli/show.py b/trapdata/cli/show.py index 71229450..0ed9aecf 100644 --- a/trapdata/cli/show.py +++ b/trapdata/cli/show.py @@ -20,7 +20,6 @@ ) from trapdata.db.models.events import ( get_monitoring_session_by_date, - get_monitoring_sessions_from_db, update_all_aggregates, ) from trapdata.db.models.occurrences import list_occurrences, list_species @@ -184,19 +183,29 @@ def detections( @cli.command() -def occurrences(limit: Optional[int] = 100, offset: int = 0): - events = get_monitoring_sessions_from_db( - db_path=settings.database_url, base_directory=settings.image_base_path - ) - occurrences: list[models.occurrences.Occurrence] = [] - for event in events: - occurrences += list_occurrences( - settings.database_url, - event, - classification_threshold=settings.classification_threshold, - limit=limit, - offset=offset, +def occurrences( + session_day: Optional[datetime.datetime] = None, + limit: Optional[int] = 100, + offset: int = 0, +): + event = None + if session_day: + events = get_monitoring_session_by_date( + db_path=settings.database_url, event_dates=[session_day] ) + if not events: + logger.info(f"No events found for {session_day}") + return [] + else: + event = events[0] + + occurrences = list_occurrences( + settings.database_url, + event, + classification_threshold=settings.classification_threshold, + limit=limit, + offset=offset, + ) table = Table("Event", "Label", "Detections", "Score", "Appearance", "Duration") for occurrence in occurrences: diff --git a/trapdata/db/models/occurrences.py b/trapdata/db/models/occurrences.py index 0b60936c..8f95e893 100644 --- a/trapdata/db/models/occurrences.py +++ b/trapdata/db/models/occurrences.py @@ -69,6 +69,30 @@ def list_occurrences( return occurrences +def get_valid_sequence_ids( + monitoring_session: Optional[models.MonitoringSession] = None, + confidence_threshold: float = 0, +) -> sa.ScalarSelect: + """ + Sequence IDs that have a detection with a score above the confidence threshold. + + Intended to be used as a subquery in a larger query. + """ + stmt = sa.select( + models.DetectedObject.sequence_id.distinct().label("id"), + ).where(models.DetectedObject.specific_label_score >= confidence_threshold) + if monitoring_session: + stmt = stmt.where( + models.DetectedObject.monitoring_session_id == monitoring_session.id + ) + stmt = ( + stmt.group_by(models.DetectedObject.sequence_id) + .order_by(models.DetectedObject.sequence_id) + .scalar_subquery() + ) + return stmt + + def list_species( db_path: str, image_base_path: pathlib.Path, @@ -108,7 +132,7 @@ def list_species( def get_unique_species_by_track( db_path: str, - monitoring_session=None, + monitoring_session: Optional[models.MonitoringSession] = None, classification_threshold: float = -1, num_examples: int = 3, limit: Optional[int] = None, @@ -117,46 +141,45 @@ def get_unique_species_by_track( Session = db.get_session_class(db_path) session = Session() - filter_args = {} + # Select all sequences where at least one example is above the score threshold + stmt = sa.select( + models.DetectedObject.sequence_id, + models.MonitoringSession.day.label("monitoring_session_day"), + models.MonitoringSession.base_directory.label( + "monitoring_session_base_directory" + ), + sa.func.count(models.DetectedObject.id).label( + "sequence_frame_count" + ), # frames in track + sa.func.max(models.DetectedObject.specific_label_score).label( + "sequence_best_score" + ), + sa.func.min(models.DetectedObject.timestamp).label("sequence_start_time"), + sa.func.max(models.DetectedObject.timestamp).label("sequence_end_time"), + ).join( + models.MonitoringSession, + models.DetectedObject.monitoring_session_id == models.MonitoringSession.id, + ) if monitoring_session: - filter_args["monitoring_session_id"] = monitoring_session.id + stmt = stmt.where(models.MonitoringSession.id == monitoring_session.id) - # Select all sequences where at least one example is above the score threshold - sequences = session.execute( - sa.select( - models.DetectedObject.sequence_id, - models.MonitoringSession.day.label("monitoring_session_day"), - models.MonitoringSession.base_directory.label( - "monitoring_session_base_directory" - ), - sa.func.count(models.DetectedObject.id).label( - "sequence_frame_count" - ), # frames in track - sa.func.max(models.DetectedObject.specific_label_score).label( - "sequence_best_score" - ), - sa.func.min(models.DetectedObject.timestamp).label("sequence_start_time"), - sa.func.max(models.DetectedObject.timestamp).label("sequence_end_time"), - ) - .join( - models.MonitoringSession, - models.DetectedObject.monitoring_session_id == models.MonitoringSession.id, + stmt = ( + stmt.group_by( + "sequence_id", "monitoring_session_day", "monitoring_session_base_directory" ) - .group_by("sequence_id") - .filter_by(**filter_args) .having( sa.func.max(models.DetectedObject.specific_label_score) >= classification_threshold, ) - .order_by(models.DetectedObject.timestamp.asc()) + .order_by("sequence_id") .limit(limit) .offset(offset) - ).all() + ) + sequences = session.execute(stmt).all() rows = [] for sequence in sequences: - print(sequence) - frames = session.execute( + stmt = ( sa.select( models.DetectedObject.id, models.DetectedObject.image_id.label("source_image_id"), @@ -167,15 +190,20 @@ def get_unique_species_by_track( models.DetectedObject.sequence_id, models.DetectedObject.timestamp, ) - .where(models.DetectedObject.sequence_id == sequence.sequence_id) - .filter_by(**filter_args) .join( models.TrapImage, models.TrapImage.id == models.DetectedObject.image_id ) - # .order_by(sa.func.random()) - .order_by(sa.desc("score")) - .limit(num_examples) - ).all() + .where(models.DetectedObject.sequence_id == sequence.sequence_id) + ) + + if monitoring_session: + stmt = stmt.where( + models.DetectedObject.monitoring_session_id == monitoring_session.id + ) + stmt = stmt.order_by(sa.desc("score")).limit(num_examples) + + frames = session.execute(stmt).all() + row = dict(sequence._mapping) if frames: best_example = frames[0] From 67a20fdb50756022a29c402e24958708d89eee51 Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Tue, 11 Apr 2023 19:18:11 -0700 Subject: [PATCH 14/53] Standardize list_monitoring_sessions --- .flake8 | 2 +- .pre-commit-config.yaml | 2 +- trapdata/api/views/__init__.py | 3 +- trapdata/api/views/{sessions.py => events.py} | 24 ++--- trapdata/cli/export.py | 34 ++----- trapdata/cli/show.py | 41 +++------ trapdata/db/models/events.py | 91 ++++++++++++++++++- trapdata/db/models/images.py | 8 +- 8 files changed, 126 insertions(+), 79 deletions(-) rename trapdata/api/views/{sessions.py => events.py} (59%) diff --git a/.flake8 b/.flake8 index 9c7d08f6..011ecfa8 100644 --- a/.flake8 +++ b/.flake8 @@ -1,3 +1,3 @@ [flake8] max-line-length = 160 -ignore = E203, E402, W503 +ignore = E203, E402, W503, B008 diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 20901bdc..b57fbfe2 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -19,7 +19,7 @@ repos: - id: autoflake args: - --in-place - - --imports=sqlalchemy,pydantic + - --imports=trapdata,sqlalchemy,pydantic,fastapi files: . types: [file, python] diff --git a/trapdata/api/views/__init__.py b/trapdata/api/views/__init__.py index 58497375..921cbbfb 100644 --- a/trapdata/api/views/__init__.py +++ b/trapdata/api/views/__init__.py @@ -1,11 +1,12 @@ from fastapi import APIRouter -from trapdata.api.views import deployments, occurrences, settings, stats, status +from trapdata.api.views import deployments, events, occurrences, settings, stats, status api_router = APIRouter() api_router.include_router(stats.router, tags=["stats"]) api_router.include_router(status.router, tags=["status"]) api_router.include_router(deployments.router, tags=["deployments"]) +api_router.include_router(events.router, tags=["events"]) api_router.include_router(occurrences.router, tags=["occurrences"]) api_router.include_router(settings.router, tags=["settings"]) diff --git a/trapdata/api/views/sessions.py b/trapdata/api/views/events.py similarity index 59% rename from trapdata/api/views/sessions.py rename to trapdata/api/views/events.py index 95557331..b7f13304 100644 --- a/trapdata/api/views/sessions.py +++ b/trapdata/api/views/events.py @@ -1,27 +1,27 @@ -from typing import Any, List, Optional +from typing import Any, List -from fastapi import APIRouter, Depends, HTTPException -from sqlalchemy import func, orm, select +from fastapi import APIRouter, Depends +from sqlalchemy import orm from starlette.responses import Response from trapdata.api.config import settings from trapdata.api.deps.db import get_session -from trapdata.api.deps.request_params import parse_react_admin_params -from trapdata.api.request_params import RequestParams -from trapdata.db import Base -from trapdata.db.models.deployments import DeploymentListItem, list_deployments +from trapdata.db.models.events import ( + MonitoringSessionListItem, + list_monitoring_sessions, +) -router = APIRouter(prefix="/deployments") +router = APIRouter(prefix="/events") -@router.get("", response_model=List[DeploymentListItem]) -async def get_deployments( +@router.get("", response_model=List[MonitoringSessionListItem]) +async def get_monitoring_sessions( response: Response, session: orm.Session = Depends(get_session), # request_params: RequestParams = Depends(parse_react_admin_params(Base)), ) -> Any: - deployments = list_deployments(session) - return deployments + items = list_monitoring_sessions(session, settings.image_base_path) + return items # @router.post("/process", response_model=List[DeploymentListItem]) diff --git a/trapdata/cli/export.py b/trapdata/cli/export.py index 98d6dfad..4f6a7cb0 100644 --- a/trapdata/cli/export.py +++ b/trapdata/cli/export.py @@ -14,15 +14,12 @@ from trapdata.cli import settings from trapdata.db import get_session_class from trapdata.db.models.deployments import list_deployments -from trapdata.db.models.detections import ( - get_detected_objects, - num_occurrences_for_event, - num_species_for_event, -) +from trapdata.db.models.detections import get_detected_objects from trapdata.db.models.events import ( get_monitoring_session_by_date, get_monitoring_session_images, get_monitoring_sessions_from_db, + list_monitoring_sessions, ) from trapdata.db.models.occurrences import list_occurrences @@ -161,28 +158,11 @@ def sessions( """ Export a summary of monitoring sessions from database in the specified format. """ - monitoring_events = get_monitoring_sessions_from_db( - db_path=settings.database_url, base_directory=settings.image_base_path - ) - items = [] - for event in monitoring_events: - event_data = event.report_data() - num_occurrences = num_occurrences_for_event( - db_path=settings.database_url, monitoring_session=event - ) - num_species = num_species_for_event( - db_path=settings.database_url, monitoring_session=event - ) - example_captures = get_monitoring_session_images( - settings.database_url, event, limit=5, offset=int(event.num_images / 2) - ) - event_data["example_captures"] = [ - img.report_data().dict() for img in example_captures - ] - event_data["num_occurrences"] = num_occurrences - event_data["num_species"] = num_species - items.append(event_data) - df = pd.DataFrame(items) + Session = get_session_class(settings.database_url) + session = Session() + + items = list_monitoring_sessions(session, settings.image_base_path) + df = pd.DataFrame([item.dict() for item in items]) return export(df=df, format=format, outfile=outfile) diff --git a/trapdata/cli/show.py b/trapdata/cli/show.py index 0ed9aecf..ed1d92ff 100644 --- a/trapdata/cli/show.py +++ b/trapdata/cli/show.py @@ -13,11 +13,7 @@ from trapdata.db import models from trapdata.db.base import get_session_class from trapdata.db.models.deployments import list_deployments -from trapdata.db.models.detections import ( - get_detected_objects, - num_occurrences_for_event, - num_species_for_event, -) +from trapdata.db.models.detections import get_detected_objects from trapdata.db.models.events import ( get_monitoring_session_by_date, update_all_aggregates, @@ -101,39 +97,26 @@ def sessions(): """ Show all monitoring events that have been interpreted from image timestamps. """ + from trapdata.db.models.events import list_monitoring_sessions + Session = get_session_class(settings.database_url) session = Session() # image_base_path = str(settings.image_base_path.resolve()) - update_all_aggregates(session, settings.image_base_path) - logger.info(f"Show monitoring events for images in {settings.image_base_path}") - events = ( - session.execute( - select(models.MonitoringSession).where( - models.MonitoringSession.base_directory == str(settings.image_base_path) - ) - ) - .unique() - .scalars() - .all() - ) + events = list_monitoring_sessions(session, settings.image_base_path) - table = Table("ID", "Day", "Images", "Detections", "Occurrences", "Species") + table = Table( + "ID", "Day", "Duration", "Captures", "Detections", "Occurrences", "Species" + ) for event in events: - event.update_aggregates(session) - num_occurrences = num_occurrences_for_event( - db_path=settings.database_url, monitoring_session=event - ) - num_species = num_species_for_event( - db_path=settings.database_url, monitoring_session=event - ) row_values = [ event.id, event.day, - event.num_images, - event.num_detected_objects, - num_occurrences, - num_species, + event.duration_label, + event.num_captures, + event.num_detections, + event.num_occurrences, + event.num_species, ] table.add_row(*[str(val) for val in row_values]) console.print(table) diff --git a/trapdata/db/models/events.py b/trapdata/db/models/events.py index ea899e59..65395bfa 100644 --- a/trapdata/db/models/events.py +++ b/trapdata/db/models/events.py @@ -5,6 +5,7 @@ import sqlalchemy as sa from pydantic import BaseModel from sqlalchemy import orm +from sqlalchemy.ext.hybrid import hybrid_method from sqlalchemy_utils import aggregated from trapdata.common.filemanagement import find_images, group_images_by_day @@ -15,10 +16,25 @@ # @TODO Rename to TrapEvent? CapturePeriod? less confusing with other types of Sessions. CaptureSession? Or SurveyEvent or Survey? -class Event(BaseModel): +class MonitoringSessionListItem(BaseModel): id: str - frames: list[dict] - example_frames: list[dict] + day: datetime.date + num_captures: int + num_detections: int + num_occurrences: int + num_species: int + example_captures: list[dict] + start_time: datetime.datetime + end_time: datetime.datetime + duration: datetime.timedelta + duration_label: str + + +class MonitoringSessionDetail(MonitoringSessionListItem): + notes: Optional[str] + detections: list + # @TODO add more info about the session, like the number of images, the number of detected objects, etc + # @TODO add the number of species detected in this session class MonitoringSession(Base): @@ -41,6 +57,26 @@ def num_images(self): def num_detected_objects(self): return sa.func.count("1") + def num_occurrences(self, session: orm.Session) -> int: + return ( + session.execute( + sa.select( + sa.func.count(models.DetectedObject.sequence_id.distinct()) + ).where(models.DetectedObject.monitoring_session_id == self.id) + ).scalar() + or 0 + ) + + def num_species(self, session: orm.Session) -> int: + return ( + session.execute( + sa.select( + sa.func.count(models.DetectedObject.specific_label.distinct()) + ).where(models.DetectedObject.monitoring_session_id == self.id) + ).scalar() + or 0 + ) + # This runs an expensive/slow query every time an image is updated # @observes("images") # def image_observer(self, images): @@ -101,11 +137,12 @@ def update_aggregates(self, session: orm.Session, commit=True): if commit: session.commit() - def duration(self) -> Optional[datetime.timedelta]: + @hybrid_method + def duration(self) -> datetime.timedelta: if self.start_time and self.end_time: return self.end_time - self.start_time else: - return None + return datetime.timedelta(0) @property def duration_label(self): @@ -344,3 +381,47 @@ def export_monitoring_sessions( ): records = [item.report_data() for item in items] return export_report(records, report_name, directory) + + +def list_monitoring_sessions( + session: orm.Session, + image_base_path: FilePath, + limit: Optional[int] = None, + offset: int = 0, +) -> list[MonitoringSessionListItem]: + """ """ + + update_all_aggregates(session, image_base_path) + logger.info(f"Fetching monitoring events for images in {image_base_path}") + events = ( + session.execute( + sa.select(models.MonitoringSession) + .where(models.MonitoringSession.base_directory == str(image_base_path)) + .order_by(models.MonitoringSession.day) + .limit(limit) + .offset(offset) + ) + .unique() + .scalars() + .all() + ) + + list_items = [] + for event in events: + event.update_aggregates(session) + list_items.append( + MonitoringSessionListItem( + id=event.id, + day=event.day, + start_time=event.start_time, + end_time=event.end_time, + num_captures=event.num_images, + num_detections=event.num_detected_objects, + num_occurrences=event.num_occurrences(session), + num_species=event.num_species(session), + example_captures=[], # @TODO + duration=event.duration(), + duration_label=event.duration_label, + ) + ) + return list_items diff --git a/trapdata/db/models/images.py b/trapdata/db/models/images.py index d7980f00..0ccc193a 100644 --- a/trapdata/db/models/images.py +++ b/trapdata/db/models/images.py @@ -19,9 +19,7 @@ class CaptureListItem(BaseModel): id: int timestamp: datetime.datetime - source_image: str - last_read: Optional[datetime.datetime] - last_processed: Optional[datetime.datetime] + path: pathlib.Path num_detections: Optional[int] in_queue: bool @@ -33,6 +31,10 @@ class CaptureDetail(CaptureListItem): filesize: int width: int height: int + last_read: Optional[datetime.datetime] + last_processed: Optional[datetime.datetime] + next_capture: Optional[CaptureListItem] + prev_capture: Optional[CaptureListItem] class TrapImage(Base): From bc66cf24268fafe6be9537b001f1966e35023650 Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Tue, 11 Apr 2023 19:39:03 -0700 Subject: [PATCH 15/53] Update method for counting everything --- trapdata/api/views/status.py | 46 +++++++++++-------------------- trapdata/db/models/deployments.py | 16 +++++++---- trapdata/db/models/events.py | 5 ++++ 3 files changed, 31 insertions(+), 36 deletions(-) diff --git a/trapdata/api/views/status.py b/trapdata/api/views/status.py index 0eead271..290ba385 100644 --- a/trapdata/api/views/status.py +++ b/trapdata/api/views/status.py @@ -1,7 +1,6 @@ import datetime from typing import Any, List, Optional -import sqlalchemy as sa from fastapi import APIRouter, Depends from pydantic import BaseModel from sqlalchemy import orm @@ -9,7 +8,8 @@ from trapdata.api.config import settings from trapdata.api.deps.db import get_session -from trapdata.db import models +from trapdata.db.models.deployments import list_deployments +from trapdata.db.models.events import list_monitoring_sessions from trapdata.db.models.queue import QueueListItem, list_queues router = APIRouter(prefix="/status") @@ -29,7 +29,7 @@ class SummaryCounts(BaseModel): num_sessions: Optional[int] = 0 num_detections: Optional[int] = 0 num_occurrences: Optional[int] = 0 - num_species: Optional[int] = 0 + num_species: Optional[int] = 0 last_updated: Optional[datetime.datetime] = None @@ -38,33 +38,19 @@ async def get_summary_counts( response: Response, session: orm.Session = Depends(get_session), ) -> Any: - stmt = ( - sa.select( - sa.func.count(models.MonitoringSession.base_directory.distinct()).label( - "num_deployments" - ), - sa.func.count(models.MonitoringSession.id.distinct()).label("num_sessions"), - sa.func.count(models.TrapImage.id.distinct()).label("num_captures"), - sa.func.count(models.DetectedObject.id.distinct()).label( - "num_detections" - ), - sa.func.count(models.DetectedObject.sequence_id.distinct()).label( - "num_occurrences" - ), # @TODO does not filter based on classification threshold, among other things! - sa.func.count(models.DetectedObject.specific_label.distinct()).label( - "num_species" - ), - ) - .join( - models.TrapImage, - models.MonitoringSession.id == models.TrapImage.monitoring_session_id, - ) - .join( - models.DetectedObject, - models.MonitoringSession.id == models.DetectedObject.monitoring_session_id, - ) + deployments = list_deployments(session) + events = [] + for deployment in deployments: + events += list_monitoring_sessions(session, deployment.image_base_path) + + summary = SummaryCounts( + num_deployments=len({e.deployment for e in events}), + num_sessions=len(events), + num_captures=sum(e.num_captures for e in events), + num_detections=sum(e.num_detections for e in events), + num_occurrences=sum(e.num_occurrences for e in events), + num_species=sum(e.num_species for e in events), + last_updated=datetime.datetime.now(), ) - summary = session.execute(stmt).one() - summary = SummaryCounts(**summary._mapping, last_updated=datetime.datetime.now()) return summary diff --git a/trapdata/db/models/deployments.py b/trapdata/db/models/deployments.py index 8e5a37f0..d0aa9f5c 100644 --- a/trapdata/db/models/deployments.py +++ b/trapdata/db/models/deployments.py @@ -19,6 +19,7 @@ class DeploymentListItem(BaseModel): id: Optional[int] = None name: str + image_base_path: FilePath num_events: int num_source_images: int num_detections: int @@ -43,17 +44,20 @@ def list_deployments(session: orm.Session) -> list[DeploymentListItem]: A proxy for "registered trap deployments". """ stmt = sa.select( - models.MonitoringSession.base_directory.label("name"), + models.MonitoringSession.base_directory.label("image_base_path"), sa.func.count(models.MonitoringSession.id).label("num_events"), sa.func.sum(models.MonitoringSession.num_images).label("num_source_images"), sa.func.sum(models.MonitoringSession.num_detected_objects).label( "num_detections" ), ).group_by(models.MonitoringSession.base_directory) - deployments = [ - DeploymentListItem(**d._mapping) for d in session.execute(stmt).all() - ] - for deployment in deployments: - deployment.name = deployment_name(deployment.name) + deployments = [] + for deployment in session.execute(stmt).all(): + deployments.append( + DeploymentListItem( + **deployment._mapping, + name=deployment_name(deployment.image_base_path), + ) + ) return deployments diff --git a/trapdata/db/models/events.py b/trapdata/db/models/events.py index 65395bfa..9d1ba984 100644 --- a/trapdata/db/models/events.py +++ b/trapdata/db/models/events.py @@ -13,12 +13,15 @@ from trapdata.common.types import FilePath from trapdata.common.utils import export_report from trapdata.db import Base, get_session, models +from trapdata.db.models.deployments import deployment_name # @TODO Rename to TrapEvent? CapturePeriod? less confusing with other types of Sessions. CaptureSession? Or SurveyEvent or Survey? class MonitoringSessionListItem(BaseModel): id: str day: datetime.date + image_base_path: str + deployment: str num_captures: int num_detections: int num_occurrences: int @@ -413,6 +416,8 @@ def list_monitoring_sessions( MonitoringSessionListItem( id=event.id, day=event.day, + image_base_path=str(event.base_directory), + deployment=deployment_name(str(event.base_directory)), start_time=event.start_time, end_time=event.end_time, num_captures=event.num_images, From 1c1576883995d58be0c74d54acb4b9b75d8c1b56 Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Tue, 11 Apr 2023 23:58:08 -0700 Subject: [PATCH 16/53] Try serving image crops --- trapdata/api/factory.py | 5 +++++ trapdata/api/views/events.py | 4 +++- trapdata/api/views/occurrences.py | 1 + trapdata/common/filemanagement.py | 8 ++++++++ trapdata/db/models/events.py | 1 + trapdata/db/models/occurrences.py | 9 +++++++++ 6 files changed, 27 insertions(+), 1 deletion(-) diff --git a/trapdata/api/factory.py b/trapdata/api/factory.py index b7a8f1ae..6f190f39 100644 --- a/trapdata/api/factory.py +++ b/trapdata/api/factory.py @@ -36,6 +36,11 @@ def serve_static_app(app): StaticFiles(directory=settings.user_data_path / "crops"), name="crops", ) + app.mount( + "/static/captures", + StaticFiles(directory=settings.image_base_path), + name="captures", + ) app.mount( "/", StaticFiles(directory="trapdata/webui/public"), diff --git a/trapdata/api/views/events.py b/trapdata/api/views/events.py index b7f13304..febd2c9b 100644 --- a/trapdata/api/views/events.py +++ b/trapdata/api/views/events.py @@ -20,7 +20,9 @@ async def get_monitoring_sessions( session: orm.Session = Depends(get_session), # request_params: RequestParams = Depends(parse_react_admin_params(Base)), ) -> Any: - items = list_monitoring_sessions(session, settings.image_base_path) + items = list_monitoring_sessions( + session, settings.image_base_path, media_url_base="/static/" + ) return items diff --git a/trapdata/api/views/occurrences.py b/trapdata/api/views/occurrences.py index d7a39ea0..ab604f14 100644 --- a/trapdata/api/views/occurrences.py +++ b/trapdata/api/views/occurrences.py @@ -22,6 +22,7 @@ async def get_occurrences( occurrences = list_occurrences( settings.database_url, classification_threshold=settings.classification_threshold, + media_url_base="/static/", ) return occurrences diff --git a/trapdata/common/filemanagement.py b/trapdata/common/filemanagement.py index 9260d0e5..bf231141 100644 --- a/trapdata/common/filemanagement.py +++ b/trapdata/common/filemanagement.py @@ -427,3 +427,11 @@ def initial_directory_choice(): than "." which is the directory of this python package. """ return pathlib.Path("~/") + + +def media_url(local_path: str, delim: str, media_url_base: Optional[str] = None) -> str: + relative_path = f"{delim}{local_path.split(delim)[-1]}" + if media_url_base: + return f"{media_url_base}{relative_path}" + else: + return relative_path diff --git a/trapdata/db/models/events.py b/trapdata/db/models/events.py index 9d1ba984..87acfef0 100644 --- a/trapdata/db/models/events.py +++ b/trapdata/db/models/events.py @@ -391,6 +391,7 @@ def list_monitoring_sessions( image_base_path: FilePath, limit: Optional[int] = None, offset: int = 0, + media_url_base: Optional[str] = None, ) -> list[MonitoringSessionListItem]: """ """ diff --git a/trapdata/db/models/occurrences.py b/trapdata/db/models/occurrences.py index 8f95e893..a9d142cd 100644 --- a/trapdata/db/models/occurrences.py +++ b/trapdata/db/models/occurrences.py @@ -14,6 +14,7 @@ from pydantic import BaseModel from trapdata import db +from trapdata.common.filemanagement import media_url from trapdata.db import models @@ -48,6 +49,7 @@ def list_occurrences( num_examples: int = 3, limit: Optional[int] = None, offset: int = 0, + media_url_base: Optional[str] = None, ) -> list[Occurrence]: occurrences = [] for item in get_unique_species_by_track( @@ -64,6 +66,13 @@ def list_occurrences( prepped["deployment"] = models.deployments.deployment_name( item["monitoring_session_base_directory"] ) + if media_url_base: + examples = [dict(example) for example in prepped["examples"]] + for example in examples: + example["cropped_image_path"] = media_url( + example["cropped_image_path"], "crops", media_url_base + ) + prepped["examples"] = examples occur = Occurrence(**prepped) occurrences.append(occur) return occurrences From 98b3964b11713e07bde876e50905dd97d83d6ba9 Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Wed, 12 Apr 2023 06:58:47 +0000 Subject: [PATCH 17/53] Fix show deployments table --- trapdata/cli/show.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/trapdata/cli/show.py b/trapdata/cli/show.py index ed1d92ff..10091dd1 100644 --- a/trapdata/cli/show.py +++ b/trapdata/cli/show.py @@ -71,6 +71,8 @@ def deployments(): update_all_aggregates(session, settings.image_base_path) deployments = list_deployments(session) table = Table( + "ID", + "Name", "Image Base Path", "Sessions", "Images", From 65766fb8e3b98125f9ebe6cfefb2ec3a4bd97b6e Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Wed, 12 Apr 2023 00:06:17 -0700 Subject: [PATCH 18/53] Update media URL for source images --- trapdata/common/filemanagement.py | 3 +++ trapdata/db/models/occurrences.py | 3 +++ 2 files changed, 6 insertions(+) diff --git a/trapdata/common/filemanagement.py b/trapdata/common/filemanagement.py index bf231141..ad27f253 100644 --- a/trapdata/common/filemanagement.py +++ b/trapdata/common/filemanagement.py @@ -430,6 +430,9 @@ def initial_directory_choice(): def media_url(local_path: str, delim: str, media_url_base: Optional[str] = None) -> str: + """ + Given a local path to a file, return a URL to that file. @TODO rework this and handle slashes better. + """ relative_path = f"{delim}{local_path.split(delim)[-1]}" if media_url_base: return f"{media_url_base}{relative_path}" diff --git a/trapdata/db/models/occurrences.py b/trapdata/db/models/occurrences.py index a9d142cd..b2e4c7bc 100644 --- a/trapdata/db/models/occurrences.py +++ b/trapdata/db/models/occurrences.py @@ -72,6 +72,9 @@ def list_occurrences( example["cropped_image_path"] = media_url( example["cropped_image_path"], "crops", media_url_base ) + example["source_image_path"] = media_url( + example["source_image_path"], "captures/", media_url_base + ) prepped["examples"] = examples occur = Occurrence(**prepped) occurrences.append(occur) From 758b2ce1ad9c5923673e8be0008aa61207eb3d47 Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Thu, 13 Apr 2023 00:26:03 -0700 Subject: [PATCH 19/53] Add event data to occurrence list --- trapdata/db/models/occurrences.py | 34 ++++++++++++++++++++++++------- 1 file changed, 27 insertions(+), 7 deletions(-) diff --git a/trapdata/db/models/occurrences.py b/trapdata/db/models/occurrences.py index b2e4c7bc..7f8c0b72 100644 --- a/trapdata/db/models/occurrences.py +++ b/trapdata/db/models/occurrences.py @@ -18,7 +18,18 @@ from trapdata.db import models -class Occurrence(BaseModel): +class OccurrenceNestedEvent(BaseModel): + id: int + day: datetime.date + url: Optional[str] = None + + +class OccurrenceNestedDetection(BaseModel): + id: int + cropped_image_path: str + + +class OccurrenceListItem(BaseModel): id: str label: str best_score: float @@ -26,14 +37,15 @@ class Occurrence(BaseModel): end_time: datetime.datetime duration: datetime.timedelta deployment: str - event: str + event: OccurrenceNestedEvent num_frames: int # cropped_image_path: pathlib.Path # source_image_id: int examples: list[dict] # detections: list[object] # deployment: object - # captures: list[object] + # captures: list[object] = + url: Optional[str] = None class SpeciesSummaryListItem(BaseModel): @@ -50,7 +62,7 @@ def list_occurrences( limit: Optional[int] = None, offset: int = 0, media_url_base: Optional[str] = None, -) -> list[Occurrence]: +) -> list[OccurrenceListItem]: occurrences = [] for item in get_unique_species_by_track( db_path, @@ -62,12 +74,12 @@ def list_occurrences( ): prepped = {k.split("sequence_", 1)[-1]: v for k, v in item.items()} if prepped["id"]: - prepped["event"] = item["monitoring_session_day"].isoformat() prepped["deployment"] = models.deployments.deployment_name( item["monitoring_session_base_directory"] ) if media_url_base: examples = [dict(example) for example in prepped["examples"]] + # @TODO use OccurrenceNestedDetection for example in examples: example["cropped_image_path"] = media_url( example["cropped_image_path"], "crops", media_url_base @@ -76,7 +88,11 @@ def list_occurrences( example["source_image_path"], "captures/", media_url_base ) prepped["examples"] = examples - occur = Occurrence(**prepped) + + prepped["event"] = OccurrenceNestedEvent( + id=item["monitoring_session_id"], day=item["monitoring_session_day"] + ) + occur = OccurrenceListItem(**prepped) occurrences.append(occur) return occurrences @@ -156,6 +172,7 @@ def get_unique_species_by_track( # Select all sequences where at least one example is above the score threshold stmt = sa.select( models.DetectedObject.sequence_id, + models.MonitoringSession.id.label("monitoring_session_id"), models.MonitoringSession.day.label("monitoring_session_day"), models.MonitoringSession.base_directory.label( "monitoring_session_base_directory" @@ -177,7 +194,10 @@ def get_unique_species_by_track( stmt = ( stmt.group_by( - "sequence_id", "monitoring_session_day", "monitoring_session_base_directory" + "sequence_id", + "monitoring_session_id", + "monitoring_session_day", + "monitoring_session_base_directory", ) .having( sa.func.max(models.DetectedObject.specific_label_score) From 53a3b6f68913a80994a9f4cf0d4dcfb5bd9a753d Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Thu, 13 Apr 2023 00:27:00 -0700 Subject: [PATCH 20/53] Add capture examples to event list items --- trapdata/api/views/occurrences.py | 4 ++-- trapdata/common/filemanagement.py | 5 ++++- trapdata/db/models/events.py | 37 ++++++++++++++++++++++++++----- 3 files changed, 38 insertions(+), 8 deletions(-) diff --git a/trapdata/api/views/occurrences.py b/trapdata/api/views/occurrences.py index ab604f14..46cb4a9a 100644 --- a/trapdata/api/views/occurrences.py +++ b/trapdata/api/views/occurrences.py @@ -9,12 +9,12 @@ from trapdata.api.deps.request_params import parse_react_admin_params from trapdata.api.request_params import RequestParams from trapdata.db import Base -from trapdata.db.models.occurrences import Occurrence, list_occurrences +from trapdata.db.models.occurrences import OccurrenceListItem, list_occurrences router = APIRouter(prefix="/occurrences") -@router.get("", response_model=List[Occurrence]) +@router.get("", response_model=List[OccurrenceListItem]) async def get_occurrences( response: Response, # request_params: RequestParams = Depends(parse_react_admin_params(Base)), diff --git a/trapdata/common/filemanagement.py b/trapdata/common/filemanagement.py index ad27f253..ae1dc7be 100644 --- a/trapdata/common/filemanagement.py +++ b/trapdata/common/filemanagement.py @@ -17,6 +17,7 @@ from . import constants from .logs import logger +from .types import FilePath APP_NAME_SLUG = "AMI" EXIF_DATETIME_STR_FORMAT = "%Y:%m:%d %H:%M:%S" @@ -429,7 +430,9 @@ def initial_directory_choice(): return pathlib.Path("~/") -def media_url(local_path: str, delim: str, media_url_base: Optional[str] = None) -> str: +def media_url( + local_path: FilePath, delim: str, media_url_base: Optional[str] = None +) -> str: """ Given a local path to a file, return a URL to that file. @TODO rework this and handle slashes better. """ diff --git a/trapdata/db/models/events.py b/trapdata/db/models/events.py index 87acfef0..3f6740b5 100644 --- a/trapdata/db/models/events.py +++ b/trapdata/db/models/events.py @@ -8,25 +8,35 @@ from sqlalchemy.ext.hybrid import hybrid_method from sqlalchemy_utils import aggregated -from trapdata.common.filemanagement import find_images, group_images_by_day +from trapdata.common.filemanagement import find_images, group_images_by_day, media_url from trapdata.common.logs import logger from trapdata.common.types import FilePath from trapdata.common.utils import export_report from trapdata.db import Base, get_session, models from trapdata.db.models.deployments import deployment_name - # @TODO Rename to TrapEvent? CapturePeriod? less confusing with other types of Sessions. CaptureSession? Or SurveyEvent or Survey? + + +class MonitoringSessionNestedCapture(BaseModel): + id: str + path: pathlib.Path + timestamp: datetime.datetime + # example["source_image_path"] = media_url( + # example["source_image_path"], "captures/", media_url_base + # ) + + class MonitoringSessionListItem(BaseModel): id: str day: datetime.date - image_base_path: str + # image_base_path: str deployment: str num_captures: int num_detections: int num_occurrences: int num_species: int - example_captures: list[dict] + example_captures: list[MonitoringSessionNestedCapture] start_time: datetime.datetime end_time: datetime.datetime duration: datetime.timedelta @@ -391,6 +401,7 @@ def list_monitoring_sessions( image_base_path: FilePath, limit: Optional[int] = None, offset: int = 0, + num_examples: int = 5, media_url_base: Optional[str] = None, ) -> list[MonitoringSessionListItem]: """ """ @@ -413,6 +424,22 @@ def list_monitoring_sessions( list_items = [] for event in events: event.update_aggregates(session) + rows = session.execute( + sa.select( + models.TrapImage.id, models.TrapImage.path, models.TrapImage.timestamp + ) + .where(models.TrapImage.monitoring_session_id == event.id) + .order_by(models.TrapImage.filesize.desc()) + .limit(num_examples) + ).all() + example_captures = [ + MonitoringSessionNestedCapture( + id=row.id, + path=media_url(row.path, "captures/", media_url_base), + timestamp=row.timestamp, + ) + for row in rows + ] list_items.append( MonitoringSessionListItem( id=event.id, @@ -425,7 +452,7 @@ def list_monitoring_sessions( num_detections=event.num_detected_objects, num_occurrences=event.num_occurrences(session), num_species=event.num_species(session), - example_captures=[], # @TODO + example_captures=example_captures, duration=event.duration(), duration_label=event.duration_label, ) From b621b29f970faadd9929474192b2b30156e8f5a6 Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Thu, 13 Apr 2023 00:27:14 -0700 Subject: [PATCH 21/53] Optional URL on list items --- trapdata/db/models/images.py | 1 + 1 file changed, 1 insertion(+) diff --git a/trapdata/db/models/images.py b/trapdata/db/models/images.py index 0ccc193a..43378000 100644 --- a/trapdata/db/models/images.py +++ b/trapdata/db/models/images.py @@ -22,6 +22,7 @@ class CaptureListItem(BaseModel): path: pathlib.Path num_detections: Optional[int] in_queue: bool + url: Optional[str] = None class CaptureDetail(CaptureListItem): From cafb3d6a094c2958841179f1009674a6026f49d6 Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Thu, 13 Apr 2023 00:41:40 -0700 Subject: [PATCH 22/53] Hide image base path --- trapdata/db/models/deployments.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trapdata/db/models/deployments.py b/trapdata/db/models/deployments.py index d0aa9f5c..33ec18c4 100644 --- a/trapdata/db/models/deployments.py +++ b/trapdata/db/models/deployments.py @@ -19,7 +19,7 @@ class DeploymentListItem(BaseModel): id: Optional[int] = None name: str - image_base_path: FilePath + # image_base_path: FilePath num_events: int num_source_images: int num_detections: int From b667f45e94666f83fbc100ce3cedf995379d1755 Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Thu, 13 Apr 2023 00:54:34 -0700 Subject: [PATCH 23/53] Fix group by in PostgreSQL --- .gitignore | 4 +++- scripts/start_db_container.sh | 2 +- trapdata/db/models/deployments.py | 2 +- trapdata/db/models/occurrences.py | 2 +- 4 files changed, 6 insertions(+), 4 deletions(-) diff --git a/.gitignore b/.gitignore index 0e3fdd56..61c9031c 100644 --- a/.gitignore +++ b/.gitignore @@ -135,4 +135,6 @@ trapdata.ini # macOS -.DS_Store \ No newline at end of file +.DS_Store + +db_data diff --git a/scripts/start_db_container.sh b/scripts/start_db_container.sh index 1d13682c..307e5acb 100755 --- a/scripts/start_db_container.sh +++ b/scripts/start_db_container.sh @@ -8,7 +8,7 @@ HOST_PORT=5432 POSTGRES_VERSION=14 POSTGRES_DB=ami -docker run -d -i --name $CONTAINER_NAME -v ./db_data:/var/lib/postgresql/data --restart always -p $HOST_PORT:5432 -e POSTGRES_HOST_AUTH_METHOD=trust -e POSTGRES_DB=$POSTGRES_DB postgres:$POSTGRES_VERSION +docker run -d -i --name $CONTAINER_NAME -v "$(pwd)/db_data":/var/lib/postgresql/data --restart always -p $HOST_PORT:5432 -e POSTGRES_HOST_AUTH_METHOD=trust -e POSTGRES_DB=$POSTGRES_DB postgres:$POSTGRES_VERSION docker logs ami-db --tail 100 diff --git a/trapdata/db/models/deployments.py b/trapdata/db/models/deployments.py index 33ec18c4..d0aa9f5c 100644 --- a/trapdata/db/models/deployments.py +++ b/trapdata/db/models/deployments.py @@ -19,7 +19,7 @@ class DeploymentListItem(BaseModel): id: Optional[int] = None name: str - # image_base_path: FilePath + image_base_path: FilePath num_events: int num_source_images: int num_detections: int diff --git a/trapdata/db/models/occurrences.py b/trapdata/db/models/occurrences.py index 7f8c0b72..7b0d73f5 100644 --- a/trapdata/db/models/occurrences.py +++ b/trapdata/db/models/occurrences.py @@ -172,7 +172,7 @@ def get_unique_species_by_track( # Select all sequences where at least one example is above the score threshold stmt = sa.select( models.DetectedObject.sequence_id, - models.MonitoringSession.id.label("monitoring_session_id"), + models.DetectedObject.monitoring_session_id, models.MonitoringSession.day.label("monitoring_session_day"), models.MonitoringSession.base_directory.label( "monitoring_session_base_directory" From 5a5342c9456ec237293bcae14e634c540f360902 Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Fri, 14 Apr 2023 14:10:20 +0100 Subject: [PATCH 24/53] Add initial session detail endpoint, no captures --- trapdata/api/views/events.py | 17 +++++++- trapdata/db/models/events.py | 78 +++++++++++++++++++++++++++++++++++- 2 files changed, 92 insertions(+), 3 deletions(-) diff --git a/trapdata/api/views/events.py b/trapdata/api/views/events.py index febd2c9b..612bb6fa 100644 --- a/trapdata/api/views/events.py +++ b/trapdata/api/views/events.py @@ -1,13 +1,15 @@ from typing import Any, List -from fastapi import APIRouter, Depends +from fastapi import APIRouter, Depends, HTTPException from sqlalchemy import orm from starlette.responses import Response from trapdata.api.config import settings from trapdata.api.deps.db import get_session from trapdata.db.models.events import ( + MonitoringSessionDetail, MonitoringSessionListItem, + get_monitoring_session_by_id, list_monitoring_sessions, ) @@ -26,6 +28,19 @@ async def get_monitoring_sessions( return items +@router.get("/{event_id}", response_model=MonitoringSessionDetail) +async def get_monitoring_session( + event_id: int, + response: Response, + session: orm.Session = Depends(get_session), + # request_params: RequestParams = Depends(parse_react_admin_params(Base)), +) -> Any: + event = get_monitoring_session_by_id(session, event_id, media_url_base="/static/") + if not event: + raise HTTPException(404) + return event + + # @router.post("/process", response_model=List[DeploymentListItem]) # async def process_deployment( # response: Response, diff --git a/trapdata/db/models/events.py b/trapdata/db/models/events.py index 3f6740b5..e1240dcc 100644 --- a/trapdata/db/models/events.py +++ b/trapdata/db/models/events.py @@ -45,7 +45,9 @@ class MonitoringSessionListItem(BaseModel): class MonitoringSessionDetail(MonitoringSessionListItem): notes: Optional[str] - detections: list + captures: list[ + MonitoringSessionNestedCapture + ] # Too many! @TODO include summary data to generate the timeline instead # @TODO add more info about the session, like the number of images, the number of detected objects, etc # @TODO add the number of species detected in this session @@ -396,6 +398,79 @@ def export_monitoring_sessions( return export_report(records, report_name, directory) +def event_response( + session: orm.Session, + event: MonitoringSession, +) -> MonitoringSessionListItem: + """ + Reusable method to create a MonitoringSession Schema from a MonitoringSession model. + + @TODO decide if this is helpful or not to reuse in get_monitoring_sessions and get_monitoring_session_by_id + """ + + event.update_aggregates(session) + event_response = MonitoringSessionListItem( + id=event.id, + day=event.day, + deployment=deployment_name(str(event.base_directory)), + start_time=event.start_time, + end_time=event.end_time, + num_captures=event.num_images, + num_detections=event.num_detected_objects, + num_occurrences=event.num_occurrences(session), + num_species=event.num_species(session), + example_captures=[], + duration=event.duration(), + duration_label=event.duration_label, + ) + + return event_response + + +def get_monitoring_session_by_id( + session: orm.Session, + event_id: int, + media_url_base: str, +) -> Optional[MonitoringSessionDetail]: + event: Optional[MonitoringSession] = session.get(MonitoringSession, event_id) + if event: + event.update_aggregates(session) + captures = session.execute( + sa.select( + models.TrapImage.id, models.TrapImage.path, models.TrapImage.timestamp + ) + .where(models.TrapImage.monitoring_session_id == event.id) + .order_by(models.TrapImage.timestamp) + ).all() + nested_captures = [ + MonitoringSessionNestedCapture( + id=row.id, + path=media_url(row.path, "captures/", media_url_base), + timestamp=row.timestamp, + ) + for row in captures + ] + event_detail = MonitoringSessionDetail( + id=event.id, + day=event.day, + deployment=deployment_name(str(event.base_directory)), + start_time=event.start_time, + end_time=event.end_time, + num_captures=event.num_images, + num_detections=event.num_detected_objects, + num_occurrences=event.num_occurrences(session), + num_species=event.num_species(session), + example_captures=[], + duration=event.duration(), + duration_label=event.duration_label, + notes=event.notes, + captures=[], + ) + return event_detail + else: + return None + + def list_monitoring_sessions( session: orm.Session, image_base_path: FilePath, @@ -444,7 +519,6 @@ def list_monitoring_sessions( MonitoringSessionListItem( id=event.id, day=event.day, - image_base_path=str(event.base_directory), deployment=deployment_name(str(event.base_directory)), start_time=event.start_time, end_time=event.end_time, From 7fad50b727424f8756cabf24b3099674654a7821 Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Thu, 13 Apr 2023 00:54:34 -0700 Subject: [PATCH 25/53] Fix group by in PostgreSQL --- .gitignore | 4 +++- scripts/start_db_container.sh | 2 +- trapdata/db/base.py | 7 ++++++- trapdata/db/models/deployments.py | 2 +- trapdata/db/models/occurrences.py | 2 +- 5 files changed, 12 insertions(+), 5 deletions(-) diff --git a/.gitignore b/.gitignore index 0e3fdd56..61c9031c 100644 --- a/.gitignore +++ b/.gitignore @@ -135,4 +135,6 @@ trapdata.ini # macOS -.DS_Store \ No newline at end of file +.DS_Store + +db_data diff --git a/scripts/start_db_container.sh b/scripts/start_db_container.sh index 1d13682c..307e5acb 100755 --- a/scripts/start_db_container.sh +++ b/scripts/start_db_container.sh @@ -8,7 +8,7 @@ HOST_PORT=5432 POSTGRES_VERSION=14 POSTGRES_DB=ami -docker run -d -i --name $CONTAINER_NAME -v ./db_data:/var/lib/postgresql/data --restart always -p $HOST_PORT:5432 -e POSTGRES_HOST_AUTH_METHOD=trust -e POSTGRES_DB=$POSTGRES_DB postgres:$POSTGRES_VERSION +docker run -d -i --name $CONTAINER_NAME -v "$(pwd)/db_data":/var/lib/postgresql/data --restart always -p $HOST_PORT:5432 -e POSTGRES_HOST_AUTH_METHOD=trust -e POSTGRES_DB=$POSTGRES_DB postgres:$POSTGRES_VERSION docker logs ami-db --tail 100 diff --git a/trapdata/db/base.py b/trapdata/db/base.py index e50f4210..bb453809 100644 --- a/trapdata/db/base.py +++ b/trapdata/db/base.py @@ -13,12 +13,16 @@ from trapdata import logger from trapdata.common.types import DatabaseURL +DATABASE_SCHEMA_NAMESPACE = "trapdata" + DIALECT_CONNECTION_ARGS = { "sqlite": { "timeout": 10, # A longer timeout is necessary for SQLite and multiple PyTorch workers "check_same_thread": False, }, - "postgresql": {}, + "postgresql": { + 'options': f'-csearch_path={DATABASE_SCHEMA_NAMESPACE}' + }, } SUPPORTED_DIALECTS = list(DIALECT_CONNECTION_ARGS.keys()) @@ -71,6 +75,7 @@ def create_db(db_path: DatabaseURL) -> None: db = get_db(db_path) from . import Base + Base.metadata.schema = DATABASE_SCHEMA_NAMESPACE Base.metadata.create_all(db, checkfirst=True) alembic_cfg = get_alembic_config(db_path) diff --git a/trapdata/db/models/deployments.py b/trapdata/db/models/deployments.py index 33ec18c4..d0aa9f5c 100644 --- a/trapdata/db/models/deployments.py +++ b/trapdata/db/models/deployments.py @@ -19,7 +19,7 @@ class DeploymentListItem(BaseModel): id: Optional[int] = None name: str - # image_base_path: FilePath + image_base_path: FilePath num_events: int num_source_images: int num_detections: int diff --git a/trapdata/db/models/occurrences.py b/trapdata/db/models/occurrences.py index 7f8c0b72..7b0d73f5 100644 --- a/trapdata/db/models/occurrences.py +++ b/trapdata/db/models/occurrences.py @@ -172,7 +172,7 @@ def get_unique_species_by_track( # Select all sequences where at least one example is above the score threshold stmt = sa.select( models.DetectedObject.sequence_id, - models.MonitoringSession.id.label("monitoring_session_id"), + models.DetectedObject.monitoring_session_id, models.MonitoringSession.day.label("monitoring_session_day"), models.MonitoringSession.base_directory.label( "monitoring_session_base_directory" From b217e69d0b570844ae8d1c9fd721261f4f306861 Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Sat, 15 Apr 2023 07:23:30 +0100 Subject: [PATCH 26/53] Add species endpoint --- trapdata/api/views/__init__.py | 11 ++- trapdata/api/views/species.py | 28 ++++++++ trapdata/common/filemanagement.py | 2 +- trapdata/db/models/detections.py | 108 +++++++++++++++++++++++++++--- 4 files changed, 137 insertions(+), 12 deletions(-) create mode 100644 trapdata/api/views/species.py diff --git a/trapdata/api/views/__init__.py b/trapdata/api/views/__init__.py index 921cbbfb..fa891a78 100644 --- a/trapdata/api/views/__init__.py +++ b/trapdata/api/views/__init__.py @@ -1,6 +1,14 @@ from fastapi import APIRouter -from trapdata.api.views import deployments, events, occurrences, settings, stats, status +from trapdata.api.views import ( + deployments, + events, + occurrences, + settings, + species, + stats, + status, +) api_router = APIRouter() @@ -9,4 +17,5 @@ api_router.include_router(deployments.router, tags=["deployments"]) api_router.include_router(events.router, tags=["events"]) api_router.include_router(occurrences.router, tags=["occurrences"]) +api_router.include_router(species.router, tags=["species"]) api_router.include_router(settings.router, tags=["settings"]) diff --git a/trapdata/api/views/species.py b/trapdata/api/views/species.py new file mode 100644 index 00000000..ccd23bde --- /dev/null +++ b/trapdata/api/views/species.py @@ -0,0 +1,28 @@ +from typing import Any, List, Optional + +from fastapi import APIRouter, Depends, HTTPException +from sqlalchemy import func, orm, select +from starlette.responses import Response + +from trapdata.api.config import settings +from trapdata.api.deps.db import get_session +from trapdata.api.deps.request_params import parse_react_admin_params +from trapdata.api.request_params import RequestParams +from trapdata.db import Base +from trapdata.db.models.detections import TaxonListItem, list_species + +router = APIRouter(prefix="/species") + + +@router.get("", response_model=List[TaxonListItem]) +async def get_species( + response: Response, + session: orm.Session = Depends(get_session), + # request_params: RequestParams = Depends(parse_react_admin_params(Base)), +) -> Any: + species = list_species( + session=session, + classification_threshold=settings.classification_threshold, + media_url_base="/static/", + ) + return species diff --git a/trapdata/common/filemanagement.py b/trapdata/common/filemanagement.py index ae1dc7be..c6af295c 100644 --- a/trapdata/common/filemanagement.py +++ b/trapdata/common/filemanagement.py @@ -438,6 +438,6 @@ def media_url( """ relative_path = f"{delim}{local_path.split(delim)[-1]}" if media_url_base: - return f"{media_url_base}{relative_path}" + return os.path.join(media_url_base, relative_path) else: return relative_path diff --git a/trapdata/db/models/detections.py b/trapdata/db/models/detections.py index 6a9ebc9b..d68c148e 100644 --- a/trapdata/db/models/detections.py +++ b/trapdata/db/models/detections.py @@ -8,7 +8,12 @@ from sqlalchemy import orm from trapdata import constants, db -from trapdata.common.filemanagement import absolute_path, construct_exif, save_image +from trapdata.common.filemanagement import ( + absolute_path, + construct_exif, + media_url, + save_image, +) from trapdata.common.logs import logger from trapdata.common.types import FilePath from trapdata.common.utils import bbox_area, bbox_center, export_report @@ -18,15 +23,17 @@ class DetectionListItem(BaseModel): id: int - cropped_image_path: Optional[pathlib.Path] - bbox: Optional[tuple[float, float, float, float]] - area_pixels: Optional[float] - last_detected: Optional[datetime.datetime] - label: Optional[str] - score: Optional[int] - model_name: Optional[str] - in_queue: bool - notes: Optional[str] + cropped_image_path: Optional[FilePath] = None + bbox: Optional[tuple[float, float, float, float]] = None + area_pixels: Optional[float] = None + width: Optional[int] = None + height: Optional[int] = None + last_detected: Optional[datetime.datetime] = None + label: Optional[str] = None + score: Optional[int] = None + model_name: Optional[str] = None + in_queue: bool = False + notes: Optional[str] = "" class DetectionDetail(DetectionListItem): @@ -532,6 +539,87 @@ def num_occurrences_for_event( return sesh.execute(query).scalar_one() +class TaxonListItem(BaseModel): + name: str + genus: Optional[str] = None + family: Optional[str] = None + num_occurrences: Optional[int] = None + num_detections: Optional[int] = None + examples: list[DetectionListItem] = list() + score_stats: Optional[dict[str, float]] = None + training_examples: Optional[int] = None + + +def list_species( + session: orm.Session, + classification_threshold: int = 0, + num_examples: int = 3, + media_url_base: Optional[str] = None, +) -> list[TaxonListItem]: + """ + Return a list of unique species and example detections. + """ + species = session.execute( + sa.select( + DetectedObject.specific_label.label("name"), + sa.func.count(DetectedObject.specific_label.distinct()).label( + "num_detections" + ), + sa.func.count(DetectedObject.sequence_id.distinct()).label( + "num_occurrences" + ), # @TODO handle sequences with None + sa.func.max(DetectedObject.specific_label_score).label("score_max"), + sa.func.min(DetectedObject.specific_label_score).label("score_min"), + sa.func.avg(DetectedObject.specific_label_score).label("score_mean"), + ) + .where( + DetectedObject.specific_label_score >= classification_threshold, + ) + .group_by(DetectedObject.specific_label) + ).all() + + examples = ( + sa.select(DetectedObject) + .where(DetectedObject.specific_label.in_([sp.name for sp in species])) + .limit(num_examples) + .order_by(DetectedObject.specific_label_score.desc()) + ) + + examples_by_name = {} + for detection in session.execute(examples).unique().scalars().all(): + examples_by_name.setdefault(detection.specific_label, []).append(detection) + + metadata_by_name = {sp.name: sp for sp in species} + + taxa = [ + TaxonListItem( + name=name, + num_occurrences=metadata_by_name[name].num_occurrences, + num_detections=metadata_by_name[name].num_detections, + score_stats={ + "max": metadata_by_name[name].score_max, + "min": metadata_by_name[name].score_min, + "mean": metadata_by_name[name].score_mean, + }, + examples=[ + DetectionListItem( + id=detection.id, + cropped_image_path=media_url( + detection.path, + "crops", + media_url_base=media_url_base, + ), + height=detection.height(), + width=detection.width(), + ) + for detection in examples + ], + ) + for name, examples in examples_by_name.items() + ] + return taxa + + def get_unique_species( db_path, monitoring_session=None, classification_threshold: float = -1 ): From 202f5c1c48dc235476ab30696f7a58e86bc5c0b5 Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Sat, 15 Apr 2023 23:46:41 +0100 Subject: [PATCH 27/53] Support fetching ocur. for all events in a dep. --- trapdata/api/views/occurrences.py | 21 ++++++++++++++ trapdata/cli/export.py | 35 ++++++++++++++++++----- trapdata/cli/show.py | 1 + trapdata/cli/test.py | 2 ++ trapdata/db/models/occurrences.py | 46 ++++++++++++++++++------------- trapdata/ui/summary.py | 3 +- 6 files changed, 81 insertions(+), 27 deletions(-) diff --git a/trapdata/api/views/occurrences.py b/trapdata/api/views/occurrences.py index 46cb4a9a..68fd9190 100644 --- a/trapdata/api/views/occurrences.py +++ b/trapdata/api/views/occurrences.py @@ -21,12 +21,33 @@ async def get_occurrences( ) -> Any: occurrences = list_occurrences( settings.database_url, + settings.image_base_path, classification_threshold=settings.classification_threshold, media_url_base="/static/", ) return occurrences +@router.get("", response_model=List[OccurrenceListItem]) +async def get_occurrence( + item_id: int, + response: Response, + # request_params: RequestParams = Depends(parse_react_admin_params(Base)), +) -> Any: + """ + @TODO placeholder! replace this with an actual get single occurrence method. + """ + occurrences = list_occurrences( + settings.database_url, + settings.image_base_path, + classification_threshold=settings.classification_threshold, + media_url_base="/static/", + limit=1, + offset=item_id, + ) + return occurrences + + # @router.post("/process", response_model=List[DeploymentListItem]) # async def process_deployment( # response: Response, diff --git a/trapdata/cli/export.py b/trapdata/cli/export.py index 4f6a7cb0..cad216f7 100644 --- a/trapdata/cli/export.py +++ b/trapdata/cli/export.py @@ -14,12 +14,15 @@ from trapdata.cli import settings from trapdata.db import get_session_class from trapdata.db.models.deployments import list_deployments -from trapdata.db.models.detections import get_detected_objects +from trapdata.db.models.detections import ( + get_detected_objects, + num_occurrences_for_event, + num_species_for_event, +) from trapdata.db.models.events import ( get_monitoring_session_by_date, get_monitoring_session_images, get_monitoring_sessions_from_db, - list_monitoring_sessions, ) from trapdata.db.models.occurrences import list_occurrences @@ -89,6 +92,7 @@ def occurrences( for event in events: occurrences += list_occurrences( settings.database_url, + settings.image_base_path, monitoring_session=event, classification_threshold=settings.classification_threshold, num_examples=num_examples, @@ -158,11 +162,28 @@ def sessions( """ Export a summary of monitoring sessions from database in the specified format. """ - Session = get_session_class(settings.database_url) - session = Session() - - items = list_monitoring_sessions(session, settings.image_base_path) - df = pd.DataFrame([item.dict() for item in items]) + monitoring_events = get_monitoring_sessions_from_db( + db_path=settings.database_url, base_directory=settings.image_base_path + ) + items = [] + for event in monitoring_events: + event_data = event.report_data() + num_occurrences = num_occurrences_for_event( + db_path=settings.database_url, monitoring_session=event + ) + num_species = num_species_for_event( + db_path=settings.database_url, monitoring_session=event + ) + example_captures = get_monitoring_session_images( + settings.database_url, event, limit=5, offset=int(event.num_images / 2) + ) + event_data["example_captures"] = [ + img.report_data().dict() for img in example_captures + ] + event_data["num_occurrences"] = num_occurrences + event_data["num_species"] = num_species + items.append(event_data) + df = pd.DataFrame(items) return export(df=df, format=format, outfile=outfile) diff --git a/trapdata/cli/show.py b/trapdata/cli/show.py index 10091dd1..0b6a39a6 100644 --- a/trapdata/cli/show.py +++ b/trapdata/cli/show.py @@ -186,6 +186,7 @@ def occurrences( occurrences = list_occurrences( settings.database_url, + settings.image_base_path, event, classification_threshold=settings.classification_threshold, limit=limit, diff --git a/trapdata/cli/test.py b/trapdata/cli/test.py index d63f2841..2e5c96b7 100644 --- a/trapdata/cli/test.py +++ b/trapdata/cli/test.py @@ -5,6 +5,7 @@ import typer from rich import print from sqlalchemy import select + from trapdata.cli import settings from trapdata.db.base import check_db, get_session_class from trapdata.db.models import MonitoringSession @@ -54,6 +55,7 @@ def species_by_track(event_day: datetime.datetime): print(f"Matched of event: {event}") get_unique_species_by_track( settings.database_url, + image_base_path=event.base_directory, monitoring_session=event, classification_threshold=0.1, ) diff --git a/trapdata/db/models/occurrences.py b/trapdata/db/models/occurrences.py index 7b0d73f5..f580c598 100644 --- a/trapdata/db/models/occurrences.py +++ b/trapdata/db/models/occurrences.py @@ -15,6 +15,7 @@ from trapdata import db from trapdata.common.filemanagement import media_url +from trapdata.common.types import FilePath from trapdata.db import models @@ -56,6 +57,7 @@ class SpeciesSummaryListItem(BaseModel): def list_occurrences( db_path: str, + image_base_path: FilePath, monitoring_session: Optional[models.MonitoringSession] = None, classification_threshold: float = -1, num_examples: int = 3, @@ -66,7 +68,8 @@ def list_occurrences( occurrences = [] for item in get_unique_species_by_track( db_path, - monitoring_session, + monitoring_session=monitoring_session, + image_base_path=image_base_path, classification_threshold=classification_threshold, num_examples=num_examples, limit=limit, @@ -160,6 +163,7 @@ def list_species( def get_unique_species_by_track( db_path: str, + image_base_path: FilePath, monitoring_session: Optional[models.MonitoringSession] = None, classification_threshold: float = -1, num_examples: int = 3, @@ -170,24 +174,28 @@ def get_unique_species_by_track( session = Session() # Select all sequences where at least one example is above the score threshold - stmt = sa.select( - models.DetectedObject.sequence_id, - models.DetectedObject.monitoring_session_id, - models.MonitoringSession.day.label("monitoring_session_day"), - models.MonitoringSession.base_directory.label( - "monitoring_session_base_directory" - ), - sa.func.count(models.DetectedObject.id).label( - "sequence_frame_count" - ), # frames in track - sa.func.max(models.DetectedObject.specific_label_score).label( - "sequence_best_score" - ), - sa.func.min(models.DetectedObject.timestamp).label("sequence_start_time"), - sa.func.max(models.DetectedObject.timestamp).label("sequence_end_time"), - ).join( - models.MonitoringSession, - models.DetectedObject.monitoring_session_id == models.MonitoringSession.id, + stmt = ( + sa.select( + models.DetectedObject.sequence_id, + models.DetectedObject.monitoring_session_id, + models.MonitoringSession.day.label("monitoring_session_day"), + models.MonitoringSession.base_directory.label( + "monitoring_session_base_directory" + ), + sa.func.count(models.DetectedObject.id).label( + "sequence_frame_count" + ), # frames in track + sa.func.max(models.DetectedObject.specific_label_score).label( + "sequence_best_score" + ), + sa.func.min(models.DetectedObject.timestamp).label("sequence_start_time"), + sa.func.max(models.DetectedObject.timestamp).label("sequence_end_time"), + ) + .join( + models.MonitoringSession, + models.DetectedObject.monitoring_session_id == models.MonitoringSession.id, + ) + .where(models.MonitoringSession.base_directory == image_base_path) ) if monitoring_session: stmt = stmt.where(models.MonitoringSession.id == monitoring_session.id) diff --git a/trapdata/ui/summary.py b/trapdata/ui/summary.py index bfb7f6ec..1782f3ef 100644 --- a/trapdata/ui/summary.py +++ b/trapdata/ui/summary.py @@ -180,7 +180,8 @@ def load_species(self, ms): # ) classification_summary = get_unique_species_by_track( app.db_path, - ms, + image_base_path=ms.base_directory, + monitoring_session=ms, classification_threshold=classification_threshold, num_examples=NUM_EXAMPLES_PER_ROW, ) From 9d58d4941c5a1ceade92425066d495f201cbdef8 Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Sat, 15 Apr 2023 23:47:12 +0100 Subject: [PATCH 28/53] Renable type hints for settings class --- trapdata/settings.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/trapdata/settings.py b/trapdata/settings.py index 1412cf21..7e9738f7 100644 --- a/trapdata/settings.py +++ b/trapdata/settings.py @@ -206,11 +206,9 @@ def kivy_settings_source(settings: BaseSettings) -> dict[str, str]: @lru_cache -def read_settings( - settings_class: ModelMetaclass = Settings, *args, **kwargs -) -> ModelMetaclass: +def read_settings(*args, **kwargs) -> Settings: try: - return settings_class(*args, **kwargs) + return Settings(*args, **kwargs) except ValidationError as e: # @TODO the validation errors could be printed in a more helpful way: rprint(cli_help_message) From a249356ec6c666f248a6c66f30caa6905c4e1650 Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Sun, 16 Apr 2023 00:00:34 +0100 Subject: [PATCH 29/53] Fix list occurrences --- trapdata/cli/show.py | 2 +- trapdata/db/models/occurrences.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/trapdata/cli/show.py b/trapdata/cli/show.py index 395d050c..729508fc 100644 --- a/trapdata/cli/show.py +++ b/trapdata/cli/show.py @@ -198,7 +198,7 @@ def occurrences( ) for occurrence in occurrences: table.add_row( - occurrence.event, + str(occurrence.event.day), occurrence.id, occurrence.label, str(occurrence.num_frames), diff --git a/trapdata/db/models/occurrences.py b/trapdata/db/models/occurrences.py index f580c598..206883f3 100644 --- a/trapdata/db/models/occurrences.py +++ b/trapdata/db/models/occurrences.py @@ -195,7 +195,7 @@ def get_unique_species_by_track( models.MonitoringSession, models.DetectedObject.monitoring_session_id == models.MonitoringSession.id, ) - .where(models.MonitoringSession.base_directory == image_base_path) + .where(models.MonitoringSession.base_directory == str(image_base_path)) ) if monitoring_session: stmt = stmt.where(models.MonitoringSession.id == monitoring_session.id) From 855beb3a4364d0e1eb93b214305e0865f7d3f23e Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Sun, 16 Apr 2023 00:06:08 +0100 Subject: [PATCH 30/53] Filter species by current deployment --- trapdata/api/views/species.py | 1 + trapdata/db/models/detections.py | 8 ++++++++ 2 files changed, 9 insertions(+) diff --git a/trapdata/api/views/species.py b/trapdata/api/views/species.py index ccd23bde..ce260e0b 100644 --- a/trapdata/api/views/species.py +++ b/trapdata/api/views/species.py @@ -22,6 +22,7 @@ async def get_species( ) -> Any: species = list_species( session=session, + image_base_path=settings.image_base_path, classification_threshold=settings.classification_threshold, media_url_base="/static/", ) diff --git a/trapdata/db/models/detections.py b/trapdata/db/models/detections.py index d68c148e..b3794708 100644 --- a/trapdata/db/models/detections.py +++ b/trapdata/db/models/detections.py @@ -552,12 +552,15 @@ class TaxonListItem(BaseModel): def list_species( session: orm.Session, + image_base_path: FilePath, classification_threshold: int = 0, num_examples: int = 3, media_url_base: Optional[str] = None, ) -> list[TaxonListItem]: """ Return a list of unique species and example detections. + + @TODO compare this with list_species in occurrences.py """ species = session.execute( sa.select( @@ -572,9 +575,14 @@ def list_species( sa.func.min(DetectedObject.specific_label_score).label("score_min"), sa.func.avg(DetectedObject.specific_label_score).label("score_mean"), ) + .join( + models.MonitoringSession, + models.MonitoringSession.id == DetectedObject.monitoring_session_id, + ) .where( DetectedObject.specific_label_score >= classification_threshold, ) + .where(DetectedObject.monitoring_session.base_directory == str(image_base_path)) .group_by(DetectedObject.specific_label) ).all() From 8de725e3f6f31bf87ae71ce845ab75a4a3f75c1e Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Sun, 16 Apr 2023 00:08:39 +0100 Subject: [PATCH 31/53] Fix join --- trapdata/db/models/detections.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trapdata/db/models/detections.py b/trapdata/db/models/detections.py index b3794708..77279e87 100644 --- a/trapdata/db/models/detections.py +++ b/trapdata/db/models/detections.py @@ -582,7 +582,7 @@ def list_species( .where( DetectedObject.specific_label_score >= classification_threshold, ) - .where(DetectedObject.monitoring_session.base_directory == str(image_base_path)) + .where(models.MonitoringSession.base_directory == str(image_base_path)) .group_by(DetectedObject.specific_label) ).all() From 4816025bec5aa39c043dade3659ff0ffcc0e5c0d Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Sun, 16 Apr 2023 00:13:57 +0100 Subject: [PATCH 32/53] Add limits & offset from API --- trapdata/api/views/events.py | 2 ++ trapdata/api/views/occurrences.py | 4 ++++ trapdata/api/views/species.py | 4 ++++ trapdata/db/models/detections.py | 4 ++++ 4 files changed, 14 insertions(+) diff --git a/trapdata/api/views/events.py b/trapdata/api/views/events.py index 612bb6fa..9f491eb0 100644 --- a/trapdata/api/views/events.py +++ b/trapdata/api/views/events.py @@ -21,6 +21,8 @@ async def get_monitoring_sessions( response: Response, session: orm.Session = Depends(get_session), # request_params: RequestParams = Depends(parse_react_admin_params(Base)), + limit: int = 100, + offset: int = 100, ) -> Any: items = list_monitoring_sessions( session, settings.image_base_path, media_url_base="/static/" diff --git a/trapdata/api/views/occurrences.py b/trapdata/api/views/occurrences.py index 68fd9190..654817f8 100644 --- a/trapdata/api/views/occurrences.py +++ b/trapdata/api/views/occurrences.py @@ -17,6 +17,8 @@ @router.get("", response_model=List[OccurrenceListItem]) async def get_occurrences( response: Response, + limit: int = 100, + offset: int = 0, # request_params: RequestParams = Depends(parse_react_admin_params(Base)), ) -> Any: occurrences = list_occurrences( @@ -24,6 +26,8 @@ async def get_occurrences( settings.image_base_path, classification_threshold=settings.classification_threshold, media_url_base="/static/", + limit=limit, + offset=offset, ) return occurrences diff --git a/trapdata/api/views/species.py b/trapdata/api/views/species.py index ce260e0b..845c3287 100644 --- a/trapdata/api/views/species.py +++ b/trapdata/api/views/species.py @@ -19,11 +19,15 @@ async def get_species( response: Response, session: orm.Session = Depends(get_session), # request_params: RequestParams = Depends(parse_react_admin_params(Base)), + limit: int = 100, + offset: int = 0, ) -> Any: species = list_species( session=session, image_base_path=settings.image_base_path, classification_threshold=settings.classification_threshold, media_url_base="/static/", + limit=limit, + offset=offset, ) return species diff --git a/trapdata/db/models/detections.py b/trapdata/db/models/detections.py index 77279e87..285aae5a 100644 --- a/trapdata/db/models/detections.py +++ b/trapdata/db/models/detections.py @@ -556,6 +556,8 @@ def list_species( classification_threshold: int = 0, num_examples: int = 3, media_url_base: Optional[str] = None, + limit: int = 100, + offset: int = 0, ) -> list[TaxonListItem]: """ Return a list of unique species and example detections. @@ -584,6 +586,8 @@ def list_species( ) .where(models.MonitoringSession.base_directory == str(image_base_path)) .group_by(DetectedObject.specific_label) + .limit(limit) + .offset(offset) ).all() examples = ( From f3d09c00617a3bfd49656e008f091e643c99985b Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Sun, 16 Apr 2023 00:29:33 +0100 Subject: [PATCH 33/53] Only show stats for current deployment --- trapdata/api/views/status.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/trapdata/api/views/status.py b/trapdata/api/views/status.py index 290ba385..392cb438 100644 --- a/trapdata/api/views/status.py +++ b/trapdata/api/views/status.py @@ -38,10 +38,11 @@ async def get_summary_counts( response: Response, session: orm.Session = Depends(get_session), ) -> Any: - deployments = list_deployments(session) - events = [] - for deployment in deployments: - events += list_monitoring_sessions(session, deployment.image_base_path) + # deployments = list_deployments(session) + # events = [] + # for deployment in deployments: + # events += list_monitoring_sessions(session, deployment.image_base_path) + events = list_monitoring_sessions(session, settings.image_base_path) summary = SummaryCounts( num_deployments=len({e.deployment for e in events}), From d73d411eba303a7aa9f735104455aff0f93f87c2 Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Sun, 16 Apr 2023 00:29:42 +0100 Subject: [PATCH 34/53] Formatting --- trapdata/db/base.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/trapdata/db/base.py b/trapdata/db/base.py index bb453809..ed19bd3a 100644 --- a/trapdata/db/base.py +++ b/trapdata/db/base.py @@ -20,9 +20,7 @@ "timeout": 10, # A longer timeout is necessary for SQLite and multiple PyTorch workers "check_same_thread": False, }, - "postgresql": { - 'options': f'-csearch_path={DATABASE_SCHEMA_NAMESPACE}' - }, + "postgresql": {"options": f"-csearch_path={DATABASE_SCHEMA_NAMESPACE}"}, } SUPPORTED_DIALECTS = list(DIALECT_CONNECTION_ARGS.keys()) @@ -75,8 +73,12 @@ def create_db(db_path: DatabaseURL) -> None: db = get_db(db_path) from . import Base - Base.metadata.schema = DATABASE_SCHEMA_NAMESPACE + with db.connect() as con: + if not db.dialect.has_schema(con, DATABASE_SCHEMA_NAMESPACE): + print("CREATING SCHEMS") + con.execute(sqlalchemy.schema.CreateSchema(DATABASE_SCHEMA_NAMESPACE)) + Base.metadata.schema = DATABASE_SCHEMA_NAMESPACE Base.metadata.create_all(db, checkfirst=True) alembic_cfg = get_alembic_config(db_path) alembic.stamp(alembic_cfg, "head") From f21efe03b5b11dda8cb7536a84a737a3e12c954f Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Sun, 16 Apr 2023 00:32:32 +0100 Subject: [PATCH 35/53] Show total deployments --- trapdata/api/views/status.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/trapdata/api/views/status.py b/trapdata/api/views/status.py index 392cb438..f76067d4 100644 --- a/trapdata/api/views/status.py +++ b/trapdata/api/views/status.py @@ -38,14 +38,14 @@ async def get_summary_counts( response: Response, session: orm.Session = Depends(get_session), ) -> Any: - # deployments = list_deployments(session) + deployments = list_deployments(session) # events = [] # for deployment in deployments: # events += list_monitoring_sessions(session, deployment.image_base_path) events = list_monitoring_sessions(session, settings.image_base_path) summary = SummaryCounts( - num_deployments=len({e.deployment for e in events}), + num_deployments=len(deployments), num_sessions=len(events), num_captures=sum(e.num_captures for e in events), num_detections=sum(e.num_detections for e in events), From 1b6f1847f6771c7b7c901a3aa4ee8395354b9fcc Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Sun, 16 Apr 2023 00:41:21 +0100 Subject: [PATCH 36/53] Fix num species in summary --- trapdata/api/views/status.py | 5 ++++- trapdata/db/models/detections.py | 15 +++++++++++++++ 2 files changed, 19 insertions(+), 1 deletion(-) diff --git a/trapdata/api/views/status.py b/trapdata/api/views/status.py index f76067d4..d88dab1c 100644 --- a/trapdata/api/views/status.py +++ b/trapdata/api/views/status.py @@ -9,6 +9,7 @@ from trapdata.api.config import settings from trapdata.api.deps.db import get_session from trapdata.db.models.deployments import list_deployments +from trapdata.db.models.detections import num_species_for_deployment from trapdata.db.models.events import list_monitoring_sessions from trapdata.db.models.queue import QueueListItem, list_queues @@ -50,7 +51,9 @@ async def get_summary_counts( num_captures=sum(e.num_captures for e in events), num_detections=sum(e.num_detections for e in events), num_occurrences=sum(e.num_occurrences for e in events), - num_species=sum(e.num_species for e in events), + num_species=num_species_for_deployment( + session, image_base_path=settings.image_base_path + ), last_updated=datetime.datetime.now(), ) diff --git a/trapdata/db/models/detections.py b/trapdata/db/models/detections.py index 285aae5a..5d9e660e 100644 --- a/trapdata/db/models/detections.py +++ b/trapdata/db/models/detections.py @@ -632,6 +632,21 @@ def list_species( return taxa +def num_species_for_deployment(session: orm.Session, image_base_path: FilePath) -> int: + return ( + session.execute( + sa.select(sa.func.count(models.DetectedObject.specific_label.distinct())) + .join( + models.MonitoringSession, + models.MonitoringSession.id + == models.DetectedObject.monitoring_session_id, + ) + .where(models.MonitoringSession.base_directory == str(image_base_path)) + ).scalar() + or 0 + ) + + def get_unique_species( db_path, monitoring_session=None, classification_threshold: float = -1 ): From 7ecef36b1b40b4cbd75eb6632f5b6b4a058e310b Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Sun, 16 Apr 2023 00:45:25 +0100 Subject: [PATCH 37/53] Fix species detection count --- trapdata/db/models/detections.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/trapdata/db/models/detections.py b/trapdata/db/models/detections.py index 5d9e660e..9ee8ea52 100644 --- a/trapdata/db/models/detections.py +++ b/trapdata/db/models/detections.py @@ -567,9 +567,7 @@ def list_species( species = session.execute( sa.select( DetectedObject.specific_label.label("name"), - sa.func.count(DetectedObject.specific_label.distinct()).label( - "num_detections" - ), + sa.func.count(DetectedObject.id).label("num_detections"), sa.func.count(DetectedObject.sequence_id.distinct()).label( "num_occurrences" ), # @TODO handle sequences with None From 11b61a31ef6db4db81f2573b5676e99cd10a6d9d Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Sun, 16 Apr 2023 00:59:52 +0100 Subject: [PATCH 38/53] Fix num events count --- trapdata/db/models/deployments.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/trapdata/db/models/deployments.py b/trapdata/db/models/deployments.py index d0aa9f5c..835ea47d 100644 --- a/trapdata/db/models/deployments.py +++ b/trapdata/db/models/deployments.py @@ -45,7 +45,6 @@ def list_deployments(session: orm.Session) -> list[DeploymentListItem]: """ stmt = sa.select( models.MonitoringSession.base_directory.label("image_base_path"), - sa.func.count(models.MonitoringSession.id).label("num_events"), sa.func.sum(models.MonitoringSession.num_images).label("num_source_images"), sa.func.sum(models.MonitoringSession.num_detected_objects).label( "num_detections" @@ -53,9 +52,19 @@ def list_deployments(session: orm.Session) -> list[DeploymentListItem]: ).group_by(models.MonitoringSession.base_directory) deployments = [] for deployment in session.execute(stmt).all(): + num_events = ( + session.scalar( + sa.select(sa.func.count(models.MonitoringSession.id)).where( + models.MonitoringSession.base_directory + == str(deployment.image_base_path) + ) + ) + or 0 + ) deployments.append( DeploymentListItem( **deployment._mapping, + num_events=num_events, name=deployment_name(deployment.image_base_path), ) ) From 58d9ec52596f10fc59ae5dfde5a0f0407586797c Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Sun, 16 Apr 2023 01:05:51 +0100 Subject: [PATCH 39/53] Add height width props to detection --- trapdata/db/models/detections.py | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/trapdata/db/models/detections.py b/trapdata/db/models/detections.py index 9ee8ea52..90bfb7a0 100644 --- a/trapdata/db/models/detections.py +++ b/trapdata/db/models/detections.py @@ -6,6 +6,7 @@ import sqlalchemy as sa from pydantic import BaseModel from sqlalchemy import orm +from sqlalchemy.ext.hybrid import hybrid_property from trapdata import constants, db from trapdata.common.filemanagement import ( @@ -160,11 +161,15 @@ def save_cropped_image_data( self.path = str(fpath) return fpath - def width(self): - pass # Use bbox + @hybrid_property + def width(self) -> int: + x1, y1, x2, y2 = self.bbox + return x2 - x1 - def height(self): - pass # Use bbox + @hybrid_property + def height(self) -> int: + x1, y1, x2, y2 = self.bbox + return y2 - y1 def previous_frame_detections( self, session: orm.Session @@ -619,8 +624,8 @@ def list_species( "crops", media_url_base=media_url_base, ), - height=detection.height(), - width=detection.width(), + height=detection.height, + width=detection.width, ) for detection in examples ], From 17394682888842cf82f6491658bda43f25ad24ec Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Mon, 17 Apr 2023 22:53:45 +0100 Subject: [PATCH 40/53] Fix occurrence detail path --- trapdata/api/views/occurrences.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trapdata/api/views/occurrences.py b/trapdata/api/views/occurrences.py index 654817f8..6bd64895 100644 --- a/trapdata/api/views/occurrences.py +++ b/trapdata/api/views/occurrences.py @@ -32,7 +32,7 @@ async def get_occurrences( return occurrences -@router.get("", response_model=List[OccurrenceListItem]) +@router.get("/{item_id}", response_model=List[OccurrenceListItem]) async def get_occurrence( item_id: int, response: Response, From ac2618f1b60903d74b261954a94db5ea15009b8f Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Mon, 17 Apr 2023 22:54:07 +0100 Subject: [PATCH 41/53] Update aggregates before showing deployments --- trapdata/api/views/deployments.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/trapdata/api/views/deployments.py b/trapdata/api/views/deployments.py index 95557331..c212cec2 100644 --- a/trapdata/api/views/deployments.py +++ b/trapdata/api/views/deployments.py @@ -10,6 +10,7 @@ from trapdata.api.request_params import RequestParams from trapdata.db import Base from trapdata.db.models.deployments import DeploymentListItem, list_deployments +from trapdata.db.models.events import update_all_aggregates router = APIRouter(prefix="/deployments") @@ -20,6 +21,7 @@ async def get_deployments( session: orm.Session = Depends(get_session), # request_params: RequestParams = Depends(parse_react_admin_params(Base)), ) -> Any: + update_all_aggregates(session, settings.image_base_path) deployments = list_deployments(session) return deployments From 605ce96b1aa9e6288fee213cec3411326b44af83 Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Mon, 17 Apr 2023 23:40:54 +0100 Subject: [PATCH 42/53] Fix species list --- trapdata/db/models/detections.py | 11 ++++------- 1 file changed, 4 insertions(+), 7 deletions(-) diff --git a/trapdata/db/models/detections.py b/trapdata/db/models/detections.py index 90bfb7a0..3a0d1fd0 100644 --- a/trapdata/db/models/detections.py +++ b/trapdata/db/models/detections.py @@ -559,7 +559,7 @@ def list_species( session: orm.Session, image_base_path: FilePath, classification_threshold: int = 0, - num_examples: int = 3, + num_examples: int = 5, media_url_base: Optional[str] = None, limit: int = 100, offset: int = 0, @@ -580,14 +580,11 @@ def list_species( sa.func.min(DetectedObject.specific_label_score).label("score_min"), sa.func.avg(DetectedObject.specific_label_score).label("score_mean"), ) - .join( - models.MonitoringSession, - models.MonitoringSession.id == DetectedObject.monitoring_session_id, - ) .where( - DetectedObject.specific_label_score >= classification_threshold, + (models.TrapImage.base_path == str(image_base_path)) + & (models.DetectedObject.specific_label_score >= classification_threshold) ) - .where(models.MonitoringSession.base_directory == str(image_base_path)) + .join(models.TrapImage, models.DetectedObject.image_id == models.TrapImage.id) .group_by(DetectedObject.specific_label) .limit(limit) .offset(offset) From 09ec6ad4f5fcd91d2d8d486a2f0a8af1072fd074 Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Mon, 17 Apr 2023 23:41:11 +0100 Subject: [PATCH 43/53] Don't create schema for sqlite --- trapdata/db/base.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/trapdata/db/base.py b/trapdata/db/base.py index ed19bd3a..397d5dcd 100644 --- a/trapdata/db/base.py +++ b/trapdata/db/base.py @@ -74,10 +74,10 @@ def create_db(db_path: DatabaseURL) -> None: from . import Base - with db.connect() as con: - if not db.dialect.has_schema(con, DATABASE_SCHEMA_NAMESPACE): - print("CREATING SCHEMS") - con.execute(sqlalchemy.schema.CreateSchema(DATABASE_SCHEMA_NAMESPACE)) + if db.dialect.name != "sqlite": + with db.connect() as con: + if not db.dialect.has_schema(con, DATABASE_SCHEMA_NAMESPACE): + con.execute(sqlalchemy.schema.CreateSchema(DATABASE_SCHEMA_NAMESPACE)) Base.metadata.schema = DATABASE_SCHEMA_NAMESPACE Base.metadata.create_all(db, checkfirst=True) alembic_cfg = get_alembic_config(db_path) From b979be56e1756f113b088252ede6384951636d23 Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Mon, 17 Apr 2023 23:43:17 +0100 Subject: [PATCH 44/53] Update limits and examples --- trapdata/api/views/occurrences.py | 2 +- trapdata/api/views/species.py | 2 +- trapdata/db/models/detections.py | 2 +- trapdata/db/models/occurrences.py | 2 +- 4 files changed, 4 insertions(+), 4 deletions(-) diff --git a/trapdata/api/views/occurrences.py b/trapdata/api/views/occurrences.py index 6bd64895..ce3698b8 100644 --- a/trapdata/api/views/occurrences.py +++ b/trapdata/api/views/occurrences.py @@ -17,7 +17,7 @@ @router.get("", response_model=List[OccurrenceListItem]) async def get_occurrences( response: Response, - limit: int = 100, + limit: int = 20, offset: int = 0, # request_params: RequestParams = Depends(parse_react_admin_params(Base)), ) -> Any: diff --git a/trapdata/api/views/species.py b/trapdata/api/views/species.py index 845c3287..ac163807 100644 --- a/trapdata/api/views/species.py +++ b/trapdata/api/views/species.py @@ -19,7 +19,7 @@ async def get_species( response: Response, session: orm.Session = Depends(get_session), # request_params: RequestParams = Depends(parse_react_admin_params(Base)), - limit: int = 100, + limit: int = 20, offset: int = 0, ) -> Any: species = list_species( diff --git a/trapdata/db/models/detections.py b/trapdata/db/models/detections.py index 3a0d1fd0..9dc4ae7a 100644 --- a/trapdata/db/models/detections.py +++ b/trapdata/db/models/detections.py @@ -561,7 +561,7 @@ def list_species( classification_threshold: int = 0, num_examples: int = 5, media_url_base: Optional[str] = None, - limit: int = 100, + limit: Optional[int] = None, offset: int = 0, ) -> list[TaxonListItem]: """ diff --git a/trapdata/db/models/occurrences.py b/trapdata/db/models/occurrences.py index 206883f3..9c1ec7a2 100644 --- a/trapdata/db/models/occurrences.py +++ b/trapdata/db/models/occurrences.py @@ -60,7 +60,7 @@ def list_occurrences( image_base_path: FilePath, monitoring_session: Optional[models.MonitoringSession] = None, classification_threshold: float = -1, - num_examples: int = 3, + num_examples: int = 5, limit: Optional[int] = None, offset: int = 0, media_url_base: Optional[str] = None, From fe4404949f7f5641dbb780e7b59d6d9bc8711e21 Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Tue, 18 Apr 2023 00:05:55 +0100 Subject: [PATCH 45/53] Fix list of examples --- trapdata/db/models/detections.py | 24 ++++++++++++++++-------- 1 file changed, 16 insertions(+), 8 deletions(-) diff --git a/trapdata/db/models/detections.py b/trapdata/db/models/detections.py index 9dc4ae7a..91c41e77 100644 --- a/trapdata/db/models/detections.py +++ b/trapdata/db/models/detections.py @@ -572,6 +572,7 @@ def list_species( species = session.execute( sa.select( DetectedObject.specific_label.label("name"), + sa.func.min(DetectedObject.sequence_id).label("sequence_id"), sa.func.count(DetectedObject.id).label("num_detections"), sa.func.count(DetectedObject.sequence_id.distinct()).label( "num_occurrences" @@ -591,17 +592,24 @@ def list_species( ).all() examples = ( - sa.select(DetectedObject) - .where(DetectedObject.specific_label.in_([sp.name for sp in species])) - .limit(num_examples) - .order_by(DetectedObject.specific_label_score.desc()) + session.execute( + sa.select(DetectedObject) + .where(DetectedObject.specific_label.in_([sp.name for sp in species])) + .limit(num_examples) + .order_by(DetectedObject.specific_label_score.desc()) + ) + .unique() + .scalars() + .all() ) + metadata_by_name = {} examples_by_name = {} - for detection in session.execute(examples).unique().scalars().all(): - examples_by_name.setdefault(detection.specific_label, []).append(detection) - - metadata_by_name = {sp.name: sp for sp in species} + for sp in species: + metadata_by_name[sp.name] = sp + examples_by_name[sp.name] = [ + ex for ex in examples if ex.sequence_id == sp.sequence_id + ] taxa = [ TaxonListItem( From 7a0af6667a32aa0ab60191b3cba158e021c91a37 Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Tue, 18 Apr 2023 00:34:17 +0100 Subject: [PATCH 46/53] Fix species examples, but slow --- trapdata/db/models/detections.py | 51 +++++++++++++++++++++++--------- 1 file changed, 37 insertions(+), 14 deletions(-) diff --git a/trapdata/db/models/detections.py b/trapdata/db/models/detections.py index 91c41e77..54520f26 100644 --- a/trapdata/db/models/detections.py +++ b/trapdata/db/models/detections.py @@ -591,25 +591,48 @@ def list_species( .offset(offset) ).all() - examples = ( - session.execute( - sa.select(DetectedObject) - .where(DetectedObject.specific_label.in_([sp.name for sp in species])) - .limit(num_examples) - .order_by(DetectedObject.specific_label_score.desc()) - ) - .unique() - .scalars() - .all() - ) + # examples = ( + # session.execute( + # sa.select(DetectedObject) + # .where(DetectedObject.specific_label.in_(sp.name for sp in species])) # @TODO not working! + # .where(models.TrapImage.base_path == str(image_base_path)) + # .where(DetectedObject.specific_label_score >= classification_threshold) + # .join( + # models.TrapImage, models.DetectedObject.image_id == models.TrapImage.id + # ) + # .limit(num_examples) + # .order_by(DetectedObject.specific_label_score.desc()) + # ) + # .unique() + # .scalars() + # .all() + # ) metadata_by_name = {} examples_by_name = {} for sp in species: metadata_by_name[sp.name] = sp - examples_by_name[sp.name] = [ - ex for ex in examples if ex.sequence_id == sp.sequence_id - ] + # matching_examples = [ex for ex in examples if ex.specific_label == sp.name] + matching_examples = ( + session.execute( + sa.select(DetectedObject) + .where(DetectedObject.specific_label == sp.name) + .where(models.TrapImage.base_path == str(image_base_path)) + .where(DetectedObject.specific_label_score >= classification_threshold) + .join( + models.TrapImage, + models.DetectedObject.image_id == models.TrapImage.id, + ) + .limit(num_examples) + .order_by(DetectedObject.specific_label_score.desc()) + .limit(num_examples) + ) + .unique() + .scalars() + .all() + ) + examples_by_name[sp.name] = matching_examples + print(sp.name, len(matching_examples)) taxa = [ TaxonListItem( From 08e5a3e54cf166123b143aa129343bdfb36c2760 Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Tue, 18 Apr 2023 00:37:09 +0100 Subject: [PATCH 47/53] Todo --- trapdata/db/models/detections.py | 1 + 1 file changed, 1 insertion(+) diff --git a/trapdata/db/models/detections.py b/trapdata/db/models/detections.py index 54520f26..cae9a6e5 100644 --- a/trapdata/db/models/detections.py +++ b/trapdata/db/models/detections.py @@ -568,6 +568,7 @@ def list_species( Return a list of unique species and example detections. @TODO compare this with list_species in occurrences.py + @TODO prefetch related and speed this up """ species = session.execute( sa.select( From 4ae15503a5e82e0894eaa772e3065944ae431e61 Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Tue, 18 Apr 2023 00:39:50 +0100 Subject: [PATCH 48/53] Increase number of examples --- trapdata/db/models/detections.py | 2 +- trapdata/db/models/occurrences.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/trapdata/db/models/detections.py b/trapdata/db/models/detections.py index cae9a6e5..9b10170b 100644 --- a/trapdata/db/models/detections.py +++ b/trapdata/db/models/detections.py @@ -559,7 +559,7 @@ def list_species( session: orm.Session, image_base_path: FilePath, classification_threshold: int = 0, - num_examples: int = 5, + num_examples: int = 10, media_url_base: Optional[str] = None, limit: Optional[int] = None, offset: int = 0, diff --git a/trapdata/db/models/occurrences.py b/trapdata/db/models/occurrences.py index 9c1ec7a2..6dbe7a72 100644 --- a/trapdata/db/models/occurrences.py +++ b/trapdata/db/models/occurrences.py @@ -60,7 +60,7 @@ def list_occurrences( image_base_path: FilePath, monitoring_session: Optional[models.MonitoringSession] = None, classification_threshold: float = -1, - num_examples: int = 5, + num_examples: int = 10, limit: Optional[int] = None, offset: int = 0, media_url_base: Optional[str] = None, From 40619248d47e986dd2fd8134bfaad93ca151069c Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Tue, 18 Apr 2023 00:50:33 +0100 Subject: [PATCH 49/53] Show species occurrences, not detections --- trapdata/db/models/detections.py | 45 ++++++++++++++++---------------- 1 file changed, 22 insertions(+), 23 deletions(-) diff --git a/trapdata/db/models/detections.py b/trapdata/db/models/detections.py index 9b10170b..d4fef281 100644 --- a/trapdata/db/models/detections.py +++ b/trapdata/db/models/detections.py @@ -23,7 +23,7 @@ class DetectionListItem(BaseModel): - id: int + id: Optional[int] = None cropped_image_path: Optional[FilePath] = None bbox: Optional[tuple[float, float, float, float]] = None area_pixels: Optional[float] = None @@ -544,13 +544,18 @@ def num_occurrences_for_event( return sesh.execute(query).scalar_one() +class TaxonOccurrenceListItem(BaseModel): + id: str + cropped_image_path: Optional[FilePath] = None + + class TaxonListItem(BaseModel): name: str genus: Optional[str] = None family: Optional[str] = None num_occurrences: Optional[int] = None num_detections: Optional[int] = None - examples: list[DetectionListItem] = list() + examples: list[TaxonOccurrenceListItem] = list() score_stats: Optional[dict[str, float]] = None training_examples: Optional[int] = None @@ -614,24 +619,20 @@ def list_species( for sp in species: metadata_by_name[sp.name] = sp # matching_examples = [ex for ex in examples if ex.specific_label == sp.name] - matching_examples = ( - session.execute( - sa.select(DetectedObject) - .where(DetectedObject.specific_label == sp.name) - .where(models.TrapImage.base_path == str(image_base_path)) - .where(DetectedObject.specific_label_score >= classification_threshold) - .join( - models.TrapImage, - models.DetectedObject.image_id == models.TrapImage.id, - ) - .limit(num_examples) - .order_by(DetectedObject.specific_label_score.desc()) - .limit(num_examples) + matching_examples = session.execute( + sa.select(DetectedObject.sequence_id, DetectedObject.path) + .where(DetectedObject.specific_label == sp.name) + .where(models.TrapImage.base_path == str(image_base_path)) + .where(DetectedObject.specific_label_score >= classification_threshold) + .join( + models.TrapImage, + models.DetectedObject.image_id == models.TrapImage.id, ) - .unique() - .scalars() - .all() - ) + .group_by(DetectedObject.sequence_id) + .limit(num_examples) + .order_by(DetectedObject.specific_label_score.desc()) + .limit(num_examples) + ).all() examples_by_name[sp.name] = matching_examples print(sp.name, len(matching_examples)) @@ -646,15 +647,13 @@ def list_species( "mean": metadata_by_name[name].score_mean, }, examples=[ - DetectionListItem( - id=detection.id, + TaxonOccurrenceListItem( + id=detection.sequence_id, cropped_image_path=media_url( detection.path, "crops", media_url_base=media_url_base, ), - height=detection.height, - width=detection.width, ) for detection in examples ], From 4c69ce55df768294ba09ab411ec2709ec27b02c5 Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Tue, 18 Apr 2023 00:52:27 +0100 Subject: [PATCH 50/53] Fix psql query --- trapdata/db/models/detections.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/trapdata/db/models/detections.py b/trapdata/db/models/detections.py index d4fef281..136ffbfd 100644 --- a/trapdata/db/models/detections.py +++ b/trapdata/db/models/detections.py @@ -628,7 +628,7 @@ def list_species( models.TrapImage, models.DetectedObject.image_id == models.TrapImage.id, ) - .group_by(DetectedObject.sequence_id) + .group_by(DetectedObject.sequence_id, DetectedObject.path) .limit(num_examples) .order_by(DetectedObject.specific_label_score.desc()) .limit(num_examples) From a0718e3c18ad854faf6ed852232048b8e195be70 Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Tue, 18 Apr 2023 00:31:07 +0000 Subject: [PATCH 51/53] Hacky way to get unique species occurrences --- trapdata/db/models/detections.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/trapdata/db/models/detections.py b/trapdata/db/models/detections.py index 136ffbfd..62a64205 100644 --- a/trapdata/db/models/detections.py +++ b/trapdata/db/models/detections.py @@ -619,18 +619,24 @@ def list_species( for sp in species: metadata_by_name[sp.name] = sp # matching_examples = [ex for ex in examples if ex.specific_label == sp.name] - matching_examples = session.execute( - sa.select(DetectedObject.sequence_id, DetectedObject.path) + matching_example_ids = session.execute( + sa.select( + DetectedObject.sequence_id, sa.func.min(DetectedObject.id).label("id") + ) .where(DetectedObject.specific_label == sp.name) .where(models.TrapImage.base_path == str(image_base_path)) .where(DetectedObject.specific_label_score >= classification_threshold) + .group_by(DetectedObject.sequence_id) .join( models.TrapImage, models.DetectedObject.image_id == models.TrapImage.id, ) - .group_by(DetectedObject.sequence_id, DetectedObject.path) - .limit(num_examples) - .order_by(DetectedObject.specific_label_score.desc()) + ).all() + matching_example_ids = [row.id for row in matching_example_ids] + # .group_by(DetectedObject.sequence_id, DetectedObject.path, DetectedObject.specific_label_score, models.TrapImage.base_path) + matching_examples = session.execute( + sa.select(DetectedObject.sequence_id, DetectedObject.path) + .where(DetectedObject.id.in_(matching_example_ids)) .limit(num_examples) ).all() examples_by_name[sp.name] = matching_examples From 93786010a5d8a18a344cb56897f62ebdbe25a11a Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Tue, 11 Jul 2023 04:20:03 +0000 Subject: [PATCH 52/53] Add gunicorn conf for API server --- trapdata/api/gunicorn_conf.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) create mode 100644 trapdata/api/gunicorn_conf.py diff --git a/trapdata/api/gunicorn_conf.py b/trapdata/api/gunicorn_conf.py new file mode 100644 index 00000000..8cabec15 --- /dev/null +++ b/trapdata/api/gunicorn_conf.py @@ -0,0 +1,14 @@ +# gunicorn_conf.py +from multiprocessing import cpu_count + +bind = "0.0.0.0:8000" + +# Worker Options +workers = cpu_count() + 1 +worker_class = 'uvicorn.workers.UvicornWorker' +timeout = 120 + +# Logging Options +loglevel = 'debug' +accesslog = '/home/debian/logs/access_log' +errorlog = '/home/debian/logs/error_log' From e14fe1a01c2daba72230f2018473528b64e352b9 Mon Sep 17 00:00:00 2001 From: Michael Bunsen Date: Wed, 9 Aug 2023 20:03:54 -0700 Subject: [PATCH 53/53] Updates for exporting occurrences (#54) * Support data import to AMI platform DB * Fix imports --- trapdata/cli/export.py | 22 +++++++++++++++++----- trapdata/common/filemanagement.py | 2 +- trapdata/common/logs.py | 2 +- trapdata/db/models/images.py | 7 ++++++- trapdata/db/models/occurrences.py | 6 +++++- 5 files changed, 30 insertions(+), 9 deletions(-) diff --git a/trapdata/cli/export.py b/trapdata/cli/export.py index cad216f7..789b2f1b 100644 --- a/trapdata/cli/export.py +++ b/trapdata/cli/export.py @@ -189,7 +189,7 @@ def sessions( @cli.command() def captures( - date: datetime.datetime, + date: Optional[datetime.datetime] = None, format: ExportFormat = ExportFormat.json, outfile: Optional[pathlib.Path] = None, ) -> Optional[str]: @@ -200,16 +200,28 @@ def captures( """ Session = get_session_class(settings.database_url) session = Session() + if date is not None: + event_dates = [date.date()] + else: + event_dates = [ + event.day + for event in get_monitoring_sessions_from_db( + db_path=settings.database_url, base_directory=settings.image_base_path + ) + ] events = get_monitoring_session_by_date( db_path=settings.database_url, base_directory=settings.image_base_path, - event_dates=[date.date()], + event_dates=event_dates, ) - if not len(events): + if date and not len(events): raise Exception(f"No Monitoring Event with date: {date.date()}") - event = events[0] - captures = get_monitoring_session_images(settings.database_url, event, limit=100) + captures = [] + for event in events: + captures += get_monitoring_session_images( + settings.database_url, event, limit=100 + ) [session.add(img) for img in captures] df = pd.DataFrame([img.report_detail().dict() for img in captures]) diff --git a/trapdata/common/filemanagement.py b/trapdata/common/filemanagement.py index 48840ae5..26a6a017 100644 --- a/trapdata/common/filemanagement.py +++ b/trapdata/common/filemanagement.py @@ -20,7 +20,7 @@ from . import constants from .logs import logger -from .types import FilePath +from .schemas import FilePath APP_NAME_SLUG = "AMI" EXIF_DATETIME_STR_FORMAT = "%Y:%m:%d %H:%M:%S" diff --git a/trapdata/common/logs.py b/trapdata/common/logs.py index e0c2f7cb..cf4b4e7e 100644 --- a/trapdata/common/logs.py +++ b/trapdata/common/logs.py @@ -3,7 +3,7 @@ import structlog structlog.configure( - wrapper_class=structlog.make_filtering_bound_logger(logging.INFO), + wrapper_class=structlog.make_filtering_bound_logger(logging.DEBUG), ) diff --git a/trapdata/db/models/images.py b/trapdata/db/models/images.py index 43378000..457f0bca 100644 --- a/trapdata/db/models/images.py +++ b/trapdata/db/models/images.py @@ -27,6 +27,9 @@ class CaptureListItem(BaseModel): class CaptureDetail(CaptureListItem): event: object + url: Optional[str] = None + event: object + deployment: str notes: Optional[str] detections: list filesize: int @@ -123,17 +126,19 @@ def report_data(self) -> CaptureListItem: return CaptureListItem( id=self.id, source_image=f"{constants.IMAGE_BASE_URL}vermont/snapshots/{self.path}", + path=self.path, timestamp=self.timestamp, last_read=self.last_read, last_processed=self.last_processed, in_queue=self.in_queue, num_detections=self.num_detected_objects, + event=self.monitoring_session.day, + deployment=self.monitoring_session.deployment, ) def report_detail(self) -> CaptureDetail: return CaptureDetail( **self.report_data().dict(), - event=self.monitoring_session.day, width=self.width, height=self.height, filesize=self.filesize, diff --git a/trapdata/db/models/occurrences.py b/trapdata/db/models/occurrences.py index 6dbe7a72..fc8bb072 100644 --- a/trapdata/db/models/occurrences.py +++ b/trapdata/db/models/occurrences.py @@ -15,7 +15,7 @@ from trapdata import db from trapdata.common.filemanagement import media_url -from trapdata.common.types import FilePath +from trapdata.common.schemas import FilePath from trapdata.db import models @@ -224,11 +224,15 @@ def get_unique_species_by_track( models.DetectedObject.id, models.DetectedObject.image_id.label("source_image_id"), models.TrapImage.path.label("source_image_path"), + models.TrapImage.width.label("source_image_width"), + models.TrapImage.height.label("source_image_height"), + models.TrapImage.filesize.label("source_image_filesize"), models.DetectedObject.specific_label.label("label"), models.DetectedObject.specific_label_score.label("score"), models.DetectedObject.path.label("cropped_image_path"), models.DetectedObject.sequence_id, models.DetectedObject.timestamp, + models.DetectedObject.bbox, ) .join( models.TrapImage, models.TrapImage.id == models.DetectedObject.image_id