diff --git a/app/api/routes/forms.py b/app/api/routes/forms.py index 87e4a16..879fa39 100644 --- a/app/api/routes/forms.py +++ b/app/api/routes/forms.py @@ -11,7 +11,8 @@ ModelsResponse, TranscriptionResponse, ) -from app.core.config import OLLAMA_HOST, OLLAMA_MODEL, WHISPER_HOST, BASE_DIR, RETENTION_PERIOD_DAYS +from app.core.config import OLLAMA_HOST, OLLAMA_MODEL, BASE_DIR, RETENTION_PERIOD_DAYS +from app.services.whisper import call_whisper_asr from app.core.errors.base import AppError from app.db.repositories import create_form, get_template, get_form_submission, delete_form_submission from app.models import FormSubmission @@ -94,34 +95,16 @@ def transcribe(audio: UploadFile = File(...)): audio is streamed straight through to the local STT service and never persisted — no PII leaves the machine. """ - whisper_url = f"{WHISPER_HOST}/asr" - - files = { - "audio_file": ( - audio.filename or "audio.wav", + try: + text = call_whisper_asr( audio.file.read(), + audio.filename or "audio.wav", audio.content_type or "audio/wav", ) - } - params = {"task": "transcribe", "output": "json", "encode": "true"} - - try: - response = requests.post(whisper_url, params=params, files=files, timeout=120) - response.raise_for_status() - except requests.exceptions.ConnectionError: - raise AppError( - f"Could not connect to the speech-to-text service at {whisper_url}. " - "Please ensure the whisper service is running.", - status_code=503, - error_code="STT_UNAVAILABLE", - ) - except requests.exceptions.RequestException as e: - raise AppError(f"Transcription failed: {e}", status_code=502, error_code="TRANSCRIPTION_FAILED") - - try: - text = (response.json().get("text") or "").strip() - except ValueError: - text = response.text.strip() + except ConnectionError as exc: + raise AppError(str(exc), status_code=503, error_code="STT_UNAVAILABLE") + except RuntimeError as exc: + raise AppError(str(exc), status_code=502, error_code="TRANSCRIPTION_FAILED") return TranscriptionResponse(text=text) diff --git a/app/api/routes/input.py b/app/api/routes/input.py index 009bdf0..2cd2329 100644 --- a/app/api/routes/input.py +++ b/app/api/routes/input.py @@ -1,19 +1,100 @@ +from datetime import date +from pathlib import Path from uuid import UUID -from fastapi import APIRouter, Depends +from fastapi import APIRouter, Depends, File, Form, UploadFile from fastapi.exceptions import RequestValidationError from sqlmodel import Session from app.api.deps import get_db from app.api.schemas.enums import InputStatus -from app.api.schemas.input import InputRecordResponse, TextInputRequest, TextInputResponse -from app.core.config import INPUT_POLL_INTERVAL_SECONDS +from app.api.schemas.input import ( + InputRecordResponse, + TextInputRequest, + TextInputResponse, + VoiceInputResponse, +) +from app.core.config import ( + ALLOWED_AUDIO_EXTENSIONS, + AUDIO_CONTENT_TYPES, + ESTIMATED_TRANSCRIPTION_SECONDS, + INPUT_POLL_INTERVAL_SECONDS, +) from app.core.errors.base import AppError from app.db.repositories import create_input, get_input as repo_get_input from app.services.input import InputService +from app.services.whisper import check_whisper_available router = APIRouter(prefix="/input", tags=["input"]) +_MAX_AUDIO_BYTES = 500 * 1024 * 1024 # 500 MB + + +@router.post("/voice", response_model=VoiceInputResponse, status_code=201) +def submit_voice_input( + audio_file: UploadFile = File(...), + station_id: str | None = Form(default=None), + responder_badge: str | None = Form(default=None), + incident_date_hint: date | None = Form(default=None), + db: Session = Depends(get_db), +): + # 415 — format check (HTTP concern, before any I/O) + filename = audio_file.filename or "" + ext = Path(filename).suffix.lstrip(".").lower() + if not ext or ext not in ALLOWED_AUDIO_EXTENSIONS: + raise AppError( + "Audio format not supported. Accepted formats: wav, mp3, m4a, ogg, webm", + status_code=415, + error_code="UNSUPPORTED_FORMAT", + detail={"accepted_formats": list(AUDIO_CONTENT_TYPES)}, + ) + + # 413 — size check: fast path via Content-Length (no read), fallback via len(bytes) + if audio_file.size is not None and audio_file.size > _MAX_AUDIO_BYTES: + raise AppError( + "Audio file exceeds maximum size of 500MB", + status_code=413, + error_code="FILE_TOO_LARGE", + detail={"max_size_bytes": _MAX_AUDIO_BYTES, "received_size_bytes": audio_file.size}, + ) + content = audio_file.file.read() + if len(content) > _MAX_AUDIO_BYTES: + raise AppError( + "Audio file exceeds maximum size of 500MB", + status_code=413, + error_code="FILE_TOO_LARGE", + detail={"max_size_bytes": _MAX_AUDIO_BYTES, "received_size_bytes": len(content)}, + ) + + # 503 — Whisper availability (HTTP concern) + if not check_whisper_available(): + raise AppError( + "Whisper transcription service is not available", + status_code=503, + error_code="STT_UNAVAILABLE", + ) + + svc = InputService() + record, job = svc.process_voice_upload( + session=db, + audio_content=content, + ext=ext, + filename=filename, + station_id=station_id, + responder_badge=responder_badge, + incident_date_hint=incident_date_hint, + ) + + return VoiceInputResponse( + input_id=record.input_id, + status=record.status, + input_type=record.input_type, + job_id=job.job_id, + poll_url=f"/api/v1/input/{record.input_id}", + estimated_processing_seconds=ESTIMATED_TRANSCRIPTION_SECONDS, + created_at=record.created_at, + ) + @router.post("/text", response_model=TextInputResponse, status_code=201) def submit_text_input(body: TextInputRequest, db: Session = Depends(get_db)): diff --git a/app/api/schemas/input.py b/app/api/schemas/input.py index cc37daf..9d8b908 100644 --- a/app/api/schemas/input.py +++ b/app/api/schemas/input.py @@ -8,6 +8,16 @@ from app.api.schemas.enums import InputStatus, InputType +class VoiceInputResponse(BaseModel): + input_id: UUID + status: InputStatus + input_type: InputType + estimated_processing_seconds: int | None = None + created_at: datetime | None = None + job_id: str | None = None + poll_url: str | None = None + + class TextInputRequest(BaseModel): narrative: str = Field(min_length=20) station_id: str | None = None diff --git a/app/core/celery.py b/app/core/celery.py index f763702..a91146f 100644 --- a/app/core/celery.py +++ b/app/core/celery.py @@ -16,7 +16,7 @@ result_expires=86400, ) -celery_app.conf.include = ["app.tasks.fill", "app.tasks.purge"] +celery_app.conf.include = ["app.tasks.fill", "app.tasks.purge", "app.tasks.transcribe"] # Optional Celery Beat schedule — runs purge_old_submissions once a day. # Enable by running: celery -A app.core.celery beat diff --git a/app/core/config.py b/app/core/config.py index e0f41f8..4a7ca89 100644 --- a/app/core/config.py +++ b/app/core/config.py @@ -65,4 +65,23 @@ FIREFORM_API_KEY = os.getenv("FIREFORM_API_KEY", "") # --- Data Retention -------------------------------------------------------- -RETENTION_PERIOD_DAYS = int(os.getenv("RETENTION_PERIOD_DAYS", "30")) \ No newline at end of file +RETENTION_PERIOD_DAYS = int(os.getenv("RETENTION_PERIOD_DAYS", "30")) + +# --- Audio storage -------------------------------------------------------- +# Voice input audio files land here: {AUDIO_DIR}/{input_id}.{ext} +AUDIO_DIR = DATA_DIR / "audio" + +# Advisory estimate returned in VoiceInputResponse.estimated_processing_seconds. +ESTIMATED_TRANSCRIPTION_SECONDS = int(os.getenv("ESTIMATED_TRANSCRIPTION_SECONDS", "30")) + +# Canonical audio format mapping — single source of truth for both the route +# (membership check, 415 detail list) and the task (content-type lookup). +# Dict insertion order gives the stable list shown in error responses. +AUDIO_CONTENT_TYPES: dict[str, str] = { + "wav": "audio/wav", + "mp3": "audio/mpeg", + "m4a": "audio/m4a", + "ogg": "audio/ogg", + "webm": "audio/webm", +} +ALLOWED_AUDIO_EXTENSIONS: frozenset[str] = frozenset(AUDIO_CONTENT_TYPES) \ No newline at end of file diff --git a/app/db/repositories.py b/app/db/repositories.py index 9186d61..0935c13 100644 --- a/app/db/repositories.py +++ b/app/db/repositories.py @@ -81,3 +81,10 @@ def create_input(session: Session, input_obj: Input) -> Input: def get_input(session: Session, input_id: UUID) -> Input | None: return session.get(Input, input_id) + +def update_input(session: Session, input_obj: Input) -> Input: + session.add(input_obj) + session.commit() + session.refresh(input_obj) + return input_obj + diff --git a/app/services/input.py b/app/services/input.py index b04868c..1aa2892 100644 --- a/app/services/input.py +++ b/app/services/input.py @@ -1,10 +1,71 @@ from datetime import date, datetime, timezone +from sqlmodel import Session + from app.api.schemas.enums import InputStatus, InputType -from app.models import Input +from app.core.config import AUDIO_CONTENT_TYPES, AUDIO_DIR +from app.db.repositories import create_input, create_job, update_job +from app.models import Input, Job +from app.tasks.transcribe import transcribe_audio_task class InputService: + def build_voice_input( + self, + original_filename: str, + station_id: str | None = None, + responder_badge: str | None = None, + incident_date_hint: date | None = None, + ) -> Input: + now = datetime.now(timezone.utc) + return Input( + input_type=InputType.voice, + status=InputStatus.queued, + original_filename=original_filename, + station_id=station_id, + responder_badge=responder_badge, + incident_date_hint=incident_date_hint, + created_at=now, + updated_at=now, + ) + + def process_voice_upload( + self, + session: Session, + audio_content: bytes, + ext: str, + filename: str, + station_id: str | None, + responder_badge: str | None, + incident_date_hint: date | None, + ) -> tuple[Input, Job]: + record = self.build_voice_input( + original_filename=filename, + station_id=station_id, + responder_badge=responder_badge, + incident_date_hint=incident_date_hint, + ) + record = create_input(session, record) + + AUDIO_DIR.mkdir(parents=True, exist_ok=True) + audio_path = AUDIO_DIR / f"{record.input_id}.{ext}" + audio_path.write_bytes(audio_content) + + # Create Job, dispatch, and backfill celery_task_id. + # On any failure after the file write, remove the orphaned audio file. + try: + job = Job(celery_task_id="", job_type="transcription", status="queued") + job = create_job(session, job) + result = transcribe_audio_task.delay(str(record.input_id), str(audio_path), job.job_id) + job.celery_task_id = result.id + job = update_job(session, job) + except Exception: + if audio_path.exists(): + audio_path.unlink() + raise + + return record, job + def build_text_input( self, narrative: str, diff --git a/app/services/whisper.py b/app/services/whisper.py new file mode 100644 index 0000000..118af8e --- /dev/null +++ b/app/services/whisper.py @@ -0,0 +1,38 @@ +import requests + +from app.core.config import WHISPER_HOST + + +def call_whisper_asr(audio_bytes: bytes, filename: str, content_type: str) -> str: + """Post audio to the local Whisper ASR sidecar and return the transcript. + + Raises ConnectionError if the service is unreachable, RuntimeError for any + other HTTP failure. Callers map these to their own error codes. + """ + whisper_url = f"{WHISPER_HOST}/asr" + files = {"audio_file": (filename, audio_bytes, content_type)} + params = {"task": "transcribe", "output": "json", "encode": "true"} + + try: + response = requests.post(whisper_url, params=params, files=files, timeout=120) + response.raise_for_status() + except requests.exceptions.ConnectionError as exc: + raise ConnectionError( + f"Could not connect to the speech-to-text service at {whisper_url}. " + "Please ensure the whisper service is running." + ) from exc + except requests.exceptions.RequestException as exc: + raise RuntimeError(f"Transcription failed: {exc}") from exc + + try: + return (response.json().get("text") or "").strip() + except ValueError: + return response.text.strip() + + +def check_whisper_available() -> bool: + """Return True if the Whisper sidecar responds with a successful status.""" + try: + return requests.get(WHISPER_HOST, timeout=3).ok + except requests.exceptions.RequestException: + return False diff --git a/app/tasks/transcribe.py b/app/tasks/transcribe.py new file mode 100644 index 0000000..91999c6 --- /dev/null +++ b/app/tasks/transcribe.py @@ -0,0 +1,102 @@ +import logging +import wave +from datetime import datetime, timezone +from pathlib import Path +from uuid import UUID + +from app.api.schemas.enums import InputStatus +from app.core.celery import celery_app +from app.core.config import AUDIO_CONTENT_TYPES +from app.db.database import get_session +from app.db.repositories import get_input, get_job_by_uuid, update_input, update_job +from app.services.whisper import call_whisper_asr + +logger = logging.getLogger(__name__) + + +def _wav_duration(path: Path) -> float | None: + """Return duration in seconds for a WAV file; None for all other formats or on error.""" + if path.suffix.lower() != ".wav": + return None + try: + with wave.open(str(path)) as wf: + return wf.getnframes() / wf.getframerate() + except Exception: + return None + + +def _fail_records( + session, + input_id: UUID, + job_id_str: str, + error_code: str, + message: str, +) -> None: + now = datetime.now(timezone.utc) + record = get_input(session, input_id) + if record: + record.status = InputStatus.failed + record.error_detail = message + record.updated_at = now + update_input(session, record) + job = get_job_by_uuid(session, job_id_str) + if job: + job.status = "failed" + job.error = {"error_code": error_code, "message": message} + job.updated_at = now + update_job(session, job) + + +@celery_app.task(name="transcribe_audio") +def transcribe_audio_task(input_id_str: str, audio_path: str, job_id_str: str) -> dict: + session = next(get_session()) + input_id = UUID(input_id_str) + try: + input_record = get_input(session, input_id) + job = get_job_by_uuid(session, job_id_str) + + # start: mark both records in-flight + now = datetime.now(timezone.utc) + input_record.status = InputStatus.transcribing + input_record.updated_at = now + update_input(session, input_record) + + job.status = "processing" + job.updated_at = now + update_job(session, job) + + # transcribe + p = Path(audio_path) + ext = p.suffix.lstrip(".") + text = call_whisper_asr(p.read_bytes(), p.name, AUDIO_CONTENT_TYPES.get(ext, "audio/wav")) + + # success + words = text.split() + now = datetime.now(timezone.utc) + input_record.status = InputStatus.ready + input_record.transcript = text + input_record.character_count = len(text) + input_record.word_count = len(words) + input_record.audio_duration_seconds = _wav_duration(p) + input_record.updated_at = now + update_input(session, input_record) + + job.status = "completed" + job.result_url = f"/api/v1/input/{input_id_str}" + job.updated_at = now + update_job(session, job) + + return {"input_id": input_id_str, "job_id": job_id_str} + + except ConnectionError as exc: + logger.exception("transcribe_audio_task: Whisper unavailable for input %s", input_id_str) + _fail_records(session, input_id, job_id_str, "STT_UNAVAILABLE", str(exc)) + raise + + except RuntimeError as exc: + logger.exception("transcribe_audio_task: transcription failed for input %s", input_id_str) + _fail_records(session, input_id, job_id_str, "TRANSCRIPTION_FAILED", str(exc)) + raise + + finally: + session.close() diff --git a/contracts/path/input.yaml b/contracts/path/input.yaml index a2d6a57..def2f2c 100644 --- a/contracts/path/input.yaml +++ b/contracts/path/input.yaml @@ -89,14 +89,14 @@ voice: - ogg - webm "503": - description: Ollama/Whisper service unavailable + description: Whisper transcription service unavailable content: application/json: schema: $ref: "../schemas/common.yaml#/ErrorResponse" example: - error_code: "LLM_UNAVAILABLE" - message: "Ollama transcription service is not available" + error_code: "STT_UNAVAILABLE" + message: "Whisper transcription service is not available" retry_after_seconds: 30 text: diff --git a/tests/conftest.py b/tests/conftest.py index 1f41d67..ab4370c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -55,6 +55,12 @@ def db(): yield session +@pytest.fixture +def test_engine(): + """Expose the shared in-memory engine for tests that need to open extra sessions.""" + return _engine + + # --------------------------------------------------------------------------- # Minimal PDF bytes (valid 1-page blank PDF) # --------------------------------------------------------------------------- diff --git a/tests/test_v1_voice.py b/tests/test_v1_voice.py new file mode 100644 index 0000000..1ae2ab0 --- /dev/null +++ b/tests/test_v1_voice.py @@ -0,0 +1,269 @@ +"""Tests for POST /api/v1/input/voice (endpoint) and transcribe_audio_task (task unit). + +Endpoint tests: dispatch is mocked — no broker, no Whisper, no filesystem writes. +Task tests: called directly with an injected test session and mocked call_whisper_asr. +""" + +import io +import pytest +from pathlib import Path +from unittest.mock import MagicMock, patch + +from sqlmodel import Session + +from app.api.schemas.enums import InputStatus, InputType +from app.db.repositories import get_input, get_job_by_uuid +from app.models import Input, Job +from app.tasks.transcribe import transcribe_audio_task + +VOICE_URL = "/api/v1/input/voice" + +# Minimal valid WAV bytes (44-byte header + silence) — passes extension check. +_WAV_BYTES = b"RIFF$\x00\x00\x00WAVEfmt \x10\x00\x00\x00\x01\x00\x01\x00\x80>\x00\x00\x00}\x00\x00\x02\x00\x10\x00data\x00\x00\x00\x00" + + +# --------------------------------------------------------------------------- +# POST /api/v1/input/voice — endpoint tests (dispatch mocked) +# --------------------------------------------------------------------------- + +class TestSubmitVoiceInput: + + def _post(self, client, filename="memo.wav", content=None, fields=None, whisper_up=True): + mock_result = MagicMock() + mock_result.id = "celery-task-uuid-001" + # Dispatch is owned by InputService.process_voice_upload — patch there. + with patch("app.api.routes.input.check_whisper_available", return_value=whisper_up), \ + patch("app.services.input.transcribe_audio_task") as mock_task, \ + patch("pathlib.Path.write_bytes"): + mock_task.delay.return_value = mock_result + resp = client.post( + VOICE_URL, + files={"audio_file": (filename, io.BytesIO(content or _WAV_BYTES), "audio/wav")}, + data=fields or {}, + ) + return resp, mock_task + + def test_201_returns_required_fields(self, client): + resp, _ = self._post(client) + assert resp.status_code == 201 + body = resp.json() + assert body["status"] == "queued" + assert body["input_type"] == "voice" + assert "input_id" in body + assert "job_id" in body + assert "poll_url" in body + assert "estimated_processing_seconds" in body + + def test_201_poll_url_points_to_input_record(self, client): + resp, _ = self._post(client) + body = resp.json() + assert body["poll_url"] == f"/api/v1/input/{body['input_id']}" + + def test_201_creates_queued_input_row(self, client, db): + resp, _ = self._post(client) + input_id = resp.json()["input_id"] + from uuid import UUID + record = get_input(db, UUID(input_id)) + assert record is not None + assert record.status == InputStatus.queued + assert record.input_type == InputType.voice + assert record.original_filename == "memo.wav" + + def test_201_creates_transcription_job_row(self, client, db): + resp, _ = self._post(client) + job_id = resp.json()["job_id"] + job = get_job_by_uuid(db, job_id) + assert job is not None + assert job.job_type == "transcription" + assert job.status == "queued" + assert job.celery_task_id == "celery-task-uuid-001" + + def test_201_dispatch_called_with_input_id_and_job_id(self, client, db): + resp, mock_task = self._post(client) + body = resp.json() + mock_task.delay.assert_called_once() + call_args = mock_task.delay.call_args[0] + assert call_args[0] == body["input_id"] # input_id_str + assert call_args[2] == body["job_id"] # job_id_str + # call_args[1] is the audio_path (filesystem path string) + assert body["input_id"] in call_args[1] + + def test_201_optional_metadata_stored_on_input(self, client, db): + resp, _ = self._post(client, fields={ + "station_id": "STA-045", + "responder_badge": "FD-7842", + "incident_date_hint": "2024-07-10", + }) + assert resp.status_code == 201 + from uuid import UUID + record = get_input(db, UUID(resp.json()["input_id"])) + assert record.station_id == "STA-045" + assert record.responder_badge == "FD-7842" + assert str(record.incident_date_hint) == "2024-07-10" + + def test_415_unsupported_format_rejected(self, client): + resp, _ = self._post(client, filename="memo.pdf") + assert resp.status_code == 415 + body = resp.json() + assert body["error_code"] == "UNSUPPORTED_FORMAT" + assert "accepted_formats" in body["detail"] + assert "wav" in body["detail"]["accepted_formats"] + + def test_415_no_extension_rejected(self, client): + resp, _ = self._post(client, filename="audiofile") + assert resp.status_code == 415 + assert resp.json()["error_code"] == "UNSUPPORTED_FORMAT" + + def test_413_oversize_file_rejected(self, client, monkeypatch): + monkeypatch.setattr("app.api.routes.input._MAX_AUDIO_BYTES", 5) + resp, _ = self._post(client, content=b"toolong") # 7 bytes > 5 + assert resp.status_code == 413 + body = resp.json() + assert body["error_code"] == "FILE_TOO_LARGE" + assert body["detail"]["max_size_bytes"] == 5 + assert body["detail"]["received_size_bytes"] == 7 + + def test_503_when_whisper_unavailable(self, client): + resp, _ = self._post(client, whisper_up=False) + assert resp.status_code == 503 + body = resp.json() + assert body["error_code"] == "STT_UNAVAILABLE" + + def test_422_missing_audio_file(self, client): + with patch("app.api.routes.input.check_whisper_available", return_value=True): + resp = client.post(VOICE_URL, data={}) + assert resp.status_code == 422 + + +# --------------------------------------------------------------------------- +# transcribe_audio_task — unit tests (no broker, direct call, mocked Whisper) +# --------------------------------------------------------------------------- + +def _seed_input_and_job(session: Session): + """Insert a queued Input and Job, return both refreshed.""" + from datetime import datetime, timezone + now = datetime.now(timezone.utc) + inp = Input( + input_type=InputType.voice, + status=InputStatus.queued, + original_filename="test.wav", + created_at=now, + updated_at=now, + ) + session.add(inp) + job = Job(celery_task_id="celery-test-id", job_type="transcription", status="queued") + session.add(job) + session.commit() + session.refresh(inp) + session.refresh(job) + return inp, job + + +class TestTranscribeAudioTask: + + def _run_task(self, inp, job, test_engine, whisper_side_effect=None, whisper_return="Transcribed text here."): + """ + Call the task function directly against the in-memory DB. + Mocks get_session (so the task uses the test engine) and call_whisper_asr. + Also mocks Path.read_bytes so no filesystem access is needed. + """ + def _gen(): + with Session(test_engine) as s: + yield s + + with patch("app.tasks.transcribe.get_session", side_effect=_gen), \ + patch("app.tasks.transcribe.call_whisper_asr") as mock_whisper, \ + patch.object(Path, "read_bytes", return_value=b"fake_audio_bytes"): + if whisper_side_effect is not None: + mock_whisper.side_effect = whisper_side_effect + else: + mock_whisper.return_value = whisper_return + try: + transcribe_audio_task(str(inp.input_id), "/data/audio/test.wav", job.job_id) + except (ConnectionError, RuntimeError): + pass # expected on failure tests + + def test_success_input_becomes_ready_with_transcript(self, db, test_engine): + inp, job = _seed_input_and_job(db) + + self._run_task(inp, job, test_engine, whisper_return="Incident at Main St, two casualties.") + + db.refresh(inp) + assert inp.status == InputStatus.ready + assert inp.transcript == "Incident at Main St, two casualties." + assert inp.character_count == len("Incident at Main St, two casualties.") + assert inp.word_count == len("Incident at Main St, two casualties.".split()) + + def test_success_job_becomes_completed_with_result_url(self, db, test_engine): + inp, job = _seed_input_and_job(db) + + self._run_task(inp, job, test_engine, whisper_return="Incident at Main St.") + + db.refresh(job) + assert job.status == "completed" + assert job.result_url == f"/api/v1/input/{inp.input_id}" + + def test_connection_error_input_failed_error_detail(self, db, test_engine): + inp, job = _seed_input_and_job(db) + + self._run_task(inp, job, test_engine, whisper_side_effect=ConnectionError("Cannot reach Whisper")) + + db.refresh(inp) + assert inp.status == InputStatus.failed + assert "Cannot reach Whisper" in inp.error_detail + + def test_connection_error_job_failed_with_stt_unavailable_code(self, db, test_engine): + inp, job = _seed_input_and_job(db) + + self._run_task(inp, job, test_engine, whisper_side_effect=ConnectionError("Cannot reach Whisper")) + + db.refresh(job) + assert job.status == "failed" + assert job.error["error_code"] == "STT_UNAVAILABLE" + assert "Cannot reach Whisper" in job.error["message"] + + def test_runtime_error_input_failed_error_detail(self, db, test_engine): + inp, job = _seed_input_and_job(db) + + self._run_task(inp, job, test_engine, whisper_side_effect=RuntimeError("HTTP 500 from Whisper")) + + db.refresh(inp) + assert inp.status == InputStatus.failed + assert "HTTP 500 from Whisper" in inp.error_detail + + def test_runtime_error_job_failed_with_transcription_failed_code(self, db, test_engine): + inp, job = _seed_input_and_job(db) + + self._run_task(inp, job, test_engine, whisper_side_effect=RuntimeError("HTTP 500 from Whisper")) + + db.refresh(job) + assert job.status == "failed" + assert job.error["error_code"] == "TRANSCRIPTION_FAILED" + assert "HTTP 500 from Whisper" in job.error["message"] + + def test_success_input_transitions_through_transcribing(self, db, test_engine): + """The task sets transcribing before calling Whisper and ready after.""" + inp, job = _seed_input_and_job(db) + observed_statuses: list[str] = [] + + original_update = __import__( + "app.db.repositories", fromlist=["update_input"] + ).update_input + + def tracking_update(session, input_obj): + observed_statuses.append(input_obj.status.value) + return original_update(session, input_obj) + + def _gen(): + with Session(test_engine) as s: + yield s + + with patch("app.tasks.transcribe.get_session", side_effect=_gen), \ + patch("app.tasks.transcribe.call_whisper_asr", return_value="Hello world text."), \ + patch("app.tasks.transcribe.update_input", side_effect=tracking_update), \ + patch.object(Path, "read_bytes", return_value=b"fake"): + transcribe_audio_task(str(inp.input_id), "/data/audio/test.wav", job.job_id) + + assert "transcribing" in observed_statuses + assert "ready" in observed_statuses + assert observed_statuses.index("transcribing") < observed_statuses.index("ready")