Skip to content

Commit 56af4fe

Browse files
committed
invocation replacement; dependencies still need handling
1 parent 7265536 commit 56af4fe

File tree

1 file changed

+248
-43
lines changed

1 file changed

+248
-43
lines changed

loopy/transform/compute.py

Lines changed: 248 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -1,30 +1,35 @@
1+
from collections.abc import Sequence, Set
2+
from dataclasses import dataclass
3+
from typing import override
14
import loopy as lp
25
from loopy.kernel.tools import DomainChanger
36
import namedisl as nisl
47

58
from loopy.kernel import LoopKernel
6-
from loopy.kernel.data import AddressSpace
7-
from loopy.match import parse_stack_match
9+
from loopy.kernel.data import AddressSpace, SubstitutionRule
10+
from loopy.match import StackMatch, parse_stack_match
811
from loopy.symbolic import (
12+
ExpansionState,
913
RuleAwareIdentityMapper,
1014
RuleAwareSubstitutionMapper,
15+
SubstitutionRuleExpander,
1116
SubstitutionRuleMappingContext,
17+
get_dependencies,
1218
pw_aff_to_expr,
1319
pwaff_from_expr
1420
)
1521
from loopy.transform.precompute import (
16-
RuleInvocationGatherer,
1722
contains_a_subst_rule_invocation
1823
)
1924
from loopy.translation_unit import for_each_kernel
20-
from pymbolic import var
25+
from pymbolic import ArithmeticExpression, var
2126
from pymbolic.mapper.substitutor import make_subst_func
2227

2328
import islpy as isl
2429
import pymbolic.primitives as p
2530
from pymbolic.mapper.dependency import DependencyMapper
26-
27-
from pymbolic.mapper import IdentityMapper
31+
from pymbolic.typing import Expression
32+
from pytools.tag import Tag
2833

2934

3035
def gather_vars(expr) -> set[str]:
@@ -35,6 +40,7 @@ def gather_vars(expr) -> set[str]:
3540
if isinstance(dep, p.Variable)
3641
}
3742

43+
3844
def space_from_exprs(exprs, ctx=isl.DEFAULT_CONTEXT):
3945
names = sorted(set().union(*(gather_vars(expr) for expr in exprs)))
4046
set_names = [name for name in names]
@@ -44,12 +50,181 @@ def space_from_exprs(exprs, ctx=isl.DEFAULT_CONTEXT):
4450
set=set_names
4551
)
4652

53+
54+
@dataclass(frozen=True)
55+
class UsageDescriptor:
56+
usage: Sequence[Expression]
57+
global_map: isl.Map
58+
local_map: isl.Map
59+
60+
@override
61+
def __str__(self):
62+
return (
63+
f"USAGE = {self.usage}\n" +
64+
f"GLOBAL MAP = {self.global_map}\n" +
65+
f"LOCAL MAP = {self.local_map}"
66+
)
67+
68+
69+
class UsageSiteExpressionGatherer(RuleAwareIdentityMapper[[]]):
70+
"""
71+
Gathers all expressions used as inputs to a particular substitution rule,
72+
identified by name.
73+
"""
74+
def __init__(
75+
self,
76+
rule_mapping_ctx: SubstitutionRuleMappingContext,
77+
subst_expander: SubstitutionRuleExpander,
78+
kernel: LoopKernel,
79+
subst_name: str,
80+
subst_tag: Set[Tag] | Tag | None = None
81+
) -> None:
82+
83+
super().__init__(rule_mapping_ctx)
84+
85+
self.subst_expander: SubstitutionRuleExpander = subst_expander
86+
self.kernel: LoopKernel = kernel
87+
self.subst_name: str = subst_name
88+
self.subst_tag: Set[Tag] | None = (
89+
{subst_tag} if isinstance(subst_tag, Tag) else subst_tag
90+
)
91+
92+
self.usage_expressions: list[Sequence[Expression]] = []
93+
94+
95+
@override
96+
def map_subst_rule(
97+
self,
98+
name: str,
99+
tags: Set[Tag] | None,
100+
arguments: Sequence[Expression],
101+
expn_state: ExpansionState,
102+
) -> Expression:
103+
104+
if name != self.subst_name:
105+
return super().map_subst_rule(
106+
name, tags, arguments, expn_state
107+
)
108+
109+
if self.subst_tag is not None and self.subst_tag != tags:
110+
return super().map_subst_rule(
111+
name, tags, arguments, expn_state
112+
)
113+
114+
rule = self.rule_mapping_context.old_subst_rules[name]
115+
arg_ctx = self.make_new_arg_context(
116+
name, rule.arguments, arguments, expn_state.arg_context
117+
)
118+
119+
self.usage_expressions.append([
120+
arg_ctx[arg_name] for arg_name in rule.arguments
121+
])
122+
123+
return 0
124+
125+
126+
class RuleInvocationReplacer(RuleAwareIdentityMapper[[]]):
127+
def __init__(
128+
self,
129+
ctx: SubstitutionRuleMappingContext,
130+
subst_name: str,
131+
subst_tag: Sequence[Tag] | None,
132+
usage_descriptors: Sequence[UsageDescriptor],
133+
storage_indices: Sequence[str],
134+
temporary_name: str,
135+
compute_insn_id: str,
136+
compute_map: isl.Map
137+
) -> None:
138+
139+
super().__init__(ctx)
140+
141+
self.subst_name: str = subst_name
142+
self.subst_tag: Sequence[Tag] | None = subst_tag
143+
144+
self.usage_descriptors: Sequence[UsageDescriptor] = usage_descriptors
145+
self.storage_indices: Sequence[str] = storage_indices
146+
147+
self.temporary_name: str = temporary_name
148+
self.compute_insn_id: str = compute_insn_id
149+
150+
151+
@override
152+
def map_subst_rule(
153+
self,
154+
name: str,
155+
tags: Set[Tag] | None,
156+
arguments: Sequence[Expression],
157+
expn_state: ExpansionState
158+
) -> Expression:
159+
160+
if not name == self.subst_name:
161+
return super().map_subst_rule(name, tags, arguments, expn_state)
162+
163+
rule = self.rule_mapping_context.old_subst_rules[name]
164+
arg_ctx = self.make_new_arg_context(
165+
name, rule.arguments, arguments, expn_state.arg_context
166+
)
167+
args = [arg_ctx[arg_name] for arg_name in rule.arguments]
168+
169+
# FIXME: footprint check? likely required if user supplies bounds on
170+
# storage indices because we are not guaranteed to capture the footprint
171+
# of all usage sites
172+
173+
if not len(arguments) == len(rule.arguments):
174+
raise ValueError("Number of arguments passed to rule {name} ",
175+
"does not match the signature of {name}.")
176+
177+
index_exprs: Sequence[Expression] = []
178+
for usage_descr in self.usage_descriptors:
179+
if args == usage_descr.usage:
180+
local_pwmaff = usage_descr.local_map.as_pw_multi_aff()
181+
182+
for dim in range(local_pwmaff.dim(isl.dim_type.out)):
183+
index_exprs.append(pw_aff_to_expr(local_pwmaff.get_at(dim)))
184+
185+
break
186+
187+
new_expression = var(self.temporary_name)[tuple(index_exprs)]
188+
189+
return new_expression
190+
191+
192+
@override
193+
def map_kernel(
194+
self,
195+
kernel: LoopKernel,
196+
within: StackMatch = lambda knl, insn, stack: True,
197+
map_args: bool = True,
198+
map_tvs: bool = True
199+
) -> LoopKernel:
200+
201+
new_insns = []
202+
for insn in kernel.instructions:
203+
if (isinstance(insn, lp.MultiAssignmentBase) and not
204+
contains_a_subst_rule_invocation(kernel, insn)):
205+
new_insns.append(insn)
206+
continue
207+
208+
insn = insn.with_transformed_expressions(
209+
lambda expr: self(expr, kernel, insn)
210+
)
211+
212+
new_insns.append(insn)
213+
214+
return kernel.copy(instructions=new_insns)
215+
216+
47217
@for_each_kernel
48218
def compute(
49219
kernel: LoopKernel,
50220
substitution: str,
51221
compute_map: nisl.Map,
52-
storage_indices: frozenset[str],
222+
storage_indices: Sequence[str],
223+
224+
# NOTE: how can we deduce this?
225+
temporal_inames: Sequence[str],
226+
227+
temporary_name: str | None = None,
53228
temporary_address_space: AddressSpace | None = None
54229
) -> LoopKernel:
55230
"""
@@ -64,65 +239,76 @@ def compute(
64239
substitution rule indices and tuples `(a, l)`, where `a` is a vector of
65240
storage indices and `l` is a vector of "timestamps".
66241
67-
:arg storage_indices: A :class:`frozenset` of names of storage indices. Used
68-
to create inames for the loops that cover the required footprint.
242+
:arg storage_indices: An ordered sequence of names of storage indices. Used
243+
to create inames for the loops that cover the required set of compute points.
69244
"""
70245
compute_map = compute_map._reconstruct_isl_object()
71246

72247
# construct union of usage footprints to determine bounds on compute inames
73248
ctx = SubstitutionRuleMappingContext(
74249
kernel.substitutions, kernel.get_var_name_generator())
75-
inv_gatherer = RuleInvocationGatherer(
76-
ctx, kernel, substitution, None, parse_stack_match(None)
250+
expander = SubstitutionRuleExpander(kernel.substitutions)
251+
expr_gatherer = UsageSiteExpressionGatherer(
252+
ctx, expander, kernel, substitution, None
77253
)
78254

79-
for insn in kernel.instructions:
80-
if (isinstance(insn, lp.MultiAssignmentBase) and
81-
contains_a_subst_rule_invocation(kernel, insn)):
82-
for assignee in insn.assignees:
83-
_ = inv_gatherer(assignee, kernel, insn)
84-
_ = inv_gatherer(insn.expression, kernel, insn)
255+
_ = expr_gatherer.map_kernel(kernel)
256+
usage_exprs = expr_gatherer.usage_expressions
85257

86-
access_descriptors = inv_gatherer.access_descriptors
87-
88-
acc_desc_exprs = [
89-
arg
90-
for ad in access_descriptors
91-
if ad.args is not None
92-
for arg in ad.args
258+
all_exprs = [
259+
expr
260+
for usage in usage_exprs
261+
for expr in usage
93262
]
94263

95-
space = space_from_exprs(acc_desc_exprs)
264+
space = space_from_exprs(all_exprs)
96265

97-
footprint = isl.Set.empty(isl.Space.create_from_names(
98-
ctx=space.get_ctx(),
99-
set=list(storage_indices)
100-
))
101-
for ad in access_descriptors:
102-
if not ad.args:
103-
continue
266+
footprint = isl.Set.empty(
267+
isl.Space.create_from_names(
268+
ctx=space.get_ctx(),
269+
set=list(storage_indices)
270+
)
271+
)
104272

105-
nout = len(ad.args)
273+
usage_descrs: Sequence[UsageDescriptor] = []
274+
for usage in usage_exprs:
106275

107-
range_space = isl.Space.alloc(space.get_ctx(), 0, nout, 0).domain()
276+
range_space = isl.Space.create_from_names(
277+
ctx=space.get_ctx(),
278+
set=list(storage_indices)
279+
)
108280
map_space = space.map_from_domain_and_range(range_space)
281+
109282
pw_multi_aff = isl.MultiPwAff.zero(map_space)
110283

111-
for i, arg in enumerate(ad.args):
112-
if arg is not None:
113-
pw_multi_aff = pw_multi_aff.set_pw_aff(
114-
i,
115-
pwaff_from_expr(space, arg)
116-
)
284+
for i, arg in enumerate(usage):
285+
pw_multi_aff = pw_multi_aff.set_pw_aff(
286+
i,
287+
pwaff_from_expr(space, arg)
288+
)
117289

118290
usage_map = pw_multi_aff.as_map()
119-
iname_to_timespace = usage_map.apply_range(compute_map).coalesce()
291+
292+
iname_to_timespace = usage_map.apply_range(compute_map)
120293
iname_to_storage = iname_to_timespace.project_out_except(
121294
storage_indices, [isl.dim_type.out]
122295
)
123296

297+
local_map = iname_to_storage.project_out_except(
298+
kernel.all_inames() - frozenset(temporal_inames),
299+
[isl.dim_type.in_]
300+
)
301+
124302
footprint = footprint | iname_to_storage.range()
125303

304+
usage_descrs.append(
305+
UsageDescriptor(
306+
usage,
307+
iname_to_storage,
308+
local_map
309+
)
310+
)
311+
126312
# add compute inames to domain / kernel
127313
domain_changer = DomainChanger(kernel, kernel.all_inames())
128314
domain = domain_changer.domain
@@ -138,7 +324,7 @@ def compute(
138324
storage_ax_to_global_expr = {
139325
compute_pw_aff.get_dim_name(isl.dim_type.out, dim) :
140326
pw_aff_to_expr(compute_pw_aff.get_at(dim))
141-
for dim in range(compute_pw_aff.dim(isl.dim_type.out))
327+
for dim in range(compute_pw_aff.dim(isl.dim_type.out))
142328
}
143329

144330
expr_subst_map = RuleAwareSubstitutionMapper(
@@ -150,7 +336,9 @@ def compute(
150336
subst_expr = kernel.substitutions[substitution].expression
151337
compute_expression = expr_subst_map(subst_expr, kernel, None)
152338

153-
temporary_name = substitution + "_temp"
339+
if not temporary_name:
340+
temporary_name = substitution + "_temp"
341+
154342
assignee = var(temporary_name)[tuple(
155343
var(iname) for iname in storage_indices
156344
)]
@@ -172,5 +360,22 @@ def compute(
172360
new_insns.append(compute_insn)
173361
kernel = kernel.copy(instructions=new_insns)
174362

363+
ctx = SubstitutionRuleMappingContext(
364+
kernel.substitutions, kernel.get_var_name_generator()
365+
)
366+
367+
replacer = RuleInvocationReplacer(
368+
ctx,
369+
substitution,
370+
None,
371+
usage_descrs,
372+
storage_indices,
373+
temporary_name,
374+
compute_insn_id,
375+
compute_map
376+
)
377+
378+
kernel = replacer.map_kernel(kernel)
379+
175380
print(kernel)
176381
return kernel

0 commit comments

Comments
 (0)