diff --git a/.gitignore b/.gitignore index d71fb6d4..53bb56dd 100644 --- a/.gitignore +++ b/.gitignore @@ -78,3 +78,6 @@ uv.lock # VS Code .vscode/ + +# evals +evals/ \ No newline at end of file diff --git a/README.md b/README.md index 730d894b..43aacae7 100644 --- a/README.md +++ b/README.md @@ -201,6 +201,7 @@ insert_documents(documents, config=my_config) > [!TIP] > 📝 Documents can include metadata by passing keyword arguments to `Document.from_text()` or `Document.from_path()`. This metadata can later be used for filtering during retrieval. +> For list values, metadata is stored as-is (e.g. `domain=["open", "music"]`). You may also want to expand the document metadata before insertion: @@ -308,6 +309,14 @@ chunk_ids_hybrid, _ = hybrid_search( user_prompt, num_results=20, metadata_filter={"topic": "physics"}, config=my_config ) # Filter results to only include chunks from documents with topic="physics" (works with any search method) +# Multi-value filter in one field uses OR semantics: +chunk_ids_or, _ = hybrid_search( + user_prompt, + num_results=20, + metadata_filter={"domain": ["open", "music"]}, + config=my_config, +) # Returns chunks where domain includes "open" OR "music". + # Retrieve chunks from raglite import retrieve_chunks diff --git a/src/raglite/_database.py b/src/raglite/_database.py index 7ea5cc62..b6827f0a 100644 --- a/src/raglite/_database.py +++ b/src/raglite/_database.py @@ -48,7 +48,7 @@ MetadataJSON = JSON().with_variant(JSONB(), "postgresql") -def _adapt_metadata(metadata: Any) -> dict[str, MetadataValue | list[MetadataValue]]: +def _adapt_metadata(metadata: Any) -> dict[str, list[MetadataValue]]: """Adapt metadata to the format expected by the database.""" if not metadata: return {} diff --git a/src/raglite/_delete.py b/src/raglite/_delete.py index 2120854d..896217a2 100644 --- a/src/raglite/_delete.py +++ b/src/raglite/_delete.py @@ -1,13 +1,11 @@ """Delete documents from the database.""" -import json from contextlib import nullcontext from pathlib import Path from typing import Any, Literal from filelock import FileLock -from sqlalchemy import delete, func, text, update -from sqlalchemy.dialects.postgresql import JSONB +from sqlalchemy import delete, text, update from sqlalchemy.engine import make_url from sqlalchemy.orm import load_only from sqlalchemy.orm.attributes import flag_modified @@ -21,10 +19,10 @@ Eval, IndexMetadata, Metadata, - _adapt_metadata, create_database_engine, ) from raglite._insert import _aggregate_metadata_from_documents +from raglite._metadata_filter import build_metadata_filter_condition from raglite._typing import DocumentId @@ -49,17 +47,14 @@ def _get_documents_with_metadata( metadata_filter: dict[str, Any], session: Session ) -> list[DocumentId]: """Get document IDs matching a metadata filter.""" - metadata_filter = _adapt_metadata(metadata_filter) - - # Determine the filter condition based on the database engine - if session.get_bind().dialect.name == "postgresql": - condition = col(Document.metadata_).cast(JSONB).op("@>")(metadata_filter) # type: ignore[attr-defined] - else: - condition = func.json_contains( - col(Document.metadata_), func.json(json.dumps(metadata_filter)) - ) - - statement = select(Document.id).where(condition) + condition = build_metadata_filter_condition( + Document.metadata_, + metadata_filter, + dialect=session.get_bind().dialect.name, + ) + statement = select(Document.id) + if condition is not None: + statement = statement.where(condition) return list(session.exec(statement).all()) diff --git a/src/raglite/_embed.py b/src/raglite/_embed.py index f8a4266e..2e3185dd 100644 --- a/src/raglite/_embed.py +++ b/src/raglite/_embed.py @@ -1,6 +1,7 @@ """String embedder.""" from functools import partial +from threading import Lock from typing import Literal import numpy as np @@ -12,6 +13,8 @@ from raglite._litellm import LlamaCppPythonLLM from raglite._typing import FloatMatrix, IntVector +LLAMA_EMBED_LOCK = Lock() + def embed_strings_with_late_chunking( # noqa: C901,PLR0915 sentences: list[str], *, config: RAGLiteConfig | None = None @@ -116,7 +119,8 @@ def _create_segment( # Get the token embeddings of the entire segment, including preamble and content. segment_start_index, content_start_index, segment_end_index = segment segment_sentences = sentences[segment_start_index:segment_end_index] - segment_embedding = np.asarray(embedder.embed("".join(segment_sentences))) + with LLAMA_EMBED_LOCK: + segment_embedding = np.asarray(embedder.embed("".join(segment_sentences))) # Split the segment embeddings into embedding matrices per sentence using the largest # remainder method. segment_tokens = num_tokens[segment_start_index:segment_end_index] @@ -151,7 +155,8 @@ def _embed_string_batch(string_batch: list[str], *, config: RAGLiteConfig) -> Fl embedder = LlamaCppPythonLLM.llm( config.embedder, embedding=True, pooling_type=LLAMA_POOLING_TYPE_NONE ) - embeddings = np.asarray([np.mean(row, axis=0) for row in embedder.embed(string_batch)]) + with LLAMA_EMBED_LOCK: + embeddings = np.asarray([np.mean(row, axis=0) for row in embedder.embed(string_batch)]) else: # Use LiteLLM's API to embed the batch of strings. response = embedding(config.embedder, string_batch) diff --git a/src/raglite/_litellm.py b/src/raglite/_litellm.py index ab9efec3..f93c5c70 100644 --- a/src/raglite/_litellm.py +++ b/src/raglite/_litellm.py @@ -8,6 +8,7 @@ from collections.abc import AsyncIterator, Callable, Iterator from functools import cache from io import StringIO +from threading import Lock from typing import Any, ClassVar, cast import httpx @@ -35,6 +36,7 @@ # Reduce the logging level for LiteLLM, flashrank, and httpx. litellm.suppress_debug_info = True +litellm.drop_params = True # Drop unsupported parameters for models like GPT-5 os.environ["LITELLM_LOG"] = "WARNING" logging.getLogger("LiteLLM").setLevel(logging.WARNING) logging.getLogger("flashrank").setLevel(logging.WARNING) @@ -62,8 +64,9 @@ class LlamaCppPythonLLM(CustomLLM): ``` """ - # Create a lock to prevent concurrent access to llama-cpp-python models. + # Create locks to prevent concurrent access to llama-cpp-python models. streaming_lock: ClassVar[asyncio.Lock] = asyncio.Lock() + completion_lock: ClassVar[Lock] = Lock() # The set of supported OpenAI parameters is the intersection of [1] and [2]. Not included: # max_completion_tokens, stream_options, n, user, logprobs, top_logprobs, extra_headers. @@ -198,10 +201,11 @@ def completion( # noqa: PLR0913 llm = self.llm(model) llama_cpp_python_params = self._translate_openai_params(optional_params) llama_cpp_python_params = self._add_recommended_model_params(model, llama_cpp_python_params) - response = cast( - "llama_types.CreateChatCompletionResponse", - llm.create_chat_completion(messages=messages, **llama_cpp_python_params), - ) + with LlamaCppPythonLLM.completion_lock: + response = cast( + "llama_types.CreateChatCompletionResponse", + llm.create_chat_completion(messages=messages, **llama_cpp_python_params), + ) litellm_model_response: ModelResponse = convert_to_model_response_object( response_object=response, model_response_object=model_response, diff --git a/src/raglite/_metadata_filter.py b/src/raglite/_metadata_filter.py new file mode 100644 index 00000000..3ecc81a9 --- /dev/null +++ b/src/raglite/_metadata_filter.py @@ -0,0 +1,100 @@ +"""Helpers to build metadata filter conditions with consistent semantics.""" + +import json +from collections.abc import Mapping +from typing import Any + +from sqlalchemy import and_, false, or_ +from sqlalchemy.dialects.postgresql import JSONB +from sqlmodel import col, func + +from raglite._database import _adapt_metadata +from raglite._typing import MetadataFilter, MetadataValue + + +def build_metadata_filter_condition( + metadata_column: Any, + metadata_filter: MetadataFilter | None, + *, + dialect: str, +) -> Any: + """Build a SQLAlchemy condition for metadata filtering. + + A list of values within the same field uses OR semantics. + Different fields are combined with AND semantics. + """ + normalized_metadata_filter = _adapt_metadata(metadata_filter) + if not normalized_metadata_filter: + return None + + field_conditions: list[Any] = [] + for metadata_name, metadata_values in normalized_metadata_filter.items(): + if not metadata_values: + return false() # empty filters are considered unsatisfiable + + value_conditions: list[Any] = [] + for metadata_value in metadata_values: + single_value_filter = {metadata_name: [metadata_value]} + if dialect == "postgresql": + value_conditions.append( + col(metadata_column).cast(JSONB).op("@>")(single_value_filter) # type: ignore[attr-defined] + ) + elif dialect == "duckdb": + value_conditions.append( + func.json_contains( + col(metadata_column), func.json(json.dumps(single_value_filter)) + ) + ) + else: + error_message = f"Unsupported dialect: {dialect}." + raise ValueError(error_message) + field_conditions.append(or_(*value_conditions)) # combine values for the same field with OR + return and_(*field_conditions) # combine different fields with AND + + +def build_metadata_filter_sql( + metadata_filter: Mapping[str, list[MetadataValue] | MetadataValue] | None, + *, + dialect: str, +) -> tuple[str, dict[str, str]]: + """Build SQL fragment and bound parameters for metadata filtering. + + A list of values within the same field uses OR semantics. + Different fields are combined with AND semantics. + + Returns + ------- + sql_fragment : str + A SQL fragment to be included in the WHERE clause, with placeholders for parameters. + parameters : dict + A dictionary of parameter names and their corresponding JSON string values to be used in + the query execution + """ + normalized_metadata_filter = _adapt_metadata(metadata_filter) + if not normalized_metadata_filter: + return "", {} + + field_sql_conditions: list[str] = [] + parameters: dict[str, str] = {} + parameter_index = 0 + + for metadata_name, metadata_values in normalized_metadata_filter.items(): + if not metadata_values: + return " AND 1=0", {} + + value_sql_conditions: list[str] = [] + for metadata_value in metadata_values: + parameter_name = f"metadata_filter_{parameter_index}" + parameter_index += 1 + single_value_filter = json.dumps({metadata_name: [metadata_value]}) + parameters[parameter_name] = single_value_filter + + if dialect == "postgresql": + value_sql_conditions.append(f"metadata::jsonb @> :{parameter_name}") + elif dialect == "duckdb": + value_sql_conditions.append(f"json_contains(metadata, JSON(:{parameter_name}))") + else: + error_message = f"Unsupported dialect: {dialect}." + raise ValueError(error_message) + field_sql_conditions.append(f"({' OR '.join(value_sql_conditions)})") + return f" AND {' AND '.join(field_sql_conditions)}", parameters diff --git a/src/raglite/_search.py b/src/raglite/_search.py index c3306cfe..e1158d60 100644 --- a/src/raglite/_search.py +++ b/src/raglite/_search.py @@ -1,7 +1,6 @@ """Search and retrieve chunks.""" import contextlib -import json import logging import re import string @@ -12,7 +11,6 @@ import numpy as np from langdetect import LangDetectException, detect from pydantic import BaseModel, Field, create_model -from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy.orm import joinedload from sqlmodel import Session, and_, col, func, or_, select, text @@ -28,6 +26,7 @@ from raglite._embed import embed_strings from raglite._extract import extract_with_llm from raglite._insert import _get_database_metadata +from raglite._metadata_filter import build_metadata_filter_condition, build_metadata_filter_sql from raglite._typing import BasicSearchMethod, ChunkId, FloatVector, MetadataFilter, MetadataValue logger = logging.getLogger(__name__) @@ -80,18 +79,12 @@ def vector_search( else: def _apply_metadata_filter(query_builder: Any) -> Any: - if (dialect := session.get_bind().dialect.name) == "postgresql": - # Always cast to JSONB before using @> operator. - return query_builder.where( - col(Chunk.metadata_).cast(JSONB).op("@>")(metadata_filter) # type: ignore[attr-defined] - ) - if dialect == "duckdb": - return query_builder.where( - func.json_contains( - col(Chunk.metadata_), func.json(json.dumps(metadata_filter)) - ) - ) - return query_builder + condition = build_metadata_filter_condition( + Chunk.metadata_, + metadata_filter, + dialect=session.get_bind().dialect.name, + ) + return query_builder.where(condition) if condition is not None else query_builder # Count how many results match the given metadata filter. metadata_count_query = _apply_metadata_filter( @@ -188,9 +181,11 @@ def keyword_search( params = {"tsv_query": tsv_query, "limit": num_results} if metadata_filter: - # Always cast to JSONB before using @> operator - base_sql += " AND metadata::jsonb @> :metadata_filter" - params["metadata_filter"] = json.dumps(metadata_filter) + metadata_filter_sql, metadata_filter_params = build_metadata_filter_sql( + metadata_filter, dialect="postgresql" + ) + base_sql += metadata_filter_sql + params.update(metadata_filter_params) base_sql += """ ORDER BY score DESC @@ -211,8 +206,11 @@ def keyword_search( params = {"query": query, "limit": num_results} if metadata_filter: - base_sql += " AND json_contains(metadata, JSON(:metadata_filter))" - params["metadata_filter"] = json.dumps(metadata_filter) + metadata_filter_sql, metadata_filter_params = build_metadata_filter_sql( + metadata_filter, dialect="duckdb" + ) + base_sql += metadata_filter_sql + params.update(metadata_filter_params) base_sql += """ ) sq @@ -434,23 +432,32 @@ def search_and_rerank_chunk_spans( # noqa: PLR0913 SELF_QUERY_PROMPT = """ -You are an assistant that extracts metadata filters from user queries to help search a knowledge base. +You are an expert assistant that extracts metadata filters from user queries to help search a knowledge base. Instructions: -1. For each metadata field, only populate it if the query explicitly and unambiguously mentions a specific allowed value. -2. If the query is general, ambiguous, or does not mention a field, set it to None. -3. Do NOT infer values from common knowledge or context. -4. For each field, return ONLY the numeric ID(s) from the allowed options below. Do NOT return labels or text. +1. For each metadata field, populate it only if the query can be reasonably mapped to one or more allowed values. +2. If a field clearly matches multiple allowed values, return all corresponding numeric IDs for that field. +3. If the query is general, ambiguous, or does not clearly map to any allowed value for a field, return None for that field. +4. For each populated field, return only the numeric ID(s) defined in the allowed options. Do not return text labels. Do not infer or invent IDs. 5. Output your answer as a JSON object with field names as keys and lists of IDs or None as values. -Example: +Examples: + Allowed options: - category: {0: "Technology", 1: "Health", 2: "Finance"} - region: {0: "Europe", 1: "Asia", 2: "Americas"} -Query: "Show me the latest news in Technology from Asia." -Output: -{"category": [0], "region": [1]} +Query: "Show me the latest news in Technology from Asia and Europe." +Reasoning: The query explicitly mentions "Technology", which matches category ID 0. It also explicitly mentions "Asia" (region ID 1) and "Europe" (region ID 0). Both fields have clear matches. +Output: {"category": [0], "region": [1, 0]} + +Query: "Show me Health articles." +Reasoning: The query explicitly mentions "Health", which matches category ID 1. The query does not mention any region, so region is None. +Output: {"category": [1], "region": null} + +Query: "What is the price of a Bugatti Chiron?" +Reasoning: The query does not mention any category ("Technology", "Health", or "Finance") or any region ("Europe", "Asia", or "Americas"). No fields match. +Output: {"category": null, "region": null} """.strip() @@ -495,19 +502,24 @@ def _self_query( return_type=metadata_filter_model, user_prompt=query, config=config, - temperature=0, + temperature=0.0, # Deterministic output if the model allows ) except ValueError as e: logger.debug("Failed to extract metadata filter: %s", e) return {} else: - metadata_filter = result.model_dump(exclude_none=True) - # Convert from field IDs to actual metadata values. - for field, value_ids in metadata_filter.items(): - if field in field_ids_mapping: - metadata_filter[field] = [ - field_ids_mapping[field].get(value_id) + # Convert the extracted metadata filter from IDs back to actual metadata values. + metadata_filter_by_id = result.model_dump(exclude_none=True) + metadata_filter: dict[str, list[MetadataValue] | MetadataValue] = {} + for field, value_ids in metadata_filter_by_id.items(): + value_mapping = field_ids_mapping.get(field, {}) + metadata_values = list( + dict.fromkeys( + value_mapping[value_id] for value_id in value_ids - if value_id in field_ids_mapping[field] - ] + if value_id in value_mapping # handle potential out-of-range IDs gracefully + ) + ) + if metadata_values: + metadata_filter[field] = metadata_values return metadata_filter diff --git a/tests/conftest.py b/tests/conftest.py index a2d4b3e4..6af7c74b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -131,3 +131,13 @@ def raglite_test_config(database: str, llm: str, embedder: str) -> RAGLiteConfig metadata: dict[str, Any] = {"type": "Paper", "topic": "Physics", "author": "Albert Einstein"} insert_documents([Document.from_path(doc_path, **metadata)], config=db_config) return db_config + + +@pytest.fixture +def isolated_raglite_test_config(raglite_test_config: RAGLiteConfig) -> RAGLiteConfig: + """Create an isolated in-memory config for tests that insert custom datasets.""" + return RAGLiteConfig( + db_url="duckdb:///:memory:", + llm=raglite_test_config.llm, + embedder=raglite_test_config.embedder, + ) diff --git a/tests/test_delete.py b/tests/test_delete.py index f3ce122a..cb5256ea 100644 --- a/tests/test_delete.py +++ b/tests/test_delete.py @@ -2,6 +2,7 @@ import dataclasses from typing import Any +from uuid import uuid4 import numpy as np from sqlmodel import Session, SQLModel @@ -77,6 +78,61 @@ def test_delete_by_metadata(raglite_test_config: RAGLiteConfig) -> None: assert state_after.keys() == state_before.keys() +def test_delete_by_metadata_multiple_values_match_any( + isolated_raglite_test_config: RAGLiteConfig, +) -> None: + """Delete documents when a multi-value metadata filter matches multiple values.""" + unique_topic = f"delete-topic-{uuid4().hex}" + document_open = Document.from_text( + "Open domain document for delete test.", + id=f"{unique_topic}-open", + domain=["open"], + topic=unique_topic, + ) + document_sports = Document.from_text( + "Sports domain document for delete test.", + id=f"{unique_topic}-sports", + domain=["sports"], + topic=unique_topic, + ) + document_open_and_movie = Document.from_text( + "Open and movie domain document for delete test.", + id=f"{unique_topic}-open-movie", + domain=["open", "movie"], # test that multiple values are handled correctly + topic=unique_topic, + ) + document_movie = Document.from_text( + "Movie domain document for delete test.", + id=f"{unique_topic}-movie", + domain=["movie"], # test that single values in a list are handled correctly + topic=unique_topic, + ) + insert_documents( + [document_open, document_sports, document_open_and_movie, document_movie], + config=isolated_raglite_test_config, + ) + + deleted_count = delete_documents_by_metadata( + {"topic": unique_topic, "domain": ["sports", "movie"]}, + config=isolated_raglite_test_config, + ) + assert ( + deleted_count == 3 # noqa: PLR2004 + ), "Expected OR metadata filter to delete the documents with domain='sports' or domain='movie'." + + with Session(create_database_engine(isolated_raglite_test_config)) as session: + remaining_document = session.get(Document, document_open.id) + assert remaining_document is not None + deleted_document = session.get(Document, document_sports.id) + assert deleted_document is None + deleted_document = session.get(Document, document_open_and_movie.id) + assert deleted_document is None + deleted_document = session.get(Document, document_movie.id) + assert deleted_document is None + + delete_documents([document_open.id], config=isolated_raglite_test_config) + + def test_delete_with_multivector_disabled(raglite_test_config: RAGLiteConfig) -> None: """Test document deletion with vector_search_multivector disabled.""" modified_test_config = dataclasses.replace(raglite_test_config, vector_search_multivector=False) diff --git a/tests/test_insert.py b/tests/test_insert.py index 21245d49..a6628cbf 100644 --- a/tests/test_insert.py +++ b/tests/test_insert.py @@ -59,37 +59,27 @@ def test_insert(raglite_test_config: RAGLiteConfig) -> None: def test_insert_reuse_document_instance( - raglite_test_config: RAGLiteConfig, + isolated_raglite_test_config: RAGLiteConfig, ) -> None: """Reuse a document instance across calls without errors.""" - isolated_config = RAGLiteConfig( - db_url="duckdb:///:memory:", - llm=raglite_test_config.llm, - embedder=raglite_test_config.embedder, - ) doc = Document.from_text( content="Reuse instance test content.", url="http://example.com/reuse", filename="reuse_instance.html", id="reuse-instance-test", ) - insert_documents([doc], config=isolated_config) - insert_documents([doc], config=isolated_config) + insert_documents([doc], config=isolated_raglite_test_config) + insert_documents([doc], config=isolated_raglite_test_config) - with Session(create_database_engine(isolated_config)) as session: + with Session(create_database_engine(isolated_raglite_test_config)) as session: documents = session.exec(select(Document).where(Document.id == "reuse-instance-test")).all() assert len(documents) == 1 def test_insert_duplicate_documents_with_same_id( - raglite_test_config: RAGLiteConfig, + isolated_raglite_test_config: RAGLiteConfig, ) -> None: """De-duplicate incoming documents that share the same id.""" - isolated_config = RAGLiteConfig( - db_url="duckdb:///:memory:", - llm=raglite_test_config.llm, - embedder=raglite_test_config.embedder, - ) doc1 = Document.from_text( content="Duplicate id test content.", url="http://example.com/duplicate", @@ -102,8 +92,8 @@ def test_insert_duplicate_documents_with_same_id( filename="duplicate.html", id="duplicate-id-test", ) - insert_documents([doc1, doc2], config=isolated_config) + insert_documents([doc1, doc2], config=isolated_raglite_test_config) - with Session(create_database_engine(isolated_config)) as session: + with Session(create_database_engine(isolated_raglite_test_config)) as session: documents = session.exec(select(Document).where(Document.id == "duplicate-id-test")).all() assert len(documents) == 1 diff --git a/tests/test_metadata_filter.py b/tests/test_metadata_filter.py new file mode 100644 index 00000000..e0f154e2 --- /dev/null +++ b/tests/test_metadata_filter.py @@ -0,0 +1,54 @@ +"""Test metadata filter SQL builders.""" + +import json + +from sqlalchemy import column +from sqlalchemy.dialects import postgresql + +from raglite._metadata_filter import build_metadata_filter_condition, build_metadata_filter_sql + + +def test_build_metadata_filter_condition_postgresql_multiple_values() -> None: + """Build PostgreSQL conditions with OR inside a field and AND across fields.""" + condition = build_metadata_filter_condition( + column("metadata"), + {"domain": ["open", "music"], "topic": "physics"}, + dialect="postgresql", + ) + assert condition is not None + + compiled_condition = condition.compile(dialect=postgresql.dialect()) # type: ignore[no-untyped-call] + compiled_sql = str(compiled_condition) + assert " OR " in compiled_sql + assert " AND " in compiled_sql + assert "CAST(metadata AS JSONB) @>" in compiled_sql + + compiled_parameters = list(compiled_condition.params.values()) + assert {"domain": ["open"]} in compiled_parameters + assert {"domain": ["music"]} in compiled_parameters + assert {"topic": ["physics"]} in compiled_parameters + + +def test_build_metadata_filter_sql_postgresql_multiple_values() -> None: + """Build PostgreSQL raw SQL fragments with OR semantics for list values.""" + sql_fragment, parameters = build_metadata_filter_sql( + {"domain": ["open", "music"], "topic": "physics"}, + dialect="postgresql", + ) + + assert "metadata::jsonb @> :metadata_filter_0" in sql_fragment + assert "metadata::jsonb @> :metadata_filter_1" in sql_fragment + assert "metadata::jsonb @> :metadata_filter_2" in sql_fragment + assert " OR " in sql_fragment + assert " AND " in sql_fragment + + assert json.loads(parameters["metadata_filter_0"]) == {"domain": ["open"]} + assert json.loads(parameters["metadata_filter_1"]) == {"domain": ["music"]} + assert json.loads(parameters["metadata_filter_2"]) == {"topic": ["physics"]} + + +def test_build_metadata_filter_sql_postgresql_empty_value_is_unsatisfiable() -> None: + """Treat empty value lists as unsatisfiable for PostgreSQL SQL generation.""" + sql_fragment, parameters = build_metadata_filter_sql({"domain": []}, dialect="postgresql") + assert sql_fragment == " AND 1=0" + assert parameters == {} diff --git a/tests/test_search.py b/tests/test_search.py index cadb88bd..10d5b029 100644 --- a/tests/test_search.py +++ b/tests/test_search.py @@ -1,13 +1,16 @@ """Test RAGLite's search functionality.""" from typing import Any +from uuid import uuid4 import pytest from raglite import ( Document, RAGLiteConfig, + delete_documents, hybrid_search, + insert_documents, keyword_search, retrieve_chunk_spans, retrieve_chunks, @@ -127,6 +130,130 @@ def test_search_metadata_filter( ) +def test_search_metadata_filter_multiple_values_match_any( + isolated_raglite_test_config: RAGLiteConfig, search_method: BasicSearchMethod +) -> None: + """Match any value when a metadata field filter contains multiple values.""" + topic = f"or-filter-topic-{uuid4().hex}" + document_ids = [f"{topic}-open", f"{topic}-music", f"{topic}-sports"] + documents = [ + Document.from_text( + "A short note on piano and orchestra for open-domain retrieval.", + id=document_ids[0], + domain="open", + topic=topic, + ), + Document.from_text( + "A short note on piano and orchestra for music-domain retrieval.", + id=document_ids[1], + domain="music", + topic=topic, + ), + Document.from_text( + "A short note on football and basketball for sports-domain retrieval.", + id=document_ids[2], + domain="sports", + topic=topic, + ), + ] + insert_documents(documents, config=isolated_raglite_test_config) + + try: + query = "piano orchestra retrieval" + metadata_filter: MetadataFilter = {"topic": topic, "domain": ["open", "music"]} + chunk_ids, _ = search_method( + query, + num_results=5, + metadata_filter=metadata_filter, + config=isolated_raglite_test_config, + ) + assert chunk_ids, ( + "Expected OR metadata filter to match documents with domain='open' and 'music'." + ) + + chunks = retrieve_chunks(chunk_ids, config=isolated_raglite_test_config) + for chunk in chunks: + assert chunk.metadata_.get("topic") == [topic] + assert any( + domain in {"open", "music"} for domain in chunk.metadata_.get("domain", []) + ), f"Expected OR match on domain values, got {chunk.metadata_.get('domain')}" + finally: + delete_documents(document_ids, config=isolated_raglite_test_config) + + +def test_search_metadata_filter_matches_documents_with_list_metadata_values( + isolated_raglite_test_config: RAGLiteConfig, search_method: BasicSearchMethod +) -> None: + """Match documents when metadata values are stored as lists with multiple elements.""" + topic = f"list-domain-topic-{uuid4().hex}" + document_ids = [f"{topic}-open-news", f"{topic}-music-arts", f"{topic}-sports-health"] + documents = [ + Document.from_text( + "Open and news coverage about orchestras and pianos in retrieval systems.", + id=document_ids[0], + domain=["open", "news"], + topic=topic, + ), + Document.from_text( + "Music and arts commentary about orchestras and pianos in retrieval systems.", + id=document_ids[1], + domain=["music", "arts"], + topic=topic, + ), + Document.from_text( + "Sports and health analysis about football training and recovery metrics.", + id=document_ids[2], + domain=["sports", "health"], + topic=topic, + ), + ] + insert_documents(documents, config=isolated_raglite_test_config) + + try: + query = "piano orchestra retrieval systems" + metadata_filter: MetadataFilter = {"topic": topic, "domain": ["open", "music"]} + chunk_ids, _ = search_method( + query, + num_results=10, + metadata_filter=metadata_filter, + config=isolated_raglite_test_config, + ) + assert chunk_ids, "Expected results for list metadata values in domain." + + chunks = retrieve_chunks(chunk_ids, config=isolated_raglite_test_config) + matched_document_ids = {chunk.document.id for chunk in chunks} + assert matched_document_ids.issubset(set(document_ids[:2])), ( + "Expected only documents whose domain list overlaps ['open', 'music'], " + f"got {matched_document_ids}." + ) + assert matched_document_ids == set(document_ids[:2]), ( + f"Expected both list-based domain matches to be returned, got {matched_document_ids}." + ) + finally: + delete_documents(document_ids, config=isolated_raglite_test_config) + + +def test_self_query_deduplicates_and_keeps_multiple_values( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Keep all relevant values in one metadata field for self-query output.""" + from raglite._database import Metadata + + class _Result: + def model_dump(self, *, exclude_none: bool) -> dict[str, list[int]]: + assert exclude_none + return {"domain": [0, 1, 0, 99]} + + monkeypatch.setattr( + "raglite._search._get_database_metadata", + lambda **_: [Metadata(name="domain", values=["open", "music", "sports"])], + ) + monkeypatch.setattr("raglite._search.extract_with_llm", lambda **_: _Result()) + + metadata_filter = _self_query("Find open and music results.") + assert metadata_filter == {"domain": ["open", "music"]} + + def test_self_query(raglite_test_config: RAGLiteConfig) -> None: """Test self-query functionality that extracts metadata filters from queries.""" # Test 1: Query that should extract "Physics" from topic field