diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..09a19c1 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,76 @@ +name: CI + +on: + push: + branches: [m, "claude/**"] + pull_request: + branches: [m] + +env: + CARGO_TERM_COLOR: always + RUSTFLAGS: -Dwarnings + +jobs: + check: + name: cargo check + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: dtolnay/rust-toolchain@stable + - uses: Swatinem/rust-cache@v2 + - run: cargo check --workspace --all-targets + + test: + name: cargo test + runs-on: ubuntu-latest + needs: check + steps: + - uses: actions/checkout@v4 + - uses: dtolnay/rust-toolchain@stable + - uses: Swatinem/rust-cache@v2 + - run: cargo test --workspace + + clippy: + name: clippy + runs-on: ubuntu-latest + needs: check + steps: + - uses: actions/checkout@v4 + - uses: dtolnay/rust-toolchain@stable + with: + components: clippy + - uses: Swatinem/rust-cache@v2 + - run: cargo clippy --workspace --all-targets -- -D warnings + + fmt: + name: rustfmt + runs-on: ubuntu-latest + env: + RUSTFLAGS: "" + steps: + - uses: actions/checkout@v4 + - uses: dtolnay/rust-toolchain@stable + with: + components: rustfmt + - run: cargo fmt --all -- --check + + doc: + name: cargo doc + runs-on: ubuntu-latest + needs: check + steps: + - uses: actions/checkout@v4 + - uses: dtolnay/rust-toolchain@stable + - uses: Swatinem/rust-cache@v2 + - run: cargo doc --workspace --no-deps + env: + RUSTDOCFLAGS: -Dwarnings + + msrv: + name: minimum supported rust version + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - uses: dtolnay/rust-toolchain@1.85.0 + - uses: Swatinem/rust-cache@v2 + - run: cargo check --workspace diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..6686961 --- /dev/null +++ b/.gitignore @@ -0,0 +1,7 @@ +/target +Cargo.lock +*.swp +*.swo +.DS_Store +.idea/ +.vscode/ diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..0d63b65 --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,86 @@ +[workspace] +resolver = "2" +members = [ + "crates/rusty_ai", + "crates/rusty_middleware", + "crates/rusty_ui_stream", + "crates/rusty_testing", + "crates/rusty_chatgpt", + "crates/rusty_claude", + "crates/rusty_gemini", + "crates/rusty_openai_compatible", + "crates/rusty_gemini_nano", + "crates/rusty_foundationmodels", + "crates/rusty_phi_silica", + "crates/rusty_browser", + "crates/rusty_ollama", + "examples/basic_text", + "examples/stream_text", + "examples/generate_object", + "examples/stream_object", + "examples/tool_loop", + "examples/multimodal", + "examples/local_android", + "examples/local_apple", + "examples/local_windows", + "examples/router", +] + +[workspace.package] +version = "0.1.0" +edition = "2021" +rust-version = "1.85" +license = "MPL-2.0" +repository = "https://github.com/undivisible/rusty_ai" + +[workspace.dependencies] +# Core +rusty_ai = { path = "crates/rusty_ai" } +rusty_middleware = { path = "crates/rusty_middleware" } +rusty_ui_stream = { path = "crates/rusty_ui_stream" } +rusty_testing = { path = "crates/rusty_testing" } + +# Cloud providers +rusty_chatgpt = { path = "crates/rusty_chatgpt" } +rusty_claude = { path = "crates/rusty_claude" } +rusty_gemini = { path = "crates/rusty_gemini" } +rusty_openai_compatible = { path = "crates/rusty_openai_compatible" } + +# Local / platform runtimes +rusty_gemini_nano = { path = "crates/rusty_gemini_nano" } +rusty_foundationmodels = { path = "crates/rusty_foundationmodels" } +rusty_phi_silica = { path = "crates/rusty_phi_silica" } +rusty_browser = { path = "crates/rusty_browser" } +rusty_ollama = { path = "crates/rusty_ollama" } + +# Async / streaming +tokio = { version = "1", features = ["full"] } +futures = "0.3" +async-trait = "0.1" +pin-project-lite = "0.2" +tokio-stream = "0.1" + +# Serialization +serde = { version = "1", features = ["derive"] } +serde_json = "1" +schemars = "0.8" + +# HTTP +reqwest = { version = "0.12", features = ["json", "stream"] } +reqwest-eventsource = "0.6" +eventsource-stream = "0.2" + +# Observability +tracing = "0.1" +uuid = { version = "1", features = ["v4", "serde"] } +chrono = { version = "0.4", features = ["serde"] } + +# Error handling +thiserror = "2" + +# Misc +bytes = "1" +url = "2" +secrecy = "0.10" +base64 = "0.22" +mime = "0.3" diff --git a/crates/rusty_ai/Cargo.toml b/crates/rusty_ai/Cargo.toml new file mode 100644 index 0000000..4a5dccd --- /dev/null +++ b/crates/rusty_ai/Cargo.toml @@ -0,0 +1,25 @@ +[package] +name = "rusty_ai" +version.workspace = true +edition.workspace = true +license.workspace = true +description = "Core traits, types, and abstractions for the Rusty AI SDK" + +[dependencies] +async-trait = { workspace = true } +futures = { workspace = true } +pin-project-lite = { workspace = true } +tokio = { workspace = true } +tokio-stream = { workspace = true } +serde = { workspace = true } +serde_json = { workspace = true } +schemars = { workspace = true } +thiserror = { workspace = true } +tracing = { workspace = true } +uuid = { workspace = true } +chrono = { workspace = true } +bytes = { workspace = true } +url = { workspace = true } +base64 = { workspace = true } +mime = { workspace = true } +secrecy = { workspace = true } diff --git a/crates/rusty_ai/src/capability.rs b/crates/rusty_ai/src/capability.rs new file mode 100644 index 0000000..4b6d0de --- /dev/null +++ b/crates/rusty_ai/src/capability.rs @@ -0,0 +1,67 @@ +use std::collections::BTreeSet; + +use serde::{Deserialize, Serialize}; + +/// A capability that a model may support. +#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum Capability { + TextInput, + TextOutput, + ImageInput, + ImageOutput, + Streaming, + ToolCalling, + StructuredOutput, + Embeddings, + LocalExecution, + SessionSupport, + PlatformNative, + ExtendedThinking, + VideoInput, + AudioInput, + AudioOutput, +} + +/// An ordered set of capabilities. +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +pub struct CapabilitySet { + inner: BTreeSet, +} + +impl CapabilitySet { + /// Create an empty capability set. + pub fn new() -> Self { + Self { + inner: BTreeSet::new(), + } + } + + /// Builder-style: add a capability and return self. + pub fn with(mut self, cap: Capability) -> Self { + self.inner.insert(cap); + self + } + + /// Check whether the set contains a specific capability. + pub fn has(&self, cap: &Capability) -> bool { + self.inner.contains(cap) + } + + /// Returns `true` if every capability in the slice is present. + pub fn supports_all(&self, caps: &[Capability]) -> bool { + caps.iter().all(|c| self.inner.contains(c)) + } + + /// Merge another set into this one. + pub fn merge(&mut self, other: &CapabilitySet) { + for cap in &other.inner { + self.inner.insert(cap.clone()); + } + } + + /// Iterate over the capabilities. + pub fn iter(&self) -> impl Iterator { + self.inner.iter() + } +} diff --git a/crates/rusty_ai/src/content.rs b/crates/rusty_ai/src/content.rs new file mode 100644 index 0000000..5aa73b0 --- /dev/null +++ b/crates/rusty_ai/src/content.rs @@ -0,0 +1,51 @@ +use serde::{Deserialize, Serialize}; + +use crate::tool::{ToolCallRequest, ToolCallResult}; + +/// A single part of a message's content. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type", rename_all = "snake_case")] +pub enum ContentPart { + Text { text: String }, + Image { data: ImageData }, + File { data: FileData }, + ToolCall { call: ToolCallRequest }, + ToolResult { result: ToolCallResult }, +} + +/// Image payload — either a URL reference or inline base64. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "source", rename_all = "snake_case")] +pub enum ImageData { + Url { + url: String, + detail: Option, + }, + Base64 { + media_type: String, + data: String, + }, +} + +/// Requested level of detail for image understanding. +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum ImageDetail { + Auto, + Low, + High, +} + +/// File payload — either a URL reference or inline base64. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "source", rename_all = "snake_case")] +pub enum FileData { + Url { + url: String, + media_type: Option, + }, + Base64 { + media_type: String, + data: String, + }, +} diff --git a/crates/rusty_ai/src/embedding.rs b/crates/rusty_ai/src/embedding.rs new file mode 100644 index 0000000..f36c914 --- /dev/null +++ b/crates/rusty_ai/src/embedding.rs @@ -0,0 +1,47 @@ +/// Compute the cosine similarity between two embedding vectors. +/// +/// Returns 0.0 if either vector has zero magnitude. +pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 { + assert_eq!( + a.len(), + b.len(), + "embedding vectors must have the same length" + ); + + let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum(); + let mag_a: f32 = a.iter().map(|x| x * x).sum::().sqrt(); + let mag_b: f32 = b.iter().map(|x| x * x).sum::().sqrt(); + + if mag_a == 0.0 || mag_b == 0.0 { + return 0.0; + } + + dot / (mag_a * mag_b) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn identical_vectors() { + let v = vec![1.0, 2.0, 3.0]; + let sim = cosine_similarity(&v, &v); + assert!((sim - 1.0).abs() < 1e-6); + } + + #[test] + fn orthogonal_vectors() { + let a = vec![1.0, 0.0]; + let b = vec![0.0, 1.0]; + let sim = cosine_similarity(&a, &b); + assert!(sim.abs() < 1e-6); + } + + #[test] + fn zero_vector() { + let a = vec![0.0, 0.0]; + let b = vec![1.0, 2.0]; + assert_eq!(cosine_similarity(&a, &b), 0.0); + } +} diff --git a/crates/rusty_ai/src/error.rs b/crates/rusty_ai/src/error.rs new file mode 100644 index 0000000..60850b2 --- /dev/null +++ b/crates/rusty_ai/src/error.rs @@ -0,0 +1,70 @@ +use std::time::Duration; + +/// The primary error type for the Rusty AI SDK. +#[derive(Debug, thiserror::Error)] +pub enum AiError { + #[error("Unsupported capability `{capability}` for provider `{provider}`")] + UnsupportedCapability { + capability: String, + provider: String, + }, + + #[error("Provider `{provider}` error (status {status:?}): {message}")] + ProviderError { + provider: String, + status: Option, + message: String, + }, + + #[error("Platform `{platform}` is unavailable")] + PlatformUnavailable { platform: String }, + + #[error("Model `{model}` is unavailable")] + ModelUnavailable { model: String }, + + #[error("Authentication error: {message}")] + AuthError { message: String }, + + #[error("Rate limited (retry after {retry_after:?})")] + RateLimit { retry_after: Option }, + + #[error("Timeout")] + Timeout, + + #[error("Request was cancelled")] + Cancelled, + + #[error("Transport error: {message}")] + Transport { + message: String, + #[source] + source: Option>, + }, + + #[error("Serialization error: {0}")] + Serialization(String), + + #[error("Tool `{tool_name}` error: {message}")] + ToolError { tool_name: String, message: String }, + + #[error("Bridge `{bridge}` error: {message}")] + BridgeError { bridge: String, message: String }, + + #[error("Schema validation error: {message}")] + SchemaValidation { message: String }, + + #[error("Stream error: {message}")] + StreamError { message: String }, + + #[error("Exceeded maximum steps ({max_steps})")] + MaxStepsExceeded { max_steps: usize }, +} + +impl From for AiError { + fn from(err: serde_json::Error) -> Self { + AiError::Serialization(err.to_string()) + } +} + +/// Convenience result type alias. +pub type AiResult = Result; diff --git a/crates/rusty_ai/src/lib.rs b/crates/rusty_ai/src/lib.rs new file mode 100644 index 0000000..b610b25 --- /dev/null +++ b/crates/rusty_ai/src/lib.rs @@ -0,0 +1,74 @@ +//! Core traits, types, and abstractions for the Rusty AI SDK. + +pub mod capability; +pub mod content; +pub mod embedding; +pub mod error; +pub mod message; +pub mod middleware; +pub mod model; +pub mod prompt; +pub mod provider; +pub mod router; +pub mod schema; +pub mod stream; +pub mod structured; +pub mod tool; +pub mod types; +pub mod usage; + +// Re-exports for convenience. +pub use capability::{Capability, CapabilitySet}; +pub use content::{ContentPart, FileData, ImageData, ImageDetail}; +pub use embedding::cosine_similarity; +pub use error::{AiError, AiResult}; +pub use message::{Message, Role}; +pub use model::{ + EmbeddingModel, GenerateOptions, LanguageModel, Middleware, MiddlewareNext, ProviderInfo, + ReasoningEffort, SpeechToTextModel, TextToSpeechModel, ThinkingConfig, +}; +pub use prompt::Prompt; +pub use provider::Provider; +pub use router::{Route, Router}; +pub use schema::OutputSchema; +pub use stream::{AiStream, StreamCollector, StreamEvent, SyntheticStreamer}; +pub use structured::{ + AudioResult, EmbeddingResult, GenerateResult, ObjectResult, TranscriptionResult, TtsOptions, +}; +pub use tool::{ToolCallRequest, ToolCallResult, ToolChoice, ToolDefinition, ToolSet}; +pub use types::{FinishReason, ModelInfo, ModelRegistry, RequestMetadata, ResponseMetadata}; +pub use usage::Usage; + +/// Generate text from a language model with default options. +/// +/// Returns an error if the model responds with no text content (e.g. the +/// model responded with only tool calls). +pub async fn generate_text( + model: &dyn LanguageModel, + prompt: impl Into, +) -> AiResult { + let result = model + .generate(prompt.into(), GenerateOptions::default()) + .await?; + result.text.ok_or_else(|| AiError::ProviderError { + provider: model.provider_id().to_string(), + status: None, + message: "Response contained no text content (model responded with tool calls only)" + .to_string(), + }) +} + +/// Stream text from a language model with default options. +pub async fn stream_text( + model: &dyn LanguageModel, + prompt: impl Into, +) -> AiResult { + model + .stream(prompt.into(), GenerateOptions::default()) + .await +} + +/// Embed texts using an embedding model. +pub async fn embed(model: &dyn EmbeddingModel, texts: Vec) -> AiResult { + model.embed(texts).await +} diff --git a/crates/rusty_ai/src/message.rs b/crates/rusty_ai/src/message.rs new file mode 100644 index 0000000..62cb0a8 --- /dev/null +++ b/crates/rusty_ai/src/message.rs @@ -0,0 +1,93 @@ +use std::collections::HashMap; + +use serde::{Deserialize, Serialize}; + +use crate::content::{ContentPart, ImageData}; + +/// The role of a message participant. +#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum Role { + System, + User, + Assistant, + Tool, +} + +/// A single message in a conversation. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct Message { + pub role: Role, + pub content: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + pub name: Option, + #[serde(default)] + pub metadata: HashMap, +} + +impl Message { + /// Create a system message from plain text. + pub fn system(text: impl Into) -> Self { + Self { + role: Role::System, + content: vec![ContentPart::Text { text: text.into() }], + name: None, + metadata: HashMap::new(), + } + } + + /// Create a user message from plain text. + pub fn user(text: impl Into) -> Self { + Self { + role: Role::User, + content: vec![ContentPart::Text { text: text.into() }], + name: None, + metadata: HashMap::new(), + } + } + + /// Create an assistant message from plain text. + pub fn assistant(text: impl Into) -> Self { + Self { + role: Role::Assistant, + content: vec![ContentPart::Text { text: text.into() }], + name: None, + metadata: HashMap::new(), + } + } + + /// Create a tool-result message. + pub fn tool_result(call_id: impl Into, content: impl Into) -> Self { + use crate::tool::ToolCallResult; + Self { + role: Role::Tool, + content: vec![ContentPart::ToolResult { + result: ToolCallResult { + call_id: call_id.into(), + content: content.into(), + is_error: false, + }, + }], + name: None, + metadata: HashMap::new(), + } + } + + /// Append an image to this message's content parts. + pub fn with_image(mut self, image_data: ImageData) -> Self { + self.content.push(ContentPart::Image { data: image_data }); + self + } + + /// Set the `name` field on this message. + pub fn with_name(mut self, name: impl Into) -> Self { + self.name = Some(name.into()); + self + } + + /// Insert an arbitrary metadata key/value. + pub fn with_metadata(mut self, key: impl Into, value: serde_json::Value) -> Self { + self.metadata.insert(key.into(), value); + self + } +} diff --git a/crates/rusty_ai/src/middleware.rs b/crates/rusty_ai/src/middleware.rs new file mode 100644 index 0000000..3c0576d --- /dev/null +++ b/crates/rusty_ai/src/middleware.rs @@ -0,0 +1,8 @@ +//! Middleware types for intercepting generation requests. +//! +//! The core [`crate::model::Middleware`] and [`crate::model::MiddlewareNext`] +//! types are defined in [`crate::model`] and re-exported from the crate root. + +// This module exists as a namespace placeholder. The primary middleware +// types live in `model.rs` because the linter/project convention co-locates +// them with the `LanguageModel` trait. diff --git a/crates/rusty_ai/src/model.rs b/crates/rusty_ai/src/model.rs new file mode 100644 index 0000000..d18e15f --- /dev/null +++ b/crates/rusty_ai/src/model.rs @@ -0,0 +1,279 @@ +use async_trait::async_trait; + +use crate::capability::CapabilitySet; +use crate::error::{AiError, AiResult}; +use crate::prompt::Prompt; +use crate::schema::OutputSchema; +use crate::stream::AiStream; +use crate::structured::{ + AudioResult, EmbeddingResult, GenerateResult, ObjectResult, TranscriptionResult, TtsOptions, +}; +use crate::tool::{ToolChoice, ToolDefinition}; +use crate::types::RequestMetadata; + +/// Extended-thinking / reasoning configuration. +/// +/// Supported by Anthropic (adaptive thinking), Gemini 2.5+ (thinking budget), +/// and Ollama reasoning models (think flag). +#[derive(Debug, Clone)] +pub enum ThinkingConfig { + /// Enable thinking with adaptive budget (Anthropic claude-3-7-sonnet and later). + Adaptive, + /// Enable thinking with a fixed token budget (Gemini 2.5+). + Budget { tokens: u32 }, + /// Simple on/off flag (Ollama, Gemini 3 Flash `think: true`). + Enabled, +} + +/// Reasoning effort level for models that support it (OpenAI Responses API). +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ReasoningEffort { + None, + Low, + Medium, + High, + XHigh, +} + +/// Options that control generation behaviour. +#[derive(Debug, Clone, Default)] +pub struct GenerateOptions { + pub temperature: Option, + pub max_tokens: Option, + pub top_p: Option, + pub top_k: Option, + pub stop_sequences: Vec, + pub frequency_penalty: Option, + pub presence_penalty: Option, + pub seed: Option, + pub tools: Option>, + pub tool_choice: Option, + pub output_schema: Option, + /// Extended thinking / reasoning configuration. + pub thinking: Option, + /// Reasoning effort (OpenAI Responses API). + pub reasoning_effort: Option, + pub metadata: RequestMetadata, +} + +impl GenerateOptions { + /// Set the temperature. + pub fn with_temperature(mut self, t: f64) -> Self { + self.temperature = Some(t); + self + } + + /// Set the maximum number of tokens to generate. + pub fn with_max_tokens(mut self, n: u32) -> Self { + self.max_tokens = Some(n); + self + } + + /// Set top-p (nucleus sampling). + pub fn with_top_p(mut self, p: f64) -> Self { + self.top_p = Some(p); + self + } + + /// Set top-k sampling. + pub fn with_top_k(mut self, k: u32) -> Self { + self.top_k = Some(k); + self + } + + /// Set stop sequences. + pub fn with_stop_sequences(mut self, seqs: Vec) -> Self { + self.stop_sequences = seqs; + self + } + + /// Set frequency penalty. + pub fn with_frequency_penalty(mut self, p: f64) -> Self { + self.frequency_penalty = Some(p); + self + } + + /// Set presence penalty. + pub fn with_presence_penalty(mut self, p: f64) -> Self { + self.presence_penalty = Some(p); + self + } + + /// Set the random seed for reproducibility. + pub fn with_seed(mut self, seed: u64) -> Self { + self.seed = Some(seed); + self + } + + /// Provide tool definitions. + pub fn with_tools(mut self, tools: Vec) -> Self { + self.tools = Some(tools); + self + } + + /// Set the tool choice strategy. + pub fn with_tool_choice(mut self, choice: ToolChoice) -> Self { + self.tool_choice = Some(choice); + self + } + + /// Set the output schema for structured generation. + pub fn with_output_schema(mut self, schema: OutputSchema) -> Self { + self.output_schema = Some(schema); + self + } + + /// Enable extended thinking with adaptive budget. + pub fn with_thinking(mut self, config: ThinkingConfig) -> Self { + self.thinking = Some(config); + self + } + + /// Set the reasoning effort level (OpenAI Responses API). + pub fn with_reasoning_effort(mut self, effort: ReasoningEffort) -> Self { + self.reasoning_effort = Some(effort); + self + } + + /// Set request metadata. + pub fn with_metadata(mut self, metadata: RequestMetadata) -> Self { + self.metadata = metadata; + self + } +} + +/// Describes a provider backend. +#[derive(Debug, Clone)] +pub struct ProviderInfo { + pub name: String, + pub default_base_url: Option, +} + +/// The core language model trait that all providers implement. +#[async_trait] +pub trait LanguageModel: Send + Sync { + /// Return the model identifier (e.g. "gpt-4o"). + fn model_id(&self) -> &str; + + /// Return the provider identifier (e.g. "openai"). + fn provider_id(&self) -> &str; + + /// Return the set of capabilities this model supports. + fn capabilities(&self) -> &CapabilitySet; + + /// Generate a complete response. + async fn generate(&self, prompt: Prompt, options: GenerateOptions) -> AiResult; + + /// Stream a response as a series of events. + async fn stream(&self, prompt: Prompt, options: GenerateOptions) -> AiResult; +} + +/// Generate a structured object from a language model response. +/// +/// Adds the JSON schema to the options, calls `generate`, and parses the text result. +pub async fn generate_object( + model: &dyn LanguageModel, + prompt: Prompt, + options: GenerateOptions, +) -> AiResult> { + let mut opts = options; + opts.output_schema = Some(OutputSchema::from_type::()); + let result = model.generate(prompt, opts).await?; + let text = result + .text + .as_deref() + .ok_or_else(|| AiError::Serialization("No text in response to parse as object".into()))?; + let object: T = serde_json::from_str(text) + .map_err(|e| AiError::Serialization(format!("Failed to parse response as object: {e}")))?; + Ok(ObjectResult { + object, + text: text.to_owned(), + usage: result.usage, + metadata: result.metadata, + }) +} + +/// A model that produces vector embeddings from text. +#[async_trait] +pub trait EmbeddingModel: Send + Sync { + /// Return the model identifier. + fn model_id(&self) -> &str; + + /// Return the provider identifier. + fn provider_id(&self) -> &str; + + /// Return the dimensionality of the embeddings, if known. + fn dimensions(&self) -> Option; + + /// Embed a batch of texts into vectors. + async fn embed(&self, texts: Vec) -> AiResult; +} + +/// Middleware sits between the caller and the model, intercepting generate calls. +#[async_trait] +pub trait Middleware: Send + Sync { + /// Process a generate request. + /// + /// Implementations should call `next.run(prompt, options).await` to + /// continue the chain, and may inspect / modify the prompt, options, + /// or result. + async fn process( + &self, + prompt: Prompt, + options: GenerateOptions, + next: MiddlewareNext<'_>, + ) -> AiResult; +} + +/// A handle to the next element in a middleware chain. +/// +/// Calling `run` will invoke either the next middleware or the final model. +pub struct MiddlewareNext<'a> { + pub middlewares: &'a [Box], + pub model: &'a dyn LanguageModel, +} + +impl<'a> MiddlewareNext<'a> { + /// Execute the next middleware (or the model if no middleware remains). + pub async fn run(self, prompt: Prompt, options: GenerateOptions) -> AiResult { + if let Some((first, rest)) = self.middlewares.split_first() { + let next = MiddlewareNext { + middlewares: rest, + model: self.model, + }; + first.process(prompt, options, next).await + } else { + self.model.generate(prompt, options).await + } + } +} + +/// A model that converts speech audio to text (e.g. OpenAI Whisper). +#[async_trait] +pub trait SpeechToTextModel: Send + Sync { + fn model_id(&self) -> &str; + fn provider_id(&self) -> &str; + + /// Transcribe audio bytes into text. + async fn transcribe( + &self, + audio: Vec, + mime_type: &str, + language: Option<&str>, + ) -> AiResult; +} + +/// A model that converts text to speech audio (e.g. OpenAI TTS). +#[async_trait] +pub trait TextToSpeechModel: Send + Sync { + fn model_id(&self) -> &str; + fn provider_id(&self) -> &str; + + /// Synthesize speech from text. Returns audio bytes. + async fn synthesize( + &self, + text: &str, + voice: &str, + options: TtsOptions, + ) -> AiResult; +} diff --git a/crates/rusty_ai/src/prompt.rs b/crates/rusty_ai/src/prompt.rs new file mode 100644 index 0000000..3d288e1 --- /dev/null +++ b/crates/rusty_ai/src/prompt.rs @@ -0,0 +1,40 @@ +use serde::{Deserialize, Serialize}; + +use crate::message::Message; + +/// A prompt that can be either raw text or a list of messages. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(untagged)] +pub enum Prompt { + Text(String), + Messages(Vec), +} + +impl Prompt { + /// Convert the prompt into a `Vec`. + /// A plain text prompt becomes a single user message. + pub fn into_messages(self) -> Vec { + match self { + Prompt::Text(text) => vec![Message::user(text)], + Prompt::Messages(msgs) => msgs, + } + } +} + +impl From<&str> for Prompt { + fn from(s: &str) -> Self { + Prompt::Text(s.to_owned()) + } +} + +impl From for Prompt { + fn from(s: String) -> Self { + Prompt::Text(s) + } +} + +impl From> for Prompt { + fn from(msgs: Vec) -> Self { + Prompt::Messages(msgs) + } +} diff --git a/crates/rusty_ai/src/provider.rs b/crates/rusty_ai/src/provider.rs new file mode 100644 index 0000000..7672d28 --- /dev/null +++ b/crates/rusty_ai/src/provider.rs @@ -0,0 +1,52 @@ +use async_trait::async_trait; + +use crate::error::{AiError, AiResult}; +use crate::model::{EmbeddingModel, LanguageModel, SpeechToTextModel, TextToSpeechModel}; +use crate::types::ModelInfo; + +/// A provider that exposes one or more language and/or embedding models. +#[async_trait] +pub trait Provider: Send + Sync { + /// A unique identifier for this provider (e.g. "openai", "anthropic"). + fn id(&self) -> &str; + + /// A human-readable display name. + fn name(&self) -> &str; + + /// Retrieve a language model by its identifier. + fn language_model(&self, model_id: &str) -> AiResult>; + + /// Retrieve an embedding model by its identifier. + fn embedding_model(&self, model_id: &str) -> AiResult>; + + /// List the locally-known models for this provider. + /// + /// This returns a static snapshot of registered models. For providers + /// that can dynamically discover models (cloud APIs, Ollama), prefer + /// [`Provider::fetch_models`] which queries the remote API. + fn available_models(&self) -> Vec; + + /// Fetch the list of models from the remote API. + /// + /// Not all providers support dynamic discovery. The default + /// implementation falls back to [`Provider::available_models`]. + async fn fetch_models(&self) -> AiResult> { + Ok(self.available_models()) + } + + /// Retrieve a speech-to-text model by its identifier. + fn speech_to_text_model(&self, _model_id: &str) -> AiResult> { + Err(AiError::UnsupportedCapability { + capability: "speech_to_text".into(), + provider: self.id().into(), + }) + } + + /// Retrieve a text-to-speech model by its identifier. + fn text_to_speech_model(&self, _model_id: &str) -> AiResult> { + Err(AiError::UnsupportedCapability { + capability: "text_to_speech".into(), + provider: self.id().into(), + }) + } +} diff --git a/crates/rusty_ai/src/router.rs b/crates/rusty_ai/src/router.rs new file mode 100644 index 0000000..f0bb796 --- /dev/null +++ b/crates/rusty_ai/src/router.rs @@ -0,0 +1,360 @@ +use async_trait::async_trait; + +use crate::capability::{Capability, CapabilitySet}; +use crate::error::{AiError, AiResult}; +use crate::model::{GenerateOptions, LanguageModel}; +use crate::prompt::Prompt; +use crate::stream::AiStream; +use crate::structured::GenerateResult; + +/// A condition function used to decide whether a route matches. +pub type RouteCondition = Box bool + Send + Sync>; + +/// A single route that maps a condition to a model. +pub struct Route { + pub model: Box, + pub condition: RouteCondition, + pub priority: i32, +} + +/// A router that dispatches generation requests to different models based on conditions. +pub struct Router { + routes: Vec, + fallback: Option>, +} + +impl Router { + /// Create an empty router. + pub fn new() -> Self { + Self { + routes: Vec::new(), + fallback: None, + } + } + + /// Add a route with the given condition. Higher priority routes are checked first. + pub fn add_route( + mut self, + model: Box, + condition: impl Fn(&Prompt, &GenerateOptions) -> bool + Send + Sync + 'static, + ) -> Self { + self.routes.push(Route { + model, + condition: Box::new(condition), + priority: 0, + }); + self + } + + /// Add a route with a specific priority. Higher priority routes are checked first. + pub fn add_route_with_priority( + mut self, + model: Box, + condition: impl Fn(&Prompt, &GenerateOptions) -> bool + Send + Sync + 'static, + priority: i32, + ) -> Self { + self.routes.push(Route { + model, + condition: Box::new(condition), + priority, + }); + self + } + + /// Set a fallback model to use when no route matches. + pub fn with_fallback(mut self, model: Box) -> Self { + self.fallback = Some(model); + self + } + + /// Create a router that prefers a local model and falls back to a cloud model. + /// + /// The local model is selected when its capabilities satisfy the request + /// (e.g. it supports tool calling when tools are provided). When the local + /// model cannot satisfy the request, the cloud model is used as a fallback. + pub fn local_first(local: Box, cloud: Box) -> Self { + let local_caps = local.capabilities().clone(); + Self::new() + .add_route_with_priority( + local, + move |_prompt, options| { + let mut needed = Vec::new(); + if options.tools.is_some() { + needed.push(Capability::ToolCalling); + } + if options.output_schema.is_some() { + needed.push(Capability::StructuredOutput); + } + local_caps.supports_all(&needed) + }, + 10, + ) + .with_fallback(cloud) + } + + /// Create a router that selects a model based on required capabilities. + /// + /// For each request, the first model whose capability set satisfies the + /// request's needs (e.g. tool calling, structured output) is selected. + pub fn capability_route(models: Vec>) -> Self { + let mut router = Self::new(); + for model in models { + let caps = model.capabilities().clone(); + router.routes.push(Route { + model, + condition: Box::new(move |_prompt, options| { + // Check if this model supports the required capabilities + let mut needed = Vec::new(); + if options.tools.is_some() { + needed.push(Capability::ToolCalling); + } + if options.output_schema.is_some() { + needed.push(Capability::StructuredOutput); + } + caps.supports_all(&needed) + }), + priority: 0, + }); + } + router + } + + /// Select the best model for the given prompt and options. + fn select_model<'a>( + &'a self, + prompt: &Prompt, + options: &GenerateOptions, + ) -> AiResult<&'a dyn LanguageModel> { + // Sort candidates by priority (highest first) + let mut candidates: Vec<&Route> = self + .routes + .iter() + .filter(|r| (r.condition)(prompt, options)) + .collect(); + candidates.sort_by(|a, b| b.priority.cmp(&a.priority)); + + if let Some(route) = candidates.first() { + return Ok(route.model.as_ref()); + } + + if let Some(ref fallback) = self.fallback { + return Ok(fallback.as_ref()); + } + + Err(AiError::ModelUnavailable { + model: "No matching route and no fallback configured".into(), + }) + } +} + +impl Default for Router { + fn default() -> Self { + Self::new() + } +} + +#[async_trait] +impl LanguageModel for Router { + fn model_id(&self) -> &str { + "router" + } + + fn provider_id(&self) -> &str { + "router" + } + + fn capabilities(&self) -> &CapabilitySet { + // A router's capabilities are conceptually the union, but we return + // an empty set since actual capabilities depend on the selected model. + // Callers should rely on the router's routing logic rather than this. + static EMPTY: std::sync::LazyLock = + std::sync::LazyLock::new(CapabilitySet::new); + &EMPTY + } + + async fn generate(&self, prompt: Prompt, options: GenerateOptions) -> AiResult { + let model = self.select_model(&prompt, &options)?; + model.generate(prompt, options).await + } + + async fn stream(&self, prompt: Prompt, options: GenerateOptions) -> AiResult { + let model = self.select_model(&prompt, &options)?; + model.stream(prompt, options).await + } +} + +#[cfg(test)] +mod tests { + use std::sync::{Arc, Mutex}; + + use async_trait::async_trait; + + use super::*; + use crate::schema::OutputSchema; + use crate::structured::GenerateResult; + use crate::tool::ToolDefinition; + use crate::types::{FinishReason, ResponseMetadata}; + use crate::usage::Usage; + + // Minimal inline mock that records which model was called. + struct TrackingModel { + id: &'static str, + capabilities: CapabilitySet, + called: Arc>>, + } + + impl TrackingModel { + fn new( + id: &'static str, + capabilities: CapabilitySet, + log: Arc>>, + ) -> Self { + Self { + id, + capabilities, + called: log, + } + } + } + + #[async_trait] + impl LanguageModel for TrackingModel { + fn model_id(&self) -> &str { + self.id + } + + fn provider_id(&self) -> &str { + "mock" + } + + fn capabilities(&self) -> &CapabilitySet { + &self.capabilities + } + + async fn generate( + &self, + _prompt: Prompt, + _options: GenerateOptions, + ) -> AiResult { + self.called.lock().unwrap().push(self.id.to_string()); + Ok(GenerateResult { + text: Some(self.id.to_string()), + tool_calls: Vec::new(), + finish_reason: FinishReason::Stop, + usage: Usage::default(), + metadata: ResponseMetadata::default(), + }) + } + + async fn stream(&self, _prompt: Prompt, _options: GenerateOptions) -> AiResult { + unimplemented!() + } + } + + fn local_caps_only() -> CapabilitySet { + CapabilitySet::new() + .with(Capability::TextInput) + .with(Capability::TextOutput) + } + + fn full_caps() -> CapabilitySet { + CapabilitySet::new() + .with(Capability::TextInput) + .with(Capability::TextOutput) + .with(Capability::ToolCalling) + .with(Capability::StructuredOutput) + } + + #[tokio::test] + async fn local_first_selects_local_for_plain_text() { + let log = Arc::new(Mutex::new(Vec::new())); + let local = TrackingModel::new("local", local_caps_only(), log.clone()); + let cloud = TrackingModel::new("cloud", full_caps(), log.clone()); + + let router = Router::local_first(Box::new(local), Box::new(cloud)); + let result = router + .generate(Prompt::from("hello"), GenerateOptions::default()) + .await + .unwrap(); + + assert_eq!( + result.text.as_deref(), + Some("local"), + "plain text should route to local" + ); + assert_eq!(*log.lock().unwrap(), vec!["local"]); + } + + #[tokio::test] + async fn local_first_falls_back_to_cloud_when_tools_needed_and_local_lacks_tool_calling() { + let log = Arc::new(Mutex::new(Vec::new())); + let local = TrackingModel::new("local", local_caps_only(), log.clone()); + let cloud = TrackingModel::new("cloud", full_caps(), log.clone()); + + let router = Router::local_first(Box::new(local), Box::new(cloud)); + let options = GenerateOptions::default().with_tools(vec![ToolDefinition { + name: "search".into(), + description: "web search".into(), + parameters: serde_json::json!({}), + }]); + let result = router + .generate(Prompt::from("search for rust"), options) + .await + .unwrap(); + + assert_eq!( + result.text.as_deref(), + Some("cloud"), + "tool call should route to cloud when local lacks ToolCalling" + ); + assert_eq!(*log.lock().unwrap(), vec!["cloud"]); + } + + #[tokio::test] + async fn local_first_selects_local_when_local_supports_tools() { + let log = Arc::new(Mutex::new(Vec::new())); + let local = TrackingModel::new("local", full_caps(), log.clone()); + let cloud = TrackingModel::new("cloud", full_caps(), log.clone()); + + let router = Router::local_first(Box::new(local), Box::new(cloud)); + let options = GenerateOptions::default().with_tools(vec![ToolDefinition { + name: "search".into(), + description: "web search".into(), + parameters: serde_json::json!({}), + }]); + let result = router + .generate(Prompt::from("use tool"), options) + .await + .unwrap(); + + assert_eq!( + result.text.as_deref(), + Some("local"), + "should prefer local when it supports the required capabilities" + ); + assert_eq!(*log.lock().unwrap(), vec!["local"]); + } + + #[tokio::test] + async fn local_first_falls_back_when_structured_output_needed_and_local_lacks_it() { + let log = Arc::new(Mutex::new(Vec::new())); + let local = TrackingModel::new("local", local_caps_only(), log.clone()); + let cloud = TrackingModel::new("cloud", full_caps(), log.clone()); + + let router = Router::local_first(Box::new(local), Box::new(cloud)); + let options = GenerateOptions::default().with_output_schema(OutputSchema::from_value( + serde_json::json!({"type": "object"}), + )); + let result = router + .generate(Prompt::from("give me json"), options) + .await + .unwrap(); + + assert_eq!( + result.text.as_deref(), + Some("cloud"), + "structured output should route to cloud when local lacks StructuredOutput" + ); + assert_eq!(*log.lock().unwrap(), vec!["cloud"]); + } +} diff --git a/crates/rusty_ai/src/schema.rs b/crates/rusty_ai/src/schema.rs new file mode 100644 index 0000000..92a9582 --- /dev/null +++ b/crates/rusty_ai/src/schema.rs @@ -0,0 +1,34 @@ +use serde::{Deserialize, Serialize}; + +/// A JSON schema describing the expected output format. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct OutputSchema { + /// The schema name (used for labelling in prompts). + pub name: String, + /// The JSON Schema object. + pub schema: serde_json::Value, +} + +impl OutputSchema { + /// Derive an `OutputSchema` from a type that implements `JsonSchema`. + pub fn from_type() -> Self { + let schema = schemars::schema_for!(T); + Self { + name: T::schema_name().to_string(), + schema: serde_json::to_value(schema).unwrap_or_default(), + } + } + + /// Create an output schema from a raw JSON value. + pub fn from_value(value: serde_json::Value) -> Self { + Self { + name: String::new(), + schema: value, + } + } + + /// Return a reference to the underlying JSON Schema value. + pub fn as_value(&self) -> &serde_json::Value { + &self.schema + } +} diff --git a/crates/rusty_ai/src/stream.rs b/crates/rusty_ai/src/stream.rs new file mode 100644 index 0000000..9f20c13 --- /dev/null +++ b/crates/rusty_ai/src/stream.rs @@ -0,0 +1,178 @@ +use std::pin::Pin; + +use futures::stream::{self, Stream, StreamExt}; + +use crate::error::{AiError, AiResult}; +use crate::structured::GenerateResult; +use crate::tool::ToolCallRequest; +use crate::types::{FinishReason, ResponseMetadata}; +use crate::usage::Usage; + +/// Events emitted by a streaming response. +#[derive(Debug, Clone)] +pub enum StreamEvent { + MessageStart { + message_id: String, + }, + TextDelta { + delta: String, + }, + ToolCallStart { + call_id: String, + tool_name: String, + }, + ToolCallDelta { + call_id: String, + delta: String, + }, + ToolCallEnd { + call_id: String, + arguments: serde_json::Value, + }, + ToolResult { + call_id: String, + content: String, + is_error: bool, + }, + ObjectDelta { + delta: serde_json::Value, + }, + /// Emitted when an extended-thinking / reasoning model produces + /// intermediate "thinking" tokens (Anthropic, Gemini 2.5+, Ollama think). + ThinkingDelta { + delta: String, + }, + UsageDelta { + usage: Usage, + }, + Warning { + message: String, + }, + MessageEnd { + finish_reason: FinishReason, + usage: Option, + }, + Error { + error: String, + }, + /// Emitted once when a local runtime falls back to non-native streaming. + SyntheticStreamingNotice, +} + +/// A boxed, pinned, sendable stream of `StreamEvent` results. +pub type AiStream = Pin> + Send>>; + +/// Collects a full `AiStream` into a single `GenerateResult`. +pub struct StreamCollector; + +impl StreamCollector { + pub async fn collect(mut stream: AiStream) -> AiResult { + let mut text = String::new(); + let mut tool_calls: Vec = Vec::new(); + let mut finish_reason = FinishReason::Unknown; + let mut usage = Usage::default(); + + // Track in-progress tool calls by call_id + let mut pending_tool_calls: std::collections::HashMap = + std::collections::HashMap::new(); + + while let Some(event) = stream.next().await { + let event = event?; + match event { + StreamEvent::TextDelta { delta } => { + text.push_str(&delta); + } + StreamEvent::ToolCallStart { call_id, tool_name } => { + pending_tool_calls.insert(call_id, (tool_name, String::new())); + } + StreamEvent::ToolCallDelta { call_id, delta } => { + if let Some((_name, args)) = pending_tool_calls.get_mut(&call_id) { + args.push_str(&delta); + } + } + StreamEvent::ToolCallEnd { call_id, arguments } => { + if let Some((name, _partial_args)) = pending_tool_calls.remove(&call_id) { + tool_calls.push(ToolCallRequest { + id: call_id, + name, + arguments, + }); + } + } + StreamEvent::UsageDelta { usage: u } => { + usage.merge(&u); + } + StreamEvent::MessageEnd { + finish_reason: fr, + usage: u, + } => { + finish_reason = fr; + if let Some(u) = u { + usage.merge(&u); + } + } + StreamEvent::Error { error } => { + return Err(AiError::StreamError { message: error }); + } + _ => {} + } + } + + Ok(GenerateResult { + text: if text.is_empty() { None } else { Some(text) }, + tool_calls, + finish_reason, + usage, + metadata: ResponseMetadata::default(), + }) + } +} + +/// Produces a synthetic stream from a complete text response. +/// +/// Useful for providers that don't support real streaming. +pub struct SyntheticStreamer; + +impl SyntheticStreamer { + pub fn stream(text: String, chunk_size: usize) -> AiStream { + let chunk_size = chunk_size.max(1); + let chunks: Vec> = { + let mut v = Vec::new(); + v.push(Ok(StreamEvent::MessageStart { + message_id: uuid::Uuid::new_v4().to_string(), + })); + // Notify callers that this is simulated streaming. + v.push(Ok(StreamEvent::SyntheticStreamingNotice)); + + let mut pos = 0; + while pos < text.len() { + let end = (pos + chunk_size).min(text.len()); + // Ensure we don't split in the middle of a multi-byte char. + let end = if end < text.len() { + let mut e = end; + while !text.is_char_boundary(e) && e > pos { + e -= 1; + } + e + } else { + end + }; + if end == pos { + break; + } + v.push(Ok(StreamEvent::TextDelta { + delta: text[pos..end].to_owned(), + })); + pos = end; + } + + v.push(Ok(StreamEvent::MessageEnd { + finish_reason: FinishReason::Stop, + usage: None, + })); + v + }; + + Box::pin(stream::iter(chunks)) + } +} diff --git a/crates/rusty_ai/src/structured.rs b/crates/rusty_ai/src/structured.rs new file mode 100644 index 0000000..a7135aa --- /dev/null +++ b/crates/rusty_ai/src/structured.rs @@ -0,0 +1,70 @@ +use serde::{Deserialize, Serialize}; + +use crate::tool::ToolCallRequest; +use crate::types::{FinishReason, ResponseMetadata}; +use crate::usage::Usage; + +/// The result of a non-streaming generate call. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct GenerateResult { + /// Generated text (may be `None` when the model only produced tool calls). + pub text: Option, + /// Tool calls requested by the model. + #[serde(default)] + pub tool_calls: Vec, + /// Reason the model stopped generating. + pub finish_reason: FinishReason, + /// Token usage information. + pub usage: Usage, + /// Provider-specific response metadata. + #[serde(default)] + pub metadata: ResponseMetadata, +} + +/// The result of a structured object generation. +#[derive(Debug, Clone)] +pub struct ObjectResult { + /// The parsed object. + pub object: T, + /// The raw text that was parsed. + pub text: String, + /// Token usage information. + pub usage: Usage, + /// Provider-specific response metadata. + pub metadata: ResponseMetadata, +} + +/// The result of an embedding call. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct EmbeddingResult { + /// The embedding vectors. + pub embeddings: Vec>, + /// Token usage information. + pub usage: Usage, +} + +/// Result of a speech-to-text transcription. +#[derive(Debug, Clone)] +pub struct TranscriptionResult { + pub text: String, + pub language: Option, + pub duration_seconds: Option, + pub usage: Usage, +} + +/// Result of text-to-speech synthesis. +#[derive(Debug, Clone)] +pub struct AudioResult { + pub audio: Vec, + pub mime_type: String, + pub usage: Usage, +} + +/// Options for text-to-speech. +#[derive(Debug, Clone, Default)] +pub struct TtsOptions { + /// Speech speed multiplier (e.g. 1.0 = normal). + pub speed: Option, + /// Output audio format (e.g. "mp3", "opus", "aac", "flac", "wav", "pcm"). + pub response_format: Option, +} diff --git a/crates/rusty_ai/src/tool.rs b/crates/rusty_ai/src/tool.rs new file mode 100644 index 0000000..f94eb5d --- /dev/null +++ b/crates/rusty_ai/src/tool.rs @@ -0,0 +1,112 @@ +use std::collections::HashMap; + +use async_trait::async_trait; +use serde::{Deserialize, Serialize}; + +use crate::error::AiError; +use crate::error::AiResult; + +/// JSON-Schema based definition of a tool that a model can call. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ToolDefinition { + pub name: String, + pub description: String, + /// JSON Schema describing the parameters object. + pub parameters: serde_json::Value, +} + +/// A request from the model to invoke a tool. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ToolCallRequest { + pub id: String, + pub name: String, + pub arguments: serde_json::Value, +} + +/// The result of executing a tool call. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ToolCallResult { + pub call_id: String, + pub content: String, + pub is_error: bool, +} + +/// How the model should choose which tool to call. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum ToolChoice { + Auto, + None, + Required, + Specific(String), +} + +/// A tool that can be called by a model. +#[async_trait] +pub trait Tool: Send + Sync { + /// Return the JSON-schema definition for this tool. + fn definition(&self) -> ToolDefinition; + /// Execute the tool with the given JSON arguments. + async fn execute(&self, args: serde_json::Value) -> Result; +} + +/// A named collection of tools. +pub struct ToolSet { + tools: HashMap>, +} + +impl ToolSet { + /// Create an empty tool set. + pub fn new() -> Self { + Self { + tools: HashMap::new(), + } + } + + /// Register a tool. Returns `&mut Self` for chaining. + pub fn add(&mut self, tool: impl Tool + 'static) -> &mut Self { + let def = tool.definition(); + self.tools.insert(def.name.clone(), Box::new(tool)); + self + } + + /// Look up a tool by name. + pub fn get(&self, name: &str) -> Option<&dyn Tool> { + self.tools.get(name).map(|b| b.as_ref()) + } + + /// Return definitions for every registered tool. + pub fn definitions(&self) -> Vec { + self.tools.values().map(|t| t.definition()).collect() + } + + /// Execute a tool call request and return the result. + pub async fn execute(&self, call: &ToolCallRequest) -> AiResult { + let tool = self + .tools + .get(&call.name) + .ok_or_else(|| AiError::ToolError { + tool_name: call.name.clone(), + message: "Tool not found".into(), + })?; + + match tool.execute(call.arguments.clone()).await { + Ok(content) => Ok(ToolCallResult { + call_id: call.id.clone(), + content, + is_error: false, + }), + Err(e) => Ok(ToolCallResult { + call_id: call.id.clone(), + content: e.to_string(), + is_error: true, + }), + } + } +} + +impl Default for ToolSet { + fn default() -> Self { + Self::new() + } +} diff --git a/crates/rusty_ai/src/types.rs b/crates/rusty_ai/src/types.rs new file mode 100644 index 0000000..feead70 --- /dev/null +++ b/crates/rusty_ai/src/types.rs @@ -0,0 +1,125 @@ +use std::collections::HashMap; + +use chrono::{DateTime, Utc}; +use serde::{Deserialize, Serialize}; +use uuid::Uuid; + +use crate::capability::{Capability, CapabilitySet}; + +/// Reason the model stopped generating. +#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum FinishReason { + Stop, + Length, + ToolCall, + ContentFilter, + Error, + Unknown, +} + +/// Metadata describing a model exposed by a provider. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ModelInfo { + pub id: String, + pub provider: String, + pub display_name: String, + pub capabilities: CapabilitySet, +} + +/// Metadata attached to an outgoing request. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct RequestMetadata { + pub request_id: Uuid, + pub timestamp: DateTime, + #[serde(default)] + pub extra: HashMap, +} + +impl Default for RequestMetadata { + fn default() -> Self { + Self { + request_id: Uuid::new_v4(), + timestamp: Utc::now(), + extra: HashMap::new(), + } + } +} + +/// Metadata returned alongside a provider response. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct ResponseMetadata { + pub request_id: Uuid, + pub provider: String, + pub model: String, + pub latency_ms: Option, + #[serde(default)] + pub extra: HashMap, +} + +impl ResponseMetadata { + /// Create a new `ResponseMetadata` with the given provider and model. + pub fn new(provider: impl Into, model: impl Into) -> Self { + Self { + request_id: Uuid::new_v4(), + provider: provider.into(), + model: model.into(), + latency_ms: None, + extra: HashMap::new(), + } + } +} + +impl Default for ResponseMetadata { + fn default() -> Self { + Self { + request_id: Uuid::new_v4(), + provider: String::new(), + model: String::new(), + latency_ms: None, + extra: HashMap::new(), + } + } +} + +/// Caches model information fetched from remote APIs. +/// +/// Model IDs change frequently. Use `fetch_models()` on the provider +/// to populate the registry, then look up models by ID. +#[derive(Debug, Clone, Default)] +pub struct ModelRegistry { + models: Vec, +} + +impl ModelRegistry { + pub fn new() -> Self { + Self { models: Vec::new() } + } + + /// Merge freshly fetched models into the registry. + pub fn update(&mut self, models: Vec) { + for model in models { + if !self.models.iter().any(|m| m.id == model.id) { + self.models.push(model); + } + } + } + + /// Look up a model by ID. + pub fn get(&self, id: &str) -> Option<&ModelInfo> { + self.models.iter().find(|m| m.id == id) + } + + /// All known models. + pub fn all(&self) -> &[ModelInfo] { + &self.models + } + + /// Filter models by a required capability. + pub fn with_capability(&self, cap: &Capability) -> Vec<&ModelInfo> { + self.models + .iter() + .filter(|m| m.capabilities.has(cap)) + .collect() + } +} diff --git a/crates/rusty_ai/src/usage.rs b/crates/rusty_ai/src/usage.rs new file mode 100644 index 0000000..2bb552c --- /dev/null +++ b/crates/rusty_ai/src/usage.rs @@ -0,0 +1,27 @@ +use serde::{Deserialize, Serialize}; + +/// Token usage information returned by providers. +#[derive(Debug, Clone, Default, Serialize, Deserialize)] +pub struct Usage { + pub prompt_tokens: Option, + pub completion_tokens: Option, + pub total_tokens: Option, +} + +impl Usage { + /// Merge another `Usage` into this one by adding token counts. + pub fn merge(&mut self, other: &Usage) { + self.prompt_tokens = add_opt(self.prompt_tokens, other.prompt_tokens); + self.completion_tokens = add_opt(self.completion_tokens, other.completion_tokens); + self.total_tokens = add_opt(self.total_tokens, other.total_tokens); + } +} + +fn add_opt(a: Option, b: Option) -> Option { + match (a, b) { + (Some(x), Some(y)) => Some(x + y), + (Some(x), None) => Some(x), + (None, Some(y)) => Some(y), + (None, None) => None, + } +} diff --git a/crates/rusty_browser/Cargo.toml b/crates/rusty_browser/Cargo.toml new file mode 100644 index 0000000..9c8431e --- /dev/null +++ b/crates/rusty_browser/Cargo.toml @@ -0,0 +1,25 @@ +[package] +name = "rusty_browser" +version.workspace = true +edition.workspace = true +license.workspace = true +description = "Browser-based local AI model detection for WASM targets (Chrome/Edge built-in AI)" + +[dependencies] +rusty_ai = { workspace = true } +async-trait = { workspace = true } +futures = { workspace = true } +tokio = { workspace = true } +tokio-stream = { workspace = true } +serde = { workspace = true } +serde_json = { workspace = true } +tracing = { workspace = true } +uuid = { workspace = true } +chrono = { workspace = true } + +# WASM dependencies would be feature-gated +# [target.'cfg(target_arch = "wasm32")'.dependencies] +# wasm-bindgen = "0.2" +# wasm-bindgen-futures = "0.4" +# web-sys = { version = "0.3", features = ["Window", "Navigator"] } +# js-sys = "0.3" diff --git a/crates/rusty_browser/src/bridge.rs b/crates/rusty_browser/src/bridge.rs new file mode 100644 index 0000000..0862185 --- /dev/null +++ b/crates/rusty_browser/src/bridge.rs @@ -0,0 +1,51 @@ +use async_trait::async_trait; + +use crate::capabilities::{BrowserAiCapabilities, BrowserAiOptions}; + +/// Trait for browser AI bridge. +/// +/// On WASM targets, implement this via `wasm-bindgen` to call the browser's +/// built-in AI APIs (Chrome Prompt API, Edge AI, etc.). +/// +/// On non-WASM targets, a no-op implementation can be used for testing. +/// +/// # Browser compatibility notes +/// +/// - `window.ai` is deprecated since Chrome 138; use the `LanguageModel` global directly. +/// - Edge AI is backed by Phi Silica on Copilot+ PCs. +/// - Not available in Chromium/CEF builds, only official Chrome/Edge. +#[async_trait] +pub trait BrowserAiBridge: Send + Sync { + /// Detect if the browser has built-in AI capabilities. + async fn detect(&self) -> BrowserAiCapabilities; + + /// Generate text using the browser's AI. + async fn generate(&self, prompt: &str, options: &BrowserAiOptions) -> Result; + + /// Stream text using the browser's AI (if supported). + /// Returns chunks of text. + async fn stream(&self, prompt: &str, options: &BrowserAiOptions) + -> Result, String>; +} + +/// A no-op bridge for non-WASM targets, useful for testing. +pub struct NoOpBrowserBridge; + +#[async_trait] +impl BrowserAiBridge for NoOpBrowserBridge { + async fn detect(&self) -> BrowserAiCapabilities { + BrowserAiCapabilities::default() + } + + async fn generate(&self, _prompt: &str, _options: &BrowserAiOptions) -> Result { + Err("Browser AI not available on this target".into()) + } + + async fn stream( + &self, + _prompt: &str, + _options: &BrowserAiOptions, + ) -> Result, String> { + Err("Browser AI not available on this target".into()) + } +} diff --git a/crates/rusty_browser/src/capabilities.rs b/crates/rusty_browser/src/capabilities.rs new file mode 100644 index 0000000..33e4dfa --- /dev/null +++ b/crates/rusty_browser/src/capabilities.rs @@ -0,0 +1,62 @@ +/// Detected browser AI capabilities. +#[derive(Debug, Clone)] +pub struct BrowserAiCapabilities { + /// Whether any browser AI API is available. + pub available: bool, + /// Detected browser type. + pub browser: BrowserType, + /// Whether the browser AI supports streaming. + pub supports_streaming: bool, + /// Whether the browser AI supports system prompts. + pub supports_system_prompt: bool, + /// Maximum token limit, if known. + pub max_tokens: Option, + /// Whether the browser uses Gemini Nano (Chrome) or Phi Silica (Edge). + pub backing_model: BackingModel, + /// Whether structured output via responseConstraint is supported. + pub supports_response_constraint: bool, +} + +impl Default for BrowserAiCapabilities { + fn default() -> Self { + Self { + available: false, + browser: BrowserType::Unknown, + supports_streaming: false, + supports_system_prompt: false, + max_tokens: None, + backing_model: BackingModel::Unknown, + supports_response_constraint: false, + } + } +} + +/// Detected browser type. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum BrowserType { + Chrome, + Edge, + Other(String), + Unknown, +} + +/// The on-device model backing the browser AI. +#[derive(Debug, Clone, PartialEq, Eq, Default)] +pub enum BackingModel { + /// Gemini Nano (Chrome 138+). + GeminiNano, + /// Phi Silica (Microsoft Edge Copilot+ PCs). + PhiSilica, + #[default] + Unknown, +} + +/// Options for browser AI generation. +#[derive(Debug, Clone, Default)] +pub struct BrowserAiOptions { + pub system_prompt: Option, + pub temperature: Option, + pub top_k: Option, + /// Constrained JSON schema output (Chrome Prompt API responseConstraint). + pub response_constraint: Option, +} diff --git a/crates/rusty_browser/src/lib.rs b/crates/rusty_browser/src/lib.rs new file mode 100644 index 0000000..b003673 --- /dev/null +++ b/crates/rusty_browser/src/lib.rs @@ -0,0 +1,21 @@ +//! Browser-based local AI model detection for WASM targets. +//! +//! This crate detects and bridges to browser-based AI APIs: +//! - Chrome's built-in AI (Prompt API / `window.ai`) +//! - Edge's built-in AI +//! +//! Compatible with Rust WASM frameworks: Dioxus, Leptos, Yew. +//! +//! On WASM targets, the bridge is implemented via `wasm-bindgen` to call +//! the browser's AI APIs. On non-WASM targets, a no-op implementation is +//! provided for testing. + +mod bridge; +mod capabilities; +mod model; +mod provider; + +pub use bridge::*; +pub use capabilities::*; +pub use model::*; +pub use provider::*; diff --git a/crates/rusty_browser/src/model.rs b/crates/rusty_browser/src/model.rs new file mode 100644 index 0000000..19c4a39 --- /dev/null +++ b/crates/rusty_browser/src/model.rs @@ -0,0 +1,105 @@ +use std::sync::Arc; + +use async_trait::async_trait; +use rusty_ai::{ + AiError, AiResult, AiStream, Capability, CapabilitySet, ContentPart, FinishReason, + GenerateOptions, GenerateResult, LanguageModel, Prompt, ResponseMetadata, SyntheticStreamer, + Usage, +}; + +use crate::bridge::BrowserAiBridge; +use crate::capabilities::BrowserAiOptions; + +/// A [`LanguageModel`] backed by the browser's built-in AI. +pub struct BrowserAiModel { + bridge: Arc, + capabilities: CapabilitySet, +} + +impl BrowserAiModel { + pub(crate) fn new(bridge: Arc) -> Self { + let capabilities = CapabilitySet::new() + .with(Capability::TextInput) + .with(Capability::TextOutput) + .with(Capability::LocalExecution) + .with(Capability::PlatformNative); + Self { + bridge, + capabilities, + } + } + + fn prompt_to_text(prompt: &Prompt) -> String { + match prompt { + Prompt::Text(t) => t.clone(), + Prompt::Messages(msgs) => msgs + .iter() + .flat_map(|m| { + m.content.iter().filter_map(|c| match c { + ContentPart::Text { text } => Some(text.clone()), + _ => None, + }) + }) + .collect::>() + .join("\n"), + } + } +} + +#[async_trait] +impl LanguageModel for BrowserAiModel { + fn model_id(&self) -> &str { + "browser-ai" + } + + fn provider_id(&self) -> &str { + "browser" + } + + fn capabilities(&self) -> &CapabilitySet { + &self.capabilities + } + + async fn generate(&self, prompt: Prompt, options: GenerateOptions) -> AiResult { + let caps = self.bridge.detect().await; + if !caps.available { + return Err(AiError::PlatformUnavailable { + platform: "Browser AI (no built-in AI detected)".into(), + }); + } + + let browser_options = BrowserAiOptions { + temperature: options.temperature, + top_k: options.top_k, + ..Default::default() + }; + + let text = Self::prompt_to_text(&prompt); + let response = self + .bridge + .generate(&text, &browser_options) + .await + .map_err(|e| AiError::BridgeError { + bridge: "browser".into(), + message: e, + })?; + + Ok(GenerateResult { + text: Some(response), + tool_calls: vec![], + finish_reason: FinishReason::Stop, + usage: Usage::default(), + metadata: ResponseMetadata { + provider: "browser".into(), + model: "browser-ai".into(), + ..Default::default() + }, + }) + } + + async fn stream(&self, prompt: Prompt, options: GenerateOptions) -> AiResult { + let result = self.generate(prompt, options).await?; + let text = result.text.unwrap_or_default(); + Ok(SyntheticStreamer::stream(text, 20)) + } +} diff --git a/crates/rusty_browser/src/provider.rs b/crates/rusty_browser/src/provider.rs new file mode 100644 index 0000000..1ae5d3e --- /dev/null +++ b/crates/rusty_browser/src/provider.rs @@ -0,0 +1,67 @@ +use std::sync::Arc; + +use rusty_ai::{ + AiError, AiResult, Capability, CapabilitySet, EmbeddingModel, LanguageModel, ModelInfo, + Provider, +}; + +use crate::bridge::BrowserAiBridge; +use crate::capabilities::BrowserAiCapabilities; +use crate::model::BrowserAiModel; + +/// Provider for browser-based built-in AI models. +pub struct BrowserAiProvider { + bridge: Arc, +} + +impl BrowserAiProvider { + pub fn new(bridge: impl BrowserAiBridge + 'static) -> Self { + Self { + bridge: Arc::new(bridge), + } + } + + /// Get the browser AI model. + pub fn model(&self) -> BrowserAiModel { + BrowserAiModel::new(self.bridge.clone()) + } + + /// Detect browser AI capabilities. + pub async fn detect(&self) -> BrowserAiCapabilities { + self.bridge.detect().await + } +} + +impl Provider for BrowserAiProvider { + fn id(&self) -> &str { + "browser" + } + + fn name(&self) -> &str { + "Browser AI" + } + + fn language_model(&self, _model_id: &str) -> AiResult> { + Ok(Box::new(self.model())) + } + + fn embedding_model(&self, model_id: &str) -> AiResult> { + Err(AiError::UnsupportedCapability { + capability: "embeddings".into(), + provider: format!("browser/{model_id}"), + }) + } + + fn available_models(&self) -> Vec { + vec![ModelInfo { + id: "browser-ai".into(), + provider: "browser".into(), + display_name: "Browser Built-in AI".into(), + capabilities: CapabilitySet::new() + .with(Capability::TextInput) + .with(Capability::TextOutput) + .with(Capability::LocalExecution) + .with(Capability::PlatformNative), + }] + } +} diff --git a/crates/rusty_chatgpt/Cargo.toml b/crates/rusty_chatgpt/Cargo.toml new file mode 100644 index 0000000..d20778f --- /dev/null +++ b/crates/rusty_chatgpt/Cargo.toml @@ -0,0 +1,15 @@ +[package] +name = "rusty_chatgpt" +version.workspace = true +edition.workspace = true +license.workspace = true +description = "OpenAI ChatGPT provider for the Rusty AI SDK" + +[dependencies] +rusty_ai = { workspace = true } +rusty_openai_compatible = { workspace = true } +async-trait = { workspace = true } +secrecy = { workspace = true } +reqwest = { workspace = true } +serde = { workspace = true } +tracing = { workspace = true } diff --git a/crates/rusty_chatgpt/src/lib.rs b/crates/rusty_chatgpt/src/lib.rs new file mode 100644 index 0000000..5b00448 --- /dev/null +++ b/crates/rusty_chatgpt/src/lib.rs @@ -0,0 +1,321 @@ +//! OpenAI ChatGPT provider for the Rusty AI SDK. +//! +//! This is a thin wrapper around [`rusty_openai_compatible`] that pre-configures +//! the adapter for the official OpenAI API with well-known ChatGPT models. +//! +//! # Example +//! +//! ```rust,no_run +//! use rusty_chatgpt::ChatGptProvider; +//! use rusty_ai::Provider; +//! +//! let provider = ChatGptProvider::new("sk-..."); +//! let model = provider.language_model("gpt-4o").unwrap(); +//! ``` + +use rusty_ai::capability::{Capability, CapabilitySet}; +use rusty_ai::error::AiResult; +use rusty_ai::model::{EmbeddingModel, LanguageModel}; +use rusty_ai::provider::Provider; +use rusty_ai::types::ModelInfo; +use rusty_openai_compatible::{ + OpenAiCompatibleConfig, OpenAiCompatibleModel, OpenAiCompatibleProvider, +}; + +// ── Latest model aliases ── + +/// Convenience alias. Use `fetch_models()` or pass any model ID string to `model()`. +pub const GPT_4O_LATEST: &str = "gpt-4o"; +/// Convenience alias. Use `fetch_models()` or pass any model ID string to `model()`. +pub const GPT_4O_MINI_LATEST: &str = "gpt-4o-mini"; +/// Convenience alias. Use `fetch_models()` or pass any model ID string to `model()`. +pub const O3_MINI_LATEST: &str = "o3-mini"; +/// Convenience alias. Use `fetch_models()` or pass any model ID string to `model()`. +pub const GPT_5_4_LATEST: &str = "gpt-5.4"; +/// Convenience alias. Use `fetch_models()` or pass any model ID string to `model()`. +pub const GPT_5_4_MINI_LATEST: &str = "gpt-5.4-mini"; +/// Convenience alias. Use `fetch_models()` or pass any model ID string to `model()`. +pub const GPT_5_4_NANO_LATEST: &str = "gpt-5.4-nano"; + +// ── Audio / voice model identifiers ── + +/// OpenAI Whisper speech-to-text. +pub const WHISPER: &str = "whisper-1"; +/// OpenAI TTS. +pub const TTS: &str = "tts-1"; +/// OpenAI TTS HD. +pub const TTS_HD: &str = "tts-1-hd"; +/// OpenAI GPT-4o Realtime (voice). +pub const GPT_4O_REALTIME: &str = "gpt-4o-realtime-preview"; +/// OpenAI GPT-4o Audio (audio modality in chat completions). +pub const GPT_4O_AUDIO: &str = "gpt-4o-audio-preview"; +/// OpenAI GPT-4o Mini Realtime. +pub const GPT_4O_MINI_REALTIME: &str = "gpt-4o-mini-realtime-preview"; + +/// A provider pre-configured for the official OpenAI ChatGPT API. +pub struct ChatGptProvider { + inner: OpenAiCompatibleProvider, + config: OpenAiCompatibleConfig, +} + +impl ChatGptProvider { + /// Create a new ChatGPT provider with the given API key. + pub fn new(api_key: impl Into) -> Self { + let config = OpenAiCompatibleConfig::openai(api_key); + let inner = OpenAiCompatibleProvider::new(config.clone(), "chatgpt", "ChatGPT") + .with_model_info(ModelInfo { + id: GPT_4O_LATEST.into(), + provider: "chatgpt".into(), + display_name: "GPT-4o".into(), + capabilities: CapabilitySet::new() + .with(Capability::TextInput) + .with(Capability::TextOutput) + .with(Capability::ImageInput) + .with(Capability::Streaming) + .with(Capability::ToolCalling) + .with(Capability::StructuredOutput), + }) + .with_model_info(ModelInfo { + id: GPT_4O_MINI_LATEST.into(), + provider: "chatgpt".into(), + display_name: "GPT-4o Mini".into(), + capabilities: CapabilitySet::new() + .with(Capability::TextInput) + .with(Capability::TextOutput) + .with(Capability::ImageInput) + .with(Capability::Streaming) + .with(Capability::ToolCalling) + .with(Capability::StructuredOutput), + }) + .with_model_info(ModelInfo { + id: O3_MINI_LATEST.into(), + provider: "chatgpt".into(), + display_name: "o3-mini".into(), + capabilities: CapabilitySet::new() + .with(Capability::TextInput) + .with(Capability::TextOutput) + .with(Capability::Streaming) + .with(Capability::ToolCalling), + }) + .with_model_info(ModelInfo { + id: GPT_5_4_LATEST.into(), + provider: "chatgpt".into(), + display_name: "GPT-5.4".into(), + capabilities: CapabilitySet::new() + .with(Capability::TextInput) + .with(Capability::TextOutput) + .with(Capability::ImageInput) + .with(Capability::Streaming) + .with(Capability::ToolCalling) + .with(Capability::StructuredOutput) + .with(Capability::ExtendedThinking), + }) + .with_model_info(ModelInfo { + id: GPT_5_4_MINI_LATEST.into(), + provider: "chatgpt".into(), + display_name: "GPT-5.4 Mini".into(), + capabilities: CapabilitySet::new() + .with(Capability::TextInput) + .with(Capability::TextOutput) + .with(Capability::ImageInput) + .with(Capability::Streaming) + .with(Capability::ToolCalling) + .with(Capability::StructuredOutput), + }) + .with_model_info(ModelInfo { + id: GPT_5_4_NANO_LATEST.into(), + provider: "chatgpt".into(), + display_name: "GPT-5.4 Nano".into(), + capabilities: CapabilitySet::new() + .with(Capability::TextInput) + .with(Capability::TextOutput) + .with(Capability::Streaming), + }) + .with_model_info(ModelInfo { + id: WHISPER.into(), + provider: "chatgpt".into(), + display_name: "Whisper".into(), + capabilities: CapabilitySet::new() + .with(Capability::AudioInput) + .with(Capability::TextOutput), + }) + .with_model_info(ModelInfo { + id: TTS.into(), + provider: "chatgpt".into(), + display_name: "TTS".into(), + capabilities: CapabilitySet::new() + .with(Capability::TextInput) + .with(Capability::AudioOutput), + }) + .with_model_info(ModelInfo { + id: TTS_HD.into(), + provider: "chatgpt".into(), + display_name: "TTS HD".into(), + capabilities: CapabilitySet::new() + .with(Capability::TextInput) + .with(Capability::AudioOutput), + }) + .with_model_info(ModelInfo { + id: GPT_4O_REALTIME.into(), + provider: "chatgpt".into(), + display_name: "GPT-4o Realtime".into(), + capabilities: CapabilitySet::new() + .with(Capability::TextInput) + .with(Capability::TextOutput) + .with(Capability::AudioInput) + .with(Capability::AudioOutput) + .with(Capability::Streaming), + }) + .with_model_info(ModelInfo { + id: GPT_4O_AUDIO.into(), + provider: "chatgpt".into(), + display_name: "GPT-4o Audio".into(), + capabilities: CapabilitySet::new() + .with(Capability::TextInput) + .with(Capability::TextOutput) + .with(Capability::AudioInput) + .with(Capability::AudioOutput) + .with(Capability::Streaming), + }) + .with_model_info(ModelInfo { + id: GPT_4O_MINI_REALTIME.into(), + provider: "chatgpt".into(), + display_name: "GPT-4o Mini Realtime".into(), + capabilities: CapabilitySet::new() + .with(Capability::TextInput) + .with(Capability::TextOutput) + .with(Capability::AudioInput) + .with(Capability::AudioOutput) + .with(Capability::Streaming), + }); + Self { inner, config } + } + + /// Set an optional OpenAI organization ID. + pub fn with_org(mut self, org_id: impl Into) -> Self { + self.config = self.config.with_org(org_id); + // Rebuild inner provider with the updated config so that models + // created via the Provider trait pick up the org header. + let models: Vec = self.inner.models().to_vec(); + let mut new_inner = + OpenAiCompatibleProvider::new(self.config.clone(), "chatgpt", "ChatGPT"); + for info in models { + new_inner = new_inner.with_model_info(info); + } + self.inner = new_inner; + self + } + + /// Get a specific model by ID, looking up known capabilities. + /// Any valid OpenAI model ID is accepted. + pub fn model(&self, model_id: &str) -> OpenAiCompatibleModel { + let caps = self + .inner + .models() + .iter() + .find(|m| m.id == model_id) + .map(|m| m.capabilities.clone()) + .unwrap_or_else(|| { + CapabilitySet::new() + .with(Capability::TextInput) + .with(Capability::TextOutput) + .with(Capability::Streaming) + }); + OpenAiCompatibleModel::new(self.config.clone(), model_id, "chatgpt").with_capabilities(caps) + } + + pub fn gpt4o(&self) -> OpenAiCompatibleModel { + self.model(GPT_4O_LATEST) + } + + pub fn gpt4o_mini(&self) -> OpenAiCompatibleModel { + self.model(GPT_4O_MINI_LATEST) + } + + pub fn gpt54(&self) -> OpenAiCompatibleModel { + self.model(GPT_5_4_LATEST) + } + + pub fn gpt54_mini(&self) -> OpenAiCompatibleModel { + self.model(GPT_5_4_MINI_LATEST) + } + + pub fn gpt54_nano(&self) -> OpenAiCompatibleModel { + self.model(GPT_5_4_NANO_LATEST) + } + + /// Fetch the list of models from the OpenAI API. + pub async fn list_remote_models(&self) -> rusty_ai::AiResult> { + use secrecy::ExposeSecret; + let client = reqwest::Client::new(); + let resp = client + .get("https://api.openai.com/v1/models") + .header( + "Authorization", + format!("Bearer {}", self.config.api_key().expose_secret()), + ) + .send() + .await + .map_err(|e| rusty_ai::AiError::Transport { + message: e.to_string(), + source: Some(Box::new(e)), + })?; + + let status = resp.status(); + if !status.is_success() { + let status_code = status.as_u16(); + let body = match resp.text().await { + Ok(body) => body, + Err(e) => { + tracing::warn!(status = status_code, error = %e, "Failed to read ChatGPT error response body"); + format!("") + } + }; + return Err(rusty_ai::AiError::ProviderError { + provider: "chatgpt".into(), + status: Some(status_code), + message: body, + }); + } + + #[derive(serde::Deserialize)] + struct ListModelsResponse { + data: Vec, + } + #[derive(serde::Deserialize)] + struct ModelEntry { + id: String, + } + + let list: ListModelsResponse = resp + .json() + .await + .map_err(|e| rusty_ai::AiError::Serialization(e.to_string()))?; + + Ok(list.data.into_iter().map(|m| m.id).collect()) + } +} + +impl Provider for ChatGptProvider { + fn id(&self) -> &str { + self.inner.id() + } + + fn name(&self) -> &str { + self.inner.name() + } + + fn language_model(&self, model_id: &str) -> AiResult> { + Ok(Box::new(self.model(model_id))) + } + + fn embedding_model(&self, _model_id: &str) -> AiResult> { + Err(rusty_ai::AiError::ModelUnavailable { + model: "ChatGPT does not support embedding models".into(), + }) + } + + fn available_models(&self) -> Vec { + self.inner.models().to_vec() + } +} diff --git a/crates/rusty_claude/Cargo.toml b/crates/rusty_claude/Cargo.toml new file mode 100644 index 0000000..a4a91d9 --- /dev/null +++ b/crates/rusty_claude/Cargo.toml @@ -0,0 +1,21 @@ +[package] +name = "rusty_claude" +version.workspace = true +edition.workspace = true +license.workspace = true +description = "Anthropic Claude provider for the Rusty AI SDK" + +[dependencies] +rusty_ai = { workspace = true } +async-trait = { workspace = true } +futures = { workspace = true } +tokio = { workspace = true } +tokio-stream = { workspace = true } +reqwest = { workspace = true } +reqwest-eventsource = { workspace = true } +serde = { workspace = true } +serde_json = { workspace = true } +tracing = { workspace = true } +uuid = { workspace = true } +chrono = { workspace = true } +secrecy = { workspace = true } diff --git a/crates/rusty_claude/src/api_types.rs b/crates/rusty_claude/src/api_types.rs new file mode 100644 index 0000000..b214b42 --- /dev/null +++ b/crates/rusty_claude/src/api_types.rs @@ -0,0 +1,176 @@ +use serde::{Deserialize, Serialize}; + +#[derive(Serialize)] +pub(crate) struct MessagesRequest { + pub model: String, + pub max_tokens: u32, + pub messages: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + pub system: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub temperature: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub top_p: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub top_k: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub stop_sequences: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub tools: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_choice: Option, + pub stream: bool, + #[serde(skip_serializing_if = "Option::is_none")] + pub thinking: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub output_config: Option, +} + +#[derive(Serialize, Deserialize, Clone, Debug)] +pub(crate) struct ApiMessage { + pub role: String, + pub content: ApiContent, +} + +#[derive(Serialize, Deserialize, Clone, Debug)] +#[serde(untagged)] +pub(crate) enum ApiContent { + Text(String), + Blocks(Vec), +} + +#[derive(Serialize, Deserialize, Clone, Debug)] +#[serde(tag = "type")] +pub(crate) enum ContentBlock { + #[serde(rename = "text")] + Text { text: String }, + #[serde(rename = "image")] + Image { source: ImageSource }, + #[serde(rename = "tool_use")] + ToolUse { + id: String, + name: String, + input: serde_json::Value, + }, + #[serde(rename = "tool_result")] + ToolResult { + tool_use_id: String, + content: String, + #[serde(skip_serializing_if = "Option::is_none")] + is_error: Option, + }, +} + +#[derive(Serialize, Deserialize, Clone, Debug)] +#[serde(tag = "type", rename_all = "snake_case")] +pub(crate) enum ImageSource { + Base64 { media_type: String, data: String }, + Url { url: String }, +} + +#[derive(Serialize, Debug)] +pub(crate) struct ApiTool { + pub name: String, + pub description: String, + pub input_schema: serde_json::Value, +} + +#[derive(Serialize, Debug)] +pub(crate) struct ApiToolChoice { + #[serde(rename = "type")] + pub choice_type: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub name: Option, +} + +#[derive(Serialize, Debug)] +pub(crate) struct ApiThinkingConfig { + #[serde(rename = "type")] + pub thinking_type: String, +} + +#[derive(Serialize, Debug)] +pub(crate) struct ApiOutputConfig { + pub format: ApiOutputFormat, +} + +#[derive(Serialize, Debug)] +pub(crate) struct ApiOutputFormat { + #[serde(rename = "type")] + pub format_type: String, + pub schema: serde_json::Value, +} + +// --- Response types --- + +#[derive(Deserialize, Debug)] +pub(crate) struct MessagesResponse { + pub id: String, + pub content: Vec, + pub model: String, + pub stop_reason: Option, + pub usage: ApiUsage, +} + +#[derive(Deserialize, Debug, Clone)] +pub(crate) struct ApiUsage { + pub input_tokens: u64, + pub output_tokens: u64, +} + +// --- Streaming event types --- + +#[derive(Deserialize, Debug)] +#[serde(tag = "type")] +pub(crate) enum StreamEvent { + #[serde(rename = "message_start")] + MessageStart { message: MessagesResponse }, + #[serde(rename = "content_block_start")] + ContentBlockStart { + index: usize, + content_block: ContentBlock, + }, + #[serde(rename = "content_block_delta")] + ContentBlockDelta { index: usize, delta: DeltaBlock }, + #[serde(rename = "content_block_stop")] + ContentBlockStop { index: usize }, + #[serde(rename = "message_delta")] + MessageDelta { + delta: MessageDeltaBody, + usage: Option, + }, + #[serde(rename = "message_stop")] + MessageStop, + #[serde(rename = "ping")] + Ping, + #[serde(rename = "error")] + Error { error: ApiError }, +} + +#[derive(Deserialize, Debug)] +#[serde(tag = "type")] +pub(crate) enum DeltaBlock { + #[serde(rename = "text_delta")] + Text { text: String }, + #[serde(rename = "input_json_delta")] + InputJson { partial_json: String }, + #[serde(rename = "thinking_delta")] + Thinking { thinking: String }, + #[serde(rename = "signature_delta")] + Signature { + #[allow(dead_code)] + signature: String, + }, +} + +#[derive(Deserialize, Debug)] +pub(crate) struct MessageDeltaBody { + pub stop_reason: Option, +} + +#[derive(Deserialize, Debug)] +pub(crate) struct ApiError { + pub message: String, + #[serde(rename = "type")] + pub error_type: String, +} diff --git a/crates/rusty_claude/src/convert.rs b/crates/rusty_claude/src/convert.rs new file mode 100644 index 0000000..1b67e51 --- /dev/null +++ b/crates/rusty_claude/src/convert.rs @@ -0,0 +1,314 @@ +use rusty_ai::content::{ContentPart, ImageData}; +use rusty_ai::message::{Message, Role}; +use rusty_ai::model::{GenerateOptions, ThinkingConfig}; +use rusty_ai::prompt::Prompt; +use rusty_ai::structured::GenerateResult; +use rusty_ai::tool::{ToolCallRequest, ToolChoice, ToolDefinition}; +use rusty_ai::types::{FinishReason, ResponseMetadata}; +use rusty_ai::usage::Usage; + +use crate::api_types::{ + ApiContent, ApiMessage, ApiOutputConfig, ApiOutputFormat, ApiThinkingConfig, ApiTool, + ApiToolChoice, ContentBlock, ImageSource, MessagesRequest, MessagesResponse, +}; + +/// Holds the separated system prompt and non-system messages. +pub(crate) struct ConvertedPrompt { + pub system: Option, + pub messages: Vec, +} + +/// Convert a `Prompt` into Anthropic-compatible parts. +/// +/// Anthropic requires the system message as a separate top-level field rather +/// than as a message in the conversation array. +pub(crate) fn convert_prompt(prompt: Prompt) -> ConvertedPrompt { + let messages = prompt.into_messages(); + + let mut system_parts: Vec = Vec::new(); + let mut api_messages: Vec = Vec::new(); + + for msg in messages { + match msg.role { + Role::System => { + for part in &msg.content { + if let ContentPart::Text { text } = part { + system_parts.push(text.clone()); + } + } + } + Role::User => { + let content = convert_content_parts(&msg.content); + api_messages.push(ApiMessage { + role: "user".to_string(), + content, + }); + } + Role::Assistant => { + let content = convert_assistant_content(&msg); + api_messages.push(ApiMessage { + role: "assistant".to_string(), + content, + }); + } + Role::Tool => { + // Tool results in Anthropic are sent as a user message with + // tool_result content blocks. + let blocks = convert_tool_result_content(&msg.content); + api_messages.push(ApiMessage { + role: "user".to_string(), + content: ApiContent::Blocks(blocks), + }); + } + } + } + + let system = if system_parts.is_empty() { + None + } else { + Some(system_parts.join("\n")) + }; + + ConvertedPrompt { + system, + messages: api_messages, + } +} + +/// Convert a list of `ContentPart` values into an `ApiContent`. +fn convert_content_parts(parts: &[ContentPart]) -> ApiContent { + // Fast path: single text part -> use the simple string variant. + if parts.len() == 1 { + if let ContentPart::Text { text } = &parts[0] { + return ApiContent::Text(text.clone()); + } + } + + let blocks: Vec = parts + .iter() + .filter_map(|part| match part { + ContentPart::Text { text } => Some(ContentBlock::Text { text: text.clone() }), + ContentPart::Image { data } => Some(convert_image(data)), + _ => None, + }) + .collect(); + + ApiContent::Blocks(blocks) +} + +/// Convert assistant message content, which may include tool calls. +fn convert_assistant_content(msg: &Message) -> ApiContent { + let blocks: Vec = msg + .content + .iter() + .filter_map(|part| match part { + ContentPart::Text { text } => Some(ContentBlock::Text { text: text.clone() }), + ContentPart::ToolCall { call } => Some(ContentBlock::ToolUse { + id: call.id.clone(), + name: call.name.clone(), + input: call.arguments.clone(), + }), + _ => None, + }) + .collect(); + + ApiContent::Blocks(blocks) +} + +/// Convert tool result content parts into ContentBlock::ToolResult blocks. +fn convert_tool_result_content(parts: &[ContentPart]) -> Vec { + parts + .iter() + .filter_map(|part| match part { + ContentPart::ToolResult { result } => Some(ContentBlock::ToolResult { + tool_use_id: result.call_id.clone(), + content: result.content.clone(), + is_error: if result.is_error { Some(true) } else { None }, + }), + _ => None, + }) + .collect() +} + +fn convert_image(data: &ImageData) -> ContentBlock { + match data { + ImageData::Base64 { media_type, data } => ContentBlock::Image { + source: ImageSource::Base64 { + media_type: media_type.clone(), + data: data.clone(), + }, + }, + ImageData::Url { url, .. } => ContentBlock::Image { + source: ImageSource::Url { url: url.clone() }, + }, + } +} + +/// Convert `ToolDefinition` values to `ApiTool` values. +pub(crate) fn convert_tools(tools: &[ToolDefinition]) -> Vec { + tools + .iter() + .map(|t| ApiTool { + name: t.name.clone(), + description: t.description.clone(), + input_schema: t.parameters.clone(), + }) + .collect() +} + +/// Convert `ToolChoice` to `ApiToolChoice`. +pub(crate) fn convert_tool_choice(choice: &ToolChoice) -> ApiToolChoice { + match choice { + ToolChoice::Auto => ApiToolChoice { + choice_type: "auto".to_string(), + name: None, + }, + ToolChoice::None => ApiToolChoice { + choice_type: "none".to_string(), + name: None, + }, + ToolChoice::Required => ApiToolChoice { + choice_type: "any".to_string(), + name: None, + }, + ToolChoice::Specific(name) => ApiToolChoice { + choice_type: "tool".to_string(), + name: Some(name.clone()), + }, + } +} + +/// Build a `MessagesRequest` from the prompt and options. +pub(crate) fn build_request( + model: &str, + prompt: Prompt, + options: &GenerateOptions, + stream: bool, +) -> MessagesRequest { + let converted = convert_prompt(prompt); + let max_tokens = options.max_tokens.unwrap_or(4096); + + let tools = options + .tools + .as_ref() + .filter(|t| !t.is_empty()) + .map(|t| convert_tools(t)); + + let tool_choice = if tools.is_some() { + options + .tool_choice + .as_ref() + .map(convert_tool_choice) + .or_else(|| { + Some(ApiToolChoice { + choice_type: "auto".to_string(), + name: None, + }) + }) + } else { + None + }; + + let stop_sequences = if options.stop_sequences.is_empty() { + None + } else { + Some(options.stop_sequences.clone()) + }; + + let thinking = options.thinking.as_ref().map(|t| match t { + ThinkingConfig::Adaptive => ApiThinkingConfig { + thinking_type: "adaptive".to_string(), + }, + ThinkingConfig::Enabled => ApiThinkingConfig { + thinking_type: "enabled".to_string(), + }, + ThinkingConfig::Budget { .. } => ApiThinkingConfig { + thinking_type: "adaptive".to_string(), + }, + }); + + let output_config = options + .output_schema + .as_ref() + .map(|schema| ApiOutputConfig { + format: ApiOutputFormat { + format_type: "json_schema".to_string(), + schema: schema.as_value().clone(), + }, + }); + + MessagesRequest { + model: model.to_string(), + max_tokens, + messages: converted.messages, + system: converted.system, + temperature: options.temperature, + top_p: options.top_p, + top_k: options.top_k, + stop_sequences, + tools, + tool_choice, + stream, + thinking, + output_config, + } +} + +/// Map the Anthropic stop_reason string to `FinishReason`. +pub(crate) fn map_stop_reason(reason: Option<&str>) -> FinishReason { + match reason { + Some("end_turn") | Some("stop") => FinishReason::Stop, + Some("max_tokens") => FinishReason::Length, + Some("tool_use") => FinishReason::ToolCall, + Some("content_filter") => FinishReason::ContentFilter, + _ => FinishReason::Unknown, + } +} + +/// Convert a `MessagesResponse` (non-streaming) into a `GenerateResult`. +pub(crate) fn convert_response(response: MessagesResponse) -> GenerateResult { + let mut text_parts: Vec = Vec::new(); + let mut tool_calls: Vec = Vec::new(); + + for block in &response.content { + match block { + ContentBlock::Text { text } => { + text_parts.push(text.clone()); + } + ContentBlock::ToolUse { id, name, input } => { + tool_calls.push(ToolCallRequest { + id: id.clone(), + name: name.clone(), + arguments: input.clone(), + }); + } + _ => {} + } + } + + let text = if text_parts.is_empty() { + None + } else { + Some(text_parts.join("")) + }; + + let finish_reason = map_stop_reason(response.stop_reason.as_deref()); + + let usage = Usage { + prompt_tokens: Some(response.usage.input_tokens), + completion_tokens: Some(response.usage.output_tokens), + total_tokens: Some(response.usage.input_tokens + response.usage.output_tokens), + }; + + GenerateResult { + text, + tool_calls, + finish_reason, + usage, + metadata: ResponseMetadata { + provider: "anthropic".to_string(), + model: response.model, + ..ResponseMetadata::default() + }, + } +} diff --git a/crates/rusty_claude/src/lib.rs b/crates/rusty_claude/src/lib.rs new file mode 100644 index 0000000..d9d9437 --- /dev/null +++ b/crates/rusty_claude/src/lib.rs @@ -0,0 +1,24 @@ +//! Anthropic Claude provider for the Rusty AI SDK. +//! +//! This crate implements the `LanguageModel` and `Provider` traits from +//! `rusty_ai` for the Anthropic Messages API. It is **not** a wrapper around +//! an OpenAI-compatible endpoint -- it speaks Anthropic's native API format +//! directly. +//! +//! # Quick start +//! +//! ```rust,no_run +//! use rusty_claude::ClaudeProvider; +//! +//! let provider = ClaudeProvider::new("sk-ant-..."); +//! let model = provider.claude_sonnet(); +//! ``` + +mod api_types; +mod convert; +mod model; +mod provider; +mod stream_parser; + +pub use model::*; +pub use provider::*; diff --git a/crates/rusty_claude/src/model.rs b/crates/rusty_claude/src/model.rs new file mode 100644 index 0000000..42de62e --- /dev/null +++ b/crates/rusty_claude/src/model.rs @@ -0,0 +1,157 @@ +use async_trait::async_trait; +use secrecy::{ExposeSecret, SecretString}; + +use rusty_ai::capability::{Capability, CapabilitySet}; +use rusty_ai::error::{AiError, AiResult}; +use rusty_ai::model::{GenerateOptions, LanguageModel}; +use rusty_ai::prompt::Prompt; +use rusty_ai::stream::AiStream; +use rusty_ai::structured::GenerateResult; + +use crate::convert; +use crate::stream_parser; + +const ANTHROPIC_VERSION: &str = "2023-06-01"; +const DEFAULT_BASE_URL: &str = "https://api.anthropic.com"; + +/// A Claude language model served by Anthropic. +pub struct ClaudeModel { + api_key: SecretString, + model_id: String, + capabilities: CapabilitySet, + client: reqwest::Client, + base_url: String, +} + +impl ClaudeModel { + /// Create a new `ClaudeModel`. + /// + /// `api_key` is the Anthropic API key and `model_id` is a model identifier + /// such as `"claude-sonnet-4-20250514"`. + pub fn new(api_key: impl Into, model_id: &str) -> Self { + let capabilities = CapabilitySet::new() + .with(Capability::TextInput) + .with(Capability::TextOutput) + .with(Capability::ImageInput) + .with(Capability::Streaming) + .with(Capability::ToolCalling) + .with(Capability::ExtendedThinking) + .with(Capability::StructuredOutput); + + Self { + api_key: SecretString::from(api_key.into()), + model_id: model_id.to_string(), + capabilities, + client: reqwest::Client::new(), + base_url: DEFAULT_BASE_URL.to_string(), + } + } + + /// Override the base URL (useful for proxies or testing). + pub fn with_base_url(mut self, base_url: impl Into) -> Self { + self.base_url = base_url.into(); + self + } + + /// Send a request to the Anthropic Messages API. + async fn send_request( + &self, + prompt: Prompt, + options: &GenerateOptions, + stream: bool, + ) -> AiResult { + let request = convert::build_request(&self.model_id, prompt, options, stream); + + let body = + serde_json::to_string(&request).map_err(|e| AiError::Serialization(e.to_string()))?; + + tracing::debug!(model = %self.model_id, stream = stream, "Sending request to Anthropic"); + + let response = self + .client + .post(format!("{}/v1/messages", self.base_url)) + .header("x-api-key", self.api_key.expose_secret()) + .header("anthropic-version", ANTHROPIC_VERSION) + .header("content-type", "application/json") + .body(body) + .send() + .await + .map_err(|e| AiError::Transport { + message: e.to_string(), + source: Some(Box::new(e)), + })?; + + let status = response.status(); + if !status.is_success() { + let status_code = status.as_u16(); + let body_text = match response.text().await { + Ok(body) => body, + Err(e) => { + tracing::warn!(status = status_code, error = %e, "Failed to read Anthropic error response body"); + format!("") + } + }; + + // Try to parse structured error from Anthropic. + let message = if let Ok(parsed) = serde_json::from_str::(&body_text) + { + parsed["error"]["message"] + .as_str() + .unwrap_or(&body_text) + .to_string() + } else { + body_text + }; + + if status_code == 401 { + return Err(AiError::AuthError { message }); + } + if status_code == 429 { + return Err(AiError::RateLimit { retry_after: None }); + } + + return Err(AiError::ProviderError { + provider: "anthropic".to_string(), + status: Some(status_code), + message, + }); + } + + Ok(response) + } +} + +#[async_trait] +impl LanguageModel for ClaudeModel { + fn model_id(&self) -> &str { + &self.model_id + } + + fn provider_id(&self) -> &str { + "anthropic" + } + + fn capabilities(&self) -> &CapabilitySet { + &self.capabilities + } + + async fn generate(&self, prompt: Prompt, options: GenerateOptions) -> AiResult { + let response = self.send_request(prompt, &options, false).await?; + let body = response.text().await.map_err(|e| AiError::Transport { + message: e.to_string(), + source: Some(Box::new(e)), + })?; + + let api_response: crate::api_types::MessagesResponse = serde_json::from_str(&body) + .map_err(|e| { + AiError::Serialization(format!("Failed to parse Anthropic response: {e}")) + })?; + + Ok(convert::convert_response(api_response)) + } + + async fn stream(&self, prompt: Prompt, options: GenerateOptions) -> AiResult { + let response = self.send_request(prompt, &options, true).await?; + Ok(stream_parser::parse_stream(response)) + } +} diff --git a/crates/rusty_claude/src/provider.rs b/crates/rusty_claude/src/provider.rs new file mode 100644 index 0000000..6a0eb3f --- /dev/null +++ b/crates/rusty_claude/src/provider.rs @@ -0,0 +1,118 @@ +use async_trait::async_trait; +use secrecy::{ExposeSecret, SecretString}; + +use rusty_ai::capability::{Capability, CapabilitySet}; +use rusty_ai::error::{AiError, AiResult}; +use rusty_ai::model::{EmbeddingModel, LanguageModel}; +use rusty_ai::provider::Provider; +use rusty_ai::types::ModelInfo; + +use crate::model::ClaudeModel; + +const DEFAULT_BASE_URL: &str = "https://api.anthropic.com"; + +// ── Latest model aliases ── + +/// Convenience alias. Use `fetch_models()` or pass any model ID string to `model()`. +pub const CLAUDE_OPUS_LATEST: &str = "claude-opus-4-6"; +/// Convenience alias. Use `fetch_models()` or pass any model ID string to `model()`. +pub const CLAUDE_SONNET_LATEST: &str = "claude-sonnet-4-6"; +/// Convenience alias. Use `fetch_models()` or pass any model ID string to `model()`. +pub const CLAUDE_HAIKU_LATEST: &str = "claude-haiku-4-5-20251001"; + +/// Provider for Anthropic Claude models. +pub struct ClaudeProvider { + api_key: SecretString, + base_url: String, +} + +impl ClaudeProvider { + /// Create a new `ClaudeProvider` with the given API key. + pub fn new(api_key: impl Into) -> Self { + Self { + api_key: SecretString::from(api_key.into()), + base_url: DEFAULT_BASE_URL.to_string(), + } + } + + /// Override the base URL (useful for proxies or testing). + pub fn with_base_url(mut self, base_url: impl Into) -> Self { + self.base_url = base_url.into(); + self + } + + /// Get the Claude Sonnet 4.6 model. + pub fn claude_sonnet(&self) -> ClaudeModel { + self.model(CLAUDE_SONNET_LATEST) + } + + /// Get the Claude Opus 4.6 model. + pub fn claude_opus(&self) -> ClaudeModel { + self.model(CLAUDE_OPUS_LATEST) + } + + /// Get the Claude Haiku 4.5 model. + pub fn claude_haiku(&self) -> ClaudeModel { + self.model(CLAUDE_HAIKU_LATEST) + } + + /// Get a model by identifier. Any valid Anthropic model ID is accepted. + pub fn model(&self, model_id: &str) -> ClaudeModel { + ClaudeModel::new(self.api_key.expose_secret(), model_id) + .with_base_url(self.base_url.clone()) + } +} + +#[async_trait] +impl Provider for ClaudeProvider { + fn id(&self) -> &str { + "anthropic" + } + + fn name(&self) -> &str { + "Anthropic" + } + + fn language_model(&self, model_id: &str) -> AiResult> { + Ok(Box::new(self.model(model_id))) + } + + fn embedding_model(&self, model_id: &str) -> AiResult> { + // Anthropic does not offer embedding models. + Err(AiError::ModelUnavailable { + model: model_id.to_string(), + }) + } + + fn available_models(&self) -> Vec { + let caps = CapabilitySet::new() + .with(Capability::TextInput) + .with(Capability::TextOutput) + .with(Capability::ImageInput) + .with(Capability::Streaming) + .with(Capability::ToolCalling) + .with(Capability::ExtendedThinking) + .with(Capability::StructuredOutput); + + vec![ + ModelInfo { + id: CLAUDE_OPUS_LATEST.to_string(), + provider: "anthropic".to_string(), + display_name: "Claude Opus 4.6".to_string(), + capabilities: caps.clone(), + }, + ModelInfo { + id: CLAUDE_SONNET_LATEST.to_string(), + provider: "anthropic".to_string(), + display_name: "Claude Sonnet 4.6".to_string(), + capabilities: caps.clone(), + }, + ModelInfo { + id: CLAUDE_HAIKU_LATEST.to_string(), + provider: "anthropic".to_string(), + display_name: "Claude Haiku 4.5".to_string(), + capabilities: caps, + }, + ] + } +} diff --git a/crates/rusty_claude/src/stream_parser.rs b/crates/rusty_claude/src/stream_parser.rs new file mode 100644 index 0000000..216ff87 --- /dev/null +++ b/crates/rusty_claude/src/stream_parser.rs @@ -0,0 +1,296 @@ +use std::collections::HashMap; + +use futures::stream::{self, StreamExt}; +use reqwest::Response; +use rusty_ai::error::AiError; +use rusty_ai::stream::{AiStream, StreamEvent as RustyStreamEvent}; +use rusty_ai::Usage; + +use crate::api_types::{ContentBlock, DeltaBlock, StreamEvent as AnthropicEvent}; +use crate::convert::map_stop_reason; + +/// State tracked while parsing a streaming response from the Anthropic API. +struct StreamState { + message_id: String, + /// Maps content-block index to tool-call metadata (id, name, accumulated + /// JSON fragments). + active_tool_calls: HashMap, + input_tokens: u64, + output_tokens: u64, +} + +struct ToolCallState { + id: String, + #[allow(dead_code)] + name: String, + json_buf: String, +} + +impl StreamState { + fn new() -> Self { + Self { + message_id: String::new(), + active_tool_calls: HashMap::new(), + input_tokens: 0, + output_tokens: 0, + } + } +} + +/// Parse an SSE response from the Anthropic Messages API into an `AiStream`. +pub(crate) fn parse_stream(response: Response) -> AiStream { + let byte_stream = response.bytes_stream(); + + // We'll buffer bytes and split by SSE boundary ("\n\n"). + let event_stream = futures::stream::unfold( + (byte_stream, String::new()), + |(mut byte_stream, mut buffer)| async move { + loop { + // Try to extract a complete SSE event from the buffer. + if let Some(pos) = buffer.find("\n\n") { + let event_text = buffer[..pos].to_string(); + buffer = buffer[pos + 2..].to_string(); + let events = parse_sse_event(&event_text); + if !events.is_empty() { + return Some((events, (byte_stream, buffer))); + } + continue; + } + + // Need more data. + match byte_stream.next().await { + Some(Ok(chunk)) => { + let text = String::from_utf8_lossy(&chunk); + buffer.push_str(&text); + } + Some(Err(e)) => { + let msg = e.to_string(); + return Some(( + vec![Err(AiError::Transport { + message: msg, + source: Some(Box::new(e)), + })], + (byte_stream, buffer), + )); + } + None => { + // End of stream. Try to parse any remaining data. + if !buffer.trim().is_empty() { + let events = parse_sse_event(&buffer); + buffer.clear(); + if !events.is_empty() { + return Some((events, (byte_stream, buffer))); + } + } + return None; + } + } + } + }, + ) + .flat_map(stream::iter); + + // Now map Anthropic events to rusty_ai StreamEvents using stateful processing. + let mapped = futures::stream::unfold( + (Box::pin(event_stream), StreamState::new()), + |(mut event_stream, mut state)| async move { + loop { + match event_stream.next().await { + Some(Ok(anthropic_event)) => { + let rusty_events = map_event(anthropic_event, &mut state); + if !rusty_events.is_empty() { + let items: Vec> = + rusty_events.into_iter().map(Ok).collect(); + return Some((stream::iter(items), (event_stream, state))); + } + // Event produced no output (e.g. Ping), continue. + continue; + } + Some(Err(e)) => { + return Some((stream::iter(vec![Err(e)]), (event_stream, state))); + } + None => return None, + } + } + }, + ) + .flat_map(|items| items); + + Box::pin(mapped) +} + +/// Parse a single SSE text block (lines between double-newlines) into zero or +/// more `AnthropicEvent` results. +fn parse_sse_event(raw: &str) -> Vec> { + let mut event_type: Option<&str> = None; + let mut data_lines: Vec<&str> = Vec::new(); + + for line in raw.lines() { + let line = line.trim(); + if line.is_empty() { + continue; + } + if let Some(rest) = line.strip_prefix("event:") { + event_type = Some(rest.trim()); + } else if let Some(rest) = line.strip_prefix("data:") { + data_lines.push(rest.trim()); + } + // Ignore other SSE fields (id:, retry:, comments starting with :). + } + + if data_lines.is_empty() { + return Vec::new(); + } + + let data = data_lines.join("\n"); + let _ = event_type; // The type is embedded in the JSON payload. + + match serde_json::from_str::(&data) { + Ok(event) => vec![Ok(event)], + Err(e) => { + tracing::error!(data = %data, error = %e, "Failed to parse Anthropic SSE event; terminating stream"); + vec![Err(AiError::StreamError { + message: format!("Unparseable SSE event from Anthropic: {e}"), + })] + } + } +} + +/// Map a single `AnthropicEvent` into zero or more `RustyStreamEvent` values. +fn map_event(event: AnthropicEvent, state: &mut StreamState) -> Vec { + match event { + AnthropicEvent::MessageStart { message } => { + state.message_id = message.id.clone(); + state.input_tokens = message.usage.input_tokens; + state.output_tokens = message.usage.output_tokens; + vec![ + RustyStreamEvent::MessageStart { + message_id: message.id, + }, + RustyStreamEvent::UsageDelta { + usage: Usage { + prompt_tokens: Some(message.usage.input_tokens), + completion_tokens: Some(message.usage.output_tokens), + total_tokens: Some( + message.usage.input_tokens + message.usage.output_tokens, + ), + }, + }, + ] + } + + AnthropicEvent::ContentBlockStart { + index, + content_block, + } => match content_block { + ContentBlock::Text { text } => { + if text.is_empty() { + Vec::new() + } else { + vec![RustyStreamEvent::TextDelta { delta: text }] + } + } + ContentBlock::ToolUse { id, name, .. } => { + state.active_tool_calls.insert( + index, + ToolCallState { + id: id.clone(), + name: name.clone(), + json_buf: String::new(), + }, + ); + vec![RustyStreamEvent::ToolCallStart { + call_id: id, + tool_name: name, + }] + } + _ => Vec::new(), + }, + + AnthropicEvent::ContentBlockDelta { index, delta } => match delta { + DeltaBlock::Text { text } => { + vec![RustyStreamEvent::TextDelta { delta: text }] + } + DeltaBlock::InputJson { partial_json } => { + if let Some(tc) = state.active_tool_calls.get_mut(&index) { + tc.json_buf.push_str(&partial_json); + vec![RustyStreamEvent::ToolCallDelta { + call_id: tc.id.clone(), + delta: partial_json, + }] + } else { + Vec::new() + } + } + DeltaBlock::Thinking { thinking } => { + vec![RustyStreamEvent::ThinkingDelta { delta: thinking }] + } + DeltaBlock::Signature { .. } => Vec::new(), + }, + + AnthropicEvent::ContentBlockStop { index } => { + if let Some(tc) = state.active_tool_calls.remove(&index) { + match serde_json::from_str(&tc.json_buf) { + Ok(arguments) => vec![RustyStreamEvent::ToolCallEnd { + call_id: tc.id, + arguments, + }], + Err(e) => { + tracing::error!( + call_id = %tc.id, + tool_name = %tc.name, + error = %e, + "Malformed tool call JSON buffer; cannot reconstruct arguments" + ); + vec![RustyStreamEvent::Error { + error: format!( + "Malformed tool call arguments for `{}`: {e}", + tc.name + ), + }] + } + } + } else { + Vec::new() + } + } + + AnthropicEvent::MessageDelta { delta, usage } => { + let finish_reason = map_stop_reason(delta.stop_reason.as_deref()); + + let usage = usage.map(|u| { + state.output_tokens = u.output_tokens; + Usage { + prompt_tokens: None, + completion_tokens: Some(u.output_tokens), + total_tokens: None, + } + }); + + if let Some(u) = &usage { + return vec![ + RustyStreamEvent::UsageDelta { usage: u.clone() }, + RustyStreamEvent::MessageEnd { + finish_reason, + usage: None, + }, + ]; + } + + vec![RustyStreamEvent::MessageEnd { + finish_reason, + usage: None, + }] + } + + AnthropicEvent::MessageStop => Vec::new(), + + AnthropicEvent::Ping => Vec::new(), + + AnthropicEvent::Error { error } => { + vec![RustyStreamEvent::Error { + error: format!("{}: {}", error.error_type, error.message), + }] + } + } +} diff --git a/crates/rusty_foundationmodels/Cargo.toml b/crates/rusty_foundationmodels/Cargo.toml new file mode 100644 index 0000000..6c4c565 --- /dev/null +++ b/crates/rusty_foundationmodels/Cargo.toml @@ -0,0 +1,18 @@ +[package] +name = "rusty_foundationmodels" +version.workspace = true +edition.workspace = true +license.workspace = true +description = "Apple Foundation Models bridge for the Rusty AI SDK (see undivisible/rusty_foundationmodels)" + +[dependencies] +rusty_ai = { workspace = true } +async-trait = { workspace = true } +futures = { workspace = true } +tokio = { workspace = true } +tokio-stream = { workspace = true } +serde = { workspace = true } +serde_json = { workspace = true } +tracing = { workspace = true } +uuid = { workspace = true } +chrono = { workspace = true } diff --git a/crates/rusty_foundationmodels/src/bridge.rs b/crates/rusty_foundationmodels/src/bridge.rs new file mode 100644 index 0000000..74258fa --- /dev/null +++ b/crates/rusty_foundationmodels/src/bridge.rs @@ -0,0 +1,28 @@ +use async_trait::async_trait; + +use crate::types::{AppleModelAvailability, FoundationModelConfig}; + +/// Trait that must be implemented by the host application to bridge +/// to Apple's Foundation Models framework via Swift/ObjC interop. +/// +/// See `undivisible/rusty_foundationmodels` for the reference Swift bridge. +#[async_trait] +pub trait FoundationModelBridge: Send + Sync { + /// Check model availability on this device. + async fn availability(&self) -> AppleModelAvailability; + + /// Generate text from a prompt. + async fn generate( + &self, + prompt: &str, + config: &FoundationModelConfig, + ) -> Result; + + /// Stream text from a prompt, returning chunks. + /// Apple Foundation Models supports streaming via AsyncSequence. + async fn stream( + &self, + prompt: &str, + config: &FoundationModelConfig, + ) -> Result, String>; +} diff --git a/crates/rusty_foundationmodels/src/lib.rs b/crates/rusty_foundationmodels/src/lib.rs new file mode 100644 index 0000000..77abfd6 --- /dev/null +++ b/crates/rusty_foundationmodels/src/lib.rs @@ -0,0 +1,17 @@ +//! Apple Foundation Models bridge for the Rusty AI SDK. +//! +//! This crate provides integration with Apple's Foundation Models framework. +//! The actual Swift/ObjC bridge lives in the separate `undivisible/rusty_foundationmodels` +//! repository. This crate provides the [`rusty_ai::LanguageModel`] integration layer. +//! +//! Host applications must implement the [`FoundationModelBridge`] trait. + +mod bridge; +mod model; +mod provider; +mod types; + +pub use bridge::*; +pub use model::*; +pub use provider::*; +pub use types::*; diff --git a/crates/rusty_foundationmodels/src/model.rs b/crates/rusty_foundationmodels/src/model.rs new file mode 100644 index 0000000..5013623 --- /dev/null +++ b/crates/rusty_foundationmodels/src/model.rs @@ -0,0 +1,149 @@ +use std::sync::Arc; + +use async_trait::async_trait; +use futures::stream; + +use rusty_ai::{ + AiError, AiResult, AiStream, Capability, CapabilitySet, ContentPart, FinishReason, + GenerateOptions, GenerateResult, LanguageModel, Prompt, ResponseMetadata, StreamEvent, + SyntheticStreamer, Usage, +}; + +use crate::bridge::FoundationModelBridge; +use crate::types::{AppleModelAvailability, FoundationModelConfig}; + +/// A [`LanguageModel`] backed by Apple Foundation Models running on-device. +/// +/// When the bridge supports streaming, real chunks are emitted as stream +/// events. Otherwise, [`SyntheticStreamer`] is used as a fallback. +pub struct FoundationModel { + bridge: Arc, + capabilities: CapabilitySet, +} + +impl FoundationModel { + pub(crate) fn new(bridge: Arc) -> Self { + let capabilities = CapabilitySet::new() + .with(Capability::TextInput) + .with(Capability::TextOutput) + .with(Capability::LocalExecution) + .with(Capability::PlatformNative); + Self { + bridge, + capabilities, + } + } + + fn build_config(options: &GenerateOptions) -> FoundationModelConfig { + FoundationModelConfig { + temperature: options.temperature, + max_tokens: options.max_tokens, + } + } + + fn prompt_to_text(prompt: &Prompt) -> String { + match prompt { + Prompt::Text(t) => t.clone(), + Prompt::Messages(msgs) => msgs + .iter() + .flat_map(|m| { + m.content.iter().filter_map(|c| match c { + ContentPart::Text { text } => Some(text.clone()), + _ => None, + }) + }) + .collect::>() + .join("\n"), + } + } + + async fn ensure_available(&self) -> AiResult<()> { + match self.bridge.availability().await { + AppleModelAvailability::Available => Ok(()), + AppleModelAvailability::Unavailable { reason } => Err(AiError::PlatformUnavailable { + platform: format!("apple/foundation_models: {reason}"), + }), + AppleModelAvailability::NeedsDownload => Err(AiError::ModelUnavailable { + model: "apple-foundation-model (needs download)".into(), + }), + } + } +} + +#[async_trait] +impl LanguageModel for FoundationModel { + fn model_id(&self) -> &str { + "apple-foundation-model" + } + + fn provider_id(&self) -> &str { + "foundationmodels" + } + + fn capabilities(&self) -> &CapabilitySet { + &self.capabilities + } + + async fn generate(&self, prompt: Prompt, options: GenerateOptions) -> AiResult { + self.ensure_available().await?; + + let config = Self::build_config(&options); + let prompt_text = Self::prompt_to_text(&prompt); + let start = std::time::Instant::now(); + + let response = self + .bridge + .generate(&prompt_text, &config) + .await + .map_err(|e| AiError::BridgeError { + bridge: "foundationmodels".into(), + message: e, + })?; + + let latency_ms = start.elapsed().as_millis() as u64; + + Ok(GenerateResult { + text: Some(response), + tool_calls: vec![], + finish_reason: FinishReason::Stop, + usage: Usage::default(), + metadata: ResponseMetadata { + provider: "foundationmodels".into(), + model: "apple-foundation-model".into(), + latency_ms: Some(latency_ms), + ..Default::default() + }, + }) + } + + async fn stream(&self, prompt: Prompt, options: GenerateOptions) -> AiResult { + self.ensure_available().await?; + + let config = Self::build_config(&options); + let prompt_text = Self::prompt_to_text(&prompt); + + // Try the bridge's streaming method first (AsyncSequence-backed). + match self.bridge.stream(&prompt_text, &config).await { + Ok(chunks) if !chunks.is_empty() => { + let mut events: Vec> = Vec::new(); + events.push(Ok(StreamEvent::MessageStart { + message_id: uuid::Uuid::new_v4().to_string(), + })); + for chunk in chunks { + events.push(Ok(StreamEvent::TextDelta { delta: chunk })); + } + events.push(Ok(StreamEvent::MessageEnd { + finish_reason: FinishReason::Stop, + usage: None, + })); + Ok(Box::pin(stream::iter(events))) + } + _ => { + // Fallback: generate fully and use synthetic streaming. + let result = self.generate(Prompt::Text(prompt_text), options).await?; + let text = result.text.unwrap_or_default(); + Ok(SyntheticStreamer::stream(text, 20)) + } + } + } +} diff --git a/crates/rusty_foundationmodels/src/provider.rs b/crates/rusty_foundationmodels/src/provider.rs new file mode 100644 index 0000000..c3ff778 --- /dev/null +++ b/crates/rusty_foundationmodels/src/provider.rs @@ -0,0 +1,67 @@ +use std::sync::Arc; + +use rusty_ai::{ + AiError, AiResult, Capability, CapabilitySet, EmbeddingModel, LanguageModel, ModelInfo, + Provider, +}; + +use crate::bridge::FoundationModelBridge; +use crate::model::FoundationModel; +use crate::types::AppleModelAvailability; + +/// Provider for Apple Foundation Models on-device inference. +pub struct FoundationModelProvider { + bridge: Arc, +} + +impl FoundationModelProvider { + pub fn new(bridge: impl FoundationModelBridge + 'static) -> Self { + Self { + bridge: Arc::new(bridge), + } + } + + /// Get the Foundation Model language model. + pub fn model(&self) -> FoundationModel { + FoundationModel::new(self.bridge.clone()) + } + + /// Check model availability on this device. + pub async fn availability(&self) -> AppleModelAvailability { + self.bridge.availability().await + } +} + +impl Provider for FoundationModelProvider { + fn id(&self) -> &str { + "foundationmodels" + } + + fn name(&self) -> &str { + "Apple Foundation Models" + } + + fn language_model(&self, _model_id: &str) -> AiResult> { + Ok(Box::new(self.model())) + } + + fn embedding_model(&self, model_id: &str) -> AiResult> { + Err(AiError::UnsupportedCapability { + capability: "embeddings".into(), + provider: format!("foundationmodels/{model_id}"), + }) + } + + fn available_models(&self) -> Vec { + vec![ModelInfo { + id: "apple-foundation-model".into(), + provider: "foundationmodels".into(), + display_name: "Apple Foundation Model (On-Device)".into(), + capabilities: CapabilitySet::new() + .with(Capability::TextInput) + .with(Capability::TextOutput) + .with(Capability::LocalExecution) + .with(Capability::PlatformNative), + }] + } +} diff --git a/crates/rusty_foundationmodels/src/types.rs b/crates/rusty_foundationmodels/src/types.rs new file mode 100644 index 0000000..1f4a3d3 --- /dev/null +++ b/crates/rusty_foundationmodels/src/types.rs @@ -0,0 +1,17 @@ +/// Availability state of Apple Foundation Models on this device. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum AppleModelAvailability { + /// The model is available and ready to use. + Available, + /// The model is not available. + Unavailable { reason: String }, + /// The model needs to be downloaded first. + NeedsDownload, +} + +/// Configuration for Foundation Model generation. +#[derive(Debug, Clone, Default)] +pub struct FoundationModelConfig { + pub temperature: Option, + pub max_tokens: Option, +} diff --git a/crates/rusty_gemini/Cargo.toml b/crates/rusty_gemini/Cargo.toml new file mode 100644 index 0000000..ce2ea4b --- /dev/null +++ b/crates/rusty_gemini/Cargo.toml @@ -0,0 +1,20 @@ +[package] +name = "rusty_gemini" +version.workspace = true +edition.workspace = true +license.workspace = true +description = "Google Gemini provider for the Rusty AI SDK" + +[dependencies] +rusty_ai = { workspace = true } +async-trait = { workspace = true } +futures = { workspace = true } +tokio = { workspace = true } +tokio-stream = { workspace = true } +reqwest = { workspace = true } +serde = { workspace = true } +serde_json = { workspace = true } +tracing = { workspace = true } +uuid = { workspace = true } +chrono = { workspace = true } +secrecy = { workspace = true } diff --git a/crates/rusty_gemini/src/api_types.rs b/crates/rusty_gemini/src/api_types.rs new file mode 100644 index 0000000..6a2dff2 --- /dev/null +++ b/crates/rusty_gemini/src/api_types.rs @@ -0,0 +1,131 @@ +use serde::{Deserialize, Serialize}; + +#[derive(Serialize)] +pub(crate) struct GenerateContentRequest { + pub contents: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + pub system_instruction: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub generation_config: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub tools: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_config: Option, +} + +#[derive(Serialize, Deserialize, Clone, Debug)] +pub(crate) struct GeminiContent { + #[serde(skip_serializing_if = "Option::is_none")] + pub role: Option, + pub parts: Vec, +} + +#[derive(Serialize, Deserialize, Clone, Debug)] +#[serde(untagged)] +pub(crate) enum GeminiPart { + Text { text: String }, + Thought { thought: bool, text: String }, + InlineData { inline_data: InlineData }, + FunctionCall { function_call: FunctionCall }, + FunctionResponse { function_response: FunctionResponse }, +} + +#[derive(Serialize, Deserialize, Clone, Debug)] +pub(crate) struct InlineData { + pub mime_type: String, + pub data: String, +} + +#[derive(Serialize, Deserialize, Clone, Debug)] +pub(crate) struct FunctionCall { + #[serde(skip_serializing_if = "Option::is_none")] + pub id: Option, + pub name: String, + pub args: serde_json::Value, +} + +#[derive(Serialize, Deserialize, Clone, Debug)] +pub(crate) struct FunctionResponse { + #[serde(skip_serializing_if = "Option::is_none")] + pub id: Option, + pub name: String, + pub response: serde_json::Value, +} + +#[derive(Serialize)] +pub(crate) struct GenerationConfig { + #[serde(skip_serializing_if = "Option::is_none")] + pub temperature: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub max_output_tokens: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub top_p: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub top_k: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub stop_sequences: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub response_mime_type: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub response_schema: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub thinking_config: Option, +} + +#[derive(Serialize)] +pub(crate) struct ThinkingConfig { + /// Token budget for thinking (Gemini 2.5+). + #[serde(skip_serializing_if = "Option::is_none")] + pub thinking_budget: Option, + /// Thinking level for Gemini 3 Flash: "minimal", "low", "medium", "high". + #[serde(skip_serializing_if = "Option::is_none")] + pub thinking_level: Option, +} + +#[derive(Serialize)] +pub(crate) struct GeminiTool { + pub function_declarations: Vec, +} + +#[derive(Serialize)] +pub(crate) struct FunctionDeclaration { + pub name: String, + pub description: String, + pub parameters: serde_json::Value, +} + +#[derive(Serialize)] +pub(crate) struct ToolConfig { + pub function_calling_config: FunctionCallingConfig, +} + +#[derive(Serialize)] +pub(crate) struct FunctionCallingConfig { + pub mode: String, +} + +// Response types + +#[derive(Deserialize, Debug)] +pub(crate) struct GenerateContentResponse { + pub candidates: Option>, + #[serde(rename = "usageMetadata")] + pub usage_metadata: Option, +} + +#[derive(Deserialize, Debug)] +pub(crate) struct Candidate { + pub content: Option, + #[serde(rename = "finishReason")] + pub finish_reason: Option, +} + +#[derive(Deserialize, Debug)] +pub(crate) struct UsageMetadata { + #[serde(rename = "promptTokenCount")] + pub prompt_token_count: Option, + #[serde(rename = "candidatesTokenCount")] + pub candidates_token_count: Option, + #[serde(rename = "totalTokenCount")] + pub total_token_count: Option, +} diff --git a/crates/rusty_gemini/src/convert.rs b/crates/rusty_gemini/src/convert.rs new file mode 100644 index 0000000..8be1ebb --- /dev/null +++ b/crates/rusty_gemini/src/convert.rs @@ -0,0 +1,289 @@ +use rusty_ai::{ + ContentPart, FinishReason, GenerateOptions, GenerateResult, ImageData, Prompt, + ResponseMetadata, Role, ThinkingConfig as CoreThinkingConfig, ToolCallRequest, ToolChoice, + Usage, +}; + +use crate::api_types::*; + +/// The parts of a Gemini API request built from a `Prompt` and `GenerateOptions`. +pub(crate) struct GeminiRequestParts { + pub contents: Vec, + pub system_instruction: Option, + pub generation_config: Option, + pub tools: Option>, + pub tool_config: Option, +} + +/// Separate system messages from conversation messages and build the Gemini request parts. +pub(crate) fn build_request(prompt: Prompt, options: &GenerateOptions) -> GeminiRequestParts { + let messages = prompt.into_messages(); + + let mut system_parts: Vec = Vec::new(); + let mut contents: Vec = Vec::new(); + + for msg in messages { + match msg.role { + Role::System => { + for part in &msg.content { + if let Some(gp) = content_part_to_gemini(part) { + system_parts.push(gp); + } + } + } + Role::User => { + let parts = msg + .content + .iter() + .filter_map(content_part_to_gemini) + .collect(); + contents.push(GeminiContent { + role: Some("user".to_string()), + parts, + }); + } + Role::Assistant => { + let parts = msg + .content + .iter() + .filter_map(content_part_to_gemini) + .collect(); + contents.push(GeminiContent { + role: Some("model".to_string()), + parts, + }); + } + Role::Tool => { + let parts = msg + .content + .iter() + .filter_map(content_part_to_gemini) + .collect(); + contents.push(GeminiContent { + role: Some("user".to_string()), + parts, + }); + } + } + } + + let system_instruction = if system_parts.is_empty() { + None + } else { + Some(GeminiContent { + role: None, + parts: system_parts, + }) + }; + + let generation_config = build_generation_config(options); + let (tools, tool_config) = build_tools(options); + + GeminiRequestParts { + contents, + system_instruction, + generation_config, + tools, + tool_config, + } +} + +fn content_part_to_gemini(part: &ContentPart) -> Option { + match part { + ContentPart::Text { text } => Some(GeminiPart::Text { text: text.clone() }), + ContentPart::Image { data } => match data { + ImageData::Base64 { media_type, data } => Some(GeminiPart::InlineData { + inline_data: InlineData { + mime_type: media_type.clone(), + data: data.clone(), + }, + }), + ImageData::Url { url, .. } => { + // Gemini doesn't natively support image URLs in inline_data; + // pass as text reference as a fallback. + Some(GeminiPart::Text { + text: format!("[Image URL: {}]", url), + }) + } + }, + ContentPart::ToolCall { call } => Some(GeminiPart::FunctionCall { + function_call: FunctionCall { + id: None, + name: call.name.clone(), + args: call.arguments.clone(), + }, + }), + ContentPart::ToolResult { result } => { + let response = serde_json::json!({ + "content": result.content, + "is_error": result.is_error, + }); + Some(GeminiPart::FunctionResponse { + function_response: FunctionResponse { + id: None, + name: result.call_id.clone(), + response, + }, + }) + } + ContentPart::File { .. } => None, + } +} + +fn build_generation_config(options: &GenerateOptions) -> Option { + let max_tokens = options.max_tokens.or(Some(8192)); + + let stop_sequences = if options.stop_sequences.is_empty() { + None + } else { + Some(options.stop_sequences.clone()) + }; + + let (response_mime_type, response_schema) = if let Some(schema) = &options.output_schema { + ( + Some("application/json".to_string()), + Some(schema.as_value().clone()), + ) + } else { + (None, None) + }; + + let thinking_config = options.thinking.as_ref().map(|t| match t { + CoreThinkingConfig::Budget { tokens } => ThinkingConfig { + thinking_budget: Some(*tokens), + thinking_level: None, + }, + CoreThinkingConfig::Adaptive | CoreThinkingConfig::Enabled => ThinkingConfig { + thinking_budget: Some(8192), + thinking_level: None, + }, + }); + + Some(GenerationConfig { + temperature: options.temperature, + max_output_tokens: max_tokens, + top_p: options.top_p, + top_k: options.top_k, + stop_sequences, + response_mime_type, + response_schema, + thinking_config, + }) +} + +fn build_tools(options: &GenerateOptions) -> (Option>, Option) { + let tool_defs = match &options.tools { + Some(tools) if !tools.is_empty() => tools, + _ => return (None, None), + }; + + let declarations: Vec = tool_defs + .iter() + .map(|t| FunctionDeclaration { + name: t.name.clone(), + description: t.description.clone(), + parameters: t.parameters.clone(), + }) + .collect(); + + let tools = vec![GeminiTool { + function_declarations: declarations, + }]; + + let mode = match &options.tool_choice { + Some(ToolChoice::None) => "NONE", + Some(ToolChoice::Required) => "ANY", + Some(ToolChoice::Auto) | None => "AUTO", + Some(ToolChoice::Specific(_)) => "ANY", + }; + + let tool_config = ToolConfig { + function_calling_config: FunctionCallingConfig { + mode: mode.to_string(), + }, + }; + + (Some(tools), Some(tool_config)) +} + +/// Convert a Gemini API response into our GenerateResult. +pub(crate) fn response_to_result( + response: GenerateContentResponse, + model_id: &str, +) -> GenerateResult { + let mut text_parts: Vec = Vec::new(); + let mut tool_calls: Vec = Vec::new(); + let mut finish_reason = FinishReason::Unknown; + + if let Some(candidates) = &response.candidates { + if let Some(candidate) = candidates.first() { + if let Some(ref reason) = candidate.finish_reason { + finish_reason = map_finish_reason(reason); + } + + if let Some(ref content) = candidate.content { + for part in &content.parts { + match part { + GeminiPart::Text { text } => { + text_parts.push(text.clone()); + } + GeminiPart::FunctionCall { function_call } => { + tool_calls.push(ToolCallRequest { + id: uuid::Uuid::new_v4().to_string(), + name: function_call.name.clone(), + arguments: function_call.args.clone(), + }); + } + _ => {} + } + } + } + } + } + + let usage = map_usage(response.usage_metadata.as_ref()); + + let combined_text = if text_parts.is_empty() { + None + } else { + Some(text_parts.join("")) + }; + + if !tool_calls.is_empty() && finish_reason == FinishReason::Stop { + finish_reason = FinishReason::ToolCall; + } + + GenerateResult { + text: combined_text, + tool_calls, + finish_reason, + usage, + metadata: ResponseMetadata { + request_id: uuid::Uuid::new_v4(), + provider: "gemini".to_string(), + model: model_id.to_string(), + latency_ms: None, + extra: Default::default(), + }, + } +} + +pub(crate) fn map_finish_reason(reason: &str) -> FinishReason { + match reason { + "STOP" => FinishReason::Stop, + "MAX_TOKENS" => FinishReason::Length, + "SAFETY" | "RECITATION" | "BLOCKLIST" | "PROHIBITED_CONTENT" => FinishReason::ContentFilter, + _ => FinishReason::Unknown, + } +} + +pub(crate) fn map_usage(meta: Option<&UsageMetadata>) -> Usage { + match meta { + Some(m) => Usage { + prompt_tokens: m.prompt_token_count, + completion_tokens: m.candidates_token_count, + total_tokens: m.total_token_count, + }, + None => Usage::default(), + } +} diff --git a/crates/rusty_gemini/src/lib.rs b/crates/rusty_gemini/src/lib.rs new file mode 100644 index 0000000..840cbc6 --- /dev/null +++ b/crates/rusty_gemini/src/lib.rs @@ -0,0 +1,10 @@ +//! Google Gemini provider for the Rusty AI SDK. + +mod api_types; +mod convert; +mod model; +mod provider; +mod stream_parser; + +pub use model::*; +pub use provider::*; diff --git a/crates/rusty_gemini/src/model.rs b/crates/rusty_gemini/src/model.rs new file mode 100644 index 0000000..5c626aa --- /dev/null +++ b/crates/rusty_gemini/src/model.rs @@ -0,0 +1,174 @@ +use async_trait::async_trait; +use secrecy::{ExposeSecret, SecretString}; + +use rusty_ai::{ + AiError, AiResult, AiStream, Capability, CapabilitySet, GenerateOptions, GenerateResult, + LanguageModel, Prompt, +}; + +use crate::api_types::GenerateContentRequest; +use crate::convert::{build_request, response_to_result, GeminiRequestParts}; +use crate::stream_parser; + +const BASE_URL: &str = "https://generativelanguage.googleapis.com/v1beta/models"; + +/// A Google Gemini language model. +pub struct GeminiModel { + api_key: SecretString, + model_id: String, + capabilities: CapabilitySet, + client: reqwest::Client, +} + +impl GeminiModel { + /// Create a new Gemini model instance. + pub fn new(api_key: impl Into, model_id: &str) -> Self { + let capabilities = CapabilitySet::new() + .with(Capability::TextInput) + .with(Capability::TextOutput) + .with(Capability::ImageInput) + .with(Capability::Streaming) + .with(Capability::ToolCalling) + .with(Capability::StructuredOutput) + .with(Capability::ExtendedThinking) + .with(Capability::VideoInput) + .with(Capability::AudioInput); + + Self { + api_key: SecretString::from(api_key.into()), + model_id: model_id.to_string(), + capabilities, + client: reqwest::Client::new(), + } + } + + fn generate_url(&self) -> String { + format!( + "{}/{}:generateContent?key={}", + BASE_URL, + self.model_id, + self.api_key.expose_secret() + ) + } + + fn stream_url(&self) -> String { + format!( + "{}/{}:streamGenerateContent?key={}&alt=sse", + BASE_URL, + self.model_id, + self.api_key.expose_secret() + ) + } + + fn build_api_request( + &self, + prompt: Prompt, + options: &GenerateOptions, + ) -> GenerateContentRequest { + let GeminiRequestParts { + contents, + system_instruction, + generation_config, + tools, + tool_config, + } = build_request(prompt, options); + + GenerateContentRequest { + contents, + system_instruction, + generation_config, + tools, + tool_config, + } + } +} + +#[async_trait] +impl LanguageModel for GeminiModel { + fn model_id(&self) -> &str { + &self.model_id + } + + fn provider_id(&self) -> &str { + "gemini" + } + + fn capabilities(&self) -> &CapabilitySet { + &self.capabilities + } + + async fn generate(&self, prompt: Prompt, options: GenerateOptions) -> AiResult { + let request_body = self.build_api_request(prompt, &options); + + let response = self + .client + .post(self.generate_url()) + .header("Content-Type", "application/json") + .json(&request_body) + .send() + .await + .map_err(|e| AiError::Transport { + message: e.to_string(), + source: Some(Box::new(e)), + })?; + + let status = response.status(); + if !status.is_success() { + let status_code = status.as_u16(); + let body = match response.text().await { + Ok(body) => body, + Err(e) => { + tracing::warn!(status = status_code, error = %e, "Failed to read Gemini error response body"); + format!("") + } + }; + return Err(AiError::ProviderError { + provider: "gemini".to_string(), + status: Some(status_code), + message: body, + }); + } + + let api_response: crate::api_types::GenerateContentResponse = response + .json() + .await + .map_err(|e| AiError::Serialization(e.to_string()))?; + + Ok(response_to_result(api_response, &self.model_id)) + } + + async fn stream(&self, prompt: Prompt, options: GenerateOptions) -> AiResult { + let request_body = self.build_api_request(prompt, &options); + + let response = self + .client + .post(self.stream_url()) + .header("Content-Type", "application/json") + .json(&request_body) + .send() + .await + .map_err(|e| AiError::Transport { + message: e.to_string(), + source: Some(Box::new(e)), + })?; + + let status = response.status(); + if !status.is_success() { + let status_code = status.as_u16(); + let body = match response.text().await { + Ok(body) => body, + Err(e) => { + tracing::warn!(status = status_code, error = %e, "Failed to read Gemini error response body"); + format!("") + } + }; + return Err(AiError::ProviderError { + provider: "gemini".to_string(), + status: Some(status_code), + message: body, + }); + } + + Ok(stream_parser::parse_stream(response)) + } +} diff --git a/crates/rusty_gemini/src/provider.rs b/crates/rusty_gemini/src/provider.rs new file mode 100644 index 0000000..83a3ad6 --- /dev/null +++ b/crates/rusty_gemini/src/provider.rs @@ -0,0 +1,127 @@ +use secrecy::SecretString; + +use crate::model::GeminiModel; + +// ── Latest model aliases ── + +/// Convenience alias. Use `fetch_models()` or pass any model ID string to `model()`. +pub const GEMINI_PRO_LATEST: &str = "gemini-2.5-pro"; +/// Convenience alias. Use `fetch_models()` or pass any model ID string to `model()`. +pub const GEMINI_FLASH_LATEST: &str = "gemini-2.5-flash"; +/// Convenience alias. Use `fetch_models()` or pass any model ID string to `model()`. +pub const GEMINI_FLASH_LITE_LATEST: &str = "gemini-2.5-flash-lite"; +/// Convenience alias. Use `fetch_models()` or pass any model ID string to `model()`. +pub const GEMINI_PRO_PREVIEW_LATEST: &str = "gemini-3.1-pro-preview"; +/// Convenience alias. Use `fetch_models()` or pass any model ID string to `model()`. +pub const GEMINI_3_FLASH_LATEST: &str = "gemini-3-flash"; +/// Convenience alias. Use `fetch_models()` or pass any model ID string to `model()`. +pub const GEMINI_FLASH_LIVE_LATEST: &str = "gemini-3.1-flash-live-preview"; +/// Convenience alias. Use `fetch_models()` or pass any model ID string to `model()`. +pub const GEMINI_EMBEDDING_LATEST: &str = "gemini-embedding-2-preview"; + +/// Provider for Google Gemini models. +pub struct GeminiProvider { + api_key: SecretString, +} + +impl GeminiProvider { + /// Create a new Gemini provider with the given API key. + pub fn new(api_key: impl Into) -> Self { + Self { + api_key: SecretString::from(api_key.into()), + } + } + + /// Get Gemini 2.5 Pro (most capable reasoning). + pub fn gemini_pro(&self) -> GeminiModel { + self.model(GEMINI_PRO_LATEST) + } + + /// Get Gemini 2.5 Flash (best price/performance). + pub fn gemini_flash(&self) -> GeminiModel { + self.model(GEMINI_FLASH_LATEST) + } + + /// Get Gemini 2.5 Flash Lite (fastest/cheapest). + pub fn gemini_flash_lite(&self) -> GeminiModel { + self.model(GEMINI_FLASH_LITE_LATEST) + } + + /// Get Gemini 3.1 Pro Preview (latest preview). + pub fn gemini_31_pro(&self) -> GeminiModel { + self.model(GEMINI_PRO_PREVIEW_LATEST) + } + + /// Get Gemini 3 Flash (frontier-class preview). + pub fn gemini_3_flash(&self) -> GeminiModel { + self.model(GEMINI_3_FLASH_LATEST) + } + + /// Get a Gemini model by its model ID. Any valid Gemini model ID is accepted. + pub fn model(&self, model_id: &str) -> GeminiModel { + use secrecy::ExposeSecret; + GeminiModel::new(self.api_key.expose_secret(), model_id) + } + + /// Fetch the list of models from the Gemini API. + /// + /// Calls `GET /v1beta/models?key=...` and returns model names. + pub async fn list_remote_models(&self) -> rusty_ai::AiResult> { + use secrecy::ExposeSecret; + let url = format!( + "https://generativelanguage.googleapis.com/v1beta/models?key={}", + self.api_key.expose_secret() + ); + let client = reqwest::Client::new(); + let resp = client + .get(&url) + .send() + .await + .map_err(|e| rusty_ai::AiError::Transport { + message: e.to_string(), + source: Some(Box::new(e)), + })?; + + let status = resp.status(); + if !status.is_success() { + let status_code = status.as_u16(); + let body = match resp.text().await { + Ok(body) => body, + Err(e) => { + tracing::warn!(status = status_code, error = %e, "Failed to read Gemini error response body"); + format!("") + } + }; + return Err(rusty_ai::AiError::ProviderError { + provider: "gemini".into(), + status: Some(status_code), + message: body, + }); + } + + #[derive(serde::Deserialize)] + struct ListModelsResponse { + models: Vec, + } + #[derive(serde::Deserialize)] + struct ModelEntry { + name: String, + } + + let list: ListModelsResponse = resp + .json() + .await + .map_err(|e| rusty_ai::AiError::Serialization(e.to_string()))?; + + Ok(list + .models + .into_iter() + .map(|m| { + m.name + .strip_prefix("models/") + .unwrap_or(&m.name) + .to_string() + }) + .collect()) + } +} diff --git a/crates/rusty_gemini/src/stream_parser.rs b/crates/rusty_gemini/src/stream_parser.rs new file mode 100644 index 0000000..596c2fe --- /dev/null +++ b/crates/rusty_gemini/src/stream_parser.rs @@ -0,0 +1,191 @@ +use crate::api_types::*; +use crate::convert::{map_finish_reason, map_usage}; +use futures::stream::{self, StreamExt}; +use reqwest::Response; +use rusty_ai::error::AiError; +use rusty_ai::stream::{AiStream, StreamEvent}; + +/// Parse Gemini's Server-Sent Events streaming format into an `AiStream`. +/// +/// The streaming endpoint (with `alt=sse`) returns standard SSE where each +/// `data:` line contains a JSON `GenerateContentResponse` object. +pub(crate) fn parse_stream(response: Response) -> AiStream { + let byte_stream = response.bytes_stream(); + + // Phase 1: Parse raw bytes into SSE data payloads (JSON strings). + let json_stream = futures::stream::unfold( + (byte_stream, String::new()), + |(mut byte_stream, mut buffer)| async move { + loop { + // Try to extract a complete SSE data line from the buffer. + if let Some(json_str) = extract_next_data_line(&mut buffer) { + return Some((Ok(json_str), (byte_stream, buffer))); + } + + // Need more data. + match byte_stream.next().await { + Some(Ok(chunk)) => { + let text = String::from_utf8_lossy(&chunk); + buffer.push_str(&text); + } + Some(Err(e)) => { + let msg = e.to_string(); + return Some(( + Err(AiError::Transport { + message: msg, + source: Some(Box::new(e)), + }), + (byte_stream, buffer), + )); + } + None => { + // End of stream. Try to parse any remaining data lines. + if let Some(json_str) = extract_next_data_line(&mut buffer) { + return Some((Ok(json_str), (byte_stream, buffer))); + } + return None; + } + } + } + }, + ); + + // Phase 2: Parse JSON strings into GenerateContentResponse and then into + // StreamEvents, emitting MessageStart at the beginning. + let event_stream = futures::stream::unfold( + (Box::pin(json_stream), false), + |(mut json_stream, mut sent_start)| async move { + loop { + // Emit MessageStart before the first real event. + if !sent_start { + sent_start = true; + let start_event: Vec> = + vec![Ok(StreamEvent::MessageStart { + message_id: uuid::Uuid::new_v4().to_string(), + })]; + return Some((stream::iter(start_event), (json_stream, sent_start))); + } + + match json_stream.next().await { + Some(Ok(json_str)) => { + match serde_json::from_str::(&json_str) { + Ok(response) => { + let events = response_to_stream_events(response); + if !events.is_empty() { + let items: Vec> = + events.into_iter().map(Ok).collect(); + return Some((stream::iter(items), (json_stream, sent_start))); + } + continue; + } + Err(e) => { + tracing::error!( + data = %json_str, + error = %e, + "Failed to parse Gemini SSE event; terminating stream" + ); + return Some(( + stream::iter(vec![Err(AiError::StreamError { + message: format!("Unparseable SSE event from Gemini: {e}"), + })]), + (json_stream, sent_start), + )); + } + } + } + Some(Err(e)) => { + return Some((stream::iter(vec![Err(e)]), (json_stream, sent_start))); + } + None => return None, + } + } + }, + ) + .flat_map(|items| items); + + Box::pin(event_stream) +} + +/// Extract the next complete `data: ...` line from the buffer. +/// Removes the consumed portion from the buffer. +fn extract_next_data_line(buffer: &mut String) -> Option { + loop { + let newline_pos = buffer.find('\n')?; + let line = buffer[..newline_pos].to_string(); + *buffer = buffer[newline_pos + 1..].to_string(); + + let trimmed = line.trim(); + if trimmed.is_empty() || trimmed.starts_with(':') { + continue; + } + + if let Some(data) = trimmed.strip_prefix("data:") { + let data = data.trim(); + if !data.is_empty() { + return Some(data.to_string()); + } + } + // Skip non-data SSE fields (event:, id:, retry:). + } +} + +/// Convert a single Gemini response chunk into stream events. +fn response_to_stream_events(response: GenerateContentResponse) -> Vec { + let mut events = Vec::new(); + + if let Some(candidates) = &response.candidates { + if let Some(candidate) = candidates.first() { + if let Some(ref content) = candidate.content { + for part in &content.parts { + match part { + GeminiPart::Text { text } => { + events.push(StreamEvent::TextDelta { + delta: text.clone(), + }); + } + GeminiPart::Thought { text, .. } => { + events.push(StreamEvent::ThinkingDelta { + delta: text.clone(), + }); + } + GeminiPart::FunctionCall { function_call } => { + let call_id = uuid::Uuid::new_v4().to_string(); + events.push(StreamEvent::ToolCallStart { + call_id: call_id.clone(), + tool_name: function_call.name.clone(), + }); + events.push(StreamEvent::ToolCallEnd { + call_id, + arguments: function_call.args.clone(), + }); + } + _ => {} + } + } + } + + if let Some(ref reason) = candidate.finish_reason { + let finish = map_finish_reason(reason); + let usage = map_usage(response.usage_metadata.as_ref()); + events.push(StreamEvent::MessageEnd { + finish_reason: finish, + usage: Some(usage), + }); + } + } + } + + // Emit usage if present and no MessageEnd was emitted yet. + if let Some(ref meta) = response.usage_metadata { + let already_has_end = events + .iter() + .any(|e| matches!(e, StreamEvent::MessageEnd { .. })); + if !already_has_end { + events.push(StreamEvent::UsageDelta { + usage: map_usage(Some(meta)), + }); + } + } + + events +} diff --git a/crates/rusty_gemini_nano/Cargo.toml b/crates/rusty_gemini_nano/Cargo.toml new file mode 100644 index 0000000..d4aff0e --- /dev/null +++ b/crates/rusty_gemini_nano/Cargo.toml @@ -0,0 +1,22 @@ +[package] +name = "rusty_gemini_nano" +version.workspace = true +edition.workspace = true +license.workspace = true +description = "Gemini Nano (Android Prompt API) local runtime for the Rusty AI SDK" + +[dependencies] +rusty_ai = { workspace = true } +async-trait = { workspace = true } +futures = { workspace = true } +tokio = { workspace = true } +tokio-stream = { workspace = true } +serde = { workspace = true } +serde_json = { workspace = true } +tracing = { workspace = true } +uuid = { workspace = true } +chrono = { workspace = true } + +[target.'cfg(target_os = "android")'.dependencies] +# JNI bridge would go here in a real build +# jni = "0.21" diff --git a/crates/rusty_gemini_nano/src/bridge.rs b/crates/rusty_gemini_nano/src/bridge.rs new file mode 100644 index 0000000..cd47cd0 --- /dev/null +++ b/crates/rusty_gemini_nano/src/bridge.rs @@ -0,0 +1,37 @@ +use async_trait::async_trait; + +use crate::types::{ModelDownloadState, NanoCapabilities, NanoSessionConfig}; + +/// Trait that must be implemented by the host application to bridge +/// to the Android Prompt API via JNI/Kotlin interop. +/// +/// The host app provides a concrete implementation that calls into the +/// Android platform SDK. This crate consumes that implementation to +/// expose a standard [`rusty_ai::LanguageModel`]. +#[async_trait] +pub trait GeminiNanoBridge: Send + Sync { + /// Check if Gemini Nano is available on this device. + async fn is_available(&self) -> bool; + + /// Get the current download state of the model. + async fn download_state(&self) -> ModelDownloadState; + + /// Request model download if not already downloaded. + async fn request_download(&self) -> Result<(), String>; + + /// Get device capabilities. + async fn capabilities(&self) -> NanoCapabilities; + + /// Generate text from a prompt (single-turn). + async fn generate(&self, prompt: &str, config: &NanoSessionConfig) -> Result; + + /// Create a new session for multi-turn conversation. + /// Returns a session identifier. + async fn create_session(&self, config: &NanoSessionConfig) -> Result; + + /// Send a message in an existing session. + async fn send_message(&self, session_id: &str, message: &str) -> Result; + + /// Close/destroy a session. + async fn close_session(&self, session_id: &str) -> Result<(), String>; +} diff --git a/crates/rusty_gemini_nano/src/lib.rs b/crates/rusty_gemini_nano/src/lib.rs new file mode 100644 index 0000000..521ec02 --- /dev/null +++ b/crates/rusty_gemini_nano/src/lib.rs @@ -0,0 +1,17 @@ +//! Gemini Nano (Android Prompt API) local runtime for the Rusty AI SDK. +//! +//! This crate provides a bridge-based integration with Google's Gemini Nano +//! model running on-device via the Android Prompt API. Host applications must +//! implement the [`GeminiNanoBridge`] trait to connect the JNI/Kotlin layer. + +mod bridge; +mod model; +mod provider; +mod session; +mod types; + +pub use bridge::*; +pub use model::*; +pub use provider::*; +pub use session::*; +pub use types::*; diff --git a/crates/rusty_gemini_nano/src/model.rs b/crates/rusty_gemini_nano/src/model.rs new file mode 100644 index 0000000..4f17efe --- /dev/null +++ b/crates/rusty_gemini_nano/src/model.rs @@ -0,0 +1,162 @@ +use std::sync::Arc; + +use async_trait::async_trait; + +use rusty_ai::{ + AiError, AiResult, AiStream, Capability, CapabilitySet, ContentPart, FinishReason, + GenerateOptions, GenerateResult, LanguageModel, Message, Prompt, ResponseMetadata, Role, + SyntheticStreamer, Usage, +}; + +use crate::bridge::GeminiNanoBridge; +use crate::types::NanoSessionConfig; + +/// A language model backed by Gemini Nano running on-device via the Android +/// Prompt API. +/// +/// Streaming is synthetic -- the full response is generated first and then +/// chunked into a stream, because Gemini Nano does not support native +/// streaming on Android. +pub struct GeminiNanoModel { + bridge: Arc, + capabilities: CapabilitySet, +} + +impl GeminiNanoModel { + /// Create a new `GeminiNanoModel` wrapping the given bridge. + pub fn new(bridge: Arc) -> Self { + let capabilities = CapabilitySet::new() + .with(Capability::TextInput) + .with(Capability::TextOutput) + .with(Capability::LocalExecution) + .with(Capability::SessionSupport) + .with(Capability::PlatformNative); + + Self { + bridge, + capabilities, + } + } + + /// Build a [`NanoSessionConfig`] from the generic [`GenerateOptions`]. + fn build_config(options: &GenerateOptions) -> NanoSessionConfig { + NanoSessionConfig { + temperature: options.temperature, + top_k: options.top_k, + max_tokens: options.max_tokens, + } + } + + /// Check availability and return an error if the model is not ready. + async fn ensure_available(&self) -> AiResult<()> { + if !self.bridge.is_available().await { + return Err(AiError::PlatformUnavailable { + platform: "android/gemini_nano".into(), + }); + } + + let state = self.bridge.download_state().await; + match state { + crate::types::ModelDownloadState::Downloaded => Ok(()), + crate::types::ModelDownloadState::NotDownloaded => Err(AiError::ModelUnavailable { + model: "gemini-nano (not downloaded)".into(), + }), + crate::types::ModelDownloadState::Downloading { progress_percent } => { + Err(AiError::ModelUnavailable { + model: format!("gemini-nano (downloading: {progress_percent}%)"), + }) + } + crate::types::ModelDownloadState::Failed { reason } => Err(AiError::BridgeError { + bridge: "gemini_nano".into(), + message: format!("Model download failed: {reason}"), + }), + } + } +} + +#[async_trait] +impl LanguageModel for GeminiNanoModel { + fn model_id(&self) -> &str { + "gemini-nano" + } + + fn provider_id(&self) -> &str { + "gemini_nano" + } + + fn capabilities(&self) -> &CapabilitySet { + &self.capabilities + } + + async fn generate(&self, prompt: Prompt, options: GenerateOptions) -> AiResult { + self.ensure_available().await?; + + let config = Self::build_config(&options); + let prompt_text = extract_prompt_text(prompt); + let start = std::time::Instant::now(); + + let text = self + .bridge + .generate(&prompt_text, &config) + .await + .map_err(|e| AiError::BridgeError { + bridge: "gemini_nano".into(), + message: e, + })?; + + let latency_ms = start.elapsed().as_millis() as u64; + + Ok(GenerateResult { + text: Some(text), + tool_calls: Vec::new(), + finish_reason: FinishReason::Stop, + usage: Usage::default(), + metadata: ResponseMetadata { + provider: "gemini_nano".into(), + model: "gemini-nano".into(), + latency_ms: Some(latency_ms), + ..Default::default() + }, + }) + } + + async fn stream(&self, prompt: Prompt, options: GenerateOptions) -> AiResult { + let result = self.generate(prompt, options).await?; + let text = result.text.unwrap_or_default(); + Ok(SyntheticStreamer::stream(text, 20)) + } +} + +/// Extract a single prompt string from a [`Prompt`] for the bridge. +fn extract_prompt_text(prompt: Prompt) -> String { + match prompt { + Prompt::Text(text) => text, + Prompt::Messages(messages) => messages_to_text(&messages), + } +} + +/// Concatenate messages into a single string. +fn messages_to_text(messages: &[Message]) -> String { + let mut parts = Vec::new(); + + for msg in messages { + let prefix = match msg.role { + Role::System => "System: ", + Role::User => "", + Role::Assistant => "Assistant: ", + Role::Tool => "Tool: ", + }; + + for part in &msg.content { + if let ContentPart::Text { text } = part { + if prefix.is_empty() { + parts.push(text.clone()); + } else { + parts.push(format!("{prefix}{text}")); + } + } + } + } + + parts.join("\n") +} diff --git a/crates/rusty_gemini_nano/src/provider.rs b/crates/rusty_gemini_nano/src/provider.rs new file mode 100644 index 0000000..c0eb457 --- /dev/null +++ b/crates/rusty_gemini_nano/src/provider.rs @@ -0,0 +1,98 @@ +use std::sync::Arc; + +use rusty_ai::{ + AiError, AiResult, Capability, CapabilitySet, EmbeddingModel, LanguageModel, ModelInfo, + Provider, +}; + +use crate::bridge::GeminiNanoBridge; +use crate::model::GeminiNanoModel; +use crate::session::NanoSession; +use crate::types::{ModelDownloadState, NanoSessionConfig}; + +/// Provider for Gemini Nano on-device inference via the Android Prompt API. +pub struct GeminiNanoProvider { + bridge: Arc, +} + +impl GeminiNanoProvider { + pub fn new(bridge: impl GeminiNanoBridge + 'static) -> Self { + Self { + bridge: Arc::new(bridge), + } + } + + /// Get the Gemini Nano language model. + pub fn model(&self) -> GeminiNanoModel { + GeminiNanoModel::new(self.bridge.clone()) + } + + /// Check if Gemini Nano is available on this device. + pub async fn is_available(&self) -> bool { + self.bridge.is_available().await + } + + /// Get the current download state of the model. + pub async fn download_state(&self) -> ModelDownloadState { + self.bridge.download_state().await + } + + /// Request model download if not already downloaded. + pub async fn request_download(&self) -> AiResult<()> { + self.bridge + .request_download() + .await + .map_err(|e| AiError::BridgeError { + bridge: "gemini_nano".into(), + message: e, + }) + } + + /// Create a new multi-turn session. + pub async fn create_session(&self, config: NanoSessionConfig) -> AiResult { + let session_id = + self.bridge + .create_session(&config) + .await + .map_err(|e| AiError::BridgeError { + bridge: "gemini_nano".into(), + message: e, + })?; + Ok(NanoSession::new(session_id, self.bridge.clone(), config)) + } +} + +impl Provider for GeminiNanoProvider { + fn id(&self) -> &str { + "gemini_nano" + } + + fn name(&self) -> &str { + "Gemini Nano (Android)" + } + + fn language_model(&self, _model_id: &str) -> AiResult> { + Ok(Box::new(self.model())) + } + + fn embedding_model(&self, model_id: &str) -> AiResult> { + Err(AiError::UnsupportedCapability { + capability: "embeddings".into(), + provider: format!("gemini_nano/{model_id}"), + }) + } + + fn available_models(&self) -> Vec { + vec![ModelInfo { + id: "gemini-nano".into(), + provider: "gemini_nano".into(), + display_name: "Gemini Nano (On-Device)".into(), + capabilities: CapabilitySet::new() + .with(Capability::TextInput) + .with(Capability::TextOutput) + .with(Capability::LocalExecution) + .with(Capability::SessionSupport) + .with(Capability::PlatformNative), + }] + } +} diff --git a/crates/rusty_gemini_nano/src/session.rs b/crates/rusty_gemini_nano/src/session.rs new file mode 100644 index 0000000..d70d78f --- /dev/null +++ b/crates/rusty_gemini_nano/src/session.rs @@ -0,0 +1,59 @@ +use std::sync::Arc; + +use crate::bridge::GeminiNanoBridge; +use crate::types::NanoSessionConfig; + +/// A multi-turn conversation session backed by Gemini Nano. +/// +/// When dropped, the session is automatically closed via a background task. +pub struct NanoSession { + session_id: String, + bridge: Arc, + config: NanoSessionConfig, +} + +impl NanoSession { + pub(crate) fn new( + session_id: String, + bridge: Arc, + config: NanoSessionConfig, + ) -> Self { + Self { + session_id, + bridge, + config, + } + } + + /// Send a message within this session and receive the model's response. + pub async fn send(&self, message: &str) -> Result { + self.bridge + .send_message(&self.session_id, message) + .await + .map_err(|e| rusty_ai::AiError::BridgeError { + bridge: "gemini_nano".into(), + message: e, + }) + } + + /// Returns the identifier for this session. + pub fn session_id(&self) -> &str { + &self.session_id + } + + /// Returns the configuration used to create this session. + pub fn config(&self) -> &NanoSessionConfig { + &self.config + } +} + +impl Drop for NanoSession { + fn drop(&mut self) { + let bridge = self.bridge.clone(); + let session_id = self.session_id.clone(); + // Fire-and-forget cleanup + tokio::spawn(async move { + let _ = bridge.close_session(&session_id).await; + }); + } +} diff --git a/crates/rusty_gemini_nano/src/types.rs b/crates/rusty_gemini_nano/src/types.rs new file mode 100644 index 0000000..079e6c0 --- /dev/null +++ b/crates/rusty_gemini_nano/src/types.rs @@ -0,0 +1,40 @@ +/// The download state of the Gemini Nano model on the device. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum ModelDownloadState { + /// The model has not been downloaded to the device. + NotDownloaded, + /// The model is currently being downloaded. + Downloading { + /// Download progress as a percentage (0-100). + progress_percent: u8, + }, + /// The model has been downloaded and is ready for use. + Downloaded, + /// The model download failed. + Failed { + /// A human-readable reason for the failure. + reason: String, + }, +} + +/// Capabilities exposed by Gemini Nano on the current device. +#[derive(Debug, Clone)] +pub struct NanoCapabilities { + /// Whether text generation is supported. + pub text_generation: bool, + /// Whether summarization is supported. + pub summarization: bool, + /// Whether rewriting is supported. + pub rewriting: bool, +} + +/// Configuration for a Gemini Nano session. +#[derive(Debug, Clone, Default)] +pub struct NanoSessionConfig { + /// Sampling temperature (0.0 - 1.0). + pub temperature: Option, + /// Top-k sampling parameter. + pub top_k: Option, + /// Maximum number of tokens to generate. + pub max_tokens: Option, +} diff --git a/crates/rusty_middleware/Cargo.toml b/crates/rusty_middleware/Cargo.toml new file mode 100644 index 0000000..392483c --- /dev/null +++ b/crates/rusty_middleware/Cargo.toml @@ -0,0 +1,20 @@ +[package] +name = "rusty_middleware" +version.workspace = true +edition.workspace = true +license.workspace = true +description = "Middleware components for the Rusty AI SDK" + +[dependencies] +rusty_ai = { workspace = true } +async-trait = { workspace = true } +futures = { workspace = true } +tokio = { workspace = true } +tracing = { workspace = true } +serde = { workspace = true } +serde_json = { workspace = true } +chrono = { workspace = true } + +[dev-dependencies] +rusty_testing = { workspace = true } +tokio = { workspace = true } diff --git a/crates/rusty_middleware/src/cache.rs b/crates/rusty_middleware/src/cache.rs new file mode 100644 index 0000000..bb7c442 --- /dev/null +++ b/crates/rusty_middleware/src/cache.rs @@ -0,0 +1,247 @@ +use std::collections::hash_map::DefaultHasher; +use std::collections::HashMap; +use std::hash::{Hash, Hasher}; +use std::sync::{Arc, Mutex}; +use std::time::{Duration, Instant}; + +use async_trait::async_trait; +use rusty_ai::{AiResult, GenerateOptions, GenerateResult, Middleware, MiddlewareNext, Prompt}; + +/// A cached generate result together with its insertion timestamp. +struct CacheEntry { + result: GenerateResult, + inserted_at: Instant, +} + +/// In-memory caching middleware for non-streaming generate calls. +/// +/// Caches responses keyed on a hash of the prompt. Entries expire after +/// the configured TTL. The cache is thread-safe via `Arc>`. +pub struct CacheMiddleware { + cache: Arc>>, + ttl: Duration, +} + +impl CacheMiddleware { + /// Create a new `CacheMiddleware` with the specified time-to-live. + pub fn new(ttl: Duration) -> Self { + Self { + cache: Arc::new(Mutex::new(HashMap::new())), + ttl, + } + } + + /// Compute a deterministic cache key from the prompt and generation options. + /// + /// Returns `None` if the prompt cannot be serialized, in which case the + /// caller should bypass the cache entirely to avoid hash collisions. + /// + /// Both the prompt content and all generation-affecting options are included + /// so that requests with the same prompt but different parameters (temperature, + /// tools, output schema, etc.) are treated as distinct cache entries. + /// The request metadata (which contains a per-request UUID) is excluded. + fn cache_key(prompt: &Prompt, options: &GenerateOptions) -> Option { + let mut hasher = DefaultHasher::new(); + + let prompt_json = serde_json::to_string(prompt) + .map_err(|e| tracing::error!(error = %e, "Failed to serialize prompt for cache key; bypassing cache")) + .ok()?; + prompt_json.hash(&mut hasher); + + // Numeric options — hash the bit pattern to keep f64 deterministic. + options.temperature.map(f64::to_bits).hash(&mut hasher); + options.max_tokens.hash(&mut hasher); + options.top_p.map(f64::to_bits).hash(&mut hasher); + options.top_k.hash(&mut hasher); + options.stop_sequences.hash(&mut hasher); + options + .frequency_penalty + .map(f64::to_bits) + .hash(&mut hasher); + options.presence_penalty.map(f64::to_bits).hash(&mut hasher); + options.seed.hash(&mut hasher); + + // Complex types — serialize to JSON for a stable, content-based hash. + if let Ok(json) = serde_json::to_string(&options.tools) { + json.hash(&mut hasher); + } + if let Ok(json) = serde_json::to_string(&options.tool_choice) { + json.hash(&mut hasher); + } + if let Ok(json) = serde_json::to_string(&options.output_schema) { + json.hash(&mut hasher); + } + + // Enum options without Serialize — use Debug, which is stable for owned enums. + format!("{:?}", options.thinking).hash(&mut hasher); + format!("{:?}", options.reasoning_effort).hash(&mut hasher); + + Some(hasher.finish()) + } +} + +#[async_trait] +impl Middleware for CacheMiddleware { + async fn process( + &self, + prompt: Prompt, + options: GenerateOptions, + next: MiddlewareNext<'_>, + ) -> AiResult { + let Some(key) = Self::cache_key(&prompt, &options) else { + // Prompt could not be serialized; bypass cache to avoid collisions. + return next.run(prompt, options).await; + }; + + // Check cache. + { + let cache = self.cache.lock().expect("cache lock poisoned"); + if let Some(entry) = cache.get(&key) { + if entry.inserted_at.elapsed() < self.ttl { + tracing::debug!(cache_key = key, "cache hit"); + return Ok(entry.result.clone()); + } + } + } + + tracing::debug!(cache_key = key, "cache miss"); + + // Execute the downstream chain. + let result = next.run(prompt, options).await?; + + // Store in cache. + { + let mut cache = self.cache.lock().expect("cache lock poisoned"); + + // Evict expired entries opportunistically. + cache.retain(|_, entry| entry.inserted_at.elapsed() < self.ttl); + + cache.insert( + key, + CacheEntry { + result: result.clone(), + inserted_at: Instant::now(), + }, + ); + } + + Ok(result) + } +} + +#[cfg(test)] +mod tests { + use std::time::Duration; + + use rusty_ai::tool::ToolDefinition; + use rusty_ai::{GenerateOptions, Prompt}; + use rusty_testing::{MockLanguageModel, MockResponse}; + + use crate::chain::MiddlewareChain; + + use super::CacheMiddleware; + + fn text_response(s: &str) -> MockResponse { + MockResponse::Text(s.to_owned()) + } + + #[tokio::test] + async fn same_prompt_and_options_hits_cache() { + let model = MockLanguageModel::new("test").with_response(text_response("response-1")); + + let chain = MiddlewareChain::new(model).with(CacheMiddleware::new(Duration::from_secs(60))); + + let prompt = Prompt::from("hello"); + let opts = GenerateOptions::default(); + + let r1 = chain.generate(prompt.clone(), opts.clone()).await.unwrap(); + let r2 = chain.generate(prompt.clone(), opts.clone()).await.unwrap(); + + // Both should return the same cached text; the model was only called once. + assert_eq!(r1.text, r2.text); + assert_eq!(r1.text.as_deref(), Some("response-1")); + } + + #[tokio::test] + async fn different_temperature_is_a_cache_miss() { + let model = MockLanguageModel::new("test") + .with_response(text_response("response-cold")) + .with_response(text_response("response-hot")); + + let chain = MiddlewareChain::new(model).with(CacheMiddleware::new(Duration::from_secs(60))); + + let prompt = Prompt::from("same prompt"); + let cold = chain + .generate( + prompt.clone(), + GenerateOptions::default().with_temperature(0.0), + ) + .await + .unwrap(); + let hot = chain + .generate( + prompt.clone(), + GenerateOptions::default().with_temperature(1.0), + ) + .await + .unwrap(); + + assert_eq!(cold.text.as_deref(), Some("response-cold")); + assert_eq!(hot.text.as_deref(), Some("response-hot")); + } + + #[tokio::test] + async fn different_tools_is_a_cache_miss() { + let model = MockLanguageModel::new("test") + .with_response(text_response("no-tools")) + .with_response(text_response("with-tools")); + + let chain = MiddlewareChain::new(model).with(CacheMiddleware::new(Duration::from_secs(60))); + + let prompt = Prompt::from("same prompt"); + + let r_plain = chain + .generate(prompt.clone(), GenerateOptions::default()) + .await + .unwrap(); + let r_tools = chain + .generate( + prompt.clone(), + GenerateOptions::default().with_tools(vec![ToolDefinition { + name: "search".into(), + description: "search the web".into(), + parameters: serde_json::json!({}), + }]), + ) + .await + .unwrap(); + + assert_eq!(r_plain.text.as_deref(), Some("no-tools")); + assert_eq!(r_tools.text.as_deref(), Some("with-tools")); + } + + #[tokio::test] + async fn expired_entry_is_not_returned() { + let model = MockLanguageModel::new("test") + .with_response(text_response("first")) + .with_response(text_response("second")); + + // TTL of 1ms so the entry expires immediately. + let chain = + MiddlewareChain::new(model).with(CacheMiddleware::new(Duration::from_millis(1))); + + let prompt = Prompt::from("hello"); + let opts = GenerateOptions::default(); + + let r1 = chain.generate(prompt.clone(), opts.clone()).await.unwrap(); + tokio::time::sleep(Duration::from_millis(5)).await; + let r2 = chain.generate(prompt.clone(), opts.clone()).await.unwrap(); + + assert_eq!(r1.text.as_deref(), Some("first")); + assert_eq!( + r2.text.as_deref(), + Some("second"), + "expired entry should trigger a fresh model call" + ); + } +} diff --git a/crates/rusty_middleware/src/chain.rs b/crates/rusty_middleware/src/chain.rs new file mode 100644 index 0000000..462c61e --- /dev/null +++ b/crates/rusty_middleware/src/chain.rs @@ -0,0 +1,54 @@ +use rusty_ai::{ + AiResult, GenerateOptions, GenerateResult, LanguageModel, Middleware, MiddlewareNext, Prompt, +}; + +/// A chain of middleware wrapping a language model. +/// +/// Middlewares are executed in the order they are added (first added = outermost). +/// Each middleware may inspect or modify the prompt, options, or result before +/// and after calling the next element in the chain. +/// +/// # Example +/// +/// ```ignore +/// let chain = MiddlewareChain::new(my_model) +/// .with(LoggingMiddleware::new()) +/// .with(RetryMiddleware::new(RetryConfig::default())); +/// +/// let result = chain.generate(prompt, options).await?; +/// ``` +pub struct MiddlewareChain { + middlewares: Vec>, + model: Box, +} + +impl MiddlewareChain { + /// Create a new chain wrapping the given language model. + pub fn new(model: impl LanguageModel + 'static) -> Self { + Self { + middlewares: Vec::new(), + model: Box::new(model), + } + } + + /// Append a middleware to the chain and return `self` for fluent building. + /// + /// Middleware added first will be executed first (outermost). + pub fn with(mut self, mw: impl Middleware + 'static) -> Self { + self.middlewares.push(Box::new(mw)); + self + } + + /// Execute the middleware chain, producing a [`GenerateResult`]. + pub async fn generate( + &self, + prompt: Prompt, + options: GenerateOptions, + ) -> AiResult { + let next = MiddlewareNext { + middlewares: &self.middlewares, + model: self.model.as_ref(), + }; + next.run(prompt, options).await + } +} diff --git a/crates/rusty_middleware/src/lib.rs b/crates/rusty_middleware/src/lib.rs new file mode 100644 index 0000000..a146b56 --- /dev/null +++ b/crates/rusty_middleware/src/lib.rs @@ -0,0 +1,14 @@ +//! Middleware components for the Rusty AI SDK. +//! +//! Provides reusable middleware implementations that can be composed into +//! chains around any [`rusty_ai::LanguageModel`]. + +mod cache; +mod chain; +mod logging; +mod retry; + +pub use cache::*; +pub use chain::*; +pub use logging::*; +pub use retry::*; diff --git a/crates/rusty_middleware/src/logging.rs b/crates/rusty_middleware/src/logging.rs new file mode 100644 index 0000000..a4535df --- /dev/null +++ b/crates/rusty_middleware/src/logging.rs @@ -0,0 +1,136 @@ +use std::time::Instant; + +use async_trait::async_trait; +use rusty_ai::{AiResult, GenerateOptions, GenerateResult, Middleware, MiddlewareNext, Prompt}; + +/// Middleware that logs request and response details using the `tracing` crate. +pub struct LoggingMiddleware { + level: tracing::Level, +} + +impl LoggingMiddleware { + /// Create a new `LoggingMiddleware` that logs at `INFO` level. + pub fn new() -> Self { + Self { + level: tracing::Level::INFO, + } + } + + /// Create a new `LoggingMiddleware` that logs at the specified level. + pub fn with_level(level: tracing::Level) -> Self { + Self { level } + } + + /// Produce a short human-readable summary of the prompt. + fn summarize_prompt(prompt: &Prompt) -> String { + match prompt { + Prompt::Text(t) => { + let preview: String = t.chars().take(80).collect(); + if t.len() > 80 { + format!("\"{}...\" ({} chars)", preview, t.len()) + } else { + format!("\"{}\"", preview) + } + } + Prompt::Messages(msgs) => { + format!("{} message(s)", msgs.len()) + } + } + } +} + +impl Default for LoggingMiddleware { + fn default() -> Self { + Self::new() + } +} + +#[async_trait] +impl Middleware for LoggingMiddleware { + async fn process( + &self, + prompt: Prompt, + options: GenerateOptions, + next: MiddlewareNext<'_>, + ) -> AiResult { + let summary = Self::summarize_prompt(&prompt); + let temp = options.temperature; + let max_tokens = options.max_tokens; + + match self.level { + tracing::Level::TRACE => { + tracing::trace!( + prompt = %summary, + temperature = ?temp, + max_tokens = ?max_tokens, + "generate request" + ); + } + tracing::Level::DEBUG => { + tracing::debug!( + prompt = %summary, + temperature = ?temp, + max_tokens = ?max_tokens, + "generate request" + ); + } + tracing::Level::WARN => { + tracing::warn!( + prompt = %summary, + temperature = ?temp, + max_tokens = ?max_tokens, + "generate request" + ); + } + tracing::Level::ERROR => { + tracing::error!( + prompt = %summary, + temperature = ?temp, + max_tokens = ?max_tokens, + "generate request" + ); + } + _ => { + tracing::info!( + prompt = %summary, + temperature = ?temp, + max_tokens = ?max_tokens, + "generate request" + ); + } + } + + let start = Instant::now(); + let result = next.run(prompt, options).await; + let elapsed = start.elapsed(); + + match &result { + Ok(res) => { + let prompt_tokens = res.usage.prompt_tokens; + let completion_tokens = res.usage.completion_tokens; + let finish_reason = &res.finish_reason; + let latency_ms = elapsed.as_millis() as u64; + + match self.level { + tracing::Level::TRACE => tracing::trace!(latency_ms, prompt_tokens = ?prompt_tokens, completion_tokens = ?completion_tokens, finish_reason = ?finish_reason, "generate response"), + tracing::Level::DEBUG => tracing::debug!(latency_ms, prompt_tokens = ?prompt_tokens, completion_tokens = ?completion_tokens, finish_reason = ?finish_reason, "generate response"), + tracing::Level::WARN => tracing::warn!(latency_ms, prompt_tokens = ?prompt_tokens, completion_tokens = ?completion_tokens, finish_reason = ?finish_reason, "generate response"), + tracing::Level::ERROR => tracing::error!(latency_ms, prompt_tokens = ?prompt_tokens, completion_tokens = ?completion_tokens, finish_reason = ?finish_reason, "generate response"), + _ => tracing::info!(latency_ms, prompt_tokens = ?prompt_tokens, completion_tokens = ?completion_tokens, finish_reason = ?finish_reason, "generate response"), + } + } + Err(e) => { + let latency_ms = elapsed.as_millis() as u64; + // Use ?e (Debug) to preserve the full error source chain. + match self.level { + tracing::Level::TRACE => tracing::trace!(latency_ms, error = ?e, "generate failed"), + tracing::Level::DEBUG => tracing::debug!(latency_ms, error = ?e, "generate failed"), + tracing::Level::WARN => tracing::warn!(latency_ms, error = ?e, "generate failed"), + _ => tracing::error!(latency_ms, error = ?e, "generate failed"), + } + } + } + + result + } +} diff --git a/crates/rusty_middleware/src/retry.rs b/crates/rusty_middleware/src/retry.rs new file mode 100644 index 0000000..f88ccaf --- /dev/null +++ b/crates/rusty_middleware/src/retry.rs @@ -0,0 +1,100 @@ +use async_trait::async_trait; +use rusty_ai::{ + AiError, AiResult, GenerateOptions, GenerateResult, Middleware, MiddlewareNext, Prompt, +}; + +/// Configuration for retry behaviour. +#[derive(Debug, Clone)] +pub struct RetryConfig { + /// Maximum number of retry attempts (not counting the initial request). + pub max_retries: u32, + /// Initial delay in milliseconds before the first retry. + pub initial_delay_ms: u64, + /// Multiplier applied to the delay after each retry. + pub backoff_multiplier: f64, +} + +impl Default for RetryConfig { + fn default() -> Self { + Self { + max_retries: 3, + initial_delay_ms: 1000, + backoff_multiplier: 2.0, + } + } +} + +/// Middleware that retries failed requests with exponential backoff. +/// +/// Only errors deemed "retryable" (rate limits, timeouts, transport failures) +/// trigger a retry. All other errors are propagated immediately. +pub struct RetryMiddleware { + config: RetryConfig, +} + +impl RetryMiddleware { + /// Create a new `RetryMiddleware` with the given configuration. + pub fn new(config: RetryConfig) -> Self { + Self { config } + } + + /// Returns `true` for error variants that are safe to retry. + pub fn default_retryable(error: &AiError) -> bool { + matches!( + error, + AiError::RateLimit { .. } | AiError::Timeout | AiError::Transport { .. } + ) + } +} + +#[async_trait] +impl Middleware for RetryMiddleware { + async fn process( + &self, + prompt: Prompt, + options: GenerateOptions, + next: MiddlewareNext<'_>, + ) -> AiResult { + let mut delay_ms = self.config.initial_delay_ms; + + // Save references so we can rebuild MiddlewareNext after consumption. + let remaining_middlewares = next.middlewares; + let model = next.model; + + let mut last_error: Option = None; + + for attempt in 1..=(self.config.max_retries + 1) { + if attempt > 1 { + tracing::info!(attempt, delay_ms, "retrying request"); + tokio::time::sleep(std::time::Duration::from_millis(delay_ms)).await; + delay_ms = (delay_ms as f64 * self.config.backoff_multiplier) as u64; + } + + let next_handle = MiddlewareNext { + middlewares: remaining_middlewares, + model, + }; + + match next_handle.run(prompt.clone(), options.clone()).await { + Ok(result) => return Ok(result), + Err(e) => { + if !Self::default_retryable(&e) { + return Err(e); + } + tracing::warn!( + attempt, + max_retries = self.config.max_retries, + error = %e, + "retryable error encountered" + ); + last_error = Some(e); + } + } + } + + Err(last_error.unwrap_or_else(|| AiError::Transport { + message: "Retry loop exhausted without capturing an error (this is a bug in RetryMiddleware)".to_string(), + source: None, + })) + } +} diff --git a/crates/rusty_ollama/Cargo.toml b/crates/rusty_ollama/Cargo.toml new file mode 100644 index 0000000..c148389 --- /dev/null +++ b/crates/rusty_ollama/Cargo.toml @@ -0,0 +1,20 @@ +[package] +name = "rusty_ollama" +version.workspace = true +edition.workspace = true +license.workspace = true +description = "Ollama local runtime provider for the Rusty AI SDK" + +[dependencies] +rusty_ai = { workspace = true } +async-trait = { workspace = true } +futures = { workspace = true } +tokio = { workspace = true } +tokio-stream = { workspace = true } +reqwest = { workspace = true } +serde = { workspace = true } +serde_json = { workspace = true } +tracing = { workspace = true } +uuid = { workspace = true } +chrono = { workspace = true } +bytes = { workspace = true } diff --git a/crates/rusty_ollama/src/api_types.rs b/crates/rusty_ollama/src/api_types.rs new file mode 100644 index 0000000..04bc738 --- /dev/null +++ b/crates/rusty_ollama/src/api_types.rs @@ -0,0 +1,121 @@ +use serde::{Deserialize, Serialize}; + +// ── Chat request ── + +#[derive(Debug, Serialize)] +pub(crate) struct OllamaChatRequest { + pub model: String, + pub messages: Vec, + pub stream: bool, + #[serde(skip_serializing_if = "Option::is_none")] + pub options: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub format: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub tools: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub think: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub(crate) struct OllamaMessage { + pub role: String, + pub content: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub images: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_calls: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub thinking: Option, +} + +#[derive(Debug, Serialize)] +pub(crate) struct OllamaOptions { + #[serde(skip_serializing_if = "Option::is_none")] + pub temperature: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub num_predict: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub top_p: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub top_k: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub stop: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub seed: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub frequency_penalty: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub presence_penalty: Option, +} + +// ── Tools ── + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub(crate) struct OllamaTool { + #[serde(rename = "type")] + pub tool_type: String, + pub function: OllamaFunction, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub(crate) struct OllamaFunction { + pub name: String, + pub description: String, + pub parameters: serde_json::Value, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub(crate) struct OllamaToolCall { + pub function: OllamaFunctionCall, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub(crate) struct OllamaFunctionCall { + pub name: String, + pub arguments: serde_json::Value, +} + +// ── Chat response ── + +#[derive(Debug, Deserialize)] +pub(crate) struct OllamaChatResponse { + #[allow(dead_code)] + pub model: String, + pub message: OllamaMessage, + pub done: bool, + #[serde(default)] + pub done_reason: Option, + #[serde(default)] + pub eval_count: Option, + #[serde(default)] + pub prompt_eval_count: Option, + /// Thinking content from reasoning models (e.g. deepseek-r1, qwen3 with think=true) + #[serde(default)] + pub thinking: Option, +} + +// ── Embedding ── + +#[derive(Debug, Serialize)] +pub(crate) struct OllamaEmbedRequest { + pub model: String, + pub input: Vec, +} + +#[derive(Debug, Deserialize)] +pub(crate) struct OllamaEmbedResponse { + pub embeddings: Vec>, +} + +// ── List models ── + +#[derive(Debug, Deserialize)] +pub(crate) struct OllamaListResponse { + pub models: Vec, +} + +#[derive(Debug, Deserialize)] +pub(crate) struct OllamaModelEntry { + pub name: String, +} diff --git a/crates/rusty_ollama/src/convert.rs b/crates/rusty_ollama/src/convert.rs new file mode 100644 index 0000000..aa75172 --- /dev/null +++ b/crates/rusty_ollama/src/convert.rs @@ -0,0 +1,141 @@ +use rusty_ai::{ + ContentPart, GenerateOptions, ImageData, Message, Role, ToolCallRequest, ToolDefinition, +}; + +use crate::api_types::{ + OllamaFunction, OllamaFunctionCall, OllamaMessage, OllamaOptions, OllamaTool, OllamaToolCall, +}; + +/// Convert a slice of `rusty_ai::Message` into Ollama messages. +pub(crate) fn convert_messages(messages: &[Message]) -> Vec { + messages.iter().map(convert_message).collect() +} + +fn convert_message(msg: &Message) -> OllamaMessage { + let role = match msg.role { + Role::System => "system", + Role::User => "user", + Role::Assistant => "assistant", + Role::Tool => "tool", + }; + + let mut text_parts: Vec = Vec::new(); + let mut images: Vec = Vec::new(); + let mut tool_calls: Vec = Vec::new(); + + for part in &msg.content { + match part { + ContentPart::Text { text } => { + text_parts.push(text.clone()); + } + ContentPart::Image { data } => { + match data { + ImageData::Base64 { data, .. } => { + images.push(data.clone()); + } + ImageData::Url { url, .. } => { + // Ollama only supports base64 images natively. Pass the URL as text + // as a fallback; real usage should provide base64-encoded images. + text_parts.push(format!("[image: {}]", url)); + } + } + } + ContentPart::ToolCall { call } => { + tool_calls.push(OllamaToolCall { + function: OllamaFunctionCall { + name: call.name.clone(), + arguments: call.arguments.clone(), + }, + }); + } + ContentPart::ToolResult { result } => { + text_parts.push(result.content.clone()); + } + ContentPart::File { .. } => { + // Ollama does not support file parts; skip. + } + } + } + + OllamaMessage { + role: role.to_string(), + content: text_parts.join("\n"), + images: if images.is_empty() { + None + } else { + Some(images) + }, + tool_calls: if tool_calls.is_empty() { + None + } else { + Some(tool_calls) + }, + thinking: None, + } +} + +/// Convert `GenerateOptions` into `OllamaOptions` (returns `None` if nothing is set). +pub(crate) fn convert_options(opts: &GenerateOptions) -> Option { + let o = OllamaOptions { + temperature: opts.temperature, + num_predict: opts.max_tokens, + top_p: opts.top_p, + top_k: opts.top_k, + stop: if opts.stop_sequences.is_empty() { + None + } else { + Some(opts.stop_sequences.clone()) + }, + seed: opts.seed, + frequency_penalty: opts.frequency_penalty, + presence_penalty: opts.presence_penalty, + }; + + // Check if everything is None / empty. + if o.temperature.is_none() + && o.num_predict.is_none() + && o.top_p.is_none() + && o.top_k.is_none() + && o.stop.is_none() + && o.seed.is_none() + && o.frequency_penalty.is_none() + && o.presence_penalty.is_none() + { + return None; + } + Some(o) +} + +/// Convert `rusty_ai::ToolDefinition` to Ollama tools. +pub(crate) fn convert_tools(tools: Option<&[ToolDefinition]>) -> Option> { + let tools = tools?; + if tools.is_empty() { + return None; + } + Some( + tools + .iter() + .map(|t| OllamaTool { + tool_type: "function".to_string(), + function: OllamaFunction { + name: t.name.clone(), + description: t.description.clone(), + parameters: t.parameters.clone(), + }, + }) + .collect(), + ) +} + +/// Convert Ollama tool calls to `rusty_ai::ToolCallRequest`. +pub(crate) fn convert_tool_calls(calls: &[OllamaToolCall]) -> Vec { + calls + .iter() + .enumerate() + .map(|(i, tc)| ToolCallRequest { + id: format!("call_{}", i), + name: tc.function.name.clone(), + arguments: tc.function.arguments.clone(), + }) + .collect() +} diff --git a/crates/rusty_ollama/src/lib.rs b/crates/rusty_ollama/src/lib.rs new file mode 100644 index 0000000..0de7949 --- /dev/null +++ b/crates/rusty_ollama/src/lib.rs @@ -0,0 +1,9 @@ +//! Ollama local runtime provider for the Rusty AI SDK. + +mod api_types; +mod convert; +mod model; +mod provider; + +pub use model::*; +pub use provider::*; diff --git a/crates/rusty_ollama/src/model.rs b/crates/rusty_ollama/src/model.rs new file mode 100644 index 0000000..61bd580 --- /dev/null +++ b/crates/rusty_ollama/src/model.rs @@ -0,0 +1,424 @@ +use async_trait::async_trait; +use futures::stream::{self, StreamExt}; + +use rusty_ai::{ + AiError, AiResult, AiStream, Capability, CapabilitySet, EmbeddingModel, EmbeddingResult, + FinishReason, GenerateOptions, GenerateResult, LanguageModel, ResponseMetadata, StreamEvent, + Usage, +}; + +use crate::api_types::{ + OllamaChatRequest, OllamaChatResponse, OllamaEmbedRequest, OllamaEmbedResponse, +}; +use crate::convert; + +/// An Ollama model that can perform chat completions and embeddings against a +/// local (or remote) Ollama server. +pub struct OllamaModel { + base_url: String, + model_id: String, + capabilities: CapabilitySet, + client: reqwest::Client, +} + +impl OllamaModel { + /// Create a new `OllamaModel` pointing at `http://localhost:11434`. + pub fn new(model_id: &str) -> Self { + Self { + base_url: "http://localhost:11434".to_string(), + model_id: model_id.to_string(), + capabilities: CapabilitySet::new() + .with(Capability::TextInput) + .with(Capability::TextOutput) + .with(Capability::Streaming) + .with(Capability::ToolCalling) + .with(Capability::ImageInput) + .with(Capability::Embeddings) + .with(Capability::LocalExecution), + client: reqwest::Client::new(), + } + } + + /// Override the base URL (e.g. `http://myserver:11434`). + pub fn with_base_url(mut self, url: &str) -> Self { + self.base_url = url.trim_end_matches('/').to_string(); + self + } + + /// Build the chat request body from a prompt and options. + fn build_chat_request( + &self, + prompt: rusty_ai::Prompt, + options: &GenerateOptions, + stream: bool, + ) -> OllamaChatRequest { + let conversation = prompt.into_messages(); + let messages = convert::convert_messages(&conversation); + + OllamaChatRequest { + model: self.model_id.clone(), + messages, + stream, + options: convert::convert_options(options), + format: options.output_schema.as_ref().map(|s| s.as_value().clone()), + tools: convert::convert_tools(options.tools.as_deref()), + think: options.thinking.as_ref().map(|_| true), + } + } + + /// Parse token usage from an Ollama chat response. + fn parse_usage(resp: &OllamaChatResponse) -> Usage { + Usage { + prompt_tokens: resp.prompt_eval_count, + completion_tokens: resp.eval_count, + total_tokens: match (resp.prompt_eval_count, resp.eval_count) { + (Some(p), Some(c)) => Some(p + c), + _ => None, + }, + } + } + + /// Determine the finish reason from an Ollama response. + fn parse_finish_reason(resp: &OllamaChatResponse, has_tool_calls: bool) -> FinishReason { + if has_tool_calls { + FinishReason::ToolCall + } else if resp.done_reason.as_deref() == Some("length") { + FinishReason::Length + } else { + FinishReason::Stop + } + } + + /// Send a non-streaming chat request and parse the response. + async fn do_generate(&self, request: OllamaChatRequest) -> AiResult { + let url = format!("{}/api/chat", self.base_url); + + let resp = self + .client + .post(&url) + .json(&request) + .send() + .await + .map_err(|e| AiError::Transport { + message: e.to_string(), + source: Some(Box::new(e)), + })?; + + let status = resp.status(); + if !status.is_success() { + let status_code = status.as_u16(); + let body = match resp.text().await { + Ok(body) => body, + Err(e) => { + tracing::warn!(status = status_code, error = %e, "Failed to read Ollama error response body"); + format!("") + } + }; + return Err(AiError::ProviderError { + provider: "ollama".to_string(), + status: Some(status_code), + message: body, + }); + } + + let chat_resp: OllamaChatResponse = resp + .json() + .await + .map_err(|e| AiError::Serialization(e.to_string()))?; + + let tool_calls = chat_resp + .message + .tool_calls + .as_deref() + .map(convert::convert_tool_calls) + .unwrap_or_default(); + + let finish_reason = Self::parse_finish_reason(&chat_resp, !tool_calls.is_empty()); + + let text = if chat_resp.message.content.is_empty() { + None + } else { + Some(chat_resp.message.content.clone()) + }; + + let usage = Self::parse_usage(&chat_resp); + + Ok(GenerateResult { + text, + tool_calls, + finish_reason, + usage, + metadata: ResponseMetadata { + provider: "ollama".to_string(), + model: self.model_id.clone(), + ..ResponseMetadata::default() + }, + }) + } +} + +#[async_trait] +impl LanguageModel for OllamaModel { + fn model_id(&self) -> &str { + &self.model_id + } + + fn provider_id(&self) -> &str { + "ollama" + } + + fn capabilities(&self) -> &CapabilitySet { + &self.capabilities + } + + async fn generate( + &self, + prompt: rusty_ai::Prompt, + options: GenerateOptions, + ) -> AiResult { + let request = self.build_chat_request(prompt, &options, false); + self.do_generate(request).await + } + + async fn stream( + &self, + prompt: rusty_ai::Prompt, + options: GenerateOptions, + ) -> AiResult { + let request = self.build_chat_request(prompt, &options, true); + let url = format!("{}/api/chat", self.base_url); + + let resp = self + .client + .post(&url) + .json(&request) + .send() + .await + .map_err(|e| AiError::Transport { + message: e.to_string(), + source: Some(Box::new(e)), + })?; + + let status = resp.status(); + if !status.is_success() { + let status_code = status.as_u16(); + let body = match resp.text().await { + Ok(body) => body, + Err(e) => { + tracing::warn!(status = status_code, error = %e, "Failed to read Ollama error response body"); + format!("") + } + }; + return Err(AiError::ProviderError { + provider: "ollama".to_string(), + status: Some(status_code), + message: body, + }); + } + + // Ollama streams NDJSON: one JSON object per line. + let byte_stream = resp.bytes_stream(); + Ok(build_ndjson_stream(byte_stream)) + } +} + +/// Build an `AiStream` from Ollama's NDJSON byte stream. +/// +/// Ollama emits one JSON object per line. We accumulate bytes until we see a +/// newline, then parse each complete line as an `OllamaChatResponse`. +fn build_ndjson_stream( + byte_stream: impl futures::Stream> + Send + 'static, +) -> AiStream { + let message_id = uuid::Uuid::new_v4().to_string(); + + // State: (buffer of partial bytes, whether we already sent MessageStart) + let event_stream = byte_stream + .scan( + (Vec::::new(), false, message_id), + |state, chunk_result: Result| { + let (ref mut buf, ref mut sent_start, ref message_id) = *state; + + let chunk = match chunk_result { + Ok(c) => c, + Err(e) => { + return std::future::ready(Some(vec![Err(AiError::StreamError { + message: e.to_string(), + })])); + } + }; + + buf.extend_from_slice(&chunk); + + let mut events: Vec> = Vec::new(); + + // Process all complete lines in the buffer. + while let Some(newline_pos) = buf.iter().position(|&b| b == b'\n') { + let line_bytes: Vec = buf.drain(..=newline_pos).collect(); + let line = match std::str::from_utf8(&line_bytes) { + Ok(s) => s.trim().to_string(), + Err(_) => continue, + }; + + if line.is_empty() { + continue; + } + + let resp: OllamaChatResponse = match serde_json::from_str(&line) { + Ok(r) => r, + Err(e) => { + tracing::error!( + line = %line, + error = %e, + "Failed to parse Ollama stream chunk; terminating stream" + ); + events.push(Err(AiError::StreamError { + message: format!("Unparseable NDJSON chunk from Ollama: {e}"), + })); + return std::future::ready(Some(events)); + } + }; + + if !*sent_start { + *sent_start = true; + events.push(Ok(StreamEvent::MessageStart { + message_id: message_id.clone(), + })); + } + + // Emit thinking tokens if present (reasoning models) + if let Some(ref thinking) = resp.thinking { + if !thinking.is_empty() { + events.push(Ok(StreamEvent::ThinkingDelta { + delta: thinking.clone(), + })); + } + } + + if !resp.message.content.is_empty() { + events.push(Ok(StreamEvent::TextDelta { + delta: resp.message.content.clone(), + })); + } + + if let Some(ref calls) = resp.message.tool_calls { + for (i, tc) in calls.iter().enumerate() { + let call_id = format!("call_{}", i); + events.push(Ok(StreamEvent::ToolCallStart { + call_id: call_id.clone(), + tool_name: tc.function.name.clone(), + })); + events.push(Ok(StreamEvent::ToolCallEnd { + call_id, + arguments: tc.function.arguments.clone(), + })); + } + } + + if resp.done { + let usage = Usage { + prompt_tokens: resp.prompt_eval_count, + completion_tokens: resp.eval_count, + total_tokens: match (resp.prompt_eval_count, resp.eval_count) { + (Some(p), Some(c)) => Some(p + c), + _ => None, + }, + }; + + let has_tools = resp + .message + .tool_calls + .as_ref() + .is_some_and(|v| !v.is_empty()); + + let finish_reason = if has_tools { + FinishReason::ToolCall + } else if resp.done_reason.as_deref() == Some("length") { + FinishReason::Length + } else { + FinishReason::Stop + }; + + events.push(Ok(StreamEvent::MessageEnd { + finish_reason, + usage: Some(usage), + })); + } + } + + std::future::ready(Some(events)) + }, + ) + .flat_map(stream::iter); + + Box::pin(event_stream) +} + +#[async_trait] +impl EmbeddingModel for OllamaModel { + fn model_id(&self) -> &str { + &self.model_id + } + + fn provider_id(&self) -> &str { + "ollama" + } + + fn dimensions(&self) -> Option { + // Ollama does not report embedding dimensions upfront. + None + } + + async fn embed(&self, texts: Vec) -> AiResult { + let url = format!("{}/api/embed", self.base_url); + + let request = OllamaEmbedRequest { + model: self.model_id.clone(), + input: texts, + }; + + let resp = self + .client + .post(&url) + .json(&request) + .send() + .await + .map_err(|e| AiError::Transport { + message: e.to_string(), + source: Some(Box::new(e)), + })?; + + let status = resp.status(); + if !status.is_success() { + let status_code = status.as_u16(); + let body = match resp.text().await { + Ok(body) => body, + Err(e) => { + tracing::warn!(status = status_code, error = %e, "Failed to read Ollama error response body"); + format!("") + } + }; + return Err(AiError::ProviderError { + provider: "ollama".to_string(), + status: Some(status_code), + message: body, + }); + } + + let embed_resp: OllamaEmbedResponse = resp + .json() + .await + .map_err(|e| AiError::Serialization(e.to_string()))?; + + // Convert f32 -> f64 to match the trait signature. + let embeddings = embed_resp + .embeddings + .into_iter() + .map(|v| v.into_iter().map(|x| x as f64).collect()) + .collect(); + + Ok(EmbeddingResult { + embeddings, + usage: Usage::default(), + }) + } +} diff --git a/crates/rusty_ollama/src/provider.rs b/crates/rusty_ollama/src/provider.rs new file mode 100644 index 0000000..aa5a9d7 --- /dev/null +++ b/crates/rusty_ollama/src/provider.rs @@ -0,0 +1,97 @@ +use rusty_ai::{AiError, AiResult, ModelInfo, Provider}; + +use crate::api_types::OllamaListResponse; +use crate::model::OllamaModel; + +/// A provider handle for a local Ollama server. +/// +/// Use this to discover available models and create `OllamaModel` instances. +pub struct OllamaProvider { + base_url: String, + client: reqwest::Client, +} + +impl OllamaProvider { + /// Create a provider pointing at the default Ollama address + /// (`http://localhost:11434`). + pub fn new() -> Self { + Self { + base_url: "http://localhost:11434".to_string(), + client: reqwest::Client::new(), + } + } + + /// Override the base URL. + pub fn with_base_url(mut self, url: &str) -> Self { + self.base_url = url.trim_end_matches('/').to_string(); + self + } + + /// Create an `OllamaModel` for the given model identifier (e.g. + /// `"llama3"`, `"mistral"`, `"nomic-embed-text"`). + pub fn model(&self, id: &str) -> OllamaModel { + OllamaModel::new(id).with_base_url(&self.base_url) + } + + /// List models that are currently available on the Ollama server. + pub async fn list_models(&self) -> AiResult> { + let url = format!("{}/api/tags", self.base_url); + + let resp = self + .client + .get(&url) + .send() + .await + .map_err(|e| AiError::Transport { + message: e.to_string(), + source: Some(Box::new(e)), + })?; + + let status = resp.status(); + if !status.is_success() { + let body = resp.text().await.unwrap_or_default(); + return Err(AiError::ProviderError { + provider: "ollama".to_string(), + status: Some(status.as_u16()), + message: body, + }); + } + + let list: OllamaListResponse = resp + .json() + .await + .map_err(|e| AiError::Serialization(e.to_string()))?; + + Ok(list.models.into_iter().map(|m| m.name).collect()) + } +} + +impl Default for OllamaProvider { + fn default() -> Self { + Self::new() + } +} + +impl Provider for OllamaProvider { + fn id(&self) -> &str { + "ollama" + } + + fn name(&self) -> &str { + "Ollama" + } + + fn language_model(&self, model_id: &str) -> AiResult> { + Ok(Box::new(self.model(model_id))) + } + + fn embedding_model(&self, model_id: &str) -> AiResult> { + Ok(Box::new(self.model(model_id))) + } + + fn available_models(&self) -> Vec { + // Ollama's model list requires an async call; for the sync trait we + // return an empty list. Use `list_models()` for the async version. + Vec::new() + } +} diff --git a/crates/rusty_openai_compatible/Cargo.toml b/crates/rusty_openai_compatible/Cargo.toml new file mode 100644 index 0000000..353ddf3 --- /dev/null +++ b/crates/rusty_openai_compatible/Cargo.toml @@ -0,0 +1,22 @@ +[package] +name = "rusty_openai_compatible" +version.workspace = true +edition.workspace = true +license.workspace = true +description = "OpenAI-compatible API adapter for the Rusty AI SDK" + +[dependencies] +rusty_ai = { workspace = true } +async-trait = { workspace = true } +futures = { workspace = true } +tokio = { workspace = true } +tokio-stream = { workspace = true, features = ["sync"] } +reqwest = { workspace = true } +reqwest-eventsource = { workspace = true } +serde = { workspace = true } +serde_json = { workspace = true } +tracing = { workspace = true } +uuid = { workspace = true } +chrono = { workspace = true } +secrecy = { workspace = true } +url = { workspace = true } diff --git a/crates/rusty_openai_compatible/src/api_types.rs b/crates/rusty_openai_compatible/src/api_types.rs new file mode 100644 index 0000000..29c756a --- /dev/null +++ b/crates/rusty_openai_compatible/src/api_types.rs @@ -0,0 +1,157 @@ +use serde::{Deserialize, Serialize}; + +// --------------------------------------------------------------------------- +// Request types +// --------------------------------------------------------------------------- + +#[derive(Debug, Serialize)] +pub(crate) struct ChatCompletionRequest { + pub model: String, + pub messages: Vec, + #[serde(skip_serializing_if = "Option::is_none")] + pub temperature: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub max_tokens: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub top_p: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub frequency_penalty: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub presence_penalty: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub seed: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub stop: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub tools: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_choice: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub response_format: Option, + #[serde(skip_serializing_if = "std::ops::Not::not")] + pub stream: bool, + #[serde(skip_serializing_if = "Option::is_none")] + pub stream_options: Option, +} + +#[derive(Debug, Serialize)] +pub(crate) struct StreamOptions { + pub include_usage: bool, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub(crate) struct ChatMessage { + pub role: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub content: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub name: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_calls: Option>, + #[serde(skip_serializing_if = "Option::is_none")] + pub tool_call_id: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub(crate) struct ChatTool { + #[serde(rename = "type")] + pub tool_type: String, + pub function: ChatFunction, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub(crate) struct ChatFunction { + pub name: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub description: Option, + pub parameters: serde_json::Value, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub(crate) struct ChatToolCall { + pub id: String, + #[serde(rename = "type")] + pub call_type: String, + pub function: ChatFunctionCall, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub(crate) struct ChatFunctionCall { + pub name: String, + pub arguments: String, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub(crate) struct ResponseFormat { + #[serde(rename = "type")] + pub format_type: String, + #[serde(skip_serializing_if = "Option::is_none")] + pub json_schema: Option, +} + +#[derive(Debug, Clone, Serialize, Deserialize)] +pub(crate) struct JsonSchemaFormat { + pub name: String, + pub schema: serde_json::Value, + #[serde(skip_serializing_if = "Option::is_none")] + pub strict: Option, +} + +// --------------------------------------------------------------------------- +// Response types +// --------------------------------------------------------------------------- + +#[derive(Debug, Deserialize)] +#[allow(dead_code)] +pub(crate) struct ChatCompletionResponse { + pub id: String, + pub choices: Vec, + pub usage: Option, + pub model: String, +} + +#[derive(Debug, Deserialize)] +pub(crate) struct ChatChoice { + #[allow(dead_code)] + pub index: u32, + pub message: Option, + pub delta: Option, + pub finish_reason: Option, +} + +#[derive(Debug, Clone, Deserialize)] +pub(crate) struct ApiUsage { + pub prompt_tokens: u64, + pub completion_tokens: u64, + pub total_tokens: u64, +} + +// --------------------------------------------------------------------------- +// Streaming chunk +// --------------------------------------------------------------------------- + +#[derive(Debug, Deserialize)] +pub(crate) struct ChatCompletionChunk { + pub id: String, + pub choices: Vec, + pub usage: Option, +} + +// --------------------------------------------------------------------------- +// Error response body +// --------------------------------------------------------------------------- + +#[derive(Debug, Deserialize)] +pub(crate) struct ApiErrorResponse { + pub error: ApiErrorBody, +} + +#[derive(Debug, Deserialize)] +pub(crate) struct ApiErrorBody { + pub message: String, + #[serde(rename = "type")] + #[allow(dead_code)] + pub error_type: Option, + #[allow(dead_code)] + pub code: Option, +} diff --git a/crates/rusty_openai_compatible/src/config.rs b/crates/rusty_openai_compatible/src/config.rs new file mode 100644 index 0000000..74ac58c --- /dev/null +++ b/crates/rusty_openai_compatible/src/config.rs @@ -0,0 +1,59 @@ +use secrecy::SecretString; + +/// Configuration for connecting to an OpenAI-compatible API. +#[derive(Clone)] +pub struct OpenAiCompatibleConfig { + pub(crate) base_url: String, + pub(crate) api_key: SecretString, + pub(crate) org_id: Option, + pub(crate) default_headers: Vec<(String, String)>, +} + +impl OpenAiCompatibleConfig { + /// Create a new configuration with a custom base URL. + pub fn new(base_url: impl Into, api_key: impl Into) -> Self { + Self { + base_url: base_url.into(), + api_key: SecretString::from(api_key.into()), + org_id: None, + default_headers: Vec::new(), + } + } + + /// Create a configuration pre-pointed at the official OpenAI API. + pub fn openai(api_key: impl Into) -> Self { + Self::new("https://api.openai.com/v1", api_key) + } + + /// Set an optional organization ID header. + pub fn with_org(mut self, org_id: impl Into) -> Self { + self.org_id = Some(org_id.into()); + self + } + + /// Add a default header that will be sent with every request. + pub fn with_header(mut self, name: impl Into, value: impl Into) -> Self { + self.default_headers.push((name.into(), value.into())); + self + } + + /// The base URL (e.g. `https://api.openai.com/v1`). + pub fn base_url(&self) -> &str { + &self.base_url + } + + /// Expose the API key for building HTTP headers. + pub fn api_key(&self) -> &SecretString { + &self.api_key + } +} + +impl std::fmt::Debug for OpenAiCompatibleConfig { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("OpenAiCompatibleConfig") + .field("base_url", &self.base_url) + .field("api_key", &"[REDACTED]") + .field("org_id", &self.org_id) + .finish() + } +} diff --git a/crates/rusty_openai_compatible/src/convert.rs b/crates/rusty_openai_compatible/src/convert.rs new file mode 100644 index 0000000..39f5fa2 --- /dev/null +++ b/crates/rusty_openai_compatible/src/convert.rs @@ -0,0 +1,288 @@ +use rusty_ai::{ + ContentPart, FinishReason, GenerateOptions, GenerateResult, ImageData, Message, Prompt, + ResponseMetadata, Role, ToolCallRequest, ToolChoice, Usage, +}; + +use crate::api_types::*; + +// --------------------------------------------------------------------------- +// Prompt -> API messages +// --------------------------------------------------------------------------- + +pub(crate) fn prompt_to_messages(prompt: &Prompt) -> Vec { + match prompt { + Prompt::Text(text) => vec![ChatMessage { + role: "user".to_string(), + content: Some(serde_json::Value::String(text.clone())), + name: None, + tool_calls: None, + tool_call_id: None, + }], + Prompt::Messages(msgs) => msgs.iter().map(message_to_chat).collect(), + } +} + +fn message_to_chat(msg: &Message) -> ChatMessage { + let role = match msg.role { + Role::System => "system", + Role::User => "user", + Role::Assistant => "assistant", + Role::Tool => "tool", + }; + + // Collect tool calls from content parts (for assistant messages). + let mut tool_calls: Vec = Vec::new(); + // For tool-result messages, extract the call_id. + let mut tool_call_id: Option = None; + + let content_parts: Vec = msg + .content + .iter() + .filter_map(|part| match part { + ContentPart::Text { text } => Some(serde_json::json!({ "type": "text", "text": text })), + ContentPart::Image { data } => Some(image_to_json(data)), + ContentPart::ToolCall { call } => { + tool_calls.push(ChatToolCall { + id: call.id.clone(), + call_type: "function".to_string(), + function: ChatFunctionCall { + name: call.name.clone(), + arguments: call.arguments.to_string(), + }, + }); + None + } + ContentPart::ToolResult { result } => { + tool_call_id = Some(result.call_id.clone()); + Some(serde_json::Value::String(result.content.clone())) + } + ContentPart::File { .. } => None, // files not supported by OpenAI chat API + }) + .collect(); + + // Build the content field. + let content = if msg.role == Role::Tool { + // Tool messages use a plain string content. + content_parts.into_iter().next() + } else if content_parts.len() == 1 { + // Single text part -> use a plain string for compatibility. + if let Some(serde_json::Value::Object(ref obj)) = content_parts.first() { + if obj.get("type").and_then(|v| v.as_str()) == Some("text") { + obj.get("text").cloned() + } else { + Some(serde_json::Value::Array(content_parts)) + } + } else { + Some(serde_json::Value::Array(content_parts)) + } + } else if content_parts.is_empty() { + None + } else { + Some(serde_json::Value::Array(content_parts)) + }; + + ChatMessage { + role: role.to_string(), + content, + name: msg.name.clone(), + tool_calls: if tool_calls.is_empty() { + None + } else { + Some(tool_calls) + }, + tool_call_id, + } +} + +fn image_to_json(data: &ImageData) -> serde_json::Value { + match data { + ImageData::Url { url, detail } => { + let mut image_url = serde_json::json!({ "url": url }); + if let Some(d) = detail { + let detail_str = match d { + rusty_ai::ImageDetail::Auto => "auto", + rusty_ai::ImageDetail::Low => "low", + rusty_ai::ImageDetail::High => "high", + }; + image_url["detail"] = serde_json::Value::String(detail_str.to_string()); + } + serde_json::json!({ + "type": "image_url", + "image_url": image_url, + }) + } + ImageData::Base64 { media_type, data } => { + let data_url = format!("data:{};base64,{}", media_type, data); + serde_json::json!({ + "type": "image_url", + "image_url": { "url": data_url }, + }) + } + } +} + +// --------------------------------------------------------------------------- +// Build the full API request +// --------------------------------------------------------------------------- + +pub(crate) fn options_to_request( + model_id: &str, + prompt: &Prompt, + options: &GenerateOptions, + stream: bool, +) -> ChatCompletionRequest { + let messages = prompt_to_messages(prompt); + + let tools: Option> = options.tools.as_ref().map(|tool_defs| { + tool_defs + .iter() + .map(|td| ChatTool { + tool_type: "function".to_string(), + function: ChatFunction { + name: td.name.clone(), + description: if td.description.is_empty() { + None + } else { + Some(td.description.clone()) + }, + parameters: td.parameters.clone(), + }, + }) + .collect() + }); + + let tool_choice = options.tool_choice.as_ref().map(|tc| match tc { + ToolChoice::Auto => serde_json::json!("auto"), + ToolChoice::None => serde_json::json!("none"), + ToolChoice::Required => serde_json::json!("required"), + ToolChoice::Specific(name) => serde_json::json!({ + "type": "function", + "function": { "name": name } + }), + }); + + let response_format = options.output_schema.as_ref().map(|schema| ResponseFormat { + format_type: "json_schema".to_string(), + json_schema: Some(JsonSchemaFormat { + name: schema.name.clone(), + schema: schema.schema.clone(), + strict: Some(true), + }), + }); + + let stop = if options.stop_sequences.is_empty() { + None + } else { + Some(options.stop_sequences.clone()) + }; + + let stream_options = if stream { + Some(StreamOptions { + include_usage: true, + }) + } else { + None + }; + + ChatCompletionRequest { + model: model_id.to_string(), + messages, + temperature: options.temperature, + max_tokens: options.max_tokens, + top_p: options.top_p, + frequency_penalty: options.frequency_penalty, + presence_penalty: options.presence_penalty, + seed: options.seed, + stop, + tools, + tool_choice, + response_format, + stream, + stream_options, + } +} + +// --------------------------------------------------------------------------- +// API response -> GenerateResult +// --------------------------------------------------------------------------- + +pub(crate) fn response_to_result( + response: ChatCompletionResponse, + provider: &str, +) -> GenerateResult { + let choice = response.choices.into_iter().next(); + + let (text, tool_calls, finish_reason) = match choice { + Some(c) => { + let msg = c.message.unwrap_or(ChatMessage { + role: "assistant".to_string(), + content: None, + name: None, + tool_calls: None, + tool_call_id: None, + }); + + let text = msg.content.and_then(|v| match v { + serde_json::Value::String(s) => Some(s), + _ => None, + }); + + let tool_calls: Vec = msg + .tool_calls + .unwrap_or_default() + .into_iter() + .map(|tc| { + let args: serde_json::Value = + serde_json::from_str(&tc.function.arguments).unwrap_or_default(); + ToolCallRequest { + id: tc.id, + name: tc.function.name, + arguments: args, + } + }) + .collect(); + + let finish_reason = parse_finish_reason(c.finish_reason.as_deref()); + (text, tool_calls, finish_reason) + } + None => (None, Vec::new(), FinishReason::Unknown), + }; + + let usage = response + .usage + .map(|u| Usage { + prompt_tokens: Some(u.prompt_tokens), + completion_tokens: Some(u.completion_tokens), + total_tokens: Some(u.total_tokens), + }) + .unwrap_or_default(); + + GenerateResult { + text, + tool_calls, + finish_reason, + usage, + metadata: ResponseMetadata { + request_id: uuid::Uuid::new_v4(), + provider: provider.to_string(), + model: response.model, + latency_ms: None, + extra: Default::default(), + }, + } +} + +// --------------------------------------------------------------------------- +// Finish reason mapping +// --------------------------------------------------------------------------- + +pub(crate) fn parse_finish_reason(reason: Option<&str>) -> FinishReason { + match reason { + Some("stop") => FinishReason::Stop, + Some("length") => FinishReason::Length, + Some("tool_calls") => FinishReason::ToolCall, + Some("content_filter") => FinishReason::ContentFilter, + Some("error") => FinishReason::Error, + _ => FinishReason::Unknown, + } +} diff --git a/crates/rusty_openai_compatible/src/lib.rs b/crates/rusty_openai_compatible/src/lib.rs new file mode 100644 index 0000000..2c8062b --- /dev/null +++ b/crates/rusty_openai_compatible/src/lib.rs @@ -0,0 +1,17 @@ +//! Generic adapter for any OpenAI-compatible chat-completions API. +//! +//! This crate provides [`OpenAiCompatibleProvider`] and [`OpenAiCompatibleModel`] +//! which implement the core `rusty_ai` traits and can be pointed at any API that +//! follows the OpenAI chat-completions wire format (OpenAI, Azure, Together, +//! Groq, local vLLM, etc.). + +mod api_types; +mod config; +mod convert; +mod model; +mod provider; +mod stream_parser; + +pub use config::OpenAiCompatibleConfig; +pub use model::OpenAiCompatibleModel; +pub use provider::OpenAiCompatibleProvider; diff --git a/crates/rusty_openai_compatible/src/model.rs b/crates/rusty_openai_compatible/src/model.rs new file mode 100644 index 0000000..6ef876d --- /dev/null +++ b/crates/rusty_openai_compatible/src/model.rs @@ -0,0 +1,232 @@ +use async_trait::async_trait; +use futures::StreamExt; +use reqwest::header::{HeaderMap, HeaderName, HeaderValue, AUTHORIZATION, CONTENT_TYPE}; +use secrecy::ExposeSecret; + +use rusty_ai::{ + AiError, AiResult, AiStream, CapabilitySet, GenerateOptions, GenerateResult, LanguageModel, + Prompt, +}; + +use crate::api_types::{ApiErrorResponse, ChatCompletionRequest, ChatCompletionResponse}; +use crate::config::OpenAiCompatibleConfig; +use crate::convert; +use crate::stream_parser; + +/// A language model backed by an OpenAI-compatible HTTP API. +pub struct OpenAiCompatibleModel { + config: OpenAiCompatibleConfig, + model_id: String, + provider_id: String, + capabilities: CapabilitySet, + client: reqwest::Client, +} + +impl OpenAiCompatibleModel { + /// Create a new model instance. + /// + /// # Panics + /// + /// Panics if the system TLS stack cannot be initialized. This is a + /// system-level failure (e.g. missing TLS libraries). Use [`try_new`] to + /// handle this case without panicking. + /// + /// [`try_new`]: OpenAiCompatibleModel::try_new + pub fn new(config: OpenAiCompatibleConfig, model_id: &str, provider_id: &str) -> Self { + Self::try_new(config, model_id, provider_id) + .expect("Failed to initialize HTTP client (TLS unavailable or system misconfigured)") + } + + /// Create a new model instance, returning an error if the HTTP client + /// cannot be initialized (e.g. TLS is unavailable on the system). + pub fn try_new( + config: OpenAiCompatibleConfig, + model_id: &str, + provider_id: &str, + ) -> AiResult { + let mut headers = HeaderMap::new(); + headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json")); + + let auth_value = format!("Bearer {}", config.api_key.expose_secret()); + if let Ok(v) = HeaderValue::from_str(&auth_value) { + headers.insert(AUTHORIZATION, v); + } + + if let Some(ref org) = config.org_id { + if let Ok(v) = HeaderValue::from_str(org) { + headers.insert("OpenAI-Organization", v); + } + } + + for (name, value) in &config.default_headers { + if let (Ok(n), Ok(v)) = ( + HeaderName::try_from(name.as_str()), + HeaderValue::from_str(value), + ) { + headers.insert(n, v); + } + } + + let client = reqwest::Client::builder() + .default_headers(headers) + .build() + .map_err(|e| AiError::Transport { + message: format!("Failed to build HTTP client: {e}"), + source: Some(Box::new(e)), + })?; + + Ok(Self { + config, + model_id: model_id.to_string(), + provider_id: provider_id.to_string(), + capabilities: CapabilitySet::new(), + client, + }) + } + + /// Builder-style setter for capabilities. + pub fn with_capabilities(mut self, caps: CapabilitySet) -> Self { + self.capabilities = caps; + self + } + + /// The chat completions endpoint URL. + fn endpoint(&self) -> String { + let base = self.config.base_url.as_str().trim_end_matches('/'); + format!("{}/chat/completions", base) + } + + /// Execute a non-streaming request and return the parsed response. + async fn do_request(&self, request: ChatCompletionRequest) -> AiResult { + let url = self.endpoint(); + tracing::debug!(url = %url, model = %request.model, "sending chat completion request"); + + let response = self + .client + .post(&url) + .json(&request) + .send() + .await + .map_err(|e| AiError::Transport { + message: e.to_string(), + source: Some(Box::new(e)), + })?; + + let status = response.status(); + if !status.is_success() { + let status_code = status.as_u16(); + let body = match response.text().await { + Ok(body) => body, + Err(e) => { + tracing::warn!(status = status_code, error = %e, "Failed to read error response body"); + format!("") + } + }; + + // Attempt to parse structured error. + if let Ok(api_err) = serde_json::from_str::(&body) { + return Err(map_api_error(status_code, &api_err.error.message)); + } + + return Err(map_api_error(status_code, &body)); + } + + let body = response.text().await.map_err(|e| AiError::Transport { + message: e.to_string(), + source: Some(Box::new(e)), + })?; + + serde_json::from_str::(&body).map_err(|e| { + tracing::error!(body = %body, error = %e, "failed to deserialize response"); + AiError::Serialization(e.to_string()) + }) + } + + /// Execute a streaming request and return a boxed stream. + async fn do_stream_request(&self, request: ChatCompletionRequest) -> AiResult { + let url = self.endpoint(); + tracing::debug!(url = %url, model = %request.model, "sending streaming chat completion request"); + + let req = self.client.post(&url).json(&request); + + let mut es = + reqwest_eventsource::EventSource::new(req).map_err(|e| AiError::Transport { + message: e.to_string(), + source: Some(Box::new(e)), + })?; + + // Spawn the EventSource into a channel-based stream to avoid lifetime + // issues. EventSource must be polled to completion. + let (tx, rx) = tokio::sync::mpsc::channel(64); + + tokio::spawn(async move { + while let Some(event) = es.next().await { + match &event { + Ok(reqwest_eventsource::Event::Open) => {} + Ok(reqwest_eventsource::Event::Message(msg)) => { + if msg.data.trim() == "[DONE]" { + let _ = tx.send(event).await; + es.close(); + break; + } + } + Err(_) => { + let _ = tx.send(event).await; + es.close(); + break; + } + } + if tx.send(event).await.is_err() { + es.close(); + break; + } + } + }); + + let receiver_stream = tokio_stream::wrappers::ReceiverStream::new(rx); + let parsed = stream_parser::parse_sse_stream(receiver_stream); + Ok(Box::pin(parsed)) + } +} + +#[async_trait] +impl LanguageModel for OpenAiCompatibleModel { + fn model_id(&self) -> &str { + &self.model_id + } + + fn provider_id(&self) -> &str { + &self.provider_id + } + + fn capabilities(&self) -> &CapabilitySet { + &self.capabilities + } + + async fn generate(&self, prompt: Prompt, options: GenerateOptions) -> AiResult { + let request = convert::options_to_request(&self.model_id, &prompt, &options, false); + let response = self.do_request(request).await?; + Ok(convert::response_to_result(response, &self.provider_id)) + } + + async fn stream(&self, prompt: Prompt, options: GenerateOptions) -> AiResult { + let request = convert::options_to_request(&self.model_id, &prompt, &options, true); + self.do_stream_request(request).await + } +} + +/// Map an HTTP status code and message to the appropriate `AiError`. +fn map_api_error(status: u16, message: &str) -> AiError { + match status { + 401 => AiError::AuthError { + message: message.to_string(), + }, + 429 => AiError::RateLimit { retry_after: None }, + 408 | 504 => AiError::Timeout, + _ => AiError::ProviderError { + provider: "openai_compatible".to_string(), + status: Some(status), + message: message.to_string(), + }, + } +} diff --git a/crates/rusty_openai_compatible/src/provider.rs b/crates/rusty_openai_compatible/src/provider.rs new file mode 100644 index 0000000..107f4fb --- /dev/null +++ b/crates/rusty_openai_compatible/src/provider.rs @@ -0,0 +1,67 @@ +use rusty_ai::{CapabilitySet, LanguageModel, ModelInfo}; + +use crate::config::OpenAiCompatibleConfig; +use crate::model::OpenAiCompatibleModel; + +/// A provider backed by an OpenAI-compatible API. +/// +/// This is not a trait impl — it is a concrete registry that knows how to +/// create [`OpenAiCompatibleModel`] instances for its registered models. +pub struct OpenAiCompatibleProvider { + config: OpenAiCompatibleConfig, + provider_id: String, + provider_name: String, + models: Vec, +} + +impl OpenAiCompatibleProvider { + /// Create a new provider. + pub fn new(config: OpenAiCompatibleConfig, id: &str, name: &str) -> Self { + Self { + config, + provider_id: id.to_string(), + provider_name: name.to_string(), + models: Vec::new(), + } + } + + /// Register a model's metadata. Returns `self` for chaining. + pub fn with_model_info(mut self, info: ModelInfo) -> Self { + self.models.push(info); + self + } + + /// Return the provider identifier. + pub fn id(&self) -> &str { + &self.provider_id + } + + /// Return the human-readable provider name. + pub fn name(&self) -> &str { + &self.provider_name + } + + /// List the registered model metadata. + pub fn models(&self) -> &[ModelInfo] { + &self.models + } + + /// Create a [`LanguageModel`] for the given model id. + /// + /// If the model id matches registered metadata, the capabilities from that + /// metadata are attached. Otherwise a model with an empty capability set is + /// returned (OpenAI-compatible APIs typically accept any model string). + pub fn language_model(&self, model_id: &str) -> Box { + let caps = self + .models + .iter() + .find(|m| m.id == model_id) + .map(|m| m.capabilities.clone()) + .unwrap_or_else(CapabilitySet::new); + + Box::new( + OpenAiCompatibleModel::new(self.config.clone(), model_id, &self.provider_id) + .with_capabilities(caps), + ) + } +} diff --git a/crates/rusty_openai_compatible/src/stream_parser.rs b/crates/rusty_openai_compatible/src/stream_parser.rs new file mode 100644 index 0000000..f28bb6a --- /dev/null +++ b/crates/rusty_openai_compatible/src/stream_parser.rs @@ -0,0 +1,250 @@ +use std::collections::HashMap; + +use futures::stream::{self, Stream, StreamExt}; +use rusty_ai::{AiError, FinishReason, StreamEvent, Usage}; + +use crate::api_types::ChatCompletionChunk; +use crate::convert::parse_finish_reason; + +/// Parse an SSE event stream from the OpenAI-compatible API into a stream of +/// [`StreamEvent`] values. +/// +/// Tool-call arguments are streamed incrementally by the API, so we accumulate +/// them here and emit a [`StreamEvent::ToolCallEnd`] once the chunk indicates +/// the call is complete (via `finish_reason == "tool_calls"`) or the next +/// tool-call index appears. +pub(crate) fn parse_sse_stream( + event_stream: impl Stream> + + Send + + 'static, +) -> impl Stream> + Send { + // We keep mutable state across events via `stream::unfold`. + struct State { + inner: S, + message_id: Option, + /// In-progress tool calls keyed by index. Stores (call_id, name, args_buffer). + pending_tools: HashMap, + done: bool, + } + + let state = State { + inner: Box::pin(event_stream), + message_id: None, + pending_tools: HashMap::new(), + done: false, + }; + + stream::unfold(state, |mut state| async move { + if state.done { + return None; + } + + loop { + let event = match state.inner.next().await { + Some(Ok(ev)) => ev, + Some(Err(e)) => { + state.done = true; + return Some(( + vec![Err(AiError::StreamError { + message: e.to_string(), + })], + state, + )); + } + None => { + return None; + } + }; + + match event { + reqwest_eventsource::Event::Open => continue, + reqwest_eventsource::Event::Message(msg) => { + let data = msg.data.trim(); + + if data == "[DONE]" { + // Flush any remaining pending tool calls. + let mut events = flush_pending_tools(&mut state.pending_tools); + events.push(Ok(StreamEvent::MessageEnd { + finish_reason: FinishReason::Stop, + usage: None, + })); + state.done = true; + return Some((events, state)); + } + + let chunk: ChatCompletionChunk = match serde_json::from_str(data) { + Ok(c) => c, + Err(e) => { + tracing::error!(data, error = %e, "failed to parse SSE chunk; terminating stream"); + state.done = true; + return Some(( + vec![Err(AiError::StreamError { + message: format!("Unparseable SSE chunk: {e}"), + })], + state, + )); + } + }; + + let mut events = Vec::new(); + + // Emit MessageStart on first chunk. + if state.message_id.is_none() { + state.message_id = Some(chunk.id.clone()); + events.push(Ok(StreamEvent::MessageStart { + message_id: chunk.id.clone(), + })); + } + + for choice in &chunk.choices { + if let Some(ref delta) = choice.delta { + // Text content delta. + if let Some(ref content) = delta.content { + if let Some(text) = content.as_str() { + if !text.is_empty() { + events.push(Ok(StreamEvent::TextDelta { + delta: text.to_string(), + })); + } + } + } + + // Tool call deltas. + if let Some(ref tool_calls) = delta.tool_calls { + for tc in tool_calls { + // The API sends an index in the `id` field + // position for deltas. We parse the index + // from the serialized chunk instead. The + // ChatToolCall struct reuses id/function + // fields, so we determine index from the + // order we see new ids. + + // Determine index: if the tc has a non-empty + // id, it is the start of a new tool call. + let idx = choice.index; + + if !tc.id.is_empty() { + // New tool call starting — flush any + // previous one at the same index. + if let Some((old_id, _old_name, old_args)) = + state.pending_tools.remove(&idx) + { + match serde_json::from_str(&old_args) { + Ok(arguments) => events.push(Ok(StreamEvent::ToolCallEnd { + call_id: old_id, + arguments, + })), + Err(e) => events.push(Err(AiError::StreamError { + message: format!( + "Malformed tool call arguments (JSON parse error: {e})" + ), + })), + } + } + + state.pending_tools.insert( + idx, + ( + tc.id.clone(), + tc.function.name.clone(), + tc.function.arguments.clone(), + ), + ); + + events.push(Ok(StreamEvent::ToolCallStart { + call_id: tc.id.clone(), + tool_name: tc.function.name.clone(), + })); + } else if let Some((ref id, _name, ref mut args)) = + state.pending_tools.get_mut(&idx) + { + // Continuation of an existing tool call. + let arg_delta = &tc.function.arguments; + if !arg_delta.is_empty() { + args.push_str(arg_delta); + let call_id = id.clone(); + events.push(Ok(StreamEvent::ToolCallDelta { + call_id, + delta: arg_delta.clone(), + })); + } + } + } + } + } + + // Check for finish reason. + if let Some(ref reason) = choice.finish_reason { + let fr = parse_finish_reason(Some(reason.as_str())); + + // If finishing with tool_calls, flush pending. + if fr == FinishReason::ToolCall { + events.append(&mut flush_pending_tools(&mut state.pending_tools)); + } + + // Emit usage if present on this chunk. + let usage = chunk.usage.as_ref().map(|u| Usage { + prompt_tokens: Some(u.prompt_tokens), + completion_tokens: Some(u.completion_tokens), + total_tokens: Some(u.total_tokens), + }); + + events.push(Ok(StreamEvent::MessageEnd { + finish_reason: fr, + usage, + })); + + state.done = true; + return Some((events, state)); + } + } + + // Usage-only chunk (when stream_options.include_usage is set + // and choices is empty). + if chunk.choices.is_empty() { + if let Some(ref u) = chunk.usage { + events.push(Ok(StreamEvent::UsageDelta { + usage: Usage { + prompt_tokens: Some(u.prompt_tokens), + completion_tokens: Some(u.completion_tokens), + total_tokens: Some(u.total_tokens), + }, + })); + } + } + + if !events.is_empty() { + return Some((events, state)); + } + // If no events were produced, continue reading. + } + } + } + }) + .flat_map(stream::iter) +} + +/// Flush all pending tool calls into `ToolCallEnd` (or `StreamError`) events. +fn flush_pending_tools( + pending: &mut HashMap, +) -> Vec> { + let mut events = Vec::new(); + let mut indices: Vec = pending.keys().cloned().collect(); + indices.sort(); + for idx in indices { + if let Some((id, _name, args)) = pending.remove(&idx) { + match serde_json::from_str(&args) { + Ok(arguments) => events.push(Ok(StreamEvent::ToolCallEnd { + call_id: id, + arguments, + })), + Err(e) => events.push(Err(AiError::StreamError { + message: format!( + "Malformed tool call arguments for call `{id}` (JSON parse error: {e})" + ), + })), + } + } + } + events +} diff --git a/crates/rusty_phi_silica/Cargo.toml b/crates/rusty_phi_silica/Cargo.toml new file mode 100644 index 0000000..cf15d86 --- /dev/null +++ b/crates/rusty_phi_silica/Cargo.toml @@ -0,0 +1,18 @@ +[package] +name = "rusty_phi_silica" +version.workspace = true +edition.workspace = true +license.workspace = true +description = "Windows Phi Silica local runtime for the Rusty AI SDK" + +[dependencies] +rusty_ai = { workspace = true } +async-trait = { workspace = true } +futures = { workspace = true } +tokio = { workspace = true } +tokio-stream = { workspace = true } +serde = { workspace = true } +serde_json = { workspace = true } +tracing = { workspace = true } +uuid = { workspace = true } +chrono = { workspace = true } diff --git a/crates/rusty_phi_silica/src/bridge.rs b/crates/rusty_phi_silica/src/bridge.rs new file mode 100644 index 0000000..4113166 --- /dev/null +++ b/crates/rusty_phi_silica/src/bridge.rs @@ -0,0 +1,32 @@ +use async_trait::async_trait; + +use crate::types::PhiSilicaAvailability; + +/// Trait that must be implemented by the host application to bridge +/// to the Windows Phi Silica runtime. +#[async_trait] +pub trait PhiSilicaBridge: Send + Sync { + /// Check Phi Silica availability on this device. + async fn availability(&self) -> PhiSilicaAvailability; + + /// Generate text from a prompt. + async fn generate(&self, prompt: &str, max_tokens: Option) -> Result; + + /// Stream generated text in chunks. + /// + /// The Windows App SDK exposes `GenerateResponseWithUpdatesAsync` which + /// yields partial text results. This method should call that and return + /// each partial text chunk. + /// + /// Default implementation falls back to calling `generate()` and returning + /// a single-element Vec. + async fn stream_tokens( + &self, + prompt: &str, + max_tokens: Option, + ) -> Result, String> { + // Default: call generate and return as single chunk + let result = self.generate(prompt, max_tokens).await?; + Ok(vec![result]) + } +} diff --git a/crates/rusty_phi_silica/src/lib.rs b/crates/rusty_phi_silica/src/lib.rs new file mode 100644 index 0000000..8ec46c4 --- /dev/null +++ b/crates/rusty_phi_silica/src/lib.rs @@ -0,0 +1,15 @@ +//! Windows Phi Silica local runtime for the Rusty AI SDK. +//! +//! This crate provides integration with Microsoft's Phi Silica model +//! running on Windows devices with NPU support. Host applications must +//! implement the [`PhiSilicaBridge`] trait. + +mod bridge; +mod model; +mod provider; +mod types; + +pub use bridge::*; +pub use model::*; +pub use provider::*; +pub use types::*; diff --git a/crates/rusty_phi_silica/src/model.rs b/crates/rusty_phi_silica/src/model.rs new file mode 100644 index 0000000..f1af909 --- /dev/null +++ b/crates/rusty_phi_silica/src/model.rs @@ -0,0 +1,133 @@ +use std::sync::Arc; + +use async_trait::async_trait; +use rusty_ai::{ + AiError, AiResult, AiStream, Capability, CapabilitySet, ContentPart, FinishReason, + GenerateOptions, GenerateResult, LanguageModel, Prompt, ResponseMetadata, StreamEvent, Usage, +}; + +use crate::bridge::PhiSilicaBridge; +use crate::types::PhiSilicaAvailability; + +/// A [`LanguageModel`] backed by Windows Phi Silica. +pub struct PhiSilicaModel { + bridge: Arc, + capabilities: CapabilitySet, +} + +impl PhiSilicaModel { + pub(crate) fn new(bridge: Arc) -> Self { + let capabilities = CapabilitySet::new() + .with(Capability::TextInput) + .with(Capability::TextOutput) + .with(Capability::LocalExecution) + .with(Capability::PlatformNative); + Self { + bridge, + capabilities, + } + } + + fn prompt_to_text(prompt: &Prompt) -> String { + match prompt { + Prompt::Text(t) => t.clone(), + Prompt::Messages(msgs) => msgs + .iter() + .flat_map(|m| { + m.content.iter().filter_map(|c| match c { + ContentPart::Text { text } => Some(text.clone()), + _ => None, + }) + }) + .collect::>() + .join("\n"), + } + } +} + +#[async_trait] +impl LanguageModel for PhiSilicaModel { + fn model_id(&self) -> &str { + "phi-silica" + } + + fn provider_id(&self) -> &str { + "phi_silica" + } + + fn capabilities(&self) -> &CapabilitySet { + &self.capabilities + } + + async fn generate(&self, prompt: Prompt, options: GenerateOptions) -> AiResult { + match self.bridge.availability().await { + PhiSilicaAvailability::Available => {} + other => { + return Err(AiError::PlatformUnavailable { + platform: format!("Windows Phi Silica: {other:?}"), + }); + } + } + + let text = Self::prompt_to_text(&prompt); + let response = self + .bridge + .generate(&text, options.max_tokens) + .await + .map_err(|e| AiError::BridgeError { + bridge: "phi_silica".into(), + message: e, + })?; + + Ok(GenerateResult { + text: Some(response), + tool_calls: vec![], + finish_reason: FinishReason::Stop, + usage: Usage::default(), + metadata: ResponseMetadata { + provider: "phi_silica".into(), + model: "phi-silica".into(), + ..Default::default() + }, + }) + } + + async fn stream(&self, prompt: Prompt, options: GenerateOptions) -> AiResult { + match self.bridge.availability().await { + PhiSilicaAvailability::Available => {} + other => { + return Err(AiError::PlatformUnavailable { + platform: format!("Windows Phi Silica: {other:?}"), + }); + } + } + + let text = Self::prompt_to_text(&prompt); + let chunks = self + .bridge + .stream_tokens(&text, options.max_tokens) + .await + .map_err(|e| AiError::BridgeError { + bridge: "phi_silica".into(), + message: e, + })?; + + let message_id = uuid::Uuid::new_v4().to_string(); + let events: Vec> = { + let mut v = Vec::new(); + v.push(Ok(StreamEvent::MessageStart { message_id })); + for chunk in chunks { + if !chunk.is_empty() { + v.push(Ok(StreamEvent::TextDelta { delta: chunk })); + } + } + v.push(Ok(StreamEvent::MessageEnd { + finish_reason: FinishReason::Stop, + usage: None, + })); + v + }; + + Ok(Box::pin(futures::stream::iter(events))) + } +} diff --git a/crates/rusty_phi_silica/src/provider.rs b/crates/rusty_phi_silica/src/provider.rs new file mode 100644 index 0000000..d181650 --- /dev/null +++ b/crates/rusty_phi_silica/src/provider.rs @@ -0,0 +1,67 @@ +use std::sync::Arc; + +use rusty_ai::{ + AiError, AiResult, Capability, CapabilitySet, EmbeddingModel, LanguageModel, ModelInfo, + Provider, +}; + +use crate::bridge::PhiSilicaBridge; +use crate::model::PhiSilicaModel; +use crate::types::PhiSilicaAvailability; + +/// Provider for Windows Phi Silica on-device inference. +pub struct PhiSilicaProvider { + bridge: Arc, +} + +impl PhiSilicaProvider { + pub fn new(bridge: impl PhiSilicaBridge + 'static) -> Self { + Self { + bridge: Arc::new(bridge), + } + } + + /// Get the Phi Silica language model. + pub fn model(&self) -> PhiSilicaModel { + PhiSilicaModel::new(self.bridge.clone()) + } + + /// Check Phi Silica availability. + pub async fn availability(&self) -> PhiSilicaAvailability { + self.bridge.availability().await + } +} + +impl Provider for PhiSilicaProvider { + fn id(&self) -> &str { + "phi_silica" + } + + fn name(&self) -> &str { + "Phi Silica (Windows)" + } + + fn language_model(&self, _model_id: &str) -> AiResult> { + Ok(Box::new(self.model())) + } + + fn embedding_model(&self, model_id: &str) -> AiResult> { + Err(AiError::UnsupportedCapability { + capability: "embeddings".into(), + provider: format!("phi_silica/{model_id}"), + }) + } + + fn available_models(&self) -> Vec { + vec![ModelInfo { + id: "phi-silica".into(), + provider: "phi_silica".into(), + display_name: "Phi Silica (Windows NPU)".into(), + capabilities: CapabilitySet::new() + .with(Capability::TextInput) + .with(Capability::TextOutput) + .with(Capability::LocalExecution) + .with(Capability::PlatformNative), + }] + } +} diff --git a/crates/rusty_phi_silica/src/types.rs b/crates/rusty_phi_silica/src/types.rs new file mode 100644 index 0000000..5cb6eb0 --- /dev/null +++ b/crates/rusty_phi_silica/src/types.rs @@ -0,0 +1,12 @@ +/// Availability state of Phi Silica on this Windows device. +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum PhiSilicaAvailability { + /// Phi Silica is available and ready. + Available, + /// Phi Silica is not available on this device. + Unavailable, + /// Windows version does not support Phi Silica. + WindowsVersionTooOld, + /// NPU hardware was not detected. + NpuNotDetected, +} diff --git a/crates/rusty_testing/Cargo.toml b/crates/rusty_testing/Cargo.toml new file mode 100644 index 0000000..b516121 --- /dev/null +++ b/crates/rusty_testing/Cargo.toml @@ -0,0 +1,18 @@ +[package] +name = "rusty_testing" +version.workspace = true +edition.workspace = true +license.workspace = true +description = "Testing utilities and mock providers for the Rusty AI SDK" + +[dependencies] +rusty_ai = { workspace = true } +async-trait = { workspace = true } +futures = { workspace = true } +tokio = { workspace = true } +tokio-stream = { workspace = true } +serde = { workspace = true } +serde_json = { workspace = true } +schemars = { workspace = true } +uuid = { workspace = true } +chrono = { workspace = true } diff --git a/crates/rusty_testing/src/assertions.rs b/crates/rusty_testing/src/assertions.rs new file mode 100644 index 0000000..4b35eee --- /dev/null +++ b/crates/rusty_testing/src/assertions.rs @@ -0,0 +1,80 @@ +//! Test helper functions for asserting on AI SDK types. + +use futures::StreamExt; + +use rusty_ai::error::AiResult; +use rusty_ai::stream::{AiStream, StreamEvent}; +use rusty_ai::structured::GenerateResult; + +/// Collect all text deltas from a stream into a single string. +pub async fn collect_text(mut stream: AiStream) -> AiResult { + let mut text = String::new(); + while let Some(event) = stream.next().await { + let event = event?; + if let StreamEvent::TextDelta { delta } = event { + text.push_str(&delta); + } + } + Ok(text) +} + +/// Collect all events from a stream into a vector. +pub async fn collect_events(mut stream: AiStream) -> AiResult> { + let mut events = Vec::new(); + while let Some(event) = stream.next().await { + events.push(event?); + } + Ok(events) +} + +/// Assert that a stream produces the expected text (all text deltas concatenated). +/// +/// # Panics +/// +/// Panics if the stream produces an error or the collected text does not match. +pub async fn assert_stream_text(stream: AiStream, expected: &str) { + let text = collect_text(stream) + .await + .expect("stream should not produce an error"); + assert_eq!( + text, expected, + "stream text mismatch: expected {:?}, got {:?}", + expected, text + ); +} + +/// Assert that a `GenerateResult` contains text matching the expected value. +/// +/// # Panics +/// +/// Panics if the result has no text or the text does not match. +pub fn assert_result_text(result: &GenerateResult, expected: &str) { + let text = result + .text + .as_deref() + .expect("GenerateResult should contain text"); + assert_eq!( + text, expected, + "result text mismatch: expected {:?}, got {:?}", + expected, text + ); +} + +/// Assert that a `GenerateResult` contains tool calls with the given names. +/// +/// The order of `tool_names` does not matter; all listed names must be present. +/// +/// # Panics +/// +/// Panics if any of the expected tool names are missing. +pub fn assert_has_tool_calls(result: &GenerateResult, tool_names: &[&str]) { + let actual_names: Vec<&str> = result.tool_calls.iter().map(|c| c.name.as_str()).collect(); + for expected in tool_names { + assert!( + actual_names.contains(expected), + "expected tool call {:?} not found in {:?}", + expected, + actual_names, + ); + } +} diff --git a/crates/rusty_testing/src/lib.rs b/crates/rusty_testing/src/lib.rs new file mode 100644 index 0000000..68f2ba2 --- /dev/null +++ b/crates/rusty_testing/src/lib.rs @@ -0,0 +1,9 @@ +//! Testing utilities and mock providers for the Rusty AI SDK. + +mod assertions; +mod mock_model; +mod mock_provider; + +pub use assertions::*; +pub use mock_model::*; +pub use mock_provider::*; diff --git a/crates/rusty_testing/src/mock_model.rs b/crates/rusty_testing/src/mock_model.rs new file mode 100644 index 0000000..f59448e --- /dev/null +++ b/crates/rusty_testing/src/mock_model.rs @@ -0,0 +1,313 @@ +use std::sync::{Arc, Mutex}; + +use async_trait::async_trait; +use futures::stream; + +use rusty_ai::capability::{Capability, CapabilitySet}; +use rusty_ai::error::{AiError, AiResult}; +use rusty_ai::model::{EmbeddingModel, GenerateOptions, LanguageModel}; +use rusty_ai::prompt::Prompt; +use rusty_ai::stream::{AiStream, StreamEvent, SyntheticStreamer}; +use rusty_ai::structured::{EmbeddingResult, GenerateResult}; +use rusty_ai::tool::ToolCallRequest; +use rusty_ai::types::{FinishReason, ResponseMetadata}; +use rusty_ai::usage::Usage; + +/// A pre-configured response for the mock model to return. +pub enum MockResponse { + /// Return a plain text response. + Text(String), + /// Return a set of tool calls. + ToolCalls(Vec), + /// Return an error. + Error(AiError), + /// Return a JSON object (serialised into the text field). + Object(serde_json::Value), +} + +/// A recorded invocation of the mock model. +#[derive(Debug, Clone)] +pub struct RecordedCall { + /// The prompt that was passed to the model. + pub prompt: Prompt, + /// The options that were passed to the model. + pub options: GenerateOptions, + /// When the call was made. + pub timestamp: chrono::DateTime, +} + +/// A mock language model for testing. +/// +/// Responses are consumed in FIFO order. When no responses remain the model +/// returns an error. +pub struct MockLanguageModel { + id: String, + provider: String, + capabilities: CapabilitySet, + responses: Arc>>, + calls: Arc>>, +} + +impl MockLanguageModel { + /// Create a new mock model with the given identifier. + pub fn new(id: &str) -> Self { + Self { + id: id.to_owned(), + provider: "mock".to_owned(), + capabilities: CapabilitySet::new() + .with(Capability::TextInput) + .with(Capability::TextOutput) + .with(Capability::Streaming) + .with(Capability::ToolCalling), + responses: Arc::new(Mutex::new(Vec::new())), + calls: Arc::new(Mutex::new(Vec::new())), + } + } + + /// Queue a response (builder style). + pub fn with_response(self, response: MockResponse) -> Self { + self.responses.lock().unwrap().push(response); + self + } + + /// Queue a text response (builder style). + pub fn with_text(self, text: &str) -> Self { + self.with_response(MockResponse::Text(text.to_owned())) + } + + /// Queue an error response (builder style). + pub fn with_error(self, error: AiError) -> Self { + self.with_response(MockResponse::Error(error)) + } + + /// Queue a tool-calls response (builder style). + pub fn with_tool_calls(self, calls: Vec) -> Self { + self.with_response(MockResponse::ToolCalls(calls)) + } + + /// Queue a JSON object response (builder style). + pub fn with_object(self, value: serde_json::Value) -> Self { + self.with_response(MockResponse::Object(value)) + } + + /// Override the provider name (builder style). + pub fn with_provider(mut self, provider: &str) -> Self { + self.provider = provider.to_owned(); + self + } + + /// Override the capability set (builder style). + pub fn with_capabilities(mut self, capabilities: CapabilitySet) -> Self { + self.capabilities = capabilities; + self + } + + /// Return a snapshot of all recorded calls. + pub fn calls(&self) -> Vec { + self.calls.lock().unwrap().clone() + } + + /// Return the number of times this model was invoked. + pub fn call_count(&self) -> usize { + self.calls.lock().unwrap().len() + } + + /// Get the model identifier. + pub fn id(&self) -> &str { + &self.id + } + + /// Get the provider name. + pub fn provider(&self) -> &str { + &self.provider + } + + // ------- internal helpers ------- + + fn record_call(&self, prompt: &Prompt, options: &GenerateOptions) { + self.calls.lock().unwrap().push(RecordedCall { + prompt: prompt.clone(), + options: options.clone(), + timestamp: chrono::Utc::now(), + }); + } + + fn next_response(&self) -> MockResponse { + let mut responses = self.responses.lock().unwrap(); + if responses.is_empty() { + MockResponse::Error(AiError::StreamError { + message: "MockLanguageModel: no more queued responses".into(), + }) + } else { + responses.remove(0) + } + } + + fn response_to_result(&self, response: MockResponse) -> AiResult { + match response { + MockResponse::Text(text) => Ok(GenerateResult { + text: Some(text), + tool_calls: Vec::new(), + finish_reason: FinishReason::Stop, + usage: Usage::default(), + metadata: ResponseMetadata::default(), + }), + MockResponse::ToolCalls(calls) => Ok(GenerateResult { + text: None, + tool_calls: calls, + finish_reason: FinishReason::ToolCall, + usage: Usage::default(), + metadata: ResponseMetadata::default(), + }), + MockResponse::Object(value) => { + let text = serde_json::to_string(&value) + .map_err(|e| AiError::Serialization(e.to_string()))?; + Ok(GenerateResult { + text: Some(text), + tool_calls: Vec::new(), + finish_reason: FinishReason::Stop, + usage: Usage::default(), + metadata: ResponseMetadata::default(), + }) + } + MockResponse::Error(err) => Err(err), + } + } + + fn response_to_stream(&self, response: MockResponse) -> AiResult { + match response { + MockResponse::Text(text) => Ok(SyntheticStreamer::stream(text, 20)), + MockResponse::ToolCalls(calls) => { + let mut events: Vec> = Vec::new(); + events.push(Ok(StreamEvent::MessageStart { + message_id: uuid::Uuid::new_v4().to_string(), + })); + for call in calls { + events.push(Ok(StreamEvent::ToolCallStart { + call_id: call.id.clone(), + tool_name: call.name.clone(), + })); + events.push(Ok(StreamEvent::ToolCallEnd { + call_id: call.id.clone(), + arguments: call.arguments.clone(), + })); + } + events.push(Ok(StreamEvent::MessageEnd { + finish_reason: FinishReason::ToolCall, + usage: None, + })); + Ok(Box::pin(stream::iter(events))) + } + MockResponse::Object(value) => { + let text = serde_json::to_string(&value) + .map_err(|e| AiError::Serialization(e.to_string()))?; + Ok(SyntheticStreamer::stream(text, 20)) + } + MockResponse::Error(err) => Err(err), + } + } +} + +#[async_trait] +impl LanguageModel for MockLanguageModel { + fn model_id(&self) -> &str { + &self.id + } + + fn provider_id(&self) -> &str { + &self.provider + } + + fn capabilities(&self) -> &CapabilitySet { + &self.capabilities + } + + async fn generate(&self, prompt: Prompt, options: GenerateOptions) -> AiResult { + self.record_call(&prompt, &options); + let response = self.next_response(); + self.response_to_result(response) + } + + async fn stream(&self, prompt: Prompt, options: GenerateOptions) -> AiResult { + self.record_call(&prompt, &options); + let response = self.next_response(); + self.response_to_stream(response) + } +} + +// --------------------------------------------------------------------------- +// Mock embedding model +// --------------------------------------------------------------------------- + +/// A mock embedding model for testing. +/// +/// Returns pre-configured embedding vectors in FIFO order. When the queue is +/// exhausted it produces zero vectors. +pub struct MockEmbeddingModel { + id: String, + provider: String, + dims: usize, + embeddings: Arc>>>, +} + +impl MockEmbeddingModel { + /// Create a new mock embedding model. + pub fn new(id: &str, dims: usize) -> Self { + Self { + id: id.to_owned(), + provider: "mock".to_owned(), + dims, + embeddings: Arc::new(Mutex::new(Vec::new())), + } + } + + /// Queue embedding vectors (builder style). + pub fn with_embeddings(self, embeddings: Vec>) -> Self { + let mut store = self.embeddings.lock().unwrap(); + store.extend(embeddings); + drop(store); + self + } + + /// Get the model identifier. + pub fn id(&self) -> &str { + &self.id + } + + /// Get the provider name. + pub fn provider(&self) -> &str { + &self.provider + } +} + +#[async_trait] +impl EmbeddingModel for MockEmbeddingModel { + fn model_id(&self) -> &str { + &self.id + } + + fn provider_id(&self) -> &str { + &self.provider + } + + fn dimensions(&self) -> Option { + Some(self.dims) + } + + async fn embed(&self, texts: Vec) -> AiResult { + let mut store = self.embeddings.lock().unwrap(); + let mut result = Vec::with_capacity(texts.len()); + for _ in &texts { + if store.is_empty() { + // Return zero vectors when queue is exhausted. + result.push(vec![0.0f64; self.dims]); + } else { + result.push(store.remove(0)); + } + } + Ok(EmbeddingResult { + embeddings: result, + usage: Usage::default(), + }) + } +} diff --git a/crates/rusty_testing/src/mock_provider.rs b/crates/rusty_testing/src/mock_provider.rs new file mode 100644 index 0000000..76548a2 --- /dev/null +++ b/crates/rusty_testing/src/mock_provider.rs @@ -0,0 +1,118 @@ +use std::collections::HashMap; + +use async_trait::async_trait; + +use rusty_ai::error::{AiError, AiResult}; +use rusty_ai::model::{EmbeddingModel, LanguageModel}; +use rusty_ai::provider::Provider; +use rusty_ai::types::ModelInfo; + +use crate::mock_model::{MockEmbeddingModel, MockLanguageModel}; + +/// A mock provider for testing that holds pre-configured mock models. +pub struct MockProvider { + id: String, + name: String, + models: HashMap, + embedding_models: HashMap, +} + +impl MockProvider { + /// Create a new, empty mock provider. + pub fn new() -> Self { + Self { + id: "mock".to_owned(), + name: "Mock Provider".to_owned(), + models: HashMap::new(), + embedding_models: HashMap::new(), + } + } + + /// Register a language model (builder style). + pub fn with_model(mut self, model: MockLanguageModel) -> Self { + self.models.insert(model.id().to_owned(), model); + self + } + + /// Register an embedding model (builder style). + pub fn with_embedding_model(mut self, model: MockEmbeddingModel) -> Self { + self.embedding_models.insert(model.id().to_owned(), model); + self + } + + /// Override the provider id (builder style). + pub fn with_id(mut self, id: &str) -> Self { + self.id = id.to_owned(); + self + } + + /// Override the provider display name (builder style). + pub fn with_name(mut self, name: &str) -> Self { + self.name = name.to_owned(); + self + } +} + +impl Default for MockProvider { + fn default() -> Self { + Self::new() + } +} + +#[async_trait] +impl Provider for MockProvider { + fn id(&self) -> &str { + &self.id + } + + fn name(&self) -> &str { + &self.name + } + + fn language_model(&self, model_id: &str) -> AiResult> { + // Since MockLanguageModel is not Clone we cannot hand out the stored + // instance directly. Instead we create a new MockLanguageModel that + // shares the same internal Arc state. To achieve this we would need + // interior sharing. For simplicity, return an error if the model is + // not registered -- callers should use `MockLanguageModel` directly + // in most tests. + Err(AiError::ModelUnavailable { + model: format!( + "MockProvider does not support runtime model lookup; \ + use MockLanguageModel directly. Requested: {model_id}" + ), + }) + } + + fn embedding_model(&self, model_id: &str) -> AiResult> { + Err(AiError::ModelUnavailable { + model: format!( + "MockProvider does not support runtime model lookup; \ + use MockEmbeddingModel directly. Requested: {model_id}" + ), + }) + } + + fn available_models(&self) -> Vec { + let mut infos: Vec = Vec::new(); + for model in self.models.values() { + infos.push(ModelInfo { + id: model.id().to_owned(), + provider: model.provider().to_owned(), + display_name: format!("Mock {}", model.id()), + capabilities: rusty_ai::CapabilitySet::new() + .with(rusty_ai::Capability::TextInput) + .with(rusty_ai::Capability::TextOutput), + }); + } + for model in self.embedding_models.values() { + infos.push(ModelInfo { + id: model.id().to_owned(), + provider: model.provider().to_owned(), + display_name: format!("Mock Embedding {}", model.id()), + capabilities: rusty_ai::CapabilitySet::new().with(rusty_ai::Capability::Embeddings), + }); + } + infos + } +} diff --git a/crates/rusty_ui_stream/Cargo.toml b/crates/rusty_ui_stream/Cargo.toml new file mode 100644 index 0000000..fbb74b9 --- /dev/null +++ b/crates/rusty_ui_stream/Cargo.toml @@ -0,0 +1,16 @@ +[package] +name = "rusty_ui_stream" +version.workspace = true +edition.workspace = true +license.workspace = true +description = "UI stream protocol for frontend integration with the Rusty AI SDK" + +[dependencies] +rusty_ai = { workspace = true } +serde = { workspace = true } +serde_json = { workspace = true } +futures = { workspace = true } +tokio = { workspace = true } +tokio-stream = { workspace = true } +bytes = { workspace = true } +pin-project-lite = { workspace = true } diff --git a/crates/rusty_ui_stream/src/event.rs b/crates/rusty_ui_stream/src/event.rs new file mode 100644 index 0000000..581f2cc --- /dev/null +++ b/crates/rusty_ui_stream/src/event.rs @@ -0,0 +1,191 @@ +use rusty_ai::StreamEvent; +use serde::{Deserialize, Serialize}; + +/// The protocol version for the UI stream format. +pub const PROTOCOL_VERSION: &str = "1.0"; + +/// A UI-facing stream event. This is a versioned, frontend-friendly +/// representation of the core [`StreamEvent`] type. +/// +/// Events are tagged with `"type"` for easy dispatch on the client side. +#[derive(Debug, Clone, Serialize, Deserialize)] +#[serde(tag = "type")] +pub enum UiStreamEvent { + /// Signals that generation has started. + #[serde(rename = "start")] + Start { + message_id: String, + model: String, + version: String, + }, + + /// A chunk of generated text. + #[serde(rename = "text")] + Text { delta: String }, + + /// A tool call has begun. + #[serde(rename = "tool_call_start")] + ToolCallStart { call_id: String, tool_name: String }, + + /// A chunk of tool call arguments (partial JSON). + #[serde(rename = "tool_call_args")] + ToolCallArgs { call_id: String, delta: String }, + + /// A tool call has finished accumulating arguments. + #[serde(rename = "tool_call_end")] + ToolCallEnd { call_id: String }, + + /// Result from executing a tool. + #[serde(rename = "tool_result")] + ToolResult { + call_id: String, + content: String, + is_error: bool, + }, + + /// A partial structured-output object. + #[serde(rename = "object")] + Object { delta: serde_json::Value }, + + /// Token usage information. + #[serde(rename = "usage")] + Usage { + prompt_tokens: Option, + completion_tokens: Option, + }, + + /// An error occurred during generation. + #[serde(rename = "error")] + Error { code: String, message: String }, + + /// Intermediate thinking / reasoning tokens from a reasoning model. + #[serde(rename = "thinking")] + Thinking { delta: String }, + + /// Generation is complete. + #[serde(rename = "done")] + Done { finish_reason: String }, +} + +impl From for UiStreamEvent { + fn from(event: StreamEvent) -> Self { + match event { + StreamEvent::MessageStart { message_id } => UiStreamEvent::Start { + message_id, + model: String::new(), + version: PROTOCOL_VERSION.to_string(), + }, + StreamEvent::TextDelta { delta } => UiStreamEvent::Text { delta }, + StreamEvent::ToolCallStart { call_id, tool_name } => { + UiStreamEvent::ToolCallStart { call_id, tool_name } + } + StreamEvent::ToolCallDelta { call_id, delta } => { + UiStreamEvent::ToolCallArgs { call_id, delta } + } + StreamEvent::ToolCallEnd { call_id, .. } => UiStreamEvent::ToolCallEnd { call_id }, + StreamEvent::ToolResult { + call_id, + content, + is_error, + } => UiStreamEvent::ToolResult { + call_id, + content, + is_error, + }, + StreamEvent::ObjectDelta { delta } => UiStreamEvent::Object { delta }, + StreamEvent::UsageDelta { usage } => UiStreamEvent::Usage { + prompt_tokens: usage.prompt_tokens, + completion_tokens: usage.completion_tokens, + }, + StreamEvent::Warning { message } => UiStreamEvent::Error { + code: "warning".to_string(), + message, + }, + StreamEvent::MessageEnd { + finish_reason, + usage: _, + } => { + let reason = format!("{:?}", finish_reason).to_lowercase(); + UiStreamEvent::Done { + finish_reason: reason, + } + } + StreamEvent::Error { error } => UiStreamEvent::Error { + code: "stream_error".to_string(), + message: error, + }, + StreamEvent::ThinkingDelta { delta } => UiStreamEvent::Thinking { delta }, + StreamEvent::SyntheticStreamingNotice => UiStreamEvent::Error { + code: "notice".to_string(), + message: "synthetic streaming".to_string(), + }, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use rusty_ai::types::FinishReason; + + #[test] + fn test_start_conversion() { + let stream_event = StreamEvent::MessageStart { + message_id: "msg-123".into(), + }; + let ui_event: UiStreamEvent = stream_event.into(); + match ui_event { + UiStreamEvent::Start { + message_id, + version, + .. + } => { + assert_eq!(message_id, "msg-123"); + assert_eq!(version, PROTOCOL_VERSION); + } + _ => panic!("expected Start event"), + } + } + + #[test] + fn test_text_delta_conversion() { + let stream_event = StreamEvent::TextDelta { + delta: "Hello".into(), + }; + let ui_event: UiStreamEvent = stream_event.into(); + match ui_event { + UiStreamEvent::Text { delta } => assert_eq!(delta, "Hello"), + _ => panic!("expected Text event"), + } + } + + #[test] + fn test_serde_roundtrip() { + let event = UiStreamEvent::Error { + code: "rate_limit".into(), + message: "Too many requests".into(), + }; + let json = serde_json::to_string(&event).unwrap(); + let deserialized: UiStreamEvent = serde_json::from_str(&json).unwrap(); + match deserialized { + UiStreamEvent::Error { code, message } => { + assert_eq!(code, "rate_limit"); + assert_eq!(message, "Too many requests"); + } + _ => panic!("expected Error event"), + } + } + + #[test] + fn test_done_conversion() { + let stream_event = StreamEvent::MessageEnd { + finish_reason: FinishReason::Stop, + usage: None, + }; + let ui_event: UiStreamEvent = stream_event.into(); + match ui_event { + UiStreamEvent::Done { finish_reason } => assert_eq!(finish_reason, "stop"), + _ => panic!("expected Done event"), + } + } +} diff --git a/crates/rusty_ui_stream/src/lib.rs b/crates/rusty_ui_stream/src/lib.rs new file mode 100644 index 0000000..8ac9e3b --- /dev/null +++ b/crates/rusty_ui_stream/src/lib.rs @@ -0,0 +1,14 @@ +//! UI stream protocol for frontend integration with the Rusty AI SDK. +//! +//! This crate provides typed UI stream events and encoders for two common +//! wire formats: **SSE** (Server-Sent Events) and **NDJSON** (Newline-Delimited +//! JSON). Both encoders accept an [`rusty_ai::AiStream`] from a Rusty AI provider and +//! produce a byte stream suitable for sending over HTTP. + +mod event; +mod ndjson; +mod sse; + +pub use event::*; +pub use ndjson::*; +pub use sse::*; diff --git a/crates/rusty_ui_stream/src/ndjson.rs b/crates/rusty_ui_stream/src/ndjson.rs new file mode 100644 index 0000000..b5193ec --- /dev/null +++ b/crates/rusty_ui_stream/src/ndjson.rs @@ -0,0 +1,147 @@ +use bytes::Bytes; +use futures::Stream; +use rusty_ai::{AiError, AiStream}; +use tokio_stream::StreamExt; + +use crate::UiStreamEvent; + +/// Encoder for the [NDJSON](http://ndjson.org/) (Newline-Delimited JSON) +/// wire format. +/// +/// Each event is serialized as a single line of JSON followed by a newline: +/// ```text +/// {"type":"text","delta":"Hello"}\n +/// ``` +pub struct NdjsonEncoder; + +impl NdjsonEncoder { + /// Encode a single [`UiStreamEvent`] as NDJSON (one JSON object followed + /// by `\n`). + pub fn encode(event: &UiStreamEvent) -> String { + let json = serde_json::to_string(event).unwrap_or_else(|e| { + format!( + r#"{{"type":"error","code":"serialization_error","message":"{}"}}"#, + e.to_string().replace('"', "\\\"") + ) + }); + format!("{json}\n") + } + + /// Transform an [`AiStream`] into a stream of NDJSON-encoded [`Bytes`] + /// chunks. + /// + /// Each `Ok` item from the source stream is converted to a + /// [`UiStreamEvent`] and then NDJSON-encoded. Errors from the source + /// stream are converted into [`UiStreamEvent::Error`] events so the + /// client always receives well-formed NDJSON. + pub fn encode_stream(stream: AiStream) -> impl Stream> { + stream.map(|result| match result { + Ok(event) => { + let ui_event: UiStreamEvent = event.into(); + let encoded = Self::encode(&ui_event); + Ok(Bytes::from(encoded)) + } + Err(e) => { + let error_event = UiStreamEvent::Error { + code: "stream_error".to_string(), + message: e.to_string(), + }; + let encoded = Self::encode(&error_event); + Ok(Bytes::from(encoded)) + } + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::PROTOCOL_VERSION; + + #[test] + fn test_encode_text_event() { + let event = UiStreamEvent::Text { + delta: "hello".into(), + }; + let ndjson = NdjsonEncoder::encode(&event); + assert!(ndjson.ends_with('\n')); + // Should be exactly one line. + assert_eq!(ndjson.matches('\n').count(), 1); + let parsed: UiStreamEvent = serde_json::from_str(ndjson.trim()).unwrap(); + match parsed { + UiStreamEvent::Text { delta } => assert_eq!(delta, "hello"), + _ => panic!("unexpected variant"), + } + } + + #[test] + fn test_encode_start_event() { + let event = UiStreamEvent::Start { + message_id: "abc".into(), + model: "test-model".into(), + version: PROTOCOL_VERSION.into(), + }; + let ndjson = NdjsonEncoder::encode(&event); + assert!(ndjson.contains(r#""type":"start""#)); + assert!(ndjson.ends_with('\n')); + } + + #[tokio::test] + async fn test_encode_stream() { + use futures::stream; + use rusty_ai::StreamEvent; + + use tokio_stream::StreamExt; + + let events: Vec> = vec![ + Ok(StreamEvent::MessageStart { + message_id: "m1".into(), + }), + Ok(StreamEvent::TextDelta { delta: "Hi".into() }), + Ok(StreamEvent::MessageEnd { + finish_reason: rusty_ai::FinishReason::Stop, + usage: None, + }), + ]; + + let ai_stream: AiStream = Box::pin(stream::iter(events)); + + let encoded: Vec<_> = NdjsonEncoder::encode_stream(ai_stream).collect().await; + + assert_eq!(encoded.len(), 3); + for item in &encoded { + assert!(item.is_ok()); + let bytes = item.as_ref().unwrap(); + let s = std::str::from_utf8(bytes).unwrap(); + assert!(s.ends_with('\n')); + // Each line should be valid JSON. + let _: serde_json::Value = serde_json::from_str(s.trim()).unwrap(); + } + } + + #[tokio::test] + async fn test_encode_stream_with_error() { + use futures::stream; + use rusty_ai::StreamEvent; + + use tokio_stream::StreamExt; + + let events: Vec> = vec![ + Ok(StreamEvent::TextDelta { delta: "Hi".into() }), + Err(AiError::StreamError { + message: "timeout".into(), + }), + ]; + + let ai_stream: AiStream = Box::pin(stream::iter(events)); + + let encoded: Vec<_> = NdjsonEncoder::encode_stream(ai_stream).collect().await; + + assert_eq!(encoded.len(), 2); + assert!(encoded[1].is_ok()); + let bytes = encoded[1].as_ref().unwrap(); + let s = std::str::from_utf8(bytes).unwrap(); + assert!(s.contains("stream_error")); + assert!(s.contains("timeout")); + } +} diff --git a/crates/rusty_ui_stream/src/sse.rs b/crates/rusty_ui_stream/src/sse.rs new file mode 100644 index 0000000..0c126b7 --- /dev/null +++ b/crates/rusty_ui_stream/src/sse.rs @@ -0,0 +1,162 @@ +use bytes::Bytes; +use futures::Stream; +use rusty_ai::{AiError, AiStream}; +use tokio_stream::StreamExt; + +use crate::UiStreamEvent; + +/// Encoder for the [Server-Sent Events](https://html.spec.whatwg.org/multipage/server-sent-events.html) +/// wire format. +/// +/// Each event is serialized as: +/// ```text +/// data: {"type":"text","delta":"Hello"}\n\n +/// ``` +pub struct SseEncoder; + +impl SseEncoder { + /// Encode a single [`UiStreamEvent`] into SSE format. + /// + /// The output has the form `data: {json}\n\n` and is ready to be written + /// directly to an HTTP response body. + pub fn encode(event: &UiStreamEvent) -> String { + // serde_json::to_string never fails for our types (no maps with + // non-string keys, etc.), but we handle the error defensively. + let json = serde_json::to_string(event).unwrap_or_else(|e| { + // Fall back to an error event so the client always gets valid SSE. + format!( + r#"{{"type":"error","code":"serialization_error","message":"{}"}}"#, + e.to_string().replace('"', "\\\"") + ) + }); + format!("data: {json}\n\n") + } + + /// Transform an [`AiStream`] into a stream of SSE-encoded [`Bytes`] chunks. + /// + /// Each `Ok` item from the source stream is converted to a + /// [`UiStreamEvent`] and then SSE-encoded. Errors from the source stream + /// are forwarded as [`UiStreamEvent::Error`] events so the client always + /// receives well-formed SSE data, followed by the original error being + /// propagated. + pub fn encode_stream(stream: AiStream) -> impl Stream> { + stream.map(|result| match result { + Ok(event) => { + let ui_event: UiStreamEvent = event.into(); + let encoded = Self::encode(&ui_event); + Ok(Bytes::from(encoded)) + } + Err(e) => { + // Emit an SSE error frame so the client can handle it, then + // propagate the original error to let the transport layer + // know the stream is unhealthy. + let error_event = UiStreamEvent::Error { + code: "stream_error".to_string(), + message: e.to_string(), + }; + let encoded = Self::encode(&error_event); + // We choose to return the encoded error event as a successful + // byte chunk so the client receives it. The stream will + // naturally end after the source yields this error. + // + // If callers prefer to propagate the error, they can inspect + // the UiStreamEvent::Error on the client side. + Ok(Bytes::from(encoded)) + } + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::PROTOCOL_VERSION; + + #[test] + fn test_encode_text_event() { + let event = UiStreamEvent::Text { + delta: "world".into(), + }; + let sse = SseEncoder::encode(&event); + assert!(sse.starts_with("data: ")); + assert!(sse.ends_with("\n\n")); + // The JSON payload should be valid. + let json_str = sse.trim_start_matches("data: ").trim(); + let parsed: UiStreamEvent = serde_json::from_str(json_str).unwrap(); + match parsed { + UiStreamEvent::Text { delta } => assert_eq!(delta, "world"), + _ => panic!("unexpected variant"), + } + } + + #[test] + fn test_encode_start_event() { + let event = UiStreamEvent::Start { + message_id: "abc".into(), + model: "test-model".into(), + version: PROTOCOL_VERSION.into(), + }; + let sse = SseEncoder::encode(&event); + assert!(sse.contains(r#""type":"start""#)); + assert!(sse.contains(r#""version":"1.0""#)); + } + + #[tokio::test] + async fn test_encode_stream() { + use futures::stream; + use rusty_ai::StreamEvent; + + use tokio_stream::StreamExt; + + let events: Vec> = vec![ + Ok(StreamEvent::MessageStart { + message_id: "m1".into(), + }), + Ok(StreamEvent::TextDelta { delta: "Hi".into() }), + Ok(StreamEvent::MessageEnd { + finish_reason: rusty_ai::FinishReason::Stop, + usage: None, + }), + ]; + + let ai_stream: AiStream = Box::pin(stream::iter(events)); + + let encoded: Vec<_> = SseEncoder::encode_stream(ai_stream).collect().await; + + assert_eq!(encoded.len(), 3); + for item in &encoded { + assert!(item.is_ok()); + let bytes = item.as_ref().unwrap(); + let s = std::str::from_utf8(bytes).unwrap(); + assert!(s.starts_with("data: ")); + assert!(s.ends_with("\n\n")); + } + } + + #[tokio::test] + async fn test_encode_stream_with_error() { + use futures::stream; + use rusty_ai::StreamEvent; + + use tokio_stream::StreamExt; + + let events: Vec> = vec![ + Ok(StreamEvent::TextDelta { delta: "Hi".into() }), + Err(AiError::StreamError { + message: "connection lost".into(), + }), + ]; + + let ai_stream: AiStream = Box::pin(stream::iter(events)); + + let encoded: Vec<_> = SseEncoder::encode_stream(ai_stream).collect().await; + + assert_eq!(encoded.len(), 2); + // The error should have been encoded as an SSE event (Ok bytes). + assert!(encoded[1].is_ok()); + let bytes = encoded[1].as_ref().unwrap(); + let s = std::str::from_utf8(bytes).unwrap(); + assert!(s.contains("stream_error")); + assert!(s.contains("connection lost")); + } +} diff --git a/examples/basic_text/Cargo.toml b/examples/basic_text/Cargo.toml new file mode 100644 index 0000000..16a2788 --- /dev/null +++ b/examples/basic_text/Cargo.toml @@ -0,0 +1,11 @@ +[package] +name = "example_basic_text" +version.workspace = true +edition.workspace = true +license.workspace = true + +[dependencies] +rusty_ai = { workspace = true } +rusty_chatgpt = { workspace = true } +rusty_claude = { workspace = true } +tokio = { workspace = true } diff --git a/examples/basic_text/src/main.rs b/examples/basic_text/src/main.rs new file mode 100644 index 0000000..ab9e722 --- /dev/null +++ b/examples/basic_text/src/main.rs @@ -0,0 +1,25 @@ +use rusty_ai::*; +use rusty_chatgpt::ChatGptProvider; +use rusty_claude::ClaudeProvider; + +#[tokio::main] +async fn main() -> Result<(), Box> { + // Example 1: Using ChatGPT + let chatgpt = + ChatGptProvider::new(std::env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY required")); + let model = chatgpt.gpt4o_mini(); + + let result = generate_text(&model, "What is Rust programming language?").await?; + println!("ChatGPT says: {result}"); + + // Example 2: Using Claude + let claude = ClaudeProvider::new( + std::env::var("ANTHROPIC_API_KEY").expect("ANTHROPIC_API_KEY required"), + ); + let model = claude.claude_sonnet(); + + let result = generate_text(&model, "What is Rust programming language?").await?; + println!("Claude says: {result}"); + + Ok(()) +} diff --git a/examples/generate_object/Cargo.toml b/examples/generate_object/Cargo.toml new file mode 100644 index 0000000..d25523a --- /dev/null +++ b/examples/generate_object/Cargo.toml @@ -0,0 +1,12 @@ +[package] +name = "example_generate_object" +version.workspace = true +edition.workspace = true +license.workspace = true + +[dependencies] +rusty_ai = { workspace = true } +rusty_chatgpt = { workspace = true } +tokio = { workspace = true } +serde = { workspace = true } +schemars = { workspace = true } diff --git a/examples/generate_object/src/main.rs b/examples/generate_object/src/main.rs new file mode 100644 index 0000000..704b972 --- /dev/null +++ b/examples/generate_object/src/main.rs @@ -0,0 +1,40 @@ +use rusty_ai::model::generate_object; +use rusty_ai::*; +use rusty_chatgpt::ChatGptProvider; +use schemars::JsonSchema; +use serde::Deserialize; + +#[derive(Debug, Deserialize, JsonSchema)] +struct Recipe { + name: String, + ingredients: Vec, + steps: Vec, + prep_time_minutes: u32, +} + +#[tokio::main] +async fn main() -> Result<(), Box> { + let provider = + ChatGptProvider::new(std::env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY required")); + let model = provider.gpt4o_mini(); + + let result: ObjectResult = generate_object( + &model, + Prompt::from("Generate a recipe for chocolate chip cookies"), + GenerateOptions::default(), + ) + .await?; + + println!("Recipe: {}", result.object.name); + println!("Prep time: {} minutes", result.object.prep_time_minutes); + println!("\nIngredients:"); + for ingredient in &result.object.ingredients { + println!(" - {ingredient}"); + } + println!("\nSteps:"); + for (i, step) in result.object.steps.iter().enumerate() { + println!(" {}. {step}", i + 1); + } + + Ok(()) +} diff --git a/examples/local_android/Cargo.toml b/examples/local_android/Cargo.toml new file mode 100644 index 0000000..3233aa1 --- /dev/null +++ b/examples/local_android/Cargo.toml @@ -0,0 +1,11 @@ +[package] +name = "example_local_android" +version.workspace = true +edition.workspace = true +license.workspace = true + +[dependencies] +rusty_ai = { workspace = true } +rusty_gemini_nano = { workspace = true } +tokio = { workspace = true } +async-trait = { workspace = true } diff --git a/examples/local_android/src/main.rs b/examples/local_android/src/main.rs new file mode 100644 index 0000000..fce897c --- /dev/null +++ b/examples/local_android/src/main.rs @@ -0,0 +1,70 @@ +use async_trait::async_trait; +use rusty_ai::*; +use rusty_gemini_nano::*; + +/// Example bridge implementation (in a real app, this would call JNI) +struct MockNanoBridge; + +#[async_trait] +impl GeminiNanoBridge for MockNanoBridge { + async fn is_available(&self) -> bool { + true + } + + async fn download_state(&self) -> ModelDownloadState { + ModelDownloadState::Downloaded + } + + async fn request_download(&self) -> Result<(), String> { + Ok(()) + } + + async fn capabilities(&self) -> NanoCapabilities { + NanoCapabilities { + text_generation: true, + summarization: true, + rewriting: true, + } + } + + async fn generate(&self, prompt: &str, _config: &NanoSessionConfig) -> Result { + Ok(format!("[Gemini Nano mock response to: {prompt}]")) + } + + async fn create_session(&self, _config: &NanoSessionConfig) -> Result { + Ok("mock-session-1".into()) + } + + async fn send_message(&self, _session_id: &str, message: &str) -> Result { + Ok(format!("[Session reply to: {message}]")) + } + + async fn close_session(&self, _session_id: &str) -> Result<(), String> { + Ok(()) + } +} + +#[tokio::main] +async fn main() -> Result<(), Box> { + let provider = GeminiNanoProvider::new(MockNanoBridge); + + // Check availability + println!("Gemini Nano available: {}", provider.is_available().await); + println!("Download state: {:?}", provider.download_state().await); + + // Single-turn generation + let model = provider.model(); + let result = generate_text(&model, "Summarize the Rust programming language").await?; + println!("Response: {result}"); + + // Multi-turn session + let session = provider + .create_session(NanoSessionConfig::default()) + .await?; + let reply1 = session.send("What is Rust?").await?; + println!("Session reply 1: {reply1}"); + let reply2 = session.send("What about its memory safety?").await?; + println!("Session reply 2: {reply2}"); + + Ok(()) +} diff --git a/examples/local_apple/Cargo.toml b/examples/local_apple/Cargo.toml new file mode 100644 index 0000000..de9d3b7 --- /dev/null +++ b/examples/local_apple/Cargo.toml @@ -0,0 +1,11 @@ +[package] +name = "example_local_apple" +version.workspace = true +edition.workspace = true +license.workspace = true + +[dependencies] +rusty_ai = { workspace = true } +rusty_foundationmodels = { workspace = true } +tokio = { workspace = true } +async-trait = { workspace = true } diff --git a/examples/local_apple/src/main.rs b/examples/local_apple/src/main.rs new file mode 100644 index 0000000..eb5e241 --- /dev/null +++ b/examples/local_apple/src/main.rs @@ -0,0 +1,162 @@ +//! Example: Apple Foundation Models integration. +//! +//! This example demonstrates how the Apple Foundation Models provider would be +//! used once the `rusty_foundationmodels` crate is fully implemented. +//! +//! On a real Apple device the bridge would call into the Foundation Models +//! framework via Swift/Objective-C interop. Here we show the intended usage +//! pattern with a mock bridge that mirrors the Gemini Nano example. + +use async_trait::async_trait; +use rusty_ai::capability::{Capability, CapabilitySet}; +use rusty_ai::error::{AiError, AiResult}; +use rusty_ai::model::{GenerateOptions, LanguageModel}; +use rusty_ai::prompt::Prompt; +use rusty_ai::stream::{AiStream, SyntheticStreamer}; +use rusty_ai::structured::GenerateResult; +use rusty_ai::types::{FinishReason, ResponseMetadata}; +use rusty_ai::usage::Usage; +use rusty_ai::*; + +/// Trait representing the bridge to Apple Foundation Models on-device runtime. +#[async_trait] +trait FoundationModelsBridge: Send + Sync { + async fn is_available(&self) -> bool; + async fn generate(&self, prompt: &str) -> Result; +} + +/// Mock bridge for demonstration purposes. +struct MockAppleBridge; + +#[async_trait] +impl FoundationModelsBridge for MockAppleBridge { + async fn is_available(&self) -> bool { + true + } + + async fn generate(&self, prompt: &str) -> Result { + Ok(format!( + "[Apple Foundation Models mock response to: {prompt}]" + )) + } +} + +/// A language model backed by Apple Foundation Models. +struct FoundationModel { + bridge: Box, + capabilities: CapabilitySet, +} + +impl FoundationModel { + fn new(bridge: impl FoundationModelsBridge + 'static) -> Self { + let capabilities = CapabilitySet::new() + .with(Capability::TextInput) + .with(Capability::TextOutput) + .with(Capability::LocalExecution) + .with(Capability::PlatformNative); + Self { + bridge: Box::new(bridge), + capabilities, + } + } +} + +#[async_trait] +impl LanguageModel for FoundationModel { + fn model_id(&self) -> &str { + "apple-foundation-model" + } + + fn provider_id(&self) -> &str { + "apple" + } + + fn capabilities(&self) -> &CapabilitySet { + &self.capabilities + } + + async fn generate( + &self, + prompt: Prompt, + _options: GenerateOptions, + ) -> AiResult { + if !self.bridge.is_available().await { + return Err(AiError::PlatformUnavailable { + platform: "apple/foundation_models".into(), + }); + } + + let prompt_text = match prompt { + Prompt::Text(t) => t, + Prompt::Messages(msgs) => msgs + .into_iter() + .filter_map(|m| { + m.content.into_iter().find_map(|part| { + if let ContentPart::Text { text } = part { + Some(text) + } else { + None + } + }) + }) + .collect::>() + .join("\n"), + }; + + let text = self + .bridge + .generate(&prompt_text) + .await + .map_err(|e| AiError::BridgeError { + bridge: "apple_foundation_models".into(), + message: e, + })?; + + Ok(GenerateResult { + text: Some(text), + tool_calls: Vec::new(), + finish_reason: FinishReason::Stop, + usage: Usage::default(), + metadata: ResponseMetadata { + provider: "apple".into(), + model: "apple-foundation-model".into(), + ..Default::default() + }, + }) + } + + async fn stream(&self, prompt: Prompt, options: GenerateOptions) -> AiResult { + let result = self.generate(prompt, options).await?; + let text = result.text.unwrap_or_default(); + Ok(SyntheticStreamer::stream(text, 20)) + } +} + +#[tokio::main] +async fn main() -> Result<(), Box> { + let model = FoundationModel::new(MockAppleBridge); + + // Check availability + println!( + "Apple Foundation Model available: {}", + model.bridge.is_available().await + ); + + // Generate text + let result = generate_text(&model, "Explain Swift concurrency in one paragraph").await?; + println!("Response: {result}"); + + // Use via the LanguageModel trait with custom options + let options = GenerateOptions::default().with_temperature(0.7); + let result = model + .generate( + Prompt::from("What is the latest version of macOS?"), + options, + ) + .await?; + if let Some(text) = &result.text { + println!("With options: {text}"); + } + + Ok(()) +} diff --git a/examples/local_windows/Cargo.toml b/examples/local_windows/Cargo.toml new file mode 100644 index 0000000..0d8d846 --- /dev/null +++ b/examples/local_windows/Cargo.toml @@ -0,0 +1,11 @@ +[package] +name = "example_local_windows" +version.workspace = true +edition.workspace = true +license.workspace = true + +[dependencies] +rusty_ai = { workspace = true } +rusty_phi_silica = { workspace = true } +tokio = { workspace = true } +async-trait = { workspace = true } diff --git a/examples/local_windows/src/main.rs b/examples/local_windows/src/main.rs new file mode 100644 index 0000000..512720a --- /dev/null +++ b/examples/local_windows/src/main.rs @@ -0,0 +1,160 @@ +//! Example: Windows Phi Silica integration. +//! +//! This example demonstrates how the Phi Silica provider would be used once the +//! `rusty_phi_silica` crate is fully implemented. +//! +//! On a real Windows device the bridge would call into the Windows Copilot +//! Runtime via the Windows App SDK. Here we show the intended usage pattern +//! with a mock bridge that mirrors the Gemini Nano example. + +use async_trait::async_trait; +use rusty_ai::capability::{Capability, CapabilitySet}; +use rusty_ai::error::{AiError, AiResult}; +use rusty_ai::model::{GenerateOptions, LanguageModel}; +use rusty_ai::prompt::Prompt; +use rusty_ai::stream::{AiStream, SyntheticStreamer}; +use rusty_ai::structured::GenerateResult; +use rusty_ai::types::{FinishReason, ResponseMetadata}; +use rusty_ai::usage::Usage; +use rusty_ai::*; + +/// Trait representing the bridge to Phi Silica on Windows. +#[async_trait] +trait PhiSilicaBridge: Send + Sync { + async fn is_available(&self) -> bool; + async fn generate(&self, prompt: &str) -> Result; +} + +/// Mock bridge for demonstration purposes. +struct MockPhiSilicaBridge; + +#[async_trait] +impl PhiSilicaBridge for MockPhiSilicaBridge { + async fn is_available(&self) -> bool { + true + } + + async fn generate(&self, prompt: &str) -> Result { + Ok(format!("[Phi Silica mock response to: {prompt}]")) + } +} + +/// A language model backed by Phi Silica on Windows. +struct PhiSilicaModel { + bridge: Box, + capabilities: CapabilitySet, +} + +impl PhiSilicaModel { + fn new(bridge: impl PhiSilicaBridge + 'static) -> Self { + let capabilities = CapabilitySet::new() + .with(Capability::TextInput) + .with(Capability::TextOutput) + .with(Capability::LocalExecution) + .with(Capability::PlatformNative); + Self { + bridge: Box::new(bridge), + capabilities, + } + } +} + +#[async_trait] +impl LanguageModel for PhiSilicaModel { + fn model_id(&self) -> &str { + "phi-silica" + } + + fn provider_id(&self) -> &str { + "windows" + } + + fn capabilities(&self) -> &CapabilitySet { + &self.capabilities + } + + async fn generate( + &self, + prompt: Prompt, + _options: GenerateOptions, + ) -> AiResult { + if !self.bridge.is_available().await { + return Err(AiError::PlatformUnavailable { + platform: "windows/phi_silica".into(), + }); + } + + let prompt_text = match prompt { + Prompt::Text(t) => t, + Prompt::Messages(msgs) => msgs + .into_iter() + .filter_map(|m| { + m.content.into_iter().find_map(|part| { + if let ContentPart::Text { text } = part { + Some(text) + } else { + None + } + }) + }) + .collect::>() + .join("\n"), + }; + + let text = self + .bridge + .generate(&prompt_text) + .await + .map_err(|e| AiError::BridgeError { + bridge: "phi_silica".into(), + message: e, + })?; + + Ok(GenerateResult { + text: Some(text), + tool_calls: Vec::new(), + finish_reason: FinishReason::Stop, + usage: Usage::default(), + metadata: ResponseMetadata { + provider: "windows".into(), + model: "phi-silica".into(), + ..Default::default() + }, + }) + } + + async fn stream(&self, prompt: Prompt, options: GenerateOptions) -> AiResult { + let result = self.generate(prompt, options).await?; + let text = result.text.unwrap_or_default(); + Ok(SyntheticStreamer::stream(text, 20)) + } +} + +#[tokio::main] +async fn main() -> Result<(), Box> { + let model = PhiSilicaModel::new(MockPhiSilicaBridge); + + // Check availability + println!( + "Phi Silica available: {}", + model.bridge.is_available().await + ); + + // Generate text + let result = generate_text(&model, "What is the Windows Copilot Runtime?").await?; + println!("Response: {result}"); + + // Use via the LanguageModel trait with custom options + let options = GenerateOptions::default().with_max_tokens(256); + let result = model + .generate( + Prompt::from("Describe the Phi model family in one paragraph"), + options, + ) + .await?; + if let Some(text) = &result.text { + println!("With options: {text}"); + } + + Ok(()) +} diff --git a/examples/multimodal/Cargo.toml b/examples/multimodal/Cargo.toml new file mode 100644 index 0000000..cdd59eb --- /dev/null +++ b/examples/multimodal/Cargo.toml @@ -0,0 +1,10 @@ +[package] +name = "example_multimodal" +version.workspace = true +edition.workspace = true +license.workspace = true + +[dependencies] +rusty_ai = { workspace = true } +rusty_chatgpt = { workspace = true } +tokio = { workspace = true } diff --git a/examples/multimodal/src/main.rs b/examples/multimodal/src/main.rs new file mode 100644 index 0000000..d3dfb49 --- /dev/null +++ b/examples/multimodal/src/main.rs @@ -0,0 +1,28 @@ +use rusty_ai::*; +use rusty_chatgpt::ChatGptProvider; + +#[tokio::main] +async fn main() -> Result<(), Box> { + let provider = + ChatGptProvider::new(std::env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY required")); + let model = provider.gpt4o(); + + // Create a message with an image URL + let message = Message::user("Describe this image in detail.").with_image(ImageData::Url { + url: "https://upload.wikimedia.org/wikipedia/commons/thumb/d/d5/\ + Rust_programming_language_black_logo.svg/\ + 1200px-Rust_programming_language_black_logo.svg.png" + .into(), + detail: Some(ImageDetail::Auto), + }); + + let result = model + .generate(Prompt::Messages(vec![message]), GenerateOptions::default()) + .await?; + + if let Some(text) = &result.text { + println!("Description: {text}"); + } + + Ok(()) +} diff --git a/examples/router/Cargo.toml b/examples/router/Cargo.toml new file mode 100644 index 0000000..cb9c0c6 --- /dev/null +++ b/examples/router/Cargo.toml @@ -0,0 +1,10 @@ +[package] +name = "example_router" +version.workspace = true +edition.workspace = true +license.workspace = true + +[dependencies] +rusty_ai = { workspace = true } +rusty_testing = { workspace = true } +tokio = { workspace = true } diff --git a/examples/router/src/main.rs b/examples/router/src/main.rs new file mode 100644 index 0000000..6d21f25 --- /dev/null +++ b/examples/router/src/main.rs @@ -0,0 +1,34 @@ +use rusty_ai::*; +use rusty_testing::*; + +#[tokio::main] +async fn main() -> Result<(), Box> { + // Create mock models + let local_model = MockLanguageModel::new("local-llm").with_text("Hello from the local model!"); + + let cloud_model = MockLanguageModel::new("cloud-llm").with_text("Hello from the cloud model!"); + + // Create a local-first router: tries the local model first, falls back to cloud + let router = Router::local_first(Box::new(local_model), Box::new(cloud_model)); + + // The router will try the local model first + let result = generate_text(&router, "Hello!").await?; + println!("Router result: {result}"); + + // Create a capability-based router + let text_model = MockLanguageModel::new("text-model").with_text("I handle text!"); + + let tool_model = MockLanguageModel::new("tool-model").with_text("I handle tools!"); + + let router = Router::new() + .add_route(Box::new(tool_model), |_prompt, options| { + options.tools.is_some() && !options.tools.as_ref().unwrap().is_empty() + }) + .with_fallback(Box::new(text_model)); + + // Simple text request -- no tools, goes to fallback + let result = generate_text(&router, "Hello!").await?; + println!("Simple text: {result}"); + + Ok(()) +} diff --git a/examples/stream_object/Cargo.toml b/examples/stream_object/Cargo.toml new file mode 100644 index 0000000..ff1436d --- /dev/null +++ b/examples/stream_object/Cargo.toml @@ -0,0 +1,14 @@ +[package] +name = "example_stream_object" +version.workspace = true +edition.workspace = true +license.workspace = true + +[dependencies] +rusty_ai = { workspace = true } +rusty_chatgpt = { workspace = true } +tokio = { workspace = true } +futures = { workspace = true } +serde = { workspace = true } +serde_json = { workspace = true } +schemars = { workspace = true } diff --git a/examples/stream_object/src/main.rs b/examples/stream_object/src/main.rs new file mode 100644 index 0000000..db83538 --- /dev/null +++ b/examples/stream_object/src/main.rs @@ -0,0 +1,48 @@ +use futures::StreamExt; +use rusty_ai::*; +use rusty_chatgpt::ChatGptProvider; +use schemars::JsonSchema; +use serde::Deserialize; + +#[derive(Debug, Deserialize, JsonSchema)] +#[allow(dead_code)] +struct MovieReview { + title: String, + rating: f32, + summary: String, + pros: Vec, + cons: Vec, +} + +#[tokio::main] +async fn main() -> Result<(), Box> { + let provider = + ChatGptProvider::new(std::env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY required")); + let model = provider.gpt4o(); + + let options = GenerateOptions { + output_schema: Some(OutputSchema::from_type::()), + ..Default::default() + }; + + let mut stream = model + .stream( + Prompt::from("Write a review for the movie 'Inception'"), + options, + ) + .await?; + + println!("Streaming object deltas:"); + while let Some(event) = stream.next().await { + match event? { + StreamEvent::TextDelta { delta } => print!("{delta}"), + StreamEvent::ObjectDelta { delta } => { + println!("Object delta: {delta}"); + } + StreamEvent::MessageEnd { .. } => println!("\n\nDone!"), + _ => {} + } + } + + Ok(()) +} diff --git a/examples/stream_text/Cargo.toml b/examples/stream_text/Cargo.toml new file mode 100644 index 0000000..ddc6473 --- /dev/null +++ b/examples/stream_text/Cargo.toml @@ -0,0 +1,11 @@ +[package] +name = "example_stream_text" +version.workspace = true +edition.workspace = true +license.workspace = true + +[dependencies] +rusty_ai = { workspace = true } +rusty_chatgpt = { workspace = true } +tokio = { workspace = true } +futures = { workspace = true } diff --git a/examples/stream_text/src/main.rs b/examples/stream_text/src/main.rs new file mode 100644 index 0000000..717c3f5 --- /dev/null +++ b/examples/stream_text/src/main.rs @@ -0,0 +1,31 @@ +use futures::StreamExt; +use rusty_ai::*; +use rusty_chatgpt::ChatGptProvider; + +#[tokio::main] +async fn main() -> Result<(), Box> { + let provider = + ChatGptProvider::new(std::env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY required")); + let model = provider.gpt4o_mini(); + + let mut stream = stream_text(&model, "Write a haiku about Rust programming").await?; + + print!("Streaming: "); + while let Some(event) = stream.next().await { + match event? { + StreamEvent::TextDelta { delta } => print!("{delta}"), + StreamEvent::MessageEnd { + finish_reason, + usage, + } => { + println!("\n\nFinished: {finish_reason:?}"); + if let Some(usage) = usage { + println!("Tokens used: {:?}", usage.total_tokens); + } + } + _ => {} + } + } + + Ok(()) +} diff --git a/examples/tool_loop/Cargo.toml b/examples/tool_loop/Cargo.toml new file mode 100644 index 0000000..8cbc0e4 --- /dev/null +++ b/examples/tool_loop/Cargo.toml @@ -0,0 +1,12 @@ +[package] +name = "example_tool_loop" +version.workspace = true +edition.workspace = true +license.workspace = true + +[dependencies] +rusty_ai = { workspace = true } +rusty_chatgpt = { workspace = true } +tokio = { workspace = true } +async-trait = { workspace = true } +serde_json = { workspace = true } diff --git a/examples/tool_loop/src/main.rs b/examples/tool_loop/src/main.rs new file mode 100644 index 0000000..867d551 --- /dev/null +++ b/examples/tool_loop/src/main.rs @@ -0,0 +1,141 @@ +use async_trait::async_trait; +use rusty_ai::tool::Tool; +use rusty_ai::*; +use rusty_chatgpt::ChatGptProvider; + +// Define a calculator tool +struct CalculatorTool; + +#[async_trait] +impl Tool for CalculatorTool { + fn definition(&self) -> ToolDefinition { + ToolDefinition { + name: "calculator".into(), + description: "Performs basic arithmetic. Input: JSON with 'operation' (add/subtract/multiply/divide) and 'a', 'b' numbers.".into(), + parameters: serde_json::json!({ + "type": "object", + "properties": { + "operation": { "type": "string", "enum": ["add", "subtract", "multiply", "divide"] }, + "a": { "type": "number" }, + "b": { "type": "number" } + }, + "required": ["operation", "a", "b"] + }), + } + } + + async fn execute(&self, args: serde_json::Value) -> Result { + let op = args["operation"].as_str().unwrap_or("add"); + let a = args["a"].as_f64().unwrap_or(0.0); + let b = args["b"].as_f64().unwrap_or(0.0); + let result = match op { + "add" => a + b, + "subtract" => a - b, + "multiply" => a * b, + "divide" => { + if b != 0.0 { + a / b + } else { + return Ok("Error: division by zero".into()); + } + } + _ => return Ok(format!("Unknown operation: {op}")), + }; + Ok(format!("{result}")) + } +} + +// Define a weather tool +struct WeatherTool; + +#[async_trait] +impl Tool for WeatherTool { + fn definition(&self) -> ToolDefinition { + ToolDefinition { + name: "get_weather".into(), + description: "Get the current weather for a city.".into(), + parameters: serde_json::json!({ + "type": "object", + "properties": { + "city": { "type": "string" } + }, + "required": ["city"] + }), + } + } + + async fn execute(&self, args: serde_json::Value) -> Result { + let city = args["city"].as_str().unwrap_or("Unknown"); + // Fake weather data for demonstration + Ok(format!( + "Weather in {city}: 72\u{00b0}F, sunny with light clouds" + )) + } +} + +#[tokio::main] +async fn main() -> Result<(), Box> { + let provider = + ChatGptProvider::new(std::env::var("OPENAI_API_KEY").expect("OPENAI_API_KEY required")); + let model = provider.gpt4o_mini(); + + let mut tools = ToolSet::new(); + tools.add(CalculatorTool); + tools.add(WeatherTool); + + let prompt = Prompt::from("What's the weather in Tokyo? Also, what is 42 * 17?"); + let max_steps = 5; + + // Tool calling loop + let mut messages = prompt.into_messages(); + for step in 0..max_steps { + println!("--- Step {step} ---"); + + let options = GenerateOptions { + tools: Some(tools.definitions()), + tool_choice: Some(ToolChoice::Auto), + ..Default::default() + }; + + let result = model + .generate(Prompt::Messages(messages.clone()), options) + .await?; + + if let Some(text) = &result.text { + println!("Assistant: {text}"); + } + + if result.tool_calls.is_empty() || result.finish_reason != FinishReason::ToolCall { + println!("No more tool calls. Done!"); + break; + } + + // Add assistant message carrying the tool calls + let mut assistant_parts: Vec = Vec::new(); + if let Some(text) = &result.text { + assistant_parts.push(ContentPart::Text { text: text.clone() }); + } + for call in &result.tool_calls { + assistant_parts.push(ContentPart::ToolCall { call: call.clone() }); + } + messages.push(Message { + role: Role::Assistant, + content: assistant_parts, + name: None, + metadata: Default::default(), + }); + + // Execute each tool call and add results + for call in &result.tool_calls { + println!("Calling tool '{}' with args: {}", call.name, call.arguments); + let tool_result = tools.execute(call).await?; + println!("Tool result: {}", tool_result.content); + messages.push(Message::tool_result( + &tool_result.call_id, + &tool_result.content, + )); + } + } + + Ok(()) +} diff --git a/rustfmt.toml b/rustfmt.toml new file mode 100644 index 0000000..3a26366 --- /dev/null +++ b/rustfmt.toml @@ -0,0 +1 @@ +edition = "2021"