@@ -429,6 +429,53 @@ def map_quotient(self, expr: p.Quotient):
429429 else :
430430 return self .combine ([n_dtype_set , d_dtype_set ])
431431
432+ def _map_int_div_modulo (self , expr : p .FloorDiv | p .Remainder ):
433+ # This is pretty gross, but generally appears to lack alternatives.
434+ # See https://github.com/inducer/loopy/pull/1000 for some discussion.
435+ # In general, for array // array, numpy is very eager to infer
436+ # float dtypes (for example for u64/i32), which doesn't work for us:
437+ # integers should stay integers to stay usable as array indices.
438+
439+ n_dtype_set = self .rec (expr .numerator )
440+ d_dtype_set = self .rec (expr .denominator )
441+
442+ if not (n_dtype_set and d_dtype_set ):
443+ return cast ("list[NumpyType]" , [])
444+
445+ n_dtype = n_dtype_set [0 ].numpy_dtype
446+ d_dtype = d_dtype_set [0 ].numpy_dtype
447+ num = (
448+ np .empty (0 , dtype = n_dtype )
449+ if not is_integer (expr .numerator )
450+ else expr .numerator
451+ )
452+ denom = (
453+ np .empty (0 , dtype = d_dtype )
454+ if not is_integer (expr .denominator )
455+ else expr .denominator
456+ )
457+ denom = (
458+ cast ("int | np.integer" , denom + 1 )
459+ if is_integer (denom ) and denom == 0
460+ else denom
461+ ) # avoid divide by zero.
462+
463+ if is_integer (num ) and is_integer (denom ):
464+ return self .rec (num // denom )
465+
466+ floor_div_np = num // denom
467+ assert isinstance (floor_div_np , np .ndarray )
468+
469+ return [NumpyType (floor_div_np .dtype )]
470+
471+ @override
472+ def map_floor_div (self , expr : p .FloorDiv ):
473+ return self ._map_int_div_modulo (expr )
474+
475+ @override
476+ def map_remainder (self , expr : p .Remainder ):
477+ return self ._map_int_div_modulo (expr )
478+
432479 @override
433480 def map_constant (self , expr : object ):
434481 if isinstance (expr , np .generic ):
0 commit comments