Skip to content

Commit 1bb8552

Browse files
authored
Merge pull request #45 from DanCardin/dc/agnostic-has-table
refactor: Use native sqlalchemy dialect function for detecting table …
2 parents 7f4d490 + ed1a8b3 commit 1bb8552

16 files changed

Lines changed: 471 additions & 468 deletions

File tree

docs/source/contributing/dialect-support.md

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -36,15 +36,12 @@ Here's an excerpt for example:
3636
```python
3737
from sqlalchemy_declarative_extensions.dialects.mysql.query import (
3838
check_schema_exists_mysql,
39-
check_table_exists_mysql,
4039
)
4140
from sqlalchemy_declarative_extensions.dialects.postgresql.query import (
4241
check_schema_exists_postgresql,
43-
check_table_exists_postgresql,
4442
)
4543
from sqlalchemy_declarative_extensions.dialects.sqlite.query import (
4644
check_schema_exists_sqlite,
47-
check_table_exists_sqlite,
4845
)
4946
from sqlalchemy_declarative_extensions.sqlalchemy import dialect_dispatch
5047

@@ -53,12 +50,6 @@ check_schema_exists = dialect_dispatch(
5350
sqlite=check_schema_exists_sqlite,
5451
mysql=check_schema_exists_mysql,
5552
)
56-
57-
check_table_exists = dialect_dispatch(
58-
postgresql=check_table_exists_postgresql,
59-
sqlite=check_table_exists_sqlite,
60-
mysql=check_table_exists_mysql,
61-
)
6253
```
6354

6455
"Feature" implementations should only reference **these** functions, which will

poetry.lock

Lines changed: 403 additions & 373 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[tool.poetry]
22
name = "sqlalchemy-declarative-extensions"
3-
version = "0.6.8"
3+
version = "0.6.9"
44
authors = ["Dan Cardin <ddcardin@gmail.com>"]
55

66
description = "Library to declare additional kinds of objects not natively supported by SQLAlchemy/Alembic."
@@ -48,8 +48,8 @@ psycopg = "*"
4848
alembic-utils = "0.8.1"
4949
black = ">=22.3.0"
5050
coverage = ">=5"
51-
ruff = ">=0.0.165"
52-
mypy = ">=0.942"
51+
ruff = "0.0.259"
52+
mypy = "1.4.1"
5353
pytest = ">=7"
5454
pytest-xdist = "*"
5555
pytest-alembic = "*"

src/sqlalchemy_declarative_extensions/dialects/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
from sqlalchemy_declarative_extensions.dialects.query import (
33
check_schema_exists,
44
check_table_exists,
5+
get_current_schema,
56
get_default_grants,
67
get_function_cls,
78
get_functions,
@@ -18,6 +19,7 @@
1819
__all__ = [
1920
"check_schema_exists",
2021
"check_table_exists",
22+
"get_current_schema",
2123
"get_default_grants",
2224
"get_function_cls",
2325
"get_grants",

src/sqlalchemy_declarative_extensions/dialects/mysql/query.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22

33
from sqlalchemy_declarative_extensions.dialects.mysql.schema import (
44
schema_exists_query,
5-
table_exists_query,
65
views_query,
76
)
87
from sqlalchemy_declarative_extensions.view.base import View
@@ -23,10 +22,3 @@ def get_views_mysql(connection: Connection):
2322
def check_schema_exists_mysql(connection: Connection, name: str) -> bool:
2423
row = connection.execute(schema_exists_query, {"schema": name}).scalar()
2524
return not bool(row)
26-
27-
28-
def check_table_exists_mysql(connection: Connection, name: str, *, schema: str) -> bool:
29-
row = connection.execute(
30-
table_exists_query, {"name": name, "schema": schema}
31-
).scalar()
32-
return bool(row)

src/sqlalchemy_declarative_extensions/dialects/mysql/schema.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,12 +33,6 @@
3333
.where(views.c.table_schema.notin_(["sys"]))
3434
)
3535

36-
table_exists_query = (
37-
select(tables)
38-
.where(tables.c.table_schema == bindparam("schema"))
39-
.where(tables.c.table_name == bindparam("name"))
40-
)
41-
4236
schema_exists_query = select(schemata).where(
4337
schemata.c.schema_name == bindparam("schema")
4438
)

src/sqlalchemy_declarative_extensions/dialects/postgresql/query.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
roles_query,
1919
schema_exists_query,
2020
schemas_query,
21-
table_exists_query,
2221
triggers_query,
2322
view_query,
2423
views_query,
@@ -46,15 +45,6 @@ def check_schema_exists_postgresql(connection: Connection, name: str) -> bool:
4645
return not bool(row)
4746

4847

49-
def check_table_exists_postgresql(
50-
connection: Connection, name: str, *, schema: str
51-
) -> bool:
52-
row = connection.execute(
53-
table_exists_query, {"name": name, "schema": schema}
54-
).scalar()
55-
return bool(row)
56-
57-
5848
def get_objects_postgresql(connection: Connection):
5949
return sorted(
6050
[

src/sqlalchemy_declarative_extensions/dialects/postgresql/schema.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -150,13 +150,6 @@ def _schema_not_pg(column=pg_namespace.c.nspname):
150150
)
151151

152152

153-
table_exists_query = (
154-
select(tables)
155-
.where(tables.c.table_schema == bindparam("schema"))
156-
.where(tables.c.table_name == bindparam("name"))
157-
)
158-
159-
160153
default_acl_query = select(
161154
pg_roles.c.rolname.label("role_name"),
162155
pg_namespace.c.nspname.label("schema_name"),

src/sqlalchemy_declarative_extensions/dialects/query.py

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
1+
from __future__ import annotations
2+
3+
from sqlalchemy import func
4+
from sqlalchemy.engine import Connection
5+
16
from sqlalchemy_declarative_extensions.dialects import postgresql
27
from sqlalchemy_declarative_extensions.dialects.mysql.query import (
38
check_schema_exists_mysql,
4-
check_table_exists_mysql,
59
get_views_mysql,
610
)
711
from sqlalchemy_declarative_extensions.dialects.postgresql.query import (
812
check_schema_exists_postgresql,
9-
check_table_exists_postgresql,
1013
get_default_grants_postgresql,
1114
get_functions_postgresql,
1215
get_grants_postgresql,
@@ -19,10 +22,9 @@
1922
)
2023
from sqlalchemy_declarative_extensions.dialects.sqlite.query import (
2124
check_schema_exists_sqlite,
22-
check_table_exists_sqlite,
2325
get_views_sqlite,
2426
)
25-
from sqlalchemy_declarative_extensions.sqlalchemy import dialect_dispatch
27+
from sqlalchemy_declarative_extensions.sqlalchemy import dialect_dispatch, select
2628

2729
get_schemas = dialect_dispatch(
2830
postgresql=get_schemas_postgresql,
@@ -34,12 +36,6 @@
3436
mysql=check_schema_exists_mysql,
3537
)
3638

37-
check_table_exists = dialect_dispatch(
38-
postgresql=check_table_exists_postgresql,
39-
sqlite=check_table_exists_sqlite,
40-
mysql=check_table_exists_mysql,
41-
)
42-
4339
get_objects = dialect_dispatch(
4440
postgresql=get_objects_postgresql,
4541
)
@@ -85,3 +81,22 @@
8581
get_triggers = dialect_dispatch(
8682
postgresql=get_triggers_postgresql,
8783
)
84+
85+
86+
def check_table_exists(
87+
connection: Connection, tablename: str, schema: str | None = None
88+
):
89+
return connection.dialect.has_table(connection, tablename, schema=schema)
90+
91+
92+
def get_current_schema(connection: Connection) -> str | None:
93+
if connection.dialect.name == "mysql":
94+
return None
95+
96+
schema = connection.execute(select(func.current_schema())).scalar()
97+
98+
default_schema = connection.dialect.default_schema_name
99+
if schema == default_schema:
100+
return None
101+
102+
return schema.lower()
Lines changed: 1 addition & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,7 @@
11
from sqlalchemy import text
22
from sqlalchemy.engine import Connection
33

4-
from sqlalchemy_declarative_extensions.dialects.sqlite.schema import (
5-
table_exists_query,
6-
views_query,
7-
)
4+
from sqlalchemy_declarative_extensions.dialects.sqlite.schema import views_query
85
from sqlalchemy_declarative_extensions.view.base import View
96

107

@@ -26,13 +23,3 @@ def get_views_sqlite(connection: Connection):
2623
View(v.name, v.definition, schema=v.schema)
2724
for v in connection.execute(views_query()).fetchall()
2825
]
29-
30-
31-
def check_table_exists_sqlite(
32-
connection: Connection, name: str, *, schema: str
33-
) -> bool:
34-
row = connection.execute(
35-
table_exists_query(schema),
36-
{"name": name, "schema": schema},
37-
).scalar()
38-
return bool(row)

0 commit comments

Comments
 (0)