Skip to content

Commit 3ad4df2

Browse files
committed
fix: Ensure existing functions are normalized in all cases.
1 parent 48cee2e commit 3ad4df2

File tree

9 files changed

+77
-21
lines changed

9 files changed

+77
-21
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@
22

33
## 0.16
44

5+
### 0.16.3
6+
7+
- fix: Ensure existing functions are normalized in all cases.
8+
59
### 0.16.2
610

711
- fix: Cast pg_proc char columns to Text explicitly

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "sqlalchemy-declarative-extensions"
3-
version = "0.16.2"
3+
version = "0.16.3"
44
authors = [
55
{name = "Dan Cardin", email = "ddcardin@gmail.com"},
66
]

src/sqlalchemy_declarative_extensions/dialects/postgresql/function.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -255,12 +255,6 @@ def from_unknown(
255255
)
256256

257257
return cls(table=table_return_params)
258-
# # Basic normalization: lowercase and remove extra spaces
259-
# # This might need refinement for complex TABLE definitions
260-
# inner_content = returns_lower[len("table(") : -1].strip()
261-
# cols = [normalize_arg(c) for c in inner_content.split(",")]
262-
# normalized_returns = f"table({', '.join(cols)})"
263-
# return cls()
264258

265259
# Normalize base return type (including array types)
266260
norm_type = type_map.get(returns_lower, returns_lower)

src/sqlalchemy_declarative_extensions/function/compare.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,9 @@ def compare_functions(connection: Connection, functions: Functions) -> list[Oper
5656

5757
raw_existing_functions = get_functions(connection)
5858
existing_functions = filter_functions(raw_existing_functions, functions.ignore)
59-
existing_functions_by_name = {r.qualified_name: r for r in existing_functions}
59+
existing_functions_by_name = {
60+
f.qualified_name: f.normalize() for f in existing_functions
61+
}
6062
existing_function_names = set(existing_functions_by_name)
6163

6264
new_function_names = expected_function_names - existing_function_names
@@ -75,8 +77,7 @@ def compare_functions(connection: Connection, functions: Functions) -> list[Oper
7577
result.append(CreateFunctionOp(normalized_function))
7678
else:
7779
existing_function = existing_functions_by_name[function_name]
78-
79-
normalized_existing_function = existing_function.normalize()
80+
normalized_existing_function = existing_function
8081

8182
if normalized_existing_function != normalized_function:
8283
result.append(
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
from pytest_mock_resources import create_postgres_fixture
2+
from sqlalchemy import text
3+
4+
from sqlalchemy_declarative_extensions import (
5+
declarative_database,
6+
register_sqlalchemy_events,
7+
)
8+
from sqlalchemy_declarative_extensions.function.compare import (
9+
DropFunctionOp,
10+
compare_functions,
11+
)
12+
from sqlalchemy_declarative_extensions.sqlalchemy import declarative_base
13+
14+
_Base = declarative_base()
15+
16+
17+
@declarative_database
18+
class Base(_Base): # type: ignore
19+
__abstract__ = True
20+
21+
functions: list = []
22+
23+
24+
register_sqlalchemy_events(Base.metadata, functions=True)
25+
26+
pg = create_postgres_fixture(engine_kwargs={"echo": True}, session=True)
27+
28+
29+
def test_existing_function_normalized(pg):
30+
Base.metadata.create_all(bind=pg.connection())
31+
pg.commit()
32+
33+
pg.execute(
34+
text(
35+
"""
36+
CREATE OR REPLACE FUNCTION echo_any_element(
37+
input_value ANYELEMENT
38+
)
39+
RETURNS ANYELEMENT
40+
LANGUAGE sql
41+
AS $$
42+
-- A generic function using the ANYELEMENT polymorphic type
43+
SELECT input_value;
44+
$$;
45+
"""
46+
)
47+
)
48+
49+
connection = pg.connection()
50+
diff = compare_functions(connection, Base.metadata.info["functions"])
51+
52+
assert len(diff) == 1
53+
assert isinstance(diff[0], DropFunctionOp)
54+
assert "DROP FUNCTION" in diff[0].to_sql()[0]
55+
assert "CREATE FUNCTION" in diff[0].reverse().to_sql()[0]

tests/schema/test_snowflake.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
1-
from sqlalchemy import Column, text, types
1+
from sqlalchemy import text
22
from sqlalchemy.engine import Engine, create_engine
33

44
from sqlalchemy_declarative_extensions import (
55
declarative_database,
66
register_sqlalchemy_events,
7-
view,
87
)
98
from sqlalchemy_declarative_extensions.sqlalchemy import declarative_base
109

@@ -31,7 +30,7 @@ def test_create_schemas_filtered_to_database(snowflake: Engine):
3130
"""
3231
Base.metadata.create_all(bind=snowflake)
3332

34-
engine = create_engine('snowflake://user:password@account/db/schema')
33+
engine = create_engine("snowflake://user:password@account/db/schema")
3534
with engine.connect() as conn:
3635
Base.metadata.create_all(engine)
3736

tests/trigger/test_drop_postgres.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ class Foo(Base):
2626

2727

2828
class TableWithSpecialName(Base):
29-
__tablename__ = "user" # This name will trip up unquoted identifiers
29+
__tablename__ = "user" # This name will trip up unquoted identifiers
3030

3131
id = Column(types.Integer(), primary_key=True)
3232

tests/view/test_escaped_bindparam.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,8 @@
66
register_sqlalchemy_events,
77
view,
88
)
9-
from sqlalchemy_declarative_extensions.alembic.view import UpdateViewOp
10-
from sqlalchemy_declarative_extensions.dialects import postgresql
119
from sqlalchemy_declarative_extensions.dialects.postgresql import View
1210
from sqlalchemy_declarative_extensions.sqlalchemy import declarative_base
13-
from sqlalchemy_declarative_extensions.view.compare import compare_views
1411

1512
_Base = declarative_base()
1613

@@ -48,6 +45,12 @@ def test_escape_bindparam_postgres(pg):
4845

4946
# Make sure that bindparams escaping doesn't create unnecessary escapes
5047
# for the literal casts that appear after view definition round-tripping
51-
rendered = View("simple_select", "SELECT 'a' as col1").render_definition(pg.connection())
52-
assert "::" in rendered, "Literals in the view definition are expected to get explicit type casts"
53-
assert "\\:\\:" not in rendered, "Bind parameters escaping should leave type casts unescaped"
48+
rendered = View("simple_select", "SELECT 'a' as col1").render_definition(
49+
pg.connection()
50+
)
51+
assert (
52+
"::" in rendered
53+
), "Literals in the view definition are expected to get explicit type casts"
54+
assert (
55+
"\\:\\:" not in rendered
56+
), "Bind parameters escaping should leave type casts unescaped"

uv.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)