Skip to content

Commit 86a5adc

Browse files
committed
handle redshift append clause with custom dialect, clean up table extractor and add more descriptive docstrings
1 parent 0b26278 commit 86a5adc

6 files changed

Lines changed: 342 additions & 144 deletions

File tree

sql_metadata/ast_parser.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from typing import Optional
1111

1212
from sqlglot import exp
13+
from sqlglot.dialects.dialect import DialectType
1314

1415
from sql_metadata.dialect_parser import DialectParser
1516
from sql_metadata.sql_cleaner import SqlCleaner
@@ -30,7 +31,7 @@ class ASTParser:
3031
def __init__(self, sql: str) -> None:
3132
self._raw_sql = sql
3233
self._ast: Optional[exp.Expression] = None
33-
self._dialect: object = None
34+
self._dialect: DialectType = None
3435
self._parsed = False
3536
self._is_replace = False
3637
self._cte_name_map: dict[str, str] = {}
@@ -50,7 +51,7 @@ def ast(self) -> Optional[exp.Expression]:
5051
return self._ast
5152

5253
@property
53-
def dialect(self) -> object:
54+
def dialect(self) -> DialectType:
5455
"""The sqlglot dialect that produced the current AST.
5556
5657
Set as a side-effect of :attr:`ast` access. May be ``None``

sql_metadata/dialect_parser.py

Lines changed: 33 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,12 @@ class so that callers only need to call :meth:`DialectParser.parse`.
99
from typing import Any, Optional
1010

1111
import sqlglot
12-
from sqlglot import Dialect, exp
12+
from sqlglot import exp
13+
from sqlglot.dialects.dialect import Dialect, DialectType
14+
from sqlglot.dialects.redshift import Redshift
1315
from sqlglot.dialects.tsql import TSQL
1416
from sqlglot.errors import ParseError, TokenError
17+
from sqlglot.parsers.redshift import RedshiftParser
1518
from sqlglot.tokens import Tokenizer as BaseTokenizer
1619

1720
from sql_metadata.comments import _has_hash_variables
@@ -47,6 +50,31 @@ class Tokenizer(BaseTokenizer):
4750
VAR_SINGLE_TOKENS = {*BaseTokenizer.VAR_SINGLE_TOKENS, "#"}
4851

4952

53+
class _RedshiftAppendParser(RedshiftParser):
54+
"""Redshift parser extended with ``ALTER TABLE ... APPEND FROM``."""
55+
56+
def _parse_alter_table_append(self) -> "exp.Expr | None":
57+
self._match_text_seq("FROM")
58+
return self._parse_table()
59+
60+
ALTER_PARSERS = {
61+
**RedshiftParser.ALTER_PARSERS,
62+
"APPEND": lambda self: self._parse_alter_table_append(),
63+
}
64+
65+
66+
class RedshiftAppendDialect(Redshift):
67+
"""Redshift dialect extended with ``ALTER TABLE ... APPEND FROM`` support.
68+
69+
Redshift's ``APPEND FROM`` syntax is not natively supported by sqlglot,
70+
which causes the statement to degrade to ``exp.Command``. This dialect
71+
adds an ``APPEND`` entry to ``ALTER_PARSERS`` so the statement is parsed
72+
as a proper ``exp.Alter`` with ``exp.Table`` nodes.
73+
"""
74+
75+
Parser = _RedshiftAppendParser
76+
77+
5078
class BracketedTableDialect(TSQL):
5179
"""TSQL dialect for queries containing ``[bracketed]`` identifiers.
5280
@@ -65,7 +93,7 @@ class BracketedTableDialect(TSQL):
6593
class DialectParser:
6694
"""Detect the appropriate sqlglot dialect and parse SQL into an AST."""
6795

68-
def parse(self, clean_sql: str) -> tuple[exp.Expression, object]:
96+
def parse(self, clean_sql: str) -> tuple[exp.Expression, DialectType]:
6997
"""Parse *clean_sql*, returning ``(ast, dialect)``.
7098
7199
Detects candidate dialects via heuristics, tries each in order,
@@ -110,13 +138,15 @@ def _detect_dialects(sql: str) -> list:
110138
return [BracketedTableDialect, None, "mysql"]
111139
if " UNIQUE " in upper:
112140
return [None, "mysql", "oracle"]
141+
if "APPEND FROM" in upper:
142+
return [RedshiftAppendDialect, None, "mysql"]
113143
return [None, "mysql"]
114144

115145
# -- parsing ------------------------------------------------------------
116146

117147
def _try_dialects(
118148
self, clean_sql: str, dialects: list
119-
) -> tuple[exp.Expression, object]:
149+
) -> tuple[exp.Expression, DialectType]:
120150
"""Try parsing *clean_sql* with each dialect, returning the best.
121151
122152
:returns: 2-tuple of ``(ast_node, winning_dialect)``.

sql_metadata/parser.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -287,11 +287,13 @@ def tables(self) -> List[str]:
287287
if self._tables is not None:
288288
return self._tables
289289
_ = self.query_type
290+
ast = self._ast_parser.ast
291+
assert ast is not None # guaranteed by query_type raising on None
290292
cte_names = set(self.with_names)
291293
for placeholder in self._ast_parser.cte_name_map:
292294
cte_names.add(placeholder)
293295
extractor = TableExtractor(
294-
self._ast_parser.ast,
296+
ast,
295297
self._raw_query,
296298
cte_names,
297299
dialect=self._ast_parser.dialect,
@@ -304,7 +306,9 @@ def tables_aliases(self) -> Dict[str, str]:
304306
"""Return the table alias mapping for this query."""
305307
if self._table_aliases is not None:
306308
return self._table_aliases
307-
extractor = TableExtractor(self._ast_parser.ast)
309+
ast = self._ast_parser.ast
310+
assert ast is not None # guaranteed by prior tables/query_type access
311+
extractor = TableExtractor(ast)
308312
self._table_aliases = extractor.extract_aliases(self.tables)
309313
return self._table_aliases
310314

sql_metadata/query_type_extractor.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -89,10 +89,15 @@ def _unwrap_parens(ast: exp.Expression) -> exp.Expression:
8989

9090
@staticmethod
9191
def _resolve_command_type(root: exp.Expression) -> Optional[QueryType]:
92-
"""Determine query type for an opaque Command node."""
92+
"""Determine query type for an opaque ``exp.Command`` node.
93+
94+
Hive ``CREATE FUNCTION ... USING JAR ... WITH SERDEPROPERTIES``
95+
is not supported by any sqlglot dialect and degrades to
96+
``exp.Command(this='CREATE', ...)``. This fallback extracts
97+
the query type from the command text so callers still get
98+
``QueryType.CREATE``.
99+
"""
93100
expression_text = str(root.this).upper() if root.this else ""
94-
if expression_text == "ALTER":
95-
return QueryType.ALTER
96101
if expression_text == "CREATE":
97102
return QueryType.CREATE
98103
return None

0 commit comments

Comments
 (0)