Skip to content

Commit f036d86

Browse files
committed
Fix postgresql parsing of existing function defaults
The column pg_proc.proargdefaults when converted with pg_get_expr returns a comma separated string of default values. This list of values then needs to be aligned to the right most input argument. To avoid this complex rearangement, and avoid any potential edge cases where ',' might be in the default value, the function pg_get_function_arg_default is used to check for a default for each argument.
1 parent d8ab535 commit f036d86

2 files changed

Lines changed: 68 additions & 6 deletions

File tree

src/sqlalchemy_declarative_extensions/dialects/postgresql/schema.py

Lines changed: 20 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,7 @@ def char_literals(*literals: str) -> Collection[BindParameter]:
117117
column("prosecdef"),
118118
column("prokind"),
119119
column("provolatile"),
120+
column("pronargs"),
120121
column("proargnames"),
121122
column("proargmodes"),
122123
column("proargtypes"),
@@ -318,6 +319,23 @@ def get_types(arg_type_oids):
318319
)
319320

320321

322+
def get_defaults():
323+
arg_num = (
324+
func.generate_series(1, pg_proc.c.pronargs)
325+
.table_valued("arg_num", with_ordinality="ordinality")
326+
.alias("arg_num")
327+
)
328+
return (
329+
select(
330+
func.array_agg(
331+
func.pg_get_function_arg_default(pg_proc.c.oid, arg_num.c.arg_num)
332+
)
333+
)
334+
.select_from(arg_num)
335+
.scalar_subquery()
336+
)
337+
338+
321339
def get_procedures_query(version_info):
322340
source = get_source_column(version_info)
323341
return (
@@ -334,9 +352,7 @@ def get_procedures_query(version_info):
334352
func.coalesce(
335353
get_types(pg_proc.c.proallargtypes), get_types(pg_proc.c.proargtypes)
336354
).label("arg_types"),
337-
func.pg_get_expr(
338-
pg_proc.c.proargdefaults, func.cast(literal("pg_proc"), REGCLASS)
339-
).label("arg_defaults"),
355+
get_defaults().label("arg_defaults"),
340356
)
341357
.select_from(
342358
pg_proc.join(pg_namespace, pg_proc.c.pronamespace == pg_namespace.c.oid)
@@ -369,9 +385,7 @@ def get_functions_query(version_info):
369385
func.coalesce(
370386
get_types(pg_proc.c.proallargtypes), get_types(pg_proc.c.proargtypes)
371387
).label("arg_types"),
372-
func.pg_get_expr(
373-
pg_proc.c.proargdefaults, func.cast(literal("pg_proc"), REGCLASS)
374-
).label("arg_defaults"),
388+
get_defaults().label("arg_defaults"),
375389
)
376390
.select_from(
377391
pg_proc.join(pg_namespace, pg_proc.c.pronamespace == pg_namespace.c.oid)
Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
import pytest
2+
from pytest_mock_resources import PostgresConfig, create_postgres_fixture
3+
from sqlalchemy import text
4+
5+
from sqlalchemy_declarative_extensions import Functions
6+
from sqlalchemy_declarative_extensions.dialects.postgresql import (
7+
Function,
8+
FunctionParam,
9+
FunctionVolatility,
10+
)
11+
from sqlalchemy_declarative_extensions.function.compare import compare_functions
12+
13+
pg = create_postgres_fixture(scope="function", engine_kwargs={"echo": True})
14+
15+
16+
@pytest.fixture
17+
def pmr_postgres_config():
18+
return PostgresConfig(image="postgres:11", port=None, ci_port=None)
19+
20+
21+
@pytest.mark.parametrize(
22+
("default_a", "default_b", "default_c"),
23+
[(None, None, "''::text"), (1, 0, "'m'::text"), (None, 2, "'ft'::text")],
24+
)
25+
def test_function_argument_defaults(pg, default_a, default_b, default_c):
26+
add_label_function = Function(
27+
name="add_label",
28+
definition="""
29+
BEGIN
30+
RETURN (((a + b))::text || c);
31+
END;
32+
""",
33+
parameters=[
34+
FunctionParam("a", "integer", default=default_a),
35+
FunctionParam("b", "integer", default=default_b),
36+
FunctionParam("c", "text", default=default_c),
37+
],
38+
returns="TEXT",
39+
volatility=FunctionVolatility.STABLE,
40+
language="plpgsql",
41+
).normalize()
42+
create_function = add_label_function.to_sql_create()
43+
functions = Functions([add_label_function])
44+
with pg.connect() as connection:
45+
connection.execute(text("\n".join(create_function)))
46+
diff = compare_functions(connection, functions)
47+
for op in diff:
48+
assert op.from_function == op.function

0 commit comments

Comments
 (0)