Skip to content

Commit 2dd4685

Browse files
committed
fix typing to go with 3.10 flow not deprecated typing ones. add support for nested ctes, cleanup nested resolver and simplify code there. remove unnecessary guards
1 parent 86a5adc commit 2dd4685

12 files changed

Lines changed: 585 additions & 367 deletions

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ warn_return_any = true
4444
warn_unused_configs = true
4545
check_untyped_defs = true
4646
disallow_untyped_defs = true
47+
disallow_any_generics = true
4748
ignore_missing_imports = true
4849

4950
[tool.coverage.run]

sql_metadata/ast_parser.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,6 @@
77
``ValueError``).
88
"""
99

10-
from typing import Optional
11-
1210
from sqlglot import exp
1311
from sqlglot.dialects.dialect import DialectType
1412

@@ -30,14 +28,14 @@ class ASTParser:
3028

3129
def __init__(self, sql: str) -> None:
3230
self._raw_sql = sql
33-
self._ast: Optional[exp.Expression] = None
31+
self._ast: exp.Expression | None = None
3432
self._dialect: DialectType = None
3533
self._parsed = False
3634
self._is_replace = False
3735
self._cte_name_map: dict[str, str] = {}
3836

3937
@property
40-
def ast(self) -> Optional[exp.Expression]:
38+
def ast(self) -> exp.Expression | None:
4139
"""The sqlglot AST for the query, lazily parsed on first access.
4240
4341
:returns: Root AST node, or ``None`` for empty/comment-only queries.
@@ -74,7 +72,7 @@ def is_replace(self) -> bool:
7472
return self._is_replace
7573

7674
@property
77-
def cte_name_map(self) -> dict:
75+
def cte_name_map(self) -> dict[str, str]:
7876
"""Map of placeholder CTE names back to their original qualified form.
7977
8078
Keys are underscore-separated placeholders (``db__DOT__name``),
@@ -83,7 +81,7 @@ def cte_name_map(self) -> dict:
8381
_ = self.ast
8482
return self._cte_name_map
8583

86-
def _parse(self, sql: str) -> Optional[exp.Expression]:
84+
def _parse(self, sql: str) -> exp.Expression | None:
8785
"""Parse *sql* into a sqlglot AST.
8886
8987
Delegates preprocessing to :class:`SqlCleaner` and dialect
@@ -92,7 +90,7 @@ def _parse(self, sql: str) -> Optional[exp.Expression]:
9290
:param sql: Raw SQL string (may include comments).
9391
:type sql: str
9492
:returns: Root AST node, or ``None`` for empty input.
95-
:rtype: Optional[exp.Expression]
93+
:rtype: exp.Expression | None
9694
:raises ValueError: If the SQL is malformed.
9795
"""
9896
if not sql or not sql.strip():

sql_metadata/column_extractor.py

Lines changed: 25 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,12 @@
1212
"""
1313

1414
from dataclasses import dataclass
15-
from typing import Any, Dict, Optional, Union
15+
from typing import Any
1616

1717
from sqlglot import exp
1818

1919
from sql_metadata.exceptions import InvalidQueryDefinition
20-
from sql_metadata.utils import UniqueList, _make_reverse_cte_map, last_segment
20+
from sql_metadata.utils import UniqueList, last_segment
2121

2222
# ---------------------------------------------------------------------------
2323
# Result dataclass
@@ -32,13 +32,13 @@ class ExtractionResult:
3232
"""
3333

3434
columns: UniqueList
35-
columns_dict: Dict[str, UniqueList]
35+
columns_dict: dict[str, UniqueList]
3636
alias_names: UniqueList
37-
alias_dict: Optional[Dict[str, UniqueList]]
38-
alias_map: Dict[str, Union[str, list]]
37+
alias_dict: dict[str, UniqueList] | None
38+
alias_map: dict[str, str | list[str]]
3939
cte_names: UniqueList
4040
subquery_names: UniqueList
41-
output_columns: list
41+
output_columns: list[str]
4242

4343

4444
# ---------------------------------------------------------------------------
@@ -47,7 +47,7 @@ class ExtractionResult:
4747

4848

4949
#: Simple key → clause-name lookup for most ``arg_types`` keys.
50-
_CLAUSE_MAP: Dict[str, str] = {
50+
_CLAUSE_MAP: dict[str, str] = {
5151
"where": "where",
5252
"group": "group_by",
5353
"order": "order_by",
@@ -154,17 +154,17 @@ class _Collector:
154154
"output_columns",
155155
)
156156

157-
def __init__(self, table_aliases: Dict[str, str]):
157+
def __init__(self, table_aliases: dict[str, str]):
158158
self.ta = table_aliases
159159
self.columns = UniqueList()
160-
self.columns_dict: Dict[str, UniqueList] = {}
160+
self.columns_dict: dict[str, UniqueList] = {}
161161
self.alias_names = UniqueList()
162-
self.alias_dict: Dict[str, UniqueList] = {}
163-
self.alias_map: Dict[str, Union[str, list]] = {}
162+
self.alias_dict: dict[str, UniqueList] = {}
163+
self.alias_map: dict[str, str | list[str]] = {}
164164
self.cte_names = UniqueList()
165-
self.cte_alias_names: set = set()
166-
self.subquery_items: list = []
167-
self.output_columns: list = []
165+
self.cte_alias_names: set[str] = set()
166+
self.subquery_items: list[tuple[int, str]] = []
167+
self.output_columns: list[str] = []
168168

169169
def add_column(self, name: str, clause: str) -> None:
170170
"""Record a column name, filing it into the appropriate section."""
@@ -219,14 +219,14 @@ class ColumnExtractor:
219219
def __init__(
220220
self,
221221
ast: exp.Expression,
222-
table_aliases: Dict[str, str],
223-
cte_name_map: Optional[Dict] = None,
222+
table_aliases: dict[str, str],
223+
cte_name_map: dict[str, str] | None = None,
224224
):
225225
self._ast = ast
226226
self._table_aliases = table_aliases
227227
self._cte_name_map = cte_name_map or {}
228228
self._collector = _Collector(table_aliases)
229-
self._reverse_cte_map = self._build_reverse_cte_map()
229+
self._reverse_cte_map = self._cte_name_map
230230

231231
# -------------------------------------------------------------------
232232
# Public API
@@ -283,19 +283,6 @@ def extract(self) -> ExtractionResult:
283283
# Setup helpers
284284
# -------------------------------------------------------------------
285285

286-
def _build_reverse_cte_map(self) -> Dict[str, str]:
287-
"""Build a reverse mapping from placeholder CTE names to originals.
288-
289-
During SQL preprocessing, :class:`SqlCleaner` may rewrite
290-
qualified CTE names (e.g. ``schema.cte``) into simple
291-
placeholders. This method inverts that mapping so the final
292-
extraction result uses the original qualified names.
293-
294-
:returns: A dict mapping placeholder names back to their
295-
original qualified form.
296-
"""
297-
return _make_reverse_cte_map(self._cte_name_map)
298-
299286
def _seed_cte_names(self) -> None:
300287
"""Pre-populate CTE names in the collector before the main walk.
301288
@@ -493,7 +480,9 @@ def _recurse_child(self, child: Any, clause: str, depth: int) -> None:
493480
# Node handlers
494481
# -------------------------------------------------------------------
495482

496-
def _handle_select_exprs(self, exprs: list, clause: str, depth: int) -> None:
483+
def _handle_select_exprs(
484+
self, exprs: list[exp.Expression], clause: str, depth: int
485+
) -> None:
497486
"""Process the expression list of a SELECT clause.
498487
499488
Iterates each expression in the SELECT list, dispatching to
@@ -920,7 +909,7 @@ def _is_self_alias(
920909

921910
@staticmethod
922911
def _is_standalone_star(
923-
child: exp.Star, seen_stars: set
912+
child: exp.Star, seen_stars: set[int]
924913
) -> bool:
925914
"""Check whether a star node is standalone (not consumed by a Column).
926915
@@ -993,7 +982,7 @@ def _is_cte_with_query_body(
993982
# Flat column extraction
994983
# -------------------------------------------------------------------
995984

996-
def _flat_columns_select_only(self, select: exp.Select) -> list:
985+
def _flat_columns_select_only(self, select: exp.Select) -> list[str]:
997986
"""Extract column/alias names from a SELECT's immediate expressions.
998987
999988
Unlike :meth:`_flat_columns`, this does not recurse into the
@@ -1026,7 +1015,7 @@ def _flat_columns_select_only(self, select: exp.Select) -> list:
10261015
cols.append(col_name)
10271016
return cols
10281017

1029-
def _flat_columns(self, node: exp.Expression) -> list:
1018+
def _flat_columns(self, node: exp.Expression) -> list[str]:
10301019
"""Extract all column names from an expression subtree via DFS.
10311020
10321021
Performs a full depth-first traversal of *node* using
@@ -1052,8 +1041,8 @@ def _flat_columns(self, node: exp.Expression) -> list:
10521041
return cols
10531042

10541043
def _collect_column_from_node(
1055-
self, child: exp.Expression, seen_stars: set
1056-
) -> Union[str, None]:
1044+
self, child: exp.Expression, seen_stars: set[int]
1045+
) -> str | None:
10571046
"""Extract a column name from a single DFS-visited node.
10581047
10591048
Called by :meth:`_flat_columns` for each node in the traversal.

sql_metadata/comments.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
"""
2020

2121
import re
22-
from typing import List
22+
from typing import Any
2323

2424
from sqlglot.tokens import Tokenizer
2525

@@ -71,7 +71,7 @@ def _has_hash_variables(sql: str) -> bool:
7171
return False
7272

7373

74-
def extract_comments(sql: str) -> List[str]:
74+
def extract_comments(sql: str) -> list[str]:
7575
"""Return all comments found in *sql*, with delimiters preserved.
7676
7777
Tokenizes the SQL, then scans every gap between consecutive token
@@ -107,7 +107,7 @@ def extract_comments(sql: str) -> List[str]:
107107
_COMMENT_RE = re.compile(r"/\*.*?\*/|/\*.*$|--[^\n]*\n?|#[^\n]*\n?", re.DOTALL)
108108

109109

110-
def _scan_gap(sql: str, start: int, end: int, out: list) -> None:
110+
def _scan_gap(sql: str, start: int, end: int, out: list[str]) -> None:
111111
"""Scan a slice of *sql* for comment delimiters and append matches.
112112
113113
:param sql: The full SQL string (not just the gap).
@@ -118,7 +118,7 @@ def _scan_gap(sql: str, start: int, end: int, out: list) -> None:
118118
out.extend(_COMMENT_RE.findall(sql[start:end]))
119119

120120

121-
def _reconstruct_from_tokens(sql: str, tokens: list) -> str:
121+
def _reconstruct_from_tokens(sql: str, tokens: list[Any]) -> str:
122122
"""Rebuild SQL from token spans, collapsing gaps to single spaces."""
123123
if not tokens:
124124
return ""

sql_metadata/dialect_parser.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ class so that callers only need to call :meth:`DialectParser.parse`.
66
"""
77

88
import logging
9-
from typing import Any, Optional
9+
from typing import Any
1010

1111
import sqlglot
1212
from sqlglot import exp
@@ -111,7 +111,7 @@ def parse(self, clean_sql: str) -> tuple[exp.Expression, DialectType]:
111111
# -- dialect detection --------------------------------------------------
112112

113113
@staticmethod
114-
def _detect_dialects(sql: str) -> list:
114+
def _detect_dialects(sql: str) -> list[Any]:
115115
"""Choose an ordered list of sqlglot dialects to try for *sql*.
116116
117117
Heuristics:
@@ -145,7 +145,7 @@ def _detect_dialects(sql: str) -> list:
145145
# -- parsing ------------------------------------------------------------
146146

147147
def _try_dialects(
148-
self, clean_sql: str, dialects: list
148+
self, clean_sql: str, dialects: list[Any]
149149
) -> tuple[exp.Expression, DialectType]:
150150
"""Try parsing *clean_sql* with each dialect, returning the best.
151151
@@ -180,7 +180,7 @@ def _try_dialects(
180180
)
181181

182182
@staticmethod
183-
def _parse_with_dialect(clean_sql: str, dialect: Any) -> Optional[exp.Expression]:
183+
def _parse_with_dialect(clean_sql: str, dialect: Any) -> exp.Expression | None:
184184
"""Parse *clean_sql* with a single dialect, suppressing warnings."""
185185
logger = logging.getLogger("sqlglot")
186186
old_level = logger.level

0 commit comments

Comments
 (0)