Skip to content

Commit 5104bf2

Browse files
committed
More
1 parent 28788b3 commit 5104bf2

5 files changed

Lines changed: 976 additions & 23 deletions

File tree

paradedb/sqlalchemy/search.py

Lines changed: 60 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import json
44
from typing import Any
55

6-
from sqlalchemy import Text, func, literal
6+
from sqlalchemy import Text, func, literal, literal_column, or_
77
from sqlalchemy.dialects.postgresql import array
88
from sqlalchemy.sql import operators
99
from sqlalchemy.sql.elements import ClauseElement, ColumnElement
@@ -18,13 +18,15 @@
1818
)
1919

2020
_VALID_RANGE_RELATIONS: frozenset[str] = frozenset({"Intersects", "Contains", "Within", "ContainsOrIntersects"})
21+
_VALID_RANGE_TYPES: frozenset[str] = frozenset({"int4range", "int8range", "numrange", "daterange", "tsrange", "tstzrange"})
2122

2223
_MATCH_ALL = operators.custom_op("&&&", precedence=5, is_comparison=True)
2324
_MATCH_ANY = operators.custom_op("|||", precedence=5, is_comparison=True)
2425
_TERM = operators.custom_op("===", precedence=5, is_comparison=True)
2526
_PHRASE = operators.custom_op("###", precedence=5, is_comparison=True)
2627
_QUERY = operators.custom_op("@@@", precedence=5, is_comparison=True)
2728
_NEAR = operators.custom_op("##", precedence=5)
29+
_NEAR_ORDERED = operators.custom_op("##>", precedence=5)
2830

2931

3032
def _to_term_payload(*terms: str) -> ClauseElement:
@@ -100,8 +102,8 @@ class ProximityExpr:
100102
def __init__(self, expr: ClauseElement) -> None:
101103
self.expr = expr
102104

103-
def near(self, other: str | ClauseElement | ProximityExpr, *, distance: int) -> ProximityExpr:
104-
return ProximityExpr(_near_chain(self.expr, other, distance=distance))
105+
def near(self, other: str | ClauseElement | ProximityExpr, *, distance: int, ordered: bool = False) -> ProximityExpr:
106+
return ProximityExpr(_near_chain(self.expr, other, distance=distance, ordered=ordered))
105107

106108

107109
def _to_proximity_operand(value: str | ClauseElement | ProximityExpr) -> ClauseElement:
@@ -119,10 +121,11 @@ def _to_proximity_clause(value: str | ClauseElement | ProximityExpr) -> ClauseEl
119121
return PDBCast(operand, "proximityclause")
120122

121123

122-
def _near_chain(left: str | ClauseElement | ProximityExpr, right: str | ClauseElement | ProximityExpr, *, distance: int) -> ClauseElement:
124+
def _near_chain(left: str | ClauseElement | ProximityExpr, right: str | ClauseElement | ProximityExpr, *, distance: int, ordered: bool = False) -> ClauseElement:
123125
left_expr = _to_proximity_clause(left)
124126
right_expr = _to_proximity_clause(right)
125-
return left_expr.operate(_NEAR, literal(distance)).operate(_NEAR, right_expr)
127+
op = _NEAR_ORDERED if ordered else _NEAR
128+
return left_expr.operate(op, literal(distance)).operate(op, right_expr)
126129

127130

128131
def parse(field: ColumnElement, query: str, *, lenient: bool = False, conjunction_mode: bool = False) -> ColumnElement[bool]:
@@ -150,8 +153,8 @@ def regex_phrase(
150153
return field.operate(_QUERY, func.pdb.regex_phrase(array(terms, type_=Text()), slop, max_expansions))
151154

152155

153-
def near(field: ColumnElement, left: str | ClauseElement, right: str | ClauseElement, *, distance: int) -> ColumnElement[bool]:
154-
return field.operate(_QUERY, _near_chain(left, right, distance=distance))
156+
def near(field: ColumnElement, left: str | ClauseElement, right: str | ClauseElement, *, distance: int, ordered: bool = False) -> ColumnElement[bool]:
157+
return field.operate(_QUERY, _near_chain(left, right, distance=distance, ordered=ordered))
155158

156159

157160
def prox_regex(pattern: str, max_expansions: int = 100) -> ProximityExpr:
@@ -175,6 +178,7 @@ def range_term(
175178
bounds: str,
176179
*,
177180
relation: str = "Intersects",
181+
range_type: str | None = None,
178182
) -> ColumnElement[bool]:
179183
"""Match rows where a range-typed field satisfies a range predicate.
180184
@@ -183,22 +187,37 @@ def range_term(
183187
bounds: A range literal string, e.g. ``"[3,9]"``, ``"(3,9]"``.
184188
relation: One of ``"Intersects"``, ``"Contains"``, ``"Within"``,
185189
``"ContainsOrIntersects"``. Defaults to ``"Intersects"``.
190+
range_type: Optional PostgreSQL range type for explicit casting, e.g.
191+
``"int4range"``, ``"int8range"``, ``"numrange"``,
192+
``"daterange"``, ``"tsrange"``, ``"tstzrange"``.
193+
When provided, generates ``'bounds'::range_type`` cast.
186194
187195
Generates::
188196
189197
field @@@ pdb.range_term('[3,9]', 'Contains')
198+
field @@@ pdb.range_term('[3,9]'::int4range, 'Contains')
190199
"""
191200
if relation not in _VALID_RANGE_RELATIONS:
192201
raise InvalidArgumentError(
193202
f"relation must be one of: {', '.join(sorted(_VALID_RANGE_RELATIONS))}"
194203
)
195-
return field.operate(_QUERY, func.pdb.range_term(bounds, relation))
204+
if range_type is not None:
205+
if range_type not in _VALID_RANGE_TYPES:
206+
raise InvalidArgumentError(
207+
f"range_type must be one of: {', '.join(sorted(_VALID_RANGE_TYPES))}"
208+
)
209+
escaped = bounds.replace("'", "''")
210+
bounds_arg: ClauseElement = literal_column(f"'{escaped}'::{range_type}")
211+
else:
212+
bounds_arg = literal(bounds)
213+
return field.operate(_QUERY, func.pdb.range_term(bounds_arg, relation))
196214

197215

198216
def more_like_this(
199217
field: ColumnElement,
200218
*,
201219
document_id: Any | None = None,
220+
document_ids: list[Any] | None = None,
202221
document: dict[str, Any] | str | None = None,
203222
fields: list[str] | None = None,
204223
min_term_frequency: int | None = None,
@@ -212,10 +231,13 @@ def more_like_this(
212231
) -> ColumnElement[bool]:
213232
error_cls = InvalidMoreLikeThisOptionsError
214233

215-
if (document_id is None) == (document is None):
216-
raise error_cls("exactly one of document_id or document must be provided")
234+
sources_provided = sum(x is not None for x in (document_id, document_ids, document))
235+
if sources_provided != 1:
236+
raise error_cls("exactly one of document_id, document_ids, or document must be provided")
237+
if document_ids is not None and len(document_ids) == 0:
238+
raise error_cls("document_ids must not be empty")
217239
if document is not None and fields is not None:
218-
raise error_cls("fields can only be used with document_id")
240+
raise error_cls("fields can only be used with document_id or document_ids")
219241

220242
if min_term_frequency is not None:
221243
require_non_negative(min_term_frequency, field_name="min_term_frequency", error_cls=error_cls)
@@ -264,17 +286,35 @@ def more_like_this(
264286
)
265287
)
266288

267-
args: list[Any] = []
268-
if document_id is not None:
269-
args.append(document_id)
289+
def _build_id_args(doc_id: Any) -> list[Any]:
290+
id_args: list[Any] = [doc_id]
270291
if fields is not None or options_provided:
271-
args.append(array(fields or [], type_=Text()))
272-
else:
273-
payload = document if isinstance(document, str) else json.dumps(document, separators=(",", ":"), sort_keys=True)
274-
args.append(payload)
292+
id_args.append(array(fields or [], type_=Text()))
293+
if options_provided:
294+
id_args.extend(
295+
[
296+
1 if min_term_frequency is None else min_term_frequency,
297+
25 if max_query_terms is None else max_query_terms,
298+
0 if min_doc_frequency is None else min_doc_frequency,
299+
1_000_000 if max_doc_frequency is None else max_doc_frequency,
300+
0 if min_word_length is None else min_word_length,
301+
1000 if max_word_length is None else max_word_length,
302+
0.0 if boost_factor is None else boost_factor,
303+
array(stopwords or [], type_=Text()),
304+
]
305+
)
306+
return id_args
307+
308+
if document_ids is not None:
309+
return or_(*[field.operate(_QUERY, func.pdb.more_like_this(*_build_id_args(doc_id))) for doc_id in document_ids])
275310

311+
if document_id is not None:
312+
return field.operate(_QUERY, func.pdb.more_like_this(*_build_id_args(document_id)))
313+
314+
payload = document if isinstance(document, str) else json.dumps(document, separators=(",", ":"), sort_keys=True)
315+
doc_args: list[Any] = [payload]
276316
if options_provided:
277-
args.extend(
317+
doc_args.extend(
278318
[
279319
1 if min_term_frequency is None else min_term_frequency,
280320
25 if max_query_terms is None else max_query_terms,
@@ -286,5 +326,4 @@ def more_like_this(
286326
array(stopwords or [], type_=Text()),
287327
]
288328
)
289-
290-
return field.operate(_QUERY, func.pdb.more_like_this(*args))
329+
return field.operate(_QUERY, func.pdb.more_like_this(*doc_args))

tests/integration/conftest.py

Lines changed: 44 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,11 @@
55
from typing import Any
66

77
import pytest
8-
from sqlalchemy import Integer, String, Text, text
8+
from sqlalchemy import Boolean, DateTime, Integer, String, Text, create_engine, text
99
from sqlalchemy.dialects import postgresql
10+
from sqlalchemy.dialects.postgresql import JSONB
1011
from sqlalchemy.engine import Engine
1112
from sqlalchemy.orm import DeclarativeBase, Mapped, Session, mapped_column
12-
from sqlalchemy import create_engine
1313

1414

1515
class Base(DeclarativeBase):
@@ -25,6 +25,20 @@ class Product(Base):
2525
rating: Mapped[int] = mapped_column(Integer, nullable=False)
2626

2727

28+
class MockItem(Base):
29+
"""Maps to mock_items created by paradedb.create_bm25_test_table."""
30+
31+
__tablename__ = "mock_items"
32+
33+
id: Mapped[int] = mapped_column(Integer, primary_key=True)
34+
description: Mapped[str] = mapped_column(Text, nullable=False)
35+
category: Mapped[str] = mapped_column(String(120), nullable=False)
36+
rating: Mapped[int] = mapped_column(Integer, nullable=False)
37+
in_stock: Mapped[bool] = mapped_column(Boolean, nullable=False)
38+
created_at: Mapped[Any] = mapped_column(DateTime, nullable=False)
39+
metadata_: Mapped[Any] = mapped_column("metadata", JSONB, nullable=True)
40+
41+
2842
@pytest.fixture(scope="session")
2943
def db_url() -> str:
3044
url = os.environ.get("PARADEDB_TEST_DSN") or os.environ.get("DATABASE_URL") or "postgresql://postgres:postgres@localhost:5443/postgres"
@@ -112,3 +126,31 @@ def assert_uses_paradedb_scan(session: Session, stmt, *, index_name: str = "prod
112126
]
113127
assert parade_nodes, f"Expected ParadeDB Custom Scan in plan, got: {plan}"
114128
assert any(node.get("Index") == index_name for node in parade_nodes), f"Expected index {index_name} in plan: {plan}"
129+
130+
131+
@pytest.fixture(scope="session")
132+
def paradedb_ready(engine: Engine) -> None:
133+
"""Ensure ParadeDB mock_items table exists and is indexed."""
134+
with engine.begin() as conn:
135+
conn.execute(text("CREATE EXTENSION IF NOT EXISTS pg_search"))
136+
conn.execute(text("DROP INDEX IF EXISTS mock_items_bm25_idx"))
137+
conn.execute(text("DROP TABLE IF EXISTS mock_items"))
138+
conn.execute(
139+
text(
140+
"CALL paradedb.create_bm25_test_table(schema_name => 'public', table_name => 'mock_items')"
141+
)
142+
)
143+
conn.execute(
144+
text(
145+
"CREATE INDEX mock_items_bm25_idx ON mock_items USING bm25 ("
146+
"id, description, category, rating, in_stock"
147+
") WITH (key_field='id')"
148+
)
149+
)
150+
151+
152+
@pytest.fixture()
153+
def mock_session(engine: Engine, paradedb_ready: None) -> Iterator[Session]:
154+
"""Session fixture for tests using the mock_items table."""
155+
with Session(engine) as session:
156+
yield session

tests/integration/test_advanced_search_integration.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,3 +88,50 @@ def test_more_like_this_rejects_invalid_numeric_options():
8888
def test_more_like_this_rejects_invalid_stopwords():
8989
with pytest.raises(InvalidArgumentError, match="stopwords entries must be non-empty strings"):
9090
search.more_like_this(Product.id, document_id=1, stopwords=["ok", ""])
91+
92+
93+
def test_more_like_this_by_document_ids(session):
94+
"""document_ids ORs results from multiple individual MLT queries."""
95+
stmt_combined = (
96+
select(Product.id)
97+
.where(search.more_like_this(Product.id, document_ids=[1, 3], fields=["description"]))
98+
.order_by(Product.id)
99+
)
100+
stmt_id1 = (
101+
select(Product.id)
102+
.where(search.more_like_this(Product.id, document_id=1, fields=["description"]))
103+
.order_by(Product.id)
104+
)
105+
stmt_id3 = (
106+
select(Product.id)
107+
.where(search.more_like_this(Product.id, document_id=3, fields=["description"]))
108+
.order_by(Product.id)
109+
)
110+
assert_uses_paradedb_scan(session, stmt_combined)
111+
ids_combined = set(session.scalars(stmt_combined))
112+
ids_1 = set(session.scalars(stmt_id1))
113+
ids_3 = set(session.scalars(stmt_id3))
114+
# Combined should be union of individual results
115+
assert ids_1.issubset(ids_combined)
116+
assert ids_3.issubset(ids_combined)
117+
118+
119+
def test_near_ordered_predicate(session):
120+
"""near() with ordered=True uses ##> and finds terms in sequence."""
121+
stmt_ordered = (
122+
select(Product.id)
123+
.where(search.near(Product.description, "sleek", "shoes", distance=5, ordered=True))
124+
.order_by(Product.id)
125+
)
126+
stmt_unordered = (
127+
select(Product.id)
128+
.where(search.near(Product.description, "sleek", "shoes", distance=5))
129+
.order_by(Product.id)
130+
)
131+
assert_uses_paradedb_scan(session, stmt_ordered)
132+
ids_ordered = set(session.scalars(stmt_ordered))
133+
ids_unordered = set(session.scalars(stmt_unordered))
134+
# Ordered proximity should be a subset of unordered
135+
assert ids_ordered.issubset(ids_unordered)
136+
# "Sleek running shoes" — sleek appears before shoes, so ordered should find it
137+
assert 1 in ids_ordered

0 commit comments

Comments
 (0)