Skip to content

Commit af92ff0

Browse files
committed
Fix BM25 Alembic schema handling and tighten mypy checks
1 parent 08cf6f3 commit af92ff0

11 files changed

Lines changed: 988 additions & 102 deletions

File tree

paradedb/sqlalchemy/alembic.py

Lines changed: 124 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,14 @@
11
from __future__ import annotations
22

3+
import re
4+
35
from alembic.autogenerate import comparators, renderers
46
from alembic.operations import Operations
57
from alembic.operations.ops import MigrateOperation
68
from alembic.util import DispatchPriority, PriorityDispatchResult
9+
from sqlalchemy.dialects import postgresql
710
from sqlalchemy import text
11+
from sqlalchemy.sql.elements import ClauseElement
812

913

1014
def _quote_ident(name: str) -> str:
@@ -15,100 +19,150 @@ def _quote_literal(value: str) -> str:
1519
return "'" + value.replace("'", "''") + "'"
1620

1721

22+
def _quote_qualified(schema: str | None, name: str) -> str:
23+
if schema:
24+
return f"{_quote_ident(schema)}.{_quote_ident(name)}"
25+
return _quote_ident(name)
26+
27+
1828
@Operations.register_operation("create_bm25_index")
1929
class CreateBM25IndexOp(MigrateOperation):
20-
def __init__(self, index_name: str, table_name: str, fields: list[str], key_field: str) -> None:
30+
def __init__(
31+
self,
32+
index_name: str,
33+
table_name: str,
34+
expressions: list[str],
35+
key_field: str,
36+
*,
37+
table_schema: str | None = None,
38+
index_schema: str | None = None,
39+
) -> None:
2140
self.index_name = index_name
2241
self.table_name = table_name
23-
self.fields = fields
42+
self.expressions = expressions
2443
self.key_field = key_field
44+
self.table_schema = table_schema
45+
self.index_schema = index_schema
2546

2647
@classmethod
2748
def create_bm25_index(
2849
cls,
2950
operations: Operations,
3051
index_name: str,
3152
table_name: str,
32-
fields: list[str],
53+
expressions: list[str],
3354
*,
3455
key_field: str,
56+
table_schema: str | None = None,
57+
index_schema: str | None = None,
3558
) -> MigrateOperation:
36-
return operations.invoke(cls(index_name, table_name, fields, key_field))
59+
return operations.invoke(
60+
cls(
61+
index_name,
62+
table_name,
63+
expressions,
64+
key_field,
65+
table_schema=table_schema,
66+
index_schema=index_schema,
67+
)
68+
)
3769

3870

3971
@Operations.implementation_for(CreateBM25IndexOp)
4072
def _create_bm25_index_impl(operations: Operations, operation: CreateBM25IndexOp) -> None:
41-
fields_sql = ", ".join(_quote_ident(field) for field in operation.fields)
73+
expressions_sql = ", ".join(operation.expressions)
4274
sql = (
43-
f"CREATE INDEX {_quote_ident(operation.index_name)} ON {_quote_ident(operation.table_name)} "
44-
f"USING bm25 ({fields_sql}) WITH (key_field={_quote_literal(operation.key_field)})"
75+
f"CREATE INDEX {_quote_ident(operation.index_name)} "
76+
f"ON {_quote_qualified(operation.table_schema, operation.table_name)} "
77+
f"USING bm25 ({expressions_sql}) WITH (key_field={_quote_literal(operation.key_field)})"
4578
)
4679
operations.execute(sql)
4780

4881

4982
@renderers.dispatch_for(CreateBM25IndexOp)
5083
def _render_create_bm25_index_op(autogen_context, op: CreateBM25IndexOp) -> str:
51-
return (
52-
f"op.create_bm25_index({op.index_name!r}, {op.table_name!r}, {op.fields!r}, "
53-
f"key_field={op.key_field!r})"
54-
)
84+
parts = [
85+
repr(op.index_name),
86+
repr(op.table_name),
87+
repr(op.expressions),
88+
f"key_field={op.key_field!r}",
89+
]
90+
if op.table_schema is not None:
91+
parts.append(f"table_schema={op.table_schema!r}")
92+
if op.index_schema is not None:
93+
parts.append(f"index_schema={op.index_schema!r}")
94+
return f"op.create_bm25_index({', '.join(parts)})"
5595

5696

5797
@Operations.register_operation("drop_bm25_index")
5898
class DropBM25IndexOp(MigrateOperation):
59-
def __init__(self, index_name: str, if_exists: bool = True) -> None:
99+
def __init__(self, index_name: str, if_exists: bool = True, schema: str | None = None) -> None:
60100
self.index_name = index_name
61101
self.if_exists = if_exists
102+
self.schema = schema
62103

63104
@classmethod
64-
def drop_bm25_index(cls, operations: Operations, index_name: str, if_exists: bool = True) -> MigrateOperation:
65-
return operations.invoke(cls(index_name=index_name, if_exists=if_exists))
105+
def drop_bm25_index(
106+
cls, operations: Operations, index_name: str, if_exists: bool = True, schema: str | None = None
107+
) -> MigrateOperation:
108+
return operations.invoke(cls(index_name=index_name, if_exists=if_exists, schema=schema))
66109

67110

68111
@Operations.implementation_for(DropBM25IndexOp)
69112
def _drop_bm25_index_impl(operations: Operations, operation: DropBM25IndexOp) -> None:
70113
if_exists_sql = " IF EXISTS" if operation.if_exists else ""
71-
operations.execute(f"DROP INDEX{if_exists_sql} {_quote_ident(operation.index_name)}")
114+
operations.execute(f"DROP INDEX{if_exists_sql} {_quote_qualified(operation.schema, operation.index_name)}")
72115

73116

74117
@renderers.dispatch_for(DropBM25IndexOp)
75118
def _render_drop_bm25_index_op(autogen_context, op: DropBM25IndexOp) -> str:
76-
return f"op.drop_bm25_index({op.index_name!r}, if_exists={op.if_exists!r})"
119+
parts = [repr(op.index_name), f"if_exists={op.if_exists!r}"]
120+
if op.schema is not None:
121+
parts.append(f"schema={op.schema!r}")
122+
return f"op.drop_bm25_index({', '.join(parts)})"
77123

78124

79125
@Operations.register_operation("reindex_bm25")
80126
class ReindexBM25Op(MigrateOperation):
81-
def __init__(self, index_name: str, concurrently: bool = False) -> None:
127+
def __init__(self, index_name: str, concurrently: bool = False, schema: str | None = None) -> None:
82128
self.index_name = index_name
83129
self.concurrently = concurrently
130+
self.schema = schema
84131

85132
@classmethod
86-
def reindex_bm25(cls, operations: Operations, index_name: str, concurrently: bool = False) -> MigrateOperation:
87-
return operations.invoke(cls(index_name=index_name, concurrently=concurrently))
133+
def reindex_bm25(
134+
cls, operations: Operations, index_name: str, concurrently: bool = False, schema: str | None = None
135+
) -> MigrateOperation:
136+
return operations.invoke(cls(index_name=index_name, concurrently=concurrently, schema=schema))
88137

89138

90139
@Operations.implementation_for(ReindexBM25Op)
91140
def _reindex_bm25_impl(operations: Operations, operation: ReindexBM25Op) -> None:
92141
concurrently_sql = " CONCURRENTLY" if operation.concurrently else ""
93-
operations.execute(f"REINDEX INDEX{concurrently_sql} {_quote_ident(operation.index_name)}")
142+
operations.execute(f"REINDEX INDEX{concurrently_sql} {_quote_qualified(operation.schema, operation.index_name)}")
94143

95144

96145
@renderers.dispatch_for(ReindexBM25Op)
97146
def _render_reindex_bm25_op(autogen_context, op: ReindexBM25Op) -> str:
98-
return f"op.reindex_bm25({op.index_name!r}, concurrently={op.concurrently!r})"
147+
parts = [repr(op.index_name), f"concurrently={op.concurrently!r}"]
148+
if op.schema is not None:
149+
parts.append(f"schema={op.schema!r}")
150+
return f"op.reindex_bm25({', '.join(parts)})"
99151

100152

101153
# ---------------------------------------------------------------------------
102154
# Autogenerate comparator
103155
# ---------------------------------------------------------------------------
104156

105-
def _autogen_bm25_meta_indexes(metadata, effective_schemas: set[str]) -> dict[tuple[str, str], object]:
157+
def _autogen_bm25_meta_indexes(
158+
metadata, effective_schemas: set[str], *, default_schema: str
159+
) -> dict[tuple[str, str], object]:
106160
"""Return {(schema, index_name): Index} for all BM25 indexes in MetaData."""
107161
from .indexing import _is_bm25_index
108162

109163
result: dict[tuple[str, str], object] = {}
110164
for table in metadata.tables.values():
111-
schema = table.schema or next(iter(effective_schemas), "public")
165+
schema = table.schema or default_schema
112166
if schema not in effective_schemas:
113167
continue
114168
for index in table.indexes:
@@ -118,8 +172,8 @@ def _autogen_bm25_meta_indexes(metadata, effective_schemas: set[str]) -> dict[tu
118172

119173

120174
def _autogen_bm25_db_indexes(conn, effective_schemas: set[str]) -> dict[tuple[str, str], dict]:
121-
"""Return {(schema, index_name): {table_name, fields, key_field}} from pg_indexes."""
122-
from .indexing import _extract_bm25_field_list, _extract_field_name, _extract_key_field
175+
"""Return {(schema, index_name): {table_name, expressions, key_field}} from pg_indexes."""
176+
from .indexing import _extract_bm25_field_list, _extract_key_field
123177

124178
result: dict[tuple[str, str], dict] = {}
125179
for schema in effective_schemas:
@@ -137,15 +191,43 @@ def _autogen_bm25_db_indexes(conn, effective_schemas: set[str]) -> dict[tuple[st
137191
).fetchall()
138192
for row in rows:
139193
raw_fields = _extract_bm25_field_list(row.indexdef)
140-
fields = [f for f in (_extract_field_name(rf) for rf in raw_fields) if f is not None]
141194
result[(row.schemaname, row.indexname)] = {
142195
"table_name": row.tablename,
143-
"fields": fields,
196+
"expressions": raw_fields,
144197
"key_field": _extract_key_field(row.indexdef) or "",
145198
}
146199
return result
147200

148201

202+
def _render_bm25_expression(expr: ClauseElement) -> str:
203+
return str(expr.compile(dialect=postgresql.dialect(), compile_kwargs={"literal_binds": True})) # type: ignore[no-untyped-call]
204+
205+
206+
def _strip_relation_qualifiers(expr: str, table_name: str) -> str:
207+
# SQLAlchemy may render column refs as `table.col` in metadata compilation;
208+
# CREATE INDEX field lists should be table-local expressions.
209+
stripped = expr.replace(f'"{table_name}".', "")
210+
stripped = stripped.replace(f"{table_name}.", "")
211+
return stripped
212+
213+
214+
def _normalize_bm25_expression(expr: str) -> str:
215+
"""Normalize BM25 expression text to reduce false-positive autogen churn."""
216+
normalized = "".join(expr.split())
217+
normalized = normalized.replace('"', "")
218+
normalized = normalized.replace("::text", "")
219+
# Ignore schema/table qualification differences, but keep tokenizer namespaces like `pdb.simple`.
220+
previous = None
221+
while previous != normalized:
222+
previous = normalized
223+
normalized = re.sub(r"(?<![A-Za-z0-9_])(?!pdb\b)[A-Za-z_][A-Za-z0-9_]*\.", "", normalized)
224+
return normalized
225+
226+
227+
def _normalized_expression_list(expressions: list[str]) -> list[str]:
228+
return [_normalize_bm25_expression(expr) for expr in expressions]
229+
230+
149231
def _suppress_standard_bm25_ops(upgrade_ops, bm25_names: set[str]) -> None:
150232
"""Remove any standard Alembic CreateIndexOp/DropIndexOp for BM25 indexes."""
151233
from alembic.operations.ops import CreateIndexOp, DropIndexOp, ModifyTableOps
@@ -182,7 +264,7 @@ def _compare_bm25_indexes(autogen_context, upgrade_ops, schemas) -> PriorityDisp
182264
effective_schemas = {s if s is not None else default_schema for s in schemas}
183265

184266
db_bm25 = _autogen_bm25_db_indexes(conn, effective_schemas)
185-
meta_bm25 = _autogen_bm25_meta_indexes(metadata, effective_schemas)
267+
meta_bm25 = _autogen_bm25_meta_indexes(metadata, effective_schemas, default_schema=default_schema)
186268

187269
all_bm25_names = {k[1] for k in db_bm25} | {k[1] for k in meta_bm25}
188270
if not all_bm25_names:
@@ -196,37 +278,42 @@ def _compare_bm25_indexes(autogen_context, upgrade_ops, schemas) -> PriorityDisp
196278
# Emit drop ops for indexes present in DB but absent from MetaData.
197279
for key in db_bm25:
198280
if key not in meta_bm25:
199-
upgrade_ops.ops.append(DropBM25IndexOp(index_name=key[1], if_exists=True))
281+
upgrade_ops.ops.append(DropBM25IndexOp(index_name=key[1], if_exists=True, schema=key[0]))
200282

201283
# Emit create ops for indexes present in MetaData but absent from DB.
202-
# Also re-create indexes whose field list or key_field differs from the DB.
284+
# Also re-create indexes whose expression list or key_field differs from the DB.
203285
for key, index in meta_bm25.items():
204-
from .indexing import _bm25_field_name
205-
206286
with_opts = index.dialect_options["postgresql"].get("with") or {}
207287
key_field = with_opts.get("key_field", "")
208-
fields = [f for f in (_bm25_field_name(expr) for expr in index.expressions) if f is not None]
288+
expressions = [
289+
_strip_relation_qualifiers(_render_bm25_expression(expr), index.table.name)
290+
for expr in index.expressions
291+
]
209292

210293
if key not in db_bm25:
211294
upgrade_ops.ops.append(
212295
CreateBM25IndexOp(
213296
index_name=index.name,
214297
table_name=index.table.name,
215-
fields=fields,
298+
expressions=expressions,
216299
key_field=key_field,
300+
table_schema=key[0],
301+
index_schema=key[0],
217302
)
218303
)
219304
else:
220305
db = db_bm25[key]
221-
if db["fields"] != fields or db["key_field"] != key_field:
306+
if _normalized_expression_list(db["expressions"]) != _normalized_expression_list(expressions) or db["key_field"] != key_field:
222307
# Index configuration changed: drop the old one, create the new one.
223-
upgrade_ops.ops.append(DropBM25IndexOp(index_name=key[1], if_exists=True))
308+
upgrade_ops.ops.append(DropBM25IndexOp(index_name=key[1], if_exists=True, schema=key[0]))
224309
upgrade_ops.ops.append(
225310
CreateBM25IndexOp(
226311
index_name=index.name,
227312
table_name=index.table.name,
228-
fields=fields,
313+
expressions=expressions,
229314
key_field=key_field,
315+
table_schema=key[0],
316+
index_schema=key[0],
230317
)
231318
)
232319

paradedb/sqlalchemy/facets.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
from dataclasses import dataclass
4+
from collections.abc import Sequence
45
from typing import Any
56

67
from sqlalchemy import Select
@@ -108,10 +109,9 @@ def extract(self, rows: list[object]) -> Any | None:
108109
mapping = getattr(first, "_mapping", None)
109110
if mapping is not None and self.label in mapping:
110111
return mapping[self.label]
111-
try:
112+
if isinstance(first, Sequence) and not isinstance(first, (str, bytes)):
112113
return first[-1]
113-
except Exception:
114-
return None
114+
return None
115115

116116

117117
def with_rows(

0 commit comments

Comments
 (0)