Skip to content

Commit 2156779

Browse files
committed
Add support for edge_ngram
1 parent 854bbc0 commit 2156779

2 files changed

Lines changed: 33 additions & 1 deletion

File tree

paradedb/sqlalchemy/tokenizer.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,6 +265,38 @@ def ngram(
265265
)
266266

267267

268+
def edge_ngram(
269+
*,
270+
alias: str | None = None,
271+
min_gram: int | None = None,
272+
max_gram: int | None = None,
273+
args: Sequence[Any] | None = None,
274+
named_args: Mapping[str, Any] | None = None,
275+
filters: Sequence[str] | None = None,
276+
stemmer: str | None = None,
277+
) -> Tokenizer:
278+
positional_args: list[Any] = list(args or ())
279+
use_positional_bounds = min_gram is not None and max_gram is not None and not positional_args
280+
if use_positional_bounds:
281+
positional_args.extend([min_gram, max_gram])
282+
283+
all_named_args: dict[str, Any] = {}
284+
if named_args is not None:
285+
all_named_args.update({str(key): value for key, value in named_args.items()})
286+
if min_gram is not None and not use_positional_bounds:
287+
all_named_args["min_gram"] = min_gram
288+
if max_gram is not None and not use_positional_bounds:
289+
all_named_args["max_gram"] = max_gram
290+
return _build_spec(
291+
"edge_ngram",
292+
alias=alias,
293+
args=positional_args,
294+
named_args=all_named_args,
295+
filters=filters,
296+
stemmer=stemmer,
297+
)
298+
299+
268300
def lindera(
269301
dictionary: str | None = None,
270302
*,

tests/integration/test_query_interface_integration.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ def test_agg_function_projection(session):
152152
("pdb.literal_normalized", tokenizer.literal_normalized()),
153153
("pdb.ngram(3,3)", tokenizer.ngram(args=(3, 3))),
154154
("pdb.ngram(3,3,'positions=true')", tokenizer.ngram(args=(3, 3), named_args={"positions": "true"})),
155-
# ("pdb.edge_ngram(3, 3)", tokenizer.edge), TODO add support
155+
("pdb.edge_ngram(2,5)", tokenizer.edge_ngram(args=(2, 5))),
156156
("pdb.simple", tokenizer.simple()),
157157
("pdb.regex_pattern('.*')", tokenizer.regex_pattern(".*")),
158158
("pdb.chinese_compatible", tokenizer.chinese_compatible()),

0 commit comments

Comments
 (0)