1212"""
1313
1414from dataclasses import dataclass
15- from typing import Any , Dict , Optional , Union
15+ from typing import Any
1616
1717from sqlglot import exp
1818
1919from sql_metadata .exceptions import InvalidQueryDefinition
20- from sql_metadata .utils import UniqueList , _make_reverse_cte_map , last_segment
20+ from sql_metadata .utils import UniqueList , last_segment
2121
2222# ---------------------------------------------------------------------------
2323# Result dataclass
@@ -32,13 +32,13 @@ class ExtractionResult:
3232 """
3333
3434 columns : UniqueList
35- columns_dict : Dict [str , UniqueList ]
35+ columns_dict : dict [str , UniqueList ]
3636 alias_names : UniqueList
37- alias_dict : Optional [ Dict [ str , UniqueList ]]
38- alias_map : Dict [str , Union [ str , list ]]
37+ alias_dict : dict [ str , UniqueList ] | None
38+ alias_map : dict [str , str | list [ str ]]
3939 cte_names : UniqueList
4040 subquery_names : UniqueList
41- output_columns : list
41+ output_columns : list [ str ]
4242
4343
4444# ---------------------------------------------------------------------------
@@ -47,7 +47,7 @@ class ExtractionResult:
4747
4848
4949#: Simple key → clause-name lookup for most ``arg_types`` keys.
50- _CLAUSE_MAP : Dict [str , str ] = {
50+ _CLAUSE_MAP : dict [str , str ] = {
5151 "where" : "where" ,
5252 "group" : "group_by" ,
5353 "order" : "order_by" ,
@@ -154,17 +154,17 @@ class _Collector:
154154 "output_columns" ,
155155 )
156156
157- def __init__ (self , table_aliases : Dict [str , str ]):
157+ def __init__ (self , table_aliases : dict [str , str ]):
158158 self .ta = table_aliases
159159 self .columns = UniqueList ()
160- self .columns_dict : Dict [str , UniqueList ] = {}
160+ self .columns_dict : dict [str , UniqueList ] = {}
161161 self .alias_names = UniqueList ()
162- self .alias_dict : Dict [str , UniqueList ] = {}
163- self .alias_map : Dict [str , Union [ str , list ]] = {}
162+ self .alias_dict : dict [str , UniqueList ] = {}
163+ self .alias_map : dict [str , str | list [ str ]] = {}
164164 self .cte_names = UniqueList ()
165- self .cte_alias_names : set = set ()
166- self .subquery_items : list = []
167- self .output_columns : list = []
165+ self .cte_alias_names : set [ str ] = set ()
166+ self .subquery_items : list [ tuple [ int , str ]] = []
167+ self .output_columns : list [ str ] = []
168168
169169 def add_column (self , name : str , clause : str ) -> None :
170170 """Record a column name, filing it into the appropriate section."""
@@ -219,14 +219,14 @@ class ColumnExtractor:
219219 def __init__ (
220220 self ,
221221 ast : exp .Expression ,
222- table_aliases : Dict [str , str ],
223- cte_name_map : Optional [ Dict ] = None ,
222+ table_aliases : dict [str , str ],
223+ cte_name_map : dict [ str , str ] | None = None ,
224224 ):
225225 self ._ast = ast
226226 self ._table_aliases = table_aliases
227227 self ._cte_name_map = cte_name_map or {}
228228 self ._collector = _Collector (table_aliases )
229- self ._reverse_cte_map = self ._build_reverse_cte_map ()
229+ self ._reverse_cte_map = self ._cte_name_map
230230
231231 # -------------------------------------------------------------------
232232 # Public API
@@ -283,19 +283,6 @@ def extract(self) -> ExtractionResult:
283283 # Setup helpers
284284 # -------------------------------------------------------------------
285285
286- def _build_reverse_cte_map (self ) -> Dict [str , str ]:
287- """Build a reverse mapping from placeholder CTE names to originals.
288-
289- During SQL preprocessing, :class:`SqlCleaner` may rewrite
290- qualified CTE names (e.g. ``schema.cte``) into simple
291- placeholders. This method inverts that mapping so the final
292- extraction result uses the original qualified names.
293-
294- :returns: A dict mapping placeholder names back to their
295- original qualified form.
296- """
297- return _make_reverse_cte_map (self ._cte_name_map )
298-
299286 def _seed_cte_names (self ) -> None :
300287 """Pre-populate CTE names in the collector before the main walk.
301288
@@ -493,7 +480,9 @@ def _recurse_child(self, child: Any, clause: str, depth: int) -> None:
493480 # Node handlers
494481 # -------------------------------------------------------------------
495482
496- def _handle_select_exprs (self , exprs : list , clause : str , depth : int ) -> None :
483+ def _handle_select_exprs (
484+ self , exprs : list [exp .Expression ], clause : str , depth : int
485+ ) -> None :
497486 """Process the expression list of a SELECT clause.
498487
499488 Iterates each expression in the SELECT list, dispatching to
@@ -920,7 +909,7 @@ def _is_self_alias(
920909
921910 @staticmethod
922911 def _is_standalone_star (
923- child : exp .Star , seen_stars : set
912+ child : exp .Star , seen_stars : set [ int ]
924913 ) -> bool :
925914 """Check whether a star node is standalone (not consumed by a Column).
926915
@@ -993,7 +982,7 @@ def _is_cte_with_query_body(
993982 # Flat column extraction
994983 # -------------------------------------------------------------------
995984
996- def _flat_columns_select_only (self , select : exp .Select ) -> list :
985+ def _flat_columns_select_only (self , select : exp .Select ) -> list [ str ] :
997986 """Extract column/alias names from a SELECT's immediate expressions.
998987
999988 Unlike :meth:`_flat_columns`, this does not recurse into the
@@ -1026,7 +1015,7 @@ def _flat_columns_select_only(self, select: exp.Select) -> list:
10261015 cols .append (col_name )
10271016 return cols
10281017
1029- def _flat_columns (self , node : exp .Expression ) -> list :
1018+ def _flat_columns (self , node : exp .Expression ) -> list [ str ] :
10301019 """Extract all column names from an expression subtree via DFS.
10311020
10321021 Performs a full depth-first traversal of *node* using
@@ -1052,8 +1041,8 @@ def _flat_columns(self, node: exp.Expression) -> list:
10521041 return cols
10531042
10541043 def _collect_column_from_node (
1055- self , child : exp .Expression , seen_stars : set
1056- ) -> Union [ str , None ] :
1044+ self , child : exp .Expression , seen_stars : set [ int ]
1045+ ) -> str | None :
10571046 """Extract a column name from a single DFS-visited node.
10581047
10591048 Called by :meth:`_flat_columns` for each node in the traversal.
0 commit comments