Skip to content

Commit 84ffd56

Browse files
committed
Fix PR #3 integration regressions in query helpers
1 parent e667204 commit 84ffd56

6 files changed

Lines changed: 71 additions & 55 deletions

File tree

paradedb/sqlalchemy/pdb.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,12 @@ def snippet(
2525
if (start_tag is None) != (end_tag is None):
2626
raise InvalidArgumentError("start_tag and end_tag must be provided together")
2727

28+
if max_num_chars is not None and start_tag is None and end_tag is None:
29+
# ParadeDB versions in CI don't support pdb.snippet(field, max_num_chars)
30+
# directly. Supplying default tags targets the supported 4-arg form.
31+
start_tag = "<b>"
32+
end_tag = "</b>"
33+
2834
args: list[Any] = [field]
2935
if start_tag is not None and end_tag is not None:
3036
args.extend([start_tag, end_tag])
@@ -36,12 +42,21 @@ def snippet(
3642
def snippets(
3743
field: ColumnElement,
3844
*,
45+
start_tag: str | None = None,
46+
end_tag: str | None = None,
3947
max_num_chars: int | None = None,
4048
limit: int | None = None,
4149
offset: int | None = None,
4250
sort_by: str | None = None,
4351
) -> ClauseElement:
52+
if (start_tag is None) != (end_tag is None):
53+
raise InvalidArgumentError("start_tag and end_tag must be provided together")
54+
4455
named_args: list[tuple[str, Any]] = []
56+
if start_tag is not None:
57+
named_args.append(("start_tag", start_tag))
58+
if end_tag is not None:
59+
named_args.append(("end_tag", end_tag))
4560
if max_num_chars is not None:
4661
named_args.append(("max_num_chars", max_num_chars))
4762
if limit is not None:

paradedb/sqlalchemy/search.py

Lines changed: 31 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
from sqlalchemy.sql import operators
99
from sqlalchemy.sql.elements import ClauseElement, ColumnElement
1010

11+
from ._functions import PDBFunctionWithNamedArgs
1112
from ._pdb_cast import PDBCast
1213
from .errors import InvalidArgumentError, InvalidMoreLikeThisOptionsError
1314
from .validation import (
@@ -238,6 +239,8 @@ def range_term(
238239
raise InvalidArgumentError(
239240
f"relation must be one of: {', '.join(sorted(_VALID_RANGE_RELATIONS))}"
240241
)
242+
escaped_relation = relation.replace("'", "''")
243+
relation_arg: ClauseElement = literal_column(f"'{escaped_relation}'")
241244
if range_type is not None:
242245
if range_type not in _VALID_RANGE_TYPES:
243246
raise InvalidArgumentError(
@@ -247,7 +250,7 @@ def range_term(
247250
bounds_arg: ClauseElement = literal_column(f"'{escaped}'::{range_type}")
248251
else:
249252
bounds_arg = literal(bounds)
250-
return field.operate(_QUERY, func.pdb.range_term(bounds_arg, relation))
253+
return field.operate(_QUERY, func.pdb.range_term(bounds_arg, relation_arg))
251254

252255

253256
def more_like_this(
@@ -309,58 +312,36 @@ def more_like_this(
309312
if stopwords is not None:
310313
require_non_empty_strings(stopwords, field_name="stopwords", error_cls=error_cls)
311314

312-
options_provided = any(
313-
option is not None
314-
for option in (
315-
min_term_frequency,
316-
max_query_terms,
317-
min_doc_frequency,
318-
max_doc_frequency,
319-
min_word_length,
320-
max_word_length,
321-
boost_factor,
322-
stopwords,
323-
)
324-
)
325-
326-
def _build_id_args(doc_id: Any) -> list[Any]:
327-
id_args: list[Any] = [doc_id]
328-
if fields is not None or options_provided:
329-
id_args.append(array(fields or [], type_=Text()))
330-
if options_provided:
331-
id_args.extend(
332-
[
333-
1 if min_term_frequency is None else min_term_frequency,
334-
25 if max_query_terms is None else max_query_terms,
335-
0 if min_doc_frequency is None else min_doc_frequency,
336-
1_000_000 if max_doc_frequency is None else max_doc_frequency,
337-
0 if min_word_length is None else min_word_length,
338-
1000 if max_word_length is None else max_word_length,
339-
0.0 if boost_factor is None else boost_factor,
340-
array(stopwords or [], type_=Text()),
341-
]
342-
)
343-
return id_args
315+
named_options: list[tuple[str, Any]] = []
316+
if min_term_frequency is not None:
317+
named_options.append(("min_term_frequency", min_term_frequency))
318+
if max_query_terms is not None:
319+
named_options.append(("max_query_terms", max_query_terms))
320+
if min_doc_frequency is not None:
321+
named_options.append(("min_doc_frequency", min_doc_frequency))
322+
if max_doc_frequency is not None:
323+
named_options.append(("max_doc_frequency", max_doc_frequency))
324+
if min_word_length is not None:
325+
named_options.append(("min_word_length", min_word_length))
326+
if max_word_length is not None:
327+
named_options.append(("max_word_length", max_word_length))
328+
if boost_factor is not None:
329+
named_options.append(("boost_factor", boost_factor))
330+
if stopwords is not None:
331+
named_options.append(("stopwords", array(stopwords, type_=Text())))
332+
333+
def _build_mlt_call(source_arg: ClauseElement, *, include_fields: bool) -> ClauseElement:
334+
positional_args: list[ClauseElement] = [source_arg]
335+
if include_fields and fields is not None:
336+
positional_args.append(array(fields, type_=Text()))
337+
return PDBFunctionWithNamedArgs("more_like_this", positional_args, named_options)
344338

345339
if document_ids is not None:
346-
return or_(*[field.operate(_QUERY, func.pdb.more_like_this(*_build_id_args(doc_id))) for doc_id in document_ids])
340+
clauses = [field.operate(_QUERY, _build_mlt_call(literal(doc_id), include_fields=True)) for doc_id in document_ids]
341+
return or_(*clauses)
347342

348343
if document_id is not None:
349-
return field.operate(_QUERY, func.pdb.more_like_this(*_build_id_args(document_id)))
344+
return field.operate(_QUERY, _build_mlt_call(literal(document_id), include_fields=True))
350345

351346
payload = document if isinstance(document, str) else json.dumps(document, separators=(",", ":"), sort_keys=True)
352-
doc_args: list[Any] = [payload]
353-
if options_provided:
354-
doc_args.extend(
355-
[
356-
1 if min_term_frequency is None else min_term_frequency,
357-
25 if max_query_terms is None else max_query_terms,
358-
0 if min_doc_frequency is None else min_doc_frequency,
359-
1_000_000 if max_doc_frequency is None else max_doc_frequency,
360-
0 if min_word_length is None else min_word_length,
361-
1000 if max_word_length is None else max_word_length,
362-
0.0 if boost_factor is None else boost_factor,
363-
array(stopwords or [], type_=Text()),
364-
]
365-
)
366-
return field.operate(_QUERY, func.pdb.more_like_this(*doc_args))
347+
return field.operate(_QUERY, _build_mlt_call(literal(payload), include_fields=False))

paradedb/sqlalchemy/select_with.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ def snippets(
4242
field: ColumnElement,
4343
*,
4444
label: str = "snippets",
45+
start_tag: str | None = None,
46+
end_tag: str | None = None,
4547
max_num_chars: int | None = None,
4648
limit: int | None = None,
4749
offset: int | None = None,
@@ -51,6 +53,8 @@ def snippets(
5153
return stmt.add_columns(
5254
pdb.snippets(
5355
field,
56+
start_tag=start_tag,
57+
end_tag=end_tag,
5458
max_num_chars=max_num_chars,
5559
limit=limit,
5660
offset=offset,

tests/integration/conftest.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,9 @@ def explain_plan_json(session: Session, stmt) -> dict[str, Any]:
103103
compile_kwargs={"literal_binds": True},
104104
)
105105
)
106-
explain_result = session.execute(text(f"EXPLAIN (FORMAT JSON) {sql}")).scalar_one()
106+
# Use driver-level execution so SQLAlchemy doesn't treat JSON fragments like
107+
# `...:2...` inside string literals as bind placeholders.
108+
explain_result = session.connection().exec_driver_sql(f"EXPLAIN (FORMAT JSON) {sql}").scalar_one()
107109
return _normalize_explain_plan(explain_result)
108110

109111

tests/integration/test_indexing_integration.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -182,8 +182,9 @@ def test_bm25_index_json_keys_when_supported(engine):
182182
).one()
183183

184184
assert "->>" in row.indexdef
185-
assert "alias=metadata_color" in row.indexdef
186-
assert "alias=metadata_location" in row.indexdef
185+
assert "'color'" in row.indexdef
186+
assert "'location'" in row.indexdef
187+
assert row.indexdef.count("pdb.literal(") >= 2
187188

188189
_drop_table_and_index(engine, table_name, index_name)
189190

tests/unit/test_sql_compilation_unit.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,14 +61,24 @@ def test_pdb_helpers_compile():
6161
stmt = select(
6262
pdb.score(products.c.id).label("score"),
6363
pdb.snippet(products.c.description, start_tag="<mark>", end_tag="</mark>", max_num_chars=100).label("snippet"),
64-
pdb.snippets(products.c.description, max_num_chars=15, limit=1, offset=0, sort_by="position").label("snippets"),
64+
pdb.snippets(
65+
products.c.description,
66+
start_tag="[",
67+
end_tag="]",
68+
max_num_chars=15,
69+
limit=1,
70+
offset=0,
71+
sort_by="position",
72+
).label("snippets"),
6573
pdb.snippet_positions(products.c.description).label("positions"),
6674
)
6775
sql = _sql(stmt)
6876

6977
assert "pdb.score(products.id) AS score" in sql
7078
assert "pdb.snippet(products.description, '<mark>', '</mark>', 100) AS snippet" in sql
7179
assert "pdb.snippets(products.description" in sql
80+
assert "start_tag => '['" in sql
81+
assert "end_tag => ']'" in sql
7282
assert "max_num_chars => 15" in sql
7383
assert '"limit" => 1' in sql
7484
assert '"offset" => 0' in sql
@@ -167,7 +177,10 @@ def test_more_like_this_compile():
167177

168178
assert "id @@@ pdb.more_like_this(3, ARRAY['description'])" in by_id_sql
169179
assert "id @@@ pdb.more_like_this('{\"description\":\"wireless earbuds\"}')" in by_doc_sql
170-
assert "id @@@ pdb.more_like_this(3, ARRAY['description'], 2, 10" in with_opts_sql
180+
assert "id @@@ pdb.more_like_this(3, ARRAY['description']" in with_opts_sql
181+
assert "min_term_frequency => 2" in with_opts_sql
182+
assert "max_query_terms => 10" in with_opts_sql
183+
assert "stopwords => ARRAY['the', 'a']" in with_opts_sql
171184
assert "ARRAY['the', 'a']" in with_opts_sql
172185

173186

0 commit comments

Comments
 (0)