33import json
44from typing import Any
55
6- from sqlalchemy import Text , func , literal
6+ from sqlalchemy import Text , func , literal , literal_column , or_
77from sqlalchemy .dialects .postgresql import array
88from sqlalchemy .sql import operators
99from sqlalchemy .sql .elements import ClauseElement , ColumnElement
1818)
1919
2020_VALID_RANGE_RELATIONS : frozenset [str ] = frozenset ({"Intersects" , "Contains" , "Within" , "ContainsOrIntersects" })
21+ _VALID_RANGE_TYPES : frozenset [str ] = frozenset ({"int4range" , "int8range" , "numrange" , "daterange" , "tsrange" , "tstzrange" })
2122
2223_MATCH_ALL = operators .custom_op ("&&&" , precedence = 5 , is_comparison = True )
2324_MATCH_ANY = operators .custom_op ("|||" , precedence = 5 , is_comparison = True )
2425_TERM = operators .custom_op ("===" , precedence = 5 , is_comparison = True )
2526_PHRASE = operators .custom_op ("###" , precedence = 5 , is_comparison = True )
2627_QUERY = operators .custom_op ("@@@" , precedence = 5 , is_comparison = True )
2728_NEAR = operators .custom_op ("##" , precedence = 5 )
29+ _NEAR_ORDERED = operators .custom_op ("##>" , precedence = 5 )
2830
2931
3032def _to_term_payload (* terms : str ) -> ClauseElement :
@@ -100,8 +102,8 @@ class ProximityExpr:
100102 def __init__ (self , expr : ClauseElement ) -> None :
101103 self .expr = expr
102104
103- def near (self , other : str | ClauseElement | ProximityExpr , * , distance : int ) -> ProximityExpr :
104- return ProximityExpr (_near_chain (self .expr , other , distance = distance ))
105+ def near (self , other : str | ClauseElement | ProximityExpr , * , distance : int , ordered : bool = False ) -> ProximityExpr :
106+ return ProximityExpr (_near_chain (self .expr , other , distance = distance , ordered = ordered ))
105107
106108
107109def _to_proximity_operand (value : str | ClauseElement | ProximityExpr ) -> ClauseElement :
@@ -119,10 +121,11 @@ def _to_proximity_clause(value: str | ClauseElement | ProximityExpr) -> ClauseEl
119121 return PDBCast (operand , "proximityclause" )
120122
121123
122- def _near_chain (left : str | ClauseElement | ProximityExpr , right : str | ClauseElement | ProximityExpr , * , distance : int ) -> ClauseElement :
124+ def _near_chain (left : str | ClauseElement | ProximityExpr , right : str | ClauseElement | ProximityExpr , * , distance : int , ordered : bool = False ) -> ClauseElement :
123125 left_expr = _to_proximity_clause (left )
124126 right_expr = _to_proximity_clause (right )
125- return left_expr .operate (_NEAR , literal (distance )).operate (_NEAR , right_expr )
127+ op = _NEAR_ORDERED if ordered else _NEAR
128+ return left_expr .operate (op , literal (distance )).operate (op , right_expr )
126129
127130
128131def parse (field : ColumnElement , query : str , * , lenient : bool = False , conjunction_mode : bool = False ) -> ColumnElement [bool ]:
@@ -150,8 +153,8 @@ def regex_phrase(
150153 return field .operate (_QUERY , func .pdb .regex_phrase (array (terms , type_ = Text ()), slop , max_expansions ))
151154
152155
153- def near (field : ColumnElement , left : str | ClauseElement , right : str | ClauseElement , * , distance : int ) -> ColumnElement [bool ]:
154- return field .operate (_QUERY , _near_chain (left , right , distance = distance ))
156+ def near (field : ColumnElement , left : str | ClauseElement , right : str | ClauseElement , * , distance : int , ordered : bool = False ) -> ColumnElement [bool ]:
157+ return field .operate (_QUERY , _near_chain (left , right , distance = distance , ordered = ordered ))
155158
156159
157160def prox_regex (pattern : str , max_expansions : int = 100 ) -> ProximityExpr :
@@ -175,6 +178,7 @@ def range_term(
175178 bounds : str ,
176179 * ,
177180 relation : str = "Intersects" ,
181+ range_type : str | None = None ,
178182) -> ColumnElement [bool ]:
179183 """Match rows where a range-typed field satisfies a range predicate.
180184
@@ -183,22 +187,37 @@ def range_term(
183187 bounds: A range literal string, e.g. ``"[3,9]"``, ``"(3,9]"``.
184188 relation: One of ``"Intersects"``, ``"Contains"``, ``"Within"``,
185189 ``"ContainsOrIntersects"``. Defaults to ``"Intersects"``.
190+ range_type: Optional PostgreSQL range type for explicit casting, e.g.
191+ ``"int4range"``, ``"int8range"``, ``"numrange"``,
192+ ``"daterange"``, ``"tsrange"``, ``"tstzrange"``.
193+ When provided, generates ``'bounds'::range_type`` cast.
186194
187195 Generates::
188196
189197 field @@@ pdb.range_term('[3,9]', 'Contains')
198+ field @@@ pdb.range_term('[3,9]'::int4range, 'Contains')
190199 """
191200 if relation not in _VALID_RANGE_RELATIONS :
192201 raise InvalidArgumentError (
193202 f"relation must be one of: { ', ' .join (sorted (_VALID_RANGE_RELATIONS ))} "
194203 )
195- return field .operate (_QUERY , func .pdb .range_term (bounds , relation ))
204+ if range_type is not None :
205+ if range_type not in _VALID_RANGE_TYPES :
206+ raise InvalidArgumentError (
207+ f"range_type must be one of: { ', ' .join (sorted (_VALID_RANGE_TYPES ))} "
208+ )
209+ escaped = bounds .replace ("'" , "''" )
210+ bounds_arg : ClauseElement = literal_column (f"'{ escaped } '::{ range_type } " )
211+ else :
212+ bounds_arg = literal (bounds )
213+ return field .operate (_QUERY , func .pdb .range_term (bounds_arg , relation ))
196214
197215
198216def more_like_this (
199217 field : ColumnElement ,
200218 * ,
201219 document_id : Any | None = None ,
220+ document_ids : list [Any ] | None = None ,
202221 document : dict [str , Any ] | str | None = None ,
203222 fields : list [str ] | None = None ,
204223 min_term_frequency : int | None = None ,
@@ -212,10 +231,13 @@ def more_like_this(
212231) -> ColumnElement [bool ]:
213232 error_cls = InvalidMoreLikeThisOptionsError
214233
215- if (document_id is None ) == (document is None ):
216- raise error_cls ("exactly one of document_id or document must be provided" )
234+ sources_provided = sum (x is not None for x in (document_id , document_ids , document ))
235+ if sources_provided != 1 :
236+ raise error_cls ("exactly one of document_id, document_ids, or document must be provided" )
237+ if document_ids is not None and len (document_ids ) == 0 :
238+ raise error_cls ("document_ids must not be empty" )
217239 if document is not None and fields is not None :
218- raise error_cls ("fields can only be used with document_id" )
240+ raise error_cls ("fields can only be used with document_id or document_ids " )
219241
220242 if min_term_frequency is not None :
221243 require_non_negative (min_term_frequency , field_name = "min_term_frequency" , error_cls = error_cls )
@@ -264,17 +286,35 @@ def more_like_this(
264286 )
265287 )
266288
267- args : list [Any ] = []
268- if document_id is not None :
269- args .append (document_id )
289+ def _build_id_args (doc_id : Any ) -> list [Any ]:
290+ id_args : list [Any ] = [doc_id ]
270291 if fields is not None or options_provided :
271- args .append (array (fields or [], type_ = Text ()))
272- else :
273- payload = document if isinstance (document , str ) else json .dumps (document , separators = ("," , ":" ), sort_keys = True )
274- args .append (payload )
292+ id_args .append (array (fields or [], type_ = Text ()))
293+ if options_provided :
294+ id_args .extend (
295+ [
296+ 1 if min_term_frequency is None else min_term_frequency ,
297+ 25 if max_query_terms is None else max_query_terms ,
298+ 0 if min_doc_frequency is None else min_doc_frequency ,
299+ 1_000_000 if max_doc_frequency is None else max_doc_frequency ,
300+ 0 if min_word_length is None else min_word_length ,
301+ 1000 if max_word_length is None else max_word_length ,
302+ 0.0 if boost_factor is None else boost_factor ,
303+ array (stopwords or [], type_ = Text ()),
304+ ]
305+ )
306+ return id_args
307+
308+ if document_ids is not None :
309+ return or_ (* [field .operate (_QUERY , func .pdb .more_like_this (* _build_id_args (doc_id ))) for doc_id in document_ids ])
275310
311+ if document_id is not None :
312+ return field .operate (_QUERY , func .pdb .more_like_this (* _build_id_args (document_id )))
313+
314+ payload = document if isinstance (document , str ) else json .dumps (document , separators = ("," , ":" ), sort_keys = True )
315+ doc_args : list [Any ] = [payload ]
276316 if options_provided :
277- args .extend (
317+ doc_args .extend (
278318 [
279319 1 if min_term_frequency is None else min_term_frequency ,
280320 25 if max_query_terms is None else max_query_terms ,
@@ -286,5 +326,4 @@ def more_like_this(
286326 array (stopwords or [], type_ = Text ()),
287327 ]
288328 )
289-
290- return field .operate (_QUERY , func .pdb .more_like_this (* args ))
329+ return field .operate (_QUERY , func .pdb .more_like_this (* doc_args ))
0 commit comments