Skip to content

Commit 9d12183

Browse files
committed
rough sketch of compute transform; inames not schedulable because of duplicates
1 parent 56af4fe commit 9d12183

File tree

1 file changed

+63
-37
lines changed

1 file changed

+63
-37
lines changed

loopy/transform/compute.py

Lines changed: 63 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
1-
from collections.abc import Sequence, Set
1+
from collections.abc import Mapping, Sequence, Set
22
from dataclasses import dataclass
33
from typing import override
44
import loopy as lp
55
from loopy.kernel.tools import DomainChanger
6+
from loopy.types import to_loopy_type
67
import namedisl as nisl
78

89
from loopy.kernel import LoopKernel
@@ -50,22 +51,6 @@ def space_from_exprs(exprs, ctx=isl.DEFAULT_CONTEXT):
5051
set=set_names
5152
)
5253

53-
54-
@dataclass(frozen=True)
55-
class UsageDescriptor:
56-
usage: Sequence[Expression]
57-
global_map: isl.Map
58-
local_map: isl.Map
59-
60-
@override
61-
def __str__(self):
62-
return (
63-
f"USAGE = {self.usage}\n" +
64-
f"GLOBAL MAP = {self.global_map}\n" +
65-
f"LOCAL MAP = {self.local_map}"
66-
)
67-
68-
6954
class UsageSiteExpressionGatherer(RuleAwareIdentityMapper[[]]):
7055
"""
7156
Gathers all expressions used as inputs to a particular substitution rule,
@@ -129,7 +114,7 @@ def __init__(
129114
ctx: SubstitutionRuleMappingContext,
130115
subst_name: str,
131116
subst_tag: Sequence[Tag] | None,
132-
usage_descriptors: Sequence[UsageDescriptor],
117+
usage_descriptors: Mapping[tuple[Expression, ...], isl.Map],
133118
storage_indices: Sequence[str],
134119
temporary_name: str,
135120
compute_insn_id: str,
@@ -141,12 +126,19 @@ def __init__(
141126
self.subst_name: str = subst_name
142127
self.subst_tag: Sequence[Tag] | None = subst_tag
143128

144-
self.usage_descriptors: Sequence[UsageDescriptor] = usage_descriptors
129+
self.usage_descriptors: Mapping[tuple[Expression, ...], isl.Map] = \
130+
usage_descriptors
145131
self.storage_indices: Sequence[str] = storage_indices
146132

147133
self.temporary_name: str = temporary_name
148134
self.compute_insn_id: str = compute_insn_id
149135

136+
# FIXME: may not always be the case (i.e. global barrier between
137+
# compute insn and uses)
138+
self.compute_dep_id: str = compute_insn_id
139+
140+
self.replaced_something: bool = False
141+
150142

151143
@override
152144
def map_subst_rule(
@@ -175,17 +167,17 @@ def map_subst_rule(
175167
"does not match the signature of {name}.")
176168

177169
index_exprs: Sequence[Expression] = []
178-
for usage_descr in self.usage_descriptors:
179-
if args == usage_descr.usage:
180-
local_pwmaff = usage_descr.local_map.as_pw_multi_aff()
181170

182-
for dim in range(local_pwmaff.dim(isl.dim_type.out)):
183-
index_exprs.append(pw_aff_to_expr(local_pwmaff.get_at(dim)))
171+
# FIXME: make self.usage_descriptors a constantdict
172+
local_pwmaff = self.usage_descriptors[tuple(args)].as_pw_multi_aff()
184173

185-
break
174+
for dim in range(local_pwmaff.dim(isl.dim_type.out)):
175+
index_exprs.append(pw_aff_to_expr(local_pwmaff.get_at(dim)))
186176

187177
new_expression = var(self.temporary_name)[tuple(index_exprs)]
188178

179+
self.replaced_something = True
180+
189181
return new_expression
190182

191183

@@ -198,8 +190,10 @@ def map_kernel(
198190
map_tvs: bool = True
199191
) -> LoopKernel:
200192

201-
new_insns = []
193+
new_insns: Sequence[lp.InstructionBase] = []
202194
for insn in kernel.instructions:
195+
self.replaced_something = False
196+
203197
if (isinstance(insn, lp.MultiAssignmentBase) and not
204198
contains_a_subst_rule_invocation(kernel, insn)):
205199
new_insns.append(insn)
@@ -209,6 +203,15 @@ def map_kernel(
209203
lambda expr: self(expr, kernel, insn)
210204
)
211205

206+
if self.replaced_something:
207+
insn = insn.copy(
208+
depends_on=(
209+
insn.depends_on | frozenset([self.compute_insn_id])
210+
)
211+
)
212+
213+
# FIXME: determine compute insn dependencies
214+
212215
new_insns.append(insn)
213216

214217
return kernel.copy(instructions=new_insns)
@@ -270,7 +273,7 @@ def compute(
270273
)
271274
)
272275

273-
usage_descrs: Sequence[UsageDescriptor] = []
276+
usage_descrs: Mapping[tuple[Expression, ...], isl.Map] = {}
274277
for usage in usage_exprs:
275278

276279
range_space = isl.Space.create_from_names(
@@ -301,20 +304,14 @@ def compute(
301304

302305
footprint = footprint | iname_to_storage.range()
303306

304-
usage_descrs.append(
305-
UsageDescriptor(
306-
usage,
307-
iname_to_storage,
308-
local_map
309-
)
310-
)
307+
usage_descrs[tuple(usage)] = local_map
311308

312309
# add compute inames to domain / kernel
313310
domain_changer = DomainChanger(kernel, kernel.all_inames())
314311
domain = domain_changer.domain
315312

316-
footprint, domain = isl.align_two(footprint, domain)
317-
domain = domain & footprint
313+
footprint_tmp, domain = isl.align_two(footprint, domain)
314+
domain = domain & footprint_tmp
318315

319316
new_domains = domain_changer.get_domains_with(domain)
320317
kernel = kernel.copy(domains=new_domains)
@@ -377,5 +374,34 @@ def compute(
377374

378375
kernel = replacer.map_kernel(kernel)
379376

380-
print(kernel)
377+
# FIXME: accept dtype as an argument
378+
import numpy as np
379+
loopy_type = to_loopy_type(np.float64, allow_none=True)
380+
381+
# WARNING: this can result in symbolic shapes, is that allowed?
382+
temp_shape = tuple(
383+
pw_aff_to_expr(footprint.dim_max(dim)) + 1
384+
for dim in range(footprint.dim(isl.dim_type.out))
385+
)
386+
387+
new_temp_vars = dict(kernel.temporary_variables)
388+
389+
# FIXME: temp_var might already exist, handle the case where it does
390+
temp_var = lp.TemporaryVariable(
391+
name=temporary_name,
392+
dtype=loopy_type,
393+
base_indices=(0,)*len(temp_shape),
394+
shape=temp_shape,
395+
address_space=temporary_address_space,
396+
dim_names=tuple(storage_indices)
397+
)
398+
399+
new_temp_vars[temporary_name] = temp_var
400+
401+
kernel = kernel.copy(
402+
temporary_variables=new_temp_vars
403+
)
404+
405+
# FIXME: handle iname tagging
406+
381407
return kernel

0 commit comments

Comments
 (0)