Skip to content

Commit a52c0f0

Browse files
committed
Fix TypeInferenceMapper.map_floor_div with integral dtypes.
1 parent aaeb190 commit a52c0f0

2 files changed

Lines changed: 49 additions & 0 deletions

File tree

loopy/type_inference.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -429,6 +429,41 @@ def map_quotient(self, expr: p.Quotient):
429429
else:
430430
return self.combine([n_dtype_set, d_dtype_set])
431431

432+
def _map_int_div_modulo(self, expr: p.FloorDiv | p.Remainder):
433+
n_dtype_set = self.rec(expr.numerator)
434+
d_dtype_set = self.rec(expr.denominator)
435+
436+
if not (n_dtype_set and d_dtype_set):
437+
return cast("list[NumpyType]", [])
438+
439+
n_dtype = cast("NumpyType", n_dtype_set[0]).dtype
440+
d_dtype = cast("NumpyType", d_dtype_set[0]).dtype
441+
num = (
442+
np.empty(0, dtype=n_dtype)
443+
if not is_integer(expr.numerator)
444+
else expr.numerator
445+
)
446+
denom = (
447+
np.empty(0, dtype=d_dtype)
448+
if not is_integer(expr.denominator)
449+
else expr.denominator
450+
)
451+
if is_integer(num) and is_integer(denom):
452+
return self.rec(num // denom)
453+
454+
floor_div_np = num // denom
455+
assert isinstance(floor_div_np, np.ndarray)
456+
457+
return [NumpyType(floor_div_np.dtype)]
458+
459+
@override
460+
def map_floor_div(self, expr: p.FloorDiv):
461+
return self._map_int_div_modulo(expr)
462+
463+
@override
464+
def map_remainder(self, expr: p.Remainder):
465+
return self._map_int_div_modulo(expr)
466+
432467
@override
433468
def map_constant(self, expr: object):
434469
if isinstance(expr, np.generic):

test/test_loopy.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3733,6 +3733,20 @@ def test_type_cast_parse_stringify_roundtrip():
37333733
assert expr == parsed
37343734

37353735

3736+
def test_floor_div_modulo_with_uint_index():
3737+
# See <https://github.com/inducer/loopy/issues/999>
3738+
knl = lp.make_kernel(
3739+
"{[i]: 0<=i<10}",
3740+
"a[map[i] // 2, map[i] % 35] = i",
3741+
[
3742+
lp.GlobalArg("map", dtype=np.uint64, shape=lp.auto),
3743+
lp.GlobalArg("a", dtype=np.float64, shape=(10, 4)),
3744+
],
3745+
)
3746+
# check the codegen is successful
3747+
lp.generate_code_v2(knl).device_code()
3748+
3749+
37363750
if __name__ == "__main__":
37373751
import sys
37383752
if len(sys.argv) > 1:

0 commit comments

Comments
 (0)