diff --git a/soda-core/src/soda_core/common/sql_ast.py b/soda-core/src/soda_core/common/sql_ast.py index afa972b8b..d5e7c80df 100644 --- a/soda-core/src/soda_core/common/sql_ast.py +++ b/soda-core/src/soda_core/common/sql_ast.py @@ -53,17 +53,22 @@ def check_context(self, context_node: type[BaseSqlExpression]) -> bool: @dataclass class SELECT(BaseSqlExpression): fields: SqlExpression | str | list[SqlExpression | str] + distinct: bool = False def __post_init__(self): super().__post_init__() self.handle_parent_node_update(self.fields) - # Check that the select contains a distinct and has multiple fields -> give a warning + # A DISTINCT expression inside SELECT.fields renders as DISTINCT(...) with + # parens, which is only correct inside an aggregate (e.g. COUNT(DISTINCT x)). + # For the set-quantifier form (SELECT DISTINCT a, b FROM t) use distinct=True. if isinstance(self.fields, list) and len(self.fields) > 1: if any(isinstance(field, DISTINCT) for field in self.fields): logger.warning( - """Found DISTINCT in a SELECT statement with multiple fields. - This might have unintended consequences.""" + "DISTINCT expression inside SELECT fields. " + "For set-quantifier deduplication, use SELECT(..., distinct=True). " + "The DISTINCT expression node is intended for aggregate-level use " + "(e.g. COUNT(DISTINCT x))." ) diff --git a/soda-core/src/soda_core/common/sql_dialect.py b/soda-core/src/soda_core/common/sql_dialect.py index 6cfaf8a2d..dc18261d1 100644 --- a/soda-core/src/soda_core/common/sql_dialect.py +++ b/soda-core/src/soda_core/common/sql_dialect.py @@ -688,9 +688,11 @@ def build_select_sql(self, select_elements: list, add_semicolon: Optional[bool] return "\n".join(statement_lines) + (";" if add_semicolon else "") def _build_select_sql_lines(self, select_elements: list) -> list[str]: + distinct = False select_field_sqls: list[str] = [] for select_element in select_elements: if isinstance(select_element, SELECT): + distinct = select_element.distinct if isinstance(select_element.fields, str) or isinstance(select_element.fields, SqlExpression): select_element.fields = [select_element.fields] for select_field in select_element.fields: @@ -705,12 +707,15 @@ def _build_select_sql_lines(self, select_elements: list) -> list[str]: # return "SELECT " + (", ".join(select_fields_sql)) # For now, we opt for SELECT statement readability... + select_keyword = "SELECT DISTINCT" if distinct else "SELECT" + continuation_indent = " " * (len(select_keyword) + 1) + select_sql_lines: list[str] = [] for i in range(0, len(select_field_sqls)): if i == 0: - sql_line = f"SELECT {select_field_sqls[0]}" + sql_line = f"{select_keyword} {select_field_sqls[0]}" else: - sql_line = f" {select_field_sqls[i]}" + sql_line = f"{continuation_indent}{select_field_sqls[i]}" # Append comma all lines except the last one if i < len(select_field_sqls) - 1: sql_line += "," diff --git a/soda-postgres/tests/unit/test_postgres_dialect.py b/soda-postgres/tests/unit/test_postgres_dialect.py index 06332f29f..b963054f7 100644 --- a/soda-postgres/tests/unit/test_postgres_dialect.py +++ b/soda-postgres/tests/unit/test_postgres_dialect.py @@ -1,5 +1,20 @@ +import logging + import pytest -from soda_core.common.sql_dialect import FROM, RANDOM, SELECT, STAR, SamplerType +from soda_core.common.sql_dialect import ( + COUNT, + DISTINCT, + EQ, + FROM, + LIMIT, + OFFSET, + ORDER_BY_ASC, + RANDOM, + SELECT, + STAR, + WHERE, + SamplerType, +) from soda_postgres.common.data_sources.postgres_data_source import PostgresSqlDialect @@ -51,3 +66,78 @@ def test_random(): sql_dialect: PostgresSqlDialect = PostgresSqlDialect() sql = sql_dialect.build_select_sql([SELECT(RANDOM()), FROM("a")]) assert sql == 'SELECT RANDOM()\nFROM "a";' + + +@pytest.mark.parametrize( + "sql_ast, expected_sql", + [ + pytest.param( + [SELECT(fields=["a"], distinct=True), FROM("t")], + 'SELECT DISTINCT "a"\nFROM "t";', + id="select_distinct_single_column", + ), + pytest.param( + [SELECT(fields=["a", "b", "c"], distinct=True), FROM("t")], + 'SELECT DISTINCT "a",\n "b",\n "c"\nFROM "t";', + id="select_distinct_multiple_columns_no_parens", + ), + pytest.param( + [SELECT(STAR(), distinct=True), FROM("t")], + 'SELECT DISTINCT *\nFROM "t";', + id="select_distinct_star_expression", + ), + pytest.param( + [SELECT(fields=["a", "b"]), FROM("t")], + 'SELECT "a",\n "b"\nFROM "t";', + id="select_without_distinct_unchanged", + ), + pytest.param( + [ + SELECT(fields=["a", "b"], distinct=True), + FROM("t"), + WHERE(EQ("status", 1)), + ORDER_BY_ASC("a"), + LIMIT(100), + OFFSET(200), + ], + ( + 'SELECT DISTINCT "a",\n' + ' "b"\n' + 'FROM "t"\n' + 'WHERE "status" = 1\n' + 'ORDER BY "a" ASC\n' + "LIMIT 100\n" + "OFFSET 200;" + ), + id="select_distinct_paginated_full_shape", + ), + pytest.param( + [SELECT(fields=COUNT(DISTINCT(expression="x"))), FROM("t")], + 'SELECT COUNT(DISTINCT("x"))\nFROM "t";', + id="aggregate_level_distinct_preserves_parens", + ), + ], +) +def test_select_distinct(sql_ast, expected_sql): + sql_dialect: PostgresSqlDialect = PostgresSqlDialect() + assert sql_dialect.build_select_sql(sql_ast) == expected_sql + + +def test_select_distinct_default_is_false(): + # Backwards compatibility: omitting distinct must render SELECT (not SELECT DISTINCT). + sql_dialect: PostgresSqlDialect = PostgresSqlDialect() + sql = sql_dialect.build_select_sql([SELECT(["a"]), FROM("t")]) + assert sql == 'SELECT "a"\nFROM "t";' + + +def test_distinct_expression_inside_select_fields_warns(caplog): + # Nesting DISTINCT inside SELECT.fields with multiple fields should warn the + # caller to use the set-quantifier flag instead. + with caplog.at_level(logging.WARNING, logger="soda"): + SELECT(fields=[DISTINCT(expression="a"), "b"]) + + assert any( + "use SELECT(..., distinct=True)" in record.getMessage() + for record in caplog.records + if record.levelno == logging.WARNING + )