diff --git a/Cargo.lock b/Cargo.lock index d9d7588827e..f02849aecda 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -5039,12 +5039,15 @@ dependencies = [ "arrow-schema", "async-trait", "aws-config", + "aws-credential-types", "aws-sdk-sts", + "aws-sigv4", "axum", "base64 0.22.1", "bytes", "chrono", "futures", + "hex", "hmac 0.12.1", "lance", "lance-arrow", diff --git a/java/lance-jni/Cargo.lock b/java/lance-jni/Cargo.lock index fa08fd758aa..502a2a38997 100644 --- a/java/lance-jni/Cargo.lock +++ b/java/lance-jni/Cargo.lock @@ -3910,6 +3910,7 @@ dependencies = [ "itertools 0.13.0", "lance-arrow", "libc", + "libm", "log", "moka", "num_cpus", @@ -3925,6 +3926,7 @@ dependencies = [ "tokio-stream", "tokio-util", "tracing", + "twox-hash", "url", ] @@ -4104,7 +4106,6 @@ dependencies = [ "lance-select", "lance-table", "lance-tokenizer", - "libm", "libsais-rs", "log", "ndarray", @@ -4124,7 +4125,6 @@ dependencies = [ "tempfile", "tokio", "tracing", - "twox-hash", "uuid", ] @@ -4242,9 +4242,13 @@ dependencies = [ "arrow-ipc", "arrow-schema", "async-trait", + "aws-config", + "aws-credential-types", + "aws-sigv4", "axum", "bytes", "futures", + "hex", "lance", "lance-core", "lance-index", @@ -4258,6 +4262,7 @@ dependencies = [ "reqwest 0.12.28", "serde", "serde_json", + "sha2 0.10.9", "tokio", "tower", "tower-http 0.5.2", diff --git a/java/lance-jni/Cargo.toml b/java/lance-jni/Cargo.toml index 5eaa69f071b..3d0520e38b3 100644 --- a/java/lance-jni/Cargo.toml +++ b/java/lance-jni/Cargo.toml @@ -23,7 +23,7 @@ lance-linalg = { path = "../../rust/lance-linalg" } lance-index = { path = "../../rust/lance-index" } lance-io = { path = "../../rust/lance-io" } lance-namespace = { path = "../../rust/lance-namespace" } -lance-namespace-impls = { path = "../../rust/lance-namespace-impls", features = ["rest", "rest-adapter", "dir-goosefs"] } +lance-namespace-impls = { path = "../../rust/lance-namespace-impls", features = ["rest", "rest-adapter", "dir-goosefs", "rest-auth-sigv4"] } lance-core = { path = "../../rust/lance-core" } lance-file = { path = "../../rust/lance-file" } lance-table = { path = "../../rust/lance-table" } diff --git a/java/lance-jni/src/namespace.rs b/java/lance-jni/src/namespace.rs index f0da7ff79ae..5fc3e325dc6 100644 --- a/java/lance-jni/src/namespace.rs +++ b/java/lance-jni/src/namespace.rs @@ -2533,7 +2533,12 @@ fn create_rest_namespace_internal( builder = builder.context_provider(Arc::new(java_provider)); } - let namespace = builder.build(); + let namespace = builder.build().map_err(|e| { + Error::runtime_error(format!("Failed to build RestNamespace: {}", e)) + })?; + + RT.block_on(namespace.warm_up_auth()) + .map_err(|e| Error::runtime_error(format!("Auth initialization failed: {}", e)))?; let blocking_namespace = BlockingRestNamespace { inner: Arc::new(namespace), diff --git a/java/src/main/java/org/lance/namespace/RestNamespace.java b/java/src/main/java/org/lance/namespace/RestNamespace.java index 9cbbc588660..f4fb4059bb7 100644 --- a/java/src/main/java/org/lance/namespace/RestNamespace.java +++ b/java/src/main/java/org/lance/namespace/RestNamespace.java @@ -120,6 +120,18 @@ *
Note: {@code rest.auth.*} and {@code header.Authorization} are mutually exclusive. + * Setting both will throw an error at initialization time. + * + *
{@code
* Map properties = new HashMap<>();
* properties.put("uri", "https://api.example.com");
- * properties.put("delimiter", ".");
- * properties.put("header.Authorization", "Bearer my-token");
+ * properties.put("rest.auth.type", "sigv4");
+ * properties.put("rest.auth.sigv4.region", "us-east-1");
*
* RestNamespace namespace = new RestNamespace();
* namespace.initialize(properties, allocator);
diff --git a/java/src/test/java/org/lance/namespace/SigV4AuthTest.java b/java/src/test/java/org/lance/namespace/SigV4AuthTest.java
new file mode 100644
index 00000000000..8b26c26e10a
--- /dev/null
+++ b/java/src/test/java/org/lance/namespace/SigV4AuthTest.java
@@ -0,0 +1,344 @@
+/*
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.lance.namespace;
+
+import org.lance.namespace.model.CreateNamespaceRequest;
+import org.lance.namespace.model.ListNamespacesRequest;
+
+import com.sun.net.httpserver.HttpServer;
+import org.apache.arrow.memory.BufferAllocator;
+import org.apache.arrow.memory.RootAllocator;
+import org.junit.jupiter.api.AfterEach;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+import org.junit.jupiter.api.io.TempDir;
+
+import java.io.IOException;
+import java.io.OutputStream;
+import java.net.InetSocketAddress;
+import java.nio.file.Path;
+import java.util.ArrayList;
+import java.util.Arrays;
+import java.util.HashMap;
+import java.util.List;
+import java.util.Map;
+
+import static org.junit.jupiter.api.Assertions.*;
+
+public class SigV4AuthTest {
+ @TempDir Path tempDir;
+
+ private BufferAllocator allocator;
+
+ @BeforeEach
+ void setUp() {
+ allocator = new RootAllocator(Long.MAX_VALUE);
+ }
+
+ @AfterEach
+ void tearDown() {
+ if (allocator != null) {
+ allocator.close();
+ }
+ }
+
+ @Test
+ void testSigV4ConnectAndOperate() {
+ Map backendConfig = new HashMap<>();
+ backendConfig.put("root", tempDir.toString());
+
+ RestAdapter adapter = new RestAdapter("dir", backendConfig, "127.0.0.1", 0);
+ adapter.start();
+ try {
+ Map clientConfig = new HashMap<>();
+ clientConfig.put("uri", "http://127.0.0.1:" + adapter.getPort());
+ clientConfig.put("rest.auth.type", "sigv4");
+ clientConfig.put("rest.auth.sigv4.region", "us-east-1");
+ clientConfig.put("rest.auth.sigv4.service", "execute-api");
+ clientConfig.put("rest.auth.sigv4.access-key-id", "AKIAIOSFODNN7EXAMPLE");
+ clientConfig.put(
+ "rest.auth.sigv4.secret-access-key", "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY");
+
+ RestNamespace ns = new RestNamespace();
+ ns.initialize(clientConfig, allocator);
+
+ ns.createNamespace(new CreateNamespaceRequest().id(Arrays.asList("sigv4test")));
+ var resp = ns.listNamespaces(new ListNamespacesRequest());
+ assertTrue(resp.getNamespaces().contains("sigv4test"));
+
+ ns.close();
+ } finally {
+ adapter.close();
+ }
+ }
+
+ @Test
+ void testSigV4MissingRegionFailsAtConnect() {
+ Map backendConfig = new HashMap<>();
+ backendConfig.put("root", tempDir.toString());
+
+ RestAdapter adapter = new RestAdapter("dir", backendConfig, "127.0.0.1", 0);
+ adapter.start();
+ try {
+ Map clientConfig = new HashMap<>();
+ clientConfig.put("uri", "http://127.0.0.1:" + adapter.getPort());
+ clientConfig.put("rest.auth.type", "sigv4");
+
+ RestNamespace ns = new RestNamespace();
+ RuntimeException ex =
+ assertThrows(RuntimeException.class, () -> ns.initialize(clientConfig, allocator));
+ assertTrue(ex.getMessage().contains("rest.auth.sigv4.region"));
+ } finally {
+ adapter.close();
+ }
+ }
+
+ @Test
+ void testSigV4ExplicitCredentials() throws IOException {
+ List capturedAuth = new ArrayList<>();
+
+ HttpServer server = HttpServer.create(new InetSocketAddress("127.0.0.1", 0), 0);
+ server.createContext(
+ "/",
+ exchange -> {
+ String auth = exchange.getRequestHeaders().getFirst("Authorization");
+ if (auth != null) {
+ capturedAuth.add(auth);
+ }
+ byte[] body = "{\"namespaces\":[]}".getBytes();
+ exchange.sendResponseHeaders(200, body.length);
+ try (OutputStream os = exchange.getResponseBody()) {
+ os.write(body);
+ }
+ });
+ server.start();
+ int port = server.getAddress().getPort();
+
+ try {
+ Map clientConfig = new HashMap<>();
+ clientConfig.put("uri", "http://127.0.0.1:" + port);
+ clientConfig.put("rest.auth.type", "sigv4");
+ clientConfig.put("rest.auth.sigv4.region", "us-east-1");
+ clientConfig.put("rest.auth.sigv4.service", "execute-api");
+ clientConfig.put("rest.auth.sigv4.access-key-id", "AKIAIOSFODNN7EXAMPLE");
+ clientConfig.put(
+ "rest.auth.sigv4.secret-access-key", "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY");
+
+ RestNamespace ns = new RestNamespace();
+ ns.initialize(clientConfig, allocator);
+
+ try {
+ ns.listNamespaces(new ListNamespacesRequest());
+ } catch (Exception ignored) {
+ }
+
+ ns.close();
+
+ assertFalse(capturedAuth.isEmpty(), "no Authorization header captured");
+ String auth = capturedAuth.get(0);
+ assertTrue(auth.startsWith("AWS4-HMAC-SHA256"), "expected SigV4 header, got: " + auth);
+ assertTrue(auth.contains("Credential=AKIAIOSFODNN7EXAMPLE/"), "wrong access key in: " + auth);
+ } finally {
+ server.stop(0);
+ }
+ }
+
+ @Test
+ void testSigV4ExplicitCredentialsWithSessionToken() throws IOException {
+ List capturedAuth = new ArrayList<>();
+ List capturedToken = new ArrayList<>();
+
+ HttpServer server = HttpServer.create(new InetSocketAddress("127.0.0.1", 0), 0);
+ server.createContext(
+ "/",
+ exchange -> {
+ String auth = exchange.getRequestHeaders().getFirst("Authorization");
+ if (auth != null) {
+ capturedAuth.add(auth);
+ }
+ String token = exchange.getRequestHeaders().getFirst("x-amz-security-token");
+ if (token != null) {
+ capturedToken.add(token);
+ }
+ byte[] body = "{\"namespaces\":[]}".getBytes();
+ exchange.sendResponseHeaders(200, body.length);
+ try (OutputStream os = exchange.getResponseBody()) {
+ os.write(body);
+ }
+ });
+ server.start();
+ int port = server.getAddress().getPort();
+
+ try {
+ Map clientConfig = new HashMap<>();
+ clientConfig.put("uri", "http://127.0.0.1:" + port);
+ clientConfig.put("rest.auth.type", "sigv4");
+ clientConfig.put("rest.auth.sigv4.region", "us-east-1");
+ clientConfig.put("rest.auth.sigv4.service", "execute-api");
+ clientConfig.put("rest.auth.sigv4.access-key-id", "AKIAIOSFODNN7EXAMPLE");
+ clientConfig.put(
+ "rest.auth.sigv4.secret-access-key", "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY");
+ clientConfig.put("rest.auth.sigv4.session-token", "FakeSessionToken123");
+
+ RestNamespace ns = new RestNamespace();
+ ns.initialize(clientConfig, allocator);
+
+ try {
+ ns.listNamespaces(new ListNamespacesRequest());
+ } catch (Exception ignored) {
+ }
+
+ ns.close();
+
+ assertFalse(capturedAuth.isEmpty(), "no Authorization header captured");
+ String auth = capturedAuth.get(0);
+ assertTrue(auth.startsWith("AWS4-HMAC-SHA256"), "expected SigV4 header, got: " + auth);
+
+ assertFalse(capturedToken.isEmpty(), "no x-amz-security-token header captured");
+ assertEquals("FakeSessionToken123", capturedToken.get(0));
+ } finally {
+ server.stop(0);
+ }
+ }
+
+ // Precedence (properties > env) is verified by Python/Rust; JVM cannot mutate env at runtime.
+ @Test
+ void testSigV4ExplicitCredentialsUsedRegardlessOfEnv() throws IOException {
+ List capturedAuth = new ArrayList<>();
+
+ HttpServer server = HttpServer.create(new InetSocketAddress("127.0.0.1", 0), 0);
+ server.createContext(
+ "/",
+ exchange -> {
+ String auth = exchange.getRequestHeaders().getFirst("Authorization");
+ if (auth != null) {
+ capturedAuth.add(auth);
+ }
+ byte[] body = "{\"namespaces\":[]}".getBytes();
+ exchange.sendResponseHeaders(200, body.length);
+ try (OutputStream os = exchange.getResponseBody()) {
+ os.write(body);
+ }
+ });
+ server.start();
+ int port = server.getAddress().getPort();
+
+ try {
+ Map clientConfig = new HashMap<>();
+ clientConfig.put("uri", "http://127.0.0.1:" + port);
+ clientConfig.put("rest.auth.type", "sigv4");
+ clientConfig.put("rest.auth.sigv4.region", "us-east-1");
+ clientConfig.put("rest.auth.sigv4.service", "execute-api");
+ clientConfig.put("rest.auth.sigv4.access-key-id", "AKIAIOSFODNN7EXAMPLE");
+ clientConfig.put(
+ "rest.auth.sigv4.secret-access-key", "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY");
+
+ RestNamespace ns = new RestNamespace();
+ ns.initialize(clientConfig, allocator);
+
+ try {
+ ns.listNamespaces(new ListNamespacesRequest());
+ } catch (Exception ignored) {
+ }
+
+ ns.close();
+
+ assertFalse(capturedAuth.isEmpty(), "no Authorization header captured");
+ String auth = capturedAuth.get(0);
+ assertTrue(
+ auth.contains("Credential=AKIAIOSFODNN7EXAMPLE/"),
+ "properties credentials must be used, got: " + auth);
+ } finally {
+ server.stop(0);
+ }
+ }
+
+ @Test
+ void testSigV4PartialCredentialsRejected() {
+ Map backendConfig = new HashMap<>();
+ backendConfig.put("root", tempDir.toString());
+
+ RestAdapter adapter = new RestAdapter("dir", backendConfig, "127.0.0.1", 0);
+ adapter.start();
+ try {
+ Map clientConfig = new HashMap<>();
+ clientConfig.put("uri", "http://127.0.0.1:" + adapter.getPort());
+ clientConfig.put("rest.auth.type", "sigv4");
+ clientConfig.put("rest.auth.sigv4.region", "us-east-1");
+ clientConfig.put("rest.auth.sigv4.access-key-id", "AKIAIOSFODNN7EXAMPLE");
+
+ RestNamespace ns = new RestNamespace();
+ RuntimeException ex =
+ assertThrows(RuntimeException.class, () -> ns.initialize(clientConfig, allocator));
+ assertTrue(
+ ex.getMessage().contains("rest.auth.sigv4.secret-access-key"),
+ "error must mention missing key: " + ex.getMessage());
+ } finally {
+ adapter.close();
+ }
+ }
+
+ // Signature correctness is verified at the Rust layer (AWS test vectors + botocore).
+ @Test
+ void testSigV4SignatureHeadersPresent() throws IOException {
+ List capturedAuth = new ArrayList<>();
+
+ HttpServer server = HttpServer.create(new InetSocketAddress("127.0.0.1", 0), 0);
+ server.createContext(
+ "/",
+ exchange -> {
+ String auth = exchange.getRequestHeaders().getFirst("Authorization");
+ if (auth != null) {
+ capturedAuth.add(auth);
+ }
+ byte[] body = "{\"namespaces\":[]}".getBytes();
+ exchange.sendResponseHeaders(200, body.length);
+ try (OutputStream os = exchange.getResponseBody()) {
+ os.write(body);
+ }
+ });
+ server.start();
+ int port = server.getAddress().getPort();
+
+ try {
+ Map clientConfig = new HashMap<>();
+ clientConfig.put("uri", "http://127.0.0.1:" + port);
+ clientConfig.put("rest.auth.type", "sigv4");
+ clientConfig.put("rest.auth.sigv4.region", "us-east-1");
+ clientConfig.put("rest.auth.sigv4.service", "execute-api");
+ clientConfig.put("rest.auth.sigv4.access-key-id", "AKIAIOSFODNN7EXAMPLE");
+ clientConfig.put(
+ "rest.auth.sigv4.secret-access-key", "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY");
+
+ RestNamespace ns = new RestNamespace();
+ ns.initialize(clientConfig, allocator);
+
+ try {
+ ns.listNamespaces(new ListNamespacesRequest());
+ } catch (Exception ignored) {
+ }
+
+ ns.close();
+
+ assertFalse(capturedAuth.isEmpty(), "no Authorization header captured");
+ String auth = capturedAuth.get(0);
+ assertTrue(auth.startsWith("AWS4-HMAC-SHA256"), "expected SigV4 header, got: " + auth);
+ assertTrue(auth.contains("Credential=AKIAIOSFODNN7EXAMPLE/"), "wrong access key in: " + auth);
+ assertTrue(auth.contains("SignedHeaders="), "missing SignedHeaders in: " + auth);
+ assertTrue(auth.matches(".*Signature=[a-f0-9]{64}.*"), "missing Signature in: " + auth);
+ } finally {
+ server.stop(0);
+ }
+ }
+}
diff --git a/python/Cargo.lock b/python/Cargo.lock
index 7867ea71446..1634bf0d90e 100644
--- a/python/Cargo.lock
+++ b/python/Cargo.lock
@@ -4574,9 +4574,13 @@ dependencies = [
"arrow-ipc",
"arrow-schema",
"async-trait",
+ "aws-config",
+ "aws-credential-types",
+ "aws-sigv4",
"axum",
"bytes",
"futures",
+ "hex",
"lance",
"lance-core",
"lance-index",
@@ -4590,6 +4594,7 @@ dependencies = [
"reqwest 0.12.28",
"serde",
"serde_json",
+ "sha2 0.10.9",
"tokio",
"tower",
"tower-http 0.5.2",
diff --git a/python/Cargo.toml b/python/Cargo.toml
index 9c7800d3c83..6d3878c4c31 100644
--- a/python/Cargo.toml
+++ b/python/Cargo.toml
@@ -47,7 +47,7 @@ lance-index = { path = "../rust/lance-index", features = [
lance-io = { path = "../rust/lance-io" }
lance-linalg = { path = "../rust/lance-linalg" }
lance-namespace = { path = "../rust/lance-namespace" }
-lance-namespace-impls = { path = "../rust/lance-namespace-impls", features = ["rest", "rest-adapter", "dir-goosefs"] }
+lance-namespace-impls = { path = "../rust/lance-namespace-impls", features = ["rest", "rest-adapter", "dir-goosefs", "rest-auth-sigv4"] }
lance-table = { path = "../rust/lance-table" }
lance-datafusion = { path = "../rust/lance-datafusion" }
libc = "0.2.176"
diff --git a/python/python/tests/test_namespace_rest.py b/python/python/tests/test_namespace_rest.py
index 140d9168c05..6191fea85bb 100644
--- a/python/python/tests/test_namespace_rest.py
+++ b/python/python/tests/test_namespace_rest.py
@@ -747,3 +747,373 @@ def provide_context(self, info):
# Explicit provider should have been used
assert explicit_called["called"]
+
+
+class TestSigV4Auth:
+
+ def test_sigv4_connects_and_signs_requests(self, monkeypatch):
+ monkeypatch.setenv("AWS_ACCESS_KEY_ID", "AKIAIOSFODNN7EXAMPLE")
+ monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY")
+ monkeypatch.setenv("AWS_DEFAULT_REGION", "us-east-1")
+
+ with tempfile.TemporaryDirectory() as tmpdir:
+ backend_config = {"root": tmpdir}
+
+ with lance.namespace.RestAdapter("dir", backend_config, port=0) as adapter:
+ client = connect(
+ "rest",
+ {
+ "uri": f"http://127.0.0.1:{adapter.port}",
+ "rest.auth.type": "sigv4",
+ "rest.auth.sigv4.region": "us-east-1",
+ "rest.auth.sigv4.service": "execute-api",
+ },
+ )
+
+ create_req = CreateNamespaceRequest(id=["sigv4test"])
+ client.create_namespace(create_req)
+
+ list_req = ListNamespacesRequest(id=[])
+ resp = client.list_namespaces(list_req)
+ assert "sigv4test" in resp.namespaces
+
+ def test_sigv4_missing_region_fails_at_connect(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ backend_config = {"root": tmpdir}
+
+ with lance.namespace.RestAdapter("dir", backend_config, port=0) as adapter:
+ with pytest.raises(Exception, match="rest.auth.sigv4.region"):
+ connect(
+ "rest",
+ {
+ "uri": f"http://127.0.0.1:{adapter.port}",
+ "rest.auth.type": "sigv4",
+ # no region — should fail
+ },
+ )
+
+ def test_sigv4_signature_correctness(self, monkeypatch):
+ import json
+ import re
+ import threading
+ from http.server import BaseHTTPRequestHandler, HTTPServer
+
+ from botocore.auth import SigV4Auth
+ from botocore.awsrequest import AWSRequest
+ from botocore.credentials import Credentials
+
+ ACCESS_KEY = "AKIAIOSFODNN7EXAMPLE"
+ SECRET_KEY = "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY"
+
+ monkeypatch.setenv("AWS_ACCESS_KEY_ID", ACCESS_KEY)
+ monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", SECRET_KEY)
+ monkeypatch.setenv("AWS_DEFAULT_REGION", "us-east-1")
+
+ captured_requests = []
+
+ class Recorder(BaseHTTPRequestHandler):
+ def _capture_and_respond(self):
+ content_length = int(self.headers.get("Content-Length", 0))
+ body = self.rfile.read(content_length) if content_length else b""
+ captured_requests.append({
+ "method": self.command,
+ "path": self.path,
+ "headers": {k.lower(): v for k, v in self.headers.items()},
+ "body": body,
+ })
+ self.send_response(200)
+ self.send_header("Content-Type", "application/json")
+ self.end_headers()
+ self.wfile.write(json.dumps({"namespaces": []}).encode())
+
+ def do_GET(self):
+ self._capture_and_respond()
+
+ def do_POST(self):
+ self._capture_and_respond()
+
+ def log_message(self, *_args):
+ pass
+
+ server = HTTPServer(("127.0.0.1", 0), Recorder)
+ port = server.server_address[1]
+ threading.Thread(target=server.serve_forever, daemon=True).start()
+
+ try:
+ client = connect(
+ "rest",
+ {
+ "uri": f"http://127.0.0.1:{port}",
+ "rest.auth.type": "sigv4",
+ "rest.auth.sigv4.region": "us-east-1",
+ "rest.auth.sigv4.service": "execute-api",
+ },
+ )
+
+ try:
+ client.list_namespaces(ListNamespacesRequest(id=[]))
+ except Exception:
+ pass
+
+ try:
+ client.create_namespace(CreateNamespaceRequest(id=["verify"]))
+ except Exception:
+ pass
+
+ assert len(captured_requests) >= 2, (
+ f"expected at least 2 requests (GET+POST), got {len(captured_requests)}"
+ )
+ methods_seen = {r["method"] for r in captured_requests}
+ assert "GET" in methods_seen, "expected at least one GET request"
+ assert "POST" in methods_seen, "expected at least one POST request"
+
+ creds = Credentials(ACCESS_KEY, SECRET_KEY)
+ signer = SigV4Auth(creds, "execute-api", "us-east-1")
+
+ for req in captured_requests:
+ rust_auth = req["headers"].get("authorization", "")
+ assert rust_auth.startswith("AWS4-HMAC-SHA256"), (
+ f"{req['method']} {req['path']}: missing SigV4 header"
+ )
+
+ rust_sig = re.search(r"Signature=([a-f0-9]{64})", rust_auth).group(1)
+ amz_date = req["headers"]["x-amz-date"]
+
+ url = f"http://127.0.0.1:{port}{req['path']}"
+
+ signed_names = re.search(
+ r"SignedHeaders=([^,]+)", rust_auth
+ ).group(1).split(";")
+ headers_for_signing = {}
+ for name in signed_names:
+ if name in req["headers"]:
+ headers_for_signing[name] = req["headers"][name]
+
+ aws_req = AWSRequest(
+ method=req["method"],
+ url=url,
+ headers=headers_for_signing,
+ data=req["body"],
+ )
+ aws_req.context["timestamp"] = amz_date
+
+ cr = signer.canonical_request(aws_req)
+ sts = signer.string_to_sign(aws_req, cr)
+ boto_sig = signer.signature(sts, aws_req)
+
+ assert rust_sig == boto_sig, (
+ f"{req['method']} {req['path']}: signature mismatch\n"
+ f" rust: {rust_sig}\n"
+ f" botocore:{boto_sig}\n"
+ f" rust_auth: {rust_auth}\n"
+ f" botocore canonical_request:\n{cr}"
+ )
+ finally:
+ server.shutdown()
+
+ def test_sigv4_explicit_credentials_take_precedence_over_env(self, monkeypatch):
+ import json
+ import threading
+ from http.server import BaseHTTPRequestHandler, HTTPServer
+
+ monkeypatch.setenv("AWS_ACCESS_KEY_ID", "ENVAKID_SHOULD_NOT_APPEAR")
+ monkeypatch.setenv("AWS_SECRET_ACCESS_KEY", "envSecretShouldNotAppear")
+
+ captured_headers = []
+
+ class Recorder(BaseHTTPRequestHandler):
+ def _capture_and_respond(self):
+ captured_headers.append(
+ {k.lower(): v for k, v in self.headers.items()}
+ )
+ self.send_response(200)
+ self.send_header("Content-Type", "application/json")
+ self.end_headers()
+ self.wfile.write(json.dumps({"namespaces": []}).encode())
+
+ def do_GET(self):
+ self._capture_and_respond()
+
+ def do_POST(self):
+ self._capture_and_respond()
+
+ def log_message(self, *_args):
+ pass
+
+ server = HTTPServer(("127.0.0.1", 0), Recorder)
+ port = server.server_address[1]
+ threading.Thread(target=server.serve_forever, daemon=True).start()
+
+ try:
+ client = connect(
+ "rest",
+ {
+ "uri": f"http://127.0.0.1:{port}",
+ "rest.auth.type": "sigv4",
+ "rest.auth.sigv4.region": "us-east-1",
+ "rest.auth.sigv4.service": "execute-api",
+ "rest.auth.sigv4.access-key-id": "AKIAIOSFODNN7EXAMPLE",
+ "rest.auth.sigv4.secret-access-key": (
+ "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY"
+ ),
+ },
+ )
+
+ try:
+ client.list_namespaces(ListNamespacesRequest(id=[]))
+ except Exception:
+ pass
+
+ assert len(captured_headers) >= 1
+ auth = captured_headers[0].get("authorization", "")
+ assert "Credential=AKIAIOSFODNN7EXAMPLE/" in auth, (
+ "properties credentials must take precedence over env"
+ )
+ assert "ENVAKID_SHOULD_NOT_APPEAR" not in auth
+ finally:
+ server.shutdown()
+
+ def test_sigv4_partial_credentials_rejected(self):
+ with tempfile.TemporaryDirectory() as tmpdir:
+ backend_config = {"root": tmpdir}
+
+ with lance.namespace.RestAdapter("dir", backend_config, port=0) as adapter:
+ with pytest.raises(Exception, match="rest.auth.sigv4.secret-access-key"):
+ connect(
+ "rest",
+ {
+ "uri": f"http://127.0.0.1:{adapter.port}",
+ "rest.auth.type": "sigv4",
+ "rest.auth.sigv4.region": "us-east-1",
+ "rest.auth.sigv4.access-key-id": "AKIAIOSFODNN7EXAMPLE",
+ },
+ )
+
+ def test_sigv4_explicit_credentials(self, monkeypatch):
+ import json
+ import threading
+ from http.server import BaseHTTPRequestHandler, HTTPServer
+
+ monkeypatch.delenv("AWS_ACCESS_KEY_ID", raising=False)
+ monkeypatch.delenv("AWS_SECRET_ACCESS_KEY", raising=False)
+ monkeypatch.delenv("AWS_SESSION_TOKEN", raising=False)
+
+ captured_headers = []
+
+ class Recorder(BaseHTTPRequestHandler):
+ def _capture_and_respond(self):
+ captured_headers.append(
+ {k.lower(): v for k, v in self.headers.items()}
+ )
+ self.send_response(200)
+ self.send_header("Content-Type", "application/json")
+ self.end_headers()
+ self.wfile.write(json.dumps({"namespaces": []}).encode())
+
+ def do_GET(self):
+ self._capture_and_respond()
+
+ def do_POST(self):
+ self._capture_and_respond()
+
+ def log_message(self, *_args):
+ pass
+
+ server = HTTPServer(("127.0.0.1", 0), Recorder)
+ port = server.server_address[1]
+ threading.Thread(target=server.serve_forever, daemon=True).start()
+
+ try:
+ client = connect(
+ "rest",
+ {
+ "uri": f"http://127.0.0.1:{port}",
+ "rest.auth.type": "sigv4",
+ "rest.auth.sigv4.region": "us-east-1",
+ "rest.auth.sigv4.service": "execute-api",
+ "rest.auth.sigv4.access-key-id": "AKIAIOSFODNN7EXAMPLE",
+ "rest.auth.sigv4.secret-access-key": (
+ "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY"
+ ),
+ },
+ )
+
+ try:
+ client.list_namespaces(ListNamespacesRequest(id=[]))
+ except Exception:
+ pass
+
+ assert len(captured_headers) >= 1
+ auth = captured_headers[0].get("authorization", "")
+ assert auth.startswith("AWS4-HMAC-SHA256"), (
+ f"expected SigV4 header, got: {auth}"
+ )
+ assert "Credential=AKIAIOSFODNN7EXAMPLE/" in auth
+ finally:
+ server.shutdown()
+
+ def test_sigv4_explicit_credentials_with_session_token(self, monkeypatch):
+ import json
+ import threading
+ from http.server import BaseHTTPRequestHandler, HTTPServer
+
+ monkeypatch.delenv("AWS_ACCESS_KEY_ID", raising=False)
+ monkeypatch.delenv("AWS_SECRET_ACCESS_KEY", raising=False)
+ monkeypatch.delenv("AWS_SESSION_TOKEN", raising=False)
+
+ captured_headers = []
+
+ class Recorder(BaseHTTPRequestHandler):
+ def _capture_and_respond(self):
+ captured_headers.append(
+ {k.lower(): v for k, v in self.headers.items()}
+ )
+ self.send_response(200)
+ self.send_header("Content-Type", "application/json")
+ self.end_headers()
+ self.wfile.write(json.dumps({"namespaces": []}).encode())
+
+ def do_GET(self):
+ self._capture_and_respond()
+
+ def do_POST(self):
+ self._capture_and_respond()
+
+ def log_message(self, *_args):
+ pass
+
+ server = HTTPServer(("127.0.0.1", 0), Recorder)
+ port = server.server_address[1]
+ threading.Thread(target=server.serve_forever, daemon=True).start()
+
+ try:
+ client = connect(
+ "rest",
+ {
+ "uri": f"http://127.0.0.1:{port}",
+ "rest.auth.type": "sigv4",
+ "rest.auth.sigv4.region": "us-east-1",
+ "rest.auth.sigv4.service": "execute-api",
+ "rest.auth.sigv4.access-key-id": "AKIAIOSFODNN7EXAMPLE",
+ "rest.auth.sigv4.secret-access-key": (
+ "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY"
+ ),
+ "rest.auth.sigv4.session-token": "FakeSessionToken123",
+ },
+ )
+
+ try:
+ client.list_namespaces(ListNamespacesRequest(id=[]))
+ except Exception:
+ pass
+
+ assert len(captured_headers) >= 1
+ auth = captured_headers[0].get("authorization", "")
+ assert auth.startswith("AWS4-HMAC-SHA256")
+
+ token = captured_headers[0].get("x-amz-security-token", "")
+ assert token == "FakeSessionToken123", (
+ f"expected session token in header, got: {token}"
+ )
+ finally:
+ server.shutdown()
diff --git a/python/src/namespace.rs b/python/src/namespace.rs
index cf5f7c41b0f..d39b3860e91 100644
--- a/python/src/namespace.rs
+++ b/python/src/namespace.rs
@@ -788,15 +788,19 @@ pub struct PyRestNamespace {
#[pymethods]
impl PyRestNamespace {
- /// Create a new RestNamespace from properties
+ /// Create a new RestNamespace from properties.
///
/// # Arguments
///
/// * `context_provider` - Optional object with `provide_context(info: dict) -> dict` method
/// for providing dynamic per-request context. Context keys that start with `headers.`
- /// are converted to HTTP headers by stripping the prefix. For example,
- /// `{"headers.Authorization": "Bearer token"}` becomes the `Authorization` header.
- /// * `**properties` - Namespace configuration properties (uri, delimiter, header.*, etc.)
+ /// are converted to HTTP headers by stripping the prefix.
+ /// * `**properties` - Namespace configuration properties (uri, delimiter, header.*,
+ /// rest.auth.type, rest.auth.sigv4.region, rest.auth.sigv4.service,
+ /// rest.auth.sigv4.access-key-id, rest.auth.sigv4.secret-access-key,
+ /// rest.auth.sigv4.session-token, etc.)
+ ///
+ /// `rest.auth.*` and `header.Authorization` are mutually exclusive.
#[new]
#[pyo3(signature = (context_provider = None, **properties))]
fn new(
@@ -822,7 +826,11 @@ impl PyRestNamespace {
builder = builder.context_provider(Arc::new(py_provider));
}
- let namespace = builder.build();
+ let namespace = builder.build().infer_error()?;
+
+ crate::rt()
+ .block_on(None, namespace.warm_up_auth())?
+ .infer_error()?;
Ok(Self {
inner: Arc::new(namespace),
diff --git a/rust/lance-namespace-impls/Cargo.toml b/rust/lance-namespace-impls/Cargo.toml
index 53ff79fb333..2858476d157 100644
--- a/rust/lance-namespace-impls/Cargo.toml
+++ b/rust/lance-namespace-impls/Cargo.toml
@@ -13,8 +13,9 @@ rust-version.workspace = true
[features]
default = ["dir-aws", "dir-azure", "dir-gcp", "dir-oss", "dir-huggingface"]
-rest = ["dep:reqwest", "dep:serde"]
+rest = ["dep:reqwest", "dep:serde", "dep:sha2", "dep:hex"]
rest-adapter = ["dep:axum", "dep:tower", "dep:tower-http", "dep:serde"]
+rest-auth-sigv4 = ["rest", "dep:aws-sigv4", "dep:aws-credential-types", "dep:aws-config"]
# Cloud storage features for directory implementation - align with lance-io
dir-gcp = ["lance-io/gcp", "lance/gcp"]
dir-aws = ["lance-io/aws", "lance/aws"]
@@ -69,11 +70,14 @@ rand.workspace = true
# Shared credential vending dependencies
sha2 = { version = "0.10", optional = true }
+hex = { version = "0.4", optional = true }
base64 = { version = "0.22", optional = true }
# AWS credential vending dependencies (optional, enabled by "credential-vendor-aws" feature)
aws-sdk-sts = { version = "1.38.0", optional = true, default-features = false, features = ["default-https-client", "rt-tokio"] }
aws-config = { workspace = true, optional = true }
+aws-sigv4 = { version = "1", optional = true }
+aws-credential-types = { version = "1", optional = true }
# GCP credential vending dependencies (optional, enabled by "credential-vendor-gcp" feature)
ring = { version = "0.17", optional = true }
diff --git a/rust/lance-namespace-impls/src/connect.rs b/rust/lance-namespace-impls/src/connect.rs
index c44eb2de219..d7d52c83da3 100644
--- a/rust/lance-namespace-impls/src/connect.rs
+++ b/rust/lance-namespace-impls/src/connect.rs
@@ -184,7 +184,9 @@ impl ConnectBuilder {
if let Some(provider) = self.context_provider {
builder = builder.context_provider(provider);
}
- Ok(Arc::new(builder.build()) as Arc)
+ let ns = builder.build()?;
+ ns.warm_up_auth().await?;
+ Ok(Arc::new(ns) as Arc)
}
#[cfg(not(feature = "rest"))]
"rest" => Err(NamespaceError::Unsupported {
diff --git a/rust/lance-namespace-impls/src/lib.rs b/rust/lance-namespace-impls/src/lib.rs
index 58e29aca5ef..436aad1e4bb 100644
--- a/rust/lance-namespace-impls/src/lib.rs
+++ b/rust/lance-namespace-impls/src/lib.rs
@@ -75,6 +75,8 @@ pub mod connect;
pub mod context;
pub mod credentials;
pub mod dir;
+#[cfg(feature = "rest")]
+pub mod rest_auth;
#[cfg(feature = "rest")]
pub mod rest;
@@ -89,6 +91,9 @@ pub use dir::{
DirectoryNamespace, DirectoryNamespaceBuilder, OpsMetrics, manifest::ManifestNamespace,
};
+#[cfg(feature = "rest")]
+pub use rest_auth::{NoopAuthProvider, RequestContext, RestAuthProvider};
+
// Re-export credential vending
pub use credentials::{
CredentialVendor, DEFAULT_CREDENTIAL_DURATION_MILLIS, VendedCredentials,
diff --git a/rust/lance-namespace-impls/src/rest.rs b/rust/lance-namespace-impls/src/rest.rs
index 27a563d2807..ae00437af29 100644
--- a/rust/lance-namespace-impls/src/rest.rs
+++ b/rust/lance-namespace-impls/src/rest.rs
@@ -12,6 +12,11 @@ use crate::OpsMetrics;
use async_trait::async_trait;
use bytes::Bytes;
use reqwest::header::{HeaderName, HeaderValue};
+use sha2::{Digest, Sha256};
+
+use crate::rest_auth::{
+ AUTH_PROPERTY_PREFIX, AUTH_TYPE_KEY, RequestContext, RestAuthProvider, create_auth_provider,
+};
use crate::context::{DynamicContextProvider, OperationInfo};
@@ -66,6 +71,7 @@ struct RestClient {
base_path: String,
base_headers: HashMap,
context_provider: Option>,
+ auth_provider: Option>,
}
impl std::fmt::Debug for RestClient {
@@ -77,58 +83,127 @@ impl std::fmt::Debug for RestClient {
"context_provider",
&self.context_provider.as_ref().map(|_| "Some(...)"),
)
+ .field(
+ "auth_provider",
+ &self.auth_provider.as_ref().map(|_| "Some(...)"),
+ )
.finish()
}
}
-impl RestClient {
- /// Apply base headers and dynamic context headers to a request.
- ///
- /// This method mutates the request's headers directly, which is more efficient
- /// than creating a new client with default_headers for each request.
- fn apply_headers(&self, request: &mut reqwest::Request, operation: &str, object_id: &str) {
- let request_headers = request.headers_mut();
-
- // First apply base headers
- for (key, value) in &self.base_headers {
- if let (Ok(header_name), Ok(header_value)) =
- (HeaderName::from_str(key), HeaderValue::from_str(value))
- {
- request_headers.insert(header_name, header_value);
+fn reqwest_to_lance_error(e: reqwest::Error) -> lance_core::Error {
+ let message = format!("Failed to execute request: {e:?}");
+ if e.is_timeout() || e.is_connect() {
+ NamespaceError::ServiceUnavailable { message }.into()
+ } else {
+ NamespaceError::Internal { message }.into()
+ }
+}
+
+fn apply_string_headers(headers: &mut reqwest::header::HeaderMap, pairs: I)
+where
+ I: IntoIterator- ,
+ K: AsRef
,
+ V: AsRef,
+{
+ for (k, v) in pairs {
+ match (
+ HeaderName::from_str(k.as_ref()),
+ HeaderValue::from_str(v.as_ref()),
+ ) {
+ (Ok(name), Ok(val)) => {
+ headers.insert(name, val);
+ }
+ _ => {
+ log::warn!("dropping invalid header: {:?}: {:?}", k.as_ref(), v.as_ref());
}
}
+ }
+}
+
+pub(crate) const EMPTY_BODY_SHA256: &str =
+ "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855";
+
+/// `None` for streaming bodies (currently unreachable).
+fn body_sha256_hex(request: &reqwest::Request) -> Option {
+ match request.body() {
+ None => Some(EMPTY_BODY_SHA256.to_string()),
+ Some(body) => match body.as_bytes() {
+ None => None,
+ Some([]) => Some(EMPTY_BODY_SHA256.to_string()),
+ Some(bytes) => Some(hex::encode(Sha256::digest(bytes))),
+ },
+ }
+}
+
+impl RestClient {
+ fn build_auth_context(request: &reqwest::Request) -> RequestContext {
+ // Lossy decode keeps non-ASCII headers in the signer's view.
+ let headers = request
+ .headers()
+ .iter()
+ .map(|(k, v)| {
+ (
+ k.as_str().to_string(),
+ String::from_utf8_lossy(v.as_bytes()).into_owned(),
+ )
+ })
+ .collect();
+ RequestContext {
+ method: request.method().to_string(),
+ url: request.url().to_string(),
+ headers,
+ body_sha256: body_sha256_hex(request),
+ }
+ }
+
+ /// Apply headers: base → auth (signed) → context (unsigned).
+ async fn apply_headers(
+ &self,
+ request: &mut reqwest::Request,
+ operation: &str,
+ object_id: &str,
+ ) -> Result<()> {
+ apply_string_headers(request.headers_mut(), &self.base_headers);
+
+ if let Some(auth) = &self.auth_provider {
+ let ctx = Self::build_auth_context(request);
+ let auth_headers =
+ auth.authenticate(&ctx)
+ .await
+ .map_err(|e| NamespaceError::Unauthenticated {
+ message: format!(
+ "auth provider failed for operation '{operation}' on '{object_id}': {e}"
+ ),
+ })?;
+ apply_string_headers(request.headers_mut(), auth_headers);
+ }
- // Then apply context headers (override base headers if conflict)
if let Some(provider) = &self.context_provider {
let info = OperationInfo::new(operation, object_id);
- let context = provider.provide_context(&info);
-
const HEADERS_PREFIX: &str = "headers.";
- for (key, value) in context {
- if let Some(header_name) = key.strip_prefix(HEADERS_PREFIX)
- && let (Ok(header_name), Ok(header_value)) = (
- HeaderName::from_str(header_name),
- HeaderValue::from_str(&value),
- )
- {
- request_headers.insert(header_name, header_value);
- }
- }
+ let context_headers = provider
+ .provide_context(&info)
+ .into_iter()
+ .filter_map(|(k, v)| k.strip_prefix(HEADERS_PREFIX).map(|n| (n.to_string(), v)));
+ apply_string_headers(request.headers_mut(), context_headers);
}
+ Ok(())
}
- /// Execute a request with dynamic headers applied.
- ///
- /// This method builds the request, applies headers, and executes it.
async fn execute(
&self,
req_builder: reqwest::RequestBuilder,
operation: &str,
object_id: &str,
- ) -> std::result::Result {
- let mut request = req_builder.build()?;
- self.apply_headers(&mut request, operation, object_id);
- self.client.execute(request).await
+ ) -> Result {
+ let mut request = req_builder.build().map_err(reqwest_to_lance_error)?;
+ self.apply_headers(&mut request, operation, object_id)
+ .await?;
+ self.client
+ .execute(request)
+ .await
+ .map_err(reqwest_to_lance_error)
}
/// Get the base path URL
@@ -144,19 +219,36 @@ impl RestClient {
/// Builder for creating a RestNamespace.
///
-/// This builder provides a fluent API for configuring and establishing
-/// connections to REST-based Lance namespaces.
+/// # Authentication
+///
+/// SigV4 authentication via properties:
+/// - `rest.auth.type` — `"sigv4"` or `"none"` (default: none)
+/// - `rest.auth.sigv4.region` — AWS region (required for sigv4)
+/// - `rest.auth.sigv4.service` — AWS service name (default: `"execute-api"`)
+/// - `rest.auth.sigv4.access-key-id` — explicit AWS access key ID (optional)
+/// - `rest.auth.sigv4.secret-access-key` — explicit AWS secret access key (optional)
+/// - `rest.auth.sigv4.session-token` — STS session token (optional)
+///
+/// When explicit `access-key-id` and `secret-access-key` are set, they
+/// are used directly; otherwise credentials fall back to the AWS default
+/// chain (env vars, profile, IMDS). The two keys must both be present or
+/// both be absent.
+///
+/// [`auth_provider()`](Self::auth_provider) overrides all property-based
+/// auth — when set, `rest.auth.*` properties are ignored.
+///
+/// `rest.auth.*` and `header.Authorization` are mutually exclusive —
+/// setting both will return an error at build time.
///
/// # Examples
///
/// ```no_run
/// # use lance_namespace_impls::RestNamespaceBuilder;
/// # fn example() -> Result<(), Box> {
-/// // Create a REST namespace
/// let namespace = RestNamespaceBuilder::new("http://localhost:8080")
/// .delimiter(".")
/// .header("Authorization", "Bearer token")
-/// .build();
+/// .build()?;
/// # Ok(())
/// # }
/// ```
@@ -172,6 +264,8 @@ pub struct RestNamespaceBuilder {
context_provider: Option>,
/// When true, tracks operation metrics. Default: false.
ops_metrics_enabled: bool,
+ auth_provider: Option>,
+ auth_properties: HashMap,
}
impl std::fmt::Debug for RestNamespaceBuilder {
@@ -189,6 +283,10 @@ impl std::fmt::Debug for RestNamespaceBuilder {
&self.context_provider.as_ref().map(|_| "Some(...)"),
)
.field("ops_metrics_enabled", &self.ops_metrics_enabled)
+ .field(
+ "auth_provider",
+ &self.auth_provider.as_ref().map(|_| "Some(...)"),
+ )
.finish()
}
}
@@ -213,6 +311,8 @@ impl RestNamespaceBuilder {
assert_hostname: true,
context_provider: None,
ops_metrics_enabled: false,
+ auth_provider: None,
+ auth_properties: HashMap::new(),
}
}
@@ -252,7 +352,7 @@ impl RestNamespaceBuilder {
/// properties.insert("header.Authorization".to_string(), "Bearer token".to_string());
///
/// let namespace = RestNamespaceBuilder::from_properties(properties)?
- /// .build();
+ /// .build()?;
/// # Ok(())
/// # }
/// ```
@@ -296,6 +396,12 @@ impl RestNamespaceBuilder {
.and_then(|v| v.parse::().ok())
.unwrap_or(false);
+ let auth_properties: HashMap = properties
+ .iter()
+ .filter(|(k, _)| k.starts_with(AUTH_PROPERTY_PREFIX))
+ .map(|(k, v)| (k.clone(), v.clone()))
+ .collect();
+
Ok(Self {
uri,
delimiter,
@@ -306,6 +412,8 @@ impl RestNamespaceBuilder {
assert_hostname,
context_provider: None,
ops_metrics_enabled,
+ auth_provider: None,
+ auth_properties,
})
}
@@ -411,13 +519,19 @@ impl RestNamespaceBuilder {
///
/// let namespace = RestNamespaceBuilder::new("http://localhost:8080")
/// .context_provider(Arc::new(MyProvider))
- /// .build();
+ /// .build()?;
/// ```
pub fn context_provider(mut self, provider: Arc) -> Self {
self.context_provider = Some(provider);
self
}
+ /// Set the auth provider directly. Takes precedence over `rest.auth.*` properties.
+ pub fn auth_provider(mut self, provider: Arc) -> Self {
+ self.auth_provider = Some(provider);
+ self
+ }
+
/// Enable or disable operation metrics tracking.
///
/// When enabled, the namespace will track how many times each API operation
@@ -431,12 +545,25 @@ impl RestNamespaceBuilder {
}
/// Build the RestNamespace.
- ///
- /// # Returns
- ///
- /// Returns a `RestNamespace` instance.
- pub fn build(self) -> RestNamespace {
- RestNamespace::from_builder(self)
+ pub fn build(self) -> Result {
+ let has_auth = self.auth_provider.is_some()
+ || self.auth_properties.contains_key(AUTH_TYPE_KEY);
+ if has_auth && self.headers.keys().any(|k| k.eq_ignore_ascii_case("authorization")) {
+ return Err(NamespaceError::InvalidInput {
+ message: "cannot combine header.Authorization with rest.auth.* — \
+ use one authentication method"
+ .to_string(),
+ }
+ .into());
+ }
+ let auth = if let Some(p) = self.auth_provider.clone() {
+ Some(p)
+ } else if self.auth_properties.contains_key(AUTH_TYPE_KEY) {
+ Some(create_auth_provider(&self.auth_properties)?)
+ } else {
+ None
+ };
+ Ok(RestNamespace::from_builder(self, auth))
}
}
@@ -461,7 +588,7 @@ fn object_id_str(id: &Option>, delimiter: &str) -> Result {
/// # fn example() -> Result<(), Box> {
/// // Use the builder to create a namespace
/// let namespace = RestNamespaceBuilder::new("http://localhost:8080")
-/// .build();
+/// .build()?;
/// # Ok(())
/// # }
/// ```
@@ -487,8 +614,11 @@ impl std::fmt::Display for RestNamespace {
}
impl RestNamespace {
- /// Create a new REST namespace from builder
- pub(crate) fn from_builder(builder: RestNamespaceBuilder) -> Self {
+ /// Create a new REST namespace from builder + resolved auth provider.
+ pub(crate) fn from_builder(
+ builder: RestNamespaceBuilder,
+ auth_provider: Option>,
+ ) -> Self {
// Build reqwest client WITHOUT default headers - we'll apply headers per-request
let mut client_builder = reqwest::Client::builder();
@@ -521,6 +651,7 @@ impl RestNamespace {
base_path: builder.uri,
base_headers: builder.headers,
context_provider: builder.context_provider,
+ auth_provider,
};
let ops_metrics = if builder.ops_metrics_enabled {
@@ -536,17 +667,11 @@ impl RestNamespace {
}
}
- /// Map a reqwest::Error to the appropriate NamespaceError variant.
- ///
- /// Timeout and connection errors are mapped to `ServiceUnavailable`,
- /// while other errors are mapped to `Internal`.
- fn request_error(e: reqwest::Error) -> lance_core::Error {
- let message = format!("Failed to execute request: {:?}", e);
- if e.is_timeout() || e.is_connect() {
- NamespaceError::ServiceUnavailable { message }.into()
- } else {
- NamespaceError::Internal { message }.into()
+ pub async fn warm_up_auth(&self) -> Result<()> {
+ if let Some(auth) = &self.rest_client.auth_provider {
+ auth.initialize().await?;
}
+ Ok(())
}
/// Parse an error response body and return the appropriate NamespaceError.
@@ -585,8 +710,7 @@ impl RestNamespace {
let resp = self
.rest_client
.execute(req_builder, operation, object_id)
- .await
- .map_err(Self::request_error)?;
+ .await?;
let status = resp.status();
let content = resp.text().await.map_err(|e| {
@@ -622,8 +746,7 @@ impl RestNamespace {
let resp = self
.rest_client
.execute(req_builder, operation, object_id)
- .await
- .map_err(Self::request_error)?;
+ .await?;
let status = resp.status();
let content = resp.text().await.map_err(|e| {
@@ -659,8 +782,7 @@ impl RestNamespace {
let resp = self
.rest_client
.execute(req_builder, operation, object_id)
- .await
- .map_err(Self::request_error)?;
+ .await?;
let status = resp.status();
if status.is_success() {
@@ -690,8 +812,7 @@ impl RestNamespace {
let resp = self
.rest_client
.execute(req_builder, operation, object_id)
- .await
- .map_err(Self::request_error)?;
+ .await?;
let status = resp.status();
let content = resp.text().await.map_err(|e| {
@@ -1107,8 +1228,7 @@ impl LanceNamespace for RestNamespace {
let resp = self
.rest_client
.execute(req_builder, operation, &id)
- .await
- .map_err(Self::request_error)?;
+ .await?;
let status = resp.status();
if status.is_success() {
@@ -1569,6 +1689,11 @@ mod tests {
use wiremock::matchers::{method, path};
use wiremock::{Mock, MockServer, ResponseTemplate};
+ #[test]
+ fn empty_body_sha256_const_matches_computed() {
+ assert_eq!(EMPTY_BODY_SHA256, hex::encode(Sha256::digest(b"")));
+ }
+
#[test]
fn test_rest_namespace_creation() {
let mut properties = HashMap::new();
@@ -1582,7 +1707,8 @@ mod tests {
let _namespace = RestNamespaceBuilder::from_properties(properties)
.expect("Failed to create namespace builder")
- .build();
+ .build()
+ .unwrap();
// Successfully created the namespace - test passes if no panic
}
@@ -1599,7 +1725,8 @@ mod tests {
let _namespace = RestNamespaceBuilder::from_properties(properties)
.expect("Failed to create namespace builder")
- .build();
+ .build()
+ .unwrap();
}
#[tokio::test]
@@ -1638,7 +1765,8 @@ mod tests {
let namespace = RestNamespaceBuilder::from_properties(properties)
.expect("Failed to create namespace builder")
- .build();
+ .build()
+ .unwrap();
let request = ListNamespacesRequest {
id: Some(vec!["test".to_string()]),
@@ -1657,7 +1785,8 @@ mod tests {
properties.insert("uri".to_string(), "http://localhost:8080".to_string());
let _namespace = RestNamespaceBuilder::from_properties(properties)
.expect("Failed to create namespace builder")
- .build();
+ .build()
+ .unwrap();
// The default delimiter should be "$" - test passes if no panic
}
@@ -1669,7 +1798,8 @@ mod tests {
let _namespace = RestNamespaceBuilder::from_properties(properties)
.expect("Failed to create namespace builder")
- .build();
+ .build()
+ .unwrap();
// Test passes if no panic
}
@@ -1736,7 +1866,8 @@ mod tests {
// Should not panic even with nonexistent files (they're just ignored)
let _namespace = RestNamespaceBuilder::from_properties(properties)
.expect("Failed to create namespace builder")
- .build();
+ .build()
+ .unwrap();
}
#[tokio::test]
@@ -1757,7 +1888,9 @@ mod tests {
.await;
// Create namespace with mock server URL
- let namespace = RestNamespaceBuilder::new(mock_server.uri()).build();
+ let namespace = RestNamespaceBuilder::new(mock_server.uri())
+ .build()
+ .unwrap();
let request = ListNamespacesRequest {
id: Some(vec!["test".to_string()]),
@@ -1793,7 +1926,9 @@ mod tests {
.await;
// Create namespace with mock server URL
- let namespace = RestNamespaceBuilder::new(mock_server.uri()).build();
+ let namespace = RestNamespaceBuilder::new(mock_server.uri())
+ .build()
+ .unwrap();
let request = ListNamespacesRequest {
id: Some(vec!["test".to_string()]),
@@ -1826,7 +1961,9 @@ mod tests {
.await;
// Create namespace with mock server URL
- let namespace = RestNamespaceBuilder::new(mock_server.uri()).build();
+ let namespace = RestNamespaceBuilder::new(mock_server.uri())
+ .build()
+ .unwrap();
let request = CreateNamespaceRequest {
id: Some(vec!["test".to_string(), "newnamespace".to_string()]),
@@ -1859,7 +1996,9 @@ mod tests {
.await;
// Create namespace with mock server URL
- let namespace = RestNamespaceBuilder::new(mock_server.uri()).build();
+ let namespace = RestNamespaceBuilder::new(mock_server.uri())
+ .build()
+ .unwrap();
let request = CreateTableRequest {
id: Some(vec![
@@ -1898,7 +2037,9 @@ mod tests {
.mount(&mock_server)
.await;
- let namespace = RestNamespaceBuilder::new(mock_server.uri()).build();
+ let namespace = RestNamespaceBuilder::new(mock_server.uri())
+ .build()
+ .unwrap();
let request = CreateTableRequest {
id: Some(vec![
@@ -1963,7 +2104,9 @@ mod tests {
.await;
// Create namespace with mock server URL
- let namespace = RestNamespaceBuilder::new(mock_server.uri()).build();
+ let namespace = RestNamespaceBuilder::new(mock_server.uri())
+ .build()
+ .unwrap();
let request = InsertIntoTableRequest {
id: Some(vec![
@@ -2026,7 +2169,8 @@ mod tests {
let namespace = RestNamespaceBuilder::new(mock_server.uri())
.context_provider(provider)
- .build();
+ .build()
+ .unwrap();
let request = ListNamespacesRequest {
id: Some(vec!["test".to_string()]),
@@ -2072,7 +2216,8 @@ mod tests {
let namespace = RestNamespaceBuilder::new(mock_server.uri())
.header("Authorization", "Bearer base-token")
.context_provider(provider)
- .build();
+ .build()
+ .unwrap();
let request = ListNamespacesRequest {
id: Some(vec!["test".to_string()]),
@@ -2114,7 +2259,8 @@ mod tests {
let namespace = RestNamespaceBuilder::new(mock_server.uri())
.header("Authorization", "Bearer base-token")
.context_provider(provider)
- .build();
+ .build()
+ .unwrap();
let request = ListNamespacesRequest {
id: Some(vec!["test".to_string()]),
@@ -2145,7 +2291,8 @@ mod tests {
// Create namespace WITHOUT context provider, only base headers
let namespace = RestNamespaceBuilder::new(mock_server.uri())
.header("Authorization", "Bearer base-only")
- .build();
+ .build()
+ .unwrap();
let request = ListNamespacesRequest {
id: Some(vec!["test".to_string()]),
@@ -2155,4 +2302,251 @@ mod tests {
let result = namespace.list_namespaces(request).await;
assert!(result.is_ok(), "Failed: {:?}", result.err());
}
+
+ #[tokio::test]
+ async fn rest_auth_type_none_outbound_headers_identical_to_no_config() {
+ let mock_server = MockServer::start().await;
+ Mock::given(method("GET"))
+ .and(path("/v1/namespace/ns/list"))
+ .respond_with(
+ ResponseTemplate::new(200).set_body_json(serde_json::json!({ "namespaces": [] })),
+ )
+ .mount(&mock_server)
+ .await;
+
+ let ns_no_auth = RestNamespaceBuilder::new(mock_server.uri())
+ .build()
+ .unwrap();
+ let req = ListNamespacesRequest {
+ id: Some(vec!["ns".to_string()]),
+ ..Default::default()
+ };
+ ns_no_auth.list_namespaces(req.clone()).await.unwrap();
+
+ let mut props = HashMap::new();
+ props.insert("uri".to_string(), mock_server.uri());
+ props.insert("rest.auth.type".to_string(), "none".to_string());
+ let ns_none = RestNamespaceBuilder::from_properties(props)
+ .unwrap()
+ .build()
+ .unwrap();
+ ns_none.list_namespaces(req).await.unwrap();
+
+ let requests = mock_server.received_requests().await.unwrap();
+ assert_eq!(requests.len(), 2);
+ let h0: HashMap<_, _> = requests[0]
+ .headers
+ .iter()
+ .map(|(k, v)| (k.as_str().to_lowercase(), v.to_str().unwrap().to_string()))
+ .collect();
+ let h1: HashMap<_, _> = requests[1]
+ .headers
+ .iter()
+ .map(|(k, v)| (k.as_str().to_lowercase(), v.to_str().unwrap().to_string()))
+ .collect();
+ assert_eq!(
+ h0, h1,
+ "rest.auth.type=none should produce identical headers to no config"
+ );
+ }
+
+ #[tokio::test]
+ async fn legacy_header_authorization_unchanged_with_auth_framework() {
+ let mock_server = MockServer::start().await;
+ Mock::given(method("GET"))
+ .and(path("/v1/namespace/ns/list"))
+ .and(wiremock::matchers::header(
+ "Authorization",
+ "Bearer legacy-static-token",
+ ))
+ .respond_with(
+ ResponseTemplate::new(200).set_body_json(serde_json::json!({ "namespaces": [] })),
+ )
+ .mount(&mock_server)
+ .await;
+
+ let ns = RestNamespaceBuilder::new(mock_server.uri())
+ .header("Authorization", "Bearer legacy-static-token")
+ .build()
+ .unwrap();
+ let req = ListNamespacesRequest {
+ id: Some(vec!["ns".to_string()]),
+ ..Default::default()
+ };
+ ns.list_namespaces(req).await.unwrap();
+ }
+
+ #[test]
+ fn unknown_rest_auth_type_returns_error_at_build_time() {
+ let mut props = HashMap::new();
+ props.insert("uri".to_string(), "http://127.0.0.1:1".to_string());
+ props.insert(
+ "rest.auth.type".to_string(),
+ "definitely-not-a-real-scheme".to_string(),
+ );
+ let result = RestNamespaceBuilder::from_properties(props)
+ .unwrap()
+ .build();
+ assert!(
+ result.is_err(),
+ "expected build() to fail for unknown auth type"
+ );
+ let err_str = result.err().unwrap().to_string();
+ assert!(
+ err_str.contains("definitely-not-a-real-scheme"),
+ "error should mention the offending type, got: {err_str}"
+ );
+ assert!(
+ err_str.contains("none"),
+ "error should list supported types, got: {err_str}"
+ );
+ }
+
+ #[tokio::test]
+ async fn auth_provider_failure_surfaces_as_error() {
+ #[derive(Debug)]
+ struct AlwaysFailAuth;
+ #[async_trait::async_trait]
+ impl crate::rest_auth::RestAuthProvider for AlwaysFailAuth {
+ async fn authenticate(
+ &self,
+ _ctx: &crate::rest_auth::RequestContext,
+ ) -> Result> {
+ Err(NamespaceError::Unauthenticated {
+ message: "synthetic-token-expired".to_string(),
+ }
+ .into())
+ }
+ }
+
+ let mock_server = MockServer::start().await;
+ Mock::given(method("GET"))
+ .and(path("/v1/namespace/test/list"))
+ .respond_with(
+ ResponseTemplate::new(200).set_body_json(serde_json::json!({"namespaces": []})),
+ )
+ .mount(&mock_server)
+ .await;
+
+ let ns = RestNamespaceBuilder::new(mock_server.uri())
+ .auth_provider(std::sync::Arc::new(AlwaysFailAuth))
+ .build()
+ .unwrap();
+ let req = ListNamespacesRequest {
+ id: Some(vec!["test".to_string()]),
+ ..Default::default()
+ };
+ let result = ns.list_namespaces(req).await;
+ assert!(result.is_err(), "auth failure should bubble up");
+ let err_msg = result.err().unwrap().to_string();
+ assert!(
+ err_msg.contains("synthetic-token-expired"),
+ "underlying error must be preserved: {err_msg}"
+ );
+ assert!(
+ err_msg.contains("list_namespaces"),
+ "error should include operation context: {err_msg}"
+ );
+ }
+
+ #[tokio::test]
+ async fn warm_up_auth_surfaces_initialize_error() {
+ #[derive(Debug)]
+ struct FailOnInit;
+ #[async_trait::async_trait]
+ impl crate::rest_auth::RestAuthProvider for FailOnInit {
+ async fn authenticate(
+ &self,
+ _ctx: &crate::rest_auth::RequestContext,
+ ) -> Result> {
+ Ok(HashMap::new())
+ }
+ async fn initialize(&self) -> Result<()> {
+ Err(NamespaceError::Unauthenticated {
+ message: "synthetic-credential-chain-failure".to_string(),
+ }
+ .into())
+ }
+ }
+
+ let ns = RestNamespaceBuilder::new("http://127.0.0.1:1")
+ .auth_provider(std::sync::Arc::new(FailOnInit))
+ .build()
+ .unwrap();
+ let err = ns.warm_up_auth().await.unwrap_err();
+ assert!(
+ err.to_string()
+ .contains("synthetic-credential-chain-failure"),
+ "warm_up_auth must propagate initialize() error: {err}"
+ );
+ }
+
+ #[tokio::test]
+ async fn auth_provider_setter_takes_precedence_over_properties() {
+ #[derive(Debug)]
+ struct MarkerAuth;
+ #[async_trait::async_trait]
+ impl crate::rest_auth::RestAuthProvider for MarkerAuth {
+ async fn authenticate(
+ &self,
+ _ctx: &crate::rest_auth::RequestContext,
+ ) -> Result> {
+ let mut h = HashMap::new();
+ h.insert("x-marker".to_string(), "from-setter".to_string());
+ Ok(h)
+ }
+ }
+
+ let mock_server = MockServer::start().await;
+ Mock::given(method("GET"))
+ .and(path("/v1/namespace/test/list"))
+ .respond_with(
+ ResponseTemplate::new(200).set_body_json(serde_json::json!({"namespaces": []})),
+ )
+ .mount(&mock_server)
+ .await;
+
+ let mut props = HashMap::new();
+ props.insert("uri".to_string(), mock_server.uri());
+ props.insert("rest.auth.type".to_string(), "none".to_string());
+ let ns = RestNamespaceBuilder::from_properties(props)
+ .unwrap()
+ .auth_provider(std::sync::Arc::new(MarkerAuth))
+ .build()
+ .unwrap();
+
+ let req = ListNamespacesRequest {
+ id: Some(vec!["test".to_string()]),
+ ..Default::default()
+ };
+ let _ = ns.list_namespaces(req).await.unwrap();
+
+ let matched = mock_server.received_requests().await.unwrap();
+ assert!(!matched.is_empty());
+ let marker = matched[0]
+ .headers
+ .get("x-marker")
+ .map(|v| v.to_str().unwrap());
+ assert_eq!(
+ marker,
+ Some("from-setter"),
+ "setter auth_provider must override rest.auth.type property"
+ );
+ }
+
+ #[test]
+ fn build_rejects_header_authorization_combined_with_auth_type() {
+ let mut props = HashMap::new();
+ props.insert("uri".to_string(), "http://localhost:8080".to_string());
+ props.insert("header.Authorization".to_string(), "Bearer token".to_string());
+ props.insert("rest.auth.type".to_string(), "none".to_string());
+ let err = RestNamespaceBuilder::from_properties(props)
+ .unwrap()
+ .build()
+ .unwrap_err();
+ assert!(
+ err.to_string().contains("one authentication method"),
+ "build must reject header.Authorization + rest.auth.*: {err}"
+ );
+ }
}
diff --git a/rust/lance-namespace-impls/src/rest_adapter.rs b/rust/lance-namespace-impls/src/rest_adapter.rs
index 6a3875ebf29..a155f989eb9 100644
--- a/rust/lance-namespace-impls/src/rest_adapter.rs
+++ b/rust/lance-namespace-impls/src/rest_adapter.rs
@@ -1483,7 +1483,8 @@ mod tests {
let server_url = format!("http://127.0.0.1:{}", actual_port);
let namespace = RestNamespaceBuilder::new(&server_url)
.delimiter("$")
- .build();
+ .build()
+ .unwrap();
Self {
_temp_dir: temp_dir,
@@ -3047,7 +3048,8 @@ mod tests {
.delimiter("$")
.header("X-Base-Header", "base-value")
.context_provider(provider)
- .build();
+ .build()
+ .unwrap();
// Create a namespace - should work with context provider
let create_req = CreateNamespaceRequest {
diff --git a/rust/lance-namespace-impls/src/rest_auth.rs b/rust/lance-namespace-impls/src/rest_auth.rs
new file mode 100644
index 00000000000..1decc6027c0
--- /dev/null
+++ b/rust/lance-namespace-impls/src/rest_auth.rs
@@ -0,0 +1,134 @@
+// SPDX-License-Identifier: Apache-2.0
+// SPDX-FileCopyrightText: Copyright The Lance Authors
+
+//! Authentication providers for REST Namespace HTTP requests.
+
+use std::collections::HashMap;
+use std::sync::Arc;
+
+use async_trait::async_trait;
+use lance_core::Result;
+use lance_namespace::error::NamespaceError;
+
+#[cfg(feature = "rest-auth-sigv4")]
+pub mod sigv4;
+
+pub const AUTH_TYPE_KEY: &str = "rest.auth.type";
+pub const AUTH_PROPERTY_PREFIX: &str = "rest.auth.";
+pub const AUTH_TYPE_NONE: &str = "none";
+#[cfg(feature = "rest-auth-sigv4")]
+pub const AUTH_TYPE_SIGV4: &str = "sigv4";
+
+#[derive(Debug, Clone)]
+pub struct RequestContext {
+ pub method: String,
+ pub url: String,
+ pub headers: HashMap,
+ /// `None` for streaming bodies.
+ pub body_sha256: Option,
+}
+
+#[async_trait]
+pub trait RestAuthProvider: Send + Sync + std::fmt::Debug {
+ async fn authenticate(&self, ctx: &RequestContext) -> Result>;
+
+ /// Connect-time init; default no-op.
+ async fn initialize(&self) -> Result<()> {
+ Ok(())
+ }
+}
+
+#[derive(Debug, Default)]
+pub struct NoopAuthProvider;
+
+#[async_trait]
+impl RestAuthProvider for NoopAuthProvider {
+ async fn authenticate(&self, _ctx: &RequestContext) -> Result> {
+ Ok(HashMap::new())
+ }
+}
+
+pub fn create_auth_provider(
+ properties: &HashMap,
+) -> Result> {
+ let auth_type = properties
+ .get(AUTH_TYPE_KEY)
+ .map(|s| s.as_str())
+ .unwrap_or(AUTH_TYPE_NONE);
+ match auth_type {
+ AUTH_TYPE_NONE => Ok(Arc::new(NoopAuthProvider)),
+ #[cfg(feature = "rest-auth-sigv4")]
+ AUTH_TYPE_SIGV4 => Ok(Arc::new(sigv4::SigV4AuthProvider::from_properties(
+ properties,
+ )?)),
+ other => Err(NamespaceError::InvalidInput {
+ message: format!(
+ "unsupported {AUTH_TYPE_KEY} '{other}' (supported: {})",
+ supported_auth_types()
+ ),
+ }
+ .into()),
+ }
+}
+
+fn supported_auth_types() -> &'static str {
+ #[cfg(feature = "rest-auth-sigv4")]
+ {
+ "none, sigv4"
+ }
+ #[cfg(not(feature = "rest-auth-sigv4"))]
+ {
+ "none"
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+
+ fn empty_ctx() -> RequestContext {
+ RequestContext {
+ method: "GET".to_string(),
+ url: "http://example.com/v1/test".to_string(),
+ headers: HashMap::new(),
+ body_sha256: None,
+ }
+ }
+
+ #[tokio::test]
+ async fn noop_returns_empty_headers() {
+ assert!(
+ NoopAuthProvider
+ .authenticate(&empty_ctx())
+ .await
+ .unwrap()
+ .is_empty()
+ );
+ }
+
+ #[tokio::test]
+ async fn noop_initialize_is_ok() {
+ NoopAuthProvider.initialize().await.unwrap();
+ }
+
+ #[test]
+ fn factory_accepts_missing_auth_type() {
+ assert!(create_auth_provider(&HashMap::new()).is_ok());
+ }
+
+ #[test]
+ fn factory_accepts_explicit_none() {
+ let mut props = HashMap::new();
+ props.insert(AUTH_TYPE_KEY.to_string(), AUTH_TYPE_NONE.to_string());
+ assert!(create_auth_provider(&props).is_ok());
+ }
+
+ #[test]
+ fn factory_rejects_unknown_with_helpful_error() {
+ let mut props = HashMap::new();
+ props.insert(AUTH_TYPE_KEY.to_string(), "sigv4-typo".to_string());
+ let msg = create_auth_provider(&props).unwrap_err().to_string();
+ assert!(msg.contains("sigv4-typo"));
+ assert!(msg.contains(AUTH_TYPE_NONE));
+ }
+}
diff --git a/rust/lance-namespace-impls/src/rest_auth/sigv4.rs b/rust/lance-namespace-impls/src/rest_auth/sigv4.rs
new file mode 100644
index 00000000000..0db374b7055
--- /dev/null
+++ b/rust/lance-namespace-impls/src/rest_auth/sigv4.rs
@@ -0,0 +1,613 @@
+// SPDX-License-Identifier: Apache-2.0
+// SPDX-FileCopyrightText: Copyright The Lance Authors
+
+//! AWS SigV4 authentication provider for REST Namespace.
+
+use std::collections::HashMap;
+use std::sync::Arc;
+use std::time::SystemTime;
+
+use async_trait::async_trait;
+use aws_credential_types::Credentials;
+use aws_credential_types::provider::{ProvideCredentials, SharedCredentialsProvider};
+use aws_sigv4::http_request::{
+ SignableBody, SignableRequest, SigningParams, SigningSettings, sign,
+};
+use aws_sigv4::sign::v4;
+use lance_core::Result;
+use lance_namespace::error::NamespaceError;
+use tokio::sync::OnceCell;
+use url::Url;
+
+pub const REGION_KEY: &str = "rest.auth.sigv4.region";
+pub const SERVICE_KEY: &str = "rest.auth.sigv4.service";
+pub const ACCESS_KEY_ID_KEY: &str = "rest.auth.sigv4.access-key-id";
+pub const SECRET_ACCESS_KEY_KEY: &str = "rest.auth.sigv4.secret-access-key";
+pub const SESSION_TOKEN_KEY: &str = "rest.auth.sigv4.session-token";
+const DEFAULT_SERVICE: &str = "execute-api";
+
+/// Injectable time source; tests use a fixed clock.
+pub trait Clock: Send + Sync + std::fmt::Debug {
+ fn now(&self) -> SystemTime;
+}
+
+#[derive(Debug, Default)]
+pub struct SystemClock;
+
+impl Clock for SystemClock {
+ fn now(&self) -> SystemTime {
+ SystemTime::now()
+ }
+}
+
+pub struct SigV4AuthProvider {
+ region: String,
+ service: String,
+ static_credentials: Option,
+ credentials_provider: OnceCell,
+ clock: Arc,
+}
+
+impl std::fmt::Debug for SigV4AuthProvider {
+ fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
+ f.debug_struct("SigV4AuthProvider")
+ .field("region", &self.region)
+ .field("service", &self.service)
+ .field(
+ "credential_source",
+ &if self.static_credentials.is_some() {
+ "static"
+ } else if self.credentials_provider.get().is_some() {
+ "resolved"
+ } else {
+ "default-chain (pending)"
+ },
+ )
+ .finish()
+ }
+}
+
+impl SigV4AuthProvider {
+ pub fn from_properties(properties: &HashMap) -> Result {
+ let region =
+ properties
+ .get(REGION_KEY)
+ .cloned()
+ .ok_or_else(|| NamespaceError::InvalidInput {
+ message: format!("{REGION_KEY} is required for SigV4 authentication"),
+ })?;
+ let service = properties
+ .get(SERVICE_KEY)
+ .cloned()
+ .unwrap_or_else(|| DEFAULT_SERVICE.to_string());
+
+ let ak = properties.get(ACCESS_KEY_ID_KEY);
+ let sk = properties.get(SECRET_ACCESS_KEY_KEY);
+ let static_credentials = match (ak, sk) {
+ (Some(ak), Some(sk)) => Some(Credentials::new(
+ ak.clone(),
+ sk.clone(),
+ properties.get(SESSION_TOKEN_KEY).cloned(),
+ None,
+ "lance-sigv4-static",
+ )),
+ (None, None) => None,
+ _ => {
+ return Err(NamespaceError::InvalidInput {
+ message: format!(
+ "{ACCESS_KEY_ID_KEY} and {SECRET_ACCESS_KEY_KEY} must both be set or both be omitted"
+ ),
+ }
+ .into());
+ }
+ };
+
+ Ok(Self {
+ region,
+ service,
+ static_credentials,
+ credentials_provider: OnceCell::new(),
+ clock: Arc::new(SystemClock),
+ })
+ }
+
+ pub fn with_clock(mut self, clock: Arc) -> Self {
+ self.clock = clock;
+ self
+ }
+
+ pub fn with_credentials_provider(self, provider: SharedCredentialsProvider) -> Self {
+ let cell = OnceCell::new();
+ cell.set(provider)
+ .expect("freshly constructed OnceCell never returns Err");
+ Self {
+ credentials_provider: cell,
+ ..self
+ }
+ }
+
+ async fn ensure_credentials_provider(&self) -> Result<&SharedCredentialsProvider> {
+ self.credentials_provider
+ .get_or_try_init(|| async {
+ if let Some(creds) = &self.static_credentials {
+ return Ok(SharedCredentialsProvider::new(creds.clone()));
+ }
+ // aws_config::load panics inside an existing tokio runtime.
+ let region = self.region.clone();
+ let provider = tokio::task::spawn_blocking(move || {
+ let rt = tokio::runtime::Handle::current();
+ rt.block_on(async {
+ aws_config::defaults(aws_config::BehaviorVersion::latest())
+ .region(aws_config::Region::new(region))
+ .load()
+ .await
+ })
+ })
+ .await
+ .map_err(|e| {
+ lance_core::Error::from(NamespaceError::Internal {
+ message: format!("failed to load AWS config: {e}"),
+ })
+ })?;
+ provider.credentials_provider().ok_or_else(|| {
+ lance_core::Error::from(NamespaceError::Internal {
+ message: "AWS config did not yield a credentials provider".to_string(),
+ })
+ })
+ })
+ .await
+ }
+
+ async fn resolve_credentials(&self) -> Result {
+ let provider = self.ensure_credentials_provider().await?;
+ provider.provide_credentials().await.map_err(|e| {
+ NamespaceError::Unauthenticated {
+ message: format!("failed to resolve AWS credentials: {e}"),
+ }
+ .into()
+ })
+ }
+}
+
+#[async_trait]
+impl super::RestAuthProvider for SigV4AuthProvider {
+ async fn authenticate(&self, ctx: &super::RequestContext) -> Result> {
+ let creds = self.resolve_credentials().await?;
+ let identity = creds.into();
+
+ let mut signing_settings = SigningSettings::default();
+ signing_settings.payload_checksum_kind =
+ aws_sigv4::http_request::PayloadChecksumKind::XAmzSha256;
+ let v4_params = v4::SigningParams::builder()
+ .identity(&identity)
+ .region(&self.region)
+ .name(&self.service)
+ .time(self.clock.now())
+ .settings(signing_settings)
+ .build()
+ .map_err(|e| NamespaceError::Internal {
+ message: format!("failed to build SigV4 signing params: {e}"),
+ })?;
+ let params: SigningParams = v4_params.into();
+
+ let parsed_url = Url::parse(&ctx.url).map_err(|_| NamespaceError::InvalidInput {
+ message: format!("SigV4 requires a valid URL: {}", ctx.url),
+ })?;
+ if parsed_url.host_str().is_none() {
+ return Err(NamespaceError::InvalidInput {
+ message: format!("SigV4 requires a URL with a host: {}", ctx.url),
+ }
+ .into());
+ }
+ let host = parsed_url[url::Position::BeforeHost..url::Position::AfterPort].to_string();
+
+ let other_headers = ctx
+ .headers
+ .iter()
+ .filter(|(k, _)| !k.eq_ignore_ascii_case("host"));
+ let header_iter = std::iter::once(("host", host.as_str()))
+ .chain(other_headers.map(|(k, v)| (k.as_str(), v.as_str())));
+
+ let body = match ctx.body_sha256.as_deref() {
+ Some(hash) => SignableBody::Precomputed(hash.to_string()),
+ None => SignableBody::UnsignedPayload,
+ };
+
+ let signable =
+ SignableRequest::new(&ctx.method, &ctx.url, header_iter, body).map_err(|e| {
+ NamespaceError::Internal {
+ message: format!("failed to construct SigV4 signable request: {e}"),
+ }
+ })?;
+
+ let (instructions, _signature) = sign(signable, ¶ms)
+ .map_err(|e| NamespaceError::Internal {
+ message: format!("SigV4 signing failed: {e}"),
+ })?
+ .into_parts();
+
+ Ok(instructions
+ .headers()
+ .map(|(name, value)| (name.to_string(), value.to_string()))
+ .collect())
+ }
+
+ async fn initialize(&self) -> Result<()> {
+ self.resolve_credentials().await.map(|_| ())
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use crate::rest_auth::{RequestContext, RestAuthProvider};
+ use std::time::{Duration, UNIX_EPOCH};
+
+ // AWS SigV4 test vector credentials (botocore cross-verified).
+ const VECTOR_ACCESS_KEY: &str = "AKIDEXAMPLE";
+ const VECTOR_SECRET_KEY: &str = "wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY";
+ const VECTOR_REGION: &str = "us-east-1";
+ const VECTOR_SERVICE: &str = "service";
+ const VECTOR_UNIX_SECS: u64 = 1_440_938_160; // 2015-08-30T12:36:00Z
+ const VECTOR_EXPECTED_AUTHORIZATION: &str = "AWS4-HMAC-SHA256 \
+ Credential=AKIDEXAMPLE/20150830/us-east-1/service/aws4_request, \
+ SignedHeaders=host;x-amz-content-sha256;x-amz-date, \
+ Signature=726c5c4879a6b4ccbbd3b24edbd6b8826d34f87450fbbf4e85546fc7ba9c1642";
+
+ #[derive(Debug)]
+ struct FixedClock(SystemTime);
+
+ impl Clock for FixedClock {
+ fn now(&self) -> SystemTime {
+ self.0
+ }
+ }
+
+ fn vector_provider() -> SigV4AuthProvider {
+ let creds = Credentials::new(
+ VECTOR_ACCESS_KEY,
+ VECTOR_SECRET_KEY,
+ None,
+ None,
+ "lance-sigv4-test",
+ );
+ let mut props = HashMap::new();
+ props.insert(REGION_KEY.to_string(), VECTOR_REGION.to_string());
+ props.insert(SERVICE_KEY.to_string(), VECTOR_SERVICE.to_string());
+ SigV4AuthProvider::from_properties(&props)
+ .unwrap()
+ .with_clock(Arc::new(FixedClock(
+ UNIX_EPOCH + Duration::from_secs(VECTOR_UNIX_SECS),
+ )))
+ .with_credentials_provider(SharedCredentialsProvider::new(creds))
+ }
+
+ #[test]
+ fn from_properties_requires_region() {
+ let err = SigV4AuthProvider::from_properties(&HashMap::new()).unwrap_err();
+ assert!(err.to_string().contains(REGION_KEY));
+ }
+
+ #[test]
+ fn from_properties_defaults_service_to_execute_api() {
+ let mut props = HashMap::new();
+ props.insert(REGION_KEY.to_string(), "us-west-2".to_string());
+ let provider = SigV4AuthProvider::from_properties(&props).unwrap();
+ assert_eq!(provider.service, DEFAULT_SERVICE);
+ assert_eq!(provider.region, "us-west-2");
+ }
+
+ #[test]
+ fn from_properties_accepts_explicit_service() {
+ let mut props = HashMap::new();
+ props.insert(REGION_KEY.to_string(), "us-east-1".to_string());
+ props.insert(SERVICE_KEY.to_string(), "s3".to_string());
+ let provider = SigV4AuthProvider::from_properties(&props).unwrap();
+ assert_eq!(provider.service, "s3");
+ }
+
+ #[tokio::test]
+ async fn reproduces_aws_get_vanilla_reference_vector() {
+ let provider = vector_provider();
+ let ctx = RequestContext {
+ method: "GET".to_string(),
+ url: "https://example.amazonaws.com/".to_string(),
+ headers: HashMap::new(),
+ body_sha256: Some(crate::rest::EMPTY_BODY_SHA256.to_string()),
+ };
+ let headers = provider.authenticate(&ctx).await.unwrap();
+ let actual = headers
+ .iter()
+ .find(|(k, _)| k.eq_ignore_ascii_case("authorization"))
+ .map(|(_, v)| v.as_str())
+ .expect("authorization header must be produced");
+ assert_eq!(actual, VECTOR_EXPECTED_AUTHORIZATION);
+ }
+
+ #[tokio::test]
+ async fn initialize_resolves_injected_credentials() {
+ vector_provider().initialize().await.unwrap();
+ }
+
+ #[tokio::test]
+ async fn authenticate_rejects_url_without_host() {
+ let provider = vector_provider();
+ let ctx = RequestContext {
+ method: "GET".to_string(),
+ url: "file:///nowhere".to_string(),
+ headers: HashMap::new(),
+ body_sha256: Some(crate::rest::EMPTY_BODY_SHA256.to_string()),
+ };
+ let err = provider.authenticate(&ctx).await.unwrap_err();
+ assert!(err.to_string().contains("host"));
+ }
+
+ #[tokio::test]
+ async fn authenticate_overrides_preexisting_host_header() {
+ let provider = vector_provider();
+ let mut headers = HashMap::new();
+ headers.insert("Host".to_string(), "wrong.example.com".to_string());
+ let ctx = RequestContext {
+ method: "GET".to_string(),
+ url: "https://example.amazonaws.com/".to_string(),
+ headers,
+ body_sha256: Some(crate::rest::EMPTY_BODY_SHA256.to_string()),
+ };
+ let result = provider.authenticate(&ctx).await.unwrap();
+ let actual = result
+ .iter()
+ .find(|(k, _)| k.eq_ignore_ascii_case("authorization"))
+ .map(|(_, v)| v.as_str())
+ .expect("authorization header must be produced");
+ assert_eq!(
+ actual, VECTOR_EXPECTED_AUTHORIZATION,
+ "pre-existing Host header must be replaced by the URL-derived host"
+ );
+ }
+
+ /// AWS test vector: percent-encoded path (%3D → double-encoded %253D).
+ #[tokio::test]
+ async fn reproduces_aws_double_encode_path_vector() {
+ let creds = Credentials::new(
+ "ANOTREAL",
+ "notrealrnrELgWzOk3IfjzDKtFBhDby",
+ None,
+ None,
+ "lance-sigv4-test",
+ );
+ let mut props = HashMap::new();
+ props.insert(REGION_KEY.to_string(), "us-east-1".to_string());
+ props.insert(SERVICE_KEY.to_string(), "service".to_string());
+ let provider = SigV4AuthProvider::from_properties(&props)
+ .unwrap()
+ .with_clock(Arc::new(FixedClock(
+ UNIX_EPOCH + Duration::from_secs(VECTOR_UNIX_SECS),
+ )))
+ .with_credentials_provider(SharedCredentialsProvider::new(creds));
+
+ let ctx = RequestContext {
+ method: "POST".to_string(),
+ url: "https://tj9n5r0m12.execute-api.us-east-1.amazonaws.com/test/@connections/JBDvjfGEIAMCERw%3D".to_string(),
+ headers: HashMap::new(),
+ body_sha256: Some(crate::rest::EMPTY_BODY_SHA256.to_string()),
+ };
+ let headers = provider.authenticate(&ctx).await.unwrap();
+ let auth = headers
+ .iter()
+ .find(|(k, _)| k.eq_ignore_ascii_case("authorization"))
+ .map(|(_, v)| v.as_str())
+ .expect("authorization header must be produced");
+ assert_eq!(
+ auth,
+ "AWS4-HMAC-SHA256 Credential=ANOTREAL/20150830/us-east-1/service/aws4_request, \
+ SignedHeaders=host;x-amz-content-sha256;x-amz-date, \
+ Signature=ed434df8a348089a1188defcfcc1aa24049990a7e82021d0418cfa0eb05e4d99",
+ "double-encode-path: signature must match botocore cross-verification"
+ );
+ }
+
+ #[tokio::test]
+ async fn authenticate_with_unsigned_payload_still_produces_signature() {
+ let provider = vector_provider();
+ let ctx = RequestContext {
+ method: "GET".to_string(),
+ url: "https://example.amazonaws.com/".to_string(),
+ headers: HashMap::new(),
+ body_sha256: None,
+ };
+ let headers = provider.authenticate(&ctx).await.unwrap();
+ let auth = headers
+ .iter()
+ .find(|(k, _)| k.eq_ignore_ascii_case("authorization"))
+ .map(|(_, v)| v.as_str())
+ .expect("authorization header must be produced");
+ assert!(auth.starts_with("AWS4-HMAC-SHA256 "));
+ assert!(auth.contains("Credential="));
+ assert!(auth.contains("SignedHeaders="));
+ assert!(auth.contains("Signature="));
+ }
+
+ #[tokio::test]
+ async fn authenticate_with_session_token_produces_correct_signature() {
+ let creds = Credentials::new(
+ VECTOR_ACCESS_KEY,
+ VECTOR_SECRET_KEY,
+ Some("FakeSessionToken123".to_string()),
+ None,
+ "lance-sigv4-test",
+ );
+ let mut props = HashMap::new();
+ props.insert(REGION_KEY.to_string(), VECTOR_REGION.to_string());
+ props.insert(SERVICE_KEY.to_string(), VECTOR_SERVICE.to_string());
+ let provider = SigV4AuthProvider::from_properties(&props)
+ .unwrap()
+ .with_clock(Arc::new(FixedClock(
+ UNIX_EPOCH + Duration::from_secs(VECTOR_UNIX_SECS),
+ )))
+ .with_credentials_provider(SharedCredentialsProvider::new(creds));
+
+ let ctx = RequestContext {
+ method: "GET".to_string(),
+ url: "https://example.amazonaws.com/".to_string(),
+ headers: HashMap::new(),
+ body_sha256: Some(crate::rest::EMPTY_BODY_SHA256.to_string()),
+ };
+ let headers = provider.authenticate(&ctx).await.unwrap();
+
+ let token_header = headers
+ .iter()
+ .find(|(k, _)| k.eq_ignore_ascii_case("x-amz-security-token"))
+ .map(|(_, v)| v.as_str());
+ assert_eq!(
+ token_header,
+ Some("FakeSessionToken123"),
+ "session token must be included in output headers"
+ );
+
+ let auth = headers
+ .iter()
+ .find(|(k, _)| k.eq_ignore_ascii_case("authorization"))
+ .map(|(_, v)| v.as_str())
+ .unwrap();
+ assert!(
+ auth.contains("x-amz-security-token"),
+ "session token must be in SignedHeaders: {}",
+ auth
+ );
+ assert_eq!(
+ auth,
+ "AWS4-HMAC-SHA256 Credential=AKIDEXAMPLE/20150830/us-east-1/service/aws4_request, \
+ SignedHeaders=host;x-amz-content-sha256;x-amz-date;x-amz-security-token, \
+ Signature=d690ca83bd782879e22797e35b2e25958c0d19696a92cfb479b73428e4d950f4",
+ "session token signature must match botocore cross-verification"
+ );
+ }
+
+ #[tokio::test]
+ async fn explicit_credentials_via_properties_match_injected() {
+ let mut props = HashMap::new();
+ props.insert(REGION_KEY.to_string(), VECTOR_REGION.to_string());
+ props.insert(SERVICE_KEY.to_string(), VECTOR_SERVICE.to_string());
+ props.insert(ACCESS_KEY_ID_KEY.to_string(), VECTOR_ACCESS_KEY.to_string());
+ props.insert(
+ SECRET_ACCESS_KEY_KEY.to_string(),
+ VECTOR_SECRET_KEY.to_string(),
+ );
+ let provider = SigV4AuthProvider::from_properties(&props)
+ .unwrap()
+ .with_clock(Arc::new(FixedClock(
+ UNIX_EPOCH + Duration::from_secs(VECTOR_UNIX_SECS),
+ )));
+
+ let ctx = RequestContext {
+ method: "GET".to_string(),
+ url: "https://example.amazonaws.com/".to_string(),
+ headers: HashMap::new(),
+ body_sha256: Some(crate::rest::EMPTY_BODY_SHA256.to_string()),
+ };
+ let headers = provider.authenticate(&ctx).await.unwrap();
+ let auth = headers
+ .iter()
+ .find(|(k, _)| k.eq_ignore_ascii_case("authorization"))
+ .map(|(_, v)| v.as_str())
+ .unwrap();
+ assert_eq!(auth, VECTOR_EXPECTED_AUTHORIZATION);
+ }
+
+ #[tokio::test]
+ async fn explicit_session_token_via_properties() {
+ let mut props = HashMap::new();
+ props.insert(REGION_KEY.to_string(), VECTOR_REGION.to_string());
+ props.insert(SERVICE_KEY.to_string(), VECTOR_SERVICE.to_string());
+ props.insert(ACCESS_KEY_ID_KEY.to_string(), VECTOR_ACCESS_KEY.to_string());
+ props.insert(
+ SECRET_ACCESS_KEY_KEY.to_string(),
+ VECTOR_SECRET_KEY.to_string(),
+ );
+ props.insert(
+ SESSION_TOKEN_KEY.to_string(),
+ "FakeSessionToken123".to_string(),
+ );
+ let provider = SigV4AuthProvider::from_properties(&props)
+ .unwrap()
+ .with_clock(Arc::new(FixedClock(
+ UNIX_EPOCH + Duration::from_secs(VECTOR_UNIX_SECS),
+ )));
+
+ let ctx = RequestContext {
+ method: "GET".to_string(),
+ url: "https://example.amazonaws.com/".to_string(),
+ headers: HashMap::new(),
+ body_sha256: Some(crate::rest::EMPTY_BODY_SHA256.to_string()),
+ };
+ let headers = provider.authenticate(&ctx).await.unwrap();
+
+ let token = headers
+ .iter()
+ .find(|(k, _)| k.eq_ignore_ascii_case("x-amz-security-token"))
+ .map(|(_, v)| v.as_str());
+ assert_eq!(token, Some("FakeSessionToken123"));
+
+ let auth = headers
+ .iter()
+ .find(|(k, _)| k.eq_ignore_ascii_case("authorization"))
+ .map(|(_, v)| v.as_str())
+ .unwrap();
+ assert_eq!(
+ auth,
+ "AWS4-HMAC-SHA256 Credential=AKIDEXAMPLE/20150830/us-east-1/service/aws4_request, \
+ SignedHeaders=host;x-amz-content-sha256;x-amz-date;x-amz-security-token, \
+ Signature=d690ca83bd782879e22797e35b2e25958c0d19696a92cfb479b73428e4d950f4",
+ "session-token signature mismatch"
+ );
+ }
+
+ #[tokio::test]
+ async fn injected_provider_takes_precedence_over_static_credentials() {
+ let injected_creds = Credentials::new(
+ VECTOR_ACCESS_KEY,
+ VECTOR_SECRET_KEY,
+ None,
+ None,
+ "injected",
+ );
+ let mut props = HashMap::new();
+ props.insert(REGION_KEY.to_string(), VECTOR_REGION.to_string());
+ props.insert(SERVICE_KEY.to_string(), VECTOR_SERVICE.to_string());
+ props.insert(ACCESS_KEY_ID_KEY.to_string(), "WRONG_AK".to_string());
+ props.insert(SECRET_ACCESS_KEY_KEY.to_string(), "WRONG_SK".to_string());
+ let provider = SigV4AuthProvider::from_properties(&props)
+ .unwrap()
+ .with_clock(Arc::new(FixedClock(
+ UNIX_EPOCH + Duration::from_secs(VECTOR_UNIX_SECS),
+ )))
+ .with_credentials_provider(SharedCredentialsProvider::new(injected_creds));
+
+ let ctx = RequestContext {
+ method: "GET".to_string(),
+ url: "https://example.amazonaws.com/".to_string(),
+ headers: HashMap::new(),
+ body_sha256: Some(crate::rest::EMPTY_BODY_SHA256.to_string()),
+ };
+ let headers = provider.authenticate(&ctx).await.unwrap();
+ let auth = headers
+ .iter()
+ .find(|(k, _)| k.eq_ignore_ascii_case("authorization"))
+ .map(|(_, v)| v.as_str())
+ .unwrap();
+ assert_eq!(auth, VECTOR_EXPECTED_AUTHORIZATION);
+ assert!(!auth.contains("WRONG_AK"));
+ }
+
+ #[test]
+ fn from_properties_rejects_partial_credentials() {
+ let mut props = HashMap::new();
+ props.insert(REGION_KEY.to_string(), "us-east-1".to_string());
+ props.insert(ACCESS_KEY_ID_KEY.to_string(), "AKID".to_string());
+ let err = SigV4AuthProvider::from_properties(&props).unwrap_err();
+ assert!(
+ err.to_string().contains(SECRET_ACCESS_KEY_KEY),
+ "error must mention missing key: {err}"
+ );
+ }
+}