Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 34 additions & 13 deletions src/strands_tools/retrieve.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,12 +144,21 @@
"score": {
"type": "number",
"description": (
"Minimum relevance score threshold (0.0-1.0). Results below this score will be filtered out. "
"Default is 0.4."
"Minimum relevance score threshold. For similarity metrics (default), results below this "
"score are filtered out. For distance metrics, results above this score are filtered out. "
"Default is 0.4 for similarity, inf for distance."
),
"default": 0.4,
"minimum": 0.0,
"maximum": 1.0,
},
"score_metric": {
"type": "string",
"enum": ["similarity", "distance"],
"description": (
"How to interpret the score values. Use 'similarity' (default) when higher scores mean "
"more relevant results (e.g., cosine similarity). Use 'distance' when lower scores mean "
"more relevant results (e.g., cosine distance, pgvector <=>)."
),
"default": "similarity",
},
"profile_name": {
"type": "string",
Expand Down Expand Up @@ -187,21 +196,31 @@
}


def filter_results_by_score(results: List[Dict[str, Any]], min_score: float) -> List[Dict[str, Any]]:
def filter_results_by_score(
results: List[Dict[str, Any]], min_score: float, score_metric: str = "similarity"
) -> List[Dict[str, Any]]:
"""
Filter results based on minimum score threshold.
Filter results based on score threshold, respecting the score metric type.

This function takes the raw results from a knowledge base query and removes
any items that don't meet the minimum relevance score threshold.
any items that don't meet the score threshold. The filtering direction depends
on the score metric:
- "similarity": higher scores = more relevant, keeps scores >= min_score
- "distance": lower scores = more relevant, keeps scores <= min_score

Args:
results: List of retrieval results from Bedrock Knowledge Base
min_score: Minimum score threshold (0.0-1.0). Only results with scores
greater than or equal to this value will be returned.
min_score: Score threshold for filtering. For similarity, only results with
scores >= this value are kept. For distance, only results with scores
<= this value are kept.
score_metric: How to interpret scores. "similarity" (default) means higher
is better; "distance" means lower is better.

Returns:
List of filtered results that meet or exceed the score threshold
List of filtered results that meet the score threshold
"""
if score_metric == "distance":
return [result for result in results if result.get("score", float("inf")) <= min_score]
return [result for result in results if result.get("score", 0.0) >= min_score]


Expand Down Expand Up @@ -317,7 +336,9 @@ def retrieve(tool: ToolUse, **kwargs: Any) -> ToolResult:
number_of_results = tool_input.get("numberOfResults", 10)
kb_id = tool_input.get("knowledgeBaseId", default_knowledge_base_id)
region_name = tool_input.get("region", default_aws_region)
min_score = tool_input.get("score", default_min_score)
score_metric = tool_input.get("score_metric", "similarity")
default_score = default_min_score if score_metric == "similarity" else float("inf")
min_score = tool_input.get("score", default_score)
enable_metadata = tool_input.get("enableMetadata", default_enable_metadata)
retrieve_filter = tool_input.get("retrieveFilter")

Expand Down Expand Up @@ -353,7 +374,7 @@ def retrieve(tool: ToolUse, **kwargs: Any) -> ToolResult:

# Get and filter results
all_results = response.get("retrievalResults", [])
filtered_results = filter_results_by_score(all_results, min_score)
filtered_results = filter_results_by_score(all_results, min_score, score_metric)

# Format results for display with optional metadata
formatted_results = format_results_for_display(filtered_results, enable_metadata)
Expand All @@ -363,7 +384,7 @@ def retrieve(tool: ToolUse, **kwargs: Any) -> ToolResult:
"toolUseId": tool_use_id,
"status": "success",
"content": [
{"text": f"Retrieved {len(filtered_results)} results with score >= {min_score}:\n{formatted_results}"}
{"text": f"Retrieved {len(filtered_results)} results with score {'<=' if score_metric == 'distance' else '>='} {min_score}:\n{formatted_results}"}
],
}

Expand Down
120 changes: 120 additions & 0 deletions tests/test_retrieve.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,3 +677,123 @@ def test_retrieve_via_agent_with_enable_metadata(agent, mock_boto3_client):
assert "results with score >=" in result_text
assert "Metadata:" not in result_text
assert "test-source" not in result_text


def test_filter_results_by_score_distance_metric():
"""Test filter_results_by_score with distance metric (lower = more relevant)."""
test_results = [
{"score": 0.1}, # Very close match (distance)
{"score": 0.3}, # Close match
{"score": 0.7}, # Far match
{"score": 1.2}, # Very far match
]

# Filter with distance threshold 0.5 — should keep scores <= 0.5
filtered = retrieve.filter_results_by_score(test_results, 0.5, score_metric="distance")
assert len(filtered) == 2
assert filtered[0]["score"] == 0.1
assert filtered[1]["score"] == 0.3

# Filter with distance threshold 0.8 — should keep scores <= 0.8
filtered = retrieve.filter_results_by_score(test_results, 0.8, score_metric="distance")
assert len(filtered) == 3
assert filtered[0]["score"] == 0.1
assert filtered[1]["score"] == 0.3
assert filtered[2]["score"] == 0.7


def test_filter_results_by_score_similarity_metric_explicit():
"""Test filter_results_by_score with explicit similarity metric (regression test)."""
test_results = [{"score": 0.9}, {"score": 0.5}, {"score": 0.3}, {"score": 0.8}]

# Explicit similarity should behave same as default
filtered = retrieve.filter_results_by_score(test_results, 0.5, score_metric="similarity")
assert len(filtered) == 3
assert filtered[0]["score"] == 0.9
assert filtered[1]["score"] == 0.5
assert filtered[2]["score"] == 0.8


def test_filter_results_by_score_default_unchanged():
"""Test that default behavior (no score_metric) is unchanged for backward compatibility."""
test_results = [{"score": 0.9}, {"score": 0.5}, {"score": 0.3}]

# Calling without score_metric should default to similarity (>= filter)
filtered = retrieve.filter_results_by_score(test_results, 0.5)
assert len(filtered) == 2
assert filtered[0]["score"] == 0.9
assert filtered[1]["score"] == 0.5


def test_filter_results_distance_default_score_for_missing():
"""Test that missing scores default to inf for distance metric (filtered out)."""
test_results = [{"score": 0.1}, {}, {"score": 0.3}]

# With distance metric, missing scores default to inf, so they should be filtered out
filtered = retrieve.filter_results_by_score(test_results, 0.5, score_metric="distance")
assert len(filtered) == 2
assert filtered[0]["score"] == 0.1
assert filtered[1]["score"] == 0.3


def test_retrieve_tool_distance_metric(mock_boto3_client):
"""Test direct invocation of retrieve tool with distance metric."""
# Override mock to return distance-style scores (low = good)
mock_boto3_client.return_value.retrieve.return_value = {
"retrievalResults": [
{
"content": {"text": "Close match", "type": "TEXT"},
"location": {"customDocumentLocation": {"id": "doc-001"}, "type": "CUSTOM"},
"score": 0.05,
},
{
"content": {"text": "Medium match", "type": "TEXT"},
"location": {"customDocumentLocation": {"id": "doc-002"}, "type": "CUSTOM"},
"score": 0.3,
},
{
"content": {"text": "Far match", "type": "TEXT"},
"location": {"customDocumentLocation": {"id": "doc-003"}, "type": "CUSTOM"},
"score": 0.9,
},
]
}

tool_use = {
"toolUseId": "test-tool-use-id",
"input": {
"text": "test query",
"knowledgeBaseId": "test-kb-id",
"score": 0.4,
"score_metric": "distance",
},
}

result = retrieve.retrieve(tool=tool_use)

assert result["status"] == "success"
# Should keep scores <= 0.4 (the close and medium matches)
assert "Retrieved 2 results with score <= 0.4" in result["content"][0]["text"]
assert "doc-001" in result["content"][0]["text"]
assert "doc-002" in result["content"][0]["text"]
# Far match (0.9) should be filtered out
assert "doc-003" not in result["content"][0]["text"]


def test_retrieve_tool_distance_metric_default_score(mock_boto3_client):
"""Test that distance metric without explicit score defaults to inf (keeps all results)."""
tool_use = {
"toolUseId": "test-tool-use-id",
"input": {
"text": "test query",
"knowledgeBaseId": "test-kb-id",
"score_metric": "distance",
# No score parameter — should default to inf, keeping all results
},
}

result = retrieve.retrieve(tool=tool_use)

assert result["status"] == "success"
# All 3 mock results should be kept since default threshold is inf
assert "Retrieved 3 results with score <= inf" in result["content"][0]["text"]
Loading