Skip to content

Commit d6dff0a

Browse files
kaushikcfdinducer
authored andcommitted
Fix TypeInferenceMapper.map_floor_div with integral dtypes.
Avoid div by 0 in TypeInferenceMapper.map_floor_div basedpyright: add a cast.
1 parent aaeb190 commit d6dff0a

File tree

2 files changed

+61
-0
lines changed

2 files changed

+61
-0
lines changed

loopy/type_inference.py

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -429,6 +429,53 @@ def map_quotient(self, expr: p.Quotient):
429429
else:
430430
return self.combine([n_dtype_set, d_dtype_set])
431431

432+
def _map_int_div_modulo(self, expr: p.FloorDiv | p.Remainder):
433+
# This is pretty gross, but generally appears to lack alternatives.
434+
# See https://github.com/inducer/loopy/pull/1000 for some discussion.
435+
# In general, for array // array, numpy is very eager to infer
436+
# float dtypes (for example for u64/i32), which doesn't work for us:
437+
# integers should stay integers to stay usable as array indices.
438+
439+
n_dtype_set = self.rec(expr.numerator)
440+
d_dtype_set = self.rec(expr.denominator)
441+
442+
if not (n_dtype_set and d_dtype_set):
443+
return cast("list[NumpyType]", [])
444+
445+
n_dtype = n_dtype_set[0].numpy_dtype
446+
d_dtype = d_dtype_set[0].numpy_dtype
447+
num = (
448+
np.empty(0, dtype=n_dtype)
449+
if not is_integer(expr.numerator)
450+
else expr.numerator
451+
)
452+
denom = (
453+
np.empty(0, dtype=d_dtype)
454+
if not is_integer(expr.denominator)
455+
else expr.denominator
456+
)
457+
denom = (
458+
cast("int | np.integer", denom + 1)
459+
if is_integer(denom) and denom == 0
460+
else denom
461+
) # avoid divide by zero.
462+
463+
if is_integer(num) and is_integer(denom):
464+
return self.rec(num // denom)
465+
466+
floor_div_np = num // denom
467+
assert isinstance(floor_div_np, np.ndarray)
468+
469+
return [NumpyType(floor_div_np.dtype)]
470+
471+
@override
472+
def map_floor_div(self, expr: p.FloorDiv):
473+
return self._map_int_div_modulo(expr)
474+
475+
@override
476+
def map_remainder(self, expr: p.Remainder):
477+
return self._map_int_div_modulo(expr)
478+
432479
@override
433480
def map_constant(self, expr: object):
434481
if isinstance(expr, np.generic):

test/test_loopy.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3733,6 +3733,20 @@ def test_type_cast_parse_stringify_roundtrip():
37333733
assert expr == parsed
37343734

37353735

3736+
def test_floor_div_modulo_with_uint_index():
3737+
# See <https://github.com/inducer/loopy/issues/999>
3738+
knl = lp.make_kernel(
3739+
"{[i]: 0<=i<10}",
3740+
"a[map[i] // 2, map[i] % 35] = i",
3741+
[
3742+
lp.GlobalArg("map", dtype=np.uint64, shape=lp.auto),
3743+
lp.GlobalArg("a", dtype=np.float64, shape=(10, 4)),
3744+
],
3745+
)
3746+
# check the codegen is successful
3747+
lp.generate_code_v2(knl).device_code()
3748+
3749+
37363750
if __name__ == "__main__":
37373751
import sys
37383752
if len(sys.argv) > 1:

0 commit comments

Comments
 (0)