From f72df8da1f4222e9fdbb228a993d2e5d496dbe54 Mon Sep 17 00:00:00 2001 From: Paul Teehan Date: Tue, 28 Apr 2026 00:58:49 +0200 Subject: [PATCH 1/2] feat: add first-class SELECT DISTINCT support to SQL AST Add `distinct: bool = False` to the SELECT AST node so callers can write `SELECT(fields=[...], distinct=True)` and get `SELECT DISTINCT col1, col2 FROM ...` without string-replace hacks or abusing the DISTINCT expression node (which remains the correct model for aggregate-level use like `COUNT(DISTINCT x)`). The two SQL features share a keyword but are different grammar productions: - set quantifier on SELECT: deduplicates the whole result set - aggregate-level modifier: per-aggregate input dedup The DISTINCT expression node renders as `DISTINCT(...)` with parens, which is correct inside aggregates but rejected as a SELECT set quantifier by DB2, Athena/Presto/Trino, and BigQuery. Modeling them separately fixes that and lets callers drop fragile workarounds like `sql.replace("SELECT", "SELECT DISTINCT", 1)`. Co-Authored-By: Claude Opus 4.7 (1M context) --- soda-core/src/soda_core/common/sql_ast.py | 11 ++- soda-core/src/soda_core/common/sql_dialect.py | 7 +- .../tests/unit/test_postgres_dialect.py | 72 ++++++++++++++++++- 3 files changed, 85 insertions(+), 5 deletions(-) diff --git a/soda-core/src/soda_core/common/sql_ast.py b/soda-core/src/soda_core/common/sql_ast.py index afa972b8b..d5e7c80df 100644 --- a/soda-core/src/soda_core/common/sql_ast.py +++ b/soda-core/src/soda_core/common/sql_ast.py @@ -53,17 +53,22 @@ def check_context(self, context_node: type[BaseSqlExpression]) -> bool: @dataclass class SELECT(BaseSqlExpression): fields: SqlExpression | str | list[SqlExpression | str] + distinct: bool = False def __post_init__(self): super().__post_init__() self.handle_parent_node_update(self.fields) - # Check that the select contains a distinct and has multiple fields -> give a warning + # A DISTINCT expression inside SELECT.fields renders as DISTINCT(...) with + # parens, which is only correct inside an aggregate (e.g. COUNT(DISTINCT x)). + # For the set-quantifier form (SELECT DISTINCT a, b FROM t) use distinct=True. if isinstance(self.fields, list) and len(self.fields) > 1: if any(isinstance(field, DISTINCT) for field in self.fields): logger.warning( - """Found DISTINCT in a SELECT statement with multiple fields. - This might have unintended consequences.""" + "DISTINCT expression inside SELECT fields. " + "For set-quantifier deduplication, use SELECT(..., distinct=True). " + "The DISTINCT expression node is intended for aggregate-level use " + "(e.g. COUNT(DISTINCT x))." ) diff --git a/soda-core/src/soda_core/common/sql_dialect.py b/soda-core/src/soda_core/common/sql_dialect.py index 6cfaf8a2d..0fe35e3c5 100644 --- a/soda-core/src/soda_core/common/sql_dialect.py +++ b/soda-core/src/soda_core/common/sql_dialect.py @@ -688,9 +688,12 @@ def build_select_sql(self, select_elements: list, add_semicolon: Optional[bool] return "\n".join(statement_lines) + (";" if add_semicolon else "") def _build_select_sql_lines(self, select_elements: list) -> list[str]: + distinct = False select_field_sqls: list[str] = [] for select_element in select_elements: if isinstance(select_element, SELECT): + if select_element.distinct: + distinct = True if isinstance(select_element.fields, str) or isinstance(select_element.fields, SqlExpression): select_element.fields = [select_element.fields] for select_field in select_element.fields: @@ -705,10 +708,12 @@ def _build_select_sql_lines(self, select_elements: list) -> list[str]: # return "SELECT " + (", ".join(select_fields_sql)) # For now, we opt for SELECT statement readability... + select_keyword = "SELECT DISTINCT" if distinct else "SELECT" + select_sql_lines: list[str] = [] for i in range(0, len(select_field_sqls)): if i == 0: - sql_line = f"SELECT {select_field_sqls[0]}" + sql_line = f"{select_keyword} {select_field_sqls[0]}" else: sql_line = f" {select_field_sqls[i]}" # Append comma all lines except the last one diff --git a/soda-postgres/tests/unit/test_postgres_dialect.py b/soda-postgres/tests/unit/test_postgres_dialect.py index 06332f29f..4dd6c0028 100644 --- a/soda-postgres/tests/unit/test_postgres_dialect.py +++ b/soda-postgres/tests/unit/test_postgres_dialect.py @@ -1,5 +1,18 @@ import pytest -from soda_core.common.sql_dialect import FROM, RANDOM, SELECT, STAR, SamplerType +from soda_core.common.sql_dialect import ( + COUNT, + DISTINCT, + EQ, + FROM, + LIMIT, + OFFSET, + ORDER_BY_ASC, + RANDOM, + SELECT, + STAR, + WHERE, + SamplerType, +) from soda_postgres.common.data_sources.postgres_data_source import PostgresSqlDialect @@ -51,3 +64,60 @@ def test_random(): sql_dialect: PostgresSqlDialect = PostgresSqlDialect() sql = sql_dialect.build_select_sql([SELECT(RANDOM()), FROM("a")]) assert sql == 'SELECT RANDOM()\nFROM "a";' + + +@pytest.mark.parametrize( + "sql_ast, expected_sql", + [ + pytest.param( + [SELECT(fields=["a"], distinct=True), FROM("t")], + 'SELECT DISTINCT "a"\nFROM "t";', + id="select_distinct_single_column", + ), + pytest.param( + [SELECT(fields=["a", "b", "c"], distinct=True), FROM("t")], + 'SELECT DISTINCT "a",\n "b",\n "c"\nFROM "t";', + id="select_distinct_multiple_columns_no_parens", + ), + pytest.param( + [SELECT(fields=["a", "b"]), FROM("t")], + 'SELECT "a",\n "b"\nFROM "t";', + id="select_without_distinct_unchanged", + ), + pytest.param( + [ + SELECT(fields=["a", "b"], distinct=True), + FROM("t"), + WHERE(EQ("status", 1)), + ORDER_BY_ASC("a"), + LIMIT(100), + OFFSET(200), + ], + ( + 'SELECT DISTINCT "a",\n' + ' "b"\n' + 'FROM "t"\n' + 'WHERE "status" = 1\n' + 'ORDER BY "a" ASC\n' + "LIMIT 100\n" + "OFFSET 200;" + ), + id="select_distinct_paginated_full_shape", + ), + pytest.param( + [SELECT(fields=COUNT(DISTINCT(expression="x"))), FROM("t")], + 'SELECT COUNT(DISTINCT("x"))\nFROM "t";', + id="aggregate_level_distinct_preserves_parens", + ), + ], +) +def test_select_distinct(sql_ast, expected_sql): + sql_dialect: PostgresSqlDialect = PostgresSqlDialect() + assert sql_dialect.build_select_sql(sql_ast) == expected_sql + + +def test_select_distinct_default_is_false(): + # Backwards compatibility: omitting distinct must render SELECT (not SELECT DISTINCT). + sql_dialect: PostgresSqlDialect = PostgresSqlDialect() + sql = sql_dialect.build_select_sql([SELECT(["a"]), FROM("t")]) + assert sql == 'SELECT "a"\nFROM "t";' From 83442cf37eeafe6fda5f5ea6f70c2ed54128a99e Mon Sep 17 00:00:00 2001 From: Paul Teehan Date: Wed, 29 Apr 2026 13:32:35 +0200 Subject: [PATCH 2/2] Address review feedback on SELECT DISTINCT support - Align continuation-line indent to keyword width so multi-column SELECT DISTINCT output stays under the first field (16-space indent for SELECT DISTINCT, 7-space indent unchanged for SELECT) - Replace OR-aggregation of distinct across SELECT elements with a plain assignment; removes the "any flips all" footgun while preserving behaviour for the typical single-SELECT case - Add unit test for SELECT(STAR(), distinct=True) covering the SqlExpression-not-list branch - Add unit test asserting the updated warning fires when DISTINCT is nested inside SELECT.fields Co-Authored-By: Claude Opus 4.7 (1M context) --- soda-core/src/soda_core/common/sql_dialect.py | 6 ++--- .../tests/unit/test_postgres_dialect.py | 24 +++++++++++++++++-- 2 files changed, 25 insertions(+), 5 deletions(-) diff --git a/soda-core/src/soda_core/common/sql_dialect.py b/soda-core/src/soda_core/common/sql_dialect.py index 0fe35e3c5..dc18261d1 100644 --- a/soda-core/src/soda_core/common/sql_dialect.py +++ b/soda-core/src/soda_core/common/sql_dialect.py @@ -692,8 +692,7 @@ def _build_select_sql_lines(self, select_elements: list) -> list[str]: select_field_sqls: list[str] = [] for select_element in select_elements: if isinstance(select_element, SELECT): - if select_element.distinct: - distinct = True + distinct = select_element.distinct if isinstance(select_element.fields, str) or isinstance(select_element.fields, SqlExpression): select_element.fields = [select_element.fields] for select_field in select_element.fields: @@ -709,13 +708,14 @@ def _build_select_sql_lines(self, select_elements: list) -> list[str]: # For now, we opt for SELECT statement readability... select_keyword = "SELECT DISTINCT" if distinct else "SELECT" + continuation_indent = " " * (len(select_keyword) + 1) select_sql_lines: list[str] = [] for i in range(0, len(select_field_sqls)): if i == 0: sql_line = f"{select_keyword} {select_field_sqls[0]}" else: - sql_line = f" {select_field_sqls[i]}" + sql_line = f"{continuation_indent}{select_field_sqls[i]}" # Append comma all lines except the last one if i < len(select_field_sqls) - 1: sql_line += "," diff --git a/soda-postgres/tests/unit/test_postgres_dialect.py b/soda-postgres/tests/unit/test_postgres_dialect.py index 4dd6c0028..b963054f7 100644 --- a/soda-postgres/tests/unit/test_postgres_dialect.py +++ b/soda-postgres/tests/unit/test_postgres_dialect.py @@ -1,3 +1,5 @@ +import logging + import pytest from soda_core.common.sql_dialect import ( COUNT, @@ -76,9 +78,14 @@ def test_random(): ), pytest.param( [SELECT(fields=["a", "b", "c"], distinct=True), FROM("t")], - 'SELECT DISTINCT "a",\n "b",\n "c"\nFROM "t";', + 'SELECT DISTINCT "a",\n "b",\n "c"\nFROM "t";', id="select_distinct_multiple_columns_no_parens", ), + pytest.param( + [SELECT(STAR(), distinct=True), FROM("t")], + 'SELECT DISTINCT *\nFROM "t";', + id="select_distinct_star_expression", + ), pytest.param( [SELECT(fields=["a", "b"]), FROM("t")], 'SELECT "a",\n "b"\nFROM "t";', @@ -95,7 +102,7 @@ def test_random(): ], ( 'SELECT DISTINCT "a",\n' - ' "b"\n' + ' "b"\n' 'FROM "t"\n' 'WHERE "status" = 1\n' 'ORDER BY "a" ASC\n' @@ -121,3 +128,16 @@ def test_select_distinct_default_is_false(): sql_dialect: PostgresSqlDialect = PostgresSqlDialect() sql = sql_dialect.build_select_sql([SELECT(["a"]), FROM("t")]) assert sql == 'SELECT "a"\nFROM "t";' + + +def test_distinct_expression_inside_select_fields_warns(caplog): + # Nesting DISTINCT inside SELECT.fields with multiple fields should warn the + # caller to use the set-quantifier flag instead. + with caplog.at_level(logging.WARNING, logger="soda"): + SELECT(fields=[DISTINCT(expression="a"), "b"]) + + assert any( + "use SELECT(..., distinct=True)" in record.getMessage() + for record in caplog.records + if record.levelno == logging.WARNING + )