diff --git a/src/strands_tools/retrieve.py b/src/strands_tools/retrieve.py index f5882d79..63c6b072 100644 --- a/src/strands_tools/retrieve.py +++ b/src/strands_tools/retrieve.py @@ -144,12 +144,21 @@ "score": { "type": "number", "description": ( - "Minimum relevance score threshold (0.0-1.0). Results below this score will be filtered out. " - "Default is 0.4." + "Minimum relevance score threshold. For similarity metrics (default), results below this " + "score are filtered out. For distance metrics, results above this score are filtered out. " + "Default is 0.4 for similarity, inf for distance." ), "default": 0.4, - "minimum": 0.0, - "maximum": 1.0, + }, + "score_metric": { + "type": "string", + "enum": ["similarity", "distance"], + "description": ( + "How to interpret the score values. Use 'similarity' (default) when higher scores mean " + "more relevant results (e.g., cosine similarity). Use 'distance' when lower scores mean " + "more relevant results (e.g., cosine distance, pgvector <=>)." + ), + "default": "similarity", }, "profile_name": { "type": "string", @@ -187,21 +196,31 @@ } -def filter_results_by_score(results: List[Dict[str, Any]], min_score: float) -> List[Dict[str, Any]]: +def filter_results_by_score( + results: List[Dict[str, Any]], min_score: float, score_metric: str = "similarity" +) -> List[Dict[str, Any]]: """ - Filter results based on minimum score threshold. + Filter results based on score threshold, respecting the score metric type. This function takes the raw results from a knowledge base query and removes - any items that don't meet the minimum relevance score threshold. + any items that don't meet the score threshold. The filtering direction depends + on the score metric: + - "similarity": higher scores = more relevant, keeps scores >= min_score + - "distance": lower scores = more relevant, keeps scores <= min_score Args: results: List of retrieval results from Bedrock Knowledge Base - min_score: Minimum score threshold (0.0-1.0). Only results with scores - greater than or equal to this value will be returned. + min_score: Score threshold for filtering. For similarity, only results with + scores >= this value are kept. For distance, only results with scores + <= this value are kept. + score_metric: How to interpret scores. "similarity" (default) means higher + is better; "distance" means lower is better. Returns: - List of filtered results that meet or exceed the score threshold + List of filtered results that meet the score threshold """ + if score_metric == "distance": + return [result for result in results if result.get("score", float("inf")) <= min_score] return [result for result in results if result.get("score", 0.0) >= min_score] @@ -317,7 +336,9 @@ def retrieve(tool: ToolUse, **kwargs: Any) -> ToolResult: number_of_results = tool_input.get("numberOfResults", 10) kb_id = tool_input.get("knowledgeBaseId", default_knowledge_base_id) region_name = tool_input.get("region", default_aws_region) - min_score = tool_input.get("score", default_min_score) + score_metric = tool_input.get("score_metric", "similarity") + default_score = default_min_score if score_metric == "similarity" else float("inf") + min_score = tool_input.get("score", default_score) enable_metadata = tool_input.get("enableMetadata", default_enable_metadata) retrieve_filter = tool_input.get("retrieveFilter") @@ -353,7 +374,7 @@ def retrieve(tool: ToolUse, **kwargs: Any) -> ToolResult: # Get and filter results all_results = response.get("retrievalResults", []) - filtered_results = filter_results_by_score(all_results, min_score) + filtered_results = filter_results_by_score(all_results, min_score, score_metric) # Format results for display with optional metadata formatted_results = format_results_for_display(filtered_results, enable_metadata) @@ -363,7 +384,7 @@ def retrieve(tool: ToolUse, **kwargs: Any) -> ToolResult: "toolUseId": tool_use_id, "status": "success", "content": [ - {"text": f"Retrieved {len(filtered_results)} results with score >= {min_score}:\n{formatted_results}"} + {"text": f"Retrieved {len(filtered_results)} results with score {'<=' if score_metric == 'distance' else '>='} {min_score}:\n{formatted_results}"} ], } diff --git a/tests/test_retrieve.py b/tests/test_retrieve.py index 97add221..15dfdd2c 100644 --- a/tests/test_retrieve.py +++ b/tests/test_retrieve.py @@ -677,3 +677,123 @@ def test_retrieve_via_agent_with_enable_metadata(agent, mock_boto3_client): assert "results with score >=" in result_text assert "Metadata:" not in result_text assert "test-source" not in result_text + + +def test_filter_results_by_score_distance_metric(): + """Test filter_results_by_score with distance metric (lower = more relevant).""" + test_results = [ + {"score": 0.1}, # Very close match (distance) + {"score": 0.3}, # Close match + {"score": 0.7}, # Far match + {"score": 1.2}, # Very far match + ] + + # Filter with distance threshold 0.5 — should keep scores <= 0.5 + filtered = retrieve.filter_results_by_score(test_results, 0.5, score_metric="distance") + assert len(filtered) == 2 + assert filtered[0]["score"] == 0.1 + assert filtered[1]["score"] == 0.3 + + # Filter with distance threshold 0.8 — should keep scores <= 0.8 + filtered = retrieve.filter_results_by_score(test_results, 0.8, score_metric="distance") + assert len(filtered) == 3 + assert filtered[0]["score"] == 0.1 + assert filtered[1]["score"] == 0.3 + assert filtered[2]["score"] == 0.7 + + +def test_filter_results_by_score_similarity_metric_explicit(): + """Test filter_results_by_score with explicit similarity metric (regression test).""" + test_results = [{"score": 0.9}, {"score": 0.5}, {"score": 0.3}, {"score": 0.8}] + + # Explicit similarity should behave same as default + filtered = retrieve.filter_results_by_score(test_results, 0.5, score_metric="similarity") + assert len(filtered) == 3 + assert filtered[0]["score"] == 0.9 + assert filtered[1]["score"] == 0.5 + assert filtered[2]["score"] == 0.8 + + +def test_filter_results_by_score_default_unchanged(): + """Test that default behavior (no score_metric) is unchanged for backward compatibility.""" + test_results = [{"score": 0.9}, {"score": 0.5}, {"score": 0.3}] + + # Calling without score_metric should default to similarity (>= filter) + filtered = retrieve.filter_results_by_score(test_results, 0.5) + assert len(filtered) == 2 + assert filtered[0]["score"] == 0.9 + assert filtered[1]["score"] == 0.5 + + +def test_filter_results_distance_default_score_for_missing(): + """Test that missing scores default to inf for distance metric (filtered out).""" + test_results = [{"score": 0.1}, {}, {"score": 0.3}] + + # With distance metric, missing scores default to inf, so they should be filtered out + filtered = retrieve.filter_results_by_score(test_results, 0.5, score_metric="distance") + assert len(filtered) == 2 + assert filtered[0]["score"] == 0.1 + assert filtered[1]["score"] == 0.3 + + +def test_retrieve_tool_distance_metric(mock_boto3_client): + """Test direct invocation of retrieve tool with distance metric.""" + # Override mock to return distance-style scores (low = good) + mock_boto3_client.return_value.retrieve.return_value = { + "retrievalResults": [ + { + "content": {"text": "Close match", "type": "TEXT"}, + "location": {"customDocumentLocation": {"id": "doc-001"}, "type": "CUSTOM"}, + "score": 0.05, + }, + { + "content": {"text": "Medium match", "type": "TEXT"}, + "location": {"customDocumentLocation": {"id": "doc-002"}, "type": "CUSTOM"}, + "score": 0.3, + }, + { + "content": {"text": "Far match", "type": "TEXT"}, + "location": {"customDocumentLocation": {"id": "doc-003"}, "type": "CUSTOM"}, + "score": 0.9, + }, + ] + } + + tool_use = { + "toolUseId": "test-tool-use-id", + "input": { + "text": "test query", + "knowledgeBaseId": "test-kb-id", + "score": 0.4, + "score_metric": "distance", + }, + } + + result = retrieve.retrieve(tool=tool_use) + + assert result["status"] == "success" + # Should keep scores <= 0.4 (the close and medium matches) + assert "Retrieved 2 results with score <= 0.4" in result["content"][0]["text"] + assert "doc-001" in result["content"][0]["text"] + assert "doc-002" in result["content"][0]["text"] + # Far match (0.9) should be filtered out + assert "doc-003" not in result["content"][0]["text"] + + +def test_retrieve_tool_distance_metric_default_score(mock_boto3_client): + """Test that distance metric without explicit score defaults to inf (keeps all results).""" + tool_use = { + "toolUseId": "test-tool-use-id", + "input": { + "text": "test query", + "knowledgeBaseId": "test-kb-id", + "score_metric": "distance", + # No score parameter — should default to inf, keeping all results + }, + } + + result = retrieve.retrieve(tool=tool_use) + + assert result["status"] == "success" + # All 3 mock results should be kept since default threshold is inf + assert "Retrieved 3 results with score <= inf" in result["content"][0]["text"]