2121import pickle
2222import sys
2323from typing import Any , Generic , Mapping , TypeVar
24+ import unittest
2425
2526from absl .testing import absltest
2627from absl .testing import parameterized
3031import cloudpickle
3132import jax
3233import numpy as np
33- import tree
34+
35+ # dm-tree is not compatible with Python 3.13.
36+ try :
37+ import tree # pylint:disable=g-import-not-at-top
38+ except ImportError :
39+ tree = None
3440
3541chex_dataclass = dataclass .dataclass
3642mappable_dataclass = dataclass .mappable_dataclass
@@ -209,7 +215,9 @@ def _init_testdata(self, test_type):
209215 self .dcls_tree_size = 18
210216 self .dcls_tree_size_no_dicts = 14
211217
218+ @unittest .skipIf (tree is None , 'dm-tree is not compatible with Python 3.13' )
212219 def testFlattenAndUnflatten (self , test_type ):
220+ assert tree is not None
213221 self ._init_testdata (test_type )
214222
215223 self .assertEqual (self .dcls_flattened , tree .flatten (self .dcls_with_map ))
@@ -223,29 +231,37 @@ def testFlattenAndUnflatten(self, test_type):
223231 self .assertEqual (dataclass_in_seq ,
224232 tree .unflatten_as (dataclass_in_seq , dataclass_in_seq_flat ))
225233
234+ @unittest .skipIf (tree is None , 'dm-tree is not compatible with Python 3.13' )
226235 def testFlattenUpTo (self , test_type ):
236+ assert tree is not None
227237 self ._init_testdata (test_type )
228238 structure = copy .copy (self .dcls_with_map )
229239 structure .k_dclass_with_map = None # Do not flatten 'k_dclass_with_map'
230240 self .assertEqual (self .dcls_flattened_up_to ,
231241 tree .flatten_up_to (structure , self .dcls_with_map ))
232242
243+ @unittest .skipIf (tree is None , 'dm-tree is not compatible with Python 3.13' )
233244 def testFlattenWithPath (self , test_type ):
245+ assert tree is not None
234246 self ._init_testdata (test_type )
235247
236248 self .assertEqual (
237249 tree .flatten_with_path (self .dcls_with_map ),
238250 self .dcls_flattened_with_path )
239251
252+ @unittest .skipIf (tree is None , 'dm-tree is not compatible with Python 3.13' )
240253 def testFlattenWithPathUpTo (self , test_type ):
254+ assert tree is not None
241255 self ._init_testdata (test_type )
242256 structure = copy .copy (self .dcls_with_map )
243257 structure .k_dclass_with_map = None # Do not flatten 'k_dclass_with_map'
244258 self .assertEqual (
245259 tree .flatten_with_path_up_to (structure , self .dcls_with_map ),
246260 self .dcls_flattened_with_path_up_to )
247261
262+ @unittest .skipIf (tree is None , 'dm-tree is not compatible with Python 3.13' )
248263 def testMapStructure (self , test_type ):
264+ assert tree is not None
249265 self ._init_testdata (test_type )
250266
251267 add_one_to_ints_fn = lambda x : x + 1 if isinstance (x , int ) else x
@@ -256,7 +272,9 @@ def testMapStructure(self, test_type):
256272 self .dcls_with_map_inc_ints .k_int * 10 )
257273 self .assertEqual (mapped_inc_ints .k_non_init , mapped_inc_ints .k_int * 10 )
258274
275+ @unittest .skipIf (tree is None , 'dm-tree is not compatible with Python 3.13' )
259276 def testMapStructureUpTo (self , test_type ):
277+ assert tree is not None
260278 self ._init_testdata (test_type )
261279
262280 structure = copy .copy (self .dcls_with_map )
@@ -273,7 +291,9 @@ def testMapStructureUpTo(self, test_type):
273291 self .dcls_with_map_inc_ints .k_int * 10 )
274292 self .assertEqual (mapped_inc_ints .k_non_init , mapped_inc_ints .k_int * 10 )
275293
294+ @unittest .skipIf (tree is None , 'dm-tree is not compatible with Python 3.13' )
276295 def testMapStructureWithPath (self , test_type ):
296+ assert tree is not None
277297 self ._init_testdata (test_type )
278298
279299 add_one_to_ints_fn = lambda path , x : x + 1 if isinstance (x , int ) else x
@@ -285,7 +305,9 @@ def testMapStructureWithPath(self, test_type):
285305 self .dcls_with_map_inc_ints .k_int * 10 )
286306 self .assertEqual (mapped_inc_ints .k_non_init , mapped_inc_ints .k_int * 10 )
287307
308+ @unittest .skipIf (tree is None , 'dm-tree is not compatible with Python 3.13' )
288309 def testMapStructureWithPathUpTo (self , test_type ):
310+ assert tree is not None
289311 self ._init_testdata (test_type )
290312
291313 structure = copy .copy (self .dcls_with_map )
@@ -303,7 +325,9 @@ def testMapStructureWithPathUpTo(self, test_type):
303325 self .dcls_with_map_inc_ints .k_int * 10 )
304326 self .assertEqual (mapped_inc_ints .k_non_init , mapped_inc_ints .k_int * 10 )
305327
328+ @unittest .skipIf (tree is None , 'dm-tree is not compatible with Python 3.13' )
306329 def testTraverse (self , test_type ):
330+ assert tree is not None
307331 self ._init_testdata (test_type )
308332
309333 visited = []
@@ -358,7 +382,7 @@ def test_tree_flatten_with_keys(self):
358382 )
359383 leaves = [l for _ , l in keys_and_leaves ]
360384 new_obj = treedef .unflatten (leaves )
361- self . assertEqual (new_obj , obj )
385+ asserts . assert_trees_all_equal (new_obj , obj )
362386
363387 def test_tree_map_with_keys (self ):
364388 obj = dummy_dataclass ()
0 commit comments