From 82dac31fe61b740ec1162ece15f1981128371c8e Mon Sep 17 00:00:00 2001 From: MattiaMolon Date: Tue, 17 Feb 2026 14:08:35 +0100 Subject: [PATCH 01/16] feat: allow agentic behavior with iterable tool calling --- src/raglite/_rag.py | 61 +++++++++++++++++++++++++++++++++++---------- 1 file changed, 48 insertions(+), 13 deletions(-) diff --git a/src/raglite/_rag.py b/src/raglite/_rag.py index 10bb578c..bac6ebf0 100644 --- a/src/raglite/_rag.py +++ b/src/raglite/_rag.py @@ -111,13 +111,17 @@ def _cutoff_idx(token_counts: list[int], max_tokens: int, *, reverse: bool = Fal def _get_token_counts(items: Sequence[str | ChunkSpan | Mapping[str, str]]) -> list[int]: """Compute token counts for a list of items.""" return [ - _count_tokens(item.to_xml()) - if isinstance(item, ChunkSpan) - else _count_tokens(json.dumps(item, ensure_ascii=False)) - if isinstance(item, dict) - else _count_tokens(item) - if isinstance(item, str) - else 0 + ( + _count_tokens(item.to_xml()) + if isinstance(item, ChunkSpan) + else ( + _count_tokens(json.dumps(item, ensure_ascii=False)) + if isinstance(item, dict) + else _count_tokens(item) + if isinstance(item, str) + else 0 + ) + ) for item in items ] @@ -254,9 +258,12 @@ def _get_tools( "function": { "name": "search_knowledge_base", "description": ( - "Search the knowledge base.\n" + "Search the knowledge base using single-faceted questions.\n" + "For multi-faceted questions (comparison, sets, ..), call this function once for each facet.\n " + "Example original query: Which artist has the most number-one albums on the Billboard 200: X or Y?\n" + "Facet 1: How many albums on the Billboard 200 does X have? \n" + "Facet 2: How many albums on the Billboard 200 does Y have? \n" "IMPORTANT: You MAY NOT use this function if the question can be answered with common knowledge or straightforward reasoning.\n" - "For multi-faceted questions, call this function once for each facet." ), "parameters": { "type": "object", @@ -362,6 +369,7 @@ def rag( messages: list[dict[str, str]], *, on_retrieval: Callable[[list[ChunkSpan]], None] | None = None, + allowed_iterations: int = 2, config: RAGLiteConfig, ) -> Iterator[str]: # If the final message does not contain RAG context, get a tool to search the knowledge base. @@ -380,23 +388,50 @@ def rag( chunks.append(chunk) if isinstance(token := chunk.choices[0].delta.content, str): yield token - # Check if there are tools to be called. response = stream_chunk_builder(chunks, messages) tool_calls = response.choices[0].message.tool_calls # type: ignore[union-attr] - if tool_calls: + # Check if there are tools to be called. + iterations = 0 + while iterations < allowed_iterations and tool_calls is not None: # Add the tool call request to the message array. messages.append(response.choices[0].message.to_dict()) # type: ignore[arg-type,union-attr] # Run the tool calls to retrieve the RAG context and append the output to the message array. + # TODO: should merge results between tools if same document is extracted by multiple tools + # to avoid duplicates in context, and to optimize token usage. + # TODO: Also, make sure that the questions are different from call tools + # to avoid redundant calls. + # TODO: To avoid same question, should we construct a message ourselves with the + # responses and add context function? messages.extend(_run_tools(tool_calls, on_retrieval, config, messages=messages)) + + # check if we've reached the maximum number of allowed iterations and append a stop message + if iterations == allowed_iterations - 1: + stop_message = { + "role": "system", + "content": "You have reached the maximum number of retrieval iterations for this query. " + "Answer the question based on the retrieved context without making additional tool calls.", + } + messages.extend([stop_message]) + # Stream the assistant response. chunks = [] - stream = completion(model=config.llm, messages=_clip(messages, max_tokens), stream=True) + stream = completion( + model=config.llm, + messages=_clip(messages, max_tokens), + tools=tools, + tool_choice=tool_choice, + stream=True, + ) for chunk in stream: chunks.append(chunk) if isinstance(token := chunk.choices[0].delta.content, str): yield token + + # Check if there are additional tool calls for another iteration. + response = stream_chunk_builder(chunks, messages) + tool_calls = response.choices[0].message.tool_calls # type: ignore[union-attr] + iterations += 1 # Append the assistant response to the message array. - response = stream_chunk_builder(chunks, messages) messages.append(response.choices[0].message.to_dict()) # type: ignore[arg-type,union-attr] From 9eda9c8c7077d3691fc1b19bf2f3ee4f74b59419 Mon Sep 17 00:00:00 2001 From: MattiaMolon Date: Tue, 24 Feb 2026 15:08:14 +0100 Subject: [PATCH 02/16] feat: added system message injection to allow for smoother agentic rag --- src/raglite/_rag.py | 139 +++++++++++++++++++++++++++++++++----------- 1 file changed, 105 insertions(+), 34 deletions(-) diff --git a/src/raglite/_rag.py b/src/raglite/_rag.py index bac6ebf0..c5a7a67d 100644 --- a/src/raglite/_rag.py +++ b/src/raglite/_rag.py @@ -2,7 +2,7 @@ import json import logging -from collections.abc import AsyncIterator, Callable, Iterator, Mapping, Sequence +from collections.abc import AsyncIterator, Callable, Generator, Iterator, Mapping, Sequence from concurrent.futures import ThreadPoolExecutor, as_completed from typing import Any @@ -39,6 +39,65 @@ {user_prompt} """.strip() +SEARCH_AGENT_PROMPT = """ +You are an expert research assistant that answers the user's question using a search tool over a knowledge base (RAG). +You may perform up to {allowed_iterations} iterations. In each iteration, you may issue up to {max_questions_per_iteration} search tool queries. + +Your job is to: +1) If the question can be answered via common knowledge or straightforward reasoning, answer directly without using the search tool. +2) If not, decide what information is required to answer the question. +3) Call the search tool with precise, single-faceted questions to retrieve that information from the knowledge base. +4) Produce a final answer grounded in retrieved evidence. + +## Workflow (repeat until done or iterations exhausted) +At the start of each iteration: +- Briefly state what is currently missing (the minimum unknowns that block a confident answer). +- Generate search queries that directly target those unknowns. + +After each retrieval: +- Extract the relevant facts (and ignore irrelevant text). +- Assess sufficiency: + - SUFFICIENT if you can answer every part of the user question with direct support from the retrieved info. + - INSUFFICIENT if any required fact is missing, ambiguous, or unsupported. +- If sufficient, stop searching and answer the question. +- If insufficient, call the tool again with queries that target ONLY what's missing. + +## Query guidelines (tool calls) +- Use the tool ONLY when the answer is not common knowledge or requires knowledge-base-specific facts. +- Each tool call must be a single, precise question (single facet). Split multi-facet needs into separate calls. +- Resolve pronouns and vague references into explicit nouns/entities. +- Avoid redundancy: + - Do not ask the same question twice. + - Do not ask semantically overlapping questions unless you are disambiguating conflicting info. +- Prefer queries that request: + - Definitions / canonical records first (IDs, names, dates). + - Then relationships / comparisons. + - Then edge cases / exceptions if needed. + +## Termination and fallback +- Stop early if sufficient. +- If you hit iteration limits and still lack key facts: + - Provide the best partial answer supported by evidence. + - List missing information clearly. + +## Example +Original user question: "Which city has a larger population, City A or City B?" +Reasoning: We need the population of both cities. This is not common knowledge, so we will use the search tool to find this information. +Iteration 1: + - Tool Call 1: "Population of City A" + - Tool Call 2: "Population of City B" +Retrieved information: + - "The population of City A is 1,000,000." + - "The population of the urban area of City B is 1,200,000" +Assessment: INSUFFICIENT (urban area population includes city population and surroundings, so we cannot confidently compare) +Iteration 2: + - Tool Call 1: "Population of City B (city proper)" +Retrieved information: + - "The population of the city proper of City B is 900,000." +Assessment: SUFFICIENT (now we have comparable population figures for both cities) +Final answer: "City A has a larger population than City B. City A has a population of 1,000,000, while City B has a population of 900,000." +""".strip() + def retrieve_context( query: str, @@ -369,68 +428,80 @@ def rag( messages: list[dict[str, str]], *, on_retrieval: Callable[[list[ChunkSpan]], None] | None = None, - allowed_iterations: int = 2, + allowed_iterations: int = 20, config: RAGLiteConfig, ) -> Iterator[str]: + """Run retrieval-augmented generation with the given messages and config.""" + assert allowed_iterations >= 1, "allowed_iterations must be at least 1" + # If the final message does not contain RAG context, get a tool to search the knowledge base. max_tokens = get_context_size(config) tools, tool_choice = _get_tools(messages, config) - # Stream the LLM response, which is either a tool call request or an assistant response. - stream = completion( - model=config.llm, - messages=_clip(messages, max_tokens), - tools=tools, - tool_choice=tool_choice, - stream=True, - ) - chunks = [] - for chunk in stream: - chunks.append(chunk) - if isinstance(token := chunk.choices[0].delta.content, str): - yield token + + if tools: + # inject a system prompt to guide the LLM to use the tool for iterative retrieval if + # no context is provided + system_prompt = { + "role": "system", + "content": SEARCH_AGENT_PROMPT.format( + allowed_iterations=allowed_iterations, + max_questions_per_iteration=3, # This can be made configurable if needed + ), + } + messages.insert(0, system_prompt) + + def _stream_rag_response() -> Generator[str, None, list[Any]]: + """Stream the RAG response, which may include tool calls for retrieval.""" + local_chunks: list[Any] = [] + stream = completion( + model=config.llm, + messages=_clip(messages, max_tokens), + tools=tools, + tool_choice=tool_choice, + stream=True, + ) + for chunk in stream: + local_chunks.append(chunk) + if isinstance(token := chunk.choices[0].delta.content, str): # type: ignore[union-attr] + yield token + return local_chunks + + chunks = yield from _stream_rag_response() response = stream_chunk_builder(chunks, messages) tool_calls = response.choices[0].message.tool_calls # type: ignore[union-attr] + # Check if there are tools to be called. iterations = 0 while iterations < allowed_iterations and tool_calls is not None: # Add the tool call request to the message array. messages.append(response.choices[0].message.to_dict()) # type: ignore[arg-type,union-attr] # Run the tool calls to retrieve the RAG context and append the output to the message array. - # TODO: should merge results between tools if same document is extracted by multiple tools - # to avoid duplicates in context, and to optimize token usage. - # TODO: Also, make sure that the questions are different from call tools - # to avoid redundant calls. - # TODO: To avoid same question, should we construct a message ourselves with the - # responses and add context function? messages.extend(_run_tools(tool_calls, on_retrieval, config, messages=messages)) # check if we've reached the maximum number of allowed iterations and append a stop message if iterations == allowed_iterations - 1: stop_message = { "role": "system", - "content": "You have reached the maximum number of retrieval iterations for this query. " + "content": "You have reached the maximum number of retrieval iterations for this user query. " "Answer the question based on the retrieved context without making additional tool calls.", } messages.extend([stop_message]) # Stream the assistant response. - chunks = [] - stream = completion( - model=config.llm, - messages=_clip(messages, max_tokens), - tools=tools, - tool_choice=tool_choice, - stream=True, - ) - for chunk in stream: - chunks.append(chunk) - if isinstance(token := chunk.choices[0].delta.content, str): - yield token + chunks = yield from _stream_rag_response() # Check if there are additional tool calls for another iteration. response = stream_chunk_builder(chunks, messages) tool_calls = response.choices[0].message.tool_calls # type: ignore[union-attr] iterations += 1 + + # remove last system calls + if tools: + messages.pop(0) # remove the system prompt we injected at the start of the function + system_idx = _get_last_message_idx(messages, "system") + if system_idx is not None: + messages.pop(system_idx) + # Append the assistant response to the message array. messages.append(response.choices[0].message.to_dict()) # type: ignore[arg-type,union-attr] From 7c1b08acf9ecec06ed4f09132e61cf8f357e25ae Mon Sep 17 00:00:00 2001 From: MattiaMolon Date: Thu, 26 Feb 2026 13:13:00 +0100 Subject: [PATCH 03/16] feat: spawn subagent for search queries --- .gitignore | 3 + src/raglite/_config.py | 1 + src/raglite/_rag.py | 231 +++++++++++++++++++++-------------------- 3 files changed, 122 insertions(+), 113 deletions(-) diff --git a/.gitignore b/.gitignore index d71fb6d4..53bb56dd 100644 --- a/.gitignore +++ b/.gitignore @@ -78,3 +78,6 @@ uv.lock # VS Code .vscode/ + +# evals +evals/ \ No newline at end of file diff --git a/src/raglite/_config.py b/src/raglite/_config.py index da0471d3..32d2c262 100644 --- a/src/raglite/_config.py +++ b/src/raglite/_config.py @@ -103,3 +103,4 @@ class RAGLiteConfig: # list[Chunk], or list[ChunkSpan]. search_method: SearchMethod = field(default=_vector_search, compare=False) self_query: bool = False + allowed_iterations: int = 3 diff --git a/src/raglite/_rag.py b/src/raglite/_rag.py index c5a7a67d..ace3c790 100644 --- a/src/raglite/_rag.py +++ b/src/raglite/_rag.py @@ -40,45 +40,24 @@ """.strip() SEARCH_AGENT_PROMPT = """ -You are an expert research assistant that answers the user's question using a search tool over a knowledge base (RAG). -You may perform up to {allowed_iterations} iterations. In each iteration, you may issue up to {max_questions_per_iteration} search tool queries. +You are an expert research assistant that helps retrieve the necessary information to answer a user's question. +You have access to a search tool that can query a knowledge base of documents. Each time you call the tool, you will receive a set of relevant document chunks as context. +You may perform up to {allowed_iterations} iterations. In each iteration, you may issue up to {max_questions_per_iteration} search tool queries in parallel. Your job is to: -1) If the question can be answered via common knowledge or straightforward reasoning, answer directly without using the search tool. -2) If not, decide what information is required to answer the question. -3) Call the search tool with precise, single-faceted questions to retrieve that information from the knowledge base. -4) Produce a final answer grounded in retrieved evidence. - -## Workflow (repeat until done or iterations exhausted) -At the start of each iteration: -- Briefly state what is currently missing (the minimum unknowns that block a confident answer). -- Generate search queries that directly target those unknowns. - -After each retrieval: -- Extract the relevant facts (and ignore irrelevant text). -- Assess sufficiency: - - SUFFICIENT if you can answer every part of the user question with direct support from the retrieved info. - - INSUFFICIENT if any required fact is missing, ambiguous, or unsupported. -- If sufficient, stop searching and answer the question. -- If insufficient, call the tool again with queries that target ONLY what's missing. +1) Evaluate if the context provided is sufficient to answer the user's question with confidence. +2) If sufficient, respond with "Context is sufficient" and stop iterating. +3) If insufficient, reason on which information is still required to answer the question. +4) Call the search tool with precise, single-faceted questions to retrieve that information from the knowledge base. +5) Repeat from step 1 +*IMPORTANT*: Your goal is to retrieve the necessary information with AS FEW iterations and tool calls AS POSSIBLE. Be strategic and efficient in your retrieval process. ## Query guidelines (tool calls) -- Use the tool ONLY when the answer is not common knowledge or requires knowledge-base-specific facts. -- Each tool call must be a single, precise question (single facet). Split multi-facet needs into separate calls. +- Each tool call must be a single, precise question (single facet). Split multi-facets into separate calls. - Resolve pronouns and vague references into explicit nouns/entities. -- Avoid redundancy: - - Do not ask the same question twice. - - Do not ask semantically overlapping questions unless you are disambiguating conflicting info. -- Prefer queries that request: - - Definitions / canonical records first (IDs, names, dates). - - Then relationships / comparisons. - - Then edge cases / exceptions if needed. - -## Termination and fallback -- Stop early if sufficient. -- If you hit iteration limits and still lack key facts: - - Provide the best partial answer supported by evidence. - - List missing information clearly. +Compared to prior iterations: + - DO NOT ask the same question twice. + - DO NOT ask semantically overlapping questions. ## Example Original user question: "Which city has a larger population, City A or City B?" @@ -95,7 +74,6 @@ Retrieved information: - "The population of the city proper of City B is 900,000." Assessment: SUFFICIENT (now we have comparable population figures for both cities) -Final answer: "City A has a larger population than City B. City A has a population of 1,000,000, while City B has a population of 900,000." """.strip() @@ -317,22 +295,16 @@ def _get_tools( "function": { "name": "search_knowledge_base", "description": ( - "Search the knowledge base using single-faceted questions.\n" - "For multi-faceted questions (comparison, sets, ..), call this function once for each facet.\n " - "Example original query: Which artist has the most number-one albums on the Billboard 200: X or Y?\n" - "Facet 1: How many albums on the Billboard 200 does X have? \n" - "Facet 2: How many albums on the Billboard 200 does Y have? \n" - "IMPORTANT: You MAY NOT use this function if the question can be answered with common knowledge or straightforward reasoning.\n" + "Search the knowledge base.\n" + "IMPORTANT: You MAY not use this function if the question can be answered with common knowledge or straightforward reasoning.\n" + "Reformulate the question to ensure clarity, precision, and specificity." ), "parameters": { "type": "object", "properties": { "query": { "type": "string", - "description": ( - "The `query` string MUST be a precise single-faceted question in the user's language.\n" - "The `query` string MUST resolve all pronouns to explicit nouns." - ), + "description": "The search question.", }, }, "required": ["query"], @@ -358,6 +330,79 @@ def _run_tool( Returns the tool_id and the raw chunk_spans (before formatting/limiting). """ if tool_call.function.name == "search_knowledge_base": + query = json.loads(tool_call.function.arguments)["query"] + messages = [ + { + "role": "system", + "content": SEARCH_AGENT_PROMPT.format( + allowed_iterations=config.allowed_iterations, max_questions_per_iteration=3 + ), + }, + { + "role": "user", + "content": query, + }, + ] + tool = { + "type": "function", + "function": { + "name": "query_knowledge_base", + "description": ( + "Search the knowledge base with a single faceted question. " + "Each question must be precise and focused on a specific piece of information. " + "Multi-faceted questions that can be split into separate queries are not allowed. \n" + "Example of a bad question: 'What is the population of City A and the GDP of Country B?'\n" + "Example of good questions: 'What is the population of City A?', 'What is the GDP of Country B?'\n" + ), + "parameters": { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "A precise, single-faceted question to search for in the knowledge base.", + }, + }, + "required": ["query"], + "additionalProperties": False, + }, + }, + } + + # start iterating + chunk_spans = [] + iterations = 0 + while True: + response = completion( + model=config.llm, + messages=messages, + tools=[tool], + tool_choice="auto", + ) + messages.append(response.choices[0].message.to_dict()) # type: ignore[arg-type,union-attr] + tool_calls = response.choices[0].message.tool_calls # type: ignore[union-attr] + + # check if the tool call is valid + if tool_calls is not None: + messages.extend( + _run_tools( + tool_calls, + lambda spans: chunk_spans.extend(spans), + config, + messages=messages, + ) + ) + else: + break + + # check if we've reached the maximum number of allowed iterations + iterations += 1 + if iterations >= config.allowed_iterations: + break + + # Return ID and data so the main function can aggregate and limit them + return tool_call.id, chunk_spans + + if tool_call.function.name == "query_knowledge_base": kwargs = json.loads(tool_call.function.arguments) kwargs["config"] = config chunk_spans = retrieve_context(**kwargs) @@ -424,86 +469,46 @@ def _run_tools( return tool_messages +def _stream_rag_response( + messages: list[dict[str, str]], config: RAGLiteConfig +) -> Generator[str, None, list[Any]]: + """Stream the RAG response, which may include tool calls for retrieval.""" + max_tokens = get_context_size(config) + tools, tool_choice = _get_tools(messages, config) + local_chunks: list[Any] = [] + stream = completion( + model=config.llm, + messages=_clip(messages, max_tokens), + tools=tools, + tool_choice=tool_choice, + stream=True, + ) + for chunk in stream: + local_chunks.append(chunk) + if isinstance(token := chunk.choices[0].delta.content, str): # type: ignore[union-attr] + yield token + return local_chunks + + def rag( messages: list[dict[str, str]], *, on_retrieval: Callable[[list[ChunkSpan]], None] | None = None, - allowed_iterations: int = 20, config: RAGLiteConfig, ) -> Iterator[str]: """Run retrieval-augmented generation with the given messages and config.""" - assert allowed_iterations >= 1, "allowed_iterations must be at least 1" - - # If the final message does not contain RAG context, get a tool to search the knowledge base. - max_tokens = get_context_size(config) - tools, tool_choice = _get_tools(messages, config) - - if tools: - # inject a system prompt to guide the LLM to use the tool for iterative retrieval if - # no context is provided - system_prompt = { - "role": "system", - "content": SEARCH_AGENT_PROMPT.format( - allowed_iterations=allowed_iterations, - max_questions_per_iteration=3, # This can be made configurable if needed - ), - } - messages.insert(0, system_prompt) - - def _stream_rag_response() -> Generator[str, None, list[Any]]: - """Stream the RAG response, which may include tool calls for retrieval.""" - local_chunks: list[Any] = [] - stream = completion( - model=config.llm, - messages=_clip(messages, max_tokens), - tools=tools, - tool_choice=tool_choice, - stream=True, - ) - for chunk in stream: - local_chunks.append(chunk) - if isinstance(token := chunk.choices[0].delta.content, str): # type: ignore[union-attr] - yield token - return local_chunks - - chunks = yield from _stream_rag_response() + chunks = yield from _stream_rag_response(messages, config) response = stream_chunk_builder(chunks, messages) + messages.append(response.choices[0].message.to_dict()) # type: ignore[arg-type,union-attr] tool_calls = response.choices[0].message.tool_calls # type: ignore[union-attr] - # Check if there are tools to be called. - iterations = 0 - while iterations < allowed_iterations and tool_calls is not None: - # Add the tool call request to the message array. - messages.append(response.choices[0].message.to_dict()) # type: ignore[arg-type,union-attr] - # Run the tool calls to retrieve the RAG context and append the output to the message array. + if tool_calls: messages.extend(_run_tools(tool_calls, on_retrieval, config, messages=messages)) - - # check if we've reached the maximum number of allowed iterations and append a stop message - if iterations == allowed_iterations - 1: - stop_message = { - "role": "system", - "content": "You have reached the maximum number of retrieval iterations for this user query. " - "Answer the question based on the retrieved context without making additional tool calls.", - } - messages.extend([stop_message]) - - # Stream the assistant response. - chunks = yield from _stream_rag_response() - - # Check if there are additional tool calls for another iteration. + chunks = yield from _stream_rag_response(messages, config) response = stream_chunk_builder(chunks, messages) - tool_calls = response.choices[0].message.tool_calls # type: ignore[union-attr] - iterations += 1 - - # remove last system calls - if tools: - messages.pop(0) # remove the system prompt we injected at the start of the function - system_idx = _get_last_message_idx(messages, "system") - if system_idx is not None: - messages.pop(system_idx) - - # Append the assistant response to the message array. - messages.append(response.choices[0].message.to_dict()) # type: ignore[arg-type,union-attr] + messages.append(response.choices[0].message.to_dict()) # type: ignore[arg-type,union-attr] + else: + messages.append(response.choices[0].message.to_dict()) # type: ignore[arg-type,union-attr] async def async_rag( From a3edec072879f7e183e513008167c0751e597ca7 Mon Sep 17 00:00:00 2001 From: MattiaMolon Date: Mon, 2 Mar 2026 14:29:23 +0100 Subject: [PATCH 04/16] feat: first implementation of sub-agent search --- src/raglite/_rag.py | 90 ++++++++++++++++++++++----------------------- 1 file changed, 43 insertions(+), 47 deletions(-) diff --git a/src/raglite/_rag.py b/src/raglite/_rag.py index ace3c790..e47192e5 100644 --- a/src/raglite/_rag.py +++ b/src/raglite/_rag.py @@ -41,39 +41,38 @@ SEARCH_AGENT_PROMPT = """ You are an expert research assistant that helps retrieve the necessary information to answer a user's question. -You have access to a search tool that can query a knowledge base of documents. Each time you call the tool, you will receive a set of relevant document chunks as context. -You may perform up to {allowed_iterations} iterations. In each iteration, you may issue up to {max_questions_per_iteration} search tool queries in parallel. - -Your job is to: -1) Evaluate if the context provided is sufficient to answer the user's question with confidence. -2) If sufficient, respond with "Context is sufficient" and stop iterating. -3) If insufficient, reason on which information is still required to answer the question. -4) Call the search tool with precise, single-faceted questions to retrieve that information from the knowledge base. -5) Repeat from step 1 -*IMPORTANT*: Your goal is to retrieve the necessary information with AS FEW iterations and tool calls AS POSSIBLE. Be strategic and efficient in your retrieval process. +You need to use a search tool that queries a knowledge base of documents. Each time you call the tool, you will receive a set of relevant document chunks as context. +You can use this context to iteratively refine your search and gather more information until you have enough to answer the user's question. +Once you do, respond with "Context is sufficient" and stop iterating. +*IMPORTANT*: You MUST iterate AS FEW TIMES AS POSSIBLE. Be strategic and efficient in your retrieval process. ## Query guidelines (tool calls) -- Each tool call must be a single, precise question (single facet). Split multi-facets into separate calls. -- Resolve pronouns and vague references into explicit nouns/entities. -Compared to prior iterations: +- Each query must be a short, simple, precise question (single facet). +- Optimize questions for document retrieval: use keywords, explicit nouns/entities names, dates ... +- When a tool call does not return any relevant information, pivot your line of questioning. +Always consider prior asked questions before asking a new one: - DO NOT ask the same question twice. - DO NOT ask semantically overlapping questions. -## Example -Original user question: "Which city has a larger population, City A or City B?" -Reasoning: We need the population of both cities. This is not common knowledge, so we will use the search tool to find this information. -Iteration 1: - - Tool Call 1: "Population of City A" - - Tool Call 2: "Population of City B" -Retrieved information: - - "The population of City A is 1,000,000." - - "The population of the urban area of City B is 1,200,000" -Assessment: INSUFFICIENT (urban area population includes city population and surroundings, so we cannot confidently compare) -Iteration 2: - - Tool Call 1: "Population of City B (city proper)" -Retrieved information: - - "The population of the city proper of City B is 900,000." -Assessment: SUFFICIENT (now we have comparable population figures for both cities) +## Example of bad questions: +- "What is the population of City A and City B?" (multi-faceted, not precise) +- "What is the population of City A?" followed by "What about City B?" (vague question, not optimized for retrieval) +- "What is the population of City A?" followed by "What is the population of City A?" (same question twice, not strategic) +- "When did David Gilmour join Pink Floyd and when did Syd Barrett leave? give months/years and reason" (multi-faceted, too complex) +- "Timeline of The Offspring band lineup changes drummers bassists guitarists with years (James Lilja, Ron Welty, Atom Willard, Pete Parada, Josh Freese, Brandon Pertzborn, Greg K., Todd Morse, Noodles)" (multi-faceted, too complex) + +## Example of good questions: +- "When did David Gilmour join Pink Floyd?" (single-faceted, precise) +- "Timeline of the Offspring" (optimized for retrieval, can be followed by more specific questions if needed) +- "What is the population of City A?" in parallel with "What is the population of City B?" (single-faceted, precise, non-repetitive) +- "When was X born?" instead of "How old is X?" (optimized for retrieval, as age depends on current date) +""".strip() + +NO_TOOLS_FOLLOW_UP_PROMPT = """ +Tools are unavailable for this step. +Do not call or reference any tool/function. +Try to answer the question to the best of your ability using only the context provided and your general knowledge. +If that is not possible, unknowledge it. """.strip() @@ -295,16 +294,16 @@ def _get_tools( "function": { "name": "search_knowledge_base", "description": ( - "Search the knowledge base.\n" + "Search the knowledge base for contextual information needed to answer the user question.\n" + "Use the exact user question as the query to the knowledge base. Only rephrase if necessary for clarity.\n" "IMPORTANT: You MAY not use this function if the question can be answered with common knowledge or straightforward reasoning.\n" - "Reformulate the question to ensure clarity, precision, and specificity." ), "parameters": { "type": "object", "properties": { "query": { "type": "string", - "description": "The search question.", + "description": "The exact user question. Only add current date information.", }, }, "required": ["query"], @@ -349,8 +348,7 @@ def _run_tool( "name": "query_knowledge_base", "description": ( "Search the knowledge base with a single faceted question. " - "Each question must be precise and focused on a specific piece of information. " - "Multi-faceted questions that can be split into separate queries are not allowed. \n" + "Multi-faceted questions are not allowed and should be broken down into multiple calls. \n" "Example of a bad question: 'What is the population of City A and the GDP of Country B?'\n" "Example of good questions: 'What is the population of City A?', 'What is the GDP of Country B?'\n" ), @@ -359,7 +357,7 @@ def _run_tool( "properties": { "query": { "type": "string", - "description": "A precise, single-faceted question to search for in the knowledge base.", + "description": "A short, precise, single-faceted question.", }, }, "required": ["query"], @@ -372,30 +370,26 @@ def _run_tool( chunk_spans = [] iterations = 0 while True: + iterations += 1 response = completion( model=config.llm, messages=messages, tools=[tool], - tool_choice="auto", + tool_choice="required" if iterations == 1 else "auto", ) messages.append(response.choices[0].message.to_dict()) # type: ignore[arg-type,union-attr] tool_calls = response.choices[0].message.tool_calls # type: ignore[union-attr] # check if the tool call is valid if tool_calls is not None: - messages.extend( - _run_tools( - tool_calls, - lambda spans: chunk_spans.extend(spans), - config, - messages=messages, - ) - ) + new_spans: list[ChunkSpan] = [] + messages.extend(_run_tools(tool_calls, new_spans.extend, config, messages=messages)) + # check new chunks and extend chunk_spans without duplicates + chunk_spans.extend([span for span in new_spans if span not in chunk_spans]) else: break # check if we've reached the maximum number of allowed iterations - iterations += 1 if iterations >= config.allowed_iterations: break @@ -470,11 +464,11 @@ def _run_tools( def _stream_rag_response( - messages: list[dict[str, str]], config: RAGLiteConfig + messages: list[dict[str, str]], config: RAGLiteConfig, *, use_tools: bool = True ) -> Generator[str, None, list[Any]]: """Stream the RAG response, which may include tool calls for retrieval.""" max_tokens = get_context_size(config) - tools, tool_choice = _get_tools(messages, config) + tools, tool_choice = _get_tools(messages, config) if use_tools else (None, None) local_chunks: list[Any] = [] stream = completion( model=config.llm, @@ -504,7 +498,8 @@ def rag( if tool_calls: messages.extend(_run_tools(tool_calls, on_retrieval, config, messages=messages)) - chunks = yield from _stream_rag_response(messages, config) + messages.append({"role": "system", "content": NO_TOOLS_FOLLOW_UP_PROMPT}) + chunks = yield from _stream_rag_response(messages, config, use_tools=False) response = stream_chunk_builder(chunks, messages) messages.append(response.choices[0].message.to_dict()) # type: ignore[arg-type,union-attr] else: @@ -542,6 +537,7 @@ async def async_rag( # Run the tool calls to retrieve the RAG context and append the output to the message array. # TODO: Make this async. messages.extend(_run_tools(tool_calls, on_retrieval, config, messages=messages)) + messages.append({"role": "system", "content": NO_TOOLS_FOLLOW_UP_PROMPT}) # Asynchronously stream the assistant response. chunks = [] async_stream = await acompletion( From 313b675ccaec1e5415f314cf386e7c0a4149fec7 Mon Sep 17 00:00:00 2001 From: r-dh Date: Wed, 25 Feb 2026 14:48:25 +0100 Subject: [PATCH 05/16] fix: make auto-retrieval assertion robust to parallel tool calls (cherry picked from commit 2fc5d7dc8c9b73856f0341a1b532a391c8d3c398) --- tests/test_rag.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/tests/test_rag.py b/tests/test_rag.py index 7277f23e..dca54758 100644 --- a/tests/test_rag.py +++ b/tests/test_rag.py @@ -40,8 +40,13 @@ def test_rag_auto_with_retrieval(raglite_test_config: RAGLiteConfig) -> None: answer += update assert "event" in answer.lower() # Verify that RAG context was retrieved automatically. - assert [message["role"] for message in messages] == ["user", "assistant", "tool", "assistant"] - assert json.loads(messages[-2]["content"]) + roles = [message["role"] for message in messages] + assert roles[0] == "user" + assert roles[-1] == "assistant" + assert "tool" in roles # At least one retrieval happened. + # Verify the last tool message contains valid JSON. + last_tool_idx = len(roles) - 1 - roles[::-1].index("tool") + assert json.loads(messages[last_tool_idx]["content"]) if not raglite_test_config.llm.startswith("llama-cpp-python"): assert chunk_spans assert all(isinstance(chunk_span, ChunkSpan) for chunk_span in chunk_spans) From 025a28559b44bfff3f8da8011335a3ad67fca0fc Mon Sep 17 00:00:00 2001 From: r-dh Date: Thu, 26 Feb 2026 16:14:34 +0100 Subject: [PATCH 06/16] fix: drop unsupported params in self-query LLM call for GPT-5 compatibility (cherry picked from commit d46fcf6c84700df1007d0280128fb1f6e895a523) --- src/raglite/_litellm.py | 1 - src/raglite/_search.py | 4 +++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/src/raglite/_litellm.py b/src/raglite/_litellm.py index f93c5c70..93739e1a 100644 --- a/src/raglite/_litellm.py +++ b/src/raglite/_litellm.py @@ -36,7 +36,6 @@ # Reduce the logging level for LiteLLM, flashrank, and httpx. litellm.suppress_debug_info = True -litellm.drop_params = True # Drop unsupported parameters for models like GPT-5 os.environ["LITELLM_LOG"] = "WARNING" logging.getLogger("LiteLLM").setLevel(logging.WARNING) logging.getLogger("flashrank").setLevel(logging.WARNING) diff --git a/src/raglite/_search.py b/src/raglite/_search.py index e1158d60..265901cc 100644 --- a/src/raglite/_search.py +++ b/src/raglite/_search.py @@ -56,7 +56,8 @@ def vector_search( # Apply the query adapter to the query embedding. if ( config.vector_search_query_adapter - and (Q := IndexMetadata.get("default", config=config).get("query_adapter")) is not None # noqa: N806 + and (Q := IndexMetadata.get("default", config=config).get("query_adapter")) # noqa: N806 + is not None ): query_embedding = (Q @ query_embedding).astype(query_embedding.dtype) # Rank the chunks by relevance according to the Lāˆž norm of the similarities of the multi-vector @@ -503,6 +504,7 @@ def _self_query( user_prompt=query, config=config, temperature=0.0, # Deterministic output if the model allows + drop_params=True, ) except ValueError as e: logger.debug("Failed to extract metadata filter: %s", e) From abb9e741fc4c7b3861d96c14126c3e0d6f37f5a0 Mon Sep 17 00:00:00 2001 From: MattiaMolon Date: Mon, 2 Mar 2026 16:00:52 +0100 Subject: [PATCH 07/16] fix: fixed token budget miscalculation --- src/raglite/_rag.py | 39 +++++++++++++++++---------------------- 1 file changed, 17 insertions(+), 22 deletions(-) diff --git a/src/raglite/_rag.py b/src/raglite/_rag.py index e47192e5..68c51e14 100644 --- a/src/raglite/_rag.py +++ b/src/raglite/_rag.py @@ -115,22 +115,14 @@ def _get_last_message_idx(messages: list[dict[str, str]], role: str) -> int | No def _calculate_buffer_tokens( messages: list[dict[str, str]] | None, - roles: list[str], user_prompt: str | None, template: str, ) -> int: - """Calculate the number of tokens used by other messages.""" - # Calculate already used tokens (buffer) - buffer = 0 - # Triggered when using tool calls + """Calculate the number of tokens used by existing messages.""" + # Triggered when using tool calls: count all messages. if messages: - # Count used tokens by the last message of each role - for role in roles: - idx = _get_last_message_idx(messages, role) - if idx is not None: - buffer += _count_tokens(json.dumps(messages[idx])) - return buffer - # Triggered when using add_context + return sum(_count_tokens(json.dumps(m, ensure_ascii=False)) for m in messages) + # Triggered when using add_context: count template overhead. if user_prompt: return _count_tokens(template.format(context="", user_prompt=user_prompt)) return 0 @@ -172,11 +164,10 @@ def _limit_chunkspans( ) -> dict[str, list[ChunkSpan]]: """Limit chunk spans to fit within the context window.""" # Calculate already used tokens (buffer) - buffer = _calculate_buffer_tokens( - messages, ["user", "system", "assistant"], user_prompt, template - ) - # Determine max tokens available for context - max_tokens = get_context_size(config) - buffer + buffer = _calculate_buffer_tokens(messages, user_prompt, template) + # Determine max tokens available for context, reserving space for the LLM's response. + max_output_tokens = min(2048, get_context_size(config) // 4) + max_tokens = get_context_size(config) - buffer - max_output_tokens # Compute token counts for all chunk spans per tool tool_tokens_list: dict[str, list[int]] = {} tool_total_tokens: dict[str, int] = {} @@ -258,7 +249,8 @@ def _clip(messages: list[dict[str, str]], max_tokens: int) -> list[dict[str, str max_tokens, ) # Try to include both last system and user messages if they fit together. - # If not, include just user if it fits, else return empty. + # If not, always preserve at least the last user message — the token estimate + # is approximate, and dropping all messages guarantees a crash. idx_system = _get_last_message_idx(messages, "system") if ( idx_user is not None @@ -267,9 +259,9 @@ def _clip(messages: list[dict[str, str]], max_tokens: int) -> list[dict[str, str and token_counts[idx_user] + token_counts[idx_system] <= max_tokens ): return [messages[idx_system], messages[idx_user]] - if idx_user is not None and token_counts[idx_user] <= max_tokens: + if idx_user is not None: return [messages[idx_user]] - return [] + return messages[-1:] return messages[cutoff_idx:] @@ -467,15 +459,18 @@ def _stream_rag_response( messages: list[dict[str, str]], config: RAGLiteConfig, *, use_tools: bool = True ) -> Generator[str, None, list[Any]]: """Stream the RAG response, which may include tool calls for retrieval.""" - max_tokens = get_context_size(config) + context_size = get_context_size(config) + max_output_tokens = min(2048, context_size // 4) + max_input_tokens = context_size - max_output_tokens tools, tool_choice = _get_tools(messages, config) if use_tools else (None, None) local_chunks: list[Any] = [] stream = completion( model=config.llm, - messages=_clip(messages, max_tokens), + messages=_clip(messages, max_input_tokens), tools=tools, tool_choice=tool_choice, stream=True, + max_tokens=max_output_tokens, ) for chunk in stream: local_chunks.append(chunk) From f61747eef09970b8f69ef0f4ebe782b6b034d200 Mon Sep 17 00:00:00 2001 From: MattiaMolon Date: Mon, 2 Mar 2026 16:42:38 +0100 Subject: [PATCH 08/16] fix: pass metadata_filter to run_tools --- src/raglite/_rag.py | 45 +++++++++++++++--- tests/test_rag.py | 110 +++++++++++++++++++++++++++++++++++++++++++- 2 files changed, 148 insertions(+), 7 deletions(-) diff --git a/src/raglite/_rag.py b/src/raglite/_rag.py index 68c51e14..3afabd18 100644 --- a/src/raglite/_rag.py +++ b/src/raglite/_rag.py @@ -314,6 +314,8 @@ def _get_tools( def _run_tool( tool_call: ChatCompletionMessageToolCall, config: RAGLiteConfig, + *, + metadata_filter: MetadataFilter | None = None, ) -> tuple[str, list[ChunkSpan]]: """ Run a single tool to search the knowledge base. @@ -375,7 +377,15 @@ def _run_tool( # check if the tool call is valid if tool_calls is not None: new_spans: list[ChunkSpan] = [] - messages.extend(_run_tools(tool_calls, new_spans.extend, config, messages=messages)) + messages.extend( + _run_tools( + tool_calls, + new_spans.extend, + config, + messages=messages, + metadata_filter=metadata_filter, + ) + ) # check new chunks and extend chunk_spans without duplicates chunk_spans.extend([span for span in new_spans if span not in chunk_spans]) else: @@ -391,6 +401,8 @@ def _run_tool( if tool_call.function.name == "query_knowledge_base": kwargs = json.loads(tool_call.function.arguments) kwargs["config"] = config + if metadata_filter is not None: + kwargs["metadata_filter"] = metadata_filter chunk_spans = retrieve_context(**kwargs) # Return ID and data so the main function can aggregate and limit them return tool_call.id, chunk_spans @@ -404,15 +416,18 @@ def _run_tools( config: RAGLiteConfig, *, messages: list[dict[str, str]] | None, - max_workers: int | None = None, + metadata_filter: MetadataFilter | None = None, ) -> list[dict[str, Any]]: """Run tools in parallel, limit the total context, then format messages.""" tool_chunk_spans: dict[str, list[ChunkSpan]] = {} # 1. Parallel Execution # We use the _run_tool helper to fetch data concurrently - with ThreadPoolExecutor(max_workers=max_workers) as executor: - futures = [executor.submit(_run_tool, tool_call, config) for tool_call in tool_calls] + with ThreadPoolExecutor() as executor: + futures = [ + executor.submit(_run_tool, tool_call, config, metadata_filter=metadata_filter) + for tool_call in tool_calls + ] # Collect results as they finish try: @@ -483,6 +498,7 @@ def rag( messages: list[dict[str, str]], *, on_retrieval: Callable[[list[ChunkSpan]], None] | None = None, + metadata_filter: MetadataFilter | None = None, config: RAGLiteConfig, ) -> Iterator[str]: """Run retrieval-augmented generation with the given messages and config.""" @@ -492,7 +508,15 @@ def rag( tool_calls = response.choices[0].message.tool_calls # type: ignore[union-attr] if tool_calls: - messages.extend(_run_tools(tool_calls, on_retrieval, config, messages=messages)) + messages.extend( + _run_tools( + tool_calls, + on_retrieval, + config, + messages=messages, + metadata_filter=metadata_filter, + ) + ) messages.append({"role": "system", "content": NO_TOOLS_FOLLOW_UP_PROMPT}) chunks = yield from _stream_rag_response(messages, config, use_tools=False) response = stream_chunk_builder(chunks, messages) @@ -505,6 +529,7 @@ async def async_rag( messages: list[dict[str, str]], *, on_retrieval: Callable[[list[ChunkSpan]], None] | None = None, + metadata_filter: MetadataFilter | None = None, config: RAGLiteConfig, ) -> AsyncIterator[str]: # If the final message does not contain RAG context, get a tool to search the knowledge base. @@ -531,7 +556,15 @@ async def async_rag( messages.append(response.choices[0].message.to_dict()) # type: ignore[arg-type,union-attr] # Run the tool calls to retrieve the RAG context and append the output to the message array. # TODO: Make this async. - messages.extend(_run_tools(tool_calls, on_retrieval, config, messages=messages)) + messages.extend( + _run_tools( + tool_calls, + on_retrieval, + config, + messages=messages, + metadata_filter=metadata_filter, + ) + ) messages.append({"role": "system", "content": NO_TOOLS_FOLLOW_UP_PROMPT}) # Asynchronously stream the assistant response. chunks = [] diff --git a/tests/test_rag.py b/tests/test_rag.py index dca54758..646161a8 100644 --- a/tests/test_rag.py +++ b/tests/test_rag.py @@ -1,6 +1,10 @@ """Test RAGLite's RAG functionality.""" import json +from types import SimpleNamespace +from typing import TYPE_CHECKING, Any + +import pytest from raglite import ( RAGLiteConfig, @@ -8,7 +12,10 @@ retrieve_context, ) from raglite._database import ChunkSpan -from raglite._rag import rag +from raglite._rag import _run_tool, rag + +if TYPE_CHECKING: + from raglite._typing import MetadataFilter def test_rag_manual(raglite_test_config: RAGLiteConfig) -> None: @@ -83,3 +90,104 @@ def test_retrieve_context_self_query(raglite_test_config: RAGLiteConfig) -> None assert chunk_span.document.metadata_.get("author") == ["Albert Einstein"], ( f"Expected author='Albert Einstein', got {chunk_span.document.metadata_.get('author')}" ) + + +def test_agentic_search_threads_metadata_filter_to_nested_tool_calls( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Pass metadata filters from search_knowledge_base down to nested tool calls.""" + config = RAGLiteConfig( + llm="gpt-5-mini", + embedder="text-embedding-3-small", + db_url="duckdb:///:memory:", + ) + metadata_filter: MetadataFilter = {"topic": ["Physics"]} + nested_tool_call = SimpleNamespace( + function=SimpleNamespace( + name="query_knowledge_base", + arguments=json.dumps({"query": "What is time dilation?"}), + ), + id="query_call_id", + ) + search_tool_call = SimpleNamespace( + function=SimpleNamespace( + name="search_knowledge_base", + arguments=json.dumps({"query": "Explain Einstein's time dilation."}), + ), + id="search_call_id", + ) + + def _make_response( + tool_calls: list[Any] | None, + ) -> Any: + message = SimpleNamespace( + tool_calls=tool_calls, + to_dict=lambda: {"role": "assistant", "content": ""}, + ) + return SimpleNamespace(choices=[SimpleNamespace(message=message)]) + + completion_responses = [ + _make_response([nested_tool_call]), + _make_response(None), + ] + + def fake_completion(**_: Any) -> Any: + return completion_responses.pop(0) + + observed_metadata_filters: list[MetadataFilter | None] = [] + + def fake_run_tools(*args: Any, **kwargs: Any) -> list[dict[str, Any]]: + tool_calls = args[0] + assert len(tool_calls) == 1 + assert tool_calls[0].function.name == "query_knowledge_base" + observed_metadata_filters.append(kwargs.get("metadata_filter")) + return [] + + monkeypatch.setattr("raglite._rag.completion", fake_completion) + monkeypatch.setattr("raglite._rag._run_tools", fake_run_tools) + + tool_id, chunk_spans = _run_tool( + search_tool_call, + config, + metadata_filter=metadata_filter, + ) + + assert tool_id == "search_call_id" + assert chunk_spans == [] + assert observed_metadata_filters == [metadata_filter] + + +def test_query_tool_call_passes_metadata_filter_to_retrieve_context( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Pass metadata_filter to retrieve_context when running query_knowledge_base.""" + config = RAGLiteConfig( + llm="gpt-5-mini", + embedder="text-embedding-3-small", + db_url="duckdb:///:memory:", + ) + metadata_filter: MetadataFilter = {"type": ["Paper"], "author": ["Albert Einstein"]} + tool_call = SimpleNamespace( + function=SimpleNamespace( + name="query_knowledge_base", + arguments=json.dumps({"query": "How is simultaneity defined?"}), + ), + id="query_call_id", + ) + observed_kwargs: dict[str, Any] = {} + + def fake_retrieve_context(**kwargs: Any) -> list[ChunkSpan]: + observed_kwargs.update(kwargs) + return [] + + monkeypatch.setattr("raglite._rag.retrieve_context", fake_retrieve_context) + + tool_id, chunk_spans = _run_tool( + tool_call, + config, + metadata_filter=metadata_filter, + ) + + assert tool_id == "query_call_id" + assert chunk_spans == [] + assert observed_kwargs["metadata_filter"] == metadata_filter From ef3ebf511d2e0c1016ee52480950382be91602af Mon Sep 17 00:00:00 2001 From: MattiaMolon Date: Mon, 2 Mar 2026 17:07:28 +0100 Subject: [PATCH 09/16] fix: deduplicate chunk spans based on chunk ids --- src/raglite/_rag.py | 20 +++++++++---- tests/test_rag.py | 68 +++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 82 insertions(+), 6 deletions(-) diff --git a/src/raglite/_rag.py b/src/raglite/_rag.py index 3afabd18..273db96d 100644 --- a/src/raglite/_rag.py +++ b/src/raglite/_rag.py @@ -360,8 +360,9 @@ def _run_tool( }, } - # start iterating - chunk_spans = [] + # Start iterating and keep only chunk spans that introduce at least one new chunk ID. + chunk_spans: list[ChunkSpan] = [] + seen_chunk_ids: set[str] = set() iterations = 0 while True: iterations += 1 @@ -376,18 +377,25 @@ def _run_tool( # check if the tool call is valid if tool_calls is not None: - new_spans: list[ChunkSpan] = [] + retrieved_chunk_spans: list[ChunkSpan] = [] messages.extend( _run_tools( tool_calls, - new_spans.extend, + retrieved_chunk_spans.extend, config, messages=messages, metadata_filter=metadata_filter, ) ) - # check new chunks and extend chunk_spans without duplicates - chunk_spans.extend([span for span in new_spans if span not in chunk_spans]) + # Keep a span if it contains at least one chunk we have not seen before. + novel_chunk_spans = [ + chunk_span + for chunk_span in retrieved_chunk_spans + if any(chunk.id not in seen_chunk_ids for chunk in chunk_span.chunks) + ] + chunk_spans.extend(novel_chunk_spans) + for chunk_span in novel_chunk_spans: + seen_chunk_ids.update(chunk.id for chunk in chunk_span.chunks) else: break diff --git a/tests/test_rag.py b/tests/test_rag.py index 646161a8..66b0fd49 100644 --- a/tests/test_rag.py +++ b/tests/test_rag.py @@ -191,3 +191,71 @@ def fake_retrieve_context(**kwargs: Any) -> list[ChunkSpan]: assert tool_id == "query_call_id" assert chunk_spans == [] assert observed_kwargs["metadata_filter"] == metadata_filter + + +def test_sub_agent_deduplicates_chunk_spans_by_chunk_id( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Drop fully redundant spans and keep partially novel spans in sub-agent search.""" + config = RAGLiteConfig( + llm="gpt-5-mini", + embedder="text-embedding-3-small", + db_url="duckdb:///:memory:", + ) + search_tool_call = SimpleNamespace( + function=SimpleNamespace( + name="search_knowledge_base", + arguments=json.dumps({"query": "Explain relativity."}), + ), + id="search_call_id", + ) + + def _make_response(tool_calls: list[Any] | None) -> Any: + message = SimpleNamespace( + tool_calls=tool_calls, + to_dict=lambda: {"role": "assistant", "content": ""}, + ) + return SimpleNamespace(choices=[SimpleNamespace(message=message)]) + + nested_tool_call = SimpleNamespace( + function=SimpleNamespace( + name="query_knowledge_base", + arguments=json.dumps({"query": "Q"}), + ), + id="query_call_id", + ) + completion_responses = [ + _make_response([nested_tool_call]), + _make_response([nested_tool_call]), + _make_response(None), + ] + + def fake_completion(**_: Any) -> Any: + return completion_responses.pop(0) + + def make_chunk_span(*chunk_ids: str) -> Any: + return SimpleNamespace(chunks=[SimpleNamespace(id=chunk_id) for chunk_id in chunk_ids]) + + first_iteration_spans = [ + make_chunk_span("A", "B"), + ] + second_iteration_spans = [ + make_chunk_span("A", "B"), + make_chunk_span("B", "C"), + ] + tool_results_by_iteration = [first_iteration_spans, second_iteration_spans] + + def fake_run_tools(*args: Any, **_: Any) -> list[dict[str, Any]]: + on_retrieval = args[1] + on_retrieval(tool_results_by_iteration.pop(0)) + return [] + + monkeypatch.setattr("raglite._rag.completion", fake_completion) + monkeypatch.setattr("raglite._rag._run_tools", fake_run_tools) + + _, chunk_spans = _run_tool(search_tool_call, config) + + actual_chunk_id_sequences = [ + [chunk.id for chunk in chunk_span.chunks] for chunk_span in chunk_spans + ] + assert actual_chunk_id_sequences == [["A", "B"], ["B", "C"]] From fa7c163449fdeceecdb9fc28618bcedb251e3fa0 Mon Sep 17 00:00:00 2001 From: MattiaMolon Date: Mon, 2 Mar 2026 17:29:35 +0100 Subject: [PATCH 10/16] fix: separate live messages from working messages object during sub calls --- src/raglite/_rag.py | 48 +++++++++++++++++++++++---------------------- tests/test_rag.py | 23 ++++++++++++++++++++++ 2 files changed, 48 insertions(+), 23 deletions(-) diff --git a/src/raglite/_rag.py b/src/raglite/_rag.py index 273db96d..88f16e58 100644 --- a/src/raglite/_rag.py +++ b/src/raglite/_rag.py @@ -510,27 +510,28 @@ def rag( config: RAGLiteConfig, ) -> Iterator[str]: """Run retrieval-augmented generation with the given messages and config.""" - chunks = yield from _stream_rag_response(messages, config) - response = stream_chunk_builder(chunks, messages) - messages.append(response.choices[0].message.to_dict()) # type: ignore[arg-type,union-attr] + working = list(messages) + chunks = yield from _stream_rag_response(working, config) + response = stream_chunk_builder(chunks, working) + working.append(response.choices[0].message.to_dict()) # type: ignore[arg-type,union-attr] tool_calls = response.choices[0].message.tool_calls # type: ignore[union-attr] if tool_calls: - messages.extend( + working.extend( _run_tools( tool_calls, on_retrieval, config, - messages=messages, + messages=working, metadata_filter=metadata_filter, ) ) - messages.append({"role": "system", "content": NO_TOOLS_FOLLOW_UP_PROMPT}) - chunks = yield from _stream_rag_response(messages, config, use_tools=False) - response = stream_chunk_builder(chunks, messages) - messages.append(response.choices[0].message.to_dict()) # type: ignore[arg-type,union-attr] - else: - messages.append(response.choices[0].message.to_dict()) # type: ignore[arg-type,union-attr] + follow_up_messages = [*working, {"role": "system", "content": NO_TOOLS_FOLLOW_UP_PROMPT}] + chunks = yield from _stream_rag_response(follow_up_messages, config, use_tools=False) + response = stream_chunk_builder(chunks, follow_up_messages) + working.append(response.choices[0].message.to_dict()) # type: ignore[arg-type,union-attr] + + messages.extend(working[len(messages) :]) async def async_rag( @@ -542,11 +543,12 @@ async def async_rag( ) -> AsyncIterator[str]: # If the final message does not contain RAG context, get a tool to search the knowledge base. max_tokens = get_context_size(config) - tools, tool_choice = _get_tools(messages, config) + working = list(messages) + tools, tool_choice = _get_tools(working, config) # Asynchronously stream the LLM response, which is either a tool call or an assistant response. async_stream = await acompletion( model=config.llm, - messages=_clip(messages, max_tokens), + messages=_clip(working, max_tokens), tools=tools, tool_choice=tool_choice, stream=True, @@ -557,32 +559,32 @@ async def async_rag( if isinstance(token := chunk.choices[0].delta.content, str): yield token # Check if there are tools to be called. - response = stream_chunk_builder(chunks, messages) + response = stream_chunk_builder(chunks, working) + working.append(response.choices[0].message.to_dict()) # type: ignore[arg-type,union-attr] tool_calls = response.choices[0].message.tool_calls # type: ignore[union-attr] if tool_calls: - # Add the tool call requests to the message array. - messages.append(response.choices[0].message.to_dict()) # type: ignore[arg-type,union-attr] # Run the tool calls to retrieve the RAG context and append the output to the message array. # TODO: Make this async. - messages.extend( + working.extend( _run_tools( tool_calls, on_retrieval, config, - messages=messages, + messages=working, metadata_filter=metadata_filter, ) ) - messages.append({"role": "system", "content": NO_TOOLS_FOLLOW_UP_PROMPT}) + follow_up_messages = [*working, {"role": "system", "content": NO_TOOLS_FOLLOW_UP_PROMPT}] # Asynchronously stream the assistant response. chunks = [] async_stream = await acompletion( - model=config.llm, messages=_clip(messages, max_tokens), stream=True + model=config.llm, messages=_clip(follow_up_messages, max_tokens), stream=True ) async for chunk in async_stream: chunks.append(chunk) if isinstance(token := chunk.choices[0].delta.content, str): yield token - # Append the assistant response to the message array. - response = stream_chunk_builder(chunks, messages) - messages.append(response.choices[0].message.to_dict()) # type: ignore[arg-type,union-attr] + response = stream_chunk_builder(chunks, follow_up_messages) + working.append(response.choices[0].message.to_dict()) # type: ignore[arg-type,union-attr] + + messages.extend(working[len(messages) :]) diff --git a/tests/test_rag.py b/tests/test_rag.py index 66b0fd49..302b8138 100644 --- a/tests/test_rag.py +++ b/tests/test_rag.py @@ -259,3 +259,26 @@ def fake_run_tools(*args: Any, **_: Any) -> list[dict[str, Any]]: [chunk.id for chunk in chunk_span.chunks] for chunk_span in chunk_spans ] assert actual_chunk_id_sequences == [["A", "B"], ["B", "C"]] + + +def test_rag_does_not_mutate_caller_messages_on_stream_error( + monkeypatch: pytest.MonkeyPatch, +) -> None: + """Keep caller messages unchanged when an exception happens mid-rag.""" + config = RAGLiteConfig( + llm="gpt-5-mini", + embedder="text-embedding-3-small", + db_url="duckdb:///:memory:", + ) + messages = [{"role": "user", "content": "Hello"}] + original_messages = list(messages) + + def fake_stream(*_: Any, **__: Any) -> Any: + error_message = "stream failure" + raise RuntimeError(error_message) + + monkeypatch.setattr("raglite._rag._stream_rag_response", fake_stream) + + with pytest.raises(RuntimeError, match="stream failure"): + list(rag(messages, config=config)) + assert messages == original_messages From cff150186ec1bb5050012c80ffcce76b73deff74 Mon Sep 17 00:00:00 2001 From: MattiaMolon Date: Mon, 2 Mar 2026 17:45:53 +0100 Subject: [PATCH 11/16] feat: updated async_rag to match rag --- src/raglite/_rag.py | 62 +++++++++++++++++++++++++-------------------- 1 file changed, 35 insertions(+), 27 deletions(-) diff --git a/src/raglite/_rag.py b/src/raglite/_rag.py index 88f16e58..44ca2f8a 100644 --- a/src/raglite/_rag.py +++ b/src/raglite/_rag.py @@ -502,6 +502,32 @@ def _stream_rag_response( return local_chunks +async def _async_stream_rag_response( + messages: list[dict[str, str]], + config: RAGLiteConfig, + response_chunks: list[Any], + *, + use_tools: bool = True, +) -> AsyncIterator[str]: + """Async version of _stream_rag_response.""" + context_size = get_context_size(config) + max_output_tokens = min(2048, context_size // 4) + max_input_tokens = context_size - max_output_tokens + tools, tool_choice = _get_tools(messages, config) if use_tools else (None, None) + async_stream = await acompletion( + model=config.llm, + messages=_clip(messages, max_input_tokens), + tools=tools, + tool_choice=tool_choice, + stream=True, + max_tokens=max_output_tokens, + ) + async for chunk in async_stream: + response_chunks.append(chunk) + if isinstance(token := chunk.choices[0].delta.content, str): + yield token + + def rag( messages: list[dict[str, str]], *, @@ -541,30 +567,16 @@ async def async_rag( metadata_filter: MetadataFilter | None = None, config: RAGLiteConfig, ) -> AsyncIterator[str]: - # If the final message does not contain RAG context, get a tool to search the knowledge base. - max_tokens = get_context_size(config) + """Run retrieval-augmented generation with the given messages and config.""" working = list(messages) - tools, tool_choice = _get_tools(working, config) - # Asynchronously stream the LLM response, which is either a tool call or an assistant response. - async_stream = await acompletion( - model=config.llm, - messages=_clip(working, max_tokens), - tools=tools, - tool_choice=tool_choice, - stream=True, - ) - chunks = [] - async for chunk in async_stream: - chunks.append(chunk) - if isinstance(token := chunk.choices[0].delta.content, str): - yield token - # Check if there are tools to be called. + chunks: list[Any] = [] + async for token in _async_stream_rag_response(working, config, chunks): + yield token response = stream_chunk_builder(chunks, working) working.append(response.choices[0].message.to_dict()) # type: ignore[arg-type,union-attr] tool_calls = response.choices[0].message.tool_calls # type: ignore[union-attr] + if tool_calls: - # Run the tool calls to retrieve the RAG context and append the output to the message array. - # TODO: Make this async. working.extend( _run_tools( tool_calls, @@ -575,15 +587,11 @@ async def async_rag( ) ) follow_up_messages = [*working, {"role": "system", "content": NO_TOOLS_FOLLOW_UP_PROMPT}] - # Asynchronously stream the assistant response. chunks = [] - async_stream = await acompletion( - model=config.llm, messages=_clip(follow_up_messages, max_tokens), stream=True - ) - async for chunk in async_stream: - chunks.append(chunk) - if isinstance(token := chunk.choices[0].delta.content, str): - yield token + async for token in _async_stream_rag_response( + follow_up_messages, config, chunks, use_tools=False + ): + yield token response = stream_chunk_builder(chunks, follow_up_messages) working.append(response.choices[0].message.to_dict()) # type: ignore[arg-type,union-attr] From 120192a5a4a1c92573675a15ccf9ad1d67a71b02 Mon Sep 17 00:00:00 2001 From: MattiaMolon Date: Tue, 3 Mar 2026 14:42:52 +0100 Subject: [PATCH 12/16] feat: code clean up --- src/raglite/_rag.py | 59 +++++++++++++++++---------------------------- 1 file changed, 22 insertions(+), 37 deletions(-) diff --git a/src/raglite/_rag.py b/src/raglite/_rag.py index 44ca2f8a..e1ce7bf2 100644 --- a/src/raglite/_rag.py +++ b/src/raglite/_rag.py @@ -90,14 +90,13 @@ def retrieve_context( query, num_results=num_chunks, metadata_filter=metadata_filter, config=config ) # Convert results to chunk spans. - chunk_spans = [] if isinstance(results, tuple): - chunk_spans = retrieve_chunk_spans(results[0], config=config) - elif all(isinstance(result, Chunk) for result in results): - chunk_spans = retrieve_chunk_spans(results, config=config) # type: ignore[arg-type] - elif all(isinstance(result, ChunkSpan) for result in results): - chunk_spans = results # type: ignore[assignment] - return chunk_spans + return retrieve_chunk_spans(results[0], config=config) + if all(isinstance(result, Chunk) for result in results): + return retrieve_chunk_spans(results, config=config) # type: ignore[arg-type] + if all(isinstance(result, ChunkSpan) for result in results): + return list(results) # type: ignore[arg-type] + return [] def _count_tokens(item: str) -> int: @@ -138,20 +137,15 @@ def _cutoff_idx(token_counts: list[int], max_tokens: int, *, reverse: bool = Fal def _get_token_counts(items: Sequence[str | ChunkSpan | Mapping[str, str]]) -> list[int]: """Compute token counts for a list of items.""" - return [ - ( - _count_tokens(item.to_xml()) - if isinstance(item, ChunkSpan) - else ( - _count_tokens(json.dumps(item, ensure_ascii=False)) - if isinstance(item, dict) - else _count_tokens(item) - if isinstance(item, str) - else 0 - ) - ) - for item in items - ] + token_counts: list[int] = [] + for item in items: + if isinstance(item, ChunkSpan): + token_counts.append(_count_tokens(item.to_xml())) + elif isinstance(item, Mapping): + token_counts.append(_count_tokens(json.dumps(item, ensure_ascii=False))) + else: + token_counts.append(_count_tokens(item)) + return token_counts def _limit_chunkspans( @@ -327,9 +321,7 @@ def _run_tool( messages = [ { "role": "system", - "content": SEARCH_AGENT_PROMPT.format( - allowed_iterations=config.allowed_iterations, max_questions_per_iteration=3 - ), + "content": SEARCH_AGENT_PROMPT, }, { "role": "user", @@ -363,14 +355,12 @@ def _run_tool( # Start iterating and keep only chunk spans that introduce at least one new chunk ID. chunk_spans: list[ChunkSpan] = [] seen_chunk_ids: set[str] = set() - iterations = 0 - while True: - iterations += 1 + for iteration_index in range(max(1, config.allowed_iterations)): response = completion( model=config.llm, messages=messages, tools=[tool], - tool_choice="required" if iterations == 1 else "auto", + tool_choice="required" if iteration_index == 0 else "auto", ) messages.append(response.choices[0].message.to_dict()) # type: ignore[arg-type,union-attr] tool_calls = response.choices[0].message.tool_calls # type: ignore[union-attr] @@ -399,10 +389,6 @@ def _run_tool( else: break - # check if we've reached the maximum number of allowed iterations - if iterations >= config.allowed_iterations: - break - # Return ID and data so the main function can aggregate and limit them return tool_call.id, chunk_spans @@ -459,14 +445,13 @@ def _run_tools( chunk_spans = tool_chunk_spans.get(tool_id, []) # Create the final message structure + documents = ", ".join( + chunk_span.to_json(index=i + 1) for i, chunk_span in enumerate(chunk_spans) + ) tool_messages.append( { "role": "tool", - "content": '{{"documents": [{elements}]}}'.format( - elements=", ".join( - chunk_span.to_json(index=i + 1) for i, chunk_span in enumerate(chunk_spans) - ) - ), + "content": f'{{"documents": [{documents}]}}', "tool_call_id": tool_id, } ) From 96aa82fbe98b18b9c29230890676bb088714c1c7 Mon Sep 17 00:00:00 2001 From: MattiaMolon Date: Tue, 3 Mar 2026 15:03:34 +0100 Subject: [PATCH 13/16] fix: small robustness fixes for edge cases --- src/raglite/_config.py | 2 +- src/raglite/_rag.py | 11 ++++++----- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/src/raglite/_config.py b/src/raglite/_config.py index 32d2c262..f90377e9 100644 --- a/src/raglite/_config.py +++ b/src/raglite/_config.py @@ -103,4 +103,4 @@ class RAGLiteConfig: # list[Chunk], or list[ChunkSpan]. search_method: SearchMethod = field(default=_vector_search, compare=False) self_query: bool = False - allowed_iterations: int = 3 + agentic_iterations: int = 3 diff --git a/src/raglite/_rag.py b/src/raglite/_rag.py index e1ce7bf2..861f23e9 100644 --- a/src/raglite/_rag.py +++ b/src/raglite/_rag.py @@ -60,6 +60,7 @@ - "What is the population of City A?" followed by "What is the population of City A?" (same question twice, not strategic) - "When did David Gilmour join Pink Floyd and when did Syd Barrett leave? give months/years and reason" (multi-faceted, too complex) - "Timeline of The Offspring band lineup changes drummers bassists guitarists with years (James Lilja, Ron Welty, Atom Willard, Pete Parada, Josh Freese, Brandon Pertzborn, Greg K., Todd Morse, Noodles)" (multi-faceted, too complex) +- "Has X ever had consecutive Billboard Hot 100 number-one singles? List any runs of consecutive Hot 100 #1 singles (song titles and dates) and the length of her longest such streak (as of August 26, 2024)." (multi-faceted, not precise, not optimized for retrieval) ## Example of good questions: - "When did David Gilmour join Pink Floyd?" (single-faceted, precise) @@ -72,7 +73,7 @@ Tools are unavailable for this step. Do not call or reference any tool/function. Try to answer the question to the best of your ability using only the context provided and your general knowledge. -If that is not possible, unknowledge it. +If that is not possible, acknowledge it. """.strip() @@ -175,7 +176,7 @@ def _limit_chunkspans( total_tokens += tool_total total_chunk_spans += len(chunk_spans) # Early exit if we're already under the limit - if total_tokens <= max_tokens: + if total_tokens == 0 or total_tokens <= max_tokens: return tool_chunk_spans # Allocate tokens proportionally and truncate new_total_chunk_spans = 0 @@ -264,7 +265,7 @@ def _get_tools( ) -> tuple[list[dict[str, Any]] | None, dict[str, Any] | str | None]: """Get tools to search the knowledge base if no RAG context is provided in the messages.""" # Check if messages already contain RAG context or if the LLM supports tool use. - final_message = messages[-1].get("content", "") + final_message = messages[-1].get("content") or "" messages_contain_rag_context = any( s in final_message for s in ("", "", "from_chunk_id") ) @@ -289,7 +290,7 @@ def _get_tools( "properties": { "query": { "type": "string", - "description": "The exact user question. Only add current date information.", + "description": "The exact user question. Add current date information if relevant.", }, }, "required": ["query"], @@ -355,7 +356,7 @@ def _run_tool( # Start iterating and keep only chunk spans that introduce at least one new chunk ID. chunk_spans: list[ChunkSpan] = [] seen_chunk_ids: set[str] = set() - for iteration_index in range(max(1, config.allowed_iterations)): + for iteration_index in range(max(1, config.agentic_iterations)): response = completion( model=config.llm, messages=messages, From 9f405739f807cdd1a5f48be91dc7f8ace452e15a Mon Sep 17 00:00:00 2001 From: MattiaMolon Date: Tue, 3 Mar 2026 17:10:00 +0000 Subject: [PATCH 14/16] fix: fix required tool call in litellm --- src/raglite/_chatml_function_calling.py | 4 +++- tests/test_rag.py | 9 +++++---- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/src/raglite/_chatml_function_calling.py b/src/raglite/_chatml_function_calling.py index 9b0b240a..59fe49c9 100644 --- a/src/raglite/_chatml_function_calling.py +++ b/src/raglite/_chatml_function_calling.py @@ -400,7 +400,9 @@ def chatml_function_calling_with_streaming( # Case 2: Automatic or fixed tool choice # Case 2 step 1: Determine whether to respond with a message or a tool call - assert (isinstance(tool_choice, str) and tool_choice == "auto") or isinstance(tool_choice, dict) + assert (isinstance(tool_choice, str) and tool_choice in ("auto", "required")) or isinstance( + tool_choice, dict + ) if isinstance(tool_choice, dict): tools = [t for t in tools if t["function"]["name"] == tool_choice["function"]["name"]] assert tools diff --git a/tests/test_rag.py b/tests/test_rag.py index 302b8138..ed550629 100644 --- a/tests/test_rag.py +++ b/tests/test_rag.py @@ -50,11 +50,12 @@ def test_rag_auto_with_retrieval(raglite_test_config: RAGLiteConfig) -> None: roles = [message["role"] for message in messages] assert roles[0] == "user" assert roles[-1] == "assistant" - assert "tool" in roles # At least one retrieval happened. - # Verify the last tool message contains valid JSON. - last_tool_idx = len(roles) - 1 - roles[::-1].index("tool") - assert json.loads(messages[last_tool_idx]["content"]) + if "tool" in roles: + # Verify the last tool message contains valid JSON. + last_tool_idx = len(roles) - 1 - roles[::-1].index("tool") + assert json.loads(messages[last_tool_idx]["content"]) if not raglite_test_config.llm.startswith("llama-cpp-python"): + assert "tool" in roles # At least one retrieval happened. assert chunk_spans assert all(isinstance(chunk_span, ChunkSpan) for chunk_span in chunk_spans) From 76d5f848941b467516a71f019132d0f8ba22642b Mon Sep 17 00:00:00 2001 From: MattiaMolon Date: Thu, 5 Mar 2026 10:50:06 +0100 Subject: [PATCH 15/16] fix: adapted tool description to make it more robust to tests --- src/raglite/_rag.py | 7 +++---- tests/test_rag.py | 7 ++----- 2 files changed, 5 insertions(+), 9 deletions(-) diff --git a/src/raglite/_rag.py b/src/raglite/_rag.py index 861f23e9..c34b8a8a 100644 --- a/src/raglite/_rag.py +++ b/src/raglite/_rag.py @@ -281,16 +281,15 @@ def _get_tools( "function": { "name": "search_knowledge_base", "description": ( - "Search the knowledge base for contextual information needed to answer the user question.\n" - "Use the exact user question as the query to the knowledge base. Only rephrase if necessary for clarity.\n" - "IMPORTANT: You MAY not use this function if the question can be answered with common knowledge or straightforward reasoning.\n" + "Search the knowledge base for contextual information needed to answer a user question.\n" + "IMPORTANT: You MAY NOT use this function if the question can be answered with common knowledge or straightforward reasoning.\n" ), "parameters": { "type": "object", "properties": { "query": { "type": "string", - "description": "The exact user question. Add current date information if relevant.", + "description": "The exact user question, only rephrase if necessary for clarity. Add current date information if relevant.", }, }, "required": ["query"], diff --git a/tests/test_rag.py b/tests/test_rag.py index ed550629..42f9f9d1 100644 --- a/tests/test_rag.py +++ b/tests/test_rag.py @@ -37,7 +37,7 @@ def test_rag_manual(raglite_test_config: RAGLiteConfig) -> None: def test_rag_auto_with_retrieval(raglite_test_config: RAGLiteConfig) -> None: """Test Retrieval-Augmented Generation with automatic retrieval.""" # Answer a question that requires RAG. - user_prompt = "How does Einstein define 'simultaneous events' in his special relativity paper?" + user_prompt = "How does Einstein define 'simultaneous events' in his special relativity paper? do not guess and provide me with proof via retrieval" messages = [{"role": "user", "content": user_prompt}] chunk_spans: list[ChunkSpan] = [] stream = rag(messages, on_retrieval=chunk_spans.extend, config=raglite_test_config) @@ -50,10 +50,7 @@ def test_rag_auto_with_retrieval(raglite_test_config: RAGLiteConfig) -> None: roles = [message["role"] for message in messages] assert roles[0] == "user" assert roles[-1] == "assistant" - if "tool" in roles: - # Verify the last tool message contains valid JSON. - last_tool_idx = len(roles) - 1 - roles[::-1].index("tool") - assert json.loads(messages[last_tool_idx]["content"]) + assert "tool" in roles # At least one retrieval happened. if not raglite_test_config.llm.startswith("llama-cpp-python"): assert "tool" in roles # At least one retrieval happened. assert chunk_spans From ac5a3a147e2af799b39cf63f7e6760189fd17773 Mon Sep 17 00:00:00 2001 From: MattiaMolon Date: Mon, 9 Mar 2026 16:01:29 +0100 Subject: [PATCH 16/16] fix: resolved pr comments - import MetadataFilter at runtime - remove redundant assert tool - added clip to subagent messages - fix bug that allowed empty tool call list to run LLM completion --- src/raglite/_rag.py | 13 ++++++++++--- tests/test_rag.py | 7 ++----- 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/src/raglite/_rag.py b/src/raglite/_rag.py index c34b8a8a..2ee7f450 100644 --- a/src/raglite/_rag.py +++ b/src/raglite/_rag.py @@ -317,7 +317,11 @@ def _run_tool( Returns the tool_id and the raw chunk_spans (before formatting/limiting). """ if tool_call.function.name == "search_knowledge_base": - query = json.loads(tool_call.function.arguments)["query"] + try: + query = json.loads(tool_call.function.arguments)["query"] + except (json.JSONDecodeError, KeyError, TypeError) as exc: + msg = f"Invalid arguments for 'search_knowledge_base': {exc}" + raise ValueError(msg) from exc messages = [ { "role": "system", @@ -355,10 +359,13 @@ def _run_tool( # Start iterating and keep only chunk spans that introduce at least one new chunk ID. chunk_spans: list[ChunkSpan] = [] seen_chunk_ids: set[str] = set() + context_size = get_context_size(config) + max_output_tokens = min(2048, context_size // 4) + max_input_tokens = context_size - max_output_tokens for iteration_index in range(max(1, config.agentic_iterations)): response = completion( model=config.llm, - messages=messages, + messages=_clip(messages, max_input_tokens), tools=[tool], tool_choice="required" if iteration_index == 0 else "auto", ) @@ -366,7 +373,7 @@ def _run_tool( tool_calls = response.choices[0].message.tool_calls # type: ignore[union-attr] # check if the tool call is valid - if tool_calls is not None: + if tool_calls: retrieved_chunk_spans: list[ChunkSpan] = [] messages.extend( _run_tools( diff --git a/tests/test_rag.py b/tests/test_rag.py index 42f9f9d1..a8ded950 100644 --- a/tests/test_rag.py +++ b/tests/test_rag.py @@ -2,7 +2,7 @@ import json from types import SimpleNamespace -from typing import TYPE_CHECKING, Any +from typing import Any import pytest @@ -13,9 +13,7 @@ ) from raglite._database import ChunkSpan from raglite._rag import _run_tool, rag - -if TYPE_CHECKING: - from raglite._typing import MetadataFilter +from raglite._typing import MetadataFilter # noqa: TC001 def test_rag_manual(raglite_test_config: RAGLiteConfig) -> None: @@ -52,7 +50,6 @@ def test_rag_auto_with_retrieval(raglite_test_config: RAGLiteConfig) -> None: assert roles[-1] == "assistant" assert "tool" in roles # At least one retrieval happened. if not raglite_test_config.llm.startswith("llama-cpp-python"): - assert "tool" in roles # At least one retrieval happened. assert chunk_spans assert all(isinstance(chunk_span, ChunkSpan) for chunk_span in chunk_spans)