diff --git a/hindsight-api-slim/hindsight_api/engine/search/link_expansion_retrieval.py b/hindsight-api-slim/hindsight_api/engine/search/link_expansion_retrieval.py index 6e903d95e..bd844182f 100644 --- a/hindsight-api-slim/hindsight_api/engine/search/link_expansion_retrieval.py +++ b/hindsight-api-slim/hindsight_api/engine/search/link_expansion_retrieval.py @@ -251,7 +251,7 @@ async def retrieve( result.activation = row["score"] results.append(result) - if tags: + if tags or tags_match == "exact": results = filter_results_by_tags(results, tags, match=tags_match) if tag_groups: diff --git a/hindsight-api-slim/hindsight_api/engine/search/tags.py b/hindsight-api-slim/hindsight_api/engine/search/tags.py index a14032fc5..eef10ee10 100644 --- a/hindsight-api-slim/hindsight_api/engine/search/tags.py +++ b/hindsight-api-slim/hindsight_api/engine/search/tags.py @@ -7,13 +7,14 @@ - "all": AND matching, includes untagged memories - "any_strict": OR matching, excludes untagged memories - "all_strict": AND matching, excludes untagged memories -- "exact": set-equality matching, excludes untagged memories +- "exact": set-equality matching; an empty tag set selects untagged memories OR matching (any/any_strict): Memory matches if ANY of its tags overlap with request tags AND matching (all/all_strict): Memory matches if ALL request tags are present in its tags EXACT matching: Memory matches only if its tag set EQUALS the request tag set (order- independent). Used for observation "scope" filtering, where each observation lives - under exactly one scope (its full tag set) and "scope [a]" must not match "[a, b]". + under exactly one scope (its full tag set), the empty set identifies the global scope, + and "scope [a]" must not match "[a, b]". """ from __future__ import annotations @@ -32,7 +33,7 @@ def _parse_tags_match(match: TagsMatch) -> tuple[str, bool]: Returns: Tuple of (operator, include_untagged) - operator: "&&" for any/any_strict, "@>" for all/all_strict - - include_untagged: True for any/all, False for any_strict/all_strict + - include_untagged: True for any/all, False for any_strict/all_strict/exact """ if match == "any": return "&&", True @@ -43,8 +44,7 @@ def _parse_tags_match(match: TagsMatch) -> tuple[str, bool]: elif match == "all_strict": return "@>", False elif match == "exact": - # Set equality is handled by the callers via `@> AND <@`; the operator - # here is unused. Untagged rows never equal a non-empty scope. + # Set equality is handled by the callers; the operator here is unused. return "@>", False else: # Default to "any" behavior @@ -60,14 +60,16 @@ def build_tags_where_clause( """ Build a SQL WHERE clause for filtering by tags. - Supports four matching modes: + Supports five matching modes: - "any" (default): OR matching, includes untagged memories - "all": AND matching, includes untagged memories - "any_strict": OR matching, excludes untagged memories - "all_strict": AND matching, excludes untagged memories + - "exact": set equality; an empty tag set selects only untagged memories Args: - tags: List of tags to filter by. If None or empty, returns empty clause (no filtering). + tags: List of tags to filter by. If None or empty, returns no filter except + in exact mode, where it selects the global/untagged scope. param_offset: Starting parameter number for SQL placeholders (default 1). table_alias: Optional table alias prefix (e.g., "mu." for "memory_units mu"). match: Matching mode. Defaults to "any". @@ -82,17 +84,21 @@ def build_tags_where_clause( >>> clause, params, next_offset = build_tags_where_clause(['user_a'], 3, 'mu.', 'any_strict') >>> print(clause) # "AND mu.tags IS NOT NULL AND mu.tags != '{}' AND mu.tags && $3" """ - if not tags: - return "", [], param_offset - column = f"{table_alias}tags" if table_alias else "tags" if match == "exact": + if not tags: + # Exact equality with the empty set is the global observation scope. + # Handle both historical NULLs and the current empty-array storage. + return f"AND ({column} IS NULL OR {column} = '{{}}')", [], param_offset # Set equality (order-independent): superset AND subset. Untagged rows # (empty array) never satisfy `@>` of a non-empty scope, so they're excluded. clause = f"AND ({column} @> ${param_offset} AND {column} <@ ${param_offset})" return clause, [tags], param_offset + 1 + if not tags: + return "", [], param_offset + operator, include_untagged = _parse_tags_match(match) if include_untagged: @@ -118,7 +124,8 @@ def build_tags_where_clause_simple( assuming the caller will add the tags array to their params list. Args: - tags: List of tags to filter by. If None or empty, returns empty string. + tags: List of tags to filter by. If None or empty, returns no filter except + in exact mode, where it selects the global/untagged scope. param_num: Parameter number to use in the clause. table_alias: Optional table alias prefix. match: Matching mode. Defaults to "any". @@ -126,16 +133,18 @@ def build_tags_where_clause_simple( Returns: SQL clause string or empty string. """ - if not tags: - return "" - column = f"{table_alias}tags" if table_alias else "tags" if match == "exact": + if not tags: + return f"AND ({column} IS NULL OR {column} = '{{}}')" # Set equality (order-independent): superset AND subset. Untagged rows # (empty array) never satisfy `@>` of a non-empty scope, so they're excluded. return f"AND ({column} @> ${param_num} AND {column} <@ ${param_num})" + if not tags: + return "" + operator, include_untagged = _parse_tags_match(match) if include_untagged: @@ -158,15 +167,19 @@ def filter_results_by_tags( Args: results: List of RetrievalResult objects with a 'tags' attribute. - tags: List of tags to filter by. If None or empty, returns all results. + tags: List of tags to filter by. If None or empty, returns all results except + in exact mode, where it selects the global/untagged scope. match: Matching mode. Defaults to "any". Returns: Filtered list of results. """ - if not tags: + if not tags and match != "exact": return results + if not tags: + return [r for r in results if not (getattr(r, "tags", None) or [])] + _, include_untagged = _parse_tags_match(match) is_any_match = match in ("any", "any_strict") @@ -267,6 +280,8 @@ def _build_group_clause( if isinstance(group, TagGroupLeaf): column = f"{table_alias}tags" if table_alias else "tags" if group.match == "exact": + if not group.tags: + return f"({column} IS NULL OR {column} = '{{}}')", [], param_offset clause = f"({column} @> ${param_offset} AND {column} <@ ${param_offset})" return clause, [group.tags], param_offset + 1 operator, include_untagged = _parse_tags_match(group.match) @@ -373,12 +388,13 @@ def _match_group(result: object, group: TagGroup) -> bool: is_any_match = group.match in ("any", "any_strict") tags_set = set(group.tags) + if group.match == "exact": + return is_untagged if not tags_set else not is_untagged and set(result_tags) == tags_set + if is_untagged: return include_untagged else: result_tags_set = set(result_tags) - if group.match == "exact": - return result_tags_set == tags_set if is_any_match: return bool(result_tags_set & tags_set) else: diff --git a/hindsight-api-slim/tests/test_tags_visibility.py b/hindsight-api-slim/tests/test_tags_visibility.py index 138ea16f1..20620dce1 100644 --- a/hindsight-api-slim/tests/test_tags_visibility.py +++ b/hindsight-api-slim/tests/test_tags_visibility.py @@ -23,6 +23,7 @@ TagGroupNot, TagGroupOr, build_tag_groups_where_clause, + build_tags_where_clause, build_tags_where_clause_simple, filter_results_by_tag_groups, filter_results_by_tags, @@ -136,6 +137,20 @@ def test_tags_match_exact_with_table_alias(self): assert "@>" in result assert "<@" in result + @pytest.mark.parametrize("tags", [None, []]) + def test_tags_match_exact_empty_scope_selects_untagged(self, tags): + """An empty exact scope matches both NULL and empty-array tag storage.""" + result = build_tags_where_clause_simple(tags, 3, table_alias="mu.", match="exact") + assert result == "AND (mu.tags IS NULL OR mu.tags = '{}')" + assert "$3" not in result + + def test_tags_match_exact_empty_scope_does_not_consume_parameter(self): + """The parameterized builder must keep following bind indexes aligned.""" + clause, params, next_offset = build_tags_where_clause([], param_offset=4, match="exact") + assert clause == "AND (tags IS NULL OR tags = '{}')" + assert params == [] + assert next_offset == 4 + # ---- Test table alias with all modes ---- def test_tags_match_any_with_table_alias(self): @@ -255,6 +270,13 @@ def test_exact_mode_excludes_untagged(self): assert len(filtered) == 1 assert filtered[0].tags == ["a"] + @pytest.mark.parametrize("tags", [None, []]) + def test_exact_empty_scope_matches_only_untagged(self, tags): + """An empty exact scope includes NULL/empty tags and excludes tagged rows.""" + results = [MockResult(["a"]), MockResult(None), MockResult([])] + filtered = filter_results_by_tags(results, tags, match="exact") + assert [r.tags for r in filtered] == [None, []] + def test_all_mode_includes_untagged(self): """'all' mode should include untagged results.""" results = [MockResult(["a", "b"]), MockResult(None), MockResult([])] @@ -367,6 +389,14 @@ def test_single_leaf_any_includes_untagged(self): assert params == [["user:alice"]] assert next_offset == 2 + def test_exact_empty_leaf_selects_untagged_without_parameter(self): + """An exact empty compound leaf represents the global scope.""" + groups = [TagGroupLeaf(tags=[], match="exact")] + clause, params, next_offset = build_tag_groups_where_clause(groups, 3, table_alias="mu.") + assert clause == "AND (mu.tags IS NULL OR mu.tags = '{}')" + assert params == [] + assert next_offset == 3 + def test_and_of_two_leaves(self): """AND of two leaves generates AND-joined clause.""" groups = [ @@ -543,6 +573,13 @@ def test_single_leaf_all_strict_matches_superset(self): filtered = filter_results_by_tag_groups(results, groups) assert len(filtered) == 2 + def test_exact_empty_leaf_matches_only_untagged(self): + """Python compound filtering uses the same global-scope semantics as SQL.""" + groups = [TagGroupLeaf(tags=[], match="exact")] + results = [MockResult(["project"]), MockResult(None), MockResult([])] + filtered = filter_results_by_tag_groups(results, groups) + assert [r.tags for r in filtered] == [None, []] + def test_and_both_conditions_must_match(self): """AND group: both leaf conditions must match.""" groups = [ @@ -904,6 +941,35 @@ async def test_recall_with_empty_tags_returns_all(api_client, test_bank_id): assert any("Rachel" in t for t in texts), "Should find Rachel" +@pytest.mark.asyncio +async def test_recall_with_exact_empty_tags_returns_only_untagged(api_client, test_bank_id): + """Exact matching with an empty tag set recalls only the global scope.""" + response = await api_client.post( + f"/v1/default/banks/{test_bank_id}/memories", + json={ + "items": [ + {"content": "The deployment region is Singapore."}, + {"content": "The deployment region is Frankfurt.", "tags": ["project:eu"]}, + ] + }, + ) + assert response.status_code == 200 + + response = await api_client.post( + f"/v1/default/banks/{test_bank_id}/memories/recall", + json={ + "query": "What is the deployment region?", + "budget": "low", + "tags": [], + "tags_match": "exact", + }, + ) + assert response.status_code == 200 + texts = [result["text"] for result in response.json()["results"]] + assert any("Singapore" in text for text in texts) + assert not any("Frankfurt" in text for text in texts) + + @pytest.mark.asyncio async def test_multi_user_agent_visibility(api_client): """