Skip to content

Commit b7c0d8c

Browse files
authored
Merge pull request #116 from DanCardin/dc/quote-audit-columns
fix: Quote audit column names.
2 parents 2421814 + b0a8d33 commit b7c0d8c

File tree

13 files changed

+138
-19
lines changed

13 files changed

+138
-19
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@
22

33
## 0.15
44

5+
### 0.15.11
6+
- fix: Exclude extension-created objects from comparisons.
7+
- fix: Quote audit column names.
8+
59
### 0.15.10
610

711
- feat: Support alembic check.

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "sqlalchemy-declarative-extensions"
3-
version = "0.15.10"
3+
version = "0.15.11"
44
authors = [
55
{name = "Dan Cardin", email = "ddcardin@gmail.com"},
66
]

src/sqlalchemy_declarative_extensions/audit.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
register_trigger,
1111
)
1212
from sqlalchemy_declarative_extensions.dialects.postgresql.trigger import Trigger
13+
from sqlalchemy_declarative_extensions.sql import quote_name
1314

1415
default_primary_key = Column(
1516
"audit_pk", types.Integer(), primary_key=True, autoincrement=True
@@ -192,7 +193,7 @@ def create_audit_functions(
192193
if column.name == AUDIT_PK:
193194
continue
194195

195-
audit_columns.append(column.name)
196+
audit_columns.append(quote_name(column.name))
196197

197198
if column.name in {
198199
AUDIT_PK,
@@ -211,8 +212,8 @@ def create_audit_functions(
211212
old_elements.append(value)
212213
new_elements.append(value)
213214
else:
214-
old_elements.append(f"OLD.{column.name}")
215-
new_elements.append(f"NEW.{column.name}")
215+
old_elements.append(f'OLD."{column.name}"')
216+
new_elements.append(f'NEW."{column.name}"')
216217

217218
audit_columns_str = ", ".join(audit_columns)
218219
old_elements_str = ", ".join(old_elements)
@@ -234,7 +235,7 @@ def create_audit_functions(
234235
"_".join([function_name, op.lower()]),
235236
f"""
236237
BEGIN
237-
INSERT INTO {audit_table.fullname} ({audit_columns_str})
238+
INSERT INTO {quote_name(audit_table.fullname)} ({audit_columns_str})
238239
SELECT '{op_key}', now(), current_user, {elements};
239240
RETURN NULL;
240241
END

src/sqlalchemy_declarative_extensions/dialects/postgresql/function.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from dataclasses import dataclass, replace
66

77
from sqlalchemy_declarative_extensions.function import base
8+
from sqlalchemy_declarative_extensions.sql import quote_name
89

910

1011
@enum.unique
@@ -31,7 +32,7 @@ def to_sql_create(self, replace=False) -> list[str]:
3132
components.append("OR REPLACE")
3233

3334
components.append("FUNCTION")
34-
components.append(self.qualified_name + "()")
35+
components.append(quote_name(self.qualified_name) + "()")
3536
components.append(f"RETURNS {self.returns}")
3637

3738
if self.security == FunctionSecurity.definer:

src/sqlalchemy_declarative_extensions/dialects/postgresql/procedure.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from typing_extensions import Self
88

99
from sqlalchemy_declarative_extensions.procedure import base
10+
from sqlalchemy_declarative_extensions.sql import quote_name
1011

1112

1213
@enum.unique
@@ -33,7 +34,7 @@ def to_sql_create(self, replace=False) -> list[str]:
3334
components.append("OR REPLACE")
3435

3536
components.append("PROCEDURE")
36-
components.append(self.qualified_name + "()")
37+
components.append(quote_name(self.qualified_name) + "()")
3738

3839
if self.security == ProcedureSecurity.definer:
3940
components.append("SECURITY DEFINER")

src/sqlalchemy_declarative_extensions/dialects/postgresql/trigger.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from sqlalchemy.engine import Connection
66

77
from sqlalchemy_declarative_extensions.dialects.from_string import FromStrings
8+
from sqlalchemy_declarative_extensions.sql import quote_name
89
from sqlalchemy_declarative_extensions.trigger import base
910

1011

@@ -155,14 +156,12 @@ def to_sql_create(self, replace=False):
155156
components.append("OR REPLACE")
156157

157158
components.append("TRIGGER")
158-
components.append(f'"{self.name}"')
159+
components.append(quote_name(self.name))
159160
components.append(self.time.value)
160161
components.append(" OR ".join([e.value for e in self.events]))
161162
components.append("ON")
162163

163-
on_components = [f'"{component}"' for component in self.on.split(".")]
164-
on = ".".join(on_components)
165-
components.append(on)
164+
components.append(quote_name(self.on))
166165

167166
components.append("FOR EACH")
168167
components.append(self.for_each.value)
@@ -175,7 +174,7 @@ def to_sql_create(self, replace=False):
175174
args_quoted = (
176175
tuple(f"'{arg}'" for arg in self.arguments) if self.arguments else ()
177176
)
178-
components.append(self.execute + f"({','.join(args_quoted)})")
177+
components.append(quote_name(self.execute) + f"({','.join(args_quoted)})")
179178
return " ".join(components) + ";"
180179

181180
def to_sql_update(self, connection: Connection | None = None):

src/sqlalchemy_declarative_extensions/function/base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from sqlalchemy import MetaData
77
from typing_extensions import Self
88

9-
from sqlalchemy_declarative_extensions.sql import qualify_name
9+
from sqlalchemy_declarative_extensions.sql import qualify_name, quote_name
1010
from sqlalchemy_declarative_extensions.sqlalchemy import HasMetaData
1111

1212

@@ -54,7 +54,7 @@ def to_sql_update(self) -> list[str]:
5454
]
5555

5656
def to_sql_drop(self) -> list[str]:
57-
return [f"DROP FUNCTION {self.qualified_name}();"]
57+
return [f"DROP FUNCTION {quote_name(self.qualified_name)}();"]
5858

5959
def with_name(self, name: str):
6060
return replace(self, name=name)

src/sqlalchemy_declarative_extensions/procedure/base.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from sqlalchemy import MetaData
77
from typing_extensions import Self
88

9-
from sqlalchemy_declarative_extensions.sql import qualify_name
9+
from sqlalchemy_declarative_extensions.sql import qualify_name, quote_name
1010
from sqlalchemy_declarative_extensions.sqlalchemy import HasMetaData
1111

1212

@@ -52,7 +52,7 @@ def to_sql_update(self) -> list[str]:
5252
]
5353

5454
def to_sql_drop(self) -> list[str]:
55-
return [f"DROP PROCEDURE {self.qualified_name}();"]
55+
return [f"DROP PROCEDURE {quote_name(self.qualified_name)}();"]
5656

5757
def with_name(self, name: str):
5858
return replace(self, name=name)

src/sqlalchemy_declarative_extensions/sql.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -12,9 +12,10 @@ def qualify_name(schema: str | None, name: str, quote=False) -> str:
1212
return f'"{name}"'
1313
return name
1414

15+
result = f"{schema}.{name}"
1516
if quote:
16-
return f'"{schema}"."{name}"'
17-
return f"{schema}.{name}"
17+
return quote_name(result)
18+
return result
1819

1920

2021
def split_schema(
@@ -46,3 +47,8 @@ def coerce_name(name: str | HasName):
4647
if isinstance(name, HasName):
4748
return name.name
4849
return name
50+
51+
52+
def quote_name(name: str) -> str:
53+
components = [f'"{component}"' for component in name.split(".")]
54+
return ".".join(components)

tests/audit/test_quoting.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ class Something(Base):
2121
__tablename__ = "Something"
2222

2323
id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True)
24+
end: Mapped[int] = mapped_column()
2425

2526

2627
register_sqlalchemy_events(Base.metadata, functions=True, triggers=True)
@@ -31,3 +32,6 @@ class Something(Base):
3132
def test_quoted_name(pg):
3233
Base.metadata.create_all(bind=pg.connection())
3334
pg.commit()
35+
36+
pg.add(Something(end=2))
37+
pg.commit()

0 commit comments

Comments
 (0)