Skip to content

Commit b3744ac

Browse files
committed
fix remaining mypy errors in untyped code
1 parent 9ce3ab3 commit b3744ac

10 files changed

Lines changed: 38 additions & 32 deletions

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ max-complexity = 8
4242
python_version = "3.10"
4343
warn_return_any = true
4444
warn_unused_configs = true
45+
check_untyped_defs = true
46+
disallow_untyped_defs = true
4547
ignore_missing_imports = true
4648

4749
[tool.coverage.run]

sql_metadata/ast_parser.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def ast(self) -> Optional[exp.Expression]:
5050
return self._ast
5151

5252
@property
53-
def dialect(self):
53+
def dialect(self) -> object:
5454
"""The sqlglot dialect that produced the current AST.
5555
5656
Set as a side-effect of :attr:`ast` access. May be ``None``

sql_metadata/column_extractor.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
"""
1313

1414
from dataclasses import dataclass
15-
from typing import Dict, Optional, Union
15+
from typing import Any, Dict, Optional, Union
1616

1717
from sqlglot import exp
1818

@@ -88,7 +88,7 @@ def _classify_clause(key: str, parent_type: type) -> str:
8888
# ---------------------------------------------------------------------------
8989

9090

91-
def _dfs(node: exp.Expression):
91+
def _dfs(node: exp.Expression) -> Any:
9292
"""Yield *node* and all its descendants in depth-first order.
9393
9494
:param node: Root expression node.
@@ -168,7 +168,7 @@ def add_column(self, name: str, clause: str) -> None:
168168
if clause:
169169
self.columns_dict.setdefault(clause, UniqueList()).append(name)
170170

171-
def add_alias(self, name: str, target, clause: str) -> None:
171+
def add_alias(self, name: str, target: Any, clause: str) -> None:
172172
"""Record a column alias and its target expression."""
173173
self.alias_names.append(name)
174174
if clause:
@@ -335,7 +335,9 @@ def _is_star_inside_function(star: exp.Star) -> bool:
335335
# DFS walk
336336
# -------------------------------------------------------------------
337337

338-
def _walk(self, node, clause: str = "", depth: int = 0) -> None:
338+
def _walk(
339+
self, node: Optional[exp.Expression], clause: str = "", depth: int = 0
340+
) -> None:
339341
"""Depth-first walk of the AST in ``arg_types`` key order."""
340342
if node is None:
341343
return
@@ -346,7 +348,7 @@ def _walk(self, node, clause: str = "", depth: int = 0) -> None:
346348
if hasattr(node, "arg_types"):
347349
self._walk_children(node, clause, depth)
348350

349-
def _walk_children(self, node, clause: str, depth: int) -> None:
351+
def _walk_children(self, node: exp.Expression, clause: str, depth: int) -> None:
350352
"""Recurse into children of *node* in ``arg_types`` key order."""
351353
for key in node.arg_types:
352354
if key in _SKIP_KEYS:
@@ -360,7 +362,7 @@ def _walk_children(self, node, clause: str, depth: int) -> None:
360362
if not self._process_child_key(node, key, child, new_clause, depth):
361363
self._recurse_child(child, new_clause, depth)
362364

363-
def _dispatch_leaf(self, node, clause: str, depth: int) -> bool:
365+
def _dispatch_leaf(self, node: exp.Expression, clause: str, depth: int) -> bool:
364366
"""Dispatch leaf-like AST nodes to their specialised handlers.
365367
366368
Returns ``True`` if handled (stop recursion), ``False`` to continue.
@@ -384,7 +386,7 @@ def _dispatch_leaf(self, node, clause: str, depth: int) -> bool:
384386
return False
385387

386388
def _process_child_key(
387-
self, node, key: str, child, clause: str, depth: int
389+
self, node: exp.Expression, key: str, child: Any, clause: str, depth: int
388390
) -> bool:
389391
"""Handle special cases for SELECT expressions, INSERT schema, JOIN USING.
390392
@@ -401,7 +403,7 @@ def _process_child_key(
401403
return True
402404
return False
403405

404-
def _recurse_child(self, child, clause: str, depth: int) -> None:
406+
def _recurse_child(self, child: Any, clause: str, depth: int) -> None:
405407
"""Recursively walk a child value (single expression or list)."""
406408
if isinstance(child, list):
407409
for item in child:
@@ -437,7 +439,7 @@ def _handle_insert_schema(self, node: exp.Insert) -> None:
437439
name = col_id.name if hasattr(col_id, "name") else str(col_id)
438440
self._collector.add_column(name, "insert")
439441

440-
def _handle_join_using(self, child) -> None:
442+
def _handle_join_using(self, child: Any) -> None:
441443
"""Extract column identifiers from a JOIN USING clause."""
442444
if isinstance(child, list):
443445
for item in child:
@@ -473,7 +475,7 @@ def _handle_column(self, col: exp.Column, clause: str) -> None:
473475

474476
c.add_column(full, clause)
475477

476-
def _handle_select_exprs(self, exprs, clause: str, depth: int) -> None:
478+
def _handle_select_exprs(self, exprs: Any, clause: str, depth: int) -> None:
477479
"""Handle the expressions list of a SELECT clause."""
478480
if not isinstance(exprs, list):
479481
return

sql_metadata/comments.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from sqlglot.tokens import Tokenizer
2525

2626

27-
def _choose_tokenizer(sql: str):
27+
def _choose_tokenizer(sql: str) -> Tokenizer:
2828
"""Select the appropriate sqlglot tokenizer for *sql*.
2929
3030
The default sqlglot tokenizer does **not** treat ``#`` as a comment

sql_metadata/dialect_parser.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ class so that callers only need to call :meth:`DialectParser.parse`.
66
"""
77

88
import logging
9-
from typing import Optional
9+
from typing import Any, Optional
1010

1111
import sqlglot
1212
from sqlglot import Dialect, exp
@@ -144,7 +144,7 @@ def _try_dialects(
144144
raise ValueError("This query is wrong")
145145

146146
@staticmethod
147-
def _parse_with_dialect(clean_sql: str, dialect) -> Optional[exp.Expression]:
147+
def _parse_with_dialect(clean_sql: str, dialect: Any) -> Optional[exp.Expression]:
148148
"""Parse *clean_sql* with a single dialect, suppressing warnings."""
149149
logger = logging.getLogger("sqlglot")
150150
old_level = logger.level

sql_metadata/nested_resolver.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -42,25 +42,25 @@ class _PreservingGenerator(Generator):
4242
),
4343
}
4444

45-
def coalesce_sql(self, expression):
45+
def coalesce_sql(self, expression: exp.Expression) -> str:
4646
args = [expression.this] + expression.expressions
4747
if len(args) == 2:
4848
return f"IFNULL({self.sql(args[0])}, {self.sql(args[1])})"
49-
return super().coalesce_sql(expression)
49+
return super().coalesce_sql(expression) # type: ignore[misc, no-any-return]
5050

51-
def dateadd_sql(self, expression):
51+
def dateadd_sql(self, expression: exp.Expression) -> str:
5252
return (
5353
f"DATE_ADD({self.sql(expression, 'this')}, "
5454
f"{self.sql(expression, 'expression')})"
5555
)
5656

57-
def datesub_sql(self, expression):
57+
def datesub_sql(self, expression: exp.Expression) -> str:
5858
return (
5959
f"DATE_SUB({self.sql(expression, 'this')}, "
6060
f"{self.sql(expression, 'expression')})"
6161
)
6262

63-
def tsordsadd_sql(self, expression):
63+
def tsordsadd_sql(self, expression: exp.Expression) -> str:
6464
this = self.sql(expression, "this")
6565
expr_node = expression.expression
6666
if isinstance(expr_node, exp.Mul):
@@ -74,13 +74,13 @@ def tsordsadd_sql(self, expression):
7474
return f"DATE_SUB({this}, {left})"
7575
return f"DATE_ADD({this}, {self.sql(expression, 'expression')})"
7676

77-
def not_sql(self, expression):
77+
def not_sql(self, expression: exp.Expression) -> str:
7878
child = expression.this
7979
if isinstance(child, exp.Is) and isinstance(child.expression, exp.Null):
8080
return f"{self.sql(child, 'this')} IS NOT NULL"
8181
if isinstance(child, exp.In):
8282
return f"{self.sql(child, 'this')} NOT IN ({self.expressions(child)})"
83-
return super().not_sql(expression)
83+
return super().not_sql(expression) # type: ignore[arg-type, no-any-return]
8484

8585

8686
_GENERATOR = _PreservingGenerator()
@@ -287,7 +287,7 @@ def resolve(
287287
return columns, columns_dict, self._columns_aliases
288288

289289
def _resolve_and_filter(
290-
self, columns, drop_bare_aliases: bool = True
290+
self, columns: "UniqueList", drop_bare_aliases: bool = True
291291
) -> "UniqueList":
292292
"""Apply subquery/CTE resolution and bare-alias handling."""
293293
resolved = UniqueList()

sql_metadata/parser.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@
1313

1414
import logging
1515
import re
16-
from typing import Dict, List, Optional, Tuple, Union
16+
from typing import Any, Dict, List, Optional, Tuple, Union
17+
18+
from sqlglot import exp
1719

1820
from sql_metadata.ast_parser import ASTParser
1921
from sql_metadata.column_extractor import ColumnExtractor
@@ -99,10 +101,10 @@ def _preprocess_query(self) -> str:
99101
if self._raw_query == "":
100102
return ""
101103

102-
def replace_quotes_in_string(match):
104+
def replace_quotes_in_string(match: re.Match[str]) -> str:
103105
return re.sub('"', "<!!__QUOTE__!!>", match.group())
104106

105-
def replace_back_quotes_in_string(match):
107+
def replace_back_quotes_in_string(match: re.Match[str]) -> str:
106108
return re.sub("<!!__QUOTE__!!>", '"', match.group())
107109

108110
query = re.sub(r"'.*?'", replace_quotes_in_string, self._raw_query)
@@ -334,7 +336,7 @@ def subqueries_names(self) -> List[str]:
334336
# -------------------------------------------------------------------
335337

336338
@staticmethod
337-
def _extract_int_from_node(node) -> Optional[int]:
339+
def _extract_int_from_node(node: Any) -> Optional[int]:
338340
"""Safely extract an integer value from a Limit or Offset node."""
339341
if not node:
340342
return None
@@ -440,7 +442,7 @@ def _extract_values(self) -> List:
440442
return values
441443

442444
@staticmethod
443-
def _convert_value(val) -> Union[int, float, str]:
445+
def _convert_value(val: exp.Expression) -> Union[int, float, str]:
444446
"""Convert a sqlglot literal AST node to a Python type."""
445447
from sqlglot import exp
446448

sql_metadata/sql_cleaner.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def _strip_outer_parens(sql: str) -> str:
3030
"""
3131
s = sql.strip()
3232

33-
def _is_wrapped(text):
33+
def _is_wrapped(text: str) -> bool:
3434
if len(text) < 2 or text[0] != "(" or text[-1] != ")":
3535
return False
3636
inner = text[1:-1]
@@ -66,7 +66,7 @@ def _normalize_cte_names(sql: str) -> tuple:
6666
re.IGNORECASE,
6767
)
6868

69-
def replacer(match):
69+
def replacer(match: re.Match[str]) -> str:
7070
prefix = match.group(1)
7171
qualified_name = match.group(2)
7272
suffix = match.group(3)

sql_metadata/table_extractor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -212,7 +212,7 @@ def _first_position(self, name: str) -> int:
212212
return pos if pos >= 0 else len(self._raw_sql)
213213

214214
@staticmethod
215-
def _word_pattern(name_upper: str):
215+
def _word_pattern(name_upper: str) -> re.Pattern[str]:
216216
"""Build a regex matching *name_upper* as a whole word."""
217217
escaped = re.escape(name_upper)
218218
return re.compile(r"(?<![A-Za-z0-9_])" + escaped + r"(?![A-Za-z0-9_])")

sql_metadata/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ class UniqueList(list):
1717
an internal ``set`` for O(1) membership checks.
1818
"""
1919

20-
def __init__(self, *args, **kwargs):
20+
def __init__(self, *args: Any, **kwargs: Any) -> None:
2121
super().__init__(*args, **kwargs)
2222
self._seen: set = set(self)
2323

@@ -32,7 +32,7 @@ def extend(self, items: Iterable[Any]) -> None: # type: ignore[override]
3232
for item in items:
3333
self.append(item)
3434

35-
def __sub__(self, other) -> List:
35+
def __sub__(self, other: Any) -> List:
3636
"""Return a plain list of elements in *self* that are not in *other*."""
3737
other_set = set(other)
3838
return [x for x in self if x not in other_set]

0 commit comments

Comments
 (0)