Skip to content

Commit c9e45fb

Browse files
authored
Merge pull request #130 from DanCardin/dc/unnamed-parameters
fix: Handle unnamed parameters during normalization and CREATE statements.
2 parents 3ca08f0 + 7ae6f85 commit c9e45fb

File tree

5 files changed

+974
-917
lines changed

5 files changed

+974
-917
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@
22

33
## 0.16
44

5+
### 0.16.4
6+
7+
- fix: Handle unnamed parameters during normalization and CREATE statements.
8+
59
### 0.16.3
610

711
- fix: Ensure existing functions are normalized in all cases.

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "sqlalchemy-declarative-extensions"
3-
version = "0.16.3"
3+
version = "0.16.4"
44
authors = [
55
{name = "Dan Cardin", email = "ddcardin@gmail.com"},
66
]

src/sqlalchemy_declarative_extensions/dialects/postgresql/function.py

Lines changed: 20 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -35,28 +35,6 @@ def from_provolatile(cls, provolatile: str) -> FunctionVolatility:
3535
raise ValueError(f"Invalid volatility: {provolatile}")
3636

3737

38-
# def normalize_arg(arg: str) -> str:
39-
# parts = arg.strip().split(maxsplit=1)
40-
# if len(parts) == 2:
41-
# name, type_str = parts
42-
# norm_type = type_map.get(type_str.lower(), type_str.lower())
43-
# # Handle array types
44-
# if norm_type.endswith("[]"):
45-
# base_type = norm_type[:-2]
46-
# norm_base_type = type_map.get(base_type, base_type)
47-
# norm_type = f"{norm_base_type}[]"
48-
#
49-
# return f"{name} {norm_type}"
50-
# # Handle case where it might just be the type (e.g., from DROP FUNCTION)
51-
# type_str = arg.strip()
52-
# norm_type = type_map.get(type_str.lower(), type_str.lower())
53-
# if norm_type.endswith("[]"):
54-
# base_type = norm_type[:-2]
55-
# norm_base_type = type_map.get(base_type, base_type)
56-
# norm_type = f"{norm_base_type}[]"
57-
# return norm_type
58-
59-
6038
@dataclass
6139
class Function(base.Function):
6240
"""Describes a PostgreSQL function.
@@ -143,7 +121,7 @@ def normalize(self) -> Function:
143121

144122
@dataclass
145123
class FunctionParam:
146-
name: str
124+
name: str | None
147125
type: str
148126
default: Any | None = None
149127
mode: Literal["i", "o", "b", "v", "t"] | None = None
@@ -185,31 +163,40 @@ def from_unknown(
185163
if isinstance(source_param, tuple):
186164
return cls(*source_param)
187165

188-
name, type = source_param.strip().split(maxsplit=1)
166+
try:
167+
name, type = source_param.strip().split(maxsplit=1)
168+
except ValueError:
169+
name = None
170+
type = source_param.strip()
171+
189172
return cls(name, type)
190173

191174
def normalize(self) -> FunctionParam:
192175
type = self.type.lower()
193176
return replace(
194177
self,
195-
name=self.name.lower(),
178+
name=self.name.lower() if self.name is not None else None,
196179
mode=self.mode or "i",
197180
type=type_map.get(type, type),
198181
default=str(self.default) if self.default is not None else None,
199182
)
200183

201184
def to_sql_create(self) -> str:
202-
result = ""
185+
segments = []
203186
if self.mode:
204-
result += {"o": "OUT ", "b": "INOUT ", "v": "VARIADIC ", "t": "TABLE "}.get(
205-
self.mode, ""
206-
)
187+
modes = {"o": "OUT ", "b": "INOUT ", "v": "VARIADIC ", "t": "TABLE "}
188+
mode = modes.get(self.mode)
189+
if mode:
190+
segments.append(mode)
191+
192+
if self.name:
193+
segments.append(self.name)
207194

208-
result += f"{self.name} {self.type}"
195+
segments.append(self.type)
209196

210197
if self.default is not None:
211-
result += f" DEFAULT {self.default}"
212-
return result
198+
segments.append(f"DEFAULT {self.default}")
199+
return " ".join(segments)
213200

214201
def to_sql_drop(self) -> str:
215202
return self.type
@@ -245,7 +232,7 @@ def from_unknown(
245232
returns_lower = source.lower().strip()
246233
if returns_lower.startswith("table("):
247234
table_return_params = [
248-
(p.name, p.type) for p in (parameters or []) if p.mode == "t"
235+
(p.name, p.type) for p in (parameters or []) if p.name and p.mode == "t"
249236
]
250237

251238
if not table_return_params:
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
from pytest_mock_resources import create_postgres_fixture
2+
from sqlalchemy import text
3+
4+
from sqlalchemy_declarative_extensions.dialects.postgresql.function import Function
5+
from sqlalchemy_declarative_extensions.function.base import Functions
6+
from sqlalchemy_declarative_extensions.function.compare import (
7+
DropFunctionOp,
8+
compare_functions,
9+
)
10+
11+
pg = create_postgres_fixture(engine_kwargs={"echo": True}, session=True)
12+
13+
14+
def test_existing_function(pg):
15+
pg.execute(
16+
text(
17+
"""
18+
CREATE OR REPLACE FUNCTION add_numbers(integer, integer)
19+
RETURNS integer AS $$
20+
BEGIN
21+
RETURN $1 + $2;
22+
END;
23+
$$ LANGUAGE plpgsql;
24+
"""
25+
)
26+
)
27+
28+
connection = pg.connection()
29+
diff = compare_functions(connection, Functions())
30+
31+
assert len(diff) == 1
32+
assert isinstance(diff[0], DropFunctionOp)
33+
assert "DROP FUNCTION" in diff[0].to_sql()[0]
34+
assert "CREATE FUNCTION" in diff[0].reverse().to_sql()[0]
35+
36+
37+
def test_creates_function(pg):
38+
connection = pg.connection()
39+
diff = compare_functions(
40+
connection,
41+
Functions(
42+
functions=[
43+
Function(
44+
"add_numbers",
45+
"""
46+
BEGIN
47+
RETURN $1 + $2;
48+
END;
49+
""",
50+
language="plpgsql",
51+
returns="integer",
52+
parameters=["integer", "integer"],
53+
)
54+
]
55+
),
56+
)
57+
58+
assert len(diff) == 1
59+
assert "CREATE FUNCTION" in diff[0].to_sql()[0]
60+
pg.execute(text(diff[0].to_sql()[0]))
61+
pg.commit()
62+
63+
result = pg.execute(text("SELECT add_numbers(2, 3)")).one()[0]
64+
assert result == 5

0 commit comments

Comments
 (0)