Skip to content

Commit d8ab535

Browse files
authored
Merge pull request #131 from Dan-Knott/dk/sql-body
Add support for Postgres sqlbody functions
2 parents c9e45fb + 9bca5fc commit d8ab535

File tree

15 files changed

+476
-68
lines changed

15 files changed

+476
-68
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
# Changelog
22

3+
### 0.17.0
4+
5+
- feat: Add support for Postgres sqlbody functions.
6+
37
## 0.16
48

59
### 0.16.4

pyproject.toml

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "sqlalchemy-declarative-extensions"
3-
version = "0.16.4"
3+
version = "0.16.5"
44
authors = [
55
{name = "Dan Cardin", email = "ddcardin@gmail.com"},
66
]
@@ -45,9 +45,8 @@ changelog = "https://github.com/DanCardin/sqlalchemy-declarative-extensions/blob
4545
alembic = ["alembic >= 1.0"]
4646
parse = ["sqlglot"]
4747

48-
[tool.uv]
49-
environments = ["python_version < '3.9'", "python_version >= '3.9' and python_version < '4'"]
50-
dev-dependencies = [
48+
[dependency-groups]
49+
dev = [
5150
"alembic-utils >= 0.8.1",
5251
"coverage >= 5",
5352
"mypy == 1.8.0",
@@ -58,7 +57,7 @@ dev-dependencies = [
5857
"pytest-xdist",
5958
"ruff >= 0.5.0",
6059
"sqlalchemy[mypy] >= 1.4",
61-
"psycopg",
60+
"psycopg[binary]",
6261
"psycopg2-binary",
6362

6463
# snowflake
@@ -67,6 +66,9 @@ dev-dependencies = [
6766
"snowflake-sqlalchemy >= 1.6.0; python_version >= '3.9'",
6867
]
6968

69+
[tool.uv]
70+
environments = ["python_version < '3.9'", "python_version >= '3.9' and python_version < '4'"]
71+
7072
[tool.mypy]
7173
strict_optional = true
7274
ignore_missing_imports = true

src/sqlalchemy_declarative_extensions/dialects/postgresql/function.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import enum
4+
import re
45
import textwrap
56
from dataclasses import dataclass, replace
67
from typing import Any, List, Literal, Sequence, Tuple, cast
@@ -10,6 +11,22 @@
1011
from sqlalchemy_declarative_extensions.function import base
1112
from sqlalchemy_declarative_extensions.sql import quote_name
1213

14+
_sqlbody_regex = re.compile(r"\W*(BEGIN ATOMIC|RETURN)\W", re.IGNORECASE | re.MULTILINE)
15+
"""sql_body
16+
The body of a LANGUAGE SQL function. This can either be a single statement
17+
18+
RETURN expression
19+
20+
or a block
21+
22+
BEGIN ATOMIC
23+
statement;
24+
statement;
25+
...
26+
statement;
27+
END
28+
"""
29+
1330

1431
@enum.unique
1532
class FunctionSecurity(enum.Enum):
@@ -47,6 +64,12 @@ class Function(base.Function):
4764
parameters: Sequence[FunctionParam | str] | None = None # type: ignore
4865
volatility: FunctionVolatility = FunctionVolatility.VOLATILE
4966

67+
@property
68+
def _has_sqlbody(self) -> bool:
69+
return self.language.lower() == "sql" and bool(
70+
_sqlbody_regex.match(self.definition)
71+
)
72+
5073
def to_sql_create(self, replace=False) -> list[str]:
5174
components = ["CREATE"]
5275

@@ -72,7 +95,10 @@ def to_sql_create(self, replace=False) -> list[str]:
7295
components.append(self.volatility.value)
7396

7497
components.append(f"LANGUAGE {self.language}")
75-
components.append(f"AS $${self.definition}$$")
98+
if self._has_sqlbody:
99+
components.append(self.definition)
100+
else:
101+
components.append(f"AS $${self.definition}$$")
76102

77103
return [" ".join(components) + ";"]
78104

@@ -96,6 +122,8 @@ def with_security_definer(self):
96122

97123
def normalize(self) -> Function:
98124
definition = textwrap.dedent(self.definition)
125+
if self._has_sqlbody:
126+
definition = definition.strip()
99127

100128
# Normalize parameter types
101129
parameters = []

src/sqlalchemy_declarative_extensions/dialects/postgresql/procedure.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import enum
4+
import re
45
import textwrap
56
from dataclasses import dataclass, replace
67

@@ -9,6 +10,18 @@
910
from sqlalchemy_declarative_extensions.procedure import base
1011
from sqlalchemy_declarative_extensions.sql import quote_name
1112

13+
_sqlbody_regex = re.compile(r"\W?(BEGIN ATOMIC)\W", re.IGNORECASE | re.MULTILINE)
14+
"""sql_body
15+
The body of a LANGUAGE SQL procedure. This should be a block
16+
17+
BEGIN ATOMIC
18+
statement;
19+
statement;
20+
...
21+
statement;
22+
END
23+
"""
24+
1225

1326
@enum.unique
1427
class ProcedureSecurity(enum.Enum):
@@ -27,6 +40,12 @@ class Procedure(base.Procedure):
2740

2841
security: ProcedureSecurity = ProcedureSecurity.invoker
2942

43+
@property
44+
def _has_sqlbody(self) -> bool:
45+
return self.language.lower() == "sql" and bool(
46+
_sqlbody_regex.match(self.definition)
47+
)
48+
3049
def to_sql_create(self, replace=False) -> list[str]:
3150
components = ["CREATE"]
3251

@@ -40,7 +59,10 @@ def to_sql_create(self, replace=False) -> list[str]:
4059
components.append("SECURITY DEFINER")
4160

4261
components.append(f"LANGUAGE {self.language}")
43-
components.append(f"AS $${self.definition}$$")
62+
if self._has_sqlbody:
63+
components.append(self.definition)
64+
else:
65+
components.append(f"AS $${self.definition}$$")
4466

4567
return [" ".join(components) + ";"]
4668

@@ -49,6 +71,8 @@ def to_sql_update(self) -> list[str]:
4971

5072
def normalize(self) -> Self:
5173
definition = textwrap.dedent(self.definition)
74+
if self._has_sqlbody:
75+
definition = definition.strip()
5276
return replace(self, definition=definition)
5377

5478
def with_security(self, security: ProcedureSecurity):

src/sqlalchemy_declarative_extensions/dialects/postgresql/query.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,10 +27,10 @@
2727
databases_query,
2828
default_acl_query,
2929
extensions_query,
30-
functions_query,
30+
get_functions_query,
31+
get_procedures_query,
3132
object_acl_query,
3233
objects_query,
33-
procedures_query,
3434
roles_query,
3535
schema_exists_query,
3636
schemas_query,
@@ -195,6 +195,7 @@ def get_view_postgresql(connection: Connection, name: str, schema: str = "public
195195

196196
def get_procedures_postgresql(connection: Connection) -> Sequence[BaseProcedure]:
197197
procedures = []
198+
procedures_query = get_procedures_query(connection.dialect.server_version_info)
198199
for f in connection.execute(procedures_query).fetchall():
199200
name = f.name
200201
definition = f.source
@@ -225,6 +226,7 @@ def get_functions_postgresql(connection: Connection) -> Sequence[BaseFunction]:
225226
)
226227

227228
functions = []
229+
functions_query = get_functions_query(connection.dialect.server_version_info)
228230

229231
for f in connection.execute(functions_query).fetchall():
230232
name = f.name

src/sqlalchemy_declarative_extensions/dialects/postgresql/schema.py

Lines changed: 82 additions & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
union,
1818
)
1919
from sqlalchemy.dialects.postgresql import ARRAY, CHAR, REGCLASS, aggregate_order_by
20+
from sqlalchemy.sql.functions import coalesce
2021

2122
from sqlalchemy_declarative_extensions.sqlalchemy import select
2223

@@ -317,65 +318,90 @@ def get_types(arg_type_oids):
317318
)
318319

319320

320-
procedures_query = (
321-
select(
322-
pg_proc.c.proname.label("name"),
323-
pg_namespace.c.nspname.label("schema"),
324-
pg_language.c.lanname.label("language"),
325-
pg_type.c.typname.label("return_type"),
326-
pg_proc.c.prosrc.label("source"),
327-
pg_proc.c.prosecdef.label("security_definer"),
328-
pg_proc.c.prokind.label("kind"),
329-
pg_proc.c.proargnames.label("arg_names"),
330-
pg_proc.c.proargmodes.label("arg_modes"),
331-
func.coalesce(
332-
get_types(pg_proc.c.proallargtypes), get_types(pg_proc.c.proargtypes)
333-
).label("arg_types"),
334-
func.pg_get_expr(
335-
pg_proc.c.proargdefaults, func.cast(literal("pg_proc"), REGCLASS)
336-
).label("arg_defaults"),
337-
)
338-
.select_from(
339-
pg_proc.join(pg_namespace, pg_proc.c.pronamespace == pg_namespace.c.oid)
340-
.join(pg_language, pg_proc.c.prolang == pg_language.c.oid)
341-
.join(pg_type, pg_proc.c.prorettype == pg_type.c.oid)
321+
def get_procedures_query(version_info):
322+
source = get_source_column(version_info)
323+
return (
324+
select(
325+
pg_proc.c.proname.label("name"),
326+
pg_namespace.c.nspname.label("schema"),
327+
pg_language.c.lanname.label("language"),
328+
pg_type.c.typname.label("return_type"),
329+
source.label("source"),
330+
pg_proc.c.prosecdef.label("security_definer"),
331+
pg_proc.c.prokind.label("kind"),
332+
pg_proc.c.proargnames.label("arg_names"),
333+
pg_proc.c.proargmodes.label("arg_modes"),
334+
func.coalesce(
335+
get_types(pg_proc.c.proallargtypes), get_types(pg_proc.c.proargtypes)
336+
).label("arg_types"),
337+
func.pg_get_expr(
338+
pg_proc.c.proargdefaults, func.cast(literal("pg_proc"), REGCLASS)
339+
).label("arg_defaults"),
340+
)
341+
.select_from(
342+
pg_proc.join(pg_namespace, pg_proc.c.pronamespace == pg_namespace.c.oid)
343+
.join(pg_language, pg_proc.c.prolang == pg_language.c.oid)
344+
.join(pg_type, pg_proc.c.prorettype == pg_type.c.oid)
345+
)
346+
.where(pg_namespace.c.nspname.notin_(["pg_catalog", "information_schema"]))
347+
.where(pg_proc.c.prokind == "p")
348+
.where(_schema_not_from_extension())
349+
.where(_not_from_extension(pg_proc.c.oid, "pg_proc"))
342350
)
343-
.where(pg_namespace.c.nspname.notin_(["pg_catalog", "information_schema"]))
344-
.where(pg_proc.c.prokind == "p")
345-
.where(_schema_not_from_extension())
346-
.where(_not_from_extension(pg_proc.c.oid, "pg_proc"))
347-
)
348351

349-
functions_query = (
350-
select(
351-
pg_proc.c.proname.label("name"),
352-
pg_namespace.c.nspname.label("schema"),
353-
pg_language.c.lanname.label("language"),
354-
pg_type.c.typname.label("base_return_type"),
355-
pg_proc.c.prosrc.label("source"),
356-
pg_proc.c.prosecdef.label("security_definer"),
357-
cast(pg_proc.c.prokind, Text).label("kind"),
358-
func.pg_get_function_arguments(pg_proc.c.oid).label("parameters"),
359-
cast(pg_proc.c.provolatile, Text).label("volatility"),
360-
func.pg_get_function_result(pg_proc.c.oid).label("return_type_string"),
361-
pg_proc.c.proargnames.label("arg_names"),
362-
pg_proc.c.proargmodes.label("arg_modes"),
363-
func.coalesce(
364-
get_types(pg_proc.c.proallargtypes), get_types(pg_proc.c.proargtypes)
365-
).label("arg_types"),
366-
func.pg_get_expr(
367-
pg_proc.c.proargdefaults, func.cast(literal("pg_proc"), REGCLASS)
368-
).label("arg_defaults"),
369-
)
370-
.select_from(
371-
pg_proc.join(pg_namespace, pg_proc.c.pronamespace == pg_namespace.c.oid)
372-
.join(pg_language, pg_proc.c.prolang == pg_language.c.oid)
373-
.join(pg_type, pg_proc.c.prorettype == pg_type.c.oid)
352+
353+
def get_functions_query(version_info):
354+
source = get_source_column(version_info)
355+
return (
356+
select(
357+
pg_proc.c.proname.label("name"),
358+
pg_namespace.c.nspname.label("schema"),
359+
pg_language.c.lanname.label("language"),
360+
pg_type.c.typname.label("base_return_type"),
361+
source.label("source"),
362+
pg_proc.c.prosecdef.label("security_definer"),
363+
cast(pg_proc.c.prokind, Text).label("kind"),
364+
func.pg_get_function_arguments(pg_proc.c.oid).label("parameters"),
365+
cast(pg_proc.c.provolatile, Text).label("volatility"),
366+
func.pg_get_function_result(pg_proc.c.oid).label("return_type_string"),
367+
pg_proc.c.proargnames.label("arg_names"),
368+
pg_proc.c.proargmodes.label("arg_modes"),
369+
func.coalesce(
370+
get_types(pg_proc.c.proallargtypes), get_types(pg_proc.c.proargtypes)
371+
).label("arg_types"),
372+
func.pg_get_expr(
373+
pg_proc.c.proargdefaults, func.cast(literal("pg_proc"), REGCLASS)
374+
).label("arg_defaults"),
375+
)
376+
.select_from(
377+
pg_proc.join(pg_namespace, pg_proc.c.pronamespace == pg_namespace.c.oid)
378+
.join(pg_language, pg_proc.c.prolang == pg_language.c.oid)
379+
.join(pg_type, pg_proc.c.prorettype == pg_type.c.oid)
380+
)
381+
.where(pg_namespace.c.nspname.notin_(["pg_catalog", "information_schema"]))
382+
.where(pg_proc.c.prokind != "p")
383+
.where(_not_from_extension(pg_proc.c.oid, "pg_proc"))
374384
)
375-
.where(pg_namespace.c.nspname.notin_(["pg_catalog", "information_schema"]))
376-
.where(pg_proc.c.prokind != "p")
377-
.where(_not_from_extension(pg_proc.c.oid, "pg_proc"))
378-
)
385+
386+
387+
def get_source_column(version_info):
388+
"""Postgres 14 introduced SQL-standard function and procedure bodies.
389+
390+
When writing a function or procedure in SQL-standard syntax, the body is parsed
391+
immediately and stored as a parse tree. This allows better tracking of function
392+
dependencies, and can have security benefits.
393+
394+
For these sqlbody functions, the pg_proc.prosrc column is an empty string.
395+
The pre-parsed SQL function body is stored in pg_proc.prosqlbody as pg_node_tree.
396+
The text representation can be returned along with the full ddl with
397+
pg_get_functiondef.
398+
Alternatively pg_get_function_sqlbody(pg_proc.oid) can be called to just get the
399+
body. This function is not documented, see source:
400+
https://doxygen.postgresql.org/ruleutils_8c.html#a99a3f975518b6b1707a3159c5f80427e
401+
"""
402+
if version_info >= (14, 0):
403+
return coalesce(func.pg_get_function_sqlbody(pg_proc.c.oid), pg_proc.c.prosrc)
404+
return pg_proc.c.prosrc
379405

380406

381407
rel_nsp = pg_namespace.alias("rel_nsp")
Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
import pytest
2+
from pytest_mock_resources import PostgresConfig, create_postgres_fixture
3+
from sqlalchemy import text
4+
5+
from sqlalchemy_declarative_extensions import Functions
6+
from sqlalchemy_declarative_extensions.dialects.postgresql import (
7+
Function,
8+
FunctionVolatility,
9+
)
10+
from sqlalchemy_declarative_extensions.function.compare import compare_functions
11+
12+
pg = create_postgres_fixture(scope="function", engine_kwargs={"echo": True})
13+
14+
15+
@pytest.fixture
16+
def pmr_postgres_config():
17+
return PostgresConfig(image="postgres:14", port=None, ci_port=None)
18+
19+
20+
def test_functions(pg):
21+
add_stable_function = Function(
22+
name="add_stable",
23+
definition="RETURN (i + 1)",
24+
parameters=["i integer"],
25+
returns="INTEGER",
26+
volatility=FunctionVolatility.STABLE,
27+
).normalize()
28+
create_function = add_stable_function.to_sql_create()
29+
functions = Functions([add_stable_function])
30+
with pg.connect() as connection:
31+
connection.execute(text("\n".join(create_function)))
32+
diff = compare_functions(connection, functions)
33+
for op in diff:
34+
assert op.from_function == op.function

0 commit comments

Comments
 (0)