@@ -271,7 +271,7 @@ def compute(
271271 """
272272 Inserts an instruction to compute an expression given by :arg:`substitution`
273273 and replaces all invocations of :arg:`substitution` with the result of the
274- compute instruction.
274+ inserted compute instruction.
275275
276276 :arg substitution: The substitution rule for which the compute
277277 transform should be applied.
@@ -283,8 +283,6 @@ def compute(
283283 :arg storage_indices: An ordered sequence of names of storage indices. Used
284284 to create inames for the loops that cover the required set of compute points.
285285 """
286- # FIXME: use namedisl directly
287- compute_map = compute_map ._reconstruct_isl_object ()
288286
289287 # {{{ construct necessary pieces; footprint, global usage map
290288
@@ -295,125 +293,74 @@ def compute(
295293 ctx , expander , kernel , substitution , None
296294 )
297295
296+ # add compute inames to domain / kernel
297+ domain_changer = DomainChanger (kernel , kernel .all_inames ())
298+ named_domain = nisl .make_basic_set (domain_changer .domain )
299+
298300 _ = expr_gatherer .map_kernel (kernel )
299301 usage_exprs = expr_gatherer .usage_expressions
300302
301- all_exprs = [
302- expr
303- for usage in usage_exprs
304- for expr in usage
305- ]
303+ all_exprs = [expr for usage in usage_exprs for expr in usage ]
304+ usage_inames = set .union (* (gather_vars (expr ) for expr in all_exprs ))
306305
307- space = space_from_exprs (all_exprs )
306+ usage_domain = nisl .make_set (f"{{ [{ "," .join (iname for iname in usage_inames )} ] }}" )
307+ footprint = nisl .make_set (f"{{ [{ "," .join (idx for idx in storage_indices )} ] }}" )
308308
309- footprint = isl .Set .empty (
310- isl .Space .create_from_names (
311- ctx = space .get_ctx (),
312- set = list (storage_indices )
313- )
309+ global_usage_map = nisl .make_map_from_domain_and_range (
310+ usage_domain ,
311+ compute_map .domain ()
314312 )
315313
316- # add compute inames to domain / kernel
317- domain_changer = DomainChanger (kernel , kernel .all_inames ())
318- domain = domain_changer .domain
319-
320- range_space = isl .Space .create_from_names (
321- ctx = space .get_ctx (),
322- set = list (storage_indices )
323- )
324- map_space = space .map_from_domain_and_range (range_space )
325- global_usage_map = isl .Map .empty (map_space )
314+ global_usage_map = nisl .make_map (isl .Map .empty (global_usage_map .get_space ()))
326315
316+ usage_substs : Mapping [AccessTuple , nisl .Map ] = {}
327317 for usage in usage_exprs :
328318
329319 # FIXME package sequence of pymbolic exprs -> multipwaff up as a function in loopy.symbolic
330- local_usage_mpwaff = isl .MultiPwAff .zero (map_space )
320+ local_usage_mpwaff = isl .MultiPwAff .zero (global_usage_map . get_space () )
331321
332322 for i in range (len (storage_indices )):
323+ local_space = local_usage_mpwaff .get_at (i ).get_space ().domain ()
333324 local_usage_mpwaff = local_usage_mpwaff .set_pw_aff (
334325 i ,
335- pwaff_from_expr (space , usage [i ])
326+ pwaff_from_expr (local_space , usage [i ])
336327 )
337328
338- local_usage_map = local_usage_mpwaff .as_map ()
329+ local_usage_map = nisl . make_map ( local_usage_mpwaff .as_map () )
339330
340- # FIXME: fix with namedisl
341- # remove unnecessary names from domain and intersect with usage map
342- usage_names = frozenset (
343- local_usage_map .get_dim_name (isl .dim_type .in_ , dim )
344- for dim in range (local_usage_map .dim (isl .dim_type .in_ ))
345- )
346-
347- domain_names = frozenset (
348- domain .get_dim_name (isl .dim_type .set , dim )
349- for dim in range (domain .dim (isl .dim_type .set ))
350- )
351-
352- domain_tmp = domain .project_out_except (
353- usage_names & domain_names , [isl .dim_type .set ]
354- )
355-
356- local_usage_map = align_map_domain_to_set (local_usage_map , domain_tmp )
357- local_usage_map = local_usage_map .intersect_domain (domain_tmp )
331+ local_usage_map = local_usage_map .intersect_domain (named_domain )
358332 global_usage_map = global_usage_map | local_usage_map
359333
360- # {{{ FIXME: this shouldn't need to be done here; will be handled by namedisl
334+ local_storage_map = local_usage_map .apply_range (compute_map )
335+ relevant_names = gather_vars (usage )
361336
362- global_usage_map = global_usage_map .apply_range (compute_map )
363- common_dims = {
364- dim1 : dim2
365- for dim1 in range (global_usage_map .dim (isl .dim_type .in_ ))
366- for dim2 in range (global_usage_map .dim (isl .dim_type .out ))
367- if (
368- global_usage_map .get_dim_name (isl .dim_type .in_ , dim1 )
369- ==
370- global_usage_map .get_dim_name (isl .dim_type .out , dim2 )
337+ local_storage_map = local_storage_map .project_out_except (
338+ (relevant_names - frozenset (temporal_inames )) | frozenset (storage_indices )
371339 )
372- }
373340
374- for pos1 , pos2 in common_dims .items ():
375- global_usage_map = global_usage_map .equate (
376- isl .dim_type .in_ , pos1 ,
377- isl .dim_type .out , pos2
378- )
341+ usage_substs [tuple (usage )] = local_storage_map
379342
380- # }}}
343+ global_usage_map = global_usage_map . apply_range ( compute_map )
381344
382345 # }}}
383346
384347 # {{{ compute bounds and update kernel domain
385348
386349 footprint = global_usage_map .range ()
387- footprint_tmp , domain = isl .align_two (footprint , domain )
388- domain = (domain & footprint_tmp ).get_basic_sets ()[0 ]
389-
390- new_domains = domain_changer .get_domains_with (domain )
391- kernel = kernel .copy (domains = new_domains )
392-
393- # }}}
394-
395- # {{{ compute index expressions
396-
397- usage_substs : Mapping [AccessTuple , isl .Map ] = {}
398- for usage in usage_exprs :
399- # find the relevant names
400- relevant_names = gather_vars (usage )
401-
402- # project out irrelevant names
403- relevant_names = frozenset (relevant_names ) - frozenset (temporal_inames )
350+ footprint = footprint .project_out_except (
351+ frozenset (temporal_inames ) | frozenset (storage_indices )
352+ )
404353
405- local_iname_to_storage = global_usage_map .project_out_except (
406- relevant_names ,
407- [isl .dim_type .in_ ]
408- )
354+ # FIXME: probably do not want this permanently here
355+ footprint = nisl .make_set (footprint ._reconstruct_isl_object ().convex_hull ())
356+ named_domain = named_domain & footprint
409357
410- local_iname_to_storage = local_iname_to_storage .project_out_except (
411- storage_indices ,
412- [isl .dim_type .out ]
413- )
358+ # FIXME:
359+ if len (named_domain .get_basic_sets ()) != 1 :
360+ raise ValueError ("New domain should be composed of a single basic set" )
414361
415- # map usage -> resulting map
416- usage_substs [ tuple ( usage )] = local_iname_to_storage
362+ new_domains = domain_changer . get_domains_with ( named_domain . get_basic_sets ()[ 0 ])
363+ kernel = kernel . copy ( domains = new_domains )
417364
418365 # }}}
419366
@@ -487,9 +434,7 @@ def compute(
487434 loopy_type = to_loopy_type (temporary_dtype , allow_none = True )
488435
489436 # FIXME: fix with namedisl?
490- shape_domain = footprint .project_out_except (storage_indices ,
491- [isl .dim_type .set ])
492- shape_domain = shape_domain .project_out_except ("" , [isl .dim_type .param ])
437+ shape_domain = footprint .project_out_except (storage_indices )
493438
494439 temp_shape = tuple (
495440 pw_aff_to_expr (shape_domain .dim_max (dim )) + 1
0 commit comments