Skip to content
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,6 @@ uv.lock

# VS Code
.vscode/

# evals
evals/
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,7 @@ insert_documents(documents, config=my_config)

> [!TIP]
> 📝 Documents can include metadata by passing keyword arguments to `Document.from_text()` or `Document.from_path()`. This metadata can later be used for filtering during retrieval.
> For list values, metadata is stored as-is (e.g. `domain=["open", "music"]`).

You may also want to expand the document metadata before insertion:

Expand Down Expand Up @@ -308,6 +309,14 @@ chunk_ids_hybrid, _ = hybrid_search(
user_prompt, num_results=20, metadata_filter={"topic": "physics"}, config=my_config
) # Filter results to only include chunks from documents with topic="physics" (works with any search method)

# Multi-value filter in one field uses OR semantics:
chunk_ids_or, _ = hybrid_search(
user_prompt,
num_results=20,
metadata_filter={"domain": ["open", "music"]},
config=my_config,
) # Returns chunks where domain includes "open" OR "music".

# Retrieve chunks
from raglite import retrieve_chunks

Expand Down
2 changes: 1 addition & 1 deletion src/raglite/_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
MetadataJSON = JSON().with_variant(JSONB(), "postgresql")


def _adapt_metadata(metadata: Any) -> dict[str, MetadataValue | list[MetadataValue]]:
def _adapt_metadata(metadata: Any) -> dict[str, list[MetadataValue]]:
"""Adapt metadata to the format expected by the database."""
if not metadata:
return {}
Expand Down
25 changes: 10 additions & 15 deletions src/raglite/_delete.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
"""Delete documents from the database."""

import json
from contextlib import nullcontext
from pathlib import Path
from typing import Any, Literal

from filelock import FileLock
from sqlalchemy import delete, func, text, update
from sqlalchemy.dialects.postgresql import JSONB
from sqlalchemy import delete, text, update
from sqlalchemy.engine import make_url
from sqlalchemy.orm import load_only
from sqlalchemy.orm.attributes import flag_modified
Expand All @@ -21,10 +19,10 @@
Eval,
IndexMetadata,
Metadata,
_adapt_metadata,
create_database_engine,
)
from raglite._insert import _aggregate_metadata_from_documents
from raglite._metadata_filter import build_metadata_filter_condition
from raglite._typing import DocumentId


Expand All @@ -49,17 +47,14 @@ def _get_documents_with_metadata(
metadata_filter: dict[str, Any], session: Session
) -> list[DocumentId]:
"""Get document IDs matching a metadata filter."""
metadata_filter = _adapt_metadata(metadata_filter)

# Determine the filter condition based on the database engine
if session.get_bind().dialect.name == "postgresql":
condition = col(Document.metadata_).cast(JSONB).op("@>")(metadata_filter) # type: ignore[attr-defined]
else:
condition = func.json_contains(
col(Document.metadata_), func.json(json.dumps(metadata_filter))
)

statement = select(Document.id).where(condition)
condition = build_metadata_filter_condition(
Document.metadata_,
metadata_filter,
dialect=session.get_bind().dialect.name,
)
statement = select(Document.id)
if condition is not None:
statement = statement.where(condition)

return list(session.exec(statement).all())

Expand Down
9 changes: 7 additions & 2 deletions src/raglite/_embed.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""String embedder."""

from functools import partial
from threading import Lock
from typing import Literal

import numpy as np
Expand All @@ -12,6 +13,8 @@
from raglite._litellm import LlamaCppPythonLLM
from raglite._typing import FloatMatrix, IntVector

LLAMA_EMBED_LOCK = Lock()


def embed_strings_with_late_chunking( # noqa: C901,PLR0915
sentences: list[str], *, config: RAGLiteConfig | None = None
Expand Down Expand Up @@ -116,7 +119,8 @@ def _create_segment(
# Get the token embeddings of the entire segment, including preamble and content.
segment_start_index, content_start_index, segment_end_index = segment
segment_sentences = sentences[segment_start_index:segment_end_index]
segment_embedding = np.asarray(embedder.embed("".join(segment_sentences)))
with LLAMA_EMBED_LOCK:
segment_embedding = np.asarray(embedder.embed("".join(segment_sentences)))
# Split the segment embeddings into embedding matrices per sentence using the largest
# remainder method.
segment_tokens = num_tokens[segment_start_index:segment_end_index]
Expand Down Expand Up @@ -151,7 +155,8 @@ def _embed_string_batch(string_batch: list[str], *, config: RAGLiteConfig) -> Fl
embedder = LlamaCppPythonLLM.llm(
config.embedder, embedding=True, pooling_type=LLAMA_POOLING_TYPE_NONE
)
embeddings = np.asarray([np.mean(row, axis=0) for row in embedder.embed(string_batch)])
with LLAMA_EMBED_LOCK:
embeddings = np.asarray([np.mean(row, axis=0) for row in embedder.embed(string_batch)])
else:
# Use LiteLLM's API to embed the batch of strings.
response = embedding(config.embedder, string_batch)
Expand Down
14 changes: 9 additions & 5 deletions src/raglite/_litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from collections.abc import AsyncIterator, Callable, Iterator
from functools import cache
from io import StringIO
from threading import Lock
from typing import Any, ClassVar, cast

import httpx
Expand Down Expand Up @@ -35,6 +36,7 @@

# Reduce the logging level for LiteLLM, flashrank, and httpx.
litellm.suppress_debug_info = True
litellm.drop_params = True # Drop unsupported parameters for models like GPT-5
os.environ["LITELLM_LOG"] = "WARNING"
logging.getLogger("LiteLLM").setLevel(logging.WARNING)
logging.getLogger("flashrank").setLevel(logging.WARNING)
Expand Down Expand Up @@ -62,8 +64,9 @@ class LlamaCppPythonLLM(CustomLLM):
```
"""

# Create a lock to prevent concurrent access to llama-cpp-python models.
# Create locks to prevent concurrent access to llama-cpp-python models.
streaming_lock: ClassVar[asyncio.Lock] = asyncio.Lock()
completion_lock: ClassVar[Lock] = Lock()

# The set of supported OpenAI parameters is the intersection of [1] and [2]. Not included:
# max_completion_tokens, stream_options, n, user, logprobs, top_logprobs, extra_headers.
Expand Down Expand Up @@ -198,10 +201,11 @@ def completion( # noqa: PLR0913
llm = self.llm(model)
llama_cpp_python_params = self._translate_openai_params(optional_params)
llama_cpp_python_params = self._add_recommended_model_params(model, llama_cpp_python_params)
response = cast(
"llama_types.CreateChatCompletionResponse",
llm.create_chat_completion(messages=messages, **llama_cpp_python_params),
)
with LlamaCppPythonLLM.completion_lock:
response = cast(
"llama_types.CreateChatCompletionResponse",
llm.create_chat_completion(messages=messages, **llama_cpp_python_params),
)
litellm_model_response: ModelResponse = convert_to_model_response_object(
response_object=response,
model_response_object=model_response,
Expand Down
100 changes: 100 additions & 0 deletions src/raglite/_metadata_filter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
"""Helpers to build metadata filter conditions with consistent semantics."""

import json
from collections.abc import Mapping
from typing import Any

from sqlalchemy import and_, false, or_
from sqlalchemy.dialects.postgresql import JSONB
from sqlmodel import col, func

from raglite._database import _adapt_metadata
from raglite._typing import MetadataFilter, MetadataValue


def build_metadata_filter_condition(
metadata_column: Any,
metadata_filter: MetadataFilter | None,
*,
dialect: str,
) -> Any:
"""Build a SQLAlchemy condition for metadata filtering.

A list of values within the same field uses OR semantics.
Different fields are combined with AND semantics.
"""
normalized_metadata_filter = _adapt_metadata(metadata_filter)
if not normalized_metadata_filter:
return None

field_conditions: list[Any] = []
for metadata_name, metadata_values in normalized_metadata_filter.items():
if not metadata_values:
return false() # empty filters are considered unsatisfiable

value_conditions: list[Any] = []
for metadata_value in metadata_values:
single_value_filter = {metadata_name: [metadata_value]}
if dialect == "postgresql":
value_conditions.append(
col(metadata_column).cast(JSONB).op("@>")(single_value_filter) # type: ignore[attr-defined]
)
elif dialect == "duckdb":
value_conditions.append(
func.json_contains(
col(metadata_column), func.json(json.dumps(single_value_filter))
)
)
else:
error_message = f"Unsupported dialect: {dialect}."
raise ValueError(error_message)
field_conditions.append(or_(*value_conditions)) # combine values for the same field with OR
return and_(*field_conditions) # combine different fields with AND


def build_metadata_filter_sql(
metadata_filter: Mapping[str, list[MetadataValue] | MetadataValue] | None,
*,
dialect: str,
) -> tuple[str, dict[str, str]]:
"""Build SQL fragment and bound parameters for metadata filtering.

A list of values within the same field uses OR semantics.
Different fields are combined with AND semantics.

Returns
-------
sql_fragment : str
A SQL fragment to be included in the WHERE clause, with placeholders for parameters.
parameters : dict
A dictionary of parameter names and their corresponding JSON string values to be used in
the query execution
"""
normalized_metadata_filter = _adapt_metadata(metadata_filter)
if not normalized_metadata_filter:
return "", {}

field_sql_conditions: list[str] = []
parameters: dict[str, str] = {}
parameter_index = 0

for metadata_name, metadata_values in normalized_metadata_filter.items():
if not metadata_values:
return " AND 1=0", {}

value_sql_conditions: list[str] = []
for metadata_value in metadata_values:
parameter_name = f"metadata_filter_{parameter_index}"
parameter_index += 1
single_value_filter = json.dumps({metadata_name: [metadata_value]})
parameters[parameter_name] = single_value_filter

if dialect == "postgresql":
value_sql_conditions.append(f"metadata::jsonb @> :{parameter_name}")
elif dialect == "duckdb":
value_sql_conditions.append(f"json_contains(metadata, JSON(:{parameter_name}))")
else:
error_message = f"Unsupported dialect: {dialect}."
raise ValueError(error_message)
field_sql_conditions.append(f"({' OR '.join(value_sql_conditions)})")
return f" AND {' AND '.join(field_sql_conditions)}", parameters
Loading