From 008de13b4d6fe34086fc0484c95bc10561dec88a Mon Sep 17 00:00:00 2001 From: shiwk Date: Tue, 2 Jun 2026 21:38:10 +0800 Subject: [PATCH 1/5] feat(namespace)!: add pluggable authentication provider framework BREAKING CHANGE: RestNamespaceBuilder::build() now returns Result; direct Rust callers must add `?`. ConnectBuilder::connect() users are unaffected. --- rust/lance-namespace-impls/src/connect.rs | 2 +- rust/lance-namespace-impls/src/lib.rs | 3 + rust/lance-namespace-impls/src/rest.rs | 404 ++++++++++++++---- .../lance-namespace-impls/src/rest_adapter.rs | 4 +- rust/lance-namespace-impls/src/rest_auth.rs | 97 +++++ 5 files changed, 421 insertions(+), 89 deletions(-) create mode 100644 rust/lance-namespace-impls/src/rest_auth.rs diff --git a/rust/lance-namespace-impls/src/connect.rs b/rust/lance-namespace-impls/src/connect.rs index c44eb2de219..63378367a96 100644 --- a/rust/lance-namespace-impls/src/connect.rs +++ b/rust/lance-namespace-impls/src/connect.rs @@ -184,7 +184,7 @@ impl ConnectBuilder { if let Some(provider) = self.context_provider { builder = builder.context_provider(provider); } - Ok(Arc::new(builder.build()) as Arc) + Ok(Arc::new(builder.build()?) as Arc) } #[cfg(not(feature = "rest"))] "rest" => Err(NamespaceError::Unsupported { diff --git a/rust/lance-namespace-impls/src/lib.rs b/rust/lance-namespace-impls/src/lib.rs index 58e29aca5ef..a17002c68f7 100644 --- a/rust/lance-namespace-impls/src/lib.rs +++ b/rust/lance-namespace-impls/src/lib.rs @@ -75,6 +75,7 @@ pub mod connect; pub mod context; pub mod credentials; pub mod dir; +pub mod rest_auth; #[cfg(feature = "rest")] pub mod rest; @@ -89,6 +90,8 @@ pub use dir::{ DirectoryNamespace, DirectoryNamespaceBuilder, OpsMetrics, manifest::ManifestNamespace, }; +pub use rest_auth::{NoopAuthProvider, RequestContext, RestAuthProvider}; + // Re-export credential vending pub use credentials::{ CredentialVendor, DEFAULT_CREDENTIAL_DURATION_MILLIS, VendedCredentials, diff --git a/rust/lance-namespace-impls/src/rest.rs b/rust/lance-namespace-impls/src/rest.rs index 27a563d2807..2771c527e2c 100644 --- a/rust/lance-namespace-impls/src/rest.rs +++ b/rust/lance-namespace-impls/src/rest.rs @@ -13,6 +13,10 @@ use async_trait::async_trait; use bytes::Bytes; use reqwest::header::{HeaderName, HeaderValue}; +use crate::rest_auth::{ + AUTH_PROPERTY_PREFIX, AUTH_TYPE_KEY, RequestContext, RestAuthProvider, create_auth_provider, +}; + use crate::context::{DynamicContextProvider, OperationInfo}; use lance_namespace::apis::urlencode; @@ -66,6 +70,7 @@ struct RestClient { base_path: String, base_headers: HashMap, context_provider: Option>, + auth_provider: Option>, } impl std::fmt::Debug for RestClient { @@ -77,58 +82,101 @@ impl std::fmt::Debug for RestClient { "context_provider", &self.context_provider.as_ref().map(|_| "Some(...)"), ) + .field( + "auth_provider", + &self.auth_provider.as_ref().map(|_| "Some(...)"), + ) .finish() } } +fn reqwest_to_lance_error(e: reqwest::Error) -> lance_core::Error { + let message = format!("Failed to execute request: {e:?}"); + if e.is_timeout() || e.is_connect() { + NamespaceError::ServiceUnavailable { message }.into() + } else { + NamespaceError::Internal { message }.into() + } +} + +fn apply_string_headers(headers: &mut reqwest::header::HeaderMap, pairs: I) +where + I: IntoIterator, + K: AsRef, + V: AsRef, +{ + for (k, v) in pairs { + if let (Ok(name), Ok(val)) = ( + HeaderName::from_str(k.as_ref()), + HeaderValue::from_str(v.as_ref()), + ) { + headers.insert(name, val); + } + } +} + impl RestClient { - /// Apply base headers and dynamic context headers to a request. - /// - /// This method mutates the request's headers directly, which is more efficient - /// than creating a new client with default_headers for each request. - fn apply_headers(&self, request: &mut reqwest::Request, operation: &str, object_id: &str) { - let request_headers = request.headers_mut(); - - // First apply base headers - for (key, value) in &self.base_headers { - if let (Ok(header_name), Ok(header_value)) = - (HeaderName::from_str(key), HeaderValue::from_str(value)) - { - request_headers.insert(header_name, header_value); - } + fn build_auth_context(request: &reqwest::Request) -> RequestContext { + let headers = request + .headers() + .iter() + .filter_map(|(k, v)| v.to_str().ok().map(|s| (k.to_string(), s.to_string()))) + .collect(); + RequestContext { + method: request.method().to_string(), + url: request.url().to_string(), + headers, + } + } + + /// Apply base, auth, then context headers. Context wins on conflict so it + /// can intentionally override auth headers per request. + async fn apply_headers( + &self, + request: &mut reqwest::Request, + operation: &str, + object_id: &str, + ) -> Result<()> { + apply_string_headers(request.headers_mut(), &self.base_headers); + + if let Some(auth) = &self.auth_provider { + let ctx = Self::build_auth_context(request); + let auth_headers = + auth.authenticate(&ctx) + .await + .map_err(|e| NamespaceError::Unauthenticated { + message: format!( + "auth provider failed for operation '{operation}' on '{object_id}': {e}" + ), + })?; + apply_string_headers(request.headers_mut(), auth_headers); } - // Then apply context headers (override base headers if conflict) if let Some(provider) = &self.context_provider { let info = OperationInfo::new(operation, object_id); - let context = provider.provide_context(&info); - const HEADERS_PREFIX: &str = "headers."; - for (key, value) in context { - if let Some(header_name) = key.strip_prefix(HEADERS_PREFIX) - && let (Ok(header_name), Ok(header_value)) = ( - HeaderName::from_str(header_name), - HeaderValue::from_str(&value), - ) - { - request_headers.insert(header_name, header_value); - } - } + let context_headers = provider + .provide_context(&info) + .into_iter() + .filter_map(|(k, v)| k.strip_prefix(HEADERS_PREFIX).map(|n| (n.to_string(), v))); + apply_string_headers(request.headers_mut(), context_headers); } + Ok(()) } - /// Execute a request with dynamic headers applied. - /// - /// This method builds the request, applies headers, and executes it. async fn execute( &self, req_builder: reqwest::RequestBuilder, operation: &str, object_id: &str, - ) -> std::result::Result { - let mut request = req_builder.build()?; - self.apply_headers(&mut request, operation, object_id); - self.client.execute(request).await + ) -> Result { + let mut request = req_builder.build().map_err(reqwest_to_lance_error)?; + self.apply_headers(&mut request, operation, object_id) + .await?; + self.client + .execute(request) + .await + .map_err(reqwest_to_lance_error) } /// Get the base path URL @@ -156,7 +204,7 @@ impl RestClient { /// let namespace = RestNamespaceBuilder::new("http://localhost:8080") /// .delimiter(".") /// .header("Authorization", "Bearer token") -/// .build(); +/// .build()?; /// # Ok(()) /// # } /// ``` @@ -172,6 +220,8 @@ pub struct RestNamespaceBuilder { context_provider: Option>, /// When true, tracks operation metrics. Default: false. ops_metrics_enabled: bool, + auth_provider: Option>, + auth_properties: HashMap, } impl std::fmt::Debug for RestNamespaceBuilder { @@ -189,6 +239,10 @@ impl std::fmt::Debug for RestNamespaceBuilder { &self.context_provider.as_ref().map(|_| "Some(...)"), ) .field("ops_metrics_enabled", &self.ops_metrics_enabled) + .field( + "auth_provider", + &self.auth_provider.as_ref().map(|_| "Some(...)"), + ) .finish() } } @@ -213,6 +267,8 @@ impl RestNamespaceBuilder { assert_hostname: true, context_provider: None, ops_metrics_enabled: false, + auth_provider: None, + auth_properties: HashMap::new(), } } @@ -252,7 +308,7 @@ impl RestNamespaceBuilder { /// properties.insert("header.Authorization".to_string(), "Bearer token".to_string()); /// /// let namespace = RestNamespaceBuilder::from_properties(properties)? - /// .build(); + /// .build()?; /// # Ok(()) /// # } /// ``` @@ -296,6 +352,12 @@ impl RestNamespaceBuilder { .and_then(|v| v.parse::().ok()) .unwrap_or(false); + let auth_properties: HashMap = properties + .iter() + .filter(|(k, _)| k.starts_with(AUTH_PROPERTY_PREFIX)) + .map(|(k, v)| (k.clone(), v.clone())) + .collect(); + Ok(Self { uri, delimiter, @@ -306,6 +368,8 @@ impl RestNamespaceBuilder { assert_hostname, context_provider: None, ops_metrics_enabled, + auth_provider: None, + auth_properties, }) } @@ -411,13 +475,19 @@ impl RestNamespaceBuilder { /// /// let namespace = RestNamespaceBuilder::new("http://localhost:8080") /// .context_provider(Arc::new(MyProvider)) - /// .build(); + /// .build()?; /// ``` pub fn context_provider(mut self, provider: Arc) -> Self { self.context_provider = Some(provider); self } + /// Set the auth provider directly. Takes precedence over `rest.auth.*` properties. + pub fn auth_provider(mut self, provider: Arc) -> Self { + self.auth_provider = Some(provider); + self + } + /// Enable or disable operation metrics tracking. /// /// When enabled, the namespace will track how many times each API operation @@ -430,13 +500,17 @@ impl RestNamespaceBuilder { self } - /// Build the RestNamespace. - /// - /// # Returns - /// - /// Returns a `RestNamespace` instance. - pub fn build(self) -> RestNamespace { - RestNamespace::from_builder(self) + /// Build the RestNamespace. Auth provider precedence: + /// `auth_provider()` setter > `rest.auth.*` properties > none. + pub fn build(self) -> Result { + let auth = if let Some(p) = self.auth_provider.clone() { + Some(p) + } else if self.auth_properties.contains_key(AUTH_TYPE_KEY) { + Some(create_auth_provider(&self.auth_properties)?) + } else { + None + }; + Ok(RestNamespace::from_builder(self, auth)) } } @@ -461,7 +535,7 @@ fn object_id_str(id: &Option>, delimiter: &str) -> Result { /// # fn example() -> Result<(), Box> { /// // Use the builder to create a namespace /// let namespace = RestNamespaceBuilder::new("http://localhost:8080") -/// .build(); +/// .build()?; /// # Ok(()) /// # } /// ``` @@ -487,8 +561,11 @@ impl std::fmt::Display for RestNamespace { } impl RestNamespace { - /// Create a new REST namespace from builder - pub(crate) fn from_builder(builder: RestNamespaceBuilder) -> Self { + /// Create a new REST namespace from builder + resolved auth provider. + pub(crate) fn from_builder( + builder: RestNamespaceBuilder, + auth_provider: Option>, + ) -> Self { // Build reqwest client WITHOUT default headers - we'll apply headers per-request let mut client_builder = reqwest::Client::builder(); @@ -521,6 +598,7 @@ impl RestNamespace { base_path: builder.uri, base_headers: builder.headers, context_provider: builder.context_provider, + auth_provider, }; let ops_metrics = if builder.ops_metrics_enabled { @@ -536,19 +614,6 @@ impl RestNamespace { } } - /// Map a reqwest::Error to the appropriate NamespaceError variant. - /// - /// Timeout and connection errors are mapped to `ServiceUnavailable`, - /// while other errors are mapped to `Internal`. - fn request_error(e: reqwest::Error) -> lance_core::Error { - let message = format!("Failed to execute request: {:?}", e); - if e.is_timeout() || e.is_connect() { - NamespaceError::ServiceUnavailable { message }.into() - } else { - NamespaceError::Internal { message }.into() - } - } - /// Parse an error response body and return the appropriate NamespaceError. /// /// Deserializes the response as an `ErrorResponse` model (the spec-defined @@ -585,8 +650,7 @@ impl RestNamespace { let resp = self .rest_client .execute(req_builder, operation, object_id) - .await - .map_err(Self::request_error)?; + .await?; let status = resp.status(); let content = resp.text().await.map_err(|e| { @@ -622,8 +686,7 @@ impl RestNamespace { let resp = self .rest_client .execute(req_builder, operation, object_id) - .await - .map_err(Self::request_error)?; + .await?; let status = resp.status(); let content = resp.text().await.map_err(|e| { @@ -659,8 +722,7 @@ impl RestNamespace { let resp = self .rest_client .execute(req_builder, operation, object_id) - .await - .map_err(Self::request_error)?; + .await?; let status = resp.status(); if status.is_success() { @@ -690,8 +752,7 @@ impl RestNamespace { let resp = self .rest_client .execute(req_builder, operation, object_id) - .await - .map_err(Self::request_error)?; + .await?; let status = resp.status(); let content = resp.text().await.map_err(|e| { @@ -1107,8 +1168,7 @@ impl LanceNamespace for RestNamespace { let resp = self .rest_client .execute(req_builder, operation, &id) - .await - .map_err(Self::request_error)?; + .await?; let status = resp.status(); if status.is_success() { @@ -1582,7 +1642,8 @@ mod tests { let _namespace = RestNamespaceBuilder::from_properties(properties) .expect("Failed to create namespace builder") - .build(); + .build() + .unwrap(); // Successfully created the namespace - test passes if no panic } @@ -1599,7 +1660,8 @@ mod tests { let _namespace = RestNamespaceBuilder::from_properties(properties) .expect("Failed to create namespace builder") - .build(); + .build() + .unwrap(); } #[tokio::test] @@ -1638,7 +1700,8 @@ mod tests { let namespace = RestNamespaceBuilder::from_properties(properties) .expect("Failed to create namespace builder") - .build(); + .build() + .unwrap(); let request = ListNamespacesRequest { id: Some(vec!["test".to_string()]), @@ -1657,7 +1720,8 @@ mod tests { properties.insert("uri".to_string(), "http://localhost:8080".to_string()); let _namespace = RestNamespaceBuilder::from_properties(properties) .expect("Failed to create namespace builder") - .build(); + .build() + .unwrap(); // The default delimiter should be "$" - test passes if no panic } @@ -1669,7 +1733,8 @@ mod tests { let _namespace = RestNamespaceBuilder::from_properties(properties) .expect("Failed to create namespace builder") - .build(); + .build() + .unwrap(); // Test passes if no panic } @@ -1736,7 +1801,8 @@ mod tests { // Should not panic even with nonexistent files (they're just ignored) let _namespace = RestNamespaceBuilder::from_properties(properties) .expect("Failed to create namespace builder") - .build(); + .build() + .unwrap(); } #[tokio::test] @@ -1757,7 +1823,9 @@ mod tests { .await; // Create namespace with mock server URL - let namespace = RestNamespaceBuilder::new(mock_server.uri()).build(); + let namespace = RestNamespaceBuilder::new(mock_server.uri()) + .build() + .unwrap(); let request = ListNamespacesRequest { id: Some(vec!["test".to_string()]), @@ -1793,7 +1861,9 @@ mod tests { .await; // Create namespace with mock server URL - let namespace = RestNamespaceBuilder::new(mock_server.uri()).build(); + let namespace = RestNamespaceBuilder::new(mock_server.uri()) + .build() + .unwrap(); let request = ListNamespacesRequest { id: Some(vec!["test".to_string()]), @@ -1826,7 +1896,9 @@ mod tests { .await; // Create namespace with mock server URL - let namespace = RestNamespaceBuilder::new(mock_server.uri()).build(); + let namespace = RestNamespaceBuilder::new(mock_server.uri()) + .build() + .unwrap(); let request = CreateNamespaceRequest { id: Some(vec!["test".to_string(), "newnamespace".to_string()]), @@ -1859,7 +1931,9 @@ mod tests { .await; // Create namespace with mock server URL - let namespace = RestNamespaceBuilder::new(mock_server.uri()).build(); + let namespace = RestNamespaceBuilder::new(mock_server.uri()) + .build() + .unwrap(); let request = CreateTableRequest { id: Some(vec![ @@ -1898,7 +1972,9 @@ mod tests { .mount(&mock_server) .await; - let namespace = RestNamespaceBuilder::new(mock_server.uri()).build(); + let namespace = RestNamespaceBuilder::new(mock_server.uri()) + .build() + .unwrap(); let request = CreateTableRequest { id: Some(vec![ @@ -1963,7 +2039,9 @@ mod tests { .await; // Create namespace with mock server URL - let namespace = RestNamespaceBuilder::new(mock_server.uri()).build(); + let namespace = RestNamespaceBuilder::new(mock_server.uri()) + .build() + .unwrap(); let request = InsertIntoTableRequest { id: Some(vec![ @@ -2026,7 +2104,8 @@ mod tests { let namespace = RestNamespaceBuilder::new(mock_server.uri()) .context_provider(provider) - .build(); + .build() + .unwrap(); let request = ListNamespacesRequest { id: Some(vec!["test".to_string()]), @@ -2072,7 +2151,8 @@ mod tests { let namespace = RestNamespaceBuilder::new(mock_server.uri()) .header("Authorization", "Bearer base-token") .context_provider(provider) - .build(); + .build() + .unwrap(); let request = ListNamespacesRequest { id: Some(vec!["test".to_string()]), @@ -2114,7 +2194,8 @@ mod tests { let namespace = RestNamespaceBuilder::new(mock_server.uri()) .header("Authorization", "Bearer base-token") .context_provider(provider) - .build(); + .build() + .unwrap(); let request = ListNamespacesRequest { id: Some(vec!["test".to_string()]), @@ -2145,7 +2226,8 @@ mod tests { // Create namespace WITHOUT context provider, only base headers let namespace = RestNamespaceBuilder::new(mock_server.uri()) .header("Authorization", "Bearer base-only") - .build(); + .build() + .unwrap(); let request = ListNamespacesRequest { id: Some(vec!["test".to_string()]), @@ -2155,4 +2237,154 @@ mod tests { let result = namespace.list_namespaces(request).await; assert!(result.is_ok(), "Failed: {:?}", result.err()); } + + // Compatibility tests: the RestAuthProvider framework must not change + // existing behaviour for users on the static `header.*` path. + + #[tokio::test] + async fn rest_auth_type_none_outbound_headers_identical_to_no_config() { + let mock_server = MockServer::start().await; + Mock::given(method("GET")) + .and(path("/v1/namespace/ns/list")) + .respond_with( + ResponseTemplate::new(200).set_body_json(serde_json::json!({ "namespaces": [] })), + ) + .mount(&mock_server) + .await; + + let ns_no_auth = RestNamespaceBuilder::new(mock_server.uri()) + .build() + .unwrap(); + let req = ListNamespacesRequest { + id: Some(vec!["ns".to_string()]), + ..Default::default() + }; + ns_no_auth.list_namespaces(req.clone()).await.unwrap(); + + let mut props = HashMap::new(); + props.insert("uri".to_string(), mock_server.uri()); + props.insert("rest.auth.type".to_string(), "none".to_string()); + let ns_none = RestNamespaceBuilder::from_properties(props) + .unwrap() + .build() + .unwrap(); + ns_none.list_namespaces(req).await.unwrap(); + + let requests = mock_server.received_requests().await.unwrap(); + assert_eq!(requests.len(), 2); + let h0: HashMap<_, _> = requests[0] + .headers + .iter() + .map(|(k, v)| (k.as_str().to_lowercase(), v.to_str().unwrap().to_string())) + .collect(); + let h1: HashMap<_, _> = requests[1] + .headers + .iter() + .map(|(k, v)| (k.as_str().to_lowercase(), v.to_str().unwrap().to_string())) + .collect(); + assert_eq!( + h0, h1, + "rest.auth.type=none should produce identical headers to no config" + ); + } + + #[tokio::test] + async fn legacy_header_authorization_unchanged_with_auth_framework() { + let mock_server = MockServer::start().await; + Mock::given(method("GET")) + .and(path("/v1/namespace/ns/list")) + .and(wiremock::matchers::header( + "Authorization", + "Bearer legacy-static-token", + )) + .respond_with( + ResponseTemplate::new(200).set_body_json(serde_json::json!({ "namespaces": [] })), + ) + .mount(&mock_server) + .await; + + let ns = RestNamespaceBuilder::new(mock_server.uri()) + .header("Authorization", "Bearer legacy-static-token") + .build() + .unwrap(); + let req = ListNamespacesRequest { + id: Some(vec!["ns".to_string()]), + ..Default::default() + }; + ns.list_namespaces(req).await.unwrap(); + } + + #[test] + fn unknown_rest_auth_type_returns_error_at_build_time() { + let mut props = HashMap::new(); + props.insert("uri".to_string(), "http://127.0.0.1:1".to_string()); + props.insert( + "rest.auth.type".to_string(), + "definitely-not-a-real-scheme".to_string(), + ); + let result = RestNamespaceBuilder::from_properties(props) + .unwrap() + .build(); + assert!( + result.is_err(), + "expected build() to fail for unknown auth type" + ); + let err_str = result.err().unwrap().to_string(); + assert!( + err_str.contains("definitely-not-a-real-scheme"), + "error should mention the offending type, got: {err_str}" + ); + assert!( + err_str.contains("none"), + "error should list supported types, got: {err_str}" + ); + } + + /// Auth provider errors must surface as hard failures, not be swallowed. + #[tokio::test] + async fn auth_provider_failure_surfaces_as_error() { + #[derive(Debug)] + struct AlwaysFailAuth; + #[async_trait::async_trait] + impl crate::rest_auth::RestAuthProvider for AlwaysFailAuth { + async fn authenticate( + &self, + _ctx: &crate::rest_auth::RequestContext, + ) -> Result> { + Err(NamespaceError::Unauthenticated { + message: "synthetic-token-expired".to_string(), + } + .into()) + } + } + + let mock_server = MockServer::start().await; + Mock::given(method("GET")) + .and(path("/v1/namespace/test/list")) + .respond_with( + ResponseTemplate::new(200).set_body_json(serde_json::json!({"namespaces": []})), + ) + .mount(&mock_server) + .await; + + let ns = RestNamespaceBuilder::new(mock_server.uri()) + .auth_provider(std::sync::Arc::new(AlwaysFailAuth)) + .build() + .unwrap(); + let req = ListNamespacesRequest { + id: Some(vec!["test".to_string()]), + ..Default::default() + }; + let result = ns.list_namespaces(req).await; + assert!(result.is_err(), "auth failure should bubble up"); + let err_msg = result.err().unwrap().to_string(); + assert!( + err_msg.contains("synthetic-token-expired"), + "underlying error must be preserved: {err_msg}" + ); + assert!( + err_msg.contains("list_namespaces"), + "error should include operation context: {err_msg}" + ); + } } diff --git a/rust/lance-namespace-impls/src/rest_adapter.rs b/rust/lance-namespace-impls/src/rest_adapter.rs index 6a3875ebf29..85d98a14574 100644 --- a/rust/lance-namespace-impls/src/rest_adapter.rs +++ b/rust/lance-namespace-impls/src/rest_adapter.rs @@ -1483,7 +1483,7 @@ mod tests { let server_url = format!("http://127.0.0.1:{}", actual_port); let namespace = RestNamespaceBuilder::new(&server_url) .delimiter("$") - .build(); + .build().unwrap(); Self { _temp_dir: temp_dir, @@ -3047,7 +3047,7 @@ mod tests { .delimiter("$") .header("X-Base-Header", "base-value") .context_provider(provider) - .build(); + .build().unwrap(); // Create a namespace - should work with context provider let create_req = CreateNamespaceRequest { diff --git a/rust/lance-namespace-impls/src/rest_auth.rs b/rust/lance-namespace-impls/src/rest_auth.rs new file mode 100644 index 00000000000..0e3f2a9df4f --- /dev/null +++ b/rust/lance-namespace-impls/src/rest_auth.rs @@ -0,0 +1,97 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +//! Authentication provider abstraction for REST Namespace HTTP requests. + +use std::collections::HashMap; +use std::sync::Arc; + +use async_trait::async_trait; +use lance_core::Result; +use lance_namespace::error::NamespaceError; + +pub const AUTH_TYPE_KEY: &str = "rest.auth.type"; +pub const AUTH_PROPERTY_PREFIX: &str = "rest.auth."; +pub const AUTH_TYPE_NONE: &str = "none"; + +/// Request snapshot handed to [`RestAuthProvider::authenticate`]. +#[derive(Debug, Clone)] +pub struct RequestContext { + pub method: String, + pub url: String, + pub headers: HashMap, +} + +/// Per-request authentication provider. Implementations own their own +/// credential lifecycle (initial fetch, refresh, caching, expiry). +#[async_trait] +pub trait RestAuthProvider: Send + Sync + std::fmt::Debug { + async fn authenticate(&self, ctx: &RequestContext) -> Result>; +} + +#[derive(Debug, Default)] +pub struct NoopAuthProvider; + +#[async_trait] +impl RestAuthProvider for NoopAuthProvider { + async fn authenticate(&self, _ctx: &RequestContext) -> Result> { + Ok(HashMap::new()) + } +} + +/// Dispatch on [`AUTH_TYPE_KEY`] to build a [`RestAuthProvider`]. Currently +/// only `"none"` (or missing key) is accepted; concrete providers extend this +/// behind feature flags. +pub fn create_auth_provider( + properties: &HashMap, +) -> Result> { + let auth_type = properties + .get(AUTH_TYPE_KEY) + .map(|s| s.as_str()) + .unwrap_or(AUTH_TYPE_NONE); + match auth_type { + AUTH_TYPE_NONE => Ok(Arc::new(NoopAuthProvider)), + other => Err(NamespaceError::InvalidInput { + message: format!( + "unsupported {AUTH_TYPE_KEY} '{other}' (supported: {AUTH_TYPE_NONE})" + ), + } + .into()), + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn noop_returns_empty_headers() { + let ctx = RequestContext { + method: "GET".to_string(), + url: "http://example.com/v1/test".to_string(), + headers: HashMap::new(), + }; + assert!(NoopAuthProvider.authenticate(&ctx).await.unwrap().is_empty()); + } + + #[test] + fn factory_accepts_missing_auth_type() { + assert!(create_auth_provider(&HashMap::new()).is_ok()); + } + + #[test] + fn factory_accepts_explicit_none() { + let mut props = HashMap::new(); + props.insert(AUTH_TYPE_KEY.to_string(), AUTH_TYPE_NONE.to_string()); + assert!(create_auth_provider(&props).is_ok()); + } + + #[test] + fn factory_rejects_unknown_with_helpful_error() { + let mut props = HashMap::new(); + props.insert(AUTH_TYPE_KEY.to_string(), "sigv4-typo".to_string()); + let msg = create_auth_provider(&props).unwrap_err().to_string(); + assert!(msg.contains("sigv4-typo")); + assert!(msg.contains(AUTH_TYPE_NONE)); + } +} From 60cb5f3204e1c6a78b684d53842e6047d3a0f68f Mon Sep 17 00:00:00 2001 From: shiwk Date: Wed, 3 Jun 2026 14:30:36 +0800 Subject: [PATCH 2/5] feat(rest-auth): add AWS SigV4 authentication provider --- Cargo.lock | 3 + rust/lance-namespace-impls/Cargo.toml | 6 +- rust/lance-namespace-impls/src/connect.rs | 4 +- rust/lance-namespace-impls/src/lib.rs | 2 + rust/lance-namespace-impls/src/rest.rs | 138 +++++- rust/lance-namespace-impls/src/rest_auth.rs | 63 ++- .../src/rest_auth/sigv4.rs | 449 ++++++++++++++++++ 7 files changed, 641 insertions(+), 24 deletions(-) create mode 100644 rust/lance-namespace-impls/src/rest_auth/sigv4.rs diff --git a/Cargo.lock b/Cargo.lock index d9d7588827e..f02849aecda 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5039,12 +5039,15 @@ dependencies = [ "arrow-schema", "async-trait", "aws-config", + "aws-credential-types", "aws-sdk-sts", + "aws-sigv4", "axum", "base64 0.22.1", "bytes", "chrono", "futures", + "hex", "hmac 0.12.1", "lance", "lance-arrow", diff --git a/rust/lance-namespace-impls/Cargo.toml b/rust/lance-namespace-impls/Cargo.toml index 53ff79fb333..2858476d157 100644 --- a/rust/lance-namespace-impls/Cargo.toml +++ b/rust/lance-namespace-impls/Cargo.toml @@ -13,8 +13,9 @@ rust-version.workspace = true [features] default = ["dir-aws", "dir-azure", "dir-gcp", "dir-oss", "dir-huggingface"] -rest = ["dep:reqwest", "dep:serde"] +rest = ["dep:reqwest", "dep:serde", "dep:sha2", "dep:hex"] rest-adapter = ["dep:axum", "dep:tower", "dep:tower-http", "dep:serde"] +rest-auth-sigv4 = ["rest", "dep:aws-sigv4", "dep:aws-credential-types", "dep:aws-config"] # Cloud storage features for directory implementation - align with lance-io dir-gcp = ["lance-io/gcp", "lance/gcp"] dir-aws = ["lance-io/aws", "lance/aws"] @@ -69,11 +70,14 @@ rand.workspace = true # Shared credential vending dependencies sha2 = { version = "0.10", optional = true } +hex = { version = "0.4", optional = true } base64 = { version = "0.22", optional = true } # AWS credential vending dependencies (optional, enabled by "credential-vendor-aws" feature) aws-sdk-sts = { version = "1.38.0", optional = true, default-features = false, features = ["default-https-client", "rt-tokio"] } aws-config = { workspace = true, optional = true } +aws-sigv4 = { version = "1", optional = true } +aws-credential-types = { version = "1", optional = true } # GCP credential vending dependencies (optional, enabled by "credential-vendor-gcp" feature) ring = { version = "0.17", optional = true } diff --git a/rust/lance-namespace-impls/src/connect.rs b/rust/lance-namespace-impls/src/connect.rs index 63378367a96..d7d52c83da3 100644 --- a/rust/lance-namespace-impls/src/connect.rs +++ b/rust/lance-namespace-impls/src/connect.rs @@ -184,7 +184,9 @@ impl ConnectBuilder { if let Some(provider) = self.context_provider { builder = builder.context_provider(provider); } - Ok(Arc::new(builder.build()?) as Arc) + let ns = builder.build()?; + ns.warm_up_auth().await?; + Ok(Arc::new(ns) as Arc) } #[cfg(not(feature = "rest"))] "rest" => Err(NamespaceError::Unsupported { diff --git a/rust/lance-namespace-impls/src/lib.rs b/rust/lance-namespace-impls/src/lib.rs index a17002c68f7..436aad1e4bb 100644 --- a/rust/lance-namespace-impls/src/lib.rs +++ b/rust/lance-namespace-impls/src/lib.rs @@ -75,6 +75,7 @@ pub mod connect; pub mod context; pub mod credentials; pub mod dir; +#[cfg(feature = "rest")] pub mod rest_auth; #[cfg(feature = "rest")] @@ -90,6 +91,7 @@ pub use dir::{ DirectoryNamespace, DirectoryNamespaceBuilder, OpsMetrics, manifest::ManifestNamespace, }; +#[cfg(feature = "rest")] pub use rest_auth::{NoopAuthProvider, RequestContext, RestAuthProvider}; // Re-export credential vending diff --git a/rust/lance-namespace-impls/src/rest.rs b/rust/lance-namespace-impls/src/rest.rs index 2771c527e2c..b4086f5e63d 100644 --- a/rust/lance-namespace-impls/src/rest.rs +++ b/rust/lance-namespace-impls/src/rest.rs @@ -12,6 +12,7 @@ use crate::OpsMetrics; use async_trait::async_trait; use bytes::Bytes; use reqwest::header::{HeaderName, HeaderValue}; +use sha2::{Digest, Sha256}; use crate::rest_auth::{ AUTH_PROPERTY_PREFIX, AUTH_TYPE_KEY, RequestContext, RestAuthProvider, create_auth_provider, @@ -106,31 +107,57 @@ where V: AsRef, { for (k, v) in pairs { - if let (Ok(name), Ok(val)) = ( + match ( HeaderName::from_str(k.as_ref()), HeaderValue::from_str(v.as_ref()), ) { - headers.insert(name, val); + (Ok(name), Ok(val)) => { + headers.insert(name, val); + } + _ => { + log::warn!("dropping invalid header: {:?}: {:?}", k.as_ref(), v.as_ref()); + } } } } +pub(crate) const EMPTY_BODY_SHA256: &str = + "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"; + +/// `None` for streaming bodies (currently unreachable). +fn body_sha256_hex(request: &reqwest::Request) -> Option { + match request.body() { + None => Some(EMPTY_BODY_SHA256.to_string()), + Some(body) => match body.as_bytes() { + None => None, + Some([]) => Some(EMPTY_BODY_SHA256.to_string()), + Some(bytes) => Some(hex::encode(Sha256::digest(bytes))), + }, + } +} + impl RestClient { fn build_auth_context(request: &reqwest::Request) -> RequestContext { + // Lossy decode keeps non-ASCII headers in the signer's view. let headers = request .headers() .iter() - .filter_map(|(k, v)| v.to_str().ok().map(|s| (k.to_string(), s.to_string()))) + .map(|(k, v)| { + ( + k.as_str().to_string(), + String::from_utf8_lossy(v.as_bytes()).into_owned(), + ) + }) .collect(); RequestContext { method: request.method().to_string(), url: request.url().to_string(), headers, + body_sha256: body_sha256_hex(request), } } - /// Apply base, auth, then context headers. Context wins on conflict so it - /// can intentionally override auth headers per request. + /// Apply headers: base → auth (signed) → context (unsigned). async fn apply_headers( &self, request: &mut reqwest::Request, @@ -614,6 +641,13 @@ impl RestNamespace { } } + pub async fn warm_up_auth(&self) -> Result<()> { + if let Some(auth) = &self.rest_client.auth_provider { + auth.initialize().await?; + } + Ok(()) + } + /// Parse an error response body and return the appropriate NamespaceError. /// /// Deserializes the response as an `ErrorResponse` model (the spec-defined @@ -1629,6 +1663,11 @@ mod tests { use wiremock::matchers::{method, path}; use wiremock::{Mock, MockServer, ResponseTemplate}; + #[test] + fn empty_body_sha256_const_matches_computed() { + assert_eq!(EMPTY_BODY_SHA256, hex::encode(Sha256::digest(b""))); + } + #[test] fn test_rest_namespace_creation() { let mut properties = HashMap::new(); @@ -2238,9 +2277,6 @@ mod tests { assert!(result.is_ok(), "Failed: {:?}", result.err()); } - // Compatibility tests: the RestAuthProvider framework must not change - // existing behaviour for users on the static `header.*` path. - #[tokio::test] async fn rest_auth_type_none_outbound_headers_identical_to_no_config() { let mock_server = MockServer::start().await; @@ -2340,7 +2376,6 @@ mod tests { ); } - /// Auth provider errors must surface as hard failures, not be swallowed. #[tokio::test] async fn auth_provider_failure_surfaces_as_error() { #[derive(Debug)] @@ -2387,4 +2422,89 @@ mod tests { "error should include operation context: {err_msg}" ); } + + #[tokio::test] + async fn warm_up_auth_surfaces_initialize_error() { + #[derive(Debug)] + struct FailOnInit; + #[async_trait::async_trait] + impl crate::rest_auth::RestAuthProvider for FailOnInit { + async fn authenticate( + &self, + _ctx: &crate::rest_auth::RequestContext, + ) -> Result> { + Ok(HashMap::new()) + } + async fn initialize(&self) -> Result<()> { + Err(NamespaceError::Unauthenticated { + message: "synthetic-credential-chain-failure".to_string(), + } + .into()) + } + } + + let ns = RestNamespaceBuilder::new("http://127.0.0.1:1") + .auth_provider(std::sync::Arc::new(FailOnInit)) + .build() + .unwrap(); + let err = ns.warm_up_auth().await.unwrap_err(); + assert!( + err.to_string() + .contains("synthetic-credential-chain-failure"), + "warm_up_auth must propagate initialize() error: {err}" + ); + } + + #[tokio::test] + async fn auth_provider_setter_takes_precedence_over_properties() { + #[derive(Debug)] + struct MarkerAuth; + #[async_trait::async_trait] + impl crate::rest_auth::RestAuthProvider for MarkerAuth { + async fn authenticate( + &self, + _ctx: &crate::rest_auth::RequestContext, + ) -> Result> { + let mut h = HashMap::new(); + h.insert("x-marker".to_string(), "from-setter".to_string()); + Ok(h) + } + } + + let mock_server = MockServer::start().await; + Mock::given(method("GET")) + .and(path("/v1/namespace/test/list")) + .respond_with( + ResponseTemplate::new(200).set_body_json(serde_json::json!({"namespaces": []})), + ) + .mount(&mock_server) + .await; + + let mut props = HashMap::new(); + props.insert("uri".to_string(), mock_server.uri()); + props.insert("rest.auth.type".to_string(), "none".to_string()); + let ns = RestNamespaceBuilder::from_properties(props) + .unwrap() + .auth_provider(std::sync::Arc::new(MarkerAuth)) + .build() + .unwrap(); + + let req = ListNamespacesRequest { + id: Some(vec!["test".to_string()]), + ..Default::default() + }; + let _ = ns.list_namespaces(req).await.unwrap(); + + let matched = mock_server.received_requests().await.unwrap(); + assert!(!matched.is_empty()); + let marker = matched[0] + .headers + .get("x-marker") + .map(|v| v.to_str().unwrap()); + assert_eq!( + marker, + Some("from-setter"), + "setter auth_provider must override rest.auth.type property" + ); + } } diff --git a/rust/lance-namespace-impls/src/rest_auth.rs b/rust/lance-namespace-impls/src/rest_auth.rs index 0e3f2a9df4f..1decc6027c0 100644 --- a/rust/lance-namespace-impls/src/rest_auth.rs +++ b/rust/lance-namespace-impls/src/rest_auth.rs @@ -1,7 +1,7 @@ // SPDX-License-Identifier: Apache-2.0 // SPDX-FileCopyrightText: Copyright The Lance Authors -//! Authentication provider abstraction for REST Namespace HTTP requests. +//! Authentication providers for REST Namespace HTTP requests. use std::collections::HashMap; use std::sync::Arc; @@ -10,23 +10,32 @@ use async_trait::async_trait; use lance_core::Result; use lance_namespace::error::NamespaceError; +#[cfg(feature = "rest-auth-sigv4")] +pub mod sigv4; + pub const AUTH_TYPE_KEY: &str = "rest.auth.type"; pub const AUTH_PROPERTY_PREFIX: &str = "rest.auth."; pub const AUTH_TYPE_NONE: &str = "none"; +#[cfg(feature = "rest-auth-sigv4")] +pub const AUTH_TYPE_SIGV4: &str = "sigv4"; -/// Request snapshot handed to [`RestAuthProvider::authenticate`]. #[derive(Debug, Clone)] pub struct RequestContext { pub method: String, pub url: String, pub headers: HashMap, + /// `None` for streaming bodies. + pub body_sha256: Option, } -/// Per-request authentication provider. Implementations own their own -/// credential lifecycle (initial fetch, refresh, caching, expiry). #[async_trait] pub trait RestAuthProvider: Send + Sync + std::fmt::Debug { async fn authenticate(&self, ctx: &RequestContext) -> Result>; + + /// Connect-time init; default no-op. + async fn initialize(&self) -> Result<()> { + Ok(()) + } } #[derive(Debug, Default)] @@ -39,9 +48,6 @@ impl RestAuthProvider for NoopAuthProvider { } } -/// Dispatch on [`AUTH_TYPE_KEY`] to build a [`RestAuthProvider`]. Currently -/// only `"none"` (or missing key) is accepted; concrete providers extend this -/// behind feature flags. pub fn create_auth_provider( properties: &HashMap, ) -> Result> { @@ -51,27 +57,58 @@ pub fn create_auth_provider( .unwrap_or(AUTH_TYPE_NONE); match auth_type { AUTH_TYPE_NONE => Ok(Arc::new(NoopAuthProvider)), + #[cfg(feature = "rest-auth-sigv4")] + AUTH_TYPE_SIGV4 => Ok(Arc::new(sigv4::SigV4AuthProvider::from_properties( + properties, + )?)), other => Err(NamespaceError::InvalidInput { message: format!( - "unsupported {AUTH_TYPE_KEY} '{other}' (supported: {AUTH_TYPE_NONE})" + "unsupported {AUTH_TYPE_KEY} '{other}' (supported: {})", + supported_auth_types() ), } .into()), } } +fn supported_auth_types() -> &'static str { + #[cfg(feature = "rest-auth-sigv4")] + { + "none, sigv4" + } + #[cfg(not(feature = "rest-auth-sigv4"))] + { + "none" + } +} + #[cfg(test)] mod tests { use super::*; - #[tokio::test] - async fn noop_returns_empty_headers() { - let ctx = RequestContext { + fn empty_ctx() -> RequestContext { + RequestContext { method: "GET".to_string(), url: "http://example.com/v1/test".to_string(), headers: HashMap::new(), - }; - assert!(NoopAuthProvider.authenticate(&ctx).await.unwrap().is_empty()); + body_sha256: None, + } + } + + #[tokio::test] + async fn noop_returns_empty_headers() { + assert!( + NoopAuthProvider + .authenticate(&empty_ctx()) + .await + .unwrap() + .is_empty() + ); + } + + #[tokio::test] + async fn noop_initialize_is_ok() { + NoopAuthProvider.initialize().await.unwrap(); } #[test] diff --git a/rust/lance-namespace-impls/src/rest_auth/sigv4.rs b/rust/lance-namespace-impls/src/rest_auth/sigv4.rs new file mode 100644 index 00000000000..9452c2692d7 --- /dev/null +++ b/rust/lance-namespace-impls/src/rest_auth/sigv4.rs @@ -0,0 +1,449 @@ +// SPDX-License-Identifier: Apache-2.0 +// SPDX-FileCopyrightText: Copyright The Lance Authors + +//! AWS SigV4 authentication provider for REST Namespace. + +use std::collections::HashMap; +use std::sync::Arc; +use std::time::SystemTime; + +use async_trait::async_trait; +use aws_credential_types::Credentials; +use aws_credential_types::provider::{ProvideCredentials, SharedCredentialsProvider}; +use aws_sigv4::http_request::{ + SignableBody, SignableRequest, SigningParams, SigningSettings, sign, +}; +use aws_sigv4::sign::v4; +use lance_core::Result; +use lance_namespace::error::NamespaceError; +use tokio::sync::OnceCell; +use url::Url; + +pub const REGION_KEY: &str = "rest.auth.sigv4.region"; +pub const SERVICE_KEY: &str = "rest.auth.sigv4.service"; +const DEFAULT_SERVICE: &str = "execute-api"; + +/// Injectable time source; tests use a fixed clock. +pub trait Clock: Send + Sync + std::fmt::Debug { + fn now(&self) -> SystemTime; +} + +#[derive(Debug, Default)] +pub struct SystemClock; + +impl Clock for SystemClock { + fn now(&self) -> SystemTime { + SystemTime::now() + } +} + +pub struct SigV4AuthProvider { + region: String, + service: String, + credentials_provider: OnceCell, + clock: Arc, +} + +impl std::fmt::Debug for SigV4AuthProvider { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.debug_struct("SigV4AuthProvider") + .field("region", &self.region) + .field("service", &self.service) + .field( + "credentials_provider", + &self.credentials_provider.get().map(|_| "resolved"), + ) + .finish() + } +} + +impl SigV4AuthProvider { + pub fn from_properties(properties: &HashMap) -> Result { + let region = + properties + .get(REGION_KEY) + .cloned() + .ok_or_else(|| NamespaceError::InvalidInput { + message: format!("{REGION_KEY} is required for SigV4 authentication"), + })?; + let service = properties + .get(SERVICE_KEY) + .cloned() + .unwrap_or_else(|| DEFAULT_SERVICE.to_string()); + Ok(Self { + region, + service, + credentials_provider: OnceCell::new(), + clock: Arc::new(SystemClock), + }) + } + + pub fn with_clock(mut self, clock: Arc) -> Self { + self.clock = clock; + self + } + + pub fn with_credentials_provider(self, provider: SharedCredentialsProvider) -> Self { + let cell = OnceCell::new(); + cell.set(provider) + .expect("freshly constructed OnceCell never returns Err"); + Self { + credentials_provider: cell, + ..self + } + } + + async fn ensure_credentials_provider(&self) -> Result<&SharedCredentialsProvider> { + self.credentials_provider + .get_or_try_init(|| async { + // aws_config::load panics inside an existing tokio runtime. + let region = self.region.clone(); + let provider = tokio::task::spawn_blocking(move || { + let rt = tokio::runtime::Handle::current(); + rt.block_on(async { + aws_config::defaults(aws_config::BehaviorVersion::latest()) + .region(aws_config::Region::new(region)) + .load() + .await + }) + }) + .await + .map_err(|e| { + lance_core::Error::from(NamespaceError::Internal { + message: format!("failed to load AWS config: {e}"), + }) + })?; + provider.credentials_provider().ok_or_else(|| { + lance_core::Error::from(NamespaceError::Internal { + message: "AWS config did not yield a credentials provider".to_string(), + }) + }) + }) + .await + } + + async fn resolve_credentials(&self) -> Result { + let provider = self.ensure_credentials_provider().await?; + provider.provide_credentials().await.map_err(|e| { + NamespaceError::Unauthenticated { + message: format!("failed to resolve AWS credentials: {e}"), + } + .into() + }) + } +} + +#[async_trait] +impl super::RestAuthProvider for SigV4AuthProvider { + async fn authenticate(&self, ctx: &super::RequestContext) -> Result> { + let creds = self.resolve_credentials().await?; + let identity = creds.into(); + + let mut signing_settings = SigningSettings::default(); + signing_settings.payload_checksum_kind = + aws_sigv4::http_request::PayloadChecksumKind::XAmzSha256; + let v4_params = v4::SigningParams::builder() + .identity(&identity) + .region(&self.region) + .name(&self.service) + .time(self.clock.now()) + .settings(signing_settings) + .build() + .map_err(|e| NamespaceError::Internal { + message: format!("failed to build SigV4 signing params: {e}"), + })?; + let params: SigningParams = v4_params.into(); + + let parsed_url = Url::parse(&ctx.url).map_err(|_| NamespaceError::InvalidInput { + message: format!("SigV4 requires a valid URL: {}", ctx.url), + })?; + if parsed_url.host_str().is_none() { + return Err(NamespaceError::InvalidInput { + message: format!("SigV4 requires a URL with a host: {}", ctx.url), + } + .into()); + } + let host = parsed_url[url::Position::BeforeHost..url::Position::AfterPort].to_string(); + + let other_headers = ctx + .headers + .iter() + .filter(|(k, _)| !k.eq_ignore_ascii_case("host")); + let header_iter = std::iter::once(("host", host.as_str())) + .chain(other_headers.map(|(k, v)| (k.as_str(), v.as_str()))); + + let body = match ctx.body_sha256.as_deref() { + Some(hash) => SignableBody::Precomputed(hash.to_string()), + None => SignableBody::UnsignedPayload, + }; + + let signable = + SignableRequest::new(&ctx.method, &ctx.url, header_iter, body).map_err(|e| { + NamespaceError::Internal { + message: format!("failed to construct SigV4 signable request: {e}"), + } + })?; + + let (instructions, _signature) = sign(signable, ¶ms) + .map_err(|e| NamespaceError::Internal { + message: format!("SigV4 signing failed: {e}"), + })? + .into_parts(); + + Ok(instructions + .headers() + .map(|(name, value)| (name.to_string(), value.to_string())) + .collect()) + } + + async fn initialize(&self) -> Result<()> { + self.resolve_credentials().await.map(|_| ()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::rest_auth::{RequestContext, RestAuthProvider}; + use std::time::{Duration, UNIX_EPOCH}; + + // AWS SigV4 test vector credentials (botocore cross-verified). + const VECTOR_ACCESS_KEY: &str = "AKIDEXAMPLE"; + const VECTOR_SECRET_KEY: &str = "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY"; + const VECTOR_REGION: &str = "us-east-1"; + const VECTOR_SERVICE: &str = "service"; + const VECTOR_UNIX_SECS: u64 = 1_440_938_160; // 2015-08-30T12:36:00Z + const VECTOR_EXPECTED_AUTHORIZATION: &str = "AWS4-HMAC-SHA256 \ + Credential=AKIDEXAMPLE/20150830/us-east-1/service/aws4_request, \ + SignedHeaders=host;x-amz-content-sha256;x-amz-date, \ + Signature=726c5c4879a6b4ccbbd3b24edbd6b8826d34f87450fbbf4e85546fc7ba9c1642"; + + #[derive(Debug)] + struct FixedClock(SystemTime); + + impl Clock for FixedClock { + fn now(&self) -> SystemTime { + self.0 + } + } + + fn vector_provider() -> SigV4AuthProvider { + let creds = Credentials::new( + VECTOR_ACCESS_KEY, + VECTOR_SECRET_KEY, + None, + None, + "lance-sigv4-test", + ); + let mut props = HashMap::new(); + props.insert(REGION_KEY.to_string(), VECTOR_REGION.to_string()); + props.insert(SERVICE_KEY.to_string(), VECTOR_SERVICE.to_string()); + SigV4AuthProvider::from_properties(&props) + .unwrap() + .with_clock(Arc::new(FixedClock( + UNIX_EPOCH + Duration::from_secs(VECTOR_UNIX_SECS), + ))) + .with_credentials_provider(SharedCredentialsProvider::new(creds)) + } + + #[test] + fn from_properties_requires_region() { + let err = SigV4AuthProvider::from_properties(&HashMap::new()).unwrap_err(); + assert!(err.to_string().contains(REGION_KEY)); + } + + #[test] + fn from_properties_defaults_service_to_execute_api() { + let mut props = HashMap::new(); + props.insert(REGION_KEY.to_string(), "us-west-2".to_string()); + let provider = SigV4AuthProvider::from_properties(&props).unwrap(); + assert_eq!(provider.service, DEFAULT_SERVICE); + assert_eq!(provider.region, "us-west-2"); + } + + #[test] + fn from_properties_accepts_explicit_service() { + let mut props = HashMap::new(); + props.insert(REGION_KEY.to_string(), "us-east-1".to_string()); + props.insert(SERVICE_KEY.to_string(), "s3".to_string()); + let provider = SigV4AuthProvider::from_properties(&props).unwrap(); + assert_eq!(provider.service, "s3"); + } + + #[tokio::test] + async fn reproduces_aws_get_vanilla_reference_vector() { + let provider = vector_provider(); + let ctx = RequestContext { + method: "GET".to_string(), + url: "https://example.amazonaws.com/".to_string(), + headers: HashMap::new(), + body_sha256: Some(crate::rest::EMPTY_BODY_SHA256.to_string()), + }; + let headers = provider.authenticate(&ctx).await.unwrap(); + let actual = headers + .iter() + .find(|(k, _)| k.eq_ignore_ascii_case("authorization")) + .map(|(_, v)| v.as_str()) + .expect("authorization header must be produced"); + assert_eq!(actual, VECTOR_EXPECTED_AUTHORIZATION); + } + + #[tokio::test] + async fn initialize_resolves_injected_credentials() { + vector_provider().initialize().await.unwrap(); + } + + #[tokio::test] + async fn authenticate_rejects_url_without_host() { + let provider = vector_provider(); + let ctx = RequestContext { + method: "GET".to_string(), + url: "file:///nowhere".to_string(), + headers: HashMap::new(), + body_sha256: Some(crate::rest::EMPTY_BODY_SHA256.to_string()), + }; + let err = provider.authenticate(&ctx).await.unwrap_err(); + assert!(err.to_string().contains("host")); + } + + #[tokio::test] + async fn authenticate_overrides_preexisting_host_header() { + let provider = vector_provider(); + let mut headers = HashMap::new(); + headers.insert("Host".to_string(), "wrong.example.com".to_string()); + let ctx = RequestContext { + method: "GET".to_string(), + url: "https://example.amazonaws.com/".to_string(), + headers, + body_sha256: Some(crate::rest::EMPTY_BODY_SHA256.to_string()), + }; + let result = provider.authenticate(&ctx).await.unwrap(); + let actual = result + .iter() + .find(|(k, _)| k.eq_ignore_ascii_case("authorization")) + .map(|(_, v)| v.as_str()) + .expect("authorization header must be produced"); + assert_eq!( + actual, VECTOR_EXPECTED_AUTHORIZATION, + "pre-existing Host header must be replaced by the URL-derived host" + ); + } + + /// AWS test vector: percent-encoded path (%3D → double-encoded %253D). + #[tokio::test] + async fn reproduces_aws_double_encode_path_vector() { + let creds = Credentials::new( + "ANOTREAL", + "notrealrnrELgWzOk3IfjzDKtFBhDby", + None, + None, + "lance-sigv4-test", + ); + let mut props = HashMap::new(); + props.insert(REGION_KEY.to_string(), "us-east-1".to_string()); + props.insert(SERVICE_KEY.to_string(), "service".to_string()); + let provider = SigV4AuthProvider::from_properties(&props) + .unwrap() + .with_clock(Arc::new(FixedClock( + UNIX_EPOCH + Duration::from_secs(VECTOR_UNIX_SECS), + ))) + .with_credentials_provider(SharedCredentialsProvider::new(creds)); + + let ctx = RequestContext { + method: "POST".to_string(), + url: "https://tj9n5r0m12.execute-api.us-east-1.amazonaws.com/test/@connections/JBDvjfGEIAMCERw%3D".to_string(), + headers: HashMap::new(), + body_sha256: Some(crate::rest::EMPTY_BODY_SHA256.to_string()), + }; + let headers = provider.authenticate(&ctx).await.unwrap(); + let auth = headers + .iter() + .find(|(k, _)| k.eq_ignore_ascii_case("authorization")) + .map(|(_, v)| v.as_str()) + .expect("authorization header must be produced"); + assert_eq!( + auth, + "AWS4-HMAC-SHA256 Credential=ANOTREAL/20150830/us-east-1/service/aws4_request, \ + SignedHeaders=host;x-amz-content-sha256;x-amz-date, \ + Signature=ed434df8a348089a1188defcfcc1aa24049990a7e82021d0418cfa0eb05e4d99", + "double-encode-path: signature must match botocore cross-verification" + ); + } + + #[tokio::test] + async fn authenticate_with_unsigned_payload_still_produces_signature() { + let provider = vector_provider(); + let ctx = RequestContext { + method: "GET".to_string(), + url: "https://example.amazonaws.com/".to_string(), + headers: HashMap::new(), + body_sha256: None, + }; + let headers = provider.authenticate(&ctx).await.unwrap(); + let auth = headers + .iter() + .find(|(k, _)| k.eq_ignore_ascii_case("authorization")) + .map(|(_, v)| v.as_str()) + .expect("authorization header must be produced"); + assert!(auth.starts_with("AWS4-HMAC-SHA256 ")); + assert!(auth.contains("Credential=")); + assert!(auth.contains("SignedHeaders=")); + assert!(auth.contains("Signature=")); + } + + #[tokio::test] + async fn authenticate_with_session_token_produces_correct_signature() { + let creds = Credentials::new( + VECTOR_ACCESS_KEY, + VECTOR_SECRET_KEY, + Some("FakeSessionToken123".to_string()), + None, + "lance-sigv4-test", + ); + let mut props = HashMap::new(); + props.insert(REGION_KEY.to_string(), VECTOR_REGION.to_string()); + props.insert(SERVICE_KEY.to_string(), VECTOR_SERVICE.to_string()); + let provider = SigV4AuthProvider::from_properties(&props) + .unwrap() + .with_clock(Arc::new(FixedClock( + UNIX_EPOCH + Duration::from_secs(VECTOR_UNIX_SECS), + ))) + .with_credentials_provider(SharedCredentialsProvider::new(creds)); + + let ctx = RequestContext { + method: "GET".to_string(), + url: "https://example.amazonaws.com/".to_string(), + headers: HashMap::new(), + body_sha256: Some(crate::rest::EMPTY_BODY_SHA256.to_string()), + }; + let headers = provider.authenticate(&ctx).await.unwrap(); + + let token_header = headers + .iter() + .find(|(k, _)| k.eq_ignore_ascii_case("x-amz-security-token")) + .map(|(_, v)| v.as_str()); + assert_eq!( + token_header, + Some("FakeSessionToken123"), + "session token must be included in output headers" + ); + + let auth = headers + .iter() + .find(|(k, _)| k.eq_ignore_ascii_case("authorization")) + .map(|(_, v)| v.as_str()) + .unwrap(); + assert!( + auth.contains("x-amz-security-token"), + "session token must be in SignedHeaders: {}", + auth + ); + assert_eq!( + auth, + "AWS4-HMAC-SHA256 Credential=AKIDEXAMPLE/20150830/us-east-1/service/aws4_request, \ + SignedHeaders=host;x-amz-content-sha256;x-amz-date;x-amz-security-token, \ + Signature=d690ca83bd782879e22797e35b2e25958c0d19696a92cfb479b73428e4d950f4", + "session token signature must match botocore cross-verification" + ); + } +} From fdae167500867fa37437bdfec99942e4b4c09d80 Mon Sep 17 00:00:00 2001 From: shiwk Date: Thu, 4 Jun 2026 20:52:53 +0800 Subject: [PATCH 3/5] feat(rest-auth): enable SigV4 in Python/Java bindings with e2e tests --- java/lance-jni/Cargo.toml | 2 +- java/lance-jni/src/namespace.rs | 7 +- .../org/lance/namespace/SigV4AuthTest.java | 153 ++++++++++++++++ python/Cargo.toml | 2 +- python/python/tests/test_namespace_rest.py | 163 ++++++++++++++++++ python/src/namespace.rs | 6 +- 6 files changed, 329 insertions(+), 4 deletions(-) create mode 100644 java/src/test/java/org/lance/namespace/SigV4AuthTest.java diff --git a/java/lance-jni/Cargo.toml b/java/lance-jni/Cargo.toml index 5eaa69f071b..3d0520e38b3 100644 --- a/java/lance-jni/Cargo.toml +++ b/java/lance-jni/Cargo.toml @@ -23,7 +23,7 @@ lance-linalg = { path = "../../rust/lance-linalg" } lance-index = { path = "../../rust/lance-index" } lance-io = { path = "../../rust/lance-io" } lance-namespace = { path = "../../rust/lance-namespace" } -lance-namespace-impls = { path = "../../rust/lance-namespace-impls", features = ["rest", "rest-adapter", "dir-goosefs"] } +lance-namespace-impls = { path = "../../rust/lance-namespace-impls", features = ["rest", "rest-adapter", "dir-goosefs", "rest-auth-sigv4"] } lance-core = { path = "../../rust/lance-core" } lance-file = { path = "../../rust/lance-file" } lance-table = { path = "../../rust/lance-table" } diff --git a/java/lance-jni/src/namespace.rs b/java/lance-jni/src/namespace.rs index f0da7ff79ae..5fc3e325dc6 100644 --- a/java/lance-jni/src/namespace.rs +++ b/java/lance-jni/src/namespace.rs @@ -2533,7 +2533,12 @@ fn create_rest_namespace_internal( builder = builder.context_provider(Arc::new(java_provider)); } - let namespace = builder.build(); + let namespace = builder.build().map_err(|e| { + Error::runtime_error(format!("Failed to build RestNamespace: {}", e)) + })?; + + RT.block_on(namespace.warm_up_auth()) + .map_err(|e| Error::runtime_error(format!("Auth initialization failed: {}", e)))?; let blocking_namespace = BlockingRestNamespace { inner: Arc::new(namespace), diff --git a/java/src/test/java/org/lance/namespace/SigV4AuthTest.java b/java/src/test/java/org/lance/namespace/SigV4AuthTest.java new file mode 100644 index 00000000000..159e3f5e236 --- /dev/null +++ b/java/src/test/java/org/lance/namespace/SigV4AuthTest.java @@ -0,0 +1,153 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.lance.namespace; + +import org.lance.namespace.model.CreateNamespaceRequest; +import org.lance.namespace.model.ListNamespacesRequest; + +import com.sun.net.httpserver.HttpServer; +import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.memory.RootAllocator; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; + +import java.io.IOException; +import java.io.OutputStream; +import java.net.InetSocketAddress; +import java.nio.file.Path; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.*; + +public class SigV4AuthTest { + @TempDir Path tempDir; + + private BufferAllocator allocator; + + @BeforeEach + void setUp() { + allocator = new RootAllocator(Long.MAX_VALUE); + } + + @AfterEach + void tearDown() { + if (allocator != null) { + allocator.close(); + } + } + + @Test + void testSigV4ConnectAndOperate() { + Map backendConfig = new HashMap<>(); + backendConfig.put("root", tempDir.toString()); + + RestAdapter adapter = new RestAdapter("dir", backendConfig, "127.0.0.1", 0); + adapter.start(); + try { + Map clientConfig = new HashMap<>(); + clientConfig.put("uri", "http://127.0.0.1:" + adapter.getPort()); + clientConfig.put("rest.auth.type", "sigv4"); + clientConfig.put("rest.auth.sigv4.region", "us-east-1"); + clientConfig.put("rest.auth.sigv4.service", "execute-api"); + + RestNamespace ns = new RestNamespace(); + ns.initialize(clientConfig, allocator); + + ns.createNamespace(new CreateNamespaceRequest().id(Arrays.asList("sigv4test"))); + var resp = ns.listNamespaces(new ListNamespacesRequest()); + assertTrue(resp.getNamespaces().contains("sigv4test")); + + ns.close(); + } finally { + adapter.close(); + } + } + + @Test + void testSigV4MissingRegionFailsAtConnect() { + Map backendConfig = new HashMap<>(); + backendConfig.put("root", tempDir.toString()); + + RestAdapter adapter = new RestAdapter("dir", backendConfig, "127.0.0.1", 0); + adapter.start(); + try { + Map clientConfig = new HashMap<>(); + clientConfig.put("uri", "http://127.0.0.1:" + adapter.getPort()); + clientConfig.put("rest.auth.type", "sigv4"); + + RestNamespace ns = new RestNamespace(); + RuntimeException ex = + assertThrows(RuntimeException.class, () -> ns.initialize(clientConfig, allocator)); + assertTrue(ex.getMessage().contains("rest.auth.sigv4.region")); + } finally { + adapter.close(); + } + } + + // Signature correctness is verified at the Rust layer (AWS test vectors + botocore). + @Test + void testSigV4SignatureHeadersPresent() throws IOException { + List capturedAuth = new ArrayList<>(); + + HttpServer server = HttpServer.create(new InetSocketAddress("127.0.0.1", 0), 0); + server.createContext( + "/", + exchange -> { + String auth = exchange.getRequestHeaders().getFirst("Authorization"); + if (auth != null) { + capturedAuth.add(auth); + } + byte[] body = "{\"namespaces\":[]}".getBytes(); + exchange.sendResponseHeaders(200, body.length); + try (OutputStream os = exchange.getResponseBody()) { + os.write(body); + } + }); + server.start(); + int port = server.getAddress().getPort(); + + try { + Map clientConfig = new HashMap<>(); + clientConfig.put("uri", "http://127.0.0.1:" + port); + clientConfig.put("rest.auth.type", "sigv4"); + clientConfig.put("rest.auth.sigv4.region", "us-east-1"); + clientConfig.put("rest.auth.sigv4.service", "execute-api"); + + RestNamespace ns = new RestNamespace(); + ns.initialize(clientConfig, allocator); + + try { + ns.listNamespaces(new ListNamespacesRequest()); + } catch (Exception ignored) { + } + + ns.close(); + + assertFalse(capturedAuth.isEmpty(), "no Authorization header captured"); + String auth = capturedAuth.get(0); + assertTrue(auth.startsWith("AWS4-HMAC-SHA256"), "expected SigV4 header, got: " + auth); + assertTrue(auth.contains("Credential="), "missing Credential in: " + auth); + assertTrue(auth.contains("SignedHeaders="), "missing SignedHeaders in: " + auth); + assertTrue(auth.matches(".*Signature=[a-f0-9]{64}.*"), "missing Signature in: " + auth); + } finally { + server.stop(0); + } + } +} diff --git a/python/Cargo.toml b/python/Cargo.toml index 9c7800d3c83..6d3878c4c31 100644 --- a/python/Cargo.toml +++ b/python/Cargo.toml @@ -47,7 +47,7 @@ lance-index = { path = "../rust/lance-index", features = [ lance-io = { path = "../rust/lance-io" } lance-linalg = { path = "../rust/lance-linalg" } lance-namespace = { path = "../rust/lance-namespace" } -lance-namespace-impls = { path = "../rust/lance-namespace-impls", features = ["rest", "rest-adapter", "dir-goosefs"] } +lance-namespace-impls = { path = "../rust/lance-namespace-impls", features = ["rest", "rest-adapter", "dir-goosefs", "rest-auth-sigv4"] } lance-table = { path = "../rust/lance-table" } lance-datafusion = { path = "../rust/lance-datafusion" } libc = "0.2.176" diff --git a/python/python/tests/test_namespace_rest.py b/python/python/tests/test_namespace_rest.py index 140d9168c05..a482a8f3bb5 100644 --- a/python/python/tests/test_namespace_rest.py +++ b/python/python/tests/test_namespace_rest.py @@ -747,3 +747,166 @@ def provide_context(self, info): # Explicit provider should have been used assert explicit_called["called"] + + +class TestSigV4Auth: + + def test_sigv4_connects_and_signs_requests(self, monkeypatch): + monkeypatch.setenv("AWS_ACCESS_KEY_ID", "AKIAIOSFODNN7EXAMPLE") + monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY") + monkeypatch.setenv("AWS_DEFAULT_REGION", "us-east-1") + + with tempfile.TemporaryDirectory() as tmpdir: + backend_config = {"root": tmpdir} + + with lance.namespace.RestAdapter("dir", backend_config, port=0) as adapter: + client = connect( + "rest", + { + "uri": f"http://127.0.0.1:{adapter.port}", + "rest.auth.type": "sigv4", + "rest.auth.sigv4.region": "us-east-1", + "rest.auth.sigv4.service": "execute-api", + }, + ) + + create_req = CreateNamespaceRequest(id=["sigv4test"]) + client.create_namespace(create_req) + + list_req = ListNamespacesRequest(id=[]) + resp = client.list_namespaces(list_req) + assert "sigv4test" in resp.namespaces + + def test_sigv4_missing_region_fails_at_connect(self): + with tempfile.TemporaryDirectory() as tmpdir: + backend_config = {"root": tmpdir} + + with lance.namespace.RestAdapter("dir", backend_config, port=0) as adapter: + with pytest.raises(Exception, match="rest.auth.sigv4.region"): + connect( + "rest", + { + "uri": f"http://127.0.0.1:{adapter.port}", + "rest.auth.type": "sigv4", + # no region — should fail + }, + ) + + def test_sigv4_signature_correctness(self, monkeypatch): + import json + import re + import threading + from http.server import BaseHTTPRequestHandler, HTTPServer + + from botocore.auth import SigV4Auth + from botocore.awsrequest import AWSRequest + from botocore.credentials import Credentials + + ACCESS_KEY = "AKIAIOSFODNN7EXAMPLE" + SECRET_KEY = "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY" + + monkeypatch.setenv("AWS_ACCESS_KEY_ID", ACCESS_KEY) + monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", SECRET_KEY) + monkeypatch.setenv("AWS_DEFAULT_REGION", "us-east-1") + + captured_requests = [] + + class Recorder(BaseHTTPRequestHandler): + def _capture_and_respond(self): + content_length = int(self.headers.get("Content-Length", 0)) + body = self.rfile.read(content_length) if content_length else b"" + captured_requests.append({ + "method": self.command, + "path": self.path, + "headers": {k.lower(): v for k, v in self.headers.items()}, + "body": body, + }) + self.send_response(200) + self.send_header("Content-Type", "application/json") + self.end_headers() + self.wfile.write(json.dumps({"namespaces": []}).encode()) + + def do_GET(self): + self._capture_and_respond() + + def do_POST(self): + self._capture_and_respond() + + def log_message(self, *_args): + pass + + server = HTTPServer(("127.0.0.1", 0), Recorder) + port = server.server_address[1] + threading.Thread(target=server.serve_forever, daemon=True).start() + + try: + client = connect( + "rest", + { + "uri": f"http://127.0.0.1:{port}", + "rest.auth.type": "sigv4", + "rest.auth.sigv4.region": "us-east-1", + "rest.auth.sigv4.service": "execute-api", + }, + ) + + try: + client.list_namespaces(ListNamespacesRequest(id=[])) + except Exception: + pass + + try: + client.create_namespace(CreateNamespaceRequest(id=["verify"])) + except Exception: + pass + + assert len(captured_requests) >= 2, ( + f"expected at least 2 requests (GET+POST), got {len(captured_requests)}" + ) + methods_seen = {r["method"] for r in captured_requests} + assert "GET" in methods_seen, "expected at least one GET request" + assert "POST" in methods_seen, "expected at least one POST request" + + creds = Credentials(ACCESS_KEY, SECRET_KEY) + signer = SigV4Auth(creds, "execute-api", "us-east-1") + + for req in captured_requests: + rust_auth = req["headers"].get("authorization", "") + assert rust_auth.startswith("AWS4-HMAC-SHA256"), ( + f"{req['method']} {req['path']}: missing SigV4 header" + ) + + rust_sig = re.search(r"Signature=([a-f0-9]{64})", rust_auth).group(1) + amz_date = req["headers"]["x-amz-date"] + + url = f"http://127.0.0.1:{port}{req['path']}" + + signed_names = re.search( + r"SignedHeaders=([^,]+)", rust_auth + ).group(1).split(";") + headers_for_signing = {} + for name in signed_names: + if name in req["headers"]: + headers_for_signing[name] = req["headers"][name] + + aws_req = AWSRequest( + method=req["method"], + url=url, + headers=headers_for_signing, + data=req["body"], + ) + aws_req.context["timestamp"] = amz_date + + cr = signer.canonical_request(aws_req) + sts = signer.string_to_sign(aws_req, cr) + boto_sig = signer.signature(sts, aws_req) + + assert rust_sig == boto_sig, ( + f"{req['method']} {req['path']}: signature mismatch\n" + f" rust: {rust_sig}\n" + f" botocore:{boto_sig}\n" + f" rust_auth: {rust_auth}\n" + f" botocore canonical_request:\n{cr}" + ) + finally: + server.shutdown() diff --git a/python/src/namespace.rs b/python/src/namespace.rs index cf5f7c41b0f..11bf02d4eba 100644 --- a/python/src/namespace.rs +++ b/python/src/namespace.rs @@ -822,7 +822,11 @@ impl PyRestNamespace { builder = builder.context_provider(Arc::new(py_provider)); } - let namespace = builder.build(); + let namespace = builder.build().infer_error()?; + + crate::rt() + .block_on(None, namespace.warm_up_auth())? + .infer_error()?; Ok(Self { inner: Arc::new(namespace), From 9d64452c9480a36de10a91e37cadce0c78dcdaa2 Mon Sep 17 00:00:00 2001 From: shiwk Date: Thu, 4 Jun 2026 20:55:34 +0800 Subject: [PATCH 4/5] docs(rest-auth): add SigV4 authentication properties to API docs --- java/lance-jni/Cargo.lock | 5 +++ .../org/lance/namespace/RestNamespace.java | 13 +++++- python/Cargo.lock | 5 +++ python/src/namespace.rs | 10 +++-- rust/lance-namespace-impls/src/rest.rs | 45 ++++++++++++++++--- 5 files changed, 67 insertions(+), 11 deletions(-) diff --git a/java/lance-jni/Cargo.lock b/java/lance-jni/Cargo.lock index fa08fd758aa..f9e39b191a0 100644 --- a/java/lance-jni/Cargo.lock +++ b/java/lance-jni/Cargo.lock @@ -4242,9 +4242,13 @@ dependencies = [ "arrow-ipc", "arrow-schema", "async-trait", + "aws-config", + "aws-credential-types", + "aws-sigv4", "axum", "bytes", "futures", + "hex", "lance", "lance-core", "lance-index", @@ -4258,6 +4262,7 @@ dependencies = [ "reqwest 0.12.28", "serde", "serde_json", + "sha2 0.10.9", "tokio", "tower", "tower-http 0.5.2", diff --git a/java/src/main/java/org/lance/namespace/RestNamespace.java b/java/src/main/java/org/lance/namespace/RestNamespace.java index 9cbbc588660..e3f09a6d678 100644 --- a/java/src/main/java/org/lance/namespace/RestNamespace.java +++ b/java/src/main/java/org/lance/namespace/RestNamespace.java @@ -120,6 +120,15 @@ *
  • uri (required): REST API endpoint URL *
  • delimiter (optional): Namespace delimiter (default: "$") *
  • header.* (optional): HTTP headers (e.g., header.Authorization=Bearer token) + *
  • rest.auth.type (optional): Authentication type — "sigv4" or "none" (default: none) + *
  • rest.auth.sigv4.region (required if sigv4): AWS region + *
  • rest.auth.sigv4.service (optional): AWS service name (default: "execute-api") + * + * + *

    Note: {@code rest.auth.*} and {@code header.Authorization} are mutually exclusive. + * Setting both will throw an error at initialization time. + * + *

      *
    • tls.cert_file (optional): Path to client certificate file *
    • tls.key_file (optional): Path to client key file *
    • tls.ssl_ca_cert (optional): Path to CA certificate file @@ -131,8 +140,8 @@ *
      {@code
        * Map properties = new HashMap<>();
        * properties.put("uri", "https://api.example.com");
      - * properties.put("delimiter", ".");
      - * properties.put("header.Authorization", "Bearer my-token");
      + * properties.put("rest.auth.type", "sigv4");
      + * properties.put("rest.auth.sigv4.region", "us-east-1");
        *
        * RestNamespace namespace = new RestNamespace();
        * namespace.initialize(properties, allocator);
      diff --git a/python/Cargo.lock b/python/Cargo.lock
      index 7867ea71446..1634bf0d90e 100644
      --- a/python/Cargo.lock
      +++ b/python/Cargo.lock
      @@ -4574,9 +4574,13 @@ dependencies = [
        "arrow-ipc",
        "arrow-schema",
        "async-trait",
      + "aws-config",
      + "aws-credential-types",
      + "aws-sigv4",
        "axum",
        "bytes",
        "futures",
      + "hex",
        "lance",
        "lance-core",
        "lance-index",
      @@ -4590,6 +4594,7 @@ dependencies = [
        "reqwest 0.12.28",
        "serde",
        "serde_json",
      + "sha2 0.10.9",
        "tokio",
        "tower",
        "tower-http 0.5.2",
      diff --git a/python/src/namespace.rs b/python/src/namespace.rs
      index 11bf02d4eba..a153fa8b43f 100644
      --- a/python/src/namespace.rs
      +++ b/python/src/namespace.rs
      @@ -788,15 +788,17 @@ pub struct PyRestNamespace {
       
       #[pymethods]
       impl PyRestNamespace {
      -    /// Create a new RestNamespace from properties
      +    /// Create a new RestNamespace from properties.
           ///
           /// # Arguments
           ///
           /// * `context_provider` - Optional object with `provide_context(info: dict) -> dict` method
           ///   for providing dynamic per-request context. Context keys that start with `headers.`
      -    ///   are converted to HTTP headers by stripping the prefix. For example,
      -    ///   `{"headers.Authorization": "Bearer token"}` becomes the `Authorization` header.
      -    /// * `**properties` - Namespace configuration properties (uri, delimiter, header.*, etc.)
      +    ///   are converted to HTTP headers by stripping the prefix.
      +    /// * `**properties` - Namespace configuration properties (uri, delimiter, header.*,
      +    ///   rest.auth.type, rest.auth.sigv4.region, rest.auth.sigv4.service, etc.)
      +    ///
      +    /// `rest.auth.*` and `header.Authorization` are mutually exclusive.
           #[new]
           #[pyo3(signature = (context_provider = None, **properties))]
           fn new(
      diff --git a/rust/lance-namespace-impls/src/rest.rs b/rust/lance-namespace-impls/src/rest.rs
      index b4086f5e63d..16381c144b1 100644
      --- a/rust/lance-namespace-impls/src/rest.rs
      +++ b/rust/lance-namespace-impls/src/rest.rs
      @@ -219,15 +219,25 @@ impl RestClient {
       
       /// Builder for creating a RestNamespace.
       ///
      -/// This builder provides a fluent API for configuring and establishing
      -/// connections to REST-based Lance namespaces.
      +/// # Authentication
      +///
      +/// SigV4 authentication via properties:
      +/// - `rest.auth.type` — `"sigv4"` or `"none"` (default: none)
      +/// - `rest.auth.sigv4.region` — AWS region (required for sigv4)
      +/// - `rest.auth.sigv4.service` — AWS service name (default: `"execute-api"`)
      +///
      +/// Credentials are resolved via the standard AWS chain (env vars, profile,
      +/// IMDS). Alternatively, use [`auth_provider()`](Self::auth_provider) to
      +/// inject a custom provider (takes precedence over properties).
      +///
      +/// `rest.auth.*` and `header.Authorization` are mutually exclusive —
      +/// setting both will return an error at build time.
       ///
       /// # Examples
       ///
       /// ```no_run
       /// # use lance_namespace_impls::RestNamespaceBuilder;
       /// # fn example() -> Result<(), Box> {
      -/// // Create a REST namespace
       /// let namespace = RestNamespaceBuilder::new("http://localhost:8080")
       ///     .delimiter(".")
       ///     .header("Authorization", "Bearer token")
      @@ -527,9 +537,18 @@ impl RestNamespaceBuilder {
               self
           }
       
      -    /// Build the RestNamespace. Auth provider precedence:
      -    /// `auth_provider()` setter > `rest.auth.*` properties > none.
      +    /// Build the RestNamespace.
           pub fn build(self) -> Result {
      +        let has_auth = self.auth_provider.is_some()
      +            || self.auth_properties.contains_key(AUTH_TYPE_KEY);
      +        if has_auth && self.headers.keys().any(|k| k.eq_ignore_ascii_case("authorization")) {
      +            return Err(NamespaceError::InvalidInput {
      +                message: "cannot combine header.Authorization with rest.auth.* — \
      +                          use one authentication method"
      +                    .to_string(),
      +            }
      +            .into());
      +        }
               let auth = if let Some(p) = self.auth_provider.clone() {
                   Some(p)
               } else if self.auth_properties.contains_key(AUTH_TYPE_KEY) {
      @@ -2507,4 +2526,20 @@ mod tests {
                   "setter auth_provider must override rest.auth.type property"
               );
           }
      +
      +    #[test]
      +    fn build_rejects_header_authorization_combined_with_auth_type() {
      +        let mut props = HashMap::new();
      +        props.insert("uri".to_string(), "http://localhost:8080".to_string());
      +        props.insert("header.Authorization".to_string(), "Bearer token".to_string());
      +        props.insert("rest.auth.type".to_string(), "none".to_string());
      +        let err = RestNamespaceBuilder::from_properties(props)
      +            .unwrap()
      +            .build()
      +            .unwrap_err();
      +        assert!(
      +            err.to_string().contains("one authentication method"),
      +            "build must reject header.Authorization + rest.auth.*: {err}"
      +        );
      +    }
       }
      
      From b5f7f1f00f21533404aa6e9367d8595bfb19f9e6 Mon Sep 17 00:00:00 2001
      From: shiwk 
      Date: Mon, 8 Jun 2026 12:00:07 +0800
      Subject: [PATCH 5/5] feat(rest-auth): support explicit static credentials via
       properties
      
      ---
       java/lance-jni/Cargo.lock                     |   4 +-
       .../org/lance/namespace/RestNamespace.java    |   3 +
       .../org/lance/namespace/SigV4AuthTest.java    | 193 +++++++++++++++-
       python/python/tests/test_namespace_rest.py    | 207 ++++++++++++++++++
       python/src/namespace.rs                       |   4 +-
       rust/lance-namespace-impls/src/rest.rs        |  13 +-
       .../lance-namespace-impls/src/rest_adapter.rs |   6 +-
       .../src/rest_auth/sigv4.rs                    | 168 +++++++++++++-
       8 files changed, 587 insertions(+), 11 deletions(-)
      
      diff --git a/java/lance-jni/Cargo.lock b/java/lance-jni/Cargo.lock
      index f9e39b191a0..502a2a38997 100644
      --- a/java/lance-jni/Cargo.lock
      +++ b/java/lance-jni/Cargo.lock
      @@ -3910,6 +3910,7 @@ dependencies = [
        "itertools 0.13.0",
        "lance-arrow",
        "libc",
      + "libm",
        "log",
        "moka",
        "num_cpus",
      @@ -3925,6 +3926,7 @@ dependencies = [
        "tokio-stream",
        "tokio-util",
        "tracing",
      + "twox-hash",
        "url",
       ]
       
      @@ -4104,7 +4106,6 @@ dependencies = [
        "lance-select",
        "lance-table",
        "lance-tokenizer",
      - "libm",
        "libsais-rs",
        "log",
        "ndarray",
      @@ -4124,7 +4125,6 @@ dependencies = [
        "tempfile",
        "tokio",
        "tracing",
      - "twox-hash",
        "uuid",
       ]
       
      diff --git a/java/src/main/java/org/lance/namespace/RestNamespace.java b/java/src/main/java/org/lance/namespace/RestNamespace.java
      index e3f09a6d678..f4fb4059bb7 100644
      --- a/java/src/main/java/org/lance/namespace/RestNamespace.java
      +++ b/java/src/main/java/org/lance/namespace/RestNamespace.java
      @@ -123,6 +123,9 @@
        *   
    • rest.auth.type (optional): Authentication type — "sigv4" or "none" (default: none) *
    • rest.auth.sigv4.region (required if sigv4): AWS region *
    • rest.auth.sigv4.service (optional): AWS service name (default: "execute-api") + *
    • rest.auth.sigv4.access-key-id (optional): Explicit AWS access key ID + *
    • rest.auth.sigv4.secret-access-key (optional): Explicit AWS secret access key + *
    • rest.auth.sigv4.session-token (optional): STS session token *
    * *

    Note: {@code rest.auth.*} and {@code header.Authorization} are mutually exclusive. diff --git a/java/src/test/java/org/lance/namespace/SigV4AuthTest.java b/java/src/test/java/org/lance/namespace/SigV4AuthTest.java index 159e3f5e236..8b26c26e10a 100644 --- a/java/src/test/java/org/lance/namespace/SigV4AuthTest.java +++ b/java/src/test/java/org/lance/namespace/SigV4AuthTest.java @@ -66,6 +66,9 @@ void testSigV4ConnectAndOperate() { clientConfig.put("rest.auth.type", "sigv4"); clientConfig.put("rest.auth.sigv4.region", "us-east-1"); clientConfig.put("rest.auth.sigv4.service", "execute-api"); + clientConfig.put("rest.auth.sigv4.access-key-id", "AKIAIOSFODNN7EXAMPLE"); + clientConfig.put( + "rest.auth.sigv4.secret-access-key", "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY"); RestNamespace ns = new RestNamespace(); ns.initialize(clientConfig, allocator); @@ -101,6 +104,191 @@ void testSigV4MissingRegionFailsAtConnect() { } } + @Test + void testSigV4ExplicitCredentials() throws IOException { + List capturedAuth = new ArrayList<>(); + + HttpServer server = HttpServer.create(new InetSocketAddress("127.0.0.1", 0), 0); + server.createContext( + "/", + exchange -> { + String auth = exchange.getRequestHeaders().getFirst("Authorization"); + if (auth != null) { + capturedAuth.add(auth); + } + byte[] body = "{\"namespaces\":[]}".getBytes(); + exchange.sendResponseHeaders(200, body.length); + try (OutputStream os = exchange.getResponseBody()) { + os.write(body); + } + }); + server.start(); + int port = server.getAddress().getPort(); + + try { + Map clientConfig = new HashMap<>(); + clientConfig.put("uri", "http://127.0.0.1:" + port); + clientConfig.put("rest.auth.type", "sigv4"); + clientConfig.put("rest.auth.sigv4.region", "us-east-1"); + clientConfig.put("rest.auth.sigv4.service", "execute-api"); + clientConfig.put("rest.auth.sigv4.access-key-id", "AKIAIOSFODNN7EXAMPLE"); + clientConfig.put( + "rest.auth.sigv4.secret-access-key", "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY"); + + RestNamespace ns = new RestNamespace(); + ns.initialize(clientConfig, allocator); + + try { + ns.listNamespaces(new ListNamespacesRequest()); + } catch (Exception ignored) { + } + + ns.close(); + + assertFalse(capturedAuth.isEmpty(), "no Authorization header captured"); + String auth = capturedAuth.get(0); + assertTrue(auth.startsWith("AWS4-HMAC-SHA256"), "expected SigV4 header, got: " + auth); + assertTrue(auth.contains("Credential=AKIAIOSFODNN7EXAMPLE/"), "wrong access key in: " + auth); + } finally { + server.stop(0); + } + } + + @Test + void testSigV4ExplicitCredentialsWithSessionToken() throws IOException { + List capturedAuth = new ArrayList<>(); + List capturedToken = new ArrayList<>(); + + HttpServer server = HttpServer.create(new InetSocketAddress("127.0.0.1", 0), 0); + server.createContext( + "/", + exchange -> { + String auth = exchange.getRequestHeaders().getFirst("Authorization"); + if (auth != null) { + capturedAuth.add(auth); + } + String token = exchange.getRequestHeaders().getFirst("x-amz-security-token"); + if (token != null) { + capturedToken.add(token); + } + byte[] body = "{\"namespaces\":[]}".getBytes(); + exchange.sendResponseHeaders(200, body.length); + try (OutputStream os = exchange.getResponseBody()) { + os.write(body); + } + }); + server.start(); + int port = server.getAddress().getPort(); + + try { + Map clientConfig = new HashMap<>(); + clientConfig.put("uri", "http://127.0.0.1:" + port); + clientConfig.put("rest.auth.type", "sigv4"); + clientConfig.put("rest.auth.sigv4.region", "us-east-1"); + clientConfig.put("rest.auth.sigv4.service", "execute-api"); + clientConfig.put("rest.auth.sigv4.access-key-id", "AKIAIOSFODNN7EXAMPLE"); + clientConfig.put( + "rest.auth.sigv4.secret-access-key", "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY"); + clientConfig.put("rest.auth.sigv4.session-token", "FakeSessionToken123"); + + RestNamespace ns = new RestNamespace(); + ns.initialize(clientConfig, allocator); + + try { + ns.listNamespaces(new ListNamespacesRequest()); + } catch (Exception ignored) { + } + + ns.close(); + + assertFalse(capturedAuth.isEmpty(), "no Authorization header captured"); + String auth = capturedAuth.get(0); + assertTrue(auth.startsWith("AWS4-HMAC-SHA256"), "expected SigV4 header, got: " + auth); + + assertFalse(capturedToken.isEmpty(), "no x-amz-security-token header captured"); + assertEquals("FakeSessionToken123", capturedToken.get(0)); + } finally { + server.stop(0); + } + } + + // Precedence (properties > env) is verified by Python/Rust; JVM cannot mutate env at runtime. + @Test + void testSigV4ExplicitCredentialsUsedRegardlessOfEnv() throws IOException { + List capturedAuth = new ArrayList<>(); + + HttpServer server = HttpServer.create(new InetSocketAddress("127.0.0.1", 0), 0); + server.createContext( + "/", + exchange -> { + String auth = exchange.getRequestHeaders().getFirst("Authorization"); + if (auth != null) { + capturedAuth.add(auth); + } + byte[] body = "{\"namespaces\":[]}".getBytes(); + exchange.sendResponseHeaders(200, body.length); + try (OutputStream os = exchange.getResponseBody()) { + os.write(body); + } + }); + server.start(); + int port = server.getAddress().getPort(); + + try { + Map clientConfig = new HashMap<>(); + clientConfig.put("uri", "http://127.0.0.1:" + port); + clientConfig.put("rest.auth.type", "sigv4"); + clientConfig.put("rest.auth.sigv4.region", "us-east-1"); + clientConfig.put("rest.auth.sigv4.service", "execute-api"); + clientConfig.put("rest.auth.sigv4.access-key-id", "AKIAIOSFODNN7EXAMPLE"); + clientConfig.put( + "rest.auth.sigv4.secret-access-key", "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY"); + + RestNamespace ns = new RestNamespace(); + ns.initialize(clientConfig, allocator); + + try { + ns.listNamespaces(new ListNamespacesRequest()); + } catch (Exception ignored) { + } + + ns.close(); + + assertFalse(capturedAuth.isEmpty(), "no Authorization header captured"); + String auth = capturedAuth.get(0); + assertTrue( + auth.contains("Credential=AKIAIOSFODNN7EXAMPLE/"), + "properties credentials must be used, got: " + auth); + } finally { + server.stop(0); + } + } + + @Test + void testSigV4PartialCredentialsRejected() { + Map backendConfig = new HashMap<>(); + backendConfig.put("root", tempDir.toString()); + + RestAdapter adapter = new RestAdapter("dir", backendConfig, "127.0.0.1", 0); + adapter.start(); + try { + Map clientConfig = new HashMap<>(); + clientConfig.put("uri", "http://127.0.0.1:" + adapter.getPort()); + clientConfig.put("rest.auth.type", "sigv4"); + clientConfig.put("rest.auth.sigv4.region", "us-east-1"); + clientConfig.put("rest.auth.sigv4.access-key-id", "AKIAIOSFODNN7EXAMPLE"); + + RestNamespace ns = new RestNamespace(); + RuntimeException ex = + assertThrows(RuntimeException.class, () -> ns.initialize(clientConfig, allocator)); + assertTrue( + ex.getMessage().contains("rest.auth.sigv4.secret-access-key"), + "error must mention missing key: " + ex.getMessage()); + } finally { + adapter.close(); + } + } + // Signature correctness is verified at the Rust layer (AWS test vectors + botocore). @Test void testSigV4SignatureHeadersPresent() throws IOException { @@ -129,6 +317,9 @@ void testSigV4SignatureHeadersPresent() throws IOException { clientConfig.put("rest.auth.type", "sigv4"); clientConfig.put("rest.auth.sigv4.region", "us-east-1"); clientConfig.put("rest.auth.sigv4.service", "execute-api"); + clientConfig.put("rest.auth.sigv4.access-key-id", "AKIAIOSFODNN7EXAMPLE"); + clientConfig.put( + "rest.auth.sigv4.secret-access-key", "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY"); RestNamespace ns = new RestNamespace(); ns.initialize(clientConfig, allocator); @@ -143,7 +334,7 @@ void testSigV4SignatureHeadersPresent() throws IOException { assertFalse(capturedAuth.isEmpty(), "no Authorization header captured"); String auth = capturedAuth.get(0); assertTrue(auth.startsWith("AWS4-HMAC-SHA256"), "expected SigV4 header, got: " + auth); - assertTrue(auth.contains("Credential="), "missing Credential in: " + auth); + assertTrue(auth.contains("Credential=AKIAIOSFODNN7EXAMPLE/"), "wrong access key in: " + auth); assertTrue(auth.contains("SignedHeaders="), "missing SignedHeaders in: " + auth); assertTrue(auth.matches(".*Signature=[a-f0-9]{64}.*"), "missing Signature in: " + auth); } finally { diff --git a/python/python/tests/test_namespace_rest.py b/python/python/tests/test_namespace_rest.py index a482a8f3bb5..6191fea85bb 100644 --- a/python/python/tests/test_namespace_rest.py +++ b/python/python/tests/test_namespace_rest.py @@ -910,3 +910,210 @@ def log_message(self, *_args): ) finally: server.shutdown() + + def test_sigv4_explicit_credentials_take_precedence_over_env(self, monkeypatch): + import json + import threading + from http.server import BaseHTTPRequestHandler, HTTPServer + + monkeypatch.setenv("AWS_ACCESS_KEY_ID", "ENVAKID_SHOULD_NOT_APPEAR") + monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "envSecretShouldNotAppear") + + captured_headers = [] + + class Recorder(BaseHTTPRequestHandler): + def _capture_and_respond(self): + captured_headers.append( + {k.lower(): v for k, v in self.headers.items()} + ) + self.send_response(200) + self.send_header("Content-Type", "application/json") + self.end_headers() + self.wfile.write(json.dumps({"namespaces": []}).encode()) + + def do_GET(self): + self._capture_and_respond() + + def do_POST(self): + self._capture_and_respond() + + def log_message(self, *_args): + pass + + server = HTTPServer(("127.0.0.1", 0), Recorder) + port = server.server_address[1] + threading.Thread(target=server.serve_forever, daemon=True).start() + + try: + client = connect( + "rest", + { + "uri": f"http://127.0.0.1:{port}", + "rest.auth.type": "sigv4", + "rest.auth.sigv4.region": "us-east-1", + "rest.auth.sigv4.service": "execute-api", + "rest.auth.sigv4.access-key-id": "AKIAIOSFODNN7EXAMPLE", + "rest.auth.sigv4.secret-access-key": ( + "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY" + ), + }, + ) + + try: + client.list_namespaces(ListNamespacesRequest(id=[])) + except Exception: + pass + + assert len(captured_headers) >= 1 + auth = captured_headers[0].get("authorization", "") + assert "Credential=AKIAIOSFODNN7EXAMPLE/" in auth, ( + "properties credentials must take precedence over env" + ) + assert "ENVAKID_SHOULD_NOT_APPEAR" not in auth + finally: + server.shutdown() + + def test_sigv4_partial_credentials_rejected(self): + with tempfile.TemporaryDirectory() as tmpdir: + backend_config = {"root": tmpdir} + + with lance.namespace.RestAdapter("dir", backend_config, port=0) as adapter: + with pytest.raises(Exception, match="rest.auth.sigv4.secret-access-key"): + connect( + "rest", + { + "uri": f"http://127.0.0.1:{adapter.port}", + "rest.auth.type": "sigv4", + "rest.auth.sigv4.region": "us-east-1", + "rest.auth.sigv4.access-key-id": "AKIAIOSFODNN7EXAMPLE", + }, + ) + + def test_sigv4_explicit_credentials(self, monkeypatch): + import json + import threading + from http.server import BaseHTTPRequestHandler, HTTPServer + + monkeypatch.delenv("AWS_ACCESS_KEY_ID", raising=False) + monkeypatch.delenv("AWS_SECRET_ACCESS_KEY", raising=False) + monkeypatch.delenv("AWS_SESSION_TOKEN", raising=False) + + captured_headers = [] + + class Recorder(BaseHTTPRequestHandler): + def _capture_and_respond(self): + captured_headers.append( + {k.lower(): v for k, v in self.headers.items()} + ) + self.send_response(200) + self.send_header("Content-Type", "application/json") + self.end_headers() + self.wfile.write(json.dumps({"namespaces": []}).encode()) + + def do_GET(self): + self._capture_and_respond() + + def do_POST(self): + self._capture_and_respond() + + def log_message(self, *_args): + pass + + server = HTTPServer(("127.0.0.1", 0), Recorder) + port = server.server_address[1] + threading.Thread(target=server.serve_forever, daemon=True).start() + + try: + client = connect( + "rest", + { + "uri": f"http://127.0.0.1:{port}", + "rest.auth.type": "sigv4", + "rest.auth.sigv4.region": "us-east-1", + "rest.auth.sigv4.service": "execute-api", + "rest.auth.sigv4.access-key-id": "AKIAIOSFODNN7EXAMPLE", + "rest.auth.sigv4.secret-access-key": ( + "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY" + ), + }, + ) + + try: + client.list_namespaces(ListNamespacesRequest(id=[])) + except Exception: + pass + + assert len(captured_headers) >= 1 + auth = captured_headers[0].get("authorization", "") + assert auth.startswith("AWS4-HMAC-SHA256"), ( + f"expected SigV4 header, got: {auth}" + ) + assert "Credential=AKIAIOSFODNN7EXAMPLE/" in auth + finally: + server.shutdown() + + def test_sigv4_explicit_credentials_with_session_token(self, monkeypatch): + import json + import threading + from http.server import BaseHTTPRequestHandler, HTTPServer + + monkeypatch.delenv("AWS_ACCESS_KEY_ID", raising=False) + monkeypatch.delenv("AWS_SECRET_ACCESS_KEY", raising=False) + monkeypatch.delenv("AWS_SESSION_TOKEN", raising=False) + + captured_headers = [] + + class Recorder(BaseHTTPRequestHandler): + def _capture_and_respond(self): + captured_headers.append( + {k.lower(): v for k, v in self.headers.items()} + ) + self.send_response(200) + self.send_header("Content-Type", "application/json") + self.end_headers() + self.wfile.write(json.dumps({"namespaces": []}).encode()) + + def do_GET(self): + self._capture_and_respond() + + def do_POST(self): + self._capture_and_respond() + + def log_message(self, *_args): + pass + + server = HTTPServer(("127.0.0.1", 0), Recorder) + port = server.server_address[1] + threading.Thread(target=server.serve_forever, daemon=True).start() + + try: + client = connect( + "rest", + { + "uri": f"http://127.0.0.1:{port}", + "rest.auth.type": "sigv4", + "rest.auth.sigv4.region": "us-east-1", + "rest.auth.sigv4.service": "execute-api", + "rest.auth.sigv4.access-key-id": "AKIAIOSFODNN7EXAMPLE", + "rest.auth.sigv4.secret-access-key": ( + "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY" + ), + "rest.auth.sigv4.session-token": "FakeSessionToken123", + }, + ) + + try: + client.list_namespaces(ListNamespacesRequest(id=[])) + except Exception: + pass + + assert len(captured_headers) >= 1 + auth = captured_headers[0].get("authorization", "") + assert auth.startswith("AWS4-HMAC-SHA256") + + token = captured_headers[0].get("x-amz-security-token", "") + assert token == "FakeSessionToken123", ( + f"expected session token in header, got: {token}" + ) + finally: + server.shutdown() diff --git a/python/src/namespace.rs b/python/src/namespace.rs index a153fa8b43f..d39b3860e91 100644 --- a/python/src/namespace.rs +++ b/python/src/namespace.rs @@ -796,7 +796,9 @@ impl PyRestNamespace { /// for providing dynamic per-request context. Context keys that start with `headers.` /// are converted to HTTP headers by stripping the prefix. /// * `**properties` - Namespace configuration properties (uri, delimiter, header.*, - /// rest.auth.type, rest.auth.sigv4.region, rest.auth.sigv4.service, etc.) + /// rest.auth.type, rest.auth.sigv4.region, rest.auth.sigv4.service, + /// rest.auth.sigv4.access-key-id, rest.auth.sigv4.secret-access-key, + /// rest.auth.sigv4.session-token, etc.) /// /// `rest.auth.*` and `header.Authorization` are mutually exclusive. #[new] diff --git a/rust/lance-namespace-impls/src/rest.rs b/rust/lance-namespace-impls/src/rest.rs index 16381c144b1..ae00437af29 100644 --- a/rust/lance-namespace-impls/src/rest.rs +++ b/rust/lance-namespace-impls/src/rest.rs @@ -225,10 +225,17 @@ impl RestClient { /// - `rest.auth.type` — `"sigv4"` or `"none"` (default: none) /// - `rest.auth.sigv4.region` — AWS region (required for sigv4) /// - `rest.auth.sigv4.service` — AWS service name (default: `"execute-api"`) +/// - `rest.auth.sigv4.access-key-id` — explicit AWS access key ID (optional) +/// - `rest.auth.sigv4.secret-access-key` — explicit AWS secret access key (optional) +/// - `rest.auth.sigv4.session-token` — STS session token (optional) /// -/// Credentials are resolved via the standard AWS chain (env vars, profile, -/// IMDS). Alternatively, use [`auth_provider()`](Self::auth_provider) to -/// inject a custom provider (takes precedence over properties). +/// When explicit `access-key-id` and `secret-access-key` are set, they +/// are used directly; otherwise credentials fall back to the AWS default +/// chain (env vars, profile, IMDS). The two keys must both be present or +/// both be absent. +/// +/// [`auth_provider()`](Self::auth_provider) overrides all property-based +/// auth — when set, `rest.auth.*` properties are ignored. /// /// `rest.auth.*` and `header.Authorization` are mutually exclusive — /// setting both will return an error at build time. diff --git a/rust/lance-namespace-impls/src/rest_adapter.rs b/rust/lance-namespace-impls/src/rest_adapter.rs index 85d98a14574..a155f989eb9 100644 --- a/rust/lance-namespace-impls/src/rest_adapter.rs +++ b/rust/lance-namespace-impls/src/rest_adapter.rs @@ -1483,7 +1483,8 @@ mod tests { let server_url = format!("http://127.0.0.1:{}", actual_port); let namespace = RestNamespaceBuilder::new(&server_url) .delimiter("$") - .build().unwrap(); + .build() + .unwrap(); Self { _temp_dir: temp_dir, @@ -3047,7 +3048,8 @@ mod tests { .delimiter("$") .header("X-Base-Header", "base-value") .context_provider(provider) - .build().unwrap(); + .build() + .unwrap(); // Create a namespace - should work with context provider let create_req = CreateNamespaceRequest { diff --git a/rust/lance-namespace-impls/src/rest_auth/sigv4.rs b/rust/lance-namespace-impls/src/rest_auth/sigv4.rs index 9452c2692d7..0db374b7055 100644 --- a/rust/lance-namespace-impls/src/rest_auth/sigv4.rs +++ b/rust/lance-namespace-impls/src/rest_auth/sigv4.rs @@ -21,6 +21,9 @@ use url::Url; pub const REGION_KEY: &str = "rest.auth.sigv4.region"; pub const SERVICE_KEY: &str = "rest.auth.sigv4.service"; +pub const ACCESS_KEY_ID_KEY: &str = "rest.auth.sigv4.access-key-id"; +pub const SECRET_ACCESS_KEY_KEY: &str = "rest.auth.sigv4.secret-access-key"; +pub const SESSION_TOKEN_KEY: &str = "rest.auth.sigv4.session-token"; const DEFAULT_SERVICE: &str = "execute-api"; /// Injectable time source; tests use a fixed clock. @@ -40,6 +43,7 @@ impl Clock for SystemClock { pub struct SigV4AuthProvider { region: String, service: String, + static_credentials: Option, credentials_provider: OnceCell, clock: Arc, } @@ -50,8 +54,14 @@ impl std::fmt::Debug for SigV4AuthProvider { .field("region", &self.region) .field("service", &self.service) .field( - "credentials_provider", - &self.credentials_provider.get().map(|_| "resolved"), + "credential_source", + &if self.static_credentials.is_some() { + "static" + } else if self.credentials_provider.get().is_some() { + "resolved" + } else { + "default-chain (pending)" + }, ) .finish() } @@ -70,9 +80,32 @@ impl SigV4AuthProvider { .get(SERVICE_KEY) .cloned() .unwrap_or_else(|| DEFAULT_SERVICE.to_string()); + + let ak = properties.get(ACCESS_KEY_ID_KEY); + let sk = properties.get(SECRET_ACCESS_KEY_KEY); + let static_credentials = match (ak, sk) { + (Some(ak), Some(sk)) => Some(Credentials::new( + ak.clone(), + sk.clone(), + properties.get(SESSION_TOKEN_KEY).cloned(), + None, + "lance-sigv4-static", + )), + (None, None) => None, + _ => { + return Err(NamespaceError::InvalidInput { + message: format!( + "{ACCESS_KEY_ID_KEY} and {SECRET_ACCESS_KEY_KEY} must both be set or both be omitted" + ), + } + .into()); + } + }; + Ok(Self { region, service, + static_credentials, credentials_provider: OnceCell::new(), clock: Arc::new(SystemClock), }) @@ -96,6 +129,9 @@ impl SigV4AuthProvider { async fn ensure_credentials_provider(&self) -> Result<&SharedCredentialsProvider> { self.credentials_provider .get_or_try_init(|| async { + if let Some(creds) = &self.static_credentials { + return Ok(SharedCredentialsProvider::new(creds.clone())); + } // aws_config::load panics inside an existing tokio runtime. let region = self.region.clone(); let provider = tokio::task::spawn_blocking(move || { @@ -446,4 +482,132 @@ mod tests { "session token signature must match botocore cross-verification" ); } + + #[tokio::test] + async fn explicit_credentials_via_properties_match_injected() { + let mut props = HashMap::new(); + props.insert(REGION_KEY.to_string(), VECTOR_REGION.to_string()); + props.insert(SERVICE_KEY.to_string(), VECTOR_SERVICE.to_string()); + props.insert(ACCESS_KEY_ID_KEY.to_string(), VECTOR_ACCESS_KEY.to_string()); + props.insert( + SECRET_ACCESS_KEY_KEY.to_string(), + VECTOR_SECRET_KEY.to_string(), + ); + let provider = SigV4AuthProvider::from_properties(&props) + .unwrap() + .with_clock(Arc::new(FixedClock( + UNIX_EPOCH + Duration::from_secs(VECTOR_UNIX_SECS), + ))); + + let ctx = RequestContext { + method: "GET".to_string(), + url: "https://example.amazonaws.com/".to_string(), + headers: HashMap::new(), + body_sha256: Some(crate::rest::EMPTY_BODY_SHA256.to_string()), + }; + let headers = provider.authenticate(&ctx).await.unwrap(); + let auth = headers + .iter() + .find(|(k, _)| k.eq_ignore_ascii_case("authorization")) + .map(|(_, v)| v.as_str()) + .unwrap(); + assert_eq!(auth, VECTOR_EXPECTED_AUTHORIZATION); + } + + #[tokio::test] + async fn explicit_session_token_via_properties() { + let mut props = HashMap::new(); + props.insert(REGION_KEY.to_string(), VECTOR_REGION.to_string()); + props.insert(SERVICE_KEY.to_string(), VECTOR_SERVICE.to_string()); + props.insert(ACCESS_KEY_ID_KEY.to_string(), VECTOR_ACCESS_KEY.to_string()); + props.insert( + SECRET_ACCESS_KEY_KEY.to_string(), + VECTOR_SECRET_KEY.to_string(), + ); + props.insert( + SESSION_TOKEN_KEY.to_string(), + "FakeSessionToken123".to_string(), + ); + let provider = SigV4AuthProvider::from_properties(&props) + .unwrap() + .with_clock(Arc::new(FixedClock( + UNIX_EPOCH + Duration::from_secs(VECTOR_UNIX_SECS), + ))); + + let ctx = RequestContext { + method: "GET".to_string(), + url: "https://example.amazonaws.com/".to_string(), + headers: HashMap::new(), + body_sha256: Some(crate::rest::EMPTY_BODY_SHA256.to_string()), + }; + let headers = provider.authenticate(&ctx).await.unwrap(); + + let token = headers + .iter() + .find(|(k, _)| k.eq_ignore_ascii_case("x-amz-security-token")) + .map(|(_, v)| v.as_str()); + assert_eq!(token, Some("FakeSessionToken123")); + + let auth = headers + .iter() + .find(|(k, _)| k.eq_ignore_ascii_case("authorization")) + .map(|(_, v)| v.as_str()) + .unwrap(); + assert_eq!( + auth, + "AWS4-HMAC-SHA256 Credential=AKIDEXAMPLE/20150830/us-east-1/service/aws4_request, \ + SignedHeaders=host;x-amz-content-sha256;x-amz-date;x-amz-security-token, \ + Signature=d690ca83bd782879e22797e35b2e25958c0d19696a92cfb479b73428e4d950f4", + "session-token signature mismatch" + ); + } + + #[tokio::test] + async fn injected_provider_takes_precedence_over_static_credentials() { + let injected_creds = Credentials::new( + VECTOR_ACCESS_KEY, + VECTOR_SECRET_KEY, + None, + None, + "injected", + ); + let mut props = HashMap::new(); + props.insert(REGION_KEY.to_string(), VECTOR_REGION.to_string()); + props.insert(SERVICE_KEY.to_string(), VECTOR_SERVICE.to_string()); + props.insert(ACCESS_KEY_ID_KEY.to_string(), "WRONG_AK".to_string()); + props.insert(SECRET_ACCESS_KEY_KEY.to_string(), "WRONG_SK".to_string()); + let provider = SigV4AuthProvider::from_properties(&props) + .unwrap() + .with_clock(Arc::new(FixedClock( + UNIX_EPOCH + Duration::from_secs(VECTOR_UNIX_SECS), + ))) + .with_credentials_provider(SharedCredentialsProvider::new(injected_creds)); + + let ctx = RequestContext { + method: "GET".to_string(), + url: "https://example.amazonaws.com/".to_string(), + headers: HashMap::new(), + body_sha256: Some(crate::rest::EMPTY_BODY_SHA256.to_string()), + }; + let headers = provider.authenticate(&ctx).await.unwrap(); + let auth = headers + .iter() + .find(|(k, _)| k.eq_ignore_ascii_case("authorization")) + .map(|(_, v)| v.as_str()) + .unwrap(); + assert_eq!(auth, VECTOR_EXPECTED_AUTHORIZATION); + assert!(!auth.contains("WRONG_AK")); + } + + #[test] + fn from_properties_rejects_partial_credentials() { + let mut props = HashMap::new(); + props.insert(REGION_KEY.to_string(), "us-east-1".to_string()); + props.insert(ACCESS_KEY_ID_KEY.to_string(), "AKID".to_string()); + let err = SigV4AuthProvider::from_properties(&props).unwrap_err(); + assert!( + err.to_string().contains(SECRET_ACCESS_KEY_KEY), + "error must mention missing key: {err}" + ); + } }