Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -251,7 +251,7 @@ async def retrieve(
result.activation = row["score"]
results.append(result)

if tags:
if tags or tags_match == "exact":
results = filter_results_by_tags(results, tags, match=tags_match)

if tag_groups:
Expand Down
52 changes: 34 additions & 18 deletions hindsight-api-slim/hindsight_api/engine/search/tags.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,14 @@
- "all": AND matching, includes untagged memories
- "any_strict": OR matching, excludes untagged memories
- "all_strict": AND matching, excludes untagged memories
- "exact": set-equality matching, excludes untagged memories
- "exact": set-equality matching; an empty tag set selects untagged memories

OR matching (any/any_strict): Memory matches if ANY of its tags overlap with request tags
AND matching (all/all_strict): Memory matches if ALL request tags are present in its tags
EXACT matching: Memory matches only if its tag set EQUALS the request tag set (order-
independent). Used for observation "scope" filtering, where each observation lives
under exactly one scope (its full tag set) and "scope [a]" must not match "[a, b]".
under exactly one scope (its full tag set), the empty set identifies the global scope,
and "scope [a]" must not match "[a, b]".
"""

from __future__ import annotations
Expand All @@ -32,7 +33,7 @@ def _parse_tags_match(match: TagsMatch) -> tuple[str, bool]:
Returns:
Tuple of (operator, include_untagged)
- operator: "&&" for any/any_strict, "@>" for all/all_strict
- include_untagged: True for any/all, False for any_strict/all_strict
- include_untagged: True for any/all, False for any_strict/all_strict/exact
"""
if match == "any":
return "&&", True
Expand All @@ -43,8 +44,7 @@ def _parse_tags_match(match: TagsMatch) -> tuple[str, bool]:
elif match == "all_strict":
return "@>", False
elif match == "exact":
# Set equality is handled by the callers via `@> AND <@`; the operator
# here is unused. Untagged rows never equal a non-empty scope.
# Set equality is handled by the callers; the operator here is unused.
return "@>", False
else:
# Default to "any" behavior
Expand All @@ -60,14 +60,16 @@ def build_tags_where_clause(
"""
Build a SQL WHERE clause for filtering by tags.

Supports four matching modes:
Supports five matching modes:
- "any" (default): OR matching, includes untagged memories
- "all": AND matching, includes untagged memories
- "any_strict": OR matching, excludes untagged memories
- "all_strict": AND matching, excludes untagged memories
- "exact": set equality; an empty tag set selects only untagged memories

Args:
tags: List of tags to filter by. If None or empty, returns empty clause (no filtering).
tags: List of tags to filter by. If None or empty, returns no filter except
in exact mode, where it selects the global/untagged scope.
param_offset: Starting parameter number for SQL placeholders (default 1).
table_alias: Optional table alias prefix (e.g., "mu." for "memory_units mu").
match: Matching mode. Defaults to "any".
Expand All @@ -82,17 +84,21 @@ def build_tags_where_clause(
>>> clause, params, next_offset = build_tags_where_clause(['user_a'], 3, 'mu.', 'any_strict')
>>> print(clause) # "AND mu.tags IS NOT NULL AND mu.tags != '{}' AND mu.tags && $3"
"""
if not tags:
return "", [], param_offset

column = f"{table_alias}tags" if table_alias else "tags"

if match == "exact":
if not tags:
# Exact equality with the empty set is the global observation scope.
# Handle both historical NULLs and the current empty-array storage.
return f"AND ({column} IS NULL OR {column} = '{{}}')", [], param_offset
# Set equality (order-independent): superset AND subset. Untagged rows
# (empty array) never satisfy `@>` of a non-empty scope, so they're excluded.
clause = f"AND ({column} @> ${param_offset} AND {column} <@ ${param_offset})"
return clause, [tags], param_offset + 1

if not tags:
return "", [], param_offset

operator, include_untagged = _parse_tags_match(match)

if include_untagged:
Expand All @@ -118,24 +124,27 @@ def build_tags_where_clause_simple(
assuming the caller will add the tags array to their params list.

Args:
tags: List of tags to filter by. If None or empty, returns empty string.
tags: List of tags to filter by. If None or empty, returns no filter except
in exact mode, where it selects the global/untagged scope.
param_num: Parameter number to use in the clause.
table_alias: Optional table alias prefix.
match: Matching mode. Defaults to "any".

Returns:
SQL clause string or empty string.
"""
if not tags:
return ""

column = f"{table_alias}tags" if table_alias else "tags"

if match == "exact":
if not tags:
return f"AND ({column} IS NULL OR {column} = '{{}}')"
# Set equality (order-independent): superset AND subset. Untagged rows
# (empty array) never satisfy `@>` of a non-empty scope, so they're excluded.
return f"AND ({column} @> ${param_num} AND {column} <@ ${param_num})"

if not tags:
return ""

operator, include_untagged = _parse_tags_match(match)

if include_untagged:
Expand All @@ -158,15 +167,19 @@ def filter_results_by_tags(

Args:
results: List of RetrievalResult objects with a 'tags' attribute.
tags: List of tags to filter by. If None or empty, returns all results.
tags: List of tags to filter by. If None or empty, returns all results except
in exact mode, where it selects the global/untagged scope.
match: Matching mode. Defaults to "any".

Returns:
Filtered list of results.
"""
if not tags:
if not tags and match != "exact":
return results

if not tags:
return [r for r in results if not (getattr(r, "tags", None) or [])]

_, include_untagged = _parse_tags_match(match)
is_any_match = match in ("any", "any_strict")

Expand Down Expand Up @@ -267,6 +280,8 @@ def _build_group_clause(
if isinstance(group, TagGroupLeaf):
column = f"{table_alias}tags" if table_alias else "tags"
if group.match == "exact":
if not group.tags:
return f"({column} IS NULL OR {column} = '{{}}')", [], param_offset
clause = f"({column} @> ${param_offset} AND {column} <@ ${param_offset})"
return clause, [group.tags], param_offset + 1
operator, include_untagged = _parse_tags_match(group.match)
Expand Down Expand Up @@ -373,12 +388,13 @@ def _match_group(result: object, group: TagGroup) -> bool:
is_any_match = group.match in ("any", "any_strict")
tags_set = set(group.tags)

if group.match == "exact":
return is_untagged if not tags_set else not is_untagged and set(result_tags) == tags_set

if is_untagged:
return include_untagged
else:
result_tags_set = set(result_tags)
if group.match == "exact":
return result_tags_set == tags_set
if is_any_match:
return bool(result_tags_set & tags_set)
else:
Expand Down
66 changes: 66 additions & 0 deletions hindsight-api-slim/tests/test_tags_visibility.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
TagGroupNot,
TagGroupOr,
build_tag_groups_where_clause,
build_tags_where_clause,
build_tags_where_clause_simple,
filter_results_by_tag_groups,
filter_results_by_tags,
Expand Down Expand Up @@ -136,6 +137,20 @@ def test_tags_match_exact_with_table_alias(self):
assert "@>" in result
assert "<@" in result

@pytest.mark.parametrize("tags", [None, []])
def test_tags_match_exact_empty_scope_selects_untagged(self, tags):
"""An empty exact scope matches both NULL and empty-array tag storage."""
result = build_tags_where_clause_simple(tags, 3, table_alias="mu.", match="exact")
assert result == "AND (mu.tags IS NULL OR mu.tags = '{}')"
assert "$3" not in result

def test_tags_match_exact_empty_scope_does_not_consume_parameter(self):
"""The parameterized builder must keep following bind indexes aligned."""
clause, params, next_offset = build_tags_where_clause([], param_offset=4, match="exact")
assert clause == "AND (tags IS NULL OR tags = '{}')"
assert params == []
assert next_offset == 4

# ---- Test table alias with all modes ----

def test_tags_match_any_with_table_alias(self):
Expand Down Expand Up @@ -255,6 +270,13 @@ def test_exact_mode_excludes_untagged(self):
assert len(filtered) == 1
assert filtered[0].tags == ["a"]

@pytest.mark.parametrize("tags", [None, []])
def test_exact_empty_scope_matches_only_untagged(self, tags):
"""An empty exact scope includes NULL/empty tags and excludes tagged rows."""
results = [MockResult(["a"]), MockResult(None), MockResult([])]
filtered = filter_results_by_tags(results, tags, match="exact")
assert [r.tags for r in filtered] == [None, []]

def test_all_mode_includes_untagged(self):
"""'all' mode should include untagged results."""
results = [MockResult(["a", "b"]), MockResult(None), MockResult([])]
Expand Down Expand Up @@ -367,6 +389,14 @@ def test_single_leaf_any_includes_untagged(self):
assert params == [["user:alice"]]
assert next_offset == 2

def test_exact_empty_leaf_selects_untagged_without_parameter(self):
"""An exact empty compound leaf represents the global scope."""
groups = [TagGroupLeaf(tags=[], match="exact")]
clause, params, next_offset = build_tag_groups_where_clause(groups, 3, table_alias="mu.")
assert clause == "AND (mu.tags IS NULL OR mu.tags = '{}')"
assert params == []
assert next_offset == 3

def test_and_of_two_leaves(self):
"""AND of two leaves generates AND-joined clause."""
groups = [
Expand Down Expand Up @@ -543,6 +573,13 @@ def test_single_leaf_all_strict_matches_superset(self):
filtered = filter_results_by_tag_groups(results, groups)
assert len(filtered) == 2

def test_exact_empty_leaf_matches_only_untagged(self):
"""Python compound filtering uses the same global-scope semantics as SQL."""
groups = [TagGroupLeaf(tags=[], match="exact")]
results = [MockResult(["project"]), MockResult(None), MockResult([])]
filtered = filter_results_by_tag_groups(results, groups)
assert [r.tags for r in filtered] == [None, []]

def test_and_both_conditions_must_match(self):
"""AND group: both leaf conditions must match."""
groups = [
Expand Down Expand Up @@ -904,6 +941,35 @@ async def test_recall_with_empty_tags_returns_all(api_client, test_bank_id):
assert any("Rachel" in t for t in texts), "Should find Rachel"


@pytest.mark.asyncio
async def test_recall_with_exact_empty_tags_returns_only_untagged(api_client, test_bank_id):
"""Exact matching with an empty tag set recalls only the global scope."""
response = await api_client.post(
f"/v1/default/banks/{test_bank_id}/memories",
json={
"items": [
{"content": "The deployment region is Singapore."},
{"content": "The deployment region is Frankfurt.", "tags": ["project:eu"]},
]
},
)
assert response.status_code == 200

response = await api_client.post(
f"/v1/default/banks/{test_bank_id}/memories/recall",
json={
"query": "What is the deployment region?",
"budget": "low",
"tags": [],
"tags_match": "exact",
},
)
assert response.status_code == 200
texts = [result["text"] for result in response.json()["results"]]
assert any("Singapore" in text for text in texts)
assert not any("Frankfurt" in text for text in texts)


@pytest.mark.asyncio
async def test_multi_user_agent_visibility(api_client):
"""
Expand Down