1+ from collections .abc import Sequence , Set
2+ from dataclasses import dataclass
3+ from typing import override
14import loopy as lp
25from loopy .kernel .tools import DomainChanger
36import namedisl as nisl
47
58from loopy .kernel import LoopKernel
6- from loopy .kernel .data import AddressSpace
7- from loopy .match import parse_stack_match
9+ from loopy .kernel .data import AddressSpace , SubstitutionRule
10+ from loopy .match import StackMatch , parse_stack_match
811from loopy .symbolic import (
12+ ExpansionState ,
913 RuleAwareIdentityMapper ,
1014 RuleAwareSubstitutionMapper ,
15+ SubstitutionRuleExpander ,
1116 SubstitutionRuleMappingContext ,
17+ get_dependencies ,
1218 pw_aff_to_expr ,
1319 pwaff_from_expr
1420)
1521from loopy .transform .precompute import (
16- RuleInvocationGatherer ,
1722 contains_a_subst_rule_invocation
1823)
1924from loopy .translation_unit import for_each_kernel
20- from pymbolic import var
25+ from pymbolic import ArithmeticExpression , var
2126from pymbolic .mapper .substitutor import make_subst_func
2227
2328import islpy as isl
2429import pymbolic .primitives as p
2530from pymbolic .mapper .dependency import DependencyMapper
26-
27- from pymbolic . mapper import IdentityMapper
31+ from pymbolic . typing import Expression
32+ from pytools . tag import Tag
2833
2934
3035def gather_vars (expr ) -> set [str ]:
@@ -35,6 +40,7 @@ def gather_vars(expr) -> set[str]:
3540 if isinstance (dep , p .Variable )
3641 }
3742
43+
3844def space_from_exprs (exprs , ctx = isl .DEFAULT_CONTEXT ):
3945 names = sorted (set ().union (* (gather_vars (expr ) for expr in exprs )))
4046 set_names = [name for name in names ]
@@ -44,12 +50,181 @@ def space_from_exprs(exprs, ctx=isl.DEFAULT_CONTEXT):
4450 set = set_names
4551 )
4652
53+
54+ @dataclass (frozen = True )
55+ class UsageDescriptor :
56+ usage : Sequence [Expression ]
57+ global_map : isl .Map
58+ local_map : isl .Map
59+
60+ @override
61+ def __str__ (self ):
62+ return (
63+ f"USAGE = { self .usage } \n " +
64+ f"GLOBAL MAP = { self .global_map } \n " +
65+ f"LOCAL MAP = { self .local_map } "
66+ )
67+
68+
69+ class UsageSiteExpressionGatherer (RuleAwareIdentityMapper [[]]):
70+ """
71+ Gathers all expressions used as inputs to a particular substitution rule,
72+ identified by name.
73+ """
74+ def __init__ (
75+ self ,
76+ rule_mapping_ctx : SubstitutionRuleMappingContext ,
77+ subst_expander : SubstitutionRuleExpander ,
78+ kernel : LoopKernel ,
79+ subst_name : str ,
80+ subst_tag : Set [Tag ] | Tag | None = None
81+ ) -> None :
82+
83+ super ().__init__ (rule_mapping_ctx )
84+
85+ self .subst_expander : SubstitutionRuleExpander = subst_expander
86+ self .kernel : LoopKernel = kernel
87+ self .subst_name : str = subst_name
88+ self .subst_tag : Set [Tag ] | None = (
89+ {subst_tag } if isinstance (subst_tag , Tag ) else subst_tag
90+ )
91+
92+ self .usage_expressions : list [Sequence [Expression ]] = []
93+
94+
95+ @override
96+ def map_subst_rule (
97+ self ,
98+ name : str ,
99+ tags : Set [Tag ] | None ,
100+ arguments : Sequence [Expression ],
101+ expn_state : ExpansionState ,
102+ ) -> Expression :
103+
104+ if name != self .subst_name :
105+ return super ().map_subst_rule (
106+ name , tags , arguments , expn_state
107+ )
108+
109+ if self .subst_tag is not None and self .subst_tag != tags :
110+ return super ().map_subst_rule (
111+ name , tags , arguments , expn_state
112+ )
113+
114+ rule = self .rule_mapping_context .old_subst_rules [name ]
115+ arg_ctx = self .make_new_arg_context (
116+ name , rule .arguments , arguments , expn_state .arg_context
117+ )
118+
119+ self .usage_expressions .append ([
120+ arg_ctx [arg_name ] for arg_name in rule .arguments
121+ ])
122+
123+ return 0
124+
125+
126+ class RuleInvocationReplacer (RuleAwareIdentityMapper [[]]):
127+ def __init__ (
128+ self ,
129+ ctx : SubstitutionRuleMappingContext ,
130+ subst_name : str ,
131+ subst_tag : Sequence [Tag ] | None ,
132+ usage_descriptors : Sequence [UsageDescriptor ],
133+ storage_indices : Sequence [str ],
134+ temporary_name : str ,
135+ compute_insn_id : str ,
136+ compute_map : isl .Map
137+ ) -> None :
138+
139+ super ().__init__ (ctx )
140+
141+ self .subst_name : str = subst_name
142+ self .subst_tag : Sequence [Tag ] | None = subst_tag
143+
144+ self .usage_descriptors : Sequence [UsageDescriptor ] = usage_descriptors
145+ self .storage_indices : Sequence [str ] = storage_indices
146+
147+ self .temporary_name : str = temporary_name
148+ self .compute_insn_id : str = compute_insn_id
149+
150+
151+ @override
152+ def map_subst_rule (
153+ self ,
154+ name : str ,
155+ tags : Set [Tag ] | None ,
156+ arguments : Sequence [Expression ],
157+ expn_state : ExpansionState
158+ ) -> Expression :
159+
160+ if not name == self .subst_name :
161+ return super ().map_subst_rule (name , tags , arguments , expn_state )
162+
163+ rule = self .rule_mapping_context .old_subst_rules [name ]
164+ arg_ctx = self .make_new_arg_context (
165+ name , rule .arguments , arguments , expn_state .arg_context
166+ )
167+ args = [arg_ctx [arg_name ] for arg_name in rule .arguments ]
168+
169+ # FIXME: footprint check? likely required if user supplies bounds on
170+ # storage indices because we are not guaranteed to capture the footprint
171+ # of all usage sites
172+
173+ if not len (arguments ) == len (rule .arguments ):
174+ raise ValueError ("Number of arguments passed to rule {name} " ,
175+ "does not match the signature of {name}." )
176+
177+ index_exprs : Sequence [Expression ] = []
178+ for usage_descr in self .usage_descriptors :
179+ if args == usage_descr .usage :
180+ local_pwmaff = usage_descr .local_map .as_pw_multi_aff ()
181+
182+ for dim in range (local_pwmaff .dim (isl .dim_type .out )):
183+ index_exprs .append (pw_aff_to_expr (local_pwmaff .get_at (dim )))
184+
185+ break
186+
187+ new_expression = var (self .temporary_name )[tuple (index_exprs )]
188+
189+ return new_expression
190+
191+
192+ @override
193+ def map_kernel (
194+ self ,
195+ kernel : LoopKernel ,
196+ within : StackMatch = lambda knl , insn , stack : True ,
197+ map_args : bool = True ,
198+ map_tvs : bool = True
199+ ) -> LoopKernel :
200+
201+ new_insns = []
202+ for insn in kernel .instructions :
203+ if (isinstance (insn , lp .MultiAssignmentBase ) and not
204+ contains_a_subst_rule_invocation (kernel , insn )):
205+ new_insns .append (insn )
206+ continue
207+
208+ insn = insn .with_transformed_expressions (
209+ lambda expr : self (expr , kernel , insn )
210+ )
211+
212+ new_insns .append (insn )
213+
214+ return kernel .copy (instructions = new_insns )
215+
216+
47217@for_each_kernel
48218def compute (
49219 kernel : LoopKernel ,
50220 substitution : str ,
51221 compute_map : nisl .Map ,
52- storage_indices : frozenset [str ],
222+ storage_indices : Sequence [str ],
223+
224+ # NOTE: how can we deduce this?
225+ temporal_inames : Sequence [str ],
226+
227+ temporary_name : str | None = None ,
53228 temporary_address_space : AddressSpace | None = None
54229 ) -> LoopKernel :
55230 """
@@ -64,65 +239,76 @@ def compute(
64239 substitution rule indices and tuples `(a, l)`, where `a` is a vector of
65240 storage indices and `l` is a vector of "timestamps".
66241
67- :arg storage_indices: A :class:`frozenset` of names of storage indices. Used
68- to create inames for the loops that cover the required footprint .
242+ :arg storage_indices: An ordered sequence of names of storage indices. Used
243+ to create inames for the loops that cover the required set of compute points .
69244 """
70245 compute_map = compute_map ._reconstruct_isl_object ()
71246
72247 # construct union of usage footprints to determine bounds on compute inames
73248 ctx = SubstitutionRuleMappingContext (
74249 kernel .substitutions , kernel .get_var_name_generator ())
75- inv_gatherer = RuleInvocationGatherer (
76- ctx , kernel , substitution , None , parse_stack_match (None )
250+ expander = SubstitutionRuleExpander (kernel .substitutions )
251+ expr_gatherer = UsageSiteExpressionGatherer (
252+ ctx , expander , kernel , substitution , None
77253 )
78254
79- for insn in kernel .instructions :
80- if (isinstance (insn , lp .MultiAssignmentBase ) and
81- contains_a_subst_rule_invocation (kernel , insn )):
82- for assignee in insn .assignees :
83- _ = inv_gatherer (assignee , kernel , insn )
84- _ = inv_gatherer (insn .expression , kernel , insn )
255+ _ = expr_gatherer .map_kernel (kernel )
256+ usage_exprs = expr_gatherer .usage_expressions
85257
86- access_descriptors = inv_gatherer .access_descriptors
87-
88- acc_desc_exprs = [
89- arg
90- for ad in access_descriptors
91- if ad .args is not None
92- for arg in ad .args
258+ all_exprs = [
259+ expr
260+ for usage in usage_exprs
261+ for expr in usage
93262 ]
94263
95- space = space_from_exprs (acc_desc_exprs )
264+ space = space_from_exprs (all_exprs )
96265
97- footprint = isl .Set .empty (isl .Space .create_from_names (
98- ctx = space .get_ctx (),
99- set = list (storage_indices )
100- ))
101- for ad in access_descriptors :
102- if not ad .args :
103- continue
266+ footprint = isl .Set .empty (
267+ isl .Space .create_from_names (
268+ ctx = space .get_ctx (),
269+ set = list (storage_indices )
270+ )
271+ )
104272
105- nout = len (ad .args )
273+ usage_descrs : Sequence [UsageDescriptor ] = []
274+ for usage in usage_exprs :
106275
107- range_space = isl .Space .alloc (space .get_ctx (), 0 , nout , 0 ).domain ()
276+ range_space = isl .Space .create_from_names (
277+ ctx = space .get_ctx (),
278+ set = list (storage_indices )
279+ )
108280 map_space = space .map_from_domain_and_range (range_space )
281+
109282 pw_multi_aff = isl .MultiPwAff .zero (map_space )
110283
111- for i , arg in enumerate (ad .args ):
112- if arg is not None :
113- pw_multi_aff = pw_multi_aff .set_pw_aff (
114- i ,
115- pwaff_from_expr (space , arg )
116- )
284+ for i , arg in enumerate (usage ):
285+ pw_multi_aff = pw_multi_aff .set_pw_aff (
286+ i ,
287+ pwaff_from_expr (space , arg )
288+ )
117289
118290 usage_map = pw_multi_aff .as_map ()
119- iname_to_timespace = usage_map .apply_range (compute_map ).coalesce ()
291+
292+ iname_to_timespace = usage_map .apply_range (compute_map )
120293 iname_to_storage = iname_to_timespace .project_out_except (
121294 storage_indices , [isl .dim_type .out ]
122295 )
123296
297+ local_map = iname_to_storage .project_out_except (
298+ kernel .all_inames () - frozenset (temporal_inames ),
299+ [isl .dim_type .in_ ]
300+ )
301+
124302 footprint = footprint | iname_to_storage .range ()
125303
304+ usage_descrs .append (
305+ UsageDescriptor (
306+ usage ,
307+ iname_to_storage ,
308+ local_map
309+ )
310+ )
311+
126312 # add compute inames to domain / kernel
127313 domain_changer = DomainChanger (kernel , kernel .all_inames ())
128314 domain = domain_changer .domain
@@ -138,7 +324,7 @@ def compute(
138324 storage_ax_to_global_expr = {
139325 compute_pw_aff .get_dim_name (isl .dim_type .out , dim ) :
140326 pw_aff_to_expr (compute_pw_aff .get_at (dim ))
141- for dim in range (compute_pw_aff .dim (isl .dim_type .out ))
327+ for dim in range (compute_pw_aff .dim (isl .dim_type .out ))
142328 }
143329
144330 expr_subst_map = RuleAwareSubstitutionMapper (
@@ -150,7 +336,9 @@ def compute(
150336 subst_expr = kernel .substitutions [substitution ].expression
151337 compute_expression = expr_subst_map (subst_expr , kernel , None )
152338
153- temporary_name = substitution + "_temp"
339+ if not temporary_name :
340+ temporary_name = substitution + "_temp"
341+
154342 assignee = var (temporary_name )[tuple (
155343 var (iname ) for iname in storage_indices
156344 )]
@@ -172,5 +360,22 @@ def compute(
172360 new_insns .append (compute_insn )
173361 kernel = kernel .copy (instructions = new_insns )
174362
363+ ctx = SubstitutionRuleMappingContext (
364+ kernel .substitutions , kernel .get_var_name_generator ()
365+ )
366+
367+ replacer = RuleInvocationReplacer (
368+ ctx ,
369+ substitution ,
370+ None ,
371+ usage_descrs ,
372+ storage_indices ,
373+ temporary_name ,
374+ compute_insn_id ,
375+ compute_map
376+ )
377+
378+ kernel = replacer .map_kernel (kernel )
379+
175380 print (kernel )
176381 return kernel
0 commit comments