Skip to content

Commit 907873c

Browse files
committed
add tiled matmul example as test; islpy -> namedisl
1 parent aa8b612 commit 907873c

4 files changed

Lines changed: 152 additions & 103 deletions

File tree

examples/python/compute-examples/compute-tiled-matmul.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,11 +43,11 @@ def main(
4343
# guarantees for the instruction that stores into c
4444
knl = lp.fix_parameters(knl, M=M, N=N, K=K)
4545

46+
# shared memory tile-level splitting
4647
knl = lp.split_iname(knl, "i", bm, inner_iname="ii", outer_iname="io")
4748
knl = lp.split_iname(knl, "j", bn, inner_iname="ji", outer_iname="jo")
4849
knl = lp.split_iname(knl, "k", bk, inner_iname="ki", outer_iname="ko")
4950

50-
# FIXME: Given the input is already tiled, we shouldn't have to supply compute bounds here.
5151
compute_map_a = nisl.make_map(f"""{{
5252
[is, ks] -> [ii_s, io, ki_s, ko] :
5353
is = io * {bm} + ii_s and
@@ -67,7 +67,8 @@ def main(
6767
compute_map=compute_map_a,
6868
storage_indices=["ii_s", "ki_s"],
6969
temporal_inames=["io", "ko", "jo"],
70-
temporary_address_space=lp.AddressSpace.LOCAL
70+
temporary_address_space=lp.AddressSpace.LOCAL,
71+
temporary_dtype=np.float64
7172
)
7273

7374
knl = compute(
@@ -76,7 +77,8 @@ def main(
7677
compute_map=compute_map_b,
7778
storage_indices=["ki_s", "ji_s"],
7879
temporal_inames=["io", "ko", "jo"],
79-
temporary_address_space=lp.AddressSpace.LOCAL
80+
temporary_address_space=lp.AddressSpace.LOCAL,
81+
temporary_dtype=np.float64
8082
)
8183

8284
if use_precompute:
@@ -116,7 +118,7 @@ def main(
116118
print(20*"=", "Tiled matmul report", 20*"=")
117119
print(f"Problem size: M = {M:-4}, N = {N:-4}, K = {K:-4}")
118120
print(f"Tile size : BM = {bm:-4}, BN = {bn:-4}, BK = {bk:-4}")
119-
print(f"Relative error = {la.norm((a @ b) - out) / la.norm(out)}")
121+
print(f"Relative error = {la.norm((a @ b) - out) / la.norm(a @ b)}")
120122
print((40 + len(" Tiled matmul report "))*"=")
121123

122124
if print_device_code:

examples/python/compute-examples/finite-difference-2-5D.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -76,11 +76,10 @@ def main(
7676
compute_map = nisl.make_map(
7777
f"""
7878
{{
79-
[is, js, ks] -> [io, ii_s, jo, ji_s, k_s] :
80-
0 <= ii_s < {bm} and 0 <= ji_s < {bn} and 0 <= k_s < {npts} and
81-
is = io * {bm} + ii_s and
82-
js = jo * {bn} + ji_s and
83-
ks = k_s
79+
[is, js, ks] -> [io, ii_s, jo, ji_s, k] :
80+
is = io * {bm} + ii_s - {r} and
81+
js = jo * {bn} + ji_s - {r} and
82+
ks = k
8483
}}
8584
"""
8685
)
@@ -89,8 +88,8 @@ def main(
8988
knl,
9089
"u_",
9190
compute_map=compute_map,
92-
storage_indices=["ii_s", "ji_s", "k_s"],
93-
temporal_inames=["io", "jo"],
91+
storage_indices=["ii_s", "ji_s"],
92+
temporal_inames=["io", "jo", "k"],
9493
temporary_name="u_compute",
9594
temporary_address_space=lp.AddressSpace.LOCAL,
9695
temporary_dtype=np.float32

loopy/transform/compute.py

Lines changed: 37 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -271,7 +271,7 @@ def compute(
271271
"""
272272
Inserts an instruction to compute an expression given by :arg:`substitution`
273273
and replaces all invocations of :arg:`substitution` with the result of the
274-
compute instruction.
274+
inserted compute instruction.
275275
276276
:arg substitution: The substitution rule for which the compute
277277
transform should be applied.
@@ -283,8 +283,6 @@ def compute(
283283
:arg storage_indices: An ordered sequence of names of storage indices. Used
284284
to create inames for the loops that cover the required set of compute points.
285285
"""
286-
# FIXME: use namedisl directly
287-
compute_map = compute_map._reconstruct_isl_object()
288286

289287
# {{{ construct necessary pieces; footprint, global usage map
290288

@@ -295,125 +293,74 @@ def compute(
295293
ctx, expander, kernel, substitution, None
296294
)
297295

296+
# add compute inames to domain / kernel
297+
domain_changer = DomainChanger(kernel, kernel.all_inames())
298+
named_domain = nisl.make_basic_set(domain_changer.domain)
299+
298300
_ = expr_gatherer.map_kernel(kernel)
299301
usage_exprs = expr_gatherer.usage_expressions
300302

301-
all_exprs = [
302-
expr
303-
for usage in usage_exprs
304-
for expr in usage
305-
]
303+
all_exprs = [expr for usage in usage_exprs for expr in usage]
304+
usage_inames = set.union(*(gather_vars(expr) for expr in all_exprs))
306305

307-
space = space_from_exprs(all_exprs)
306+
usage_domain = nisl.make_set(f"{{ [{",".join(iname for iname in usage_inames)}] }}")
307+
footprint = nisl.make_set(f"{{ [{",".join(idx for idx in storage_indices)}] }}")
308308

309-
footprint = isl.Set.empty(
310-
isl.Space.create_from_names(
311-
ctx=space.get_ctx(),
312-
set=list(storage_indices)
313-
)
309+
global_usage_map = nisl.make_map_from_domain_and_range(
310+
usage_domain,
311+
compute_map.domain()
314312
)
315313

316-
# add compute inames to domain / kernel
317-
domain_changer = DomainChanger(kernel, kernel.all_inames())
318-
domain = domain_changer.domain
319-
320-
range_space = isl.Space.create_from_names(
321-
ctx=space.get_ctx(),
322-
set=list(storage_indices)
323-
)
324-
map_space = space.map_from_domain_and_range(range_space)
325-
global_usage_map = isl.Map.empty(map_space)
314+
global_usage_map = nisl.make_map(isl.Map.empty(global_usage_map.get_space()))
326315

316+
usage_substs: Mapping[AccessTuple, nisl.Map] = {}
327317
for usage in usage_exprs:
328318

329319
# FIXME package sequence of pymbolic exprs -> multipwaff up as a function in loopy.symbolic
330-
local_usage_mpwaff = isl.MultiPwAff.zero(map_space)
320+
local_usage_mpwaff = isl.MultiPwAff.zero(global_usage_map.get_space())
331321

332322
for i in range(len(storage_indices)):
323+
local_space = local_usage_mpwaff.get_at(i).get_space().domain()
333324
local_usage_mpwaff = local_usage_mpwaff.set_pw_aff(
334325
i,
335-
pwaff_from_expr(space, usage[i])
326+
pwaff_from_expr(local_space, usage[i])
336327
)
337328

338-
local_usage_map = local_usage_mpwaff.as_map()
329+
local_usage_map = nisl.make_map(local_usage_mpwaff.as_map())
339330

340-
# FIXME: fix with namedisl
341-
# remove unnecessary names from domain and intersect with usage map
342-
usage_names = frozenset(
343-
local_usage_map.get_dim_name(isl.dim_type.in_, dim)
344-
for dim in range(local_usage_map.dim(isl.dim_type.in_))
345-
)
346-
347-
domain_names = frozenset(
348-
domain.get_dim_name(isl.dim_type.set, dim)
349-
for dim in range(domain.dim(isl.dim_type.set))
350-
)
351-
352-
domain_tmp = domain.project_out_except(
353-
usage_names & domain_names, [isl.dim_type.set]
354-
)
355-
356-
local_usage_map = align_map_domain_to_set(local_usage_map, domain_tmp)
357-
local_usage_map = local_usage_map.intersect_domain(domain_tmp)
331+
local_usage_map = local_usage_map.intersect_domain(named_domain)
358332
global_usage_map = global_usage_map | local_usage_map
359333

360-
# {{{ FIXME: this shouldn't need to be done here; will be handled by namedisl
334+
local_storage_map = local_usage_map.apply_range(compute_map)
335+
relevant_names = gather_vars(usage)
361336

362-
global_usage_map = global_usage_map.apply_range(compute_map)
363-
common_dims = {
364-
dim1 : dim2
365-
for dim1 in range(global_usage_map.dim(isl.dim_type.in_))
366-
for dim2 in range(global_usage_map.dim(isl.dim_type.out))
367-
if (
368-
global_usage_map.get_dim_name(isl.dim_type.in_, dim1)
369-
==
370-
global_usage_map.get_dim_name(isl.dim_type.out, dim2)
337+
local_storage_map = local_storage_map.project_out_except(
338+
(relevant_names - frozenset(temporal_inames)) | frozenset(storage_indices)
371339
)
372-
}
373340

374-
for pos1, pos2 in common_dims.items():
375-
global_usage_map = global_usage_map.equate(
376-
isl.dim_type.in_, pos1,
377-
isl.dim_type.out, pos2
378-
)
341+
usage_substs[tuple(usage)] = local_storage_map
379342

380-
# }}}
343+
global_usage_map = global_usage_map.apply_range(compute_map)
381344

382345
# }}}
383346

384347
# {{{ compute bounds and update kernel domain
385348

386349
footprint = global_usage_map.range()
387-
footprint_tmp, domain = isl.align_two(footprint, domain)
388-
domain = (domain & footprint_tmp).get_basic_sets()[0]
389-
390-
new_domains = domain_changer.get_domains_with(domain)
391-
kernel = kernel.copy(domains=new_domains)
392-
393-
# }}}
394-
395-
# {{{ compute index expressions
396-
397-
usage_substs: Mapping[AccessTuple, isl.Map] = {}
398-
for usage in usage_exprs:
399-
# find the relevant names
400-
relevant_names = gather_vars(usage)
401-
402-
# project out irrelevant names
403-
relevant_names = frozenset(relevant_names) - frozenset(temporal_inames)
350+
footprint = footprint.project_out_except(
351+
frozenset(temporal_inames) | frozenset(storage_indices)
352+
)
404353

405-
local_iname_to_storage = global_usage_map.project_out_except(
406-
relevant_names,
407-
[isl.dim_type.in_]
408-
)
354+
# FIXME: probably do not want this permanently here
355+
footprint = nisl.make_set(footprint._reconstruct_isl_object().convex_hull())
356+
named_domain = named_domain & footprint
409357

410-
local_iname_to_storage = local_iname_to_storage.project_out_except(
411-
storage_indices,
412-
[isl.dim_type.out]
413-
)
358+
# FIXME:
359+
if len(named_domain.get_basic_sets()) != 1:
360+
raise ValueError("New domain should be composed of a single basic set")
414361

415-
# map usage -> resulting map
416-
usage_substs[tuple(usage)] = local_iname_to_storage
362+
new_domains = domain_changer.get_domains_with(named_domain.get_basic_sets()[0])
363+
kernel = kernel.copy(domains=new_domains)
417364

418365
# }}}
419366

@@ -487,9 +434,7 @@ def compute(
487434
loopy_type = to_loopy_type(temporary_dtype, allow_none=True)
488435

489436
# FIXME: fix with namedisl?
490-
shape_domain = footprint.project_out_except(storage_indices,
491-
[isl.dim_type.set])
492-
shape_domain = shape_domain.project_out_except("", [isl.dim_type.param])
437+
shape_domain = footprint.project_out_except(storage_indices)
493438

494439
temp_shape = tuple(
495440
pw_aff_to_expr(shape_domain.dim_max(dim)) + 1

test/test_transform.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
THE SOFTWARE.
2121
"""
2222

23+
from collections.abc import Mapping
2324
import logging
2425

2526
import numpy as np
@@ -1745,6 +1746,108 @@ def test_duplicate_iname_not_read_only_nested(ctx_factory: cl.CtxFactory):
17451746
lp.auto_test_vs_ref(ref_t_unit, ctx, t_unit)
17461747

17471748

1749+
@pytest.mark.parametrize("case", (
1750+
{"M": 128, "N": 128, "K": 128, "BM": 32, "BN": 32, "BK": 16},
1751+
{"M": 200, "N": 200, "K": 200, "BM": 32, "BN": 32, "BK": 16},
1752+
))
1753+
def test_compute_simple_tiled_matmul(
1754+
ctx_factory: cl.CtxFactory,
1755+
case: Mapping[str, int]
1756+
):
1757+
1758+
import namedisl as nisl
1759+
1760+
M = case["M"]
1761+
N = case["N"]
1762+
K = case["K"]
1763+
bm = case["BM"]
1764+
bn = case["BN"]
1765+
bk = case["BK"]
1766+
1767+
knl = lp.make_kernel(
1768+
"{ [i, j, k] : 0 <= i < M and 0 <= j < N and 0 <= k < K }",
1769+
"""
1770+
a_(is, ks) := a[is, ks]
1771+
b_(ks, js) := b[ks, js]
1772+
c[i, j] = sum([k], a_(i, k) * b_(k, j))
1773+
""",
1774+
[
1775+
lp.GlobalArg("a", shape=(M, K), dtype=np.float64),
1776+
lp.GlobalArg("b", shape=(K, N), dtype=np.float64),
1777+
lp.GlobalArg("c", shape=(M, N), dtype=np.float64,
1778+
is_output=True)
1779+
]
1780+
)
1781+
1782+
knl = lp.fix_parameters(knl, M=M, N=N, K=K)
1783+
1784+
# shared memory tile-level splitting
1785+
knl = lp.split_iname(knl, "i", bm, inner_iname="ii", outer_iname="io")
1786+
knl = lp.split_iname(knl, "j", bn, inner_iname="ji", outer_iname="jo")
1787+
knl = lp.split_iname(knl, "k", bk, inner_iname="ki", outer_iname="ko")
1788+
1789+
compute_map_a = nisl.make_map(f"""{{
1790+
[is, ks] -> [ii_s, io, ki_s, ko] :
1791+
is = io * {bm} + ii_s and
1792+
ks = ko * {bk} + ki_s
1793+
}}""")
1794+
1795+
compute_map_b = nisl.make_map(f"""{{
1796+
[ks, js] -> [ki_s, ko, ji_s, jo] :
1797+
js = jo * {bn} + ji_s and
1798+
ks = ko * {bk} + ki_s
1799+
}}""")
1800+
1801+
from loopy.transform.compute import compute
1802+
knl = compute(
1803+
knl,
1804+
"a_",
1805+
compute_map=compute_map_a,
1806+
storage_indices=["ii_s", "ki_s"],
1807+
temporal_inames=["io", "ko", "jo"],
1808+
temporary_address_space=lp.AddressSpace.LOCAL,
1809+
temporary_dtype=np.float64
1810+
)
1811+
1812+
knl = compute(
1813+
knl,
1814+
"b_",
1815+
compute_map=compute_map_b,
1816+
storage_indices=["ki_s", "ji_s"],
1817+
temporal_inames=["io", "ko", "jo"],
1818+
temporary_address_space=lp.AddressSpace.LOCAL,
1819+
temporary_dtype=np.float64
1820+
)
1821+
1822+
knl = lp.tag_inames(
1823+
knl, {
1824+
"io" : "g.0", # outer block loop over block rows
1825+
"jo" : "g.1", # outer block loop over block cols
1826+
1827+
"ii" : "l.0", # inner block loop over rows
1828+
"ji" : "l.1", # inner block loop over cols
1829+
1830+
"ii_s" : "l.0", # inner storage loop over a rows
1831+
"ji_s" : "l.0", # inner storage loop over b cols
1832+
"ki_s" : "l.1" # inner storage loop over a cols / b rows
1833+
}
1834+
)
1835+
1836+
knl = lp.add_inames_for_unused_hw_axes(knl)
1837+
1838+
ctx = ctx_factory()
1839+
queue = cl.CommandQueue(ctx)
1840+
1841+
a = np.random.randn(M, K)
1842+
b = np.random.randn(K, N)
1843+
1844+
ex = knl.executor(ctx)
1845+
_, out = ex(queue, a=a, b=b)
1846+
1847+
import numpy.linalg as la
1848+
assert (la.norm((a @ b) - out) / la.norm(a @ b)) < 1e-15
1849+
1850+
17481851
if __name__ == "__main__":
17491852
import sys
17501853
if len(sys.argv) > 1:

0 commit comments

Comments
 (0)