88from sqlalchemy .sql import operators
99from sqlalchemy .sql .elements import ClauseElement , ColumnElement
1010
11+ from ._functions import PDBFunctionWithNamedArgs
1112from ._pdb_cast import PDBCast
1213from .errors import InvalidArgumentError , InvalidMoreLikeThisOptionsError
1314from .validation import (
@@ -238,6 +239,8 @@ def range_term(
238239 raise InvalidArgumentError (
239240 f"relation must be one of: { ', ' .join (sorted (_VALID_RANGE_RELATIONS ))} "
240241 )
242+ escaped_relation = relation .replace ("'" , "''" )
243+ relation_arg : ClauseElement = literal_column (f"'{ escaped_relation } '" )
241244 if range_type is not None :
242245 if range_type not in _VALID_RANGE_TYPES :
243246 raise InvalidArgumentError (
@@ -247,7 +250,7 @@ def range_term(
247250 bounds_arg : ClauseElement = literal_column (f"'{ escaped } '::{ range_type } " )
248251 else :
249252 bounds_arg = literal (bounds )
250- return field .operate (_QUERY , func .pdb .range_term (bounds_arg , relation ))
253+ return field .operate (_QUERY , func .pdb .range_term (bounds_arg , relation_arg ))
251254
252255
253256def more_like_this (
@@ -309,58 +312,36 @@ def more_like_this(
309312 if stopwords is not None :
310313 require_non_empty_strings (stopwords , field_name = "stopwords" , error_cls = error_cls )
311314
312- options_provided = any (
313- option is not None
314- for option in (
315- min_term_frequency ,
316- max_query_terms ,
317- min_doc_frequency ,
318- max_doc_frequency ,
319- min_word_length ,
320- max_word_length ,
321- boost_factor ,
322- stopwords ,
323- )
324- )
325-
326- def _build_id_args (doc_id : Any ) -> list [Any ]:
327- id_args : list [Any ] = [doc_id ]
328- if fields is not None or options_provided :
329- id_args .append (array (fields or [], type_ = Text ()))
330- if options_provided :
331- id_args .extend (
332- [
333- 1 if min_term_frequency is None else min_term_frequency ,
334- 25 if max_query_terms is None else max_query_terms ,
335- 0 if min_doc_frequency is None else min_doc_frequency ,
336- 1_000_000 if max_doc_frequency is None else max_doc_frequency ,
337- 0 if min_word_length is None else min_word_length ,
338- 1000 if max_word_length is None else max_word_length ,
339- 0.0 if boost_factor is None else boost_factor ,
340- array (stopwords or [], type_ = Text ()),
341- ]
342- )
343- return id_args
315+ named_options : list [tuple [str , Any ]] = []
316+ if min_term_frequency is not None :
317+ named_options .append (("min_term_frequency" , min_term_frequency ))
318+ if max_query_terms is not None :
319+ named_options .append (("max_query_terms" , max_query_terms ))
320+ if min_doc_frequency is not None :
321+ named_options .append (("min_doc_frequency" , min_doc_frequency ))
322+ if max_doc_frequency is not None :
323+ named_options .append (("max_doc_frequency" , max_doc_frequency ))
324+ if min_word_length is not None :
325+ named_options .append (("min_word_length" , min_word_length ))
326+ if max_word_length is not None :
327+ named_options .append (("max_word_length" , max_word_length ))
328+ if boost_factor is not None :
329+ named_options .append (("boost_factor" , boost_factor ))
330+ if stopwords is not None :
331+ named_options .append (("stopwords" , array (stopwords , type_ = Text ())))
332+
333+ def _build_mlt_call (source_arg : ClauseElement , * , include_fields : bool ) -> ClauseElement :
334+ positional_args : list [ClauseElement ] = [source_arg ]
335+ if include_fields and fields is not None :
336+ positional_args .append (array (fields , type_ = Text ()))
337+ return PDBFunctionWithNamedArgs ("more_like_this" , positional_args , named_options )
344338
345339 if document_ids is not None :
346- return or_ (* [field .operate (_QUERY , func .pdb .more_like_this (* _build_id_args (doc_id ))) for doc_id in document_ids ])
340+ clauses = [field .operate (_QUERY , _build_mlt_call (literal (doc_id ), include_fields = True )) for doc_id in document_ids ]
341+ return or_ (* clauses )
347342
348343 if document_id is not None :
349- return field .operate (_QUERY , func . pdb . more_like_this ( * _build_id_args (document_id )))
344+ return field .operate (_QUERY , _build_mlt_call ( literal (document_id ), include_fields = True ))
350345
351346 payload = document if isinstance (document , str ) else json .dumps (document , separators = ("," , ":" ), sort_keys = True )
352- doc_args : list [Any ] = [payload ]
353- if options_provided :
354- doc_args .extend (
355- [
356- 1 if min_term_frequency is None else min_term_frequency ,
357- 25 if max_query_terms is None else max_query_terms ,
358- 0 if min_doc_frequency is None else min_doc_frequency ,
359- 1_000_000 if max_doc_frequency is None else max_doc_frequency ,
360- 0 if min_word_length is None else min_word_length ,
361- 1000 if max_word_length is None else max_word_length ,
362- 0.0 if boost_factor is None else boost_factor ,
363- array (stopwords or [], type_ = Text ()),
364- ]
365- )
366- return field .operate (_QUERY , func .pdb .more_like_this (* doc_args ))
347+ return field .operate (_QUERY , _build_mlt_call (literal (payload ), include_fields = False ))
0 commit comments