diff --git a/.gitignore b/.gitignore index 8771e40..2722ea9 100644 --- a/.gitignore +++ b/.gitignore @@ -54,6 +54,7 @@ Thumbs.db # Docs docs/_build/ +docs/superpowers/ claude_notes.md claude.md diff --git a/README.md b/README.md index 90a0e4c..e0e1512 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,8 @@ This package follows [LangChain's official integration guidelines](https://pytho - **LangChain & LangGraph Integration**: First-class support for modern LLM frameworks - **Vector Store Agnostic**: Compatible with Pinecone, FAISS, Weaviate, Chroma, and more -- **Post-Filter Authorization**: Filters retrieved documents based on SpiceDB permissions +- **Post-Filter Authorization**: Retrieve semantically, then filter by SpiceDB permissions +- **Pre-Filter Authorization**: Fetch authorized resource IDs via LookupResources first, then run a filtered vector store search — ideal when users have access to a small fraction of a large corpus - **Efficient Bulk Permissions**: Uses SpiceDB's native bulk API for optimal performance - **Observable**: Returns detailed metrics about authorization decisions - **Type-Safe**: Full type hints for better IDE support @@ -19,24 +20,30 @@ This package follows [LangChain's official integration guidelines](https://pytho Most RAG pipelines retrieve documents without considering user permissions. This package solves that by: 1. **Post-retrieval filtering**: Retrieve best semantic matches first, then filter by permissions -2. **Deterministic authorization**: Every document is checked against SpiceDB before being used -3. **Framework integration**: Native LangChain and LangGraph components for seamless integration -4. **Vector store agnostic**: Not tied to any specific vector database +2. **Pre-retrieval filtering**: Fetch all resource IDs the user can access via SpiceDB's `LookupResources` API, then run a filtered vector store search — no unauthorized documents are retrieved +3. **Deterministic authorization**: Every document is checked against SpiceDB before being used +4. **Framework integration**: Native LangChain and LangGraph components for seamless integration +5. **Vector store agnostic**: Not tied to any specific vector database ## Which Component Should I Use? Choose the right component based on your use case: -| Component | Use Case | Best For | -|-----------|----------|----------| -| **SpiceDBRetriever** | Simple RAG pipelines | Drop-in replacement for any retriever. Wraps your existing retriever with authorization. | -| **SpiceDBAuthFilter** | LangChain chains with middleware | Filtering documents in the middle of a chain. Reusable across different users via `config`. | -| **create_auth_node** | LangGraph workflows | Complex multi-step workflows with state management. Provides authorization metrics in state. | -| **SpiceDBPermissionTool** | Agentic workflows | Give agents the ability to check permissions before taking actions. | -| **SpiceDBBulkPermissionTool** | Agentic workflows (batch) | Same as above but for checking multiple resources at once. | +| Component | Pattern | Use Case | +|-----------|---------|----------| +| **SpiceDBRetriever** | Post-filter | Simple RAG pipelines. Drop-in replacement for any retriever. Retrieves semantically then filters by permission. Best when users have broad access. | +| **SpiceDBPreFilterRetriever** | Pre-filter | Use when users can only access a small fraction of a large corpus. Fetches authorized IDs from SpiceDB first, then runs a filtered vector search. Requires a `filter_factory` matching your vector store's filter syntax. | +| **SpiceDBAuthFilter** | Post-filter | LangChain chains with middleware. Filtering documents in the middle of a chain. Reusable across different users via `config`. | +| **create_auth_node** | Post-filter | LangGraph workflows. Complex multi-step workflows with state management. Provides authorization metrics in state. | +| **SpiceDBPermissionTool** | Check | Agentic workflows. Give agents the ability to check a single permission before taking actions. | +| **SpiceDBBulkPermissionTool** | Check | Agentic workflows (batch). Same as above but for checking multiple resources at once. | ### Quick Decision Guide +**Pre-filter vs Post-filter:** +- Use **post-filter** (`SpiceDBRetriever`, `SpiceDBAuthFilter`) when users have access to most documents. Semantic search quality is highest because all documents are candidates. +- Use **pre-filter** (`SpiceDBPreFilterRetriever`) when users have access to a small subset of a large corpus. Avoids retrieving unauthorized content entirely. Requires knowing your vector store's filter syntax. + **Use SpiceDBRetriever if:** - You have a simple RAG pipeline - You always use the same user per retriever instance and you don't need to reuse the retriever across different users @@ -91,6 +98,20 @@ agent = create_agent(llm, tools, system_prompt="You are a helpful assistant.") # Agent can check "Can user alice delete document 123?" and explain the result ``` +**Pattern 5: SpiceDBPreFilterRetriever (pre-filter)** +```python +retriever = SpiceDBPreFilterRetriever( + vector_store=vector_store, + filter_factory=lambda ids: {"filter": {"article_id": {"$in": ids}}}, + subject_id="tim", + resource_type="article", + permission="view", + spicedb_endpoint="localhost:50051", + spicedb_token="sometoken", +) +chain = retriever | prompt | llm +``` + ## Installation ```bash diff --git a/examples/pre_filter_example.py b/examples/pre_filter_example.py new file mode 100644 index 0000000..85220a0 --- /dev/null +++ b/examples/pre_filter_example.py @@ -0,0 +1,215 @@ +""" +SpiceDBPreFilterRetriever Example - Pre-Filter Authorization RAG Pipeline + +This example demonstrates how to use SpiceDBPreFilterRetriever to pre-filter +vector store searches using SpiceDB's LookupResources API. + +Unlike SpiceDBRetriever (post-filter), this approach: +1. Calls SpiceDB first to get all resource IDs the user can access +2. Passes those IDs as a filter into the vector store search +3. Only retrieves documents the user is authorized to see + +Use this pattern when users have access to a small fraction of a large corpus. +""" + +import asyncio +import os +from typing import List +from dotenv import load_dotenv +from langchain_core.documents import Document +from langchain_core.prompts import ChatPromptTemplate +from langchain_core.output_parsers import StrOutputParser +from langchain_core.runnables import RunnableParallel, RunnablePassthrough +from langchain_openai import ChatOpenAI + +from langchain_spicedb import SpiceDBPreFilterRetriever + +load_dotenv() + + +class MockVectorStore: + """ + Mock vector store simulating Pinecone with metadata filter support. + + In a real application, replace this with: + from langchain_pinecone import PineconeVectorStore + knowledge = PineconeVectorStore.from_existing_index( + index_name="my-index", + embedding=OpenAIEmbeddings(...), + ) + """ + + async def asimilarity_search( + self, query: str, k: int = 4, filter: dict = None + ) -> List[Document]: + """Return mock documents, filtered by article_id if filter is provided.""" + all_docs = [ + Document( + page_content="Python is a high-level programming language known for simplicity.", + metadata={"article_id": "123", "title": "Python Basics"}, + ), + Document( + page_content="JavaScript is the language of the web.", + metadata={"article_id": "456", "title": "JavaScript Guide"}, + ), + Document( + page_content="Machine learning models can be trained on large datasets.", + metadata={"article_id": "789", "title": "ML Introduction"}, + ), + Document( + page_content="SpiceDB is a database for fine-grained authorization.", + metadata={"article_id": "101", "title": "SpiceDB Overview"}, + ), + ] + + if filter and "article_id" in filter: + authorized = filter["article_id"].get("$in", []) + return [d for d in all_docs if d.metadata["article_id"] in authorized][:k] + + return all_docs[:k] + + +async def main(): + print("=" * 80) + print("SpiceDBPreFilterRetriever Example - Pre-Filter Authorization RAG") + print("=" * 80) + print() + + spicedb_endpoint = os.getenv("SPICEDB_ENDPOINT", "localhost:50051") + spicedb_token = os.getenv("SPICEDB_TOKEN", "somerandomkeyhere") + subject_id = os.getenv("SUBJECT_ID", "tim") + + print("Configuration:") + print(f" SpiceDB Endpoint: {spicedb_endpoint}") + print(f" Subject (User): {subject_id}") + print(" Resource Type: article") + print(" Permission: view") + print() + print("Pattern: LookupResources → authorized IDs → vector store filter → docs") + print() + + vector_store = MockVectorStore() + + # SpiceDBPreFilterRetriever: + # 1. Calls LookupResources(subject=tim, permission=view, resource_type=article) + # 2. Gets back e.g. ["123", "101"] (articles tim can view) + # 3. Calls filter_factory(["123", "101"]) → {"filter": {"article_id": {"$in": ["123", "101"]}}} + # 4. Calls vector_store.asimilarity_search(query, k=4, filter=...) + # 5. Returns only authorized + semantically relevant documents + retriever = SpiceDBPreFilterRetriever( + vector_store=vector_store, + filter_factory=lambda ids: {"filter": {"article_id": {"$in": ids}}}, + subject_id=subject_id, + resource_type="article", + permission="view", + spicedb_endpoint=spicedb_endpoint, + spicedb_token=spicedb_token, + k=4, + ) + + llm = ChatOpenAI(api_key=os.getenv("OPENAI_API_KEY"), model="gpt-4o-mini", temperature=0) + + prompt = ChatPromptTemplate.from_messages([ + ( + "system", + "Answer questions based only on the provided context. " + "If the context doesn't contain enough information, say so.", + ), + ("human", "Question: {question}\n\nContext:\n{context}"), + ]) + + def format_docs(docs): + if not docs: + return "No authorized documents found." + return "\n\n".join( + f"Document {i + 1}:\n{doc.page_content}" for i, doc in enumerate(docs) + ) + + rag_chain = ( + RunnableParallel({ + "context": retriever | format_docs, + "question": RunnablePassthrough(), + }) + | prompt + | llm + | StrOutputParser() + ) + + query = "Tell me about SpiceDB" + print(f"Query: {query}") + print("-" * 40) + + print(f"\nDocuments after pre-filter (user: {subject_id}):") + authorized_docs = await retriever.ainvoke(query) + if authorized_docs: + for doc in authorized_docs: + print(f" ✓ {doc.metadata['title']} (ID: {doc.metadata['article_id']})") + else: + print(" ✗ No authorized documents") + + print("\nLLM Answer:") + answer = await rag_chain.ainvoke(query) + print(answer) + print() + print("=" * 80) + + +async def demo_without_openai(): + """Demo showing document pre-filtering without requiring an LLM.""" + print("=" * 80) + print("SpiceDBPreFilterRetriever Demo - Pre-Filter Only") + print("=" * 80) + print() + + spicedb_endpoint = os.getenv("SPICEDB_ENDPOINT", "localhost:50051") + spicedb_token = os.getenv("SPICEDB_TOKEN", "somerandomkeyhere") + subject_id = os.getenv("SUBJECT_ID", "tim") + + print(f"Looking up authorized articles for user: {subject_id}") + print() + + vector_store = MockVectorStore() + + retriever = SpiceDBPreFilterRetriever( + vector_store=vector_store, + filter_factory=lambda ids: {"filter": {"article_id": {"$in": ids}}}, + subject_id=subject_id, + resource_type="article", + permission="view", + spicedb_endpoint=spicedb_endpoint, + spicedb_token=spicedb_token, + ) + + query = "programming languages" + docs = await retriever.ainvoke(query) + + print(f"Documents returned for query '{query}':") + if docs: + for doc in docs: + print(f" ✓ {doc.metadata['title']} (ID: {doc.metadata['article_id']})") + else: + print(" ✗ No authorized documents found") + print() + + +if __name__ == "__main__": + print() + print("Prerequisites:") + print("1. SpiceDB running on localhost:50051 (or set SPICEDB_ENDPOINT)") + print("2. Set SPICEDB_TOKEN environment variable") + print("3. SpiceDB schema with 'article' resource type and 'view' permission") + print("4. Create relationships: zed relationship create article:123 viewer user:tim") + print() + print("Optional:") + print("5. Set OPENAI_API_KEY for full RAG demo") + print("6. Set SUBJECT_ID to test different users (default: tim)") + print() + print("=" * 80) + print() + + if os.getenv("OPENAI_API_KEY"): + asyncio.run(main()) + else: + print("OpenAI API key not found. Running pre-filter demo without LLM...") + print() + asyncio.run(demo_without_openai()) diff --git a/langchain_spicedb/__init__.py b/langchain_spicedb/__init__.py index 4a58bb3..8283290 100644 --- a/langchain_spicedb/__init__.py +++ b/langchain_spicedb/__init__.py @@ -46,7 +46,7 @@ # Import LangChain standard components (retrievers, tools) try: - from .retrievers import SpiceDBRetriever # noqa: F401 + from .retrievers import SpiceDBRetriever, SpiceDBPreFilterRetriever # noqa: F401 _has_retrievers = True except ImportError: @@ -74,7 +74,7 @@ __all__.extend(["SpiceDBAuthFilter", "SpiceDBAuthLambda"]) if _has_retrievers: - __all__.extend(["SpiceDBRetriever"]) + __all__.extend(["SpiceDBRetriever", "SpiceDBPreFilterRetriever"]) if _has_tools: __all__.extend(["SpiceDBPermissionTool", "SpiceDBBulkPermissionTool"]) diff --git a/langchain_spicedb/core.py b/langchain_spicedb/core.py index e421bd3..c854f01 100644 --- a/langchain_spicedb/core.py +++ b/langchain_spicedb/core.py @@ -15,6 +15,7 @@ CheckPermissionResponse, CheckBulkPermissionsRequest, CheckBulkPermissionsRequestItem, + LookupResourcesRequest, ObjectReference, SubjectReference, ) @@ -78,7 +79,6 @@ def __init__( subject_type: str = "user", permission: str = "view", resource_id_key: str = "resource_id", - fail_open: bool = False, use_tls: bool = False, ): """ @@ -91,7 +91,6 @@ def __init__( subject_type: SpiceDB subject type (e.g., "user", "service") permission: Permission to check (e.g., "view", "edit") resource_id_key: Key in document metadata containing resource ID - fail_open: If True, allow access on errors; if False, deny on errors use_tls: Whether to use TLS for SpiceDB connection """ self.spicedb_endpoint = spicedb_endpoint @@ -100,13 +99,10 @@ def __init__( self.subject_type = subject_type self.permission = permission self.resource_id_key = resource_id_key - self.fail_open = fail_open self.use_tls = use_tls - # Initialize SpiceDB client if use_tls: from grpcutil import bearer_token_credentials - credentials = bearer_token_credentials(spicedb_token) else: credentials = insecure_bearer_token_credentials(spicedb_token) @@ -221,30 +217,72 @@ async def check_permission( resource_type = resource_type or self.resource_type permission = permission or self.permission - try: - # Await the async gRPC call - response = await self.client.CheckPermission( - CheckPermissionRequest( - resource=ObjectReference( - object_type=resource_type, - object_id=str(resource_id), - ), - permission=permission, - subject=SubjectReference( - object=ObjectReference( - object_type=subject_type, - object_id=subject_id, - ), + response = await self.client.CheckPermission( + CheckPermissionRequest( + resource=ObjectReference( + object_type=resource_type, + object_id=str(resource_id), + ), + permission=permission, + subject=SubjectReference( + object=ObjectReference( + object_type=subject_type, + object_id=subject_id, ), - ) + ), ) + ) + + return response.permissionship == CheckPermissionResponse.PERMISSIONSHIP_HAS_PERMISSION + + async def lookup_resources( + self, + subject_id: str, + subject_type: Optional[str] = None, + resource_type: Optional[str] = None, + permission: Optional[str] = None, + ) -> List[str]: + """ + Look up all resources a subject has permission to access. - return response.permissionship == CheckPermissionResponse.PERMISSIONSHIP_HAS_PERMISSION + Uses SpiceDB's LookupResources streaming API to return all resource + IDs the subject is authorized to access. Results can be used to + pre-filter a vector store search. - except Exception: - if self.fail_open: - return True - return False + Args: + subject_id: ID of the subject (user) + subject_type: Override default subject type + resource_type: Override default resource type + permission: Override default permission + + Returns: + List of authorized resource IDs. Empty list means no access. + + Raises: + Exception: Propagates any SpiceDB communication errors to the caller. + """ + subject_type = subject_type or self.subject_type + resource_type = resource_type or self.resource_type + permission = permission or self.permission + + resp = self.client.LookupResources( + LookupResourcesRequest( + subject=SubjectReference( + object=ObjectReference( + object_type=subject_type, + object_id=subject_id, + ) + ), + permission=permission, + resource_object_type=resource_type, + ) + ) + + authorized_ids = [] + async for response in resp: + authorized_ids.append(response.resource_object_id) + + return authorized_ids async def _batch_check_permissions( self, @@ -274,47 +312,36 @@ async def _batch_check_permissions( if not resource_ids: return [] - try: - # Create bulk permission check items - items = [ - CheckBulkPermissionsRequestItem( - resource=ObjectReference( - object_type=resource_type, - object_id=str(resource_id), + items = [ + CheckBulkPermissionsRequestItem( + resource=ObjectReference( + object_type=resource_type, + object_id=str(resource_id), + ), + permission=permission, + subject=SubjectReference( + object=ObjectReference( + object_type=subject_type, + object_id=subject_id, ), - permission=permission, - subject=SubjectReference( - object=ObjectReference( - object_type=subject_type, - object_id=subject_id, - ), - ), - ) - for resource_id in resource_ids - ] - - # Make single bulk permission check request (await async gRPC call) - response = await self.client.CheckBulkPermissions( - CheckBulkPermissionsRequest(items=items) + ), ) + for resource_id in resource_ids + ] - # Extract authorized resource IDs from response - authorized_ids = [] - for i, pair in enumerate(response.pairs): - if ( - pair.item.permissionship - == CheckPermissionResponse.PERMISSIONSHIP_HAS_PERMISSION - ): - authorized_ids.append(resource_ids[i]) - - return authorized_ids - - except Exception: - # Fail-open: return all IDs if configured to do so - if self.fail_open: - return resource_ids - # Fail-closed: return empty list on error - return [] + response = await self.client.CheckBulkPermissions( + CheckBulkPermissionsRequest(items=items) + ) + + authorized_ids = [] + for i, pair in enumerate(response.pairs): + if ( + pair.item.permissionship + == CheckPermissionResponse.PERMISSIONSHIP_HAS_PERMISSION + ): + authorized_ids.append(resource_ids[i]) + + return authorized_ids def _get_resource_id(self, doc: Any) -> Optional[str]: """ diff --git a/langchain_spicedb/langchain_runnable.py b/langchain_spicedb/langchain_runnable.py index 60d2b7b..12fb43d 100644 --- a/langchain_spicedb/langchain_runnable.py +++ b/langchain_spicedb/langchain_runnable.py @@ -47,7 +47,6 @@ def __init__( subject_type: str = "user", permission: str = "view", resource_id_key: str = "resource_id", - fail_open: bool = False, use_tls: bool = False, subject_id: Optional[str] = None, return_metrics: bool = False, @@ -62,7 +61,6 @@ def __init__( subject_type: SpiceDB subject type (e.g., "user") permission: Permission to check (e.g., "view", "edit") resource_id_key: Key in document metadata containing resource ID - fail_open: If True, allow access on errors use_tls: Whether to use TLS for SpiceDB connection subject_id: Default subject ID (can be overridden in config) return_metrics: If True, return AuthorizationResult instead of just docs @@ -75,7 +73,6 @@ def __init__( subject_type=subject_type, permission=permission, resource_id_key=resource_id_key, - fail_open=fail_open, use_tls=use_tls, ) self.default_subject_id = subject_id @@ -163,7 +160,6 @@ def with_config(self, subject_id: Optional[str] = None, **kwargs) -> "SpiceDBAut subject_type=self.authorizer.subject_type, permission=self.authorizer.permission, resource_id_key=self.authorizer.resource_id_key, - fail_open=self.authorizer.fail_open, use_tls=self.authorizer.use_tls, subject_id=subject_id or self.default_subject_id, return_metrics=self.return_metrics, @@ -205,7 +201,6 @@ def __init__( permission: str = "view", resource_id_key: str = "resource_id", subject_id: str = None, - fail_open: bool = False, ): """Initialize the authorization lambda.""" self.authorizer = SpiceDBAuthorizer( @@ -215,7 +210,6 @@ def __init__( subject_type=subject_type, permission=permission, resource_id_key=resource_id_key, - fail_open=fail_open, ) self.subject_id = subject_id diff --git a/langchain_spicedb/langgraph_node.py b/langchain_spicedb/langgraph_node.py index debde1c..ae3b8d6 100644 --- a/langchain_spicedb/langgraph_node.py +++ b/langchain_spicedb/langgraph_node.py @@ -52,7 +52,6 @@ def create_auth_node( subject_type: str = "user", permission: str = "view", resource_id_key: str = "resource_id", - fail_open: bool = False, use_tls: bool = False, ): """ @@ -68,7 +67,6 @@ def create_auth_node( subject_type: SpiceDB subject type permission: Permission to check resource_id_key: Key in document metadata containing resource ID - fail_open: If True, allow access on errors use_tls: Whether to use TLS for SpiceDB connection Returns: @@ -94,7 +92,6 @@ def create_auth_node( subject_type=subject_type, permission=permission, resource_id_key=resource_id_key, - fail_open=fail_open, use_tls=use_tls, ) @@ -174,7 +171,6 @@ def __init__( subject_type: str = "user", permission: str = "view", resource_id_key: str = "resource_id", - fail_open: bool = False, use_tls: bool = False, state_keys: Optional[Dict[str, str]] = None, ): @@ -188,7 +184,6 @@ def __init__( subject_type: SpiceDB subject type permission: Permission to check resource_id_key: Key in document metadata containing resource ID - fail_open: If True, allow access on errors use_tls: Whether to use TLS for SpiceDB connection state_keys: Custom state key mappings (e.g., {"documents": "docs"}) """ @@ -199,7 +194,6 @@ def __init__( subject_type=subject_type, permission=permission, resource_id_key=resource_id_key, - fail_open=fail_open, use_tls=use_tls, ) diff --git a/langchain_spicedb/retrievers.py b/langchain_spicedb/retrievers.py index 1049687..a133855 100644 --- a/langchain_spicedb/retrievers.py +++ b/langchain_spicedb/retrievers.py @@ -5,7 +5,8 @@ existing retrievers with SpiceDB authorization. """ -from typing import List, Optional, Any +from typing import List, Optional, Any, Callable, Dict +from pydantic import ConfigDict from langchain_core.retrievers import BaseRetriever from langchain_core.documents import Document from langchain_core.callbacks import CallbackManagerForRetrieverRun @@ -72,9 +73,6 @@ class SpiceDBRetriever(BaseRetriever): resource_id_key: str = "resource_id" """Key in document metadata containing resource ID.""" - fail_open: bool = False - """If True, allow access on errors; if False, deny on errors.""" - use_tls: bool = False """Whether to use TLS for SpiceDB connection.""" @@ -91,7 +89,6 @@ def __init__( subject_type: str = "user", permission: str = "view", resource_id_key: str = "resource_id", - fail_open: bool = False, use_tls: bool = False, **kwargs: Any, ): @@ -107,7 +104,6 @@ def __init__( subject_type: SpiceDB subject type permission: Permission to check resource_id_key: Key in document metadata containing resource ID - fail_open: If True, allow access on errors use_tls: Whether to use TLS for SpiceDB connection **kwargs: Additional arguments passed to BaseRetriever """ @@ -121,7 +117,6 @@ def __init__( subject_type=subject_type, permission=permission, resource_id_key=resource_id_key, - fail_open=fail_open, use_tls=use_tls, **kwargs, ) @@ -134,7 +129,6 @@ def __init__( subject_type=self.subject_type, permission=self.permission, resource_id_key=self.resource_id_key, - fail_open=self.fail_open, use_tls=self.use_tls, ) @@ -212,3 +206,185 @@ def with_config( updates = {"subject_id": subject_id or self.subject_id} updates.update(kwargs) return self.model_copy(update=updates) + + +class SpiceDBPreFilterRetriever(BaseRetriever): + """ + LangChain retriever that pre-filters using SpiceDB's LookupResources API. + + This retriever follows the pre-filter authorization pattern: + 1. Call SpiceDB LookupResources to get all resource IDs the user can access + 2. Pass those IDs through filter_factory to build vector store search kwargs + 3. Run similarity_search with the filter applied + 4. Return only semantically relevant documents the user is authorized to see + + Use this over SpiceDBRetriever when users have access to a small fraction + of a large corpus and you want to avoid retrieving unauthorized content. + + Example: + >>> from langchain_spicedb import SpiceDBPreFilterRetriever + >>> + >>> retriever = SpiceDBPreFilterRetriever( + ... vector_store=knowledge, + ... filter_factory=lambda ids: {"filter": {"article_id": {"$in": ids}}}, + ... subject_id="tim", + ... resource_type="article", + ... permission="view", + ... spicedb_endpoint="localhost:50051", + ... spicedb_token="sometoken", + ... ) + >>> + >>> chain = retriever | prompt | llm + >>> answer = await chain.ainvoke("What is SpiceDB?") + """ + + model_config = ConfigDict(arbitrary_types_allowed=True) + + vector_store: Any + """The vector store to search after pre-filtering by authorized IDs.""" + + filter_factory: Callable[[List[str]], Dict[str, Any]] + """ + Required. Converts a list of authorized resource IDs into search_kwargs + for the vector store's similarity_search call. + + Example for Pinecone: + lambda ids: {"filter": {"article_id": {"$in": ids}}} + + Example for Chroma: + lambda ids: {"where": {"article_id": {"$in": ids}}} + """ + + subject_id: str + """User ID to look up authorized resources for.""" + + spicedb_endpoint: str = "localhost:50051" + """SpiceDB server address.""" + + spicedb_token: str = "sometoken" + """Pre-shared key for SpiceDB authentication.""" + + resource_type: str = "document" + """SpiceDB resource type (e.g., 'document', 'article').""" + + subject_type: str = "user" + """SpiceDB subject type (e.g., 'user').""" + + permission: str = "view" + """Permission to check (e.g., 'view', 'edit').""" + + use_tls: bool = False + """Whether to use TLS for SpiceDB connection.""" + + k: int = 4 + """Number of documents to retrieve from the vector store.""" + + _authorizer: Optional[SpiceDBAuthorizer] = None + """Internal SpiceDB authorizer instance.""" + + def __init__( + self, + vector_store: Any, + filter_factory: Callable[[List[str]], Dict[str, Any]], + subject_id: str, + spicedb_endpoint: str = "localhost:50051", + spicedb_token: str = "sometoken", + resource_type: str = "document", + subject_type: str = "user", + permission: str = "view", + use_tls: bool = False, + k: int = 4, + **kwargs: Any, + ): + """ + Initialize SpiceDBPreFilterRetriever. + + Args: + vector_store: The vector store to search (must support asimilarity_search) + filter_factory: Converts authorized IDs to vector store search_kwargs + subject_id: User ID to look up authorized resources for + spicedb_endpoint: SpiceDB server address + spicedb_token: Pre-shared key for SpiceDB authentication + resource_type: SpiceDB resource type + subject_type: SpiceDB subject type + permission: Permission to check + use_tls: Whether to use TLS for SpiceDB connection + k: Number of documents to return from vector store + """ + super().__init__( + vector_store=vector_store, + filter_factory=filter_factory, + subject_id=subject_id, + spicedb_endpoint=spicedb_endpoint, + spicedb_token=spicedb_token, + resource_type=resource_type, + subject_type=subject_type, + permission=permission, + use_tls=use_tls, + k=k, + **kwargs, + ) + + self._authorizer = SpiceDBAuthorizer( + spicedb_endpoint=self.spicedb_endpoint, + spicedb_token=self.spicedb_token, + resource_type=self.resource_type, + subject_type=self.subject_type, + permission=self.permission, + use_tls=self.use_tls, + ) + + def _get_relevant_documents( + self, + query: str, + *, + run_manager: CallbackManagerForRetrieverRun, + ) -> List[Document]: + """Synchronous retrieval — delegates to async implementation.""" + import asyncio + return asyncio.run(self._aget_relevant_documents(query, run_manager=run_manager)) + + async def _aget_relevant_documents( + self, + query: str, + *, + run_manager: CallbackManagerForRetrieverRun, + ) -> List[Document]: + """ + Pre-filter then retrieve documents. + + 1. LookupResources → authorized_ids + 2. filter_factory(authorized_ids) → search_kwargs + 3. vector_store.asimilarity_search(query, k, **search_kwargs) → docs + """ + authorized_ids = await self._authorizer.lookup_resources( + subject_id=self.subject_id, + ) + + if not authorized_ids: + return [] + + search_kwargs = self.filter_factory(authorized_ids) + docs = await self.vector_store.asimilarity_search( + query, k=self.k, **search_kwargs + ) + return docs + + def with_config( + self, + subject_id: Optional[str] = None, + **kwargs: Any, + ) -> "SpiceDBPreFilterRetriever": + """ + Create a new retriever with an updated subject_id. + + Args: + subject_id: New subject ID to use + **kwargs: Additional fields to update + + Returns: + New SpiceDBPreFilterRetriever instance + """ + updates = {"subject_id": subject_id if subject_id is not None else self.subject_id} + updates.update(kwargs) + return self.model_copy(update=updates) diff --git a/langchain_spicedb/tools.py b/langchain_spicedb/tools.py index a4eb97c..4387fd5 100644 --- a/langchain_spicedb/tools.py +++ b/langchain_spicedb/tools.py @@ -25,7 +25,6 @@ class SpiceDBAuthTool(BaseTool): ) resource_type: str = Field(default="document", description="SpiceDB resource type") subject_type: str = Field(default="user", description="SpiceDB subject type") - fail_open: bool = Field(default=False, description="If True, allow access on errors") use_tls: bool = Field(default=False, description="Whether to use TLS for SpiceDB connection") _authorizer: Optional[SpiceDBAuthorizer] = None @@ -36,7 +35,6 @@ def __init__( spicedb_token: str = "sometoken", resource_type: str = "document", subject_type: str = "user", - fail_open: bool = False, use_tls: bool = False, **kwargs: Any, ): @@ -48,7 +46,6 @@ def __init__( spicedb_token: Pre-shared key for SpiceDB authentication resource_type: SpiceDB resource type (e.g., 'document', 'article') subject_type: SpiceDB subject type (e.g., 'user') - fail_open: If True, allow access on errors use_tls: Whether to use TLS for SpiceDB connection **kwargs: Additional arguments passed to BaseTool """ @@ -58,7 +55,6 @@ def __init__( spicedb_token=spicedb_token, resource_type=resource_type, subject_type=subject_type, - fail_open=fail_open, use_tls=use_tls, **kwargs, ) @@ -70,7 +66,6 @@ def __init__( resource_type=self.resource_type, subject_type=self.subject_type, permission="view", # Default, can be overridden per call - fail_open=self.fail_open, use_tls=self.use_tls, ) diff --git a/tests/integration_tests/test_retrievers.py b/tests/integration_tests/test_retrievers.py index 91dfe1b..3dc746b 100644 --- a/tests/integration_tests/test_retrievers.py +++ b/tests/integration_tests/test_retrievers.py @@ -219,23 +219,3 @@ def test_retriever_with_tls(self, base_retriever, spicedb_config): docs = retriever.invoke("test query") assert isinstance(docs, list) - @pytest.mark.skipif( - not os.getenv("SPICEDB_ENDPOINT"), - reason="SPICEDB_ENDPOINT not set - skipping integration test", - ) - def test_retriever_fail_open_behavior(self, base_retriever, spicedb_config): - """Test retriever with fail_open=True.""" - retriever = SpiceDBRetriever( - base_retriever=base_retriever, - subject_id="tim", - subject_type="user", - resource_type="article", - resource_id_key="article_id", - permission="view", - fail_open=True, # Allow documents on SpiceDB errors - **spicedb_config, - ) - - # Should not raise errors even if SpiceDB has issues - docs = retriever.invoke("test query") - assert isinstance(docs, list) diff --git a/tests/integration_tests/test_tools.py b/tests/integration_tests/test_tools.py index 1308b53..b133b86 100644 --- a/tests/integration_tests/test_tools.py +++ b/tests/integration_tests/test_tools.py @@ -166,23 +166,6 @@ def test_tool_with_tls(self, spicedb_config): result = tool._run(subject_id="tim", resource_id="123", permission="view") assert result in ["true", "false"] - @pytest.mark.skipif( - not os.getenv("SPICEDB_ENDPOINT"), - reason="SPICEDB_ENDPOINT not set - skipping integration test", - ) - def test_tool_with_fail_open(self, spicedb_config): - """Test tool with fail_open=True.""" - tool = SpiceDBPermissionTool( - subject_type="user", - resource_type="article", - fail_open=True, - **spicedb_config, - ) - - # Should not raise errors even if SpiceDB has issues - result = tool._run(subject_id="tim", resource_id="123", permission="view") - assert result in ["true", "false"] - class TestSpiceDBBulkPermissionToolIntegration: """Integration tests for SpiceDBBulkPermissionTool with real SpiceDB.""" @@ -348,20 +331,3 @@ def test_bulk_tool_with_tls(self, spicedb_config): result = tool._run(subject_id="tim", resource_ids="123,456", permission="view") assert isinstance(result, str) - - @pytest.mark.skipif( - not os.getenv("SPICEDB_ENDPOINT"), - reason="SPICEDB_ENDPOINT not set - skipping integration test", - ) - def test_bulk_tool_with_fail_open(self, spicedb_config): - """Test bulk tool with fail_open=True.""" - tool = SpiceDBBulkPermissionTool( - subject_type="user", - resource_type="article", - fail_open=True, - **spicedb_config, - ) - - # Should not raise errors even if SpiceDB has issues - result = tool._run(subject_id="tim", resource_ids="123,456", permission="view") - assert isinstance(result, str) diff --git a/tests/unit_tests/test_retrievers.py b/tests/unit_tests/test_retrievers.py index 3aeaacc..73afdc3 100644 --- a/tests/unit_tests/test_retrievers.py +++ b/tests/unit_tests/test_retrievers.py @@ -9,7 +9,8 @@ from langchain_core.documents import Document from langchain_core.retrievers import BaseRetriever -from langchain_spicedb import SpiceDBRetriever +from langchain_spicedb import SpiceDBRetriever, SpiceDBPreFilterRetriever +from langchain_spicedb.core import SpiceDBAuthorizer class MockRetriever(BaseRetriever): @@ -243,16 +244,12 @@ def test_retriever_parameters_passed_to_authorizer(self, mock_base_retriever, mo resource_type="document", resource_id_key="doc_id", permission="edit", - fail_open=True, use_tls=True, ) - # Verify authorizer was initialized mock_authorizer.assert_called_once() - # Verify key parameters were passed (handle both positional and keyword args) call_args = mock_authorizer.call_args - # call_args is a tuple of (args, kwargs) or call object if hasattr(call_args, "kwargs"): call_kwargs = call_args.kwargs else: @@ -263,3 +260,265 @@ def test_retriever_parameters_passed_to_authorizer(self, mock_base_retriever, mo assert call_kwargs.get("subject_type") == "user" assert call_kwargs.get("resource_type") == "document" assert call_kwargs.get("permission") == "edit" + + +class TestSpiceDBRetrieverErrorHandling: + """Tests that SpiceDB errors propagate as exceptions.""" + + @pytest.fixture + def mock_base_retriever(self): + return MockRetriever() + + @pytest.mark.asyncio + async def test_spicedb_error_raises_exception(self, mock_base_retriever): + """SpiceDB authorization failure must raise, not silently pass.""" + with patch("langchain_spicedb.retrievers.SpiceDBAuthorizer") as mock_auth_class: + mock_instance = AsyncMock() + mock_instance.filter_documents = AsyncMock( + side_effect=Exception("SpiceDB connection refused") + ) + mock_auth_class.return_value = mock_instance + + retriever = SpiceDBRetriever( + base_retriever=mock_base_retriever, + spicedb_endpoint="localhost:50051", + spicedb_token="test_token", + subject_id="alice", + subject_type="user", + resource_type="article", + resource_id_key="article_id", + permission="view", + ) + + with pytest.raises(Exception, match="SpiceDB connection refused"): + await retriever.ainvoke("test query") + + +class TestSpiceDBAuthorizerLookupResources: + """Unit tests for SpiceDBAuthorizer.lookup_resources.""" + + @pytest.mark.asyncio + async def test_lookup_resources_returns_authorized_ids(self): + """lookup_resources streams responses and returns resource IDs.""" + with patch("langchain_spicedb.core.Client") as mock_client_class: + mock_client = Mock() + mock_client_class.return_value = mock_client + + async def mock_stream(): + responses = [ + Mock(resource_object_id="123"), + Mock(resource_object_id="456"), + ] + for r in responses: + yield r + + mock_client.LookupResources = Mock(return_value=mock_stream()) + + authorizer = SpiceDBAuthorizer( + spicedb_endpoint="localhost:50051", + spicedb_token="test_token", + resource_type="article", + subject_type="user", + permission="view", + ) + + result = await authorizer.lookup_resources(subject_id="tim") + + assert result == ["123", "456"] + mock_client.LookupResources.assert_called_once() + call_args = mock_client.LookupResources.call_args + request = call_args[0][0] # first positional arg + assert request.permission == "view" + assert request.resource_object_type == "article" + assert request.subject.object.object_type == "user" + assert request.subject.object.object_id == "tim" + + @pytest.mark.asyncio + async def test_lookup_resources_returns_empty_when_no_access(self): + """lookup_resources returns [] when the user has no authorized resources.""" + with patch("langchain_spicedb.core.Client") as mock_client_class: + mock_client = Mock() + mock_client_class.return_value = mock_client + + async def mock_stream(): + return + yield # makes this an async generator; never reached + + mock_client.LookupResources = Mock(return_value=mock_stream()) + + authorizer = SpiceDBAuthorizer( + spicedb_endpoint="localhost:50051", + spicedb_token="test_token", + resource_type="article", + ) + + result = await authorizer.lookup_resources(subject_id="tim") + + assert result == [] + mock_client.LookupResources.assert_called_once() + + @pytest.mark.asyncio + async def test_lookup_resources_propagates_error(self): + """lookup_resources raises when SpiceDB call fails.""" + with patch("langchain_spicedb.core.Client") as mock_client_class: + mock_client = Mock() + mock_client_class.return_value = mock_client + + async def mock_stream_error(): + raise Exception("SpiceDB timeout") + yield # makes this an async generator; never reached + + mock_client.LookupResources = Mock(return_value=mock_stream_error()) + + authorizer = SpiceDBAuthorizer( + spicedb_endpoint="localhost:50051", + spicedb_token="test_token", + resource_type="article", + ) + + with pytest.raises(Exception, match="SpiceDB timeout"): + await authorizer.lookup_resources(subject_id="tim") + + +class TestSpiceDBPreFilterRetrieverUnit: + """Unit tests for SpiceDBPreFilterRetriever.""" + + @pytest.fixture + def mock_vector_store(self): + """Mock vector store that returns one document.""" + mock = AsyncMock() + mock.asimilarity_search = AsyncMock(return_value=[ + Document(page_content="Doc 1", metadata={"article_id": "123"}), + ]) + return mock + + @pytest.fixture + def mock_authorizer(self): + """Mock authorizer returning two authorized IDs.""" + with patch("langchain_spicedb.retrievers.SpiceDBAuthorizer") as mock: + mock_instance = AsyncMock() + mock_instance.lookup_resources = AsyncMock(return_value=["123", "456"]) + mock.return_value = mock_instance + yield mock + + def test_pre_filter_retriever_initialization(self, mock_vector_store, mock_authorizer): + """SpiceDBPreFilterRetriever stores all config correctly.""" + retriever = SpiceDBPreFilterRetriever( + vector_store=mock_vector_store, + filter_factory=lambda ids: {"filter": {"article_id": {"$in": ids}}}, + subject_id="tim", + resource_type="article", + permission="view", + spicedb_endpoint="localhost:50051", + spicedb_token="test_token", + ) + assert retriever.subject_id == "tim" + assert retriever.resource_type == "article" + assert retriever.permission == "view" + assert retriever.k == 4 + + def test_pre_filter_retriever_is_base_retriever(self): + """SpiceDBPreFilterRetriever is a LangChain BaseRetriever.""" + from langchain_core.retrievers import BaseRetriever + assert issubclass(SpiceDBPreFilterRetriever, BaseRetriever) + + @pytest.mark.asyncio + async def test_lookup_called_then_vector_store_searched(self, mock_vector_store, mock_authorizer): + """Retriever calls lookup_resources first, then similarity_search with filter.""" + filter_factory = lambda ids: {"filter": {"article_id": {"$in": ids}}} + retriever = SpiceDBPreFilterRetriever( + vector_store=mock_vector_store, + filter_factory=filter_factory, + subject_id="tim", + resource_type="article", + permission="view", + spicedb_endpoint="localhost:50051", + spicedb_token="test_token", + ) + + docs = await retriever.ainvoke("test query") + + mock_authorizer.return_value.lookup_resources.assert_called_once_with(subject_id="tim") + mock_vector_store.asimilarity_search.assert_called_once_with( + "test query", + k=4, + filter={"article_id": {"$in": ["123", "456"]}}, + ) + assert len(docs) == 1 + assert docs[0].metadata["article_id"] == "123" + + @pytest.mark.asyncio + async def test_empty_authorized_ids_skips_vector_store(self, mock_vector_store, mock_authorizer): + """When no resources are authorized, returns [] without querying the vector store.""" + mock_authorizer.return_value.lookup_resources = AsyncMock(return_value=[]) + retriever = SpiceDBPreFilterRetriever( + vector_store=mock_vector_store, + filter_factory=lambda ids: {"filter": {"article_id": {"$in": ids}}}, + subject_id="tim", + resource_type="article", + permission="view", + spicedb_endpoint="localhost:50051", + spicedb_token="test_token", + ) + + docs = await retriever.ainvoke("test query") + + assert docs == [] + mock_vector_store.asimilarity_search.assert_not_called() + + @pytest.mark.asyncio + async def test_spicedb_error_propagates(self, mock_vector_store, mock_authorizer): + """SpiceDB errors are raised to the caller, never swallowed.""" + mock_authorizer.return_value.lookup_resources = AsyncMock( + side_effect=Exception("SpiceDB unavailable") + ) + retriever = SpiceDBPreFilterRetriever( + vector_store=mock_vector_store, + filter_factory=lambda ids: {"filter": {"article_id": {"$in": ids}}}, + subject_id="tim", + resource_type="article", + permission="view", + spicedb_endpoint="localhost:50051", + spicedb_token="test_token", + ) + + with pytest.raises(Exception, match="SpiceDB unavailable"): + await retriever.ainvoke("test query") + + @pytest.mark.asyncio + async def test_custom_k_forwarded_to_similarity_search(self, mock_vector_store, mock_authorizer): + """k parameter is forwarded to asimilarity_search.""" + retriever = SpiceDBPreFilterRetriever( + vector_store=mock_vector_store, + filter_factory=lambda ids: {"filter": {"article_id": {"$in": ids}}}, + subject_id="tim", + resource_type="article", + permission="view", + spicedb_endpoint="localhost:50051", + spicedb_token="test_token", + k=10, + ) + assert retriever.k == 10 + + await retriever.ainvoke("test query") + + mock_vector_store.asimilarity_search.assert_called_once_with( + "test query", + k=10, + filter={"article_id": {"$in": ["123", "456"]}}, + ) + + def test_with_config_returns_new_instance_with_updated_subject(self, mock_vector_store, mock_authorizer): + """with_config creates a new retriever with updated subject_id.""" + retriever = SpiceDBPreFilterRetriever( + vector_store=mock_vector_store, + filter_factory=lambda ids: {"filter": {"article_id": {"$in": ids}}}, + subject_id="tim", + resource_type="article", + permission="view", + spicedb_endpoint="localhost:50051", + spicedb_token="test_token", + ) + new_retriever = retriever.with_config(subject_id="alice") + assert new_retriever.subject_id == "alice" + assert retriever.subject_id == "tim" # original unchanged diff --git a/tests/unit_tests/test_tools.py b/tests/unit_tests/test_tools.py index 2e1f3c0..f0b9426 100644 --- a/tests/unit_tests/test_tools.py +++ b/tests/unit_tests/test_tools.py @@ -160,7 +160,6 @@ def test_tool_parameters_passed_to_authorizer(self, mock_authorizer): spicedb_token="custom_token", subject_type="service", resource_type="document", - fail_open=True, use_tls=True, )