@@ -88,17 +88,39 @@ class LoopyKeyBuilder(KeyBuilderBase):
8888 update_for_dict = KeyBuilderBase .update_for_constantdict
8989 update_for_defaultdict = KeyBuilderBase .update_for_constantdict
9090
91- def update_for_BasicSet (self , key_hash , key ): # noqa: N802
91+ def _update_for_isl_obj (self ,
92+ key_hash : Hash ,
93+ key : isl .BasicSet | isl .Set | isl .BasicMap | isl .Map
94+ ):
9295 from islpy import Printer
9396 prn = Printer .to_str (key .get_ctx ())
94- getattr (prn , "print_" + key ._base_name )(key )
97+ getattr (prn , "print_" + key ._base_name )(key ) # pyright: ignore[reportUnknownArgumentType, reportUnknownMemberType, reportAttributeAccessIssue]
9598 key_hash .update (prn .get_str ().encode ("utf8" ))
9699
97- def update_for_Map (self , key_hash , key ): # noqa: N802
100+ @override
101+ def update_for_Map (self , key_hash : Hash , key : isl .Map ):
98102 if isinstance (key , isl .Map ):
99- self .update_for_BasicSet (key_hash , key )
103+ self ._update_for_isl_obj (key_hash , key )
104+ else :
105+ super ().update_for_Map (key_hash , key )
106+
107+ def update_for_BasicMap (self , key_hash : Hash , key : isl .BasicMap ): # noqa: N802
108+ if isinstance (key , isl .BasicMap ):
109+ self ._update_for_isl_obj (key_hash , key )
110+ else :
111+ raise TypeError ("called on a non-isl type" )
112+
113+ def update_for_Set (self , key_hash : Hash , key : isl .Set ): # noqa: N802
114+ if isinstance (key , isl .Set ):
115+ self ._update_for_isl_obj (key_hash , key )
116+ else :
117+ raise TypeError ("called on a non-isl type" )
118+
119+ def update_for_BasicSet (self , key_hash : Hash , key : isl .BasicSet ): # noqa: N802
120+ if isinstance (key , isl .BasicSet ):
121+ self ._update_for_isl_obj (key_hash , key )
100122 else :
101- raise AssertionError ( )
123+ raise TypeError ( "called on a non-isl type" )
102124
103125# }}}
104126
0 commit comments