|
| 1 | +from __future__ import annotations |
| 2 | + |
| 3 | +from sqlalchemy import func |
| 4 | +from sqlalchemy.engine import Connection |
| 5 | + |
1 | 6 | from sqlalchemy_declarative_extensions.dialects import postgresql |
2 | 7 | from sqlalchemy_declarative_extensions.dialects.mysql.query import ( |
3 | 8 | check_schema_exists_mysql, |
4 | | - check_table_exists_mysql, |
5 | 9 | get_views_mysql, |
6 | 10 | ) |
7 | 11 | from sqlalchemy_declarative_extensions.dialects.postgresql.query import ( |
8 | 12 | check_schema_exists_postgresql, |
9 | | - check_table_exists_postgresql, |
10 | 13 | get_default_grants_postgresql, |
11 | 14 | get_functions_postgresql, |
12 | 15 | get_grants_postgresql, |
|
19 | 22 | ) |
20 | 23 | from sqlalchemy_declarative_extensions.dialects.sqlite.query import ( |
21 | 24 | check_schema_exists_sqlite, |
22 | | - check_table_exists_sqlite, |
23 | 25 | get_views_sqlite, |
24 | 26 | ) |
25 | | -from sqlalchemy_declarative_extensions.sqlalchemy import dialect_dispatch |
| 27 | +from sqlalchemy_declarative_extensions.sqlalchemy import dialect_dispatch, select |
26 | 28 |
|
27 | 29 | get_schemas = dialect_dispatch( |
28 | 30 | postgresql=get_schemas_postgresql, |
|
34 | 36 | mysql=check_schema_exists_mysql, |
35 | 37 | ) |
36 | 38 |
|
37 | | -check_table_exists = dialect_dispatch( |
38 | | - postgresql=check_table_exists_postgresql, |
39 | | - sqlite=check_table_exists_sqlite, |
40 | | - mysql=check_table_exists_mysql, |
41 | | -) |
42 | | - |
43 | 39 | get_objects = dialect_dispatch( |
44 | 40 | postgresql=get_objects_postgresql, |
45 | 41 | ) |
|
85 | 81 | get_triggers = dialect_dispatch( |
86 | 82 | postgresql=get_triggers_postgresql, |
87 | 83 | ) |
| 84 | + |
| 85 | + |
| 86 | +def check_table_exists( |
| 87 | + connection: Connection, tablename: str, schema: str | None = None |
| 88 | +): |
| 89 | + return connection.dialect.has_table(connection, tablename, schema=schema) |
| 90 | + |
| 91 | + |
| 92 | +def get_current_schema(connection: Connection) -> str | None: |
| 93 | + if connection.dialect.name == "mysql": |
| 94 | + return None |
| 95 | + |
| 96 | + schema = connection.execute(select(func.current_schema())).scalar() |
| 97 | + |
| 98 | + default_schema = connection.dialect.default_schema_name |
| 99 | + if schema == default_schema: |
| 100 | + return None |
| 101 | + |
| 102 | + return schema.lower() |
0 commit comments