diff --git a/backend/openapi-schema.yml b/backend/openapi-schema.yml index a01dd4c7..1b4f46db 100644 --- a/backend/openapi-schema.yml +++ b/backend/openapi-schema.yml @@ -33,7 +33,7 @@ components: type: string media_files: items: - $ref: '#/components/schemas/DocumentMedia' + $ref: '#/components/schemas/RemoteDocumentMedia' title: Media Files type: array name: @@ -66,7 +66,7 @@ components: title: Dependencies type: array document: - $ref: '#/components/schemas/Document' + $ref: '#/components/schemas/RemoteDocument' document_id: format: uuid title: Document Id @@ -229,7 +229,9 @@ components: type: string media_files: items: - $ref: '#/components/schemas/DocumentMedia' + anyOf: + - $ref: '#/components/schemas/LocalDocumentMedia' + - $ref: '#/components/schemas/RemoteDocumentMedia' title: Media Files type: array name: @@ -243,25 +245,6 @@ components: - media_files title: Document type: object - DocumentMedia: - properties: - content_type: - title: Content Type - type: string - tags: - items: - type: string - title: Tags - type: array - url: - title: Url - type: string - required: - - url - - content_type - - tags - title: DocumentMedia - type: object DocumentShareTokenBase: properties: can_write: @@ -325,7 +308,7 @@ components: type: string media_files: items: - $ref: '#/components/schemas/DocumentMedia' + $ref: '#/components/schemas/RemoteDocumentMedia' title: Media Files type: array name: @@ -422,6 +405,21 @@ components: title: Progress title: KeepaliveBody type: object + LocalDocumentMedia: + properties: + path: + title: Path + type: string + tags: + items: + type: string + title: Tags + type: array + required: + - tags + - path + title: LocalDocumentMedia + type: object LoginResponse: properties: token: @@ -458,6 +456,53 @@ components: title: Logged Out Redirect Url title: PublicConfig type: object + RemoteDocument: + properties: + changed_at: + title: Changed At + type: string + created_at: + title: Created At + type: string + id: + format: uuid + title: Id + type: string + media_files: + items: + $ref: '#/components/schemas/RemoteDocumentMedia' + title: Media Files + type: array + name: + title: Name + type: string + required: + - id + - name + - created_at + - changed_at + - media_files + title: RemoteDocument + type: object + RemoteDocumentMedia: + properties: + content_type: + title: Content Type + type: string + tags: + items: + type: string + title: Tags + type: array + url: + title: Url + type: string + required: + - tags + - url + - content_type + title: RemoteDocumentMedia + type: object SetDurationRequest: properties: duration: @@ -1176,7 +1221,7 @@ paths: application/json: schema: items: - $ref: '#/components/schemas/DocumentMedia' + $ref: '#/components/schemas/RemoteDocumentMedia' title: Response Get Document Media Api V1 Documents Document Id Media Files Get type: array diff --git a/backend/transcribee_backend/models/document.py b/backend/transcribee_backend/models/document.py index 06f1a4c8..1ac1e4c7 100644 --- a/backend/transcribee_backend/models/document.py +++ b/backend/transcribee_backend/models/document.py @@ -3,8 +3,8 @@ from pydantic.types import AwareDatetime from sqlmodel import DateTime, Field, Relationship, SQLModel -from transcribee_proto.api import Document as ApiDocument -from transcribee_proto.api import DocumentMedia as ApiDocumentMedia +from transcribee_proto.api import RemoteDocument as ApiDocument +from transcribee_proto.api import RemoteDocumentMedia as ApiDocumentMedia from transcribee_backend import media_storage diff --git a/backend/transcribee_backend/models/task.py b/backend/transcribee_backend/models/task.py index 2d925cc9..b73bb485 100644 --- a/backend/transcribee_backend/models/task.py +++ b/backend/transcribee_backend/models/task.py @@ -5,8 +5,8 @@ from typing import Any, Dict, List, Literal, Optional from sqlmodel import JSON, Column, Field, ForeignKey, Relationship, SQLModel, Uuid -from transcribee_proto.api import Document as ApiDocument from transcribee_proto.api import ExportTaskParameters, TaskType +from transcribee_proto.api import RemoteDocument as ApiDocument from typing_extensions import Self from transcribee_backend.config import settings diff --git a/backend/transcribee_backend/routers/document.py b/backend/transcribee_backend/routers/document.py index 9c7f8636..c76246d2 100644 --- a/backend/transcribee_backend/routers/document.py +++ b/backend/transcribee_backend/routers/document.py @@ -27,8 +27,8 @@ from sqlalchemy.sql.expression import desc from sqlmodel import Session, col, select from transcribee_proto.api import Document as ApiDocument -from transcribee_proto.api import DocumentMedia, ExportTaskParameters from transcribee_proto.api import DocumentWithAccessInfo as ApiDocumentWithAccessInfo +from transcribee_proto.api import ExportTaskParameters, RemoteDocumentMedia from transcribee_backend.auth import ( generate_share_token, @@ -436,7 +436,7 @@ def get_document( @document_router.get("/{document_id}/media_files/") def get_document_media( auth: AuthInfo = Depends(get_doc_min_readonly_auth), -) -> List[DocumentMedia]: +) -> List[RemoteDocumentMedia]: return auth.document.as_api_document().media_files diff --git a/desktop/desktop-backend/Cargo.lock b/desktop/desktop-backend/Cargo.lock index 51ecdd68..a111bdf4 100644 --- a/desktop/desktop-backend/Cargo.lock +++ b/desktop/desktop-backend/Cargo.lock @@ -11,6 +11,15 @@ dependencies = [ "memchr", ] +[[package]] +name = "android_system_properties" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "819e7219dbd41043ac279b19830f2efc897156490d7fd6ea916720117ee66311" +dependencies = [ + "libc", +] + [[package]] name = "anstream" version = "1.0.0" @@ -73,6 +82,12 @@ version = "1.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0" +[[package]] +name = "autocfg" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" + [[package]] name = "axum" version = "0.8.9" @@ -168,12 +183,36 @@ version = "1.11.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1e748733b7cbc798e1434b6ac524f0c1ff2ab456fe201501e6497c8417a4fc33" +[[package]] +name = "cc" +version = "1.2.62" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a1dce859f0832a7d088c4f1119888ab94ef4b5d6795d1ce05afb7fe159d79f98" +dependencies = [ + "find-msvc-tools", + "shlex", +] + [[package]] name = "cfg-if" version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9330f8b2ff13f34540b44e946ef35111825727b38d33286ef986142615121801" +[[package]] +name = "chrono" +version = "0.4.44" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c673075a2e0e5f4a1dde27ce9dee1ea4558c7ffe648f576438a20ca1d2acc4b0" +dependencies = [ + "iana-time-zone", + "js-sys", + "num-traits", + "serde", + "wasm-bindgen", + "windows-link", +] + [[package]] name = "clap" version = "4.6.1" @@ -181,6 +220,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1ddb117e43bbf7dacf0a4190fef4d345b9bad68dfc649cb349e7d17d28428e51" dependencies = [ "clap_builder", + "clap_derive", ] [[package]] @@ -195,6 +235,18 @@ dependencies = [ "strsim", ] +[[package]] +name = "clap_derive" +version = "4.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f2ce8604710f6733aa641a2b3731eaa1e8b3d9973d5e3565da11800813f997a9" +dependencies = [ + "heck", + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "clap_lex" version = "1.1.0" @@ -207,12 +259,19 @@ version = "1.0.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1d07550c9036bf2ae0c684c4297d503f838287c83c53686d05370d0e139ae570" +[[package]] +name = "core-foundation-sys" +version = "0.8.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "773648b94d0e5d620f64f280777445740e61fe701025087ec8b57f45c791888b" + [[package]] name = "desktop-backend" version = "0.1.0" dependencies = [ "axum", "axum-extra", + "chrono", "clap", "env_logger", "log", @@ -250,6 +309,12 @@ version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "877a4ace8713b0bcf2a4e7eec82529c029f1d0619886d18145fea96c3ffe5c0f" +[[package]] +name = "find-msvc-tools" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5baebc0774151f905a1a2cc41989300b1e6fbb29aff0ceffa1064fdd3088d582" + [[package]] name = "foldhash" version = "0.1.5" @@ -412,6 +477,30 @@ dependencies = [ "tower-service", ] +[[package]] +name = "iana-time-zone" +version = "0.1.65" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e31bc9ad994ba00e440a8aa5c9ef0ec67d5cb5e5cb0cc7f8b744a35b389cc470" +dependencies = [ + "android_system_properties", + "core-foundation-sys", + "iana-time-zone-haiku", + "js-sys", + "log", + "wasm-bindgen", + "windows-core", +] + +[[package]] +name = "iana-time-zone-haiku" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f31827a206f56af32e590ba56d5d2d085f558508192593743f16b2306495269f" +dependencies = [ + "cc", +] + [[package]] name = "id-arena" version = "2.3.0" @@ -525,6 +614,15 @@ dependencies = [ "windows-sys", ] +[[package]] +name = "num-traits" +version = "0.2.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" +dependencies = [ + "autocfg", +] + [[package]] name = "once_cell" version = "1.21.4" @@ -724,6 +822,12 @@ dependencies = [ "serde", ] +[[package]] +name = "shlex" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" + [[package]] name = "slab" version = "0.4.12" @@ -975,12 +1079,65 @@ dependencies = [ "semver", ] +[[package]] +name = "windows-core" +version = "0.62.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b8e83a14d34d0623b51dce9581199302a221863196a1dde71a7663a4c2be9deb" +dependencies = [ + "windows-implement", + "windows-interface", + "windows-link", + "windows-result", + "windows-strings", +] + +[[package]] +name = "windows-implement" +version = "0.60.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "053e2e040ab57b9dc951b72c264860db7eb3b0200ba345b4e4c3b14f67855ddf" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "windows-interface" +version = "0.59.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f316c4a2570ba26bbec722032c4099d8c8bc095efccdc15688708623367e358" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "windows-link" version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f0805222e57f7521d6a62e36fa9163bc891acd422f971defe97d64e70d0a4fe5" +[[package]] +name = "windows-result" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7781fa89eaf60850ac3d2da7af8e5242a5ea78d1a11c49bf2910bb5a73853eb5" +dependencies = [ + "windows-link", +] + +[[package]] +name = "windows-strings" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7837d08f69c77cf6b07689544538e017c1bfcf57e34b4c0ff58e6c2cd3b37091" +dependencies = [ + "windows-link", +] + [[package]] name = "windows-sys" version = "0.61.2" diff --git a/desktop/desktop-backend/Cargo.toml b/desktop/desktop-backend/Cargo.toml index 1758511c..3f6443b3 100644 --- a/desktop/desktop-backend/Cargo.toml +++ b/desktop/desktop-backend/Cargo.toml @@ -6,7 +6,8 @@ edition = "2024" [dependencies] axum = "0.8.9" axum-extra = { version = "0.12.6", features = ["query"] } -clap = { version = "4.6.1", optional = true, features = ["cargo"] } +chrono = { version = "0.4.44", features = ["serde"] } +clap = { version = "4.6.1", optional = true, features = ["cargo", "derive"] } env_logger = { version = "0.11.10", optional = true } log = "0.4.29" serde = { version = "1.0.228", features = ["derive", "rc"] } @@ -17,5 +18,5 @@ uuid = { version = "1.23.1", features = ["serde", "v4"] } standalone = ["dep:env_logger", "dep:clap"] [[bin]] -name = "standalone" +name = "standalone-backend" required-features = ["standalone"] diff --git a/desktop/desktop-backend/src/backend.rs b/desktop/desktop-backend/src/backend.rs index a605c1b6..8946e8b6 100644 --- a/desktop/desktop-backend/src/backend.rs +++ b/desktop/desktop-backend/src/backend.rs @@ -1,116 +1,39 @@ -use axum::Json; -use axum::extract::State; +use axum::extract::{Request, State}; +use axum::http::{StatusCode, header}; +use axum::middleware; +use axum::middleware::Next; +use axum::response::Response; use axum::routing::post; use axum::{Router, routing::get}; -use axum_extra::extract::Query; -use serde::{Deserialize, Serialize}; -use std::collections::HashMap; use std::net::{Ipv4Addr, SocketAddr, SocketAddrV4}; -use std::sync::{Arc, Mutex}; -use uuid::Uuid; -#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)] -enum TaskState { - NEW, - ASSIGNED, - COMPLETED, - FAILED, -} - -#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)] -#[serde(rename_all = "SCREAMING_SNAKE_CASE")] -enum TaskType { - IdentifySpeakers, - Transcribe, - Align, - Reencode, - Export, -} - -#[derive(Clone, Debug, Deserialize, Serialize)] -struct TaskAttempt { - progress: Option, -} - -#[derive(Clone, Debug, Deserialize, Serialize)] -struct Task { - id: Uuid, - task_type: TaskType, - state: TaskState, - dependencies: Vec, - current_attempt: Option, -} - -#[derive(Clone, Debug, Serialize)] -struct BackendState { - documents: HashMap>, - tasks: HashMap, -} - -#[derive(Clone, Debug, Serialize)] -struct ApiState { +use crate::handlers::{claim_unassigned_task, dump_state, noop}; +#[derive(Clone, Debug)] +struct ApiConfig { token: String, - state: Arc>, } -impl BackendState { - fn add_task(&mut self, task: Task) { - self.tasks.insert(task.id, task); +async fn worker_auth( + State(state): State, + req: Request, + next: Next, +) -> Result { + let auth_header = req + .headers() + .get(header::AUTHORIZATION) + .and_then(|header| header.to_str().ok()) + .ok_or(StatusCode::UNAUTHORIZED)?; + if auth_header == format!("Worker {}", state.token) { + Ok(next.run(req).await) + } else { + Err(StatusCode::UNAUTHORIZED) } } -#[derive(Deserialize)] -struct GetUnassingedTaskQuery { - #[serde(rename = "task_type")] - task_types: Vec, -} - -fn get_ready_task(tasks: &HashMap, task_types: &[TaskType]) -> Option { - 'task_loop: for task in tasks.values() { - if !task_types.contains(&task.task_type) { - continue; - } - for dependency in &task.dependencies { - if let Some(dep_task) = tasks.get(dependency) - && dep_task.state != TaskState::COMPLETED - { - continue 'task_loop; - } - } - if task.current_attempt.is_some() || task.state != TaskState::NEW { - continue; - } - return Some(task.id); - } - None -} - -async fn claim_unassigned_task( - State(state): State, - Query(query): Query, -) -> Json> { - let mut state = state.state.lock().unwrap(); - if let Some(task_id) = get_ready_task(&state.tasks, &query.task_types) { - let task = state.tasks.get_mut(&task_id).unwrap(); - (*task).current_attempt = Some(TaskAttempt { progress: None }); - (*task).state = TaskState::ASSIGNED; - return Json(Some(task.clone())); - } - return Json(None); -} - -async fn noop() -> Json<()> { - Json(()) -} -async fn dump_state(State(state): State) -> Json { - return Json(state.state.lock().unwrap().clone()); -} - pub struct BackendBuilder { port: Option, listener: Option, token: Option, - state: Arc>, } impl BackendBuilder { @@ -119,10 +42,6 @@ impl BackendBuilder { port: None, listener: None, token: None, - state: Arc::new(Mutex::new(BackendState { - documents: HashMap::new(), - tasks: HashMap::new(), - })), }; } pub fn with_token(mut self, token: String) -> Self { @@ -151,16 +70,8 @@ impl BackendBuilder { } fn get_router(&self) -> Router { - self.state.lock().unwrap().add_task(Task { - id: Uuid::new_v4(), - task_type: TaskType::Reencode, - state: TaskState::NEW, - dependencies: Vec::new(), - current_attempt: None, - }); - let state = ApiState { + let state = ApiConfig { token: self.token.clone().unwrap(), - state: self.state.clone(), }; let app = Router::new() .route("/", get(dump_state)) @@ -169,11 +80,15 @@ impl BackendBuilder { post(claim_unassigned_task), ) .route("/api/v1/tasks/{task_id}/keepalive/", post(noop)) + .route("/api/v1/tasks/{task_id}/mark_completed/", post(noop)) + .route("/api/v1/tasks/{task_id}/mark_failed/", post(noop)) + .route("/api/v1/documents/{document_id}/set_duration/", post(noop)) + .route( + "/api/v1/documents/{document_id}/add_media_file/", + post(noop), + ) + .route_layer(middleware::from_fn_with_state(state.clone(), worker_auth)) .with_state(state); - // f"documents/{task.document.id}/add_media_file/", - // self.api_client.post(url=f"tasks/{task_id}/mark_completed/", json=body) - // self.api_client.post(url=f"tasks/{task_id}/mark_failed/", json=body) - // self.api_client.post(f"tasks/{task_id}/keepalive/", json=body) return app; } diff --git a/desktop/desktop-backend/src/bin/standalone-backend.rs b/desktop/desktop-backend/src/bin/standalone-backend.rs new file mode 100644 index 00000000..d4842bd0 --- /dev/null +++ b/desktop/desktop-backend/src/bin/standalone-backend.rs @@ -0,0 +1,65 @@ +use std::collections::HashMap; + +use clap::Parser; +use clap::arg; +use clap::command; +use clap::value_parser; +use desktop_backend::BackendBuilder; +use desktop_backend::state::BACKEND_STATE; +use desktop_backend::state::Document; +use desktop_backend::state::MediaFile; +use desktop_backend::state::Task; +use desktop_backend::state::TaskParameters; +use desktop_backend::state::TaskState; +use desktop_backend::state::TaskType; +use uuid::Uuid; + +#[derive(Parser, Debug)] +#[command(version, about, long_about = None)] +struct Args { + #[arg(short, long)] + port: Option, + + #[arg(short, long)] + token: Option, + + #[arg(short, long)] + media_file: Option, +} + +#[tokio::main] +async fn main() -> std::io::Result<()> { + env_logger::init(); + + let args = Args::parse(); + + let token = args.token.unwrap_or("SECRET_TOKEN".to_string()); // TODO: generate random + let media_files = if let Some(path) = args.media_file { + vec![MediaFile::new(path)] + } else { + Vec::new() + }; + + let document = Document::new("Test".to_string(), media_files); + + let id = document.id; + BACKEND_STATE.lock().unwrap().add_document(document); + BACKEND_STATE.lock().unwrap().add_task(Task { + id: Uuid::new_v4(), + task_type: TaskType::Reencode, + state: TaskState::New, + dependencies: Vec::new(), + current_attempt: None, + document: id, + task_parameters: TaskParameters::NoParameters(HashMap::new()), + }); + + let mut backend = BackendBuilder::new().with_token(token.clone()); + if let Some(port) = args.port { + backend = backend.with_port(port); + } + + let local_addr = backend.bind().unwrap(); + log::info!("starting backend on http://{:?}", local_addr); + backend.serve().await +} diff --git a/desktop/desktop-backend/src/bin/standalone.rs b/desktop/desktop-backend/src/bin/standalone.rs deleted file mode 100644 index ed20c6c1..00000000 --- a/desktop/desktop-backend/src/bin/standalone.rs +++ /dev/null @@ -1,38 +0,0 @@ -use clap::arg; -use clap::command; -use clap::value_parser; -use desktop_backend::BackendBuilder; -#[tokio::main] -async fn main() -> std::io::Result<()> { - env_logger::init(); - let matches = command!() - .arg( - arg!( - -p --port "Backend port" - ) - .value_parser(value_parser!(u16)) - .required(false), - ) - .arg( - arg!( - -t --token "Worker token" - ) - .required(false), - ) - .get_matches(); - - let token = if let Some(token) = matches.get_one::("token") { - token.clone() - } else { - "SECRET_TOKEN".to_string() // TODO: generate random - }; - - let mut backend = BackendBuilder::new().with_token(token.clone()); - if let Some(port) = matches.get_one::("port") { - backend = backend.with_port(*port); - } - - let local_addr = backend.bind().unwrap(); - log::info!("starting backend on http://{:?}", local_addr); - backend.serve().await -} diff --git a/desktop/desktop-backend/src/handlers.rs b/desktop/desktop-backend/src/handlers.rs new file mode 100644 index 00000000..cce33905 --- /dev/null +++ b/desktop/desktop-backend/src/handlers.rs @@ -0,0 +1,25 @@ +use crate::state::{BACKEND_STATE, BackendState, Task, TaskType}; +use axum::{Json, body::Bytes}; +use axum_extra::extract::Query; +use serde::Deserialize; + +#[derive(Deserialize)] +pub struct GetUnassingedTaskQuery { + #[serde(rename = "task_type")] + task_types: Vec, +} + +pub async fn claim_unassigned_task( + Query(query): Query, +) -> Json> { + let mut state = BACKEND_STATE.lock().unwrap(); + return Json(state.claim_unassigned_task(&query.task_types)); +} + +pub async fn noop(body: Bytes) -> Json<()> { + log::debug!("noop req: {:?}", body); + Json(()) +} +pub async fn dump_state() -> Json { + return Json(BACKEND_STATE.lock().unwrap().clone()); +} diff --git a/desktop/desktop-backend/src/lib.rs b/desktop/desktop-backend/src/lib.rs index af218c03..cf6281bd 100644 --- a/desktop/desktop-backend/src/lib.rs +++ b/desktop/desktop-backend/src/lib.rs @@ -1,2 +1,4 @@ pub mod backend; pub use backend::BackendBuilder; +mod handlers; +pub mod state; diff --git a/desktop/desktop-backend/src/state.rs b/desktop/desktop-backend/src/state.rs new file mode 100644 index 00000000..17b6d547 --- /dev/null +++ b/desktop/desktop-backend/src/state.rs @@ -0,0 +1,145 @@ +use serde::{Deserialize, Serialize, Serializer}; +use std::collections::HashMap; +use std::sync::{LazyLock, Mutex}; +use uuid::Uuid; + +pub static BACKEND_STATE: LazyLock> = LazyLock::new(|| { + Mutex::new(BackendState { + documents: HashMap::new(), + tasks: HashMap::new(), + }) +}); + +#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)] +#[serde(rename_all = "SCREAMING_SNAKE_CASE")] +pub enum TaskState { + New, + Assigned, + Completed, + Failed, +} + +#[derive(Clone, Debug, Deserialize, Serialize, PartialEq)] +#[serde(rename_all = "SCREAMING_SNAKE_CASE")] +pub enum TaskType { + IdentifySpeakers, + Transcribe, + Align, + Reencode, + Export, +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct TaskAttempt { + progress: Option, +} + +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct MediaFile { + tags: Vec, + path: String, +} + +impl MediaFile { + pub fn new(path: String) -> Self { + MediaFile { + tags: Vec::new(), + path, + } + } +} +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct Document { + pub id: uuid::Uuid, + pub name: String, + pub created_at: chrono::DateTime, + pub changed_at: chrono::DateTime, + pub media_files: Vec, +} + +impl Document { + pub fn new(name: String, media_files: Vec) -> Self { + Document { + id: Uuid::new_v4(), + name, + created_at: chrono::Local::now(), + changed_at: chrono::Local::now(), + media_files, + } + } +} + +fn get_document(doc_uuid: &Uuid, serializer: S) -> Result +where + S: Serializer, +{ + return BACKEND_STATE + .lock() + .unwrap() + .documents + .get(doc_uuid) + .serialize(serializer); +} +#[derive(Clone, Debug, Deserialize, Serialize)] +#[serde(untagged)] +pub enum TaskParameters { + NoParameters(HashMap<(), ()>), +} +#[derive(Clone, Debug, Deserialize, Serialize)] +pub struct Task { + pub id: Uuid, + pub task_type: TaskType, + pub state: TaskState, + pub dependencies: Vec, + pub current_attempt: Option, + #[serde(serialize_with = "get_document")] + pub document: Uuid, + pub task_parameters: TaskParameters, +} + +#[derive(Clone, Debug, Serialize)] +pub struct BackendState { + documents: HashMap, + tasks: HashMap, +} + +impl BackendState { + pub fn add_task(&mut self, task: Task) { + self.tasks.insert(task.id, task); + } + pub fn add_document(&mut self, document: Document) { + self.documents.insert(document.id, document); + } + + fn get_ready_task<'a>(&'a mut self, task_types: &[TaskType]) -> Option<&'a mut Task> { + let uncompleted_tasks: Vec = self + .tasks + .values() + .filter(|t| t.state != TaskState::Completed) + .map(|x| x.id) + .collect(); + 'task_loop: for task in self.tasks.values_mut() { + if !task_types.contains(&task.task_type) { + continue; + } + for dependency in &task.dependencies { + if uncompleted_tasks.contains(dependency) { + continue 'task_loop; + } + } + if task.current_attempt.is_some() || task.state != TaskState::New { + continue; + } + return Some(task); + } + None + } + pub fn claim_unassigned_task(&mut self, task_types: &[TaskType]) -> Option { + if let Some(task) = self.get_ready_task(task_types) { + (*task).current_attempt = Some(TaskAttempt { progress: None }); + (*task).state = TaskState::Assigned; + return Some(task.clone()); + } + return None; + } +} diff --git a/desktop/src-tauri/Cargo.lock b/desktop/src-tauri/Cargo.lock index e011cf6d..58902843 100644 --- a/desktop/src-tauri/Cargo.lock +++ b/desktop/src-tauri/Cargo.lock @@ -650,8 +650,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c673075a2e0e5f4a1dde27ce9dee1ea4558c7ffe648f576438a20ca1d2acc4b0" dependencies = [ "iana-time-zone", + "js-sys", "num-traits", "serde", + "wasm-bindgen", "windows-link 0.2.1", ] @@ -903,6 +905,7 @@ version = "0.1.0" dependencies = [ "axum", "axum-extra", + "chrono", "log", "serde", "tokio", diff --git a/frontend/src/openapi-schema.ts b/frontend/src/openapi-schema.ts index 191ba9c4..12aa5ab7 100644 --- a/frontend/src/openapi-schema.ts +++ b/frontend/src/openapi-schema.ts @@ -176,7 +176,7 @@ export interface components { */ id: string; /** Media Files */ - media_files: components["schemas"]["DocumentMedia"][]; + media_files: components["schemas"]["RemoteDocumentMedia"][]; /** Name */ name: string; /** Tasks */ @@ -187,7 +187,7 @@ export interface components { current_attempt: components["schemas"]["TaskAttemptResponse"] | null; /** Dependencies */ dependencies: string[]; - document: components["schemas"]["Document"]; + document: components["schemas"]["RemoteDocument"]; /** * Document Id * Format: uuid @@ -286,19 +286,10 @@ export interface components { */ id: string; /** Media Files */ - media_files: components["schemas"]["DocumentMedia"][]; + media_files: (components["schemas"]["LocalDocumentMedia"] | components["schemas"]["RemoteDocumentMedia"])[]; /** Name */ name: string; }; - /** DocumentMedia */ - DocumentMedia: { - /** Content Type */ - content_type: string; - /** Tags */ - tags: string[]; - /** Url */ - url: string; - }; /** DocumentShareTokenBase */ DocumentShareTokenBase: { /** Can Write */ @@ -341,7 +332,7 @@ export interface components { */ id: string; /** Media Files */ - media_files: components["schemas"]["DocumentMedia"][]; + media_files: components["schemas"]["RemoteDocumentMedia"][]; /** Name */ name: string; }; @@ -395,6 +386,13 @@ export interface components { /** Progress */ progress?: number | null; }; + /** LocalDocumentMedia */ + LocalDocumentMedia: { + /** Path */ + path: string; + /** Tags */ + tags: string[]; + }; /** LoginResponse */ LoginResponse: { /** Token */ @@ -414,6 +412,31 @@ export interface components { /** Logged Out Redirect Url */ logged_out_redirect_url?: string | null; }; + /** RemoteDocument */ + RemoteDocument: { + /** Changed At */ + changed_at: string; + /** Created At */ + created_at: string; + /** + * Id + * Format: uuid + */ + id: string; + /** Media Files */ + media_files: components["schemas"]["RemoteDocumentMedia"][]; + /** Name */ + name: string; + }; + /** RemoteDocumentMedia */ + RemoteDocumentMedia: { + /** Content Type */ + content_type: string; + /** Tags */ + tags: string[]; + /** Url */ + url: string; + }; /** SetDurationRequest */ SetDurationRequest: { /** Duration */ @@ -883,7 +906,7 @@ export interface operations { /** @description Successful Response */ 200: { content: { - "application/json": components["schemas"]["DocumentMedia"][]; + "application/json": components["schemas"]["RemoteDocumentMedia"][]; }; }; /** @description Validation Error */ diff --git a/proto/transcribee_proto/api.py b/proto/transcribee_proto/api.py index 9e6bb691..fa648439 100644 --- a/proto/transcribee_proto/api.py +++ b/proto/transcribee_proto/api.py @@ -1,7 +1,7 @@ from __future__ import annotations import enum -from typing import Any, Dict, List, Literal, Optional +from typing import Any, Dict, Literal, Optional from uuid import UUID from pydantic import BaseModel @@ -15,10 +15,20 @@ class TaskType(str, enum.Enum): EXPORT = "EXPORT" -class DocumentMedia(BaseModel): +class BaseDocumentMedia(BaseModel): + tags: list[str] + + +class RemoteDocumentMedia(BaseDocumentMedia): url: str content_type: str - tags: List[str] + + +class LocalDocumentMedia(BaseDocumentMedia): + path: str + + +DocumentMedia = LocalDocumentMedia | RemoteDocumentMedia class Document(BaseModel): @@ -26,10 +36,18 @@ class Document(BaseModel): name: str created_at: str changed_at: str - media_files: List[DocumentMedia] + media_files: list[DocumentMedia] + + +class RemoteDocument(BaseModel): + id: UUID + name: str + created_at: str + changed_at: str + media_files: list[RemoteDocumentMedia] -class DocumentWithAccessInfo(Document): +class DocumentWithAccessInfo(RemoteDocument): can_write: bool has_full_access: bool diff --git a/worker/transcribee_worker/config.py b/worker/transcribee_worker/config.py index 158b3ba1..cdaf518f 100644 --- a/worker/transcribee_worker/config.py +++ b/worker/transcribee_worker/config.py @@ -1,6 +1,6 @@ import os from pathlib import Path -from typing import Dict, Optional +from typing import Dict, Literal, Optional from pydantic_settings import BaseSettings, SettingsConfigDict @@ -42,6 +42,8 @@ class Settings(BaseSettings): COMPUTE_TYPE: str = "int8" + WORKER_TYPE: Literal["web", "desktop"] = "web" + model_config = SettingsConfigDict(env_file=".env") def setup_env_vars(self): diff --git a/worker/transcribee_worker/worker.py b/worker/transcribee_worker/worker.py index 1c38f563..ac4a85f5 100644 --- a/worker/transcribee_worker/worker.py +++ b/worker/transcribee_worker/worker.py @@ -6,7 +6,7 @@ import traceback from contextlib import asynccontextmanager from pathlib import Path -from typing import Any, AsyncGenerator, Optional, Tuple +from typing import Any, AsyncGenerator, Optional from uuid import UUID import automerge @@ -16,9 +16,12 @@ from transcribee_proto.api import ( AlignTask, AssignedTask, + BaseDocumentMedia, ExportFormat, ExportTask, + LocalDocumentMedia, ReencodeTask, + RemoteDocumentMedia, SpeakerIdentificationTask, TaskType, TranscribeTask, @@ -137,29 +140,36 @@ def _get_tmpfile(self, filename: str) -> Path: raise ValueError("`tmpdir` must be set") return self.tmpdir / filename - def get_document_audio_bytes( + def download_media(self, media_file: RemoteDocumentMedia) -> Path: + logging.debug(f"loading audio. {media_file=}") + response = self.api_client.get(media_file.url) + extension = mimetypes.guess_extension(media_file.content_type) + path = self._get_tmpfile(f"doc_audio{extension}") + with open(path, "wb") as f: + f.write(response.content) + return path + + def find_document_audio_media_file( self, document: ApiDocument - ) -> Optional[Tuple[bytes, str]]: + ) -> BaseDocumentMedia | None: logging.debug(f"Getting audio. {document=}") if not document.media_files: return - media_file = document.media_files[0] + media_file: BaseDocumentMedia = document.media_files[0] for mf in document.media_files: if "profile:mp3" in mf.tags: media_file = mf break - response = self.api_client.get(media_file.url) - return response.content, media_file.content_type + return media_file def get_document_audio_path(self, document: ApiDocument) -> Optional[Path]: - b = self.get_document_audio_bytes(document=document) - if b is not None: - b, ct = b - extension = mimetypes.guess_extension(ct) - path = self._get_tmpfile(f"doc_audio{extension}") - with open(path, "wb") as f: - f.write(b) - return path + mf = self.find_document_audio_media_file(document) + if not mf: + return + if settings.WORKER_TYPE == "web" and isinstance(mf, RemoteDocumentMedia): + return self.download_media(mf) + elif settings.WORKER_TYPE == "desktop" and isinstance(mf, LocalDocumentMedia): + return Path(mf.path) if mf.path else None def load_document_audio(self, document: ApiDocument) -> npt.NDArray: document_audio = self.get_document_audio_path(document)