Skip to content

Commit 0e7a69e

Browse files
committed
wip
1 parent 05365c5 commit 0e7a69e

6 files changed

Lines changed: 398 additions & 3 deletions

File tree

paradedb/sqlalchemy/errors.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,3 +55,7 @@ class FacetRequiresLimitError(FacetRuntimeError):
5555

5656
class FacetRequiresParadeDBPredicateError(FacetRuntimeError):
5757
"""Raised when rows+facets helper is used without ParadeDB predicate/sentinel."""
58+
59+
60+
class FieldNotIndexedError(ParadeDBError):
61+
"""Raised when a column is not covered by any BM25 index on its table."""

paradedb/sqlalchemy/indexing.py

Lines changed: 94 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from __future__ import annotations
22

33
import re
4-
from dataclasses import dataclass
4+
from dataclasses import dataclass, field
55
from typing import Any
66

77
from sqlalchemy import Index, event, text
@@ -13,6 +13,7 @@
1313

1414
from .errors import (
1515
DuplicateTokenizerAliasError,
16+
FieldNotIndexedError,
1617
InvalidArgumentError,
1718
InvalidBM25FieldError,
1819
InvalidKeyFieldError,
@@ -35,7 +36,7 @@ def render(self) -> str:
3536
raise InvalidArgumentError("tokenizer name is required unless raw_sql is provided")
3637

3738
if not self.options:
38-
return f"pdb.{self.name}()"
39+
return f"pdb.{self.name}"
3940

4041
rendered_options = ",".join(f"{key}={_format_option_value(value)}" for key, value in self.options)
4142
escaped = rendered_options.replace("'", "''")
@@ -179,12 +180,15 @@ class IndexMeta:
179180
key_field: str | None
180181
fields: tuple[str, ...]
181182
aliases: dict[str, str]
183+
tokenizers: dict[str, tuple[str, ...]] = field(default_factory=dict)
184+
"""Maps field name to the tokenizer names used in this index, e.g. ``{"description": ("unicode_words",)}``."""
182185

183186

184187
_KEY_FIELD_RE = re.compile(r"key_field\s*=\s*'?\"?([^'\",)\s]+)\"?'?", re.IGNORECASE)
185188
_ALIAS_RE = re.compile(r"alias\s*=\s*([A-Za-z_][A-Za-z0-9_]*)", re.IGNORECASE)
186189
_CAST_FIELD_RE = re.compile(r"^\(*\"?([A-Za-z_][A-Za-z0-9_]*)\"?\)*\s*::\s*pdb\.", re.IGNORECASE)
187190
_PLAIN_FIELD_RE = re.compile(r'^\(*"?([A-Za-z_][A-Za-z0-9_]*)"?\)*$')
191+
_TOKENIZER_NAME_RE = re.compile(r"::pdb\.([A-Za-z_][A-Za-z0-9_]*)", re.IGNORECASE)
188192

189193

190194
def _split_top_level_csv(expr: str) -> list[str]:
@@ -274,6 +278,13 @@ def _extract_alias(index_expr: str) -> str | None:
274278
return None
275279

276280

281+
def _extract_tokenizer_name(field_expr: str) -> str | None:
282+
"""Return the bare tokenizer name from a field expression, e.g. ``unicode_words`` from
283+
``(description::pdb.unicode_words('lowercase=true'))``. Returns ``None`` for plain fields."""
284+
match = _TOKENIZER_NAME_RE.search(field_expr)
285+
return match.group(1) if match else None
286+
287+
277288
def describe(engine: Engine, table) -> list[IndexMeta]:
278289
query = text(
279290
"""
@@ -293,6 +304,7 @@ def describe(engine: Engine, table) -> list[IndexMeta]:
293304
key_field = _extract_key_field(indexdef)
294305
raw_fields = _extract_bm25_field_list(indexdef)
295306
aliases: dict[str, str] = {}
307+
tokenizer_map: dict[str, list[str]] = {}
296308
fields_ordered: list[str] = []
297309
for raw in raw_fields:
298310
field_name = _extract_field_name(raw)
@@ -303,13 +315,93 @@ def describe(engine: Engine, table) -> list[IndexMeta]:
303315
alias = _extract_alias(raw)
304316
if alias is not None:
305317
aliases[alias] = field_name
318+
tok = _extract_tokenizer_name(raw)
319+
if tok is not None:
320+
tokenizer_map.setdefault(field_name, []).append(tok)
306321

307322
output.append(
308323
IndexMeta(
309324
index_name=row.indexname,
310325
key_field=key_field,
311326
fields=tuple(fields_ordered),
312327
aliases=aliases,
328+
tokenizers={k: tuple(v) for k, v in tokenizer_map.items()},
313329
)
314330
)
315331
return output
332+
333+
334+
def assert_indexed(
335+
engine: Engine,
336+
column: Any,
337+
*,
338+
tokenizer: str | None = None,
339+
) -> None:
340+
"""Raise :exc:`FieldNotIndexedError` if *column* is not covered by any BM25 index.
341+
342+
Args:
343+
engine: SQLAlchemy engine connected to the ParadeDB database.
344+
column: A table-bound column expression (e.g. ``Product.description``).
345+
tokenizer: Optional tokenizer name to verify, e.g. ``"literal"`` or
346+
``"unicode_words"``. When given, raises if the column is not
347+
indexed with that specific tokenizer.
348+
349+
Example::
350+
351+
assert_indexed(engine, Product.category, tokenizer="literal")
352+
"""
353+
table = getattr(column, "table", None)
354+
if table is None:
355+
raise InvalidArgumentError("column must be a table-bound column expression")
356+
col_name: str | None = getattr(column, "name", None)
357+
if col_name is None:
358+
raise InvalidArgumentError("column must have a name attribute")
359+
360+
for idx_meta in describe(engine, table):
361+
if col_name not in idx_meta.fields:
362+
continue
363+
if tokenizer is None:
364+
return # field is indexed; no tokenizer constraint
365+
if tokenizer in idx_meta.tokenizers.get(col_name, ()):
366+
return # field is indexed with the requested tokenizer
367+
368+
msg = f"'{col_name}' is not indexed in any BM25 index on '{table.name}'"
369+
if tokenizer:
370+
msg += f" with tokenizer '{tokenizer}'"
371+
raise FieldNotIndexedError(msg)
372+
373+
374+
def validate_pushdown(stmt: Any) -> list[str]:
375+
"""Inspect *stmt* for patterns that will not push down to ParadeDB.
376+
377+
Performs **static AST analysis only** — no database connection is required.
378+
Returns a (possibly empty) list of human-readable warning strings.
379+
380+
Example::
381+
382+
issues = validate_pushdown(stmt)
383+
for w in issues:
384+
print("Warning:", w)
385+
"""
386+
from . import inspect as _inspect
387+
388+
warnings: list[str] = []
389+
390+
whereclause = getattr(stmt, "whereclause", None)
391+
if whereclause is None:
392+
warnings.append(
393+
"No WHERE clause found; query will perform a full table scan without ParadeDB"
394+
)
395+
elif not _inspect.has_paradedb_predicate(whereclause):
396+
warnings.append(
397+
"No ParadeDB predicate found in WHERE clause; query will not use a BM25 index"
398+
)
399+
400+
order_by = getattr(stmt, "_order_by_clauses", None) or ()
401+
limit = getattr(stmt, "_limit_clause", None)
402+
if order_by and limit is None:
403+
warnings.append(
404+
"ORDER BY is present without LIMIT; top-N pushdown to ParadeDB requires both"
405+
)
406+
407+
return warnings

paradedb/sqlalchemy/search.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
require_positive,
1818
)
1919

20+
_VALID_RANGE_RELATIONS: frozenset[str] = frozenset({"Intersects", "Contains", "Within", "ContainsOrIntersects"})
21+
2022
_MATCH_ALL = operators.custom_op("&&&", precedence=5, is_comparison=True)
2123
_MATCH_ANY = operators.custom_op("|||", precedence=5, is_comparison=True)
2224
_TERM = operators.custom_op("===", precedence=5, is_comparison=True)
@@ -168,6 +170,31 @@ def proximity(field: ColumnElement, prox: ProximityExpr | ClauseElement) -> Colu
168170
return field.operate(_QUERY, prox_expr)
169171

170172

173+
def range_term(
174+
field: ColumnElement,
175+
bounds: str,
176+
*,
177+
relation: str = "Intersects",
178+
) -> ColumnElement[bool]:
179+
"""Match rows where a range-typed field satisfies a range predicate.
180+
181+
Args:
182+
field: A range-typed column (int4range, daterange, tstzrange, etc.).
183+
bounds: A range literal string, e.g. ``"[3,9]"``, ``"(3,9]"``.
184+
relation: One of ``"Intersects"``, ``"Contains"``, ``"Within"``,
185+
``"ContainsOrIntersects"``. Defaults to ``"Intersects"``.
186+
187+
Generates::
188+
189+
field @@@ pdb.range_term('[3,9]', 'Contains')
190+
"""
191+
if relation not in _VALID_RANGE_RELATIONS:
192+
raise InvalidArgumentError(
193+
f"relation must be one of: {', '.join(sorted(_VALID_RANGE_RELATIONS))}"
194+
)
195+
return field.operate(_QUERY, func.pdb.range_term(bounds, relation))
196+
197+
171198
def more_like_this(
172199
field: ColumnElement,
173200
*,

tests/integration/test_indexing_integration.py

Lines changed: 78 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
import pytest
44
from sqlalchemy import Column, Index, Integer, MetaData, String, Table, Text, text
55

6-
from paradedb.sqlalchemy.indexing import BM25Field, describe, tokenize
6+
from paradedb.sqlalchemy.indexing import BM25Field, assert_indexed, describe, tokenize
7+
from paradedb.sqlalchemy.errors import FieldNotIndexedError
78

89

910
pytestmark = pytest.mark.integration
@@ -260,3 +261,79 @@ def test_describe_returns_fields_and_aliases(engine):
260261
assert meta.aliases == {"category_exact": "category"}
261262

262263
_drop_table_and_index(engine, table_name, index_name)
264+
265+
266+
def test_describe_includes_tokenizers(engine):
267+
"""describe() populates IndexMeta.tokenizers from the index definition."""
268+
if not _tokenizer_cast_supported(engine):
269+
pytest.skip("ParadeDB instance does not support tokenizer cast index syntax yet")
270+
271+
table_name = "describe_tokenizers_products"
272+
index_name = "describe_tokenizers_bm25_idx"
273+
_drop_table_and_index(engine, table_name, index_name)
274+
275+
metadata = MetaData()
276+
products = Table(
277+
table_name,
278+
metadata,
279+
Column("id", Integer, primary_key=True),
280+
Column("description", Text, nullable=False),
281+
Column("category", String(120), nullable=False),
282+
)
283+
metadata.create_all(engine)
284+
285+
idx = Index(
286+
index_name,
287+
BM25Field(products.c.id),
288+
BM25Field(products.c.description, tokenizer=tokenize.unicode(lowercase=True)),
289+
BM25Field(products.c.category, tokenizer=tokenize.literal()),
290+
postgresql_using="bm25",
291+
postgresql_with={"key_field": "id"},
292+
)
293+
idx.create(engine)
294+
295+
metas = describe(engine, products)
296+
meta = next(m for m in metas if m.index_name == index_name)
297+
298+
assert "unicode_words" in meta.tokenizers.get("description", ())
299+
assert "literal" in meta.tokenizers.get("category", ())
300+
assert "id" not in meta.tokenizers # no tokenizer for plain key field
301+
302+
_drop_table_and_index(engine, table_name, index_name)
303+
304+
305+
def test_assert_indexed_passes_and_raises(engine):
306+
"""assert_indexed passes for an indexed column and raises for an unindexed one."""
307+
table_name = "assert_indexed_products"
308+
index_name = "assert_indexed_bm25_idx"
309+
_drop_table_and_index(engine, table_name, index_name)
310+
311+
metadata = MetaData()
312+
tbl = Table(
313+
table_name,
314+
metadata,
315+
Column("id", Integer, primary_key=True),
316+
Column("description", Text, nullable=False),
317+
Column("extra", Text, nullable=False),
318+
)
319+
metadata.create_all(engine)
320+
321+
idx = Index(
322+
index_name,
323+
BM25Field(tbl.c.id),
324+
BM25Field(tbl.c.description),
325+
postgresql_using="bm25",
326+
postgresql_with={"key_field": "id"},
327+
)
328+
idx.create(engine)
329+
330+
# 'description' is indexed → no error
331+
assert_indexed(engine, tbl.c.description)
332+
333+
# 'extra' is not indexed → FieldNotIndexedError
334+
with pytest.raises(FieldNotIndexedError, match="'extra'"):
335+
assert_indexed(engine, tbl.c.extra)
336+
337+
_drop_table_and_index(engine, table_name, index_name)
338+
339+

0 commit comments

Comments
 (0)