From 84f8ae8385405f48c729626de62a5f9c0107000d Mon Sep 17 00:00:00 2001 From: Abhishek Kumar Date: Sun, 28 Jun 2026 11:28:20 +0530 Subject: [PATCH 1/2] Add voice input + async transcription(#547) --- app/api/routes/forms.py | 35 ++--- app/api/routes/input.py | 93 ++++++++++++- app/api/schemas/input.py | 10 ++ app/core/celery.py | 2 +- app/core/config.py | 9 +- app/db/repositories.py | 7 + app/services/input.py | 19 +++ app/services/whisper.py | 39 ++++++ app/tasks/transcribe.py | 109 ++++++++++++++++ tests/conftest.py | 6 + tests/test_v1_voice.py | 275 +++++++++++++++++++++++++++++++++++++++ 11 files changed, 572 insertions(+), 32 deletions(-) create mode 100644 app/services/whisper.py create mode 100644 app/tasks/transcribe.py create mode 100644 tests/test_v1_voice.py diff --git a/app/api/routes/forms.py b/app/api/routes/forms.py index 87e4a16..879fa39 100644 --- a/app/api/routes/forms.py +++ b/app/api/routes/forms.py @@ -11,7 +11,8 @@ ModelsResponse, TranscriptionResponse, ) -from app.core.config import OLLAMA_HOST, OLLAMA_MODEL, WHISPER_HOST, BASE_DIR, RETENTION_PERIOD_DAYS +from app.core.config import OLLAMA_HOST, OLLAMA_MODEL, BASE_DIR, RETENTION_PERIOD_DAYS +from app.services.whisper import call_whisper_asr from app.core.errors.base import AppError from app.db.repositories import create_form, get_template, get_form_submission, delete_form_submission from app.models import FormSubmission @@ -94,34 +95,16 @@ def transcribe(audio: UploadFile = File(...)): audio is streamed straight through to the local STT service and never persisted — no PII leaves the machine. """ - whisper_url = f"{WHISPER_HOST}/asr" - - files = { - "audio_file": ( - audio.filename or "audio.wav", + try: + text = call_whisper_asr( audio.file.read(), + audio.filename or "audio.wav", audio.content_type or "audio/wav", ) - } - params = {"task": "transcribe", "output": "json", "encode": "true"} - - try: - response = requests.post(whisper_url, params=params, files=files, timeout=120) - response.raise_for_status() - except requests.exceptions.ConnectionError: - raise AppError( - f"Could not connect to the speech-to-text service at {whisper_url}. " - "Please ensure the whisper service is running.", - status_code=503, - error_code="STT_UNAVAILABLE", - ) - except requests.exceptions.RequestException as e: - raise AppError(f"Transcription failed: {e}", status_code=502, error_code="TRANSCRIPTION_FAILED") - - try: - text = (response.json().get("text") or "").strip() - except ValueError: - text = response.text.strip() + except ConnectionError as exc: + raise AppError(str(exc), status_code=503, error_code="STT_UNAVAILABLE") + except RuntimeError as exc: + raise AppError(str(exc), status_code=502, error_code="TRANSCRIPTION_FAILED") return TranscriptionResponse(text=text) diff --git a/app/api/routes/input.py b/app/api/routes/input.py index 009bdf0..b677a34 100644 --- a/app/api/routes/input.py +++ b/app/api/routes/input.py @@ -1,19 +1,104 @@ +from datetime import date +from pathlib import Path from uuid import UUID -from fastapi import APIRouter, Depends +from fastapi import APIRouter, Depends, File, Form, UploadFile from fastapi.exceptions import RequestValidationError from sqlmodel import Session from app.api.deps import get_db from app.api.schemas.enums import InputStatus -from app.api.schemas.input import InputRecordResponse, TextInputRequest, TextInputResponse -from app.core.config import INPUT_POLL_INTERVAL_SECONDS +from app.api.schemas.input import ( + InputRecordResponse, + TextInputRequest, + TextInputResponse, + VoiceInputResponse, +) +from app.core.config import AUDIO_DIR, ESTIMATED_TRANSCRIPTION_SECONDS, INPUT_POLL_INTERVAL_SECONDS from app.core.errors.base import AppError -from app.db.repositories import create_input, get_input as repo_get_input +from app.db.repositories import create_input, create_job, get_input as repo_get_input, update_job +from app.models import Job from app.services.input import InputService +from app.services.whisper import check_whisper_available +from app.tasks.transcribe import transcribe_audio_task router = APIRouter(prefix="/input", tags=["input"]) +_ALLOWED_AUDIO_EXTS = {"wav", "mp3", "m4a", "ogg", "webm"} +_MAX_AUDIO_BYTES = 500 * 1024 * 1024 # 500 MB + + +@router.post("/voice", response_model=VoiceInputResponse, status_code=201) +def submit_voice_input( + audio_file: UploadFile = File(...), + station_id: str | None = Form(default=None), + responder_badge: str | None = Form(default=None), + incident_date_hint: date | None = Form(default=None), + db: Session = Depends(get_db), +): + # validate format + filename = audio_file.filename or "" + ext = Path(filename).suffix.lstrip(".").lower() + if not ext or ext not in _ALLOWED_AUDIO_EXTS: + raise AppError( + "Audio format not supported. Accepted formats: wav, mp3, m4a, ogg, webm", + status_code=415, + error_code="UNSUPPORTED_FORMAT", + detail={"accepted_formats": ["wav", "mp3", "m4a", "ogg", "webm"]}, + ) + + # validate size + content = audio_file.file.read() + if len(content) > _MAX_AUDIO_BYTES: + raise AppError( + "Audio file exceeds maximum size of 500MB", + status_code=413, + error_code="FILE_TOO_LARGE", + detail={"max_size_bytes": _MAX_AUDIO_BYTES, "received_size_bytes": len(content)}, + ) + + # check Whisper availability before queuing + if not check_whisper_available(): + raise AppError( + "Whisper transcription service is not available", + status_code=503, + error_code="LLM_UNAVAILABLE", + ) + + # build and persist the Input record + svc = InputService() + record = svc.build_voice_input( + original_filename=filename, + station_id=station_id, + responder_badge=responder_badge, + incident_date_hint=incident_date_hint, + ) + record = create_input(db, record) + + # save audio to DATA_DIR/audio/{input_id}.{ext} + AUDIO_DIR.mkdir(parents=True, exist_ok=True) + audio_path = AUDIO_DIR / f"{record.input_id}.{ext}" + audio_path.write_bytes(content) + + # create Job before dispatch to avoid race + job = Job(celery_task_id="", job_type="transcription", status="queued") + job = create_job(db, job) + + # dispatch and backfill celery_task_id + result = transcribe_audio_task.delay(str(record.input_id), str(audio_path), job.job_id) + job.celery_task_id = result.id + job = update_job(db, job) + + return VoiceInputResponse( + input_id=record.input_id, + status=record.status, + input_type=record.input_type, + job_id=job.job_id, + poll_url=f"/api/v1/input/{record.input_id}", + estimated_processing_seconds=ESTIMATED_TRANSCRIPTION_SECONDS, + created_at=record.created_at, + ) + @router.post("/text", response_model=TextInputResponse, status_code=201) def submit_text_input(body: TextInputRequest, db: Session = Depends(get_db)): diff --git a/app/api/schemas/input.py b/app/api/schemas/input.py index cc37daf..9d8b908 100644 --- a/app/api/schemas/input.py +++ b/app/api/schemas/input.py @@ -8,6 +8,16 @@ from app.api.schemas.enums import InputStatus, InputType +class VoiceInputResponse(BaseModel): + input_id: UUID + status: InputStatus + input_type: InputType + estimated_processing_seconds: int | None = None + created_at: datetime | None = None + job_id: str | None = None + poll_url: str | None = None + + class TextInputRequest(BaseModel): narrative: str = Field(min_length=20) station_id: str | None = None diff --git a/app/core/celery.py b/app/core/celery.py index f763702..a91146f 100644 --- a/app/core/celery.py +++ b/app/core/celery.py @@ -16,7 +16,7 @@ result_expires=86400, ) -celery_app.conf.include = ["app.tasks.fill", "app.tasks.purge"] +celery_app.conf.include = ["app.tasks.fill", "app.tasks.purge", "app.tasks.transcribe"] # Optional Celery Beat schedule — runs purge_old_submissions once a day. # Enable by running: celery -A app.core.celery beat diff --git a/app/core/config.py b/app/core/config.py index e0f41f8..2a11b56 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -65,4 +65,11 @@ FIREFORM_API_KEY = os.getenv("FIREFORM_API_KEY", "") # --- Data Retention -------------------------------------------------------- -RETENTION_PERIOD_DAYS = int(os.getenv("RETENTION_PERIOD_DAYS", "30")) \ No newline at end of file +RETENTION_PERIOD_DAYS = int(os.getenv("RETENTION_PERIOD_DAYS", "30")) + +# --- Audio storage -------------------------------------------------------- +# Voice input audio files land here: {AUDIO_DIR}/{input_id}.{ext} +AUDIO_DIR = DATA_DIR / "audio" + +# Advisory estimate returned in VoiceInputResponse.estimated_processing_seconds. +ESTIMATED_TRANSCRIPTION_SECONDS = int(os.getenv("ESTIMATED_TRANSCRIPTION_SECONDS", "30")) \ No newline at end of file diff --git a/app/db/repositories.py b/app/db/repositories.py index 9186d61..0935c13 100644 --- a/app/db/repositories.py +++ b/app/db/repositories.py @@ -81,3 +81,10 @@ def create_input(session: Session, input_obj: Input) -> Input: def get_input(session: Session, input_id: UUID) -> Input | None: return session.get(Input, input_id) + +def update_input(session: Session, input_obj: Input) -> Input: + session.add(input_obj) + session.commit() + session.refresh(input_obj) + return input_obj + diff --git a/app/services/input.py b/app/services/input.py index b04868c..a6fa229 100644 --- a/app/services/input.py +++ b/app/services/input.py @@ -5,6 +5,25 @@ class InputService: + def build_voice_input( + self, + original_filename: str, + station_id: str | None = None, + responder_badge: str | None = None, + incident_date_hint: date | None = None, + ) -> Input: + now = datetime.now(timezone.utc) + return Input( + input_type=InputType.voice, + status=InputStatus.queued, + original_filename=original_filename, + station_id=station_id, + responder_badge=responder_badge, + incident_date_hint=incident_date_hint, + created_at=now, + updated_at=now, + ) + def build_text_input( self, narrative: str, diff --git a/app/services/whisper.py b/app/services/whisper.py new file mode 100644 index 0000000..bbc13fb --- /dev/null +++ b/app/services/whisper.py @@ -0,0 +1,39 @@ +import requests + +from app.core.config import WHISPER_HOST + + +def call_whisper_asr(audio_bytes: bytes, filename: str, content_type: str) -> str: + """Post audio to the local Whisper ASR sidecar and return the transcript. + + Raises ConnectionError if the service is unreachable, RuntimeError for any + other HTTP failure. Callers map these to their own error codes. + """ + whisper_url = f"{WHISPER_HOST}/asr" + files = {"audio_file": (filename, audio_bytes, content_type)} + params = {"task": "transcribe", "output": "json", "encode": "true"} + + try: + response = requests.post(whisper_url, params=params, files=files, timeout=120) + response.raise_for_status() + except requests.exceptions.ConnectionError as exc: + raise ConnectionError( + f"Could not connect to the speech-to-text service at {whisper_url}. " + "Please ensure the whisper service is running." + ) from exc + except requests.exceptions.RequestException as exc: + raise RuntimeError(f"Transcription failed: {exc}") from exc + + try: + return (response.json().get("text") or "").strip() + except ValueError: + return response.text.strip() + + +def check_whisper_available() -> bool: + """Return True if the Whisper sidecar responds on its base URL.""" + try: + requests.get(WHISPER_HOST, timeout=3) + return True + except requests.exceptions.RequestException: + return False diff --git a/app/tasks/transcribe.py b/app/tasks/transcribe.py new file mode 100644 index 0000000..76517ec --- /dev/null +++ b/app/tasks/transcribe.py @@ -0,0 +1,109 @@ +import logging +import wave +from datetime import datetime, timezone +from pathlib import Path +from uuid import UUID + +from app.api.schemas.enums import InputStatus +from app.core.celery import celery_app +from app.db.database import get_session +from app.db.repositories import get_input, get_job_by_uuid, update_input, update_job +from app.services.whisper import call_whisper_asr + +logger = logging.getLogger(__name__) + +_CONTENT_TYPES: dict[str, str] = { + "wav": "audio/wav", + "mp3": "audio/mpeg", + "m4a": "audio/m4a", + "ogg": "audio/ogg", + "webm": "audio/webm", +} + + +def _wav_duration(path: Path) -> float | None: + """Return duration in seconds for a WAV file; None for all other formats or on error.""" + if path.suffix.lower() != ".wav": + return None + try: + with wave.open(str(path)) as wf: + return wf.getnframes() / wf.getframerate() + except Exception: + return None + + +def _fail_records( + session, + input_id: UUID, + job_id_str: str, + error_code: str, + message: str, +) -> None: + now = datetime.now(timezone.utc) + record = get_input(session, input_id) + if record: + record.status = InputStatus.failed + record.error_detail = message + record.updated_at = now + update_input(session, record) + job = get_job_by_uuid(session, job_id_str) + if job: + job.status = "failed" + job.error = {"error_code": error_code, "message": message} + job.updated_at = now + update_job(session, job) + + +@celery_app.task(name="transcribe_audio") +def transcribe_audio_task(input_id_str: str, audio_path: str, job_id_str: str) -> dict: + session = next(get_session()) + input_id = UUID(input_id_str) + try: + input_record = get_input(session, input_id) + job = get_job_by_uuid(session, job_id_str) + + # start: mark both records in-flight + now = datetime.now(timezone.utc) + input_record.status = InputStatus.transcribing + input_record.updated_at = now + update_input(session, input_record) + + job.status = "processing" + job.updated_at = now + update_job(session, job) + + # transcribe + p = Path(audio_path) + ext = p.suffix.lstrip(".") + text = call_whisper_asr(p.read_bytes(), p.name, _CONTENT_TYPES.get(ext, "audio/wav")) + + # success + words = text.split() + now = datetime.now(timezone.utc) + input_record.status = InputStatus.ready + input_record.transcript = text + input_record.character_count = len(text) + input_record.word_count = len(words) + input_record.audio_duration_seconds = _wav_duration(p) + input_record.updated_at = now + update_input(session, input_record) + + job.status = "completed" + job.result_url = f"/api/v1/input/{input_id_str}" + job.updated_at = now + update_job(session, job) + + return {"input_id": input_id_str, "job_id": job_id_str} + + except ConnectionError as exc: + logger.exception("transcribe_audio_task: Whisper unavailable for input %s", input_id_str) + _fail_records(session, input_id, job_id_str, "LLM_UNAVAILABLE", str(exc)) + raise + + except RuntimeError as exc: + logger.exception("transcribe_audio_task: transcription failed for input %s", input_id_str) + _fail_records(session, input_id, job_id_str, "TRANSCRIPTION_FAILED", str(exc)) + raise + + finally: + session.close() diff --git a/tests/conftest.py b/tests/conftest.py index 1f41d67..ab4370c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -55,6 +55,12 @@ def db(): yield session +@pytest.fixture +def test_engine(): + """Expose the shared in-memory engine for tests that need to open extra sessions.""" + return _engine + + # --------------------------------------------------------------------------- # Minimal PDF bytes (valid 1-page blank PDF) # --------------------------------------------------------------------------- diff --git a/tests/test_v1_voice.py b/tests/test_v1_voice.py new file mode 100644 index 0000000..fd2e7f3 --- /dev/null +++ b/tests/test_v1_voice.py @@ -0,0 +1,275 @@ +"""Tests for POST /api/v1/input/voice (endpoint) and transcribe_audio_task (task unit). + +Endpoint tests: dispatch is mocked — no broker, no Whisper, no filesystem writes. +Task tests: called directly with an injected test session and mocked call_whisper_asr. +""" + +import io +import pytest +from pathlib import Path +from unittest.mock import MagicMock, patch + +from sqlmodel import Session + +from app.api.schemas.enums import InputStatus, InputType +from app.db.repositories import get_input, get_job_by_uuid +from app.models import Input, Job +from app.tasks.transcribe import transcribe_audio_task + +VOICE_URL = "/api/v1/input/voice" + +# Minimal valid WAV bytes (44-byte header + silence) — passes extension check. +_WAV_BYTES = b"RIFF$\x00\x00\x00WAVEfmt \x10\x00\x00\x00\x01\x00\x01\x00\x80>\x00\x00\x00}\x00\x00\x02\x00\x10\x00data\x00\x00\x00\x00" + + +def _voice_post(client, filename="memo.wav", content=None, extra_fields=None): + """Helper: POST a voice upload, mocking dispatch and availability.""" + data = extra_fields or {} + files = {"audio_file": (filename, io.BytesIO(content or _WAV_BYTES), "audio/wav")} + return client, data, files + + +# --------------------------------------------------------------------------- +# POST /api/v1/input/voice — endpoint tests (dispatch mocked) +# --------------------------------------------------------------------------- + +class TestSubmitVoiceInput: + + def _post(self, client, filename="memo.wav", content=None, fields=None, whisper_up=True): + mock_result = MagicMock() + mock_result.id = "celery-task-uuid-001" + with patch("app.api.routes.input.check_whisper_available", return_value=whisper_up), \ + patch("app.api.routes.input.transcribe_audio_task") as mock_task, \ + patch("pathlib.Path.write_bytes"): + mock_task.delay.return_value = mock_result + resp = client.post( + VOICE_URL, + files={"audio_file": (filename, io.BytesIO(content or _WAV_BYTES), "audio/wav")}, + data=fields or {}, + ) + return resp, mock_task + + def test_201_returns_required_fields(self, client): + resp, _ = self._post(client) + assert resp.status_code == 201 + body = resp.json() + assert body["status"] == "queued" + assert body["input_type"] == "voice" + assert "input_id" in body + assert "job_id" in body + assert "poll_url" in body + assert "estimated_processing_seconds" in body + + def test_201_poll_url_points_to_input_record(self, client): + resp, _ = self._post(client) + body = resp.json() + assert body["poll_url"] == f"/api/v1/input/{body['input_id']}" + + def test_201_creates_queued_input_row(self, client, db): + resp, _ = self._post(client) + input_id = resp.json()["input_id"] + from uuid import UUID + record = get_input(db, UUID(input_id)) + assert record is not None + assert record.status == InputStatus.queued + assert record.input_type == InputType.voice + assert record.original_filename == "memo.wav" + + def test_201_creates_transcription_job_row(self, client, db): + resp, _ = self._post(client) + job_id = resp.json()["job_id"] + job = get_job_by_uuid(db, job_id) + assert job is not None + assert job.job_type == "transcription" + assert job.status == "queued" + assert job.celery_task_id == "celery-task-uuid-001" + + def test_201_dispatch_called_with_input_id_and_job_id(self, client, db): + resp, mock_task = self._post(client) + body = resp.json() + mock_task.delay.assert_called_once() + call_args = mock_task.delay.call_args[0] + assert call_args[0] == body["input_id"] # input_id_str + assert call_args[2] == body["job_id"] # job_id_str + # call_args[1] is the audio_path (filesystem path string) + assert body["input_id"] in call_args[1] + + def test_201_optional_metadata_stored_on_input(self, client, db): + resp, _ = self._post(client, fields={ + "station_id": "STA-045", + "responder_badge": "FD-7842", + "incident_date_hint": "2024-07-10", + }) + assert resp.status_code == 201 + from uuid import UUID + record = get_input(db, UUID(resp.json()["input_id"])) + assert record.station_id == "STA-045" + assert record.responder_badge == "FD-7842" + assert str(record.incident_date_hint) == "2024-07-10" + + def test_415_unsupported_format_rejected(self, client): + resp, _ = self._post(client, filename="memo.pdf") + assert resp.status_code == 415 + body = resp.json() + assert body["error_code"] == "UNSUPPORTED_FORMAT" + assert "accepted_formats" in body["detail"] + assert "wav" in body["detail"]["accepted_formats"] + + def test_415_no_extension_rejected(self, client): + resp, _ = self._post(client, filename="audiofile") + assert resp.status_code == 415 + assert resp.json()["error_code"] == "UNSUPPORTED_FORMAT" + + def test_413_oversize_file_rejected(self, client, monkeypatch): + monkeypatch.setattr("app.api.routes.input._MAX_AUDIO_BYTES", 5) + resp, _ = self._post(client, content=b"toolong") # 7 bytes > 5 + assert resp.status_code == 413 + body = resp.json() + assert body["error_code"] == "FILE_TOO_LARGE" + assert body["detail"]["max_size_bytes"] == 5 + assert body["detail"]["received_size_bytes"] == 7 + + def test_503_when_whisper_unavailable(self, client): + resp, _ = self._post(client, whisper_up=False) + assert resp.status_code == 503 + body = resp.json() + assert body["error_code"] == "LLM_UNAVAILABLE" + + def test_422_missing_audio_file(self, client): + with patch("app.api.routes.input.check_whisper_available", return_value=True): + resp = client.post(VOICE_URL, data={}) + assert resp.status_code == 422 + + +# --------------------------------------------------------------------------- +# transcribe_audio_task — unit tests (no broker, direct call, mocked Whisper) +# --------------------------------------------------------------------------- + +def _seed_input_and_job(session: Session): + """Insert a queued Input and Job, return both refreshed.""" + from datetime import datetime, timezone + now = datetime.now(timezone.utc) + inp = Input( + input_type=InputType.voice, + status=InputStatus.queued, + original_filename="test.wav", + created_at=now, + updated_at=now, + ) + session.add(inp) + job = Job(celery_task_id="celery-test-id", job_type="transcription", status="queued") + session.add(job) + session.commit() + session.refresh(inp) + session.refresh(job) + return inp, job + + +class TestTranscribeAudioTask: + + def _run_task(self, inp, job, test_engine, whisper_side_effect=None, whisper_return="Transcribed text here."): + """ + Call the task function directly against the in-memory DB. + Mocks get_session (so the task uses the test engine) and call_whisper_asr. + Also mocks Path.read_bytes so no filesystem access is needed. + """ + def _gen(): + with Session(test_engine) as s: + yield s + + with patch("app.tasks.transcribe.get_session", side_effect=_gen), \ + patch("app.tasks.transcribe.call_whisper_asr") as mock_whisper, \ + patch.object(Path, "read_bytes", return_value=b"fake_audio_bytes"): + if whisper_side_effect is not None: + mock_whisper.side_effect = whisper_side_effect + else: + mock_whisper.return_value = whisper_return + try: + transcribe_audio_task(str(inp.input_id), "/data/audio/test.wav", job.job_id) + except (ConnectionError, RuntimeError): + pass # expected on failure tests + + def test_success_input_becomes_ready_with_transcript(self, db, test_engine): + inp, job = _seed_input_and_job(db) + + self._run_task(inp, job, test_engine, whisper_return="Incident at Main St, two casualties.") + + db.refresh(inp) + assert inp.status == InputStatus.ready + assert inp.transcript == "Incident at Main St, two casualties." + assert inp.character_count == len("Incident at Main St, two casualties.") + assert inp.word_count == len("Incident at Main St, two casualties.".split()) + + def test_success_job_becomes_completed_with_result_url(self, db, test_engine): + inp, job = _seed_input_and_job(db) + + self._run_task(inp, job, test_engine, whisper_return="Incident at Main St.") + + db.refresh(job) + assert job.status == "completed" + assert job.result_url == f"/api/v1/input/{inp.input_id}" + + def test_connection_error_input_failed_error_detail(self, db, test_engine): + inp, job = _seed_input_and_job(db) + + self._run_task(inp, job, test_engine, whisper_side_effect=ConnectionError("Cannot reach Whisper")) + + db.refresh(inp) + assert inp.status == InputStatus.failed + assert "Cannot reach Whisper" in inp.error_detail + + def test_connection_error_job_failed_with_llm_unavailable_code(self, db, test_engine): + inp, job = _seed_input_and_job(db) + + self._run_task(inp, job, test_engine, whisper_side_effect=ConnectionError("Cannot reach Whisper")) + + db.refresh(job) + assert job.status == "failed" + assert job.error["error_code"] == "LLM_UNAVAILABLE" + assert "Cannot reach Whisper" in job.error["message"] + + def test_runtime_error_input_failed_error_detail(self, db, test_engine): + inp, job = _seed_input_and_job(db) + + self._run_task(inp, job, test_engine, whisper_side_effect=RuntimeError("HTTP 500 from Whisper")) + + db.refresh(inp) + assert inp.status == InputStatus.failed + assert "HTTP 500 from Whisper" in inp.error_detail + + def test_runtime_error_job_failed_with_transcription_failed_code(self, db, test_engine): + inp, job = _seed_input_and_job(db) + + self._run_task(inp, job, test_engine, whisper_side_effect=RuntimeError("HTTP 500 from Whisper")) + + db.refresh(job) + assert job.status == "failed" + assert job.error["error_code"] == "TRANSCRIPTION_FAILED" + assert "HTTP 500 from Whisper" in job.error["message"] + + def test_success_input_transitions_through_transcribing(self, db, test_engine): + """The task sets transcribing before calling Whisper and ready after.""" + inp, job = _seed_input_and_job(db) + observed_statuses: list[str] = [] + + original_update = __import__( + "app.db.repositories", fromlist=["update_input"] + ).update_input + + def tracking_update(session, input_obj): + observed_statuses.append(input_obj.status.value) + return original_update(session, input_obj) + + def _gen(): + with Session(test_engine) as s: + yield s + + with patch("app.tasks.transcribe.get_session", side_effect=_gen), \ + patch("app.tasks.transcribe.call_whisper_asr", return_value="Hello world text."), \ + patch("app.tasks.transcribe.update_input", side_effect=tracking_update), \ + patch.object(Path, "read_bytes", return_value=b"fake"): + transcribe_audio_task(str(inp.input_id), "/data/audio/test.wav", job.job_id) + + assert "transcribing" in observed_statuses + assert "ready" in observed_statuses + assert observed_statuses.index("transcribing") < observed_statuses.index("ready") From b623461dc8fe17c1fcd97d9fa927eb715ee757b2 Mon Sep 17 00:00:00 2001 From: Abhishek Kumar Date: Tue, 30 Jun 2026 00:08:05 +0530 Subject: [PATCH 2/2] Address #547 review --- app/api/routes/input.py | 54 ++++++++++++++++++--------------------- app/core/config.py | 14 +++++++++- app/services/input.py | 44 ++++++++++++++++++++++++++++++- app/services/whisper.py | 5 ++-- app/tasks/transcribe.py | 13 +++------- contracts/path/input.yaml | 6 ++--- tests/test_v1_voice.py | 16 ++++-------- 7 files changed, 94 insertions(+), 58 deletions(-) diff --git a/app/api/routes/input.py b/app/api/routes/input.py index b677a34..2cd2329 100644 --- a/app/api/routes/input.py +++ b/app/api/routes/input.py @@ -14,17 +14,19 @@ TextInputResponse, VoiceInputResponse, ) -from app.core.config import AUDIO_DIR, ESTIMATED_TRANSCRIPTION_SECONDS, INPUT_POLL_INTERVAL_SECONDS +from app.core.config import ( + ALLOWED_AUDIO_EXTENSIONS, + AUDIO_CONTENT_TYPES, + ESTIMATED_TRANSCRIPTION_SECONDS, + INPUT_POLL_INTERVAL_SECONDS, +) from app.core.errors.base import AppError -from app.db.repositories import create_input, create_job, get_input as repo_get_input, update_job -from app.models import Job +from app.db.repositories import create_input, get_input as repo_get_input from app.services.input import InputService from app.services.whisper import check_whisper_available -from app.tasks.transcribe import transcribe_audio_task router = APIRouter(prefix="/input", tags=["input"]) -_ALLOWED_AUDIO_EXTS = {"wav", "mp3", "m4a", "ogg", "webm"} _MAX_AUDIO_BYTES = 500 * 1024 * 1024 # 500 MB @@ -36,18 +38,25 @@ def submit_voice_input( incident_date_hint: date | None = Form(default=None), db: Session = Depends(get_db), ): - # validate format + # 415 — format check (HTTP concern, before any I/O) filename = audio_file.filename or "" ext = Path(filename).suffix.lstrip(".").lower() - if not ext or ext not in _ALLOWED_AUDIO_EXTS: + if not ext or ext not in ALLOWED_AUDIO_EXTENSIONS: raise AppError( "Audio format not supported. Accepted formats: wav, mp3, m4a, ogg, webm", status_code=415, error_code="UNSUPPORTED_FORMAT", - detail={"accepted_formats": ["wav", "mp3", "m4a", "ogg", "webm"]}, + detail={"accepted_formats": list(AUDIO_CONTENT_TYPES)}, ) - # validate size + # 413 — size check: fast path via Content-Length (no read), fallback via len(bytes) + if audio_file.size is not None and audio_file.size > _MAX_AUDIO_BYTES: + raise AppError( + "Audio file exceeds maximum size of 500MB", + status_code=413, + error_code="FILE_TOO_LARGE", + detail={"max_size_bytes": _MAX_AUDIO_BYTES, "received_size_bytes": audio_file.size}, + ) content = audio_file.file.read() if len(content) > _MAX_AUDIO_BYTES: raise AppError( @@ -57,37 +66,24 @@ def submit_voice_input( detail={"max_size_bytes": _MAX_AUDIO_BYTES, "received_size_bytes": len(content)}, ) - # check Whisper availability before queuing + # 503 — Whisper availability (HTTP concern) if not check_whisper_available(): raise AppError( "Whisper transcription service is not available", status_code=503, - error_code="LLM_UNAVAILABLE", + error_code="STT_UNAVAILABLE", ) - # build and persist the Input record svc = InputService() - record = svc.build_voice_input( - original_filename=filename, + record, job = svc.process_voice_upload( + session=db, + audio_content=content, + ext=ext, + filename=filename, station_id=station_id, responder_badge=responder_badge, incident_date_hint=incident_date_hint, ) - record = create_input(db, record) - - # save audio to DATA_DIR/audio/{input_id}.{ext} - AUDIO_DIR.mkdir(parents=True, exist_ok=True) - audio_path = AUDIO_DIR / f"{record.input_id}.{ext}" - audio_path.write_bytes(content) - - # create Job before dispatch to avoid race - job = Job(celery_task_id="", job_type="transcription", status="queued") - job = create_job(db, job) - - # dispatch and backfill celery_task_id - result = transcribe_audio_task.delay(str(record.input_id), str(audio_path), job.job_id) - job.celery_task_id = result.id - job = update_job(db, job) return VoiceInputResponse( input_id=record.input_id, diff --git a/app/core/config.py b/app/core/config.py index 2a11b56..4a7ca89 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -72,4 +72,16 @@ AUDIO_DIR = DATA_DIR / "audio" # Advisory estimate returned in VoiceInputResponse.estimated_processing_seconds. -ESTIMATED_TRANSCRIPTION_SECONDS = int(os.getenv("ESTIMATED_TRANSCRIPTION_SECONDS", "30")) \ No newline at end of file +ESTIMATED_TRANSCRIPTION_SECONDS = int(os.getenv("ESTIMATED_TRANSCRIPTION_SECONDS", "30")) + +# Canonical audio format mapping — single source of truth for both the route +# (membership check, 415 detail list) and the task (content-type lookup). +# Dict insertion order gives the stable list shown in error responses. +AUDIO_CONTENT_TYPES: dict[str, str] = { + "wav": "audio/wav", + "mp3": "audio/mpeg", + "m4a": "audio/m4a", + "ogg": "audio/ogg", + "webm": "audio/webm", +} +ALLOWED_AUDIO_EXTENSIONS: frozenset[str] = frozenset(AUDIO_CONTENT_TYPES) \ No newline at end of file diff --git a/app/services/input.py b/app/services/input.py index a6fa229..1aa2892 100644 --- a/app/services/input.py +++ b/app/services/input.py @@ -1,7 +1,12 @@ from datetime import date, datetime, timezone +from sqlmodel import Session + from app.api.schemas.enums import InputStatus, InputType -from app.models import Input +from app.core.config import AUDIO_CONTENT_TYPES, AUDIO_DIR +from app.db.repositories import create_input, create_job, update_job +from app.models import Input, Job +from app.tasks.transcribe import transcribe_audio_task class InputService: @@ -24,6 +29,43 @@ def build_voice_input( updated_at=now, ) + def process_voice_upload( + self, + session: Session, + audio_content: bytes, + ext: str, + filename: str, + station_id: str | None, + responder_badge: str | None, + incident_date_hint: date | None, + ) -> tuple[Input, Job]: + record = self.build_voice_input( + original_filename=filename, + station_id=station_id, + responder_badge=responder_badge, + incident_date_hint=incident_date_hint, + ) + record = create_input(session, record) + + AUDIO_DIR.mkdir(parents=True, exist_ok=True) + audio_path = AUDIO_DIR / f"{record.input_id}.{ext}" + audio_path.write_bytes(audio_content) + + # Create Job, dispatch, and backfill celery_task_id. + # On any failure after the file write, remove the orphaned audio file. + try: + job = Job(celery_task_id="", job_type="transcription", status="queued") + job = create_job(session, job) + result = transcribe_audio_task.delay(str(record.input_id), str(audio_path), job.job_id) + job.celery_task_id = result.id + job = update_job(session, job) + except Exception: + if audio_path.exists(): + audio_path.unlink() + raise + + return record, job + def build_text_input( self, narrative: str, diff --git a/app/services/whisper.py b/app/services/whisper.py index bbc13fb..118af8e 100644 --- a/app/services/whisper.py +++ b/app/services/whisper.py @@ -31,9 +31,8 @@ def call_whisper_asr(audio_bytes: bytes, filename: str, content_type: str) -> st def check_whisper_available() -> bool: - """Return True if the Whisper sidecar responds on its base URL.""" + """Return True if the Whisper sidecar responds with a successful status.""" try: - requests.get(WHISPER_HOST, timeout=3) - return True + return requests.get(WHISPER_HOST, timeout=3).ok except requests.exceptions.RequestException: return False diff --git a/app/tasks/transcribe.py b/app/tasks/transcribe.py index 76517ec..91999c6 100644 --- a/app/tasks/transcribe.py +++ b/app/tasks/transcribe.py @@ -6,20 +6,13 @@ from app.api.schemas.enums import InputStatus from app.core.celery import celery_app +from app.core.config import AUDIO_CONTENT_TYPES from app.db.database import get_session from app.db.repositories import get_input, get_job_by_uuid, update_input, update_job from app.services.whisper import call_whisper_asr logger = logging.getLogger(__name__) -_CONTENT_TYPES: dict[str, str] = { - "wav": "audio/wav", - "mp3": "audio/mpeg", - "m4a": "audio/m4a", - "ogg": "audio/ogg", - "webm": "audio/webm", -} - def _wav_duration(path: Path) -> float | None: """Return duration in seconds for a WAV file; None for all other formats or on error.""" @@ -75,7 +68,7 @@ def transcribe_audio_task(input_id_str: str, audio_path: str, job_id_str: str) - # transcribe p = Path(audio_path) ext = p.suffix.lstrip(".") - text = call_whisper_asr(p.read_bytes(), p.name, _CONTENT_TYPES.get(ext, "audio/wav")) + text = call_whisper_asr(p.read_bytes(), p.name, AUDIO_CONTENT_TYPES.get(ext, "audio/wav")) # success words = text.split() @@ -97,7 +90,7 @@ def transcribe_audio_task(input_id_str: str, audio_path: str, job_id_str: str) - except ConnectionError as exc: logger.exception("transcribe_audio_task: Whisper unavailable for input %s", input_id_str) - _fail_records(session, input_id, job_id_str, "LLM_UNAVAILABLE", str(exc)) + _fail_records(session, input_id, job_id_str, "STT_UNAVAILABLE", str(exc)) raise except RuntimeError as exc: diff --git a/contracts/path/input.yaml b/contracts/path/input.yaml index a2d6a57..def2f2c 100644 --- a/contracts/path/input.yaml +++ b/contracts/path/input.yaml @@ -89,14 +89,14 @@ voice: - ogg - webm "503": - description: Ollama/Whisper service unavailable + description: Whisper transcription service unavailable content: application/json: schema: $ref: "../schemas/common.yaml#/ErrorResponse" example: - error_code: "LLM_UNAVAILABLE" - message: "Ollama transcription service is not available" + error_code: "STT_UNAVAILABLE" + message: "Whisper transcription service is not available" retry_after_seconds: 30 text: diff --git a/tests/test_v1_voice.py b/tests/test_v1_voice.py index fd2e7f3..1ae2ab0 100644 --- a/tests/test_v1_voice.py +++ b/tests/test_v1_voice.py @@ -22,13 +22,6 @@ _WAV_BYTES = b"RIFF$\x00\x00\x00WAVEfmt \x10\x00\x00\x00\x01\x00\x01\x00\x80>\x00\x00\x00}\x00\x00\x02\x00\x10\x00data\x00\x00\x00\x00" -def _voice_post(client, filename="memo.wav", content=None, extra_fields=None): - """Helper: POST a voice upload, mocking dispatch and availability.""" - data = extra_fields or {} - files = {"audio_file": (filename, io.BytesIO(content or _WAV_BYTES), "audio/wav")} - return client, data, files - - # --------------------------------------------------------------------------- # POST /api/v1/input/voice — endpoint tests (dispatch mocked) # --------------------------------------------------------------------------- @@ -38,8 +31,9 @@ class TestSubmitVoiceInput: def _post(self, client, filename="memo.wav", content=None, fields=None, whisper_up=True): mock_result = MagicMock() mock_result.id = "celery-task-uuid-001" + # Dispatch is owned by InputService.process_voice_upload — patch there. with patch("app.api.routes.input.check_whisper_available", return_value=whisper_up), \ - patch("app.api.routes.input.transcribe_audio_task") as mock_task, \ + patch("app.services.input.transcribe_audio_task") as mock_task, \ patch("pathlib.Path.write_bytes"): mock_task.delay.return_value = mock_result resp = client.post( @@ -133,7 +127,7 @@ def test_503_when_whisper_unavailable(self, client): resp, _ = self._post(client, whisper_up=False) assert resp.status_code == 503 body = resp.json() - assert body["error_code"] == "LLM_UNAVAILABLE" + assert body["error_code"] == "STT_UNAVAILABLE" def test_422_missing_audio_file(self, client): with patch("app.api.routes.input.check_whisper_available", return_value=True): @@ -218,14 +212,14 @@ def test_connection_error_input_failed_error_detail(self, db, test_engine): assert inp.status == InputStatus.failed assert "Cannot reach Whisper" in inp.error_detail - def test_connection_error_job_failed_with_llm_unavailable_code(self, db, test_engine): + def test_connection_error_job_failed_with_stt_unavailable_code(self, db, test_engine): inp, job = _seed_input_and_job(db) self._run_task(inp, job, test_engine, whisper_side_effect=ConnectionError("Cannot reach Whisper")) db.refresh(job) assert job.status == "failed" - assert job.error["error_code"] == "LLM_UNAVAILABLE" + assert job.error["error_code"] == "STT_UNAVAILABLE" assert "Cannot reach Whisper" in job.error["message"] def test_runtime_error_input_failed_error_detail(self, db, test_engine):