Skip to content

Commit 66ff703

Browse files
authored
Merge pull request #122 from enobayram/narrow-down-sqlalchemy-parameter-escapes
Use bind params regex from sqlalchemy while escaping
2 parents 4ea6f90 + 8d990b6 commit 66ff703

File tree

2 files changed

+16
-1
lines changed

2 files changed

+16
-1
lines changed

src/sqlalchemy_declarative_extensions/sqlalchemy.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
from __future__ import annotations
22

3+
import re
34
from typing import Callable, TypeVar
45

56
import sqlalchemy
@@ -50,8 +51,12 @@ def dispatch(connection: Connection, *args: P.args, **kwargs: P.kwargs) -> T:
5051
return dispatch
5152

5253

54+
# https://github.com/sqlalchemy/sqlalchemy/blob/2e9902a34fafff0ac6d6c521a86c7dea3d96a392/lib/sqlalchemy/sql/elements.py#L2334
55+
_sqlalchemy_bind_params_regex = re.compile(r"(?<![:\w\x5c]):(\w+)(?!:)", re.UNICODE)
56+
57+
5358
def escape_params(query: str) -> str:
54-
return query.replace(":", r"\:")
59+
return _sqlalchemy_bind_params_regex.sub(r"\\:\1", query)
5560

5661

5762
if version.startswith("1.3"):

tests/view/test_escaped_bindparam.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,11 @@
66
register_sqlalchemy_events,
77
view,
88
)
9+
from sqlalchemy_declarative_extensions.alembic.view import UpdateViewOp
10+
from sqlalchemy_declarative_extensions.dialects import postgresql
11+
from sqlalchemy_declarative_extensions.dialects.postgresql import View
912
from sqlalchemy_declarative_extensions.sqlalchemy import declarative_base
13+
from sqlalchemy_declarative_extensions.view.compare import compare_views
1014

1115
_Base = declarative_base()
1216

@@ -41,3 +45,9 @@ def test_escape_bindparam_postgres(pg):
4145

4246
result = pg.execute(text("select * from bar")).fetchall()
4347
assert result == []
48+
49+
# Make sure that bindparams escaping doesn't create unnecessary escapes
50+
# for the literal casts that appear after view definition round-tripping
51+
rendered = View("simple_select", "SELECT 'a' as col1").render_definition(pg.connection())
52+
assert "::" in rendered, "Literals in the view definition are expected to get explicit type casts"
53+
assert "\\:\\:" not in rendered, "Bind parameters escaping should leave type casts unescaped"

0 commit comments

Comments
 (0)