1- from collections .abc import Sequence , Set
1+ from collections .abc import Mapping , Sequence , Set
22from dataclasses import dataclass
33from typing import override
44import loopy as lp
55from loopy .kernel .tools import DomainChanger
6+ from loopy .types import to_loopy_type
67import namedisl as nisl
78
89from loopy .kernel import LoopKernel
@@ -50,22 +51,6 @@ def space_from_exprs(exprs, ctx=isl.DEFAULT_CONTEXT):
5051 set = set_names
5152 )
5253
53-
54- @dataclass (frozen = True )
55- class UsageDescriptor :
56- usage : Sequence [Expression ]
57- global_map : isl .Map
58- local_map : isl .Map
59-
60- @override
61- def __str__ (self ):
62- return (
63- f"USAGE = { self .usage } \n " +
64- f"GLOBAL MAP = { self .global_map } \n " +
65- f"LOCAL MAP = { self .local_map } "
66- )
67-
68-
6954class UsageSiteExpressionGatherer (RuleAwareIdentityMapper [[]]):
7055 """
7156 Gathers all expressions used as inputs to a particular substitution rule,
@@ -129,7 +114,7 @@ def __init__(
129114 ctx : SubstitutionRuleMappingContext ,
130115 subst_name : str ,
131116 subst_tag : Sequence [Tag ] | None ,
132- usage_descriptors : Sequence [ UsageDescriptor ],
117+ usage_descriptors : Mapping [ tuple [ Expression , ...], isl . Map ],
133118 storage_indices : Sequence [str ],
134119 temporary_name : str ,
135120 compute_insn_id : str ,
@@ -141,12 +126,19 @@ def __init__(
141126 self .subst_name : str = subst_name
142127 self .subst_tag : Sequence [Tag ] | None = subst_tag
143128
144- self .usage_descriptors : Sequence [UsageDescriptor ] = usage_descriptors
129+ self .usage_descriptors : Mapping [tuple [Expression , ...], isl .Map ] = \
130+ usage_descriptors
145131 self .storage_indices : Sequence [str ] = storage_indices
146132
147133 self .temporary_name : str = temporary_name
148134 self .compute_insn_id : str = compute_insn_id
149135
136+ # FIXME: may not always be the case (i.e. global barrier between
137+ # compute insn and uses)
138+ self .compute_dep_id : str = compute_insn_id
139+
140+ self .replaced_something : bool = False
141+
150142
151143 @override
152144 def map_subst_rule (
@@ -175,17 +167,17 @@ def map_subst_rule(
175167 "does not match the signature of {name}." )
176168
177169 index_exprs : Sequence [Expression ] = []
178- for usage_descr in self .usage_descriptors :
179- if args == usage_descr .usage :
180- local_pwmaff = usage_descr .local_map .as_pw_multi_aff ()
181170
182- for dim in range ( local_pwmaff . dim ( isl . dim_type . out )):
183- index_exprs . append ( pw_aff_to_expr ( local_pwmaff . get_at ( dim )) )
171+ # FIXME: make self.usage_descriptors a constantdict
172+ local_pwmaff = self . usage_descriptors [ tuple ( args )]. as_pw_multi_aff ( )
184173
185- break
174+ for dim in range (local_pwmaff .dim (isl .dim_type .out )):
175+ index_exprs .append (pw_aff_to_expr (local_pwmaff .get_at (dim )))
186176
187177 new_expression = var (self .temporary_name )[tuple (index_exprs )]
188178
179+ self .replaced_something = True
180+
189181 return new_expression
190182
191183
@@ -198,8 +190,10 @@ def map_kernel(
198190 map_tvs : bool = True
199191 ) -> LoopKernel :
200192
201- new_insns = []
193+ new_insns : Sequence [ lp . InstructionBase ] = []
202194 for insn in kernel .instructions :
195+ self .replaced_something = False
196+
203197 if (isinstance (insn , lp .MultiAssignmentBase ) and not
204198 contains_a_subst_rule_invocation (kernel , insn )):
205199 new_insns .append (insn )
@@ -209,6 +203,15 @@ def map_kernel(
209203 lambda expr : self (expr , kernel , insn )
210204 )
211205
206+ if self .replaced_something :
207+ insn = insn .copy (
208+ depends_on = (
209+ insn .depends_on | frozenset ([self .compute_insn_id ])
210+ )
211+ )
212+
213+ # FIXME: determine compute insn dependencies
214+
212215 new_insns .append (insn )
213216
214217 return kernel .copy (instructions = new_insns )
@@ -270,7 +273,7 @@ def compute(
270273 )
271274 )
272275
273- usage_descrs : Sequence [ UsageDescriptor ] = []
276+ usage_descrs : Mapping [ tuple [ Expression , ...], isl . Map ] = {}
274277 for usage in usage_exprs :
275278
276279 range_space = isl .Space .create_from_names (
@@ -301,20 +304,14 @@ def compute(
301304
302305 footprint = footprint | iname_to_storage .range ()
303306
304- usage_descrs .append (
305- UsageDescriptor (
306- usage ,
307- iname_to_storage ,
308- local_map
309- )
310- )
307+ usage_descrs [tuple (usage )] = local_map
311308
312309 # add compute inames to domain / kernel
313310 domain_changer = DomainChanger (kernel , kernel .all_inames ())
314311 domain = domain_changer .domain
315312
316- footprint , domain = isl .align_two (footprint , domain )
317- domain = domain & footprint
313+ footprint_tmp , domain = isl .align_two (footprint , domain )
314+ domain = domain & footprint_tmp
318315
319316 new_domains = domain_changer .get_domains_with (domain )
320317 kernel = kernel .copy (domains = new_domains )
@@ -377,5 +374,34 @@ def compute(
377374
378375 kernel = replacer .map_kernel (kernel )
379376
380- print (kernel )
377+ # FIXME: accept dtype as an argument
378+ import numpy as np
379+ loopy_type = to_loopy_type (np .float64 , allow_none = True )
380+
381+ # WARNING: this can result in symbolic shapes, is that allowed?
382+ temp_shape = tuple (
383+ pw_aff_to_expr (footprint .dim_max (dim )) + 1
384+ for dim in range (footprint .dim (isl .dim_type .out ))
385+ )
386+
387+ new_temp_vars = dict (kernel .temporary_variables )
388+
389+ # FIXME: temp_var might already exist, handle the case where it does
390+ temp_var = lp .TemporaryVariable (
391+ name = temporary_name ,
392+ dtype = loopy_type ,
393+ base_indices = (0 ,)* len (temp_shape ),
394+ shape = temp_shape ,
395+ address_space = temporary_address_space ,
396+ dim_names = tuple (storage_indices )
397+ )
398+
399+ new_temp_vars [temporary_name ] = temp_var
400+
401+ kernel = kernel .copy (
402+ temporary_variables = new_temp_vars
403+ )
404+
405+ # FIXME: handle iname tagging
406+
381407 return kernel
0 commit comments