Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions soda-core/src/soda_core/common/sql_ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,17 +53,22 @@ def check_context(self, context_node: type[BaseSqlExpression]) -> bool:
@dataclass
class SELECT(BaseSqlExpression):
fields: SqlExpression | str | list[SqlExpression | str]
distinct: bool = False

def __post_init__(self):
super().__post_init__()
self.handle_parent_node_update(self.fields)

# Check that the select contains a distinct and has multiple fields -> give a warning
# A DISTINCT expression inside SELECT.fields renders as DISTINCT(...) with
# parens, which is only correct inside an aggregate (e.g. COUNT(DISTINCT x)).
# For the set-quantifier form (SELECT DISTINCT a, b FROM t) use distinct=True.
if isinstance(self.fields, list) and len(self.fields) > 1:
if any(isinstance(field, DISTINCT) for field in self.fields):
logger.warning(
"""Found DISTINCT in a SELECT statement with multiple fields.
This might have unintended consequences."""
"DISTINCT expression inside SELECT fields. "
"For set-quantifier deduplication, use SELECT(..., distinct=True). "
"The DISTINCT expression node is intended for aggregate-level use "
"(e.g. COUNT(DISTINCT x))."
)


Expand Down
9 changes: 7 additions & 2 deletions soda-core/src/soda_core/common/sql_dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -688,9 +688,11 @@ def build_select_sql(self, select_elements: list, add_semicolon: Optional[bool]
return "\n".join(statement_lines) + (";" if add_semicolon else "")

def _build_select_sql_lines(self, select_elements: list) -> list[str]:
distinct = False
select_field_sqls: list[str] = []
for select_element in select_elements:
if isinstance(select_element, SELECT):
distinct = select_element.distinct
if isinstance(select_element.fields, str) or isinstance(select_element.fields, SqlExpression):
select_element.fields = [select_element.fields]
for select_field in select_element.fields:
Expand All @@ -705,12 +707,15 @@ def _build_select_sql_lines(self, select_elements: list) -> list[str]:
# return "SELECT " + (", ".join(select_fields_sql))
# For now, we opt for SELECT statement readability...

select_keyword = "SELECT DISTINCT" if distinct else "SELECT"
continuation_indent = " " * (len(select_keyword) + 1)

select_sql_lines: list[str] = []
for i in range(0, len(select_field_sqls)):
if i == 0:
sql_line = f"SELECT {select_field_sqls[0]}"
sql_line = f"{select_keyword} {select_field_sqls[0]}"
else:
sql_line = f" {select_field_sqls[i]}"
sql_line = f"{continuation_indent}{select_field_sqls[i]}"
# Append comma all lines except the last one
if i < len(select_field_sqls) - 1:
sql_line += ","
Expand Down
92 changes: 91 additions & 1 deletion soda-postgres/tests/unit/test_postgres_dialect.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,20 @@
import logging

import pytest
from soda_core.common.sql_dialect import FROM, RANDOM, SELECT, STAR, SamplerType
from soda_core.common.sql_dialect import (
COUNT,
DISTINCT,
EQ,
FROM,
LIMIT,
OFFSET,
ORDER_BY_ASC,
RANDOM,
SELECT,
STAR,
WHERE,
SamplerType,
)
from soda_postgres.common.data_sources.postgres_data_source import PostgresSqlDialect


Expand Down Expand Up @@ -51,3 +66,78 @@ def test_random():
sql_dialect: PostgresSqlDialect = PostgresSqlDialect()
sql = sql_dialect.build_select_sql([SELECT(RANDOM()), FROM("a")])
assert sql == 'SELECT RANDOM()\nFROM "a";'


@pytest.mark.parametrize(
"sql_ast, expected_sql",
[
pytest.param(
[SELECT(fields=["a"], distinct=True), FROM("t")],
'SELECT DISTINCT "a"\nFROM "t";',
id="select_distinct_single_column",
),
pytest.param(
[SELECT(fields=["a", "b", "c"], distinct=True), FROM("t")],
'SELECT DISTINCT "a",\n "b",\n "c"\nFROM "t";',
id="select_distinct_multiple_columns_no_parens",
),
pytest.param(
[SELECT(STAR(), distinct=True), FROM("t")],
'SELECT DISTINCT *\nFROM "t";',
id="select_distinct_star_expression",
),
pytest.param(
[SELECT(fields=["a", "b"]), FROM("t")],
'SELECT "a",\n "b"\nFROM "t";',
id="select_without_distinct_unchanged",
),
pytest.param(
[
SELECT(fields=["a", "b"], distinct=True),
FROM("t"),
WHERE(EQ("status", 1)),
ORDER_BY_ASC("a"),
LIMIT(100),
OFFSET(200),
],
(
'SELECT DISTINCT "a",\n'
' "b"\n'
'FROM "t"\n'
'WHERE "status" = 1\n'
'ORDER BY "a" ASC\n'
"LIMIT 100\n"
"OFFSET 200;"
),
id="select_distinct_paginated_full_shape",
),
pytest.param(
[SELECT(fields=COUNT(DISTINCT(expression="x"))), FROM("t")],
'SELECT COUNT(DISTINCT("x"))\nFROM "t";',
id="aggregate_level_distinct_preserves_parens",
),
],
)
def test_select_distinct(sql_ast, expected_sql):
sql_dialect: PostgresSqlDialect = PostgresSqlDialect()
assert sql_dialect.build_select_sql(sql_ast) == expected_sql


def test_select_distinct_default_is_false():
# Backwards compatibility: omitting distinct must render SELECT (not SELECT DISTINCT).
sql_dialect: PostgresSqlDialect = PostgresSqlDialect()
sql = sql_dialect.build_select_sql([SELECT(["a"]), FROM("t")])
assert sql == 'SELECT "a"\nFROM "t";'


def test_distinct_expression_inside_select_fields_warns(caplog):
# Nesting DISTINCT inside SELECT.fields with multiple fields should warn the
# caller to use the set-quantifier flag instead.
with caplog.at_level(logging.WARNING, logger="soda"):
SELECT(fields=[DISTINCT(expression="a"), "b"])

assert any(
"use SELECT(..., distinct=True)" in record.getMessage()
for record in caplog.records
if record.levelno == logging.WARNING
)
Loading