@@ -102,8 +102,17 @@ class ProximityExpr:
102102 def __init__ (self , expr : ClauseElement ) -> None :
103103 self .expr = expr
104104
105- def near (self , other : str | ClauseElement | ProximityExpr , * , distance : int , ordered : bool = False ) -> ProximityExpr :
106- return ProximityExpr (_near_chain (self .expr , other , distance = distance , ordered = ordered ))
105+ def near (
106+ self ,
107+ other : str | ClauseElement | ProximityExpr | None = None ,
108+ * ,
109+ distance : int ,
110+ ordered : bool = False ,
111+ right_pattern : str | None = None ,
112+ max_expansions : int = 100 ,
113+ ) -> ProximityExpr :
114+ right = _resolve_near_operand (other , right_pattern = right_pattern , max_expansions = max_expansions )
115+ return ProximityExpr (_near_chain (self .expr , right , distance = distance , ordered = ordered ))
107116
108117
109118def _to_proximity_operand (value : str | ClauseElement | ProximityExpr ) -> ClauseElement :
@@ -122,12 +131,29 @@ def _to_proximity_clause(value: str | ClauseElement | ProximityExpr) -> ClauseEl
122131
123132
124133def _near_chain (left : str | ClauseElement | ProximityExpr , right : str | ClauseElement | ProximityExpr , * , distance : int , ordered : bool = False ) -> ClauseElement :
134+ require_non_negative (distance , field_name = "distance" )
125135 left_expr = _to_proximity_clause (left )
126136 right_expr = _to_proximity_clause (right )
127137 op = _NEAR_ORDERED if ordered else _NEAR
128138 return left_expr .operate (op , literal (distance )).operate (op , right_expr )
129139
130140
141+ def _resolve_near_operand (
142+ right : str | ClauseElement | ProximityExpr | None ,
143+ * ,
144+ right_pattern : str | None ,
145+ max_expansions : int ,
146+ ) -> str | ClauseElement | ProximityExpr :
147+ if right_pattern is not None :
148+ if right is not None :
149+ raise InvalidArgumentError ("right and right_pattern cannot be used together" )
150+ require_non_negative (max_expansions , field_name = "max_expansions" )
151+ return prox_regex (right_pattern , max_expansions )
152+ if right is None :
153+ raise InvalidArgumentError ("right is required unless right_pattern is provided" )
154+ return right
155+
156+
131157def parse (field : ColumnElement , query : str , * , lenient : bool = False , conjunction_mode : bool = False ) -> ColumnElement [bool ]:
132158 return field .operate (_QUERY , func .pdb .parse (query , lenient , conjunction_mode ))
133159
@@ -153,11 +179,22 @@ def regex_phrase(
153179 return field .operate (_QUERY , func .pdb .regex_phrase (array (terms , type_ = Text ()), slop , max_expansions ))
154180
155181
156- def near (field : ColumnElement , left : str | ClauseElement , right : str | ClauseElement , * , distance : int , ordered : bool = False ) -> ColumnElement [bool ]:
157- return field .operate (_QUERY , _near_chain (left , right , distance = distance , ordered = ordered ))
182+ def near (
183+ field : ColumnElement ,
184+ left : str | ClauseElement | ProximityExpr ,
185+ right : str | ClauseElement | ProximityExpr | None = None ,
186+ * ,
187+ distance : int ,
188+ ordered : bool = False ,
189+ right_pattern : str | None = None ,
190+ max_expansions : int = 100 ,
191+ ) -> ColumnElement [bool ]:
192+ right_operand = _resolve_near_operand (right , right_pattern = right_pattern , max_expansions = max_expansions )
193+ return field .operate (_QUERY , _near_chain (left , right_operand , distance = distance , ordered = ordered ))
158194
159195
160196def prox_regex (pattern : str , max_expansions : int = 100 ) -> ProximityExpr :
197+ require_non_negative (max_expansions , field_name = "max_expansions" )
161198 return ProximityExpr (func .pdb .prox_regex (pattern , max_expansions ))
162199
163200
0 commit comments