Skip to content

Commit 05365c5

Browse files
committed
migration generators
1 parent fa65b97 commit 05365c5

3 files changed

Lines changed: 389 additions & 3 deletions

File tree

paradedb/sqlalchemy/alembic.py

Lines changed: 138 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
from __future__ import annotations
22

3-
from alembic.autogenerate import renderers
3+
from alembic.autogenerate import comparators, renderers
44
from alembic.operations import Operations
55
from alembic.operations.ops import MigrateOperation
6+
from alembic.util import DispatchPriority, PriorityDispatchResult
7+
from sqlalchemy import text
68

79

810
def _quote_ident(name: str) -> str:
@@ -94,3 +96,138 @@ def _reindex_bm25_impl(operations: Operations, operation: ReindexBM25Op) -> None
9496
@renderers.dispatch_for(ReindexBM25Op)
9597
def _render_reindex_bm25_op(autogen_context, op: ReindexBM25Op) -> str:
9698
return f"op.reindex_bm25({op.index_name!r}, concurrently={op.concurrently!r})"
99+
100+
101+
# ---------------------------------------------------------------------------
102+
# Autogenerate comparator
103+
# ---------------------------------------------------------------------------
104+
105+
def _autogen_bm25_meta_indexes(metadata, effective_schemas: set[str]) -> dict[tuple[str, str], object]:
106+
"""Return {(schema, index_name): Index} for all BM25 indexes in MetaData."""
107+
from .indexing import _is_bm25_index
108+
109+
result: dict[tuple[str, str], object] = {}
110+
for table in metadata.tables.values():
111+
schema = table.schema or next(iter(effective_schemas), "public")
112+
if schema not in effective_schemas:
113+
continue
114+
for index in table.indexes:
115+
if _is_bm25_index(index):
116+
result[(schema, index.name)] = index
117+
return result
118+
119+
120+
def _autogen_bm25_db_indexes(conn, effective_schemas: set[str]) -> dict[tuple[str, str], dict]:
121+
"""Return {(schema, index_name): {table_name, fields, key_field}} from pg_indexes."""
122+
from .indexing import _extract_bm25_field_list, _extract_field_name, _extract_key_field
123+
124+
result: dict[tuple[str, str], dict] = {}
125+
for schema in effective_schemas:
126+
rows = conn.execute(
127+
text(
128+
"""
129+
SELECT schemaname, tablename, indexname, indexdef
130+
FROM pg_indexes
131+
WHERE schemaname = :schema
132+
AND indexdef ILIKE '%USING bm25%'
133+
ORDER BY indexname
134+
"""
135+
),
136+
{"schema": schema},
137+
).fetchall()
138+
for row in rows:
139+
raw_fields = _extract_bm25_field_list(row.indexdef)
140+
fields = [f for f in (_extract_field_name(rf) for rf in raw_fields) if f is not None]
141+
result[(row.schemaname, row.indexname)] = {
142+
"table_name": row.tablename,
143+
"fields": fields,
144+
"key_field": _extract_key_field(row.indexdef) or "",
145+
}
146+
return result
147+
148+
149+
def _suppress_standard_bm25_ops(upgrade_ops, bm25_names: set[str]) -> None:
150+
"""Remove any standard Alembic CreateIndexOp/DropIndexOp for BM25 indexes."""
151+
from alembic.operations.ops import CreateIndexOp, DropIndexOp, ModifyTableOps
152+
153+
# Filter top-level (rare, but defensive)
154+
upgrade_ops.ops[:] = [
155+
op
156+
for op in upgrade_ops.ops
157+
if not (isinstance(op, (CreateIndexOp, DropIndexOp)) and op.index_name in bm25_names)
158+
]
159+
# Filter inside ModifyTableOps (the normal location for index ops)
160+
for op in upgrade_ops.ops:
161+
if isinstance(op, ModifyTableOps):
162+
op.ops[:] = [
163+
sub_op
164+
for sub_op in op.ops
165+
if not (
166+
isinstance(sub_op, (CreateIndexOp, DropIndexOp))
167+
and sub_op.index_name in bm25_names
168+
)
169+
]
170+
171+
172+
@comparators.dispatch_for("schema", priority=DispatchPriority.LAST)
173+
def _compare_bm25_indexes(autogen_context, upgrade_ops, schemas) -> PriorityDispatchResult:
174+
"""Autogenerate comparator: emit BM25 create/drop ops and suppress incorrect standard ops."""
175+
conn = autogen_context.connection
176+
metadata = autogen_context.metadata
177+
178+
if conn is None or metadata is None:
179+
return PriorityDispatchResult.CONTINUE
180+
181+
default_schema: str = conn.dialect.default_schema_name or "public"
182+
effective_schemas = {s if s is not None else default_schema for s in schemas}
183+
184+
db_bm25 = _autogen_bm25_db_indexes(conn, effective_schemas)
185+
meta_bm25 = _autogen_bm25_meta_indexes(metadata, effective_schemas)
186+
187+
all_bm25_names = {k[1] for k in db_bm25} | {k[1] for k in meta_bm25}
188+
if not all_bm25_names:
189+
return PriorityDispatchResult.CONTINUE
190+
191+
# Remove any standard CreateIndexOp/DropIndexOp for BM25 indexes since
192+
# those would render incorrect DDL (BM25Field expressions can't be
193+
# round-tripped through the standard Inspector → Python code path).
194+
_suppress_standard_bm25_ops(upgrade_ops, all_bm25_names)
195+
196+
# Emit drop ops for indexes present in DB but absent from MetaData.
197+
for key in db_bm25:
198+
if key not in meta_bm25:
199+
upgrade_ops.ops.append(DropBM25IndexOp(index_name=key[1], if_exists=True))
200+
201+
# Emit create ops for indexes present in MetaData but absent from DB.
202+
# Also re-create indexes whose field list or key_field differs from the DB.
203+
for key, index in meta_bm25.items():
204+
from .indexing import _bm25_field_name
205+
206+
with_opts = index.dialect_options["postgresql"].get("with") or {}
207+
key_field = with_opts.get("key_field", "")
208+
fields = [f for f in (_bm25_field_name(expr) for expr in index.expressions) if f is not None]
209+
210+
if key not in db_bm25:
211+
upgrade_ops.ops.append(
212+
CreateBM25IndexOp(
213+
index_name=index.name,
214+
table_name=index.table.name,
215+
fields=fields,
216+
key_field=key_field,
217+
)
218+
)
219+
else:
220+
db = db_bm25[key]
221+
if db["fields"] != fields or db["key_field"] != key_field:
222+
# Index configuration changed: drop the old one, create the new one.
223+
upgrade_ops.ops.append(DropBM25IndexOp(index_name=key[1], if_exists=True))
224+
upgrade_ops.ops.append(
225+
CreateBM25IndexOp(
226+
index_name=index.name,
227+
table_name=index.table.name,
228+
fields=fields,
229+
key_field=key_field,
230+
)
231+
)
232+
233+
return PriorityDispatchResult.CONTINUE

tests/integration/test_alembic_integration.py

Lines changed: 144 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,32 @@
33
import pytest
44
from alembic.migration import MigrationContext
55
from alembic.operations import Operations
6-
from sqlalchemy import text
6+
from alembic.operations.ops import UpgradeOps
7+
from sqlalchemy import Column, Integer, MetaData, Table, Text, text
8+
from unittest.mock import MagicMock
79

8-
import paradedb.sqlalchemy.alembic # noqa: F401 Ensure op registration
10+
import paradedb.sqlalchemy.alembic as pdb_alembic # noqa: F401 Ensure op registration
11+
from paradedb.sqlalchemy.indexing import BM25Field
912

1013

1114
pytestmark = pytest.mark.integration
1215

1316

17+
# ---------------------------------------------------------------------------
18+
# Helper: run the BM25 autogenerate comparator against a real DB connection
19+
# ---------------------------------------------------------------------------
20+
21+
def _run_comparator(engine, metadata):
22+
"""Return the UpgradeOps produced by the BM25 autogenerate comparator."""
23+
with engine.connect() as conn:
24+
ctx = MagicMock()
25+
ctx.connection = conn
26+
ctx.metadata = metadata
27+
upgrade_ops = UpgradeOps([])
28+
pdb_alembic._compare_bm25_indexes(ctx, upgrade_ops, {None})
29+
return upgrade_ops
30+
31+
1432
def test_alembic_create_reindex_drop_with_quoted_identifiers(engine):
1533
table_name = 'alembic quoted products'
1634
index_name = 'alembic quoted idx'
@@ -57,3 +75,127 @@ def test_alembic_create_reindex_drop_with_quoted_identifiers(engine):
5775

5876
with engine.begin() as conn:
5977
conn.execute(text(f'DROP TABLE IF EXISTS "{table_name}"'))
78+
79+
80+
# ---------------------------------------------------------------------------
81+
# Autogenerate comparator integration tests
82+
# ---------------------------------------------------------------------------
83+
84+
_AG_TABLE = "autogen_test"
85+
_AG_IDX = "autogen_test_bm25_idx"
86+
87+
88+
def _setup_autogen_table(engine, *, with_index: bool = False):
89+
"""Create a clean autogen_test table (and optionally a BM25 index) in the DB."""
90+
with engine.begin() as conn:
91+
conn.execute(text(f'DROP INDEX IF EXISTS "{_AG_IDX}"'))
92+
conn.execute(text(f'DROP TABLE IF EXISTS "{_AG_TABLE}" CASCADE'))
93+
conn.execute(text(f'CREATE TABLE "{_AG_TABLE}" (id int primary key, description text not null)'))
94+
if with_index:
95+
conn.execute(
96+
text(
97+
f'CREATE INDEX "{_AG_IDX}" ON "{_AG_TABLE}" '
98+
f"USING bm25 (id, description) WITH (key_field='id')"
99+
)
100+
)
101+
102+
103+
def _teardown_autogen_table(engine):
104+
with engine.begin() as conn:
105+
conn.execute(text(f'DROP INDEX IF EXISTS "{_AG_IDX}"'))
106+
conn.execute(text(f'DROP TABLE IF EXISTS "{_AG_TABLE}" CASCADE'))
107+
108+
109+
def _metadata_with_bm25() -> MetaData:
110+
"""MetaData that defines autogen_test with a BM25 index."""
111+
m = MetaData()
112+
t = Table(_AG_TABLE, m, Column("id", Integer, primary_key=True), Column("description", Text))
113+
from sqlalchemy.schema import Index
114+
Index(
115+
_AG_IDX,
116+
BM25Field(t.c.id),
117+
BM25Field(t.c.description),
118+
postgresql_using="bm25",
119+
postgresql_with={"key_field": "id"},
120+
)
121+
return m
122+
123+
124+
def _metadata_without_bm25() -> MetaData:
125+
"""MetaData that defines autogen_test WITHOUT any BM25 index."""
126+
m = MetaData()
127+
Table(_AG_TABLE, m, Column("id", Integer, primary_key=True), Column("description", Text))
128+
return m
129+
130+
131+
def test_autogenerate_detects_missing_index(engine):
132+
"""MetaData has BM25 index but DB does not → CreateBM25IndexOp emitted."""
133+
_setup_autogen_table(engine, with_index=False)
134+
try:
135+
upgrade_ops = _run_comparator(engine, _metadata_with_bm25())
136+
137+
create_ops = [op for op in upgrade_ops.ops if isinstance(op, pdb_alembic.CreateBM25IndexOp)]
138+
assert len(create_ops) == 1
139+
op = create_ops[0]
140+
assert op.index_name == _AG_IDX
141+
assert op.table_name == _AG_TABLE
142+
assert op.key_field == "id"
143+
assert "id" in op.fields
144+
assert "description" in op.fields
145+
finally:
146+
_teardown_autogen_table(engine)
147+
148+
149+
def test_autogenerate_detects_extra_index(engine):
150+
"""DB has BM25 index but MetaData does not → DropBM25IndexOp emitted."""
151+
_setup_autogen_table(engine, with_index=True)
152+
try:
153+
upgrade_ops = _run_comparator(engine, _metadata_without_bm25())
154+
155+
drop_ops = [op for op in upgrade_ops.ops if isinstance(op, pdb_alembic.DropBM25IndexOp)]
156+
assert any(op.index_name == _AG_IDX for op in drop_ops)
157+
finally:
158+
_teardown_autogen_table(engine)
159+
160+
161+
def test_autogenerate_no_op_when_indexes_match(engine):
162+
"""DB and MetaData have identical BM25 index → no create/drop ops for that index."""
163+
_setup_autogen_table(engine, with_index=True)
164+
try:
165+
upgrade_ops = _run_comparator(engine, _metadata_with_bm25())
166+
167+
# Filter to only ops for our specific test index; the shared engine fixture's
168+
# products_bm25_idx may appear as "extra" since our MetaData only knows autogen_test.
169+
create_ops = [
170+
op for op in upgrade_ops.ops
171+
if isinstance(op, pdb_alembic.CreateBM25IndexOp) and op.index_name == _AG_IDX
172+
]
173+
drop_ops = [
174+
op for op in upgrade_ops.ops
175+
if isinstance(op, pdb_alembic.DropBM25IndexOp) and op.index_name == _AG_IDX
176+
]
177+
assert not create_ops
178+
assert not drop_ops
179+
finally:
180+
_teardown_autogen_table(engine)
181+
182+
183+
def test_autogenerate_detects_changed_fields(engine):
184+
"""BM25 index in DB has different fields vs MetaData → Drop + Create emitted."""
185+
_setup_autogen_table(engine, with_index=False)
186+
try:
187+
# DB index only covers 'id'
188+
with engine.begin() as conn:
189+
conn.execute(
190+
text(f'CREATE INDEX "{_AG_IDX}" ON "{_AG_TABLE}" USING bm25 (id) WITH (key_field=\'id\')')
191+
)
192+
193+
# MetaData index covers 'id' and 'description'
194+
upgrade_ops = _run_comparator(engine, _metadata_with_bm25())
195+
196+
drop_ops = [op for op in upgrade_ops.ops if isinstance(op, pdb_alembic.DropBM25IndexOp)]
197+
create_ops = [op for op in upgrade_ops.ops if isinstance(op, pdb_alembic.CreateBM25IndexOp)]
198+
assert any(op.index_name == _AG_IDX for op in drop_ops), "Expected DropBM25IndexOp"
199+
assert any(op.index_name == _AG_IDX for op in create_ops), "Expected CreateBM25IndexOp"
200+
finally:
201+
_teardown_autogen_table(engine)

0 commit comments

Comments
 (0)