diff --git a/src/raglite/_chatml_function_calling.py b/src/raglite/_chatml_function_calling.py index 9b0b240a..59fe49c9 100644 --- a/src/raglite/_chatml_function_calling.py +++ b/src/raglite/_chatml_function_calling.py @@ -400,7 +400,9 @@ def chatml_function_calling_with_streaming( # Case 2: Automatic or fixed tool choice # Case 2 step 1: Determine whether to respond with a message or a tool call - assert (isinstance(tool_choice, str) and tool_choice == "auto") or isinstance(tool_choice, dict) + assert (isinstance(tool_choice, str) and tool_choice in ("auto", "required")) or isinstance( + tool_choice, dict + ) if isinstance(tool_choice, dict): tools = [t for t in tools if t["function"]["name"] == tool_choice["function"]["name"]] assert tools diff --git a/src/raglite/_config.py b/src/raglite/_config.py index da0471d3..f90377e9 100644 --- a/src/raglite/_config.py +++ b/src/raglite/_config.py @@ -103,3 +103,4 @@ class RAGLiteConfig: # list[Chunk], or list[ChunkSpan]. search_method: SearchMethod = field(default=_vector_search, compare=False) self_query: bool = False + agentic_iterations: int = 3 diff --git a/src/raglite/_litellm.py b/src/raglite/_litellm.py index f93c5c70..93739e1a 100644 --- a/src/raglite/_litellm.py +++ b/src/raglite/_litellm.py @@ -36,7 +36,6 @@ # Reduce the logging level for LiteLLM, flashrank, and httpx. litellm.suppress_debug_info = True -litellm.drop_params = True # Drop unsupported parameters for models like GPT-5 os.environ["LITELLM_LOG"] = "WARNING" logging.getLogger("LiteLLM").setLevel(logging.WARNING) logging.getLogger("flashrank").setLevel(logging.WARNING) diff --git a/src/raglite/_rag.py b/src/raglite/_rag.py index 10bb578c..2ee7f450 100644 --- a/src/raglite/_rag.py +++ b/src/raglite/_rag.py @@ -2,7 +2,7 @@ import json import logging -from collections.abc import AsyncIterator, Callable, Iterator, Mapping, Sequence +from collections.abc import AsyncIterator, Callable, Generator, Iterator, Mapping, Sequence from concurrent.futures import ThreadPoolExecutor, as_completed from typing import Any @@ -39,6 +39,43 @@ {user_prompt} """.strip() +SEARCH_AGENT_PROMPT = """ +You are an expert research assistant that helps retrieve the necessary information to answer a user's question. +You need to use a search tool that queries a knowledge base of documents. Each time you call the tool, you will receive a set of relevant document chunks as context. +You can use this context to iteratively refine your search and gather more information until you have enough to answer the user's question. +Once you do, respond with "Context is sufficient" and stop iterating. +*IMPORTANT*: You MUST iterate AS FEW TIMES AS POSSIBLE. Be strategic and efficient in your retrieval process. + +## Query guidelines (tool calls) +- Each query must be a short, simple, precise question (single facet). +- Optimize questions for document retrieval: use keywords, explicit nouns/entities names, dates ... +- When a tool call does not return any relevant information, pivot your line of questioning. +Always consider prior asked questions before asking a new one: + - DO NOT ask the same question twice. + - DO NOT ask semantically overlapping questions. + +## Example of bad questions: +- "What is the population of City A and City B?" (multi-faceted, not precise) +- "What is the population of City A?" followed by "What about City B?" (vague question, not optimized for retrieval) +- "What is the population of City A?" followed by "What is the population of City A?" (same question twice, not strategic) +- "When did David Gilmour join Pink Floyd and when did Syd Barrett leave? give months/years and reason" (multi-faceted, too complex) +- "Timeline of The Offspring band lineup changes drummers bassists guitarists with years (James Lilja, Ron Welty, Atom Willard, Pete Parada, Josh Freese, Brandon Pertzborn, Greg K., Todd Morse, Noodles)" (multi-faceted, too complex) +- "Has X ever had consecutive Billboard Hot 100 number-one singles? List any runs of consecutive Hot 100 #1 singles (song titles and dates) and the length of her longest such streak (as of August 26, 2024)." (multi-faceted, not precise, not optimized for retrieval) + +## Example of good questions: +- "When did David Gilmour join Pink Floyd?" (single-faceted, precise) +- "Timeline of the Offspring" (optimized for retrieval, can be followed by more specific questions if needed) +- "What is the population of City A?" in parallel with "What is the population of City B?" (single-faceted, precise, non-repetitive) +- "When was X born?" instead of "How old is X?" (optimized for retrieval, as age depends on current date) +""".strip() + +NO_TOOLS_FOLLOW_UP_PROMPT = """ +Tools are unavailable for this step. +Do not call or reference any tool/function. +Try to answer the question to the best of your ability using only the context provided and your general knowledge. +If that is not possible, acknowledge it. +""".strip() + def retrieve_context( query: str, @@ -54,14 +91,13 @@ def retrieve_context( query, num_results=num_chunks, metadata_filter=metadata_filter, config=config ) # Convert results to chunk spans. - chunk_spans = [] if isinstance(results, tuple): - chunk_spans = retrieve_chunk_spans(results[0], config=config) - elif all(isinstance(result, Chunk) for result in results): - chunk_spans = retrieve_chunk_spans(results, config=config) # type: ignore[arg-type] - elif all(isinstance(result, ChunkSpan) for result in results): - chunk_spans = results # type: ignore[assignment] - return chunk_spans + return retrieve_chunk_spans(results[0], config=config) + if all(isinstance(result, Chunk) for result in results): + return retrieve_chunk_spans(results, config=config) # type: ignore[arg-type] + if all(isinstance(result, ChunkSpan) for result in results): + return list(results) # type: ignore[arg-type] + return [] def _count_tokens(item: str) -> int: @@ -79,22 +115,14 @@ def _get_last_message_idx(messages: list[dict[str, str]], role: str) -> int | No def _calculate_buffer_tokens( messages: list[dict[str, str]] | None, - roles: list[str], user_prompt: str | None, template: str, ) -> int: - """Calculate the number of tokens used by other messages.""" - # Calculate already used tokens (buffer) - buffer = 0 - # Triggered when using tool calls + """Calculate the number of tokens used by existing messages.""" + # Triggered when using tool calls: count all messages. if messages: - # Count used tokens by the last message of each role - for role in roles: - idx = _get_last_message_idx(messages, role) - if idx is not None: - buffer += _count_tokens(json.dumps(messages[idx])) - return buffer - # Triggered when using add_context + return sum(_count_tokens(json.dumps(m, ensure_ascii=False)) for m in messages) + # Triggered when using add_context: count template overhead. if user_prompt: return _count_tokens(template.format(context="", user_prompt=user_prompt)) return 0 @@ -110,16 +138,15 @@ def _cutoff_idx(token_counts: list[int], max_tokens: int, *, reverse: bool = Fal def _get_token_counts(items: Sequence[str | ChunkSpan | Mapping[str, str]]) -> list[int]: """Compute token counts for a list of items.""" - return [ - _count_tokens(item.to_xml()) - if isinstance(item, ChunkSpan) - else _count_tokens(json.dumps(item, ensure_ascii=False)) - if isinstance(item, dict) - else _count_tokens(item) - if isinstance(item, str) - else 0 - for item in items - ] + token_counts: list[int] = [] + for item in items: + if isinstance(item, ChunkSpan): + token_counts.append(_count_tokens(item.to_xml())) + elif isinstance(item, Mapping): + token_counts.append(_count_tokens(json.dumps(item, ensure_ascii=False))) + else: + token_counts.append(_count_tokens(item)) + return token_counts def _limit_chunkspans( @@ -132,11 +159,10 @@ def _limit_chunkspans( ) -> dict[str, list[ChunkSpan]]: """Limit chunk spans to fit within the context window.""" # Calculate already used tokens (buffer) - buffer = _calculate_buffer_tokens( - messages, ["user", "system", "assistant"], user_prompt, template - ) - # Determine max tokens available for context - max_tokens = get_context_size(config) - buffer + buffer = _calculate_buffer_tokens(messages, user_prompt, template) + # Determine max tokens available for context, reserving space for the LLM's response. + max_output_tokens = min(2048, get_context_size(config) // 4) + max_tokens = get_context_size(config) - buffer - max_output_tokens # Compute token counts for all chunk spans per tool tool_tokens_list: dict[str, list[int]] = {} tool_total_tokens: dict[str, int] = {} @@ -150,7 +176,7 @@ def _limit_chunkspans( total_tokens += tool_total total_chunk_spans += len(chunk_spans) # Early exit if we're already under the limit - if total_tokens <= max_tokens: + if total_tokens == 0 or total_tokens <= max_tokens: return tool_chunk_spans # Allocate tokens proportionally and truncate new_total_chunk_spans = 0 @@ -218,7 +244,8 @@ def _clip(messages: list[dict[str, str]], max_tokens: int) -> list[dict[str, str max_tokens, ) # Try to include both last system and user messages if they fit together. - # If not, include just user if it fits, else return empty. + # If not, always preserve at least the last user message — the token estimate + # is approximate, and dropping all messages guarantees a crash. idx_system = _get_last_message_idx(messages, "system") if ( idx_user is not None @@ -227,9 +254,9 @@ def _clip(messages: list[dict[str, str]], max_tokens: int) -> list[dict[str, str and token_counts[idx_user] + token_counts[idx_system] <= max_tokens ): return [messages[idx_system], messages[idx_user]] - if idx_user is not None and token_counts[idx_user] <= max_tokens: + if idx_user is not None: return [messages[idx_user]] - return [] + return messages[-1:] return messages[cutoff_idx:] @@ -238,7 +265,7 @@ def _get_tools( ) -> tuple[list[dict[str, Any]] | None, dict[str, Any] | str | None]: """Get tools to search the knowledge base if no RAG context is provided in the messages.""" # Check if messages already contain RAG context or if the LLM supports tool use. - final_message = messages[-1].get("content", "") + final_message = messages[-1].get("content") or "" messages_contain_rag_context = any( s in final_message for s in ("", "", "from_chunk_id") ) @@ -254,19 +281,15 @@ def _get_tools( "function": { "name": "search_knowledge_base", "description": ( - "Search the knowledge base.\n" + "Search the knowledge base for contextual information needed to answer a user question.\n" "IMPORTANT: You MAY NOT use this function if the question can be answered with common knowledge or straightforward reasoning.\n" - "For multi-faceted questions, call this function once for each facet." ), "parameters": { "type": "object", "properties": { "query": { "type": "string", - "description": ( - "The `query` string MUST be a precise single-faceted question in the user's language.\n" - "The `query` string MUST resolve all pronouns to explicit nouns." - ), + "description": "The exact user question, only rephrase if necessary for clarity. Add current date information if relevant.", }, }, "required": ["query"], @@ -285,6 +308,8 @@ def _get_tools( def _run_tool( tool_call: ChatCompletionMessageToolCall, config: RAGLiteConfig, + *, + metadata_filter: MetadataFilter | None = None, ) -> tuple[str, list[ChunkSpan]]: """ Run a single tool to search the knowledge base. @@ -292,8 +317,93 @@ def _run_tool( Returns the tool_id and the raw chunk_spans (before formatting/limiting). """ if tool_call.function.name == "search_knowledge_base": + try: + query = json.loads(tool_call.function.arguments)["query"] + except (json.JSONDecodeError, KeyError, TypeError) as exc: + msg = f"Invalid arguments for 'search_knowledge_base': {exc}" + raise ValueError(msg) from exc + messages = [ + { + "role": "system", + "content": SEARCH_AGENT_PROMPT, + }, + { + "role": "user", + "content": query, + }, + ] + tool = { + "type": "function", + "function": { + "name": "query_knowledge_base", + "description": ( + "Search the knowledge base with a single faceted question. " + "Multi-faceted questions are not allowed and should be broken down into multiple calls. \n" + "Example of a bad question: 'What is the population of City A and the GDP of Country B?'\n" + "Example of good questions: 'What is the population of City A?', 'What is the GDP of Country B?'\n" + ), + "parameters": { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "A short, precise, single-faceted question.", + }, + }, + "required": ["query"], + "additionalProperties": False, + }, + }, + } + + # Start iterating and keep only chunk spans that introduce at least one new chunk ID. + chunk_spans: list[ChunkSpan] = [] + seen_chunk_ids: set[str] = set() + context_size = get_context_size(config) + max_output_tokens = min(2048, context_size // 4) + max_input_tokens = context_size - max_output_tokens + for iteration_index in range(max(1, config.agentic_iterations)): + response = completion( + model=config.llm, + messages=_clip(messages, max_input_tokens), + tools=[tool], + tool_choice="required" if iteration_index == 0 else "auto", + ) + messages.append(response.choices[0].message.to_dict()) # type: ignore[arg-type,union-attr] + tool_calls = response.choices[0].message.tool_calls # type: ignore[union-attr] + + # check if the tool call is valid + if tool_calls: + retrieved_chunk_spans: list[ChunkSpan] = [] + messages.extend( + _run_tools( + tool_calls, + retrieved_chunk_spans.extend, + config, + messages=messages, + metadata_filter=metadata_filter, + ) + ) + # Keep a span if it contains at least one chunk we have not seen before. + novel_chunk_spans = [ + chunk_span + for chunk_span in retrieved_chunk_spans + if any(chunk.id not in seen_chunk_ids for chunk in chunk_span.chunks) + ] + chunk_spans.extend(novel_chunk_spans) + for chunk_span in novel_chunk_spans: + seen_chunk_ids.update(chunk.id for chunk in chunk_span.chunks) + else: + break + + # Return ID and data so the main function can aggregate and limit them + return tool_call.id, chunk_spans + + if tool_call.function.name == "query_knowledge_base": kwargs = json.loads(tool_call.function.arguments) kwargs["config"] = config + if metadata_filter is not None: + kwargs["metadata_filter"] = metadata_filter chunk_spans = retrieve_context(**kwargs) # Return ID and data so the main function can aggregate and limit them return tool_call.id, chunk_spans @@ -307,15 +417,18 @@ def _run_tools( config: RAGLiteConfig, *, messages: list[dict[str, str]] | None, - max_workers: int | None = None, + metadata_filter: MetadataFilter | None = None, ) -> list[dict[str, Any]]: """Run tools in parallel, limit the total context, then format messages.""" tool_chunk_spans: dict[str, list[ChunkSpan]] = {} # 1. Parallel Execution # We use the _run_tool helper to fetch data concurrently - with ThreadPoolExecutor(max_workers=max_workers) as executor: - futures = [executor.submit(_run_tool, tool_call, config) for tool_call in tool_calls] + with ThreadPoolExecutor() as executor: + futures = [ + executor.submit(_run_tool, tool_call, config, metadata_filter=metadata_filter) + for tool_call in tool_calls + ] # Collect results as they finish try: @@ -339,14 +452,13 @@ def _run_tools( chunk_spans = tool_chunk_spans.get(tool_id, []) # Create the final message structure + documents = ", ".join( + chunk_span.to_json(index=i + 1) for i, chunk_span in enumerate(chunk_spans) + ) tool_messages.append( { "role": "tool", - "content": '{{"documents": [{elements}]}}'.format( - elements=", ".join( - chunk_span.to_json(index=i + 1) for i, chunk_span in enumerate(chunk_spans) - ) - ), + "content": f'{{"documents": [{documents}]}}', "tool_call_id": tool_id, } ) @@ -358,88 +470,121 @@ def _run_tools( return tool_messages -def rag( - messages: list[dict[str, str]], - *, - on_retrieval: Callable[[list[ChunkSpan]], None] | None = None, - config: RAGLiteConfig, -) -> Iterator[str]: - # If the final message does not contain RAG context, get a tool to search the knowledge base. - max_tokens = get_context_size(config) - tools, tool_choice = _get_tools(messages, config) - # Stream the LLM response, which is either a tool call request or an assistant response. +def _stream_rag_response( + messages: list[dict[str, str]], config: RAGLiteConfig, *, use_tools: bool = True +) -> Generator[str, None, list[Any]]: + """Stream the RAG response, which may include tool calls for retrieval.""" + context_size = get_context_size(config) + max_output_tokens = min(2048, context_size // 4) + max_input_tokens = context_size - max_output_tokens + tools, tool_choice = _get_tools(messages, config) if use_tools else (None, None) + local_chunks: list[Any] = [] stream = completion( model=config.llm, - messages=_clip(messages, max_tokens), + messages=_clip(messages, max_input_tokens), tools=tools, tool_choice=tool_choice, stream=True, + max_tokens=max_output_tokens, ) - chunks = [] for chunk in stream: - chunks.append(chunk) - if isinstance(token := chunk.choices[0].delta.content, str): + local_chunks.append(chunk) + if isinstance(token := chunk.choices[0].delta.content, str): # type: ignore[union-attr] yield token - # Check if there are tools to be called. - response = stream_chunk_builder(chunks, messages) - tool_calls = response.choices[0].message.tool_calls # type: ignore[union-attr] - if tool_calls: - # Add the tool call request to the message array. - messages.append(response.choices[0].message.to_dict()) # type: ignore[arg-type,union-attr] - # Run the tool calls to retrieve the RAG context and append the output to the message array. - messages.extend(_run_tools(tool_calls, on_retrieval, config, messages=messages)) - # Stream the assistant response. - chunks = [] - stream = completion(model=config.llm, messages=_clip(messages, max_tokens), stream=True) - for chunk in stream: - chunks.append(chunk) - if isinstance(token := chunk.choices[0].delta.content, str): - yield token - # Append the assistant response to the message array. - response = stream_chunk_builder(chunks, messages) - messages.append(response.choices[0].message.to_dict()) # type: ignore[arg-type,union-attr] + return local_chunks -async def async_rag( +async def _async_stream_rag_response( messages: list[dict[str, str]], - *, - on_retrieval: Callable[[list[ChunkSpan]], None] | None = None, config: RAGLiteConfig, + response_chunks: list[Any], + *, + use_tools: bool = True, ) -> AsyncIterator[str]: - # If the final message does not contain RAG context, get a tool to search the knowledge base. - max_tokens = get_context_size(config) - tools, tool_choice = _get_tools(messages, config) - # Asynchronously stream the LLM response, which is either a tool call or an assistant response. + """Async version of _stream_rag_response.""" + context_size = get_context_size(config) + max_output_tokens = min(2048, context_size // 4) + max_input_tokens = context_size - max_output_tokens + tools, tool_choice = _get_tools(messages, config) if use_tools else (None, None) async_stream = await acompletion( model=config.llm, - messages=_clip(messages, max_tokens), + messages=_clip(messages, max_input_tokens), tools=tools, tool_choice=tool_choice, stream=True, + max_tokens=max_output_tokens, ) - chunks = [] async for chunk in async_stream: - chunks.append(chunk) + response_chunks.append(chunk) if isinstance(token := chunk.choices[0].delta.content, str): yield token - # Check if there are tools to be called. - response = stream_chunk_builder(chunks, messages) + + +def rag( + messages: list[dict[str, str]], + *, + on_retrieval: Callable[[list[ChunkSpan]], None] | None = None, + metadata_filter: MetadataFilter | None = None, + config: RAGLiteConfig, +) -> Iterator[str]: + """Run retrieval-augmented generation with the given messages and config.""" + working = list(messages) + chunks = yield from _stream_rag_response(working, config) + response = stream_chunk_builder(chunks, working) + working.append(response.choices[0].message.to_dict()) # type: ignore[arg-type,union-attr] tool_calls = response.choices[0].message.tool_calls # type: ignore[union-attr] + if tool_calls: - # Add the tool call requests to the message array. - messages.append(response.choices[0].message.to_dict()) # type: ignore[arg-type,union-attr] - # Run the tool calls to retrieve the RAG context and append the output to the message array. - # TODO: Make this async. - messages.extend(_run_tools(tool_calls, on_retrieval, config, messages=messages)) - # Asynchronously stream the assistant response. - chunks = [] - async_stream = await acompletion( - model=config.llm, messages=_clip(messages, max_tokens), stream=True + working.extend( + _run_tools( + tool_calls, + on_retrieval, + config, + messages=working, + metadata_filter=metadata_filter, + ) + ) + follow_up_messages = [*working, {"role": "system", "content": NO_TOOLS_FOLLOW_UP_PROMPT}] + chunks = yield from _stream_rag_response(follow_up_messages, config, use_tools=False) + response = stream_chunk_builder(chunks, follow_up_messages) + working.append(response.choices[0].message.to_dict()) # type: ignore[arg-type,union-attr] + + messages.extend(working[len(messages) :]) + + +async def async_rag( + messages: list[dict[str, str]], + *, + on_retrieval: Callable[[list[ChunkSpan]], None] | None = None, + metadata_filter: MetadataFilter | None = None, + config: RAGLiteConfig, +) -> AsyncIterator[str]: + """Run retrieval-augmented generation with the given messages and config.""" + working = list(messages) + chunks: list[Any] = [] + async for token in _async_stream_rag_response(working, config, chunks): + yield token + response = stream_chunk_builder(chunks, working) + working.append(response.choices[0].message.to_dict()) # type: ignore[arg-type,union-attr] + tool_calls = response.choices[0].message.tool_calls # type: ignore[union-attr] + + if tool_calls: + working.extend( + _run_tools( + tool_calls, + on_retrieval, + config, + messages=working, + metadata_filter=metadata_filter, + ) ) - async for chunk in async_stream: - chunks.append(chunk) - if isinstance(token := chunk.choices[0].delta.content, str): - yield token - # Append the assistant response to the message array. - response = stream_chunk_builder(chunks, messages) - messages.append(response.choices[0].message.to_dict()) # type: ignore[arg-type,union-attr] + follow_up_messages = [*working, {"role": "system", "content": NO_TOOLS_FOLLOW_UP_PROMPT}] + chunks = [] + async for token in _async_stream_rag_response( + follow_up_messages, config, chunks, use_tools=False + ): + yield token + response = stream_chunk_builder(chunks, follow_up_messages) + working.append(response.choices[0].message.to_dict()) # type: ignore[arg-type,union-attr] + + messages.extend(working[len(messages) :]) diff --git a/src/raglite/_search.py b/src/raglite/_search.py index e1158d60..265901cc 100644 --- a/src/raglite/_search.py +++ b/src/raglite/_search.py @@ -56,7 +56,8 @@ def vector_search( # Apply the query adapter to the query embedding. if ( config.vector_search_query_adapter - and (Q := IndexMetadata.get("default", config=config).get("query_adapter")) is not None # noqa: N806 + and (Q := IndexMetadata.get("default", config=config).get("query_adapter")) # noqa: N806 + is not None ): query_embedding = (Q @ query_embedding).astype(query_embedding.dtype) # Rank the chunks by relevance according to the Lāˆž norm of the similarities of the multi-vector @@ -503,6 +504,7 @@ def _self_query( user_prompt=query, config=config, temperature=0.0, # Deterministic output if the model allows + drop_params=True, ) except ValueError as e: logger.debug("Failed to extract metadata filter: %s", e) diff --git a/tests/test_rag.py b/tests/test_rag.py index 7277f23e..a8ded950 100644 --- a/tests/test_rag.py +++ b/tests/test_rag.py @@ -1,6 +1,10 @@ """Test RAGLite's RAG functionality.""" import json +from types import SimpleNamespace +from typing import Any + +import pytest from raglite import ( RAGLiteConfig, @@ -8,7 +12,8 @@ retrieve_context, ) from raglite._database import ChunkSpan -from raglite._rag import rag +from raglite._rag import _run_tool, rag +from raglite._typing import MetadataFilter # noqa: TC001 def test_rag_manual(raglite_test_config: RAGLiteConfig) -> None: @@ -30,7 +35,7 @@ def test_rag_manual(raglite_test_config: RAGLiteConfig) -> None: def test_rag_auto_with_retrieval(raglite_test_config: RAGLiteConfig) -> None: """Test Retrieval-Augmented Generation with automatic retrieval.""" # Answer a question that requires RAG. - user_prompt = "How does Einstein define 'simultaneous events' in his special relativity paper?" + user_prompt = "How does Einstein define 'simultaneous events' in his special relativity paper? do not guess and provide me with proof via retrieval" messages = [{"role": "user", "content": user_prompt}] chunk_spans: list[ChunkSpan] = [] stream = rag(messages, on_retrieval=chunk_spans.extend, config=raglite_test_config) @@ -40,8 +45,10 @@ def test_rag_auto_with_retrieval(raglite_test_config: RAGLiteConfig) -> None: answer += update assert "event" in answer.lower() # Verify that RAG context was retrieved automatically. - assert [message["role"] for message in messages] == ["user", "assistant", "tool", "assistant"] - assert json.loads(messages[-2]["content"]) + roles = [message["role"] for message in messages] + assert roles[0] == "user" + assert roles[-1] == "assistant" + assert "tool" in roles # At least one retrieval happened. if not raglite_test_config.llm.startswith("llama-cpp-python"): assert chunk_spans assert all(isinstance(chunk_span, ChunkSpan) for chunk_span in chunk_spans) @@ -78,3 +85,195 @@ def test_retrieve_context_self_query(raglite_test_config: RAGLiteConfig) -> None assert chunk_span.document.metadata_.get("author") == ["Albert Einstein"], ( f"Expected author='Albert Einstein', got {chunk_span.document.metadata_.get('author')}" ) + + +def test_agentic_search_threads_metadata_filter_to_nested_tool_calls( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Pass metadata filters from search_knowledge_base down to nested tool calls.""" + config = RAGLiteConfig( + llm="gpt-5-mini", + embedder="text-embedding-3-small", + db_url="duckdb:///:memory:", + ) + metadata_filter: MetadataFilter = {"topic": ["Physics"]} + nested_tool_call = SimpleNamespace( + function=SimpleNamespace( + name="query_knowledge_base", + arguments=json.dumps({"query": "What is time dilation?"}), + ), + id="query_call_id", + ) + search_tool_call = SimpleNamespace( + function=SimpleNamespace( + name="search_knowledge_base", + arguments=json.dumps({"query": "Explain Einstein's time dilation."}), + ), + id="search_call_id", + ) + + def _make_response( + tool_calls: list[Any] | None, + ) -> Any: + message = SimpleNamespace( + tool_calls=tool_calls, + to_dict=lambda: {"role": "assistant", "content": ""}, + ) + return SimpleNamespace(choices=[SimpleNamespace(message=message)]) + + completion_responses = [ + _make_response([nested_tool_call]), + _make_response(None), + ] + + def fake_completion(**_: Any) -> Any: + return completion_responses.pop(0) + + observed_metadata_filters: list[MetadataFilter | None] = [] + + def fake_run_tools(*args: Any, **kwargs: Any) -> list[dict[str, Any]]: + tool_calls = args[0] + assert len(tool_calls) == 1 + assert tool_calls[0].function.name == "query_knowledge_base" + observed_metadata_filters.append(kwargs.get("metadata_filter")) + return [] + + monkeypatch.setattr("raglite._rag.completion", fake_completion) + monkeypatch.setattr("raglite._rag._run_tools", fake_run_tools) + + tool_id, chunk_spans = _run_tool( + search_tool_call, + config, + metadata_filter=metadata_filter, + ) + + assert tool_id == "search_call_id" + assert chunk_spans == [] + assert observed_metadata_filters == [metadata_filter] + + +def test_query_tool_call_passes_metadata_filter_to_retrieve_context( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Pass metadata_filter to retrieve_context when running query_knowledge_base.""" + config = RAGLiteConfig( + llm="gpt-5-mini", + embedder="text-embedding-3-small", + db_url="duckdb:///:memory:", + ) + metadata_filter: MetadataFilter = {"type": ["Paper"], "author": ["Albert Einstein"]} + tool_call = SimpleNamespace( + function=SimpleNamespace( + name="query_knowledge_base", + arguments=json.dumps({"query": "How is simultaneity defined?"}), + ), + id="query_call_id", + ) + observed_kwargs: dict[str, Any] = {} + + def fake_retrieve_context(**kwargs: Any) -> list[ChunkSpan]: + observed_kwargs.update(kwargs) + return [] + + monkeypatch.setattr("raglite._rag.retrieve_context", fake_retrieve_context) + + tool_id, chunk_spans = _run_tool( + tool_call, + config, + metadata_filter=metadata_filter, + ) + + assert tool_id == "query_call_id" + assert chunk_spans == [] + assert observed_kwargs["metadata_filter"] == metadata_filter + + +def test_sub_agent_deduplicates_chunk_spans_by_chunk_id( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Drop fully redundant spans and keep partially novel spans in sub-agent search.""" + config = RAGLiteConfig( + llm="gpt-5-mini", + embedder="text-embedding-3-small", + db_url="duckdb:///:memory:", + ) + search_tool_call = SimpleNamespace( + function=SimpleNamespace( + name="search_knowledge_base", + arguments=json.dumps({"query": "Explain relativity."}), + ), + id="search_call_id", + ) + + def _make_response(tool_calls: list[Any] | None) -> Any: + message = SimpleNamespace( + tool_calls=tool_calls, + to_dict=lambda: {"role": "assistant", "content": ""}, + ) + return SimpleNamespace(choices=[SimpleNamespace(message=message)]) + + nested_tool_call = SimpleNamespace( + function=SimpleNamespace( + name="query_knowledge_base", + arguments=json.dumps({"query": "Q"}), + ), + id="query_call_id", + ) + completion_responses = [ + _make_response([nested_tool_call]), + _make_response([nested_tool_call]), + _make_response(None), + ] + + def fake_completion(**_: Any) -> Any: + return completion_responses.pop(0) + + def make_chunk_span(*chunk_ids: str) -> Any: + return SimpleNamespace(chunks=[SimpleNamespace(id=chunk_id) for chunk_id in chunk_ids]) + + first_iteration_spans = [ + make_chunk_span("A", "B"), + ] + second_iteration_spans = [ + make_chunk_span("A", "B"), + make_chunk_span("B", "C"), + ] + tool_results_by_iteration = [first_iteration_spans, second_iteration_spans] + + def fake_run_tools(*args: Any, **_: Any) -> list[dict[str, Any]]: + on_retrieval = args[1] + on_retrieval(tool_results_by_iteration.pop(0)) + return [] + + monkeypatch.setattr("raglite._rag.completion", fake_completion) + monkeypatch.setattr("raglite._rag._run_tools", fake_run_tools) + + _, chunk_spans = _run_tool(search_tool_call, config) + + actual_chunk_id_sequences = [ + [chunk.id for chunk in chunk_span.chunks] for chunk_span in chunk_spans + ] + assert actual_chunk_id_sequences == [["A", "B"], ["B", "C"]] + + +def test_rag_does_not_mutate_caller_messages_on_stream_error( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Keep caller messages unchanged when an exception happens mid-rag.""" + config = RAGLiteConfig( + llm="gpt-5-mini", + embedder="text-embedding-3-small", + db_url="duckdb:///:memory:", + ) + messages = [{"role": "user", "content": "Hello"}] + original_messages = list(messages) + + def fake_stream(*_: Any, **__: Any) -> Any: + error_message = "stream failure" + raise RuntimeError(error_message) + + monkeypatch.setattr("raglite._rag._stream_rag_response", fake_stream) + + with pytest.raises(RuntimeError, match="stream failure"): + list(rag(messages, config=config)) + assert messages == original_messages