Skip to content

Commit 9828bfe

Browse files
committed
further fixes and duplication cleanup
1 parent b3744ac commit 9828bfe

6 files changed

Lines changed: 97 additions & 56 deletions

File tree

sql_metadata/column_extractor.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
from sqlglot import exp
1818

19-
from sql_metadata.utils import UniqueList, _make_reverse_cte_map
19+
from sql_metadata.utils import UniqueList, _make_reverse_cte_map, last_segment
2020

2121
# ---------------------------------------------------------------------------
2222
# Result dataclass
@@ -516,10 +516,10 @@ def _handle_alias(self, alias_node: exp.Alias, clause: str, depth: int) -> None:
516516
for col in inner_cols:
517517
c.add_column(col, clause)
518518

519-
unique_inner = list(dict.fromkeys(inner_cols))
519+
unique_inner = UniqueList(inner_cols)
520520
is_self_alias = len(unique_inner) == 1 and (
521521
unique_inner[0] == alias_name
522-
or unique_inner[0].split(".")[-1] == alias_name
522+
or last_segment(unique_inner[0]) == alias_name
523523
)
524524
is_direct = isinstance(inner, exp.Column)
525525

sql_metadata/nested_resolver.py

Lines changed: 46 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,13 @@
1717
from sqlglot import exp
1818
from sqlglot.generator import Generator
1919

20-
from sql_metadata.utils import UniqueList, _make_reverse_cte_map, flatten_list
20+
from sql_metadata.utils import (
21+
DOT_PLACEHOLDER,
22+
UniqueList,
23+
_make_reverse_cte_map,
24+
flatten_list,
25+
last_segment,
26+
)
2127

2228
# ---------------------------------------------------------------------------
2329
# Custom SQL generator — preserves function signatures
@@ -119,6 +125,7 @@ def __init__(
119125
self._subqueries_parsers: Dict = {}
120126
self._with_parsers: Dict = {}
121127
self._columns_aliases: Dict = {}
128+
self._cached_cte_nodes: Optional[list] = None
122129

123130
# Set by resolve() caller
124131
self._subqueries_names: List[str] = []
@@ -130,21 +137,29 @@ def __init__(
130137
# Name extraction (CTE and subquery names from the AST)
131138
# -------------------------------------------------------------------
132139

133-
@staticmethod
140+
def _cte_nodes(self) -> list:
141+
"""Return all ``exp.CTE`` nodes from the AST (cached)."""
142+
if self._cached_cte_nodes is None:
143+
if self._ast is None:
144+
self._cached_cte_nodes = []
145+
else:
146+
self._cached_cte_nodes = list(self._ast.find_all(exp.CTE))
147+
return self._cached_cte_nodes
148+
134149
def extract_cte_names(
135-
ast: Optional[exp.Expression],
150+
self,
136151
cte_name_map: Optional[Dict] = None,
137152
) -> List[str]:
138153
"""Extract CTE names from the AST.
139154
140155
Called by :attr:`Parser.with_names`.
141156
"""
142-
if ast is None:
157+
if self._ast is None:
143158
return []
144159
cte_name_map = cte_name_map or {}
145160
reverse_map = _make_reverse_cte_map(cte_name_map)
146161
names = UniqueList()
147-
for cte in ast.find_all(exp.CTE):
162+
for cte in self._cte_nodes():
148163
alias = cte.alias
149164
if alias:
150165
names.append(reverse_map.get(alias, alias))
@@ -196,13 +211,13 @@ def extract_cte_bodies(
196211

197212
alias_to_name: Dict[str, str] = {}
198213
for name in cte_names:
199-
placeholder = name.replace(".", "__DOT__")
214+
placeholder = name.replace(".", DOT_PLACEHOLDER)
200215
alias_to_name[placeholder.upper()] = name
201216
alias_to_name[name.upper()] = name
202-
alias_to_name[name.split(".")[-1].upper()] = name
217+
alias_to_name[last_segment(name).upper()] = name
203218

204219
results: Dict[str, str] = {}
205-
for cte in self._ast.find_all(exp.CTE):
220+
for cte in self._cte_nodes():
206221
alias = cte.alias
207222
if alias.upper() in alias_to_name:
208223
original_name = alias_to_name[alias.upper()]
@@ -312,42 +327,34 @@ def _resolve_and_filter(
312327
final.append(col)
313328
return final
314329

330+
def _nested_sources(self) -> list:
331+
"""Return the (names, defs, cache) tuples for subqueries then CTEs."""
332+
return [
333+
(self._subqueries_names, self._subqueries, self._subqueries_parsers),
334+
(self._with_names, self._with_queries, self._with_parsers),
335+
]
336+
315337
def _resolve_sub_queries(self, column: str) -> Union[str, List[str]]:
316338
"""Resolve a ``subquery.column`` reference to actual column(s)."""
317-
result = self._resolve_nested_query(
318-
subquery_alias=column,
319-
nested_queries_names=self._subqueries_names,
320-
nested_queries=self._subqueries,
321-
already_parsed=self._subqueries_parsers,
322-
)
323-
if isinstance(result, str):
324-
result = self._resolve_nested_query(
325-
subquery_alias=result,
326-
nested_queries_names=self._with_names,
327-
nested_queries=self._with_queries,
328-
already_parsed=self._with_parsers,
329-
)
339+
result: Union[str, List[str]] = column
340+
for names, defs, cache in self._nested_sources():
341+
if isinstance(result, str):
342+
result = self._resolve_nested_query(
343+
subquery_alias=result,
344+
nested_queries_names=names,
345+
nested_queries=defs,
346+
already_parsed=cache,
347+
)
330348
return result if isinstance(result, list) else [result]
331349

332350
def _resolve_bare_through_nested(self, col_name: str) -> Union[str, List[str]]:
333351
"""Resolve a bare column name through subquery/CTE alias definitions."""
334-
result = self._lookup_alias_in_nested(
335-
col_name,
336-
self._subqueries_names,
337-
self._subqueries,
338-
self._subqueries_parsers,
339-
check_columns=True,
340-
)
341-
if result is not None:
342-
return result
343-
result = self._lookup_alias_in_nested(
344-
col_name,
345-
self._with_names,
346-
self._with_queries,
347-
self._with_parsers,
348-
)
349-
if result is not None:
350-
return result
352+
for i, (names, defs, cache) in enumerate(self._nested_sources()):
353+
result = self._lookup_alias_in_nested(
354+
col_name, names, defs, cache, check_columns=(i == 0)
355+
)
356+
if result is not None:
357+
return result
351358
return col_name
352359

353360
def _lookup_alias_in_nested(
@@ -447,7 +454,7 @@ def _find_column_fallback(
447454
) -> Union[str, List[str]]:
448455
"""Find a column by name in the subparser with wildcard fallbacks."""
449456
try:
450-
idx = [x.split(".")[-1] for x in subparser.columns].index(column_name)
457+
idx = [last_segment(x) for x in subparser.columns].index(column_name)
451458
except ValueError:
452459
if "*" in subparser.columns:
453460
return column_name

sql_metadata/parser.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -298,8 +298,9 @@ def with_names(self) -> List[str]:
298298
"""Return the CTE (Common Table Expression) names from the query."""
299299
if self._with_names is not None:
300300
return self._with_names
301-
self._with_names = NestedResolver.extract_cte_names(
302-
self._ast_parser.ast, self._ast_parser.cte_name_map
301+
resolver = self._get_resolver()
302+
self._with_names = resolver.extract_cte_names(
303+
self._ast_parser.cte_name_map
303304
)
304305
return self._with_names
305306

sql_metadata/sql_cleaner.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from typing import NamedTuple, Optional
1212

1313
from sql_metadata.comments import strip_comments_for_parsing as _strip_comments
14+
from sql_metadata.utils import DOT_PLACEHOLDER
1415

1516

1617
class CleanResult(NamedTuple):
@@ -70,7 +71,7 @@ def replacer(match: re.Match[str]) -> str:
7071
prefix = match.group(1)
7172
qualified_name = match.group(2)
7273
suffix = match.group(3)
73-
placeholder = qualified_name.replace(".", "__DOT__")
74+
placeholder = qualified_name.replace(".", DOT_PLACEHOLDER)
7475
name_map[placeholder] = qualified_name
7576
return f"{prefix}{placeholder}{suffix}"
7677

sql_metadata/table_extractor.py

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313

1414
from sqlglot import exp
1515

16-
from sql_metadata.utils import UniqueList
16+
from sql_metadata.utils import UniqueList, last_segment
1717

1818
# ---------------------------------------------------------------------------
1919
# Pure static helpers (no instance state needed)
@@ -127,6 +127,7 @@ def __init__(
127127
self._bracket_mode = isinstance(dialect, type) and issubclass(
128128
dialect, BracketedTableDialect
129129
)
130+
self._cached_table_nodes: Optional[List[exp.Table]] = None
130131

131132
# -------------------------------------------------------------------
132133
# Public API
@@ -152,6 +153,13 @@ def extract(self) -> List[str]:
152153
collected_sorted = sorted(collected, key=lambda t: self._first_position(t))
153154
return self._place_tables_in_order(create_target, collected_sorted)
154155

156+
def _table_nodes(self) -> List[exp.Table]:
157+
"""Return all ``exp.Table`` nodes from the AST (cached)."""
158+
if self._cached_table_nodes is None:
159+
assert self._ast is not None
160+
self._cached_table_nodes = list(self._ast.find_all(exp.Table))
161+
return self._cached_table_nodes
162+
155163
def extract_aliases(self, tables: List[str]) -> Dict[str, str]:
156164
"""Extract table alias mappings from the AST.
157165
@@ -162,7 +170,7 @@ def extract_aliases(self, tables: List[str]) -> Dict[str, str]:
162170
return {}
163171

164172
aliases = {}
165-
for table in self._ast.find_all(exp.Table):
173+
for table in self._table_nodes():
166174
alias = table.alias
167175
if not alias:
168176
continue
@@ -203,19 +211,25 @@ def _first_position(self, name: str) -> int:
203211
if pos >= 0:
204212
return pos
205213

206-
last_part = name_upper.split(".")[-1]
214+
last_part = last_segment(name_upper)
207215
pos = self._find_word_in_table_context(last_part)
208216
if pos >= 0:
209217
return pos
210218

211219
pos = self._find_word(name_upper)
212220
return pos if pos >= 0 else len(self._raw_sql)
213221

222+
_pattern_cache: Dict[str, re.Pattern[str]] = {}
223+
214224
@staticmethod
215225
def _word_pattern(name_upper: str) -> re.Pattern[str]:
216-
"""Build a regex matching *name_upper* as a whole word."""
217-
escaped = re.escape(name_upper)
218-
return re.compile(r"(?<![A-Za-z0-9_])" + escaped + r"(?![A-Za-z0-9_])")
226+
"""Build a regex matching *name_upper* as a whole word (cached)."""
227+
pat = TableExtractor._pattern_cache.get(name_upper)
228+
if pat is None:
229+
escaped = re.escape(name_upper)
230+
pat = re.compile(r"(?<![A-Za-z0-9_])" + escaped + r"(?![A-Za-z0-9_])")
231+
TableExtractor._pattern_cache[name_upper] = pat
232+
return pat
219233

220234
def _find_word(self, name_upper: str, start: int = 0) -> int:
221235
"""Find *name_upper* as a whole word in the upper-cased SQL."""
@@ -271,7 +285,7 @@ def _collect_all(self) -> UniqueList:
271285
"""Collect table names from Table and Lateral AST nodes."""
272286
assert self._ast is not None
273287
collected = UniqueList()
274-
for table in self._ast.find_all(exp.Table):
288+
for table in self._table_nodes():
275289
full_name = self._table_full_name(table)
276290
if full_name and full_name not in self._cte_names:
277291
collected.append(full_name)

sql_metadata/utils.py

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,10 @@
77

88
from typing import Any, Dict, Iterable, List
99

10+
#: Placeholder used to encode dots in qualified CTE names so that sqlglot
11+
#: does not misinterpret ``db.cte_name`` as a table reference.
12+
DOT_PLACEHOLDER = "__DOT__"
13+
1014

1115
class UniqueList(list):
1216
"""A list subclass that silently rejects duplicate items.
@@ -17,9 +21,14 @@ class UniqueList(list):
1721
an internal ``set`` for O(1) membership checks.
1822
"""
1923

20-
def __init__(self, *args: Any, **kwargs: Any) -> None:
21-
super().__init__(*args, **kwargs)
22-
self._seen: set = set(self)
24+
def __init__(self, iterable: Any = None, **kwargs: Any) -> None:
25+
self._seen: set = set()
26+
if iterable is not None:
27+
super().__init__(**kwargs)
28+
self.extend(iterable)
29+
else:
30+
super().__init__(**kwargs)
31+
self._seen = set(self)
2332

2433
def append(self, item: Any) -> None:
2534
"""Append *item* only if it is not already present (O(1) check)."""
@@ -32,6 +41,10 @@ def extend(self, items: Iterable[Any]) -> None: # type: ignore[override]
3241
for item in items:
3342
self.append(item)
3443

44+
def __contains__(self, item: Any) -> bool:
45+
"""O(1) membership check using the internal set."""
46+
return item in self._seen
47+
3548
def __sub__(self, other: Any) -> List:
3649
"""Return a plain list of elements in *self* that are not in *other*."""
3750
other_set = set(other)
@@ -40,11 +53,16 @@ def __sub__(self, other: Any) -> List:
4053

4154
def _make_reverse_cte_map(cte_name_map: Dict) -> Dict[str, str]:
4255
"""Build reverse mapping from placeholder CTE names to originals."""
43-
reverse = {v.replace(".", "__DOT__"): v for v in cte_name_map.values()}
56+
reverse = {v.replace(".", DOT_PLACEHOLDER): v for v in cte_name_map.values()}
4457
reverse.update(cte_name_map)
4558
return reverse
4659

4760

61+
def last_segment(name: str) -> str:
62+
"""Return the last dot-separated segment of a qualified name."""
63+
return name.rsplit(".", 1)[-1]
64+
65+
4866
def flatten_list(input_list: List) -> List[str]:
4967
"""Recursively flatten a list that may contain nested lists.
5068

0 commit comments

Comments
 (0)