Skip to content

Commit 669ee45

Browse files
authored
Merge pull request #134 from DanCardin/dc/include-exclude
feat: Add include support for trigger/function and exclude support fo…
2 parents 203792f + 7b6fc1f commit 669ee45

File tree

10 files changed

+429
-19
lines changed

10 files changed

+429
-19
lines changed

CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,11 @@
44

55
## 0.16
66

7+
### 0.16.7
8+
9+
- feat: Add missing support for `exclude` on triggers.
10+
- feat: Add support for `include` on triggers and functions.
11+
712
### 0.16.6
813

914
- fix: postgresql parsing of existing function defaults.

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "sqlalchemy-declarative-extensions"
3-
version = "0.16.6"
3+
version = "0.16.7"
44
authors = [{ name = "Dan Cardin", email = "ddcardin@gmail.com" }]
55
description = "Library to declare additional kinds of objects not natively supported by SQLAlchemy/Alembic."
66
license = { file = "LICENSE" }

src/sqlalchemy_declarative_extensions/function/base.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,7 @@ class Functions:
9191

9292
functions: list[Function] = field(default_factory=list)
9393

94+
include: list[str] | None = None
9495
ignore: list[str] = field(default_factory=list)
9596
ignore_unspecified: bool = False
9697

@@ -130,10 +131,18 @@ def extract(cls, metadata: MetaData | list[MetaData | None] | None) -> Self | No
130131
)
131132

132133
functions = [s for instance in instances for s in instance.functions]
134+
# Preserve None if all instances have include=None, otherwise combine all non-None includes
135+
include_values = [
136+
instance.include for instance in instances if instance.include is not None
137+
]
138+
include = [s for inc in include_values for s in inc] if include_values else None
133139
ignore = [s for instance in instances for s in instance.ignore]
134140
ignore_unspecified = instances[0].ignore_unspecified
135141
return cls(
136-
functions=functions, ignore_unspecified=ignore_unspecified, ignore=ignore
142+
functions=functions,
143+
ignore_unspecified=ignore_unspecified,
144+
ignore=ignore,
145+
include=include,
137146
)
138147

139148
def append(self, function: Function):

src/sqlalchemy_declarative_extensions/function/compare.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,9 @@ def compare_functions(connection: Connection, functions: Functions) -> list[Oper
5555
expected_function_names = set(functions_by_name)
5656

5757
raw_existing_functions = get_functions(connection)
58-
existing_functions = filter_functions(raw_existing_functions, functions.ignore)
58+
existing_functions = filter_functions(
59+
raw_existing_functions, exclude=functions.ignore, include=functions.include
60+
)
5961
existing_functions_by_name = {
6062
f.qualified_name: f.normalize() for f in existing_functions
6163
}
@@ -93,12 +95,18 @@ def compare_functions(connection: Connection, functions: Functions) -> list[Oper
9395

9496

9597
def filter_functions(
96-
functions: Sequence[Function], exclude: list[str]
98+
functions: Sequence[Function], *, exclude: list[str], include: list[str] | None
9799
) -> list[Function]:
98100
return [
99101
f
100102
for f in functions
101-
if not any(
103+
if (
104+
include is None
105+
or any(
106+
fnmatch.fnmatch(f.qualified_name, inclusion) for inclusion in include
107+
)
108+
)
109+
and not any(
102110
fnmatch.fnmatch(f.qualified_name, exclusion) for exclusion in exclude
103111
)
104112
]

src/sqlalchemy_declarative_extensions/trigger/base.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ def to_sql_drop(self):
3838
class Triggers:
3939
triggers: list[Trigger] = field(default_factory=list)
4040

41+
include: list[str] | None = None
42+
ignore: list[str] = field(default_factory=list)
4143
ignore_unspecified: bool = False
4244

4345
@classmethod
@@ -76,8 +78,19 @@ def extract(cls, metadata: MetaData | list[MetaData | None] | None) -> Self | No
7678
)
7779

7880
triggers = [s for instance in instances for s in instance.triggers]
81+
# Preserve None if all instances have include=None, otherwise combine all non-None includes
82+
include_values = [
83+
instance.include for instance in instances if instance.include is not None
84+
]
85+
include = [s for inc in include_values for s in inc] if include_values else None
86+
ignore = [s for instance in instances for s in instance.ignore]
7987
ignore_unspecified = instances[0].ignore_unspecified
80-
return cls(triggers=triggers, ignore_unspecified=ignore_unspecified)
88+
return cls(
89+
triggers=triggers,
90+
ignore_unspecified=ignore_unspecified,
91+
ignore=ignore,
92+
include=include,
93+
)
8194

8295
def append(self, trigger: Trigger):
8396
self.triggers.append(trigger)

src/sqlalchemy_declarative_extensions/trigger/compare.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
from __future__ import annotations
22

3+
import fnmatch
34
from dataclasses import dataclass
4-
from typing import Union
5+
from typing import Sequence, Union
56

67
from sqlalchemy.engine import Connection
78

@@ -53,7 +54,11 @@ def compare_triggers(connection: Connection, triggers: Triggers) -> list[Operati
5354
triggers_by_name = {r.name: r for r in triggers.triggers}
5455
expected_trigger_names = set(triggers_by_name)
5556

56-
existing_triggers = get_triggers(connection)
57+
raw_existing_triggers = get_triggers(connection)
58+
existing_triggers = filter_triggers(
59+
raw_existing_triggers, exclude=triggers.ignore, include=triggers.include
60+
)
61+
5762
existing_triggers_by_name = {r.name: r for r in existing_triggers}
5863
existing_trigger_names = set(existing_triggers_by_name)
5964

@@ -77,3 +82,17 @@ def compare_triggers(connection: Connection, triggers: Triggers) -> list[Operati
7782
result.append(DropTriggerOp(trigger))
7883

7984
return result
85+
86+
87+
def filter_triggers(
88+
triggers: Sequence[Trigger], *, exclude: list[str], include: list[str] | None
89+
) -> list[Trigger]:
90+
return [
91+
t
92+
for t in triggers
93+
if (
94+
include is None
95+
or any(fnmatch.fnmatch(t.name, inclusion) for inclusion in include)
96+
)
97+
and not any(fnmatch.fnmatch(t.name, exclusion) for exclusion in exclude)
98+
]

tests/dialect/postgresql/test_function_defaults.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import pytest
2-
from pytest_mock_resources import PostgresConfig, create_postgres_fixture
2+
from pytest_mock_resources import create_postgres_fixture
33
from sqlalchemy import text
44

55
from sqlalchemy_declarative_extensions import Functions
@@ -13,8 +13,6 @@
1313
pg = create_postgres_fixture(scope="function", engine_kwargs={"echo": True})
1414

1515

16-
17-
1816
@pytest.mark.parametrize(
1917
("default_a", "default_b", "default_c"),
2018
[(None, None, "''::text"), (1, 0, "'m'::text"), (None, 2, "'ft'::text")],
Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,177 @@
1+
import pytest
2+
from pytest_mock_resources import create_postgres_fixture
3+
from sqlalchemy import text
4+
from sqlalchemy.exc import ProgrammingError
5+
6+
from sqlalchemy_declarative_extensions import (
7+
Functions,
8+
declarative_database,
9+
register_sqlalchemy_events,
10+
)
11+
from sqlalchemy_declarative_extensions.sqlalchemy import declarative_base
12+
13+
_Base = declarative_base()
14+
15+
16+
@declarative_database
17+
class BaseIncludeOnly(_Base): # type: ignore
18+
__abstract__ = True
19+
20+
functions = Functions(include=["test_*"])
21+
22+
23+
@declarative_database
24+
class BaseExcludeOnly(_Base): # type: ignore
25+
__abstract__ = True
26+
27+
functions = Functions(ignore=["ignore_*"])
28+
29+
30+
@declarative_database
31+
class BaseIncludeAndExclude(_Base): # type: ignore
32+
__abstract__ = True
33+
34+
functions = Functions(include=["test_*", "keep_*"], ignore=["*_ignore"])
35+
36+
37+
register_sqlalchemy_events(BaseIncludeOnly.metadata, functions=True)
38+
register_sqlalchemy_events(BaseExcludeOnly.metadata, functions=True)
39+
register_sqlalchemy_events(BaseIncludeAndExclude.metadata, functions=True)
40+
41+
pg_include = create_postgres_fixture(engine_kwargs={"echo": True}, session=True)
42+
pg_exclude = create_postgres_fixture(engine_kwargs={"echo": True}, session=True)
43+
pg_both = create_postgres_fixture(engine_kwargs={"echo": True}, session=True)
44+
45+
46+
def test_include_only(pg_include):
47+
# Matches the include pattern, thus dropped because it's not declared.
48+
pg_include.execute(
49+
text(
50+
"CREATE FUNCTION test_func() RETURNS INTEGER language sql as $$ select 1 $$;"
51+
)
52+
)
53+
# Doesn't match the include pattern, thus kept because it's unmanaged.
54+
pg_include.execute(
55+
text(
56+
"CREATE FUNCTION other_func() RETURNS INTEGER language sql as $$ select 2 $$;"
57+
)
58+
)
59+
pg_include.commit()
60+
61+
BaseIncludeOnly.metadata.create_all(bind=pg_include.connection())
62+
pg_include.commit()
63+
64+
with pytest.raises(ProgrammingError):
65+
pg_include.execute(text("SELECT test_func()")).scalar()
66+
pg_include.rollback()
67+
68+
result = pg_include.execute(text("SELECT other_func()")).scalar()
69+
assert result == 2
70+
71+
72+
def test_exclude_only(pg_exclude):
73+
# Matches the exclude pattern, thus kept because it's being ignored.
74+
pg_exclude.execute(
75+
text(
76+
"CREATE FUNCTION ignore_this() RETURNS INTEGER language sql as $$ select 1 $$;"
77+
)
78+
)
79+
# Doesn't match the exclude pattern, thus dropped because it's not being ignored.
80+
pg_exclude.execute(
81+
text(
82+
"CREATE FUNCTION manage_this() RETURNS INTEGER language sql as $$ select 2 $$;"
83+
)
84+
)
85+
pg_exclude.commit()
86+
87+
BaseExcludeOnly.metadata.create_all(bind=pg_exclude.connection())
88+
pg_exclude.commit()
89+
90+
result = pg_exclude.execute(text("SELECT ignore_this()")).scalar()
91+
assert result == 1
92+
93+
with pytest.raises(ProgrammingError):
94+
pg_exclude.execute(text("SELECT manage_this()")).scalar()
95+
96+
97+
def test_include_and_exclude_interaction(pg_both):
98+
"""Test the interaction between include and exclude.
99+
100+
A function that matches include becomes managed, but can become unmanaged if also matching the
101+
exclude.
102+
"""
103+
pg_both.execute(
104+
text(
105+
"CREATE FUNCTION test_func() RETURNS INTEGER language sql as $$ select 1 $$;"
106+
)
107+
)
108+
pg_both.execute(
109+
text(
110+
"CREATE FUNCTION test_ignore() RETURNS INTEGER language sql as $$ select 2 $$;"
111+
)
112+
)
113+
pg_both.execute(
114+
text(
115+
"CREATE FUNCTION keep_this() RETURNS INTEGER language sql as $$ select 3 $$;"
116+
)
117+
)
118+
pg_both.execute(
119+
text(
120+
"CREATE FUNCTION other_func() RETURNS INTEGER language sql as $$ select 4 $$;"
121+
)
122+
)
123+
124+
pg_both.commit()
125+
126+
BaseIncludeAndExclude.metadata.create_all(bind=pg_both.connection())
127+
pg_both.commit()
128+
129+
with pytest.raises(ProgrammingError):
130+
pg_both.execute(text("SELECT test_func()")).scalar()
131+
pg_both.rollback()
132+
133+
result = pg_both.execute(text("SELECT test_ignore()")).scalar()
134+
assert result == 2
135+
136+
with pytest.raises(ProgrammingError):
137+
pg_both.execute(text("SELECT keep_this()")).scalar()
138+
pg_both.rollback()
139+
140+
result = pg_both.execute(text("SELECT other_func()")).scalar()
141+
assert result == 4
142+
143+
144+
def test_include_with_schema_patterns(pg_include):
145+
pg_include.execute(text("CREATE SCHEMA foo"))
146+
pg_include.execute(text("CREATE SCHEMA bar"))
147+
148+
pg_include.execute(
149+
text(
150+
"CREATE FUNCTION test_one() RETURNS INTEGER language sql as $$ select 1 $$;"
151+
)
152+
)
153+
pg_include.execute(
154+
text(
155+
"CREATE FUNCTION foo.test_two() RETURNS INTEGER language sql as $$ select 2 $$;"
156+
)
157+
)
158+
pg_include.execute(
159+
text(
160+
"CREATE FUNCTION bar.other() RETURNS INTEGER language sql as $$ select 3 $$;"
161+
)
162+
)
163+
164+
pg_include.commit()
165+
166+
BaseIncludeOnly.metadata.create_all(bind=pg_include.connection())
167+
pg_include.commit()
168+
169+
with pytest.raises(ProgrammingError):
170+
pg_include.execute(text("SELECT test_one()")).scalar()
171+
pg_include.rollback()
172+
173+
result = pg_include.execute(text("SELECT foo.test_two()")).scalar()
174+
assert result == 2
175+
176+
result = pg_include.execute(text("SELECT bar.other()")).scalar()
177+
assert result == 3

0 commit comments

Comments
 (0)