1717from sqlglot import exp
1818from sqlglot .generator import Generator
1919
20- from sql_metadata .utils import UniqueList , _make_reverse_cte_map , flatten_list
20+ from sql_metadata .utils import (
21+ DOT_PLACEHOLDER ,
22+ UniqueList ,
23+ _make_reverse_cte_map ,
24+ flatten_list ,
25+ last_segment ,
26+ )
2127
2228# ---------------------------------------------------------------------------
2329# Custom SQL generator — preserves function signatures
@@ -119,6 +125,7 @@ def __init__(
119125 self ._subqueries_parsers : Dict = {}
120126 self ._with_parsers : Dict = {}
121127 self ._columns_aliases : Dict = {}
128+ self ._cached_cte_nodes : Optional [list ] = None
122129
123130 # Set by resolve() caller
124131 self ._subqueries_names : List [str ] = []
@@ -130,21 +137,29 @@ def __init__(
130137 # Name extraction (CTE and subquery names from the AST)
131138 # -------------------------------------------------------------------
132139
133- @staticmethod
140+ def _cte_nodes (self ) -> list :
141+ """Return all ``exp.CTE`` nodes from the AST (cached)."""
142+ if self ._cached_cte_nodes is None :
143+ if self ._ast is None :
144+ self ._cached_cte_nodes = []
145+ else :
146+ self ._cached_cte_nodes = list (self ._ast .find_all (exp .CTE ))
147+ return self ._cached_cte_nodes
148+
134149 def extract_cte_names (
135- ast : Optional [ exp . Expression ] ,
150+ self ,
136151 cte_name_map : Optional [Dict ] = None ,
137152 ) -> List [str ]:
138153 """Extract CTE names from the AST.
139154
140155 Called by :attr:`Parser.with_names`.
141156 """
142- if ast is None :
157+ if self . _ast is None :
143158 return []
144159 cte_name_map = cte_name_map or {}
145160 reverse_map = _make_reverse_cte_map (cte_name_map )
146161 names = UniqueList ()
147- for cte in ast . find_all ( exp . CTE ):
162+ for cte in self . _cte_nodes ( ):
148163 alias = cte .alias
149164 if alias :
150165 names .append (reverse_map .get (alias , alias ))
@@ -196,13 +211,13 @@ def extract_cte_bodies(
196211
197212 alias_to_name : Dict [str , str ] = {}
198213 for name in cte_names :
199- placeholder = name .replace ("." , "__DOT__" )
214+ placeholder = name .replace ("." , DOT_PLACEHOLDER )
200215 alias_to_name [placeholder .upper ()] = name
201216 alias_to_name [name .upper ()] = name
202- alias_to_name [name . split ( "." )[ - 1 ] .upper ()] = name
217+ alias_to_name [last_segment ( name ) .upper ()] = name
203218
204219 results : Dict [str , str ] = {}
205- for cte in self ._ast . find_all ( exp . CTE ):
220+ for cte in self ._cte_nodes ( ):
206221 alias = cte .alias
207222 if alias .upper () in alias_to_name :
208223 original_name = alias_to_name [alias .upper ()]
@@ -312,42 +327,34 @@ def _resolve_and_filter(
312327 final .append (col )
313328 return final
314329
330+ def _nested_sources (self ) -> list :
331+ """Return the (names, defs, cache) tuples for subqueries then CTEs."""
332+ return [
333+ (self ._subqueries_names , self ._subqueries , self ._subqueries_parsers ),
334+ (self ._with_names , self ._with_queries , self ._with_parsers ),
335+ ]
336+
315337 def _resolve_sub_queries (self , column : str ) -> Union [str , List [str ]]:
316338 """Resolve a ``subquery.column`` reference to actual column(s)."""
317- result = self ._resolve_nested_query (
318- subquery_alias = column ,
319- nested_queries_names = self ._subqueries_names ,
320- nested_queries = self ._subqueries ,
321- already_parsed = self ._subqueries_parsers ,
322- )
323- if isinstance (result , str ):
324- result = self ._resolve_nested_query (
325- subquery_alias = result ,
326- nested_queries_names = self ._with_names ,
327- nested_queries = self ._with_queries ,
328- already_parsed = self ._with_parsers ,
329- )
339+ result : Union [str , List [str ]] = column
340+ for names , defs , cache in self ._nested_sources ():
341+ if isinstance (result , str ):
342+ result = self ._resolve_nested_query (
343+ subquery_alias = result ,
344+ nested_queries_names = names ,
345+ nested_queries = defs ,
346+ already_parsed = cache ,
347+ )
330348 return result if isinstance (result , list ) else [result ]
331349
332350 def _resolve_bare_through_nested (self , col_name : str ) -> Union [str , List [str ]]:
333351 """Resolve a bare column name through subquery/CTE alias definitions."""
334- result = self ._lookup_alias_in_nested (
335- col_name ,
336- self ._subqueries_names ,
337- self ._subqueries ,
338- self ._subqueries_parsers ,
339- check_columns = True ,
340- )
341- if result is not None :
342- return result
343- result = self ._lookup_alias_in_nested (
344- col_name ,
345- self ._with_names ,
346- self ._with_queries ,
347- self ._with_parsers ,
348- )
349- if result is not None :
350- return result
352+ for i , (names , defs , cache ) in enumerate (self ._nested_sources ()):
353+ result = self ._lookup_alias_in_nested (
354+ col_name , names , defs , cache , check_columns = (i == 0 )
355+ )
356+ if result is not None :
357+ return result
351358 return col_name
352359
353360 def _lookup_alias_in_nested (
@@ -447,7 +454,7 @@ def _find_column_fallback(
447454 ) -> Union [str , List [str ]]:
448455 """Find a column by name in the subparser with wildcard fallbacks."""
449456 try :
450- idx = [x . split ( "." )[ - 1 ] for x in subparser .columns ].index (column_name )
457+ idx = [last_segment ( x ) for x in subparser .columns ].index (column_name )
451458 except ValueError :
452459 if "*" in subparser .columns :
453460 return column_name
0 commit comments