Skip to content

Commit 512d26e

Browse files
Jake VanderPlasChexDev
authored andcommitted
Chex: add Python 3.13 CI tests
This requires making dm-tree a testing requirement only for Python < 3.13, because there is no py313-compatible dm-tree release. PiperOrigin-RevId: 716138572
1 parent 2fe940a commit 512d26e

File tree

3 files changed

+29
-4
lines changed

3 files changed

+29
-4
lines changed

.github/workflows/ci.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ jobs:
3030
- python-version: "3.9"
3131
os: "ubuntu-latest"
3232
jax-version: "0.4.27" # Keep this in sync with version in requirements.txt
33-
- python-version: "3.12"
33+
- python-version: "3.13"
3434
os: "ubuntu-latest"
3535
jax-version: "nightly"
3636

chex/_src/dataclass_test.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
import pickle
2222
import sys
2323
from typing import Any, Generic, Mapping, TypeVar
24+
import unittest
2425

2526
from absl.testing import absltest
2627
from absl.testing import parameterized
@@ -30,7 +31,12 @@
3031
import cloudpickle
3132
import jax
3233
import numpy as np
33-
import tree
34+
35+
# dm-tree is not compatible with Python 3.13.
36+
try:
37+
import tree # pylint:disable=g-import-not-at-top
38+
except ImportError:
39+
tree = None
3440

3541
chex_dataclass = dataclass.dataclass
3642
mappable_dataclass = dataclass.mappable_dataclass
@@ -209,7 +215,9 @@ def _init_testdata(self, test_type):
209215
self.dcls_tree_size = 18
210216
self.dcls_tree_size_no_dicts = 14
211217

218+
@unittest.skipIf(tree is None, 'dm-tree is not compatible with Python 3.13')
212219
def testFlattenAndUnflatten(self, test_type):
220+
assert tree is not None
213221
self._init_testdata(test_type)
214222

215223
self.assertEqual(self.dcls_flattened, tree.flatten(self.dcls_with_map))
@@ -223,29 +231,37 @@ def testFlattenAndUnflatten(self, test_type):
223231
self.assertEqual(dataclass_in_seq,
224232
tree.unflatten_as(dataclass_in_seq, dataclass_in_seq_flat))
225233

234+
@unittest.skipIf(tree is None, 'dm-tree is not compatible with Python 3.13')
226235
def testFlattenUpTo(self, test_type):
236+
assert tree is not None
227237
self._init_testdata(test_type)
228238
structure = copy.copy(self.dcls_with_map)
229239
structure.k_dclass_with_map = None # Do not flatten 'k_dclass_with_map'
230240
self.assertEqual(self.dcls_flattened_up_to,
231241
tree.flatten_up_to(structure, self.dcls_with_map))
232242

243+
@unittest.skipIf(tree is None, 'dm-tree is not compatible with Python 3.13')
233244
def testFlattenWithPath(self, test_type):
245+
assert tree is not None
234246
self._init_testdata(test_type)
235247

236248
self.assertEqual(
237249
tree.flatten_with_path(self.dcls_with_map),
238250
self.dcls_flattened_with_path)
239251

252+
@unittest.skipIf(tree is None, 'dm-tree is not compatible with Python 3.13')
240253
def testFlattenWithPathUpTo(self, test_type):
254+
assert tree is not None
241255
self._init_testdata(test_type)
242256
structure = copy.copy(self.dcls_with_map)
243257
structure.k_dclass_with_map = None # Do not flatten 'k_dclass_with_map'
244258
self.assertEqual(
245259
tree.flatten_with_path_up_to(structure, self.dcls_with_map),
246260
self.dcls_flattened_with_path_up_to)
247261

262+
@unittest.skipIf(tree is None, 'dm-tree is not compatible with Python 3.13')
248263
def testMapStructure(self, test_type):
264+
assert tree is not None
249265
self._init_testdata(test_type)
250266

251267
add_one_to_ints_fn = lambda x: x + 1 if isinstance(x, int) else x
@@ -256,7 +272,9 @@ def testMapStructure(self, test_type):
256272
self.dcls_with_map_inc_ints.k_int * 10)
257273
self.assertEqual(mapped_inc_ints.k_non_init, mapped_inc_ints.k_int * 10)
258274

275+
@unittest.skipIf(tree is None, 'dm-tree is not compatible with Python 3.13')
259276
def testMapStructureUpTo(self, test_type):
277+
assert tree is not None
260278
self._init_testdata(test_type)
261279

262280
structure = copy.copy(self.dcls_with_map)
@@ -273,7 +291,9 @@ def testMapStructureUpTo(self, test_type):
273291
self.dcls_with_map_inc_ints.k_int * 10)
274292
self.assertEqual(mapped_inc_ints.k_non_init, mapped_inc_ints.k_int * 10)
275293

294+
@unittest.skipIf(tree is None, 'dm-tree is not compatible with Python 3.13')
276295
def testMapStructureWithPath(self, test_type):
296+
assert tree is not None
277297
self._init_testdata(test_type)
278298

279299
add_one_to_ints_fn = lambda path, x: x + 1 if isinstance(x, int) else x
@@ -285,7 +305,9 @@ def testMapStructureWithPath(self, test_type):
285305
self.dcls_with_map_inc_ints.k_int * 10)
286306
self.assertEqual(mapped_inc_ints.k_non_init, mapped_inc_ints.k_int * 10)
287307

308+
@unittest.skipIf(tree is None, 'dm-tree is not compatible with Python 3.13')
288309
def testMapStructureWithPathUpTo(self, test_type):
310+
assert tree is not None
289311
self._init_testdata(test_type)
290312

291313
structure = copy.copy(self.dcls_with_map)
@@ -303,7 +325,9 @@ def testMapStructureWithPathUpTo(self, test_type):
303325
self.dcls_with_map_inc_ints.k_int * 10)
304326
self.assertEqual(mapped_inc_ints.k_non_init, mapped_inc_ints.k_int * 10)
305327

328+
@unittest.skipIf(tree is None, 'dm-tree is not compatible with Python 3.13')
306329
def testTraverse(self, test_type):
330+
assert tree is not None
307331
self._init_testdata(test_type)
308332

309333
visited = []
@@ -358,7 +382,7 @@ def test_tree_flatten_with_keys(self):
358382
)
359383
leaves = [l for _, l in keys_and_leaves]
360384
new_obj = treedef.unflatten(leaves)
361-
self.assertEqual(new_obj, obj)
385+
asserts.assert_trees_all_equal(new_obj, obj)
362386

363387
def test_tree_map_with_keys(self):
364388
obj = dummy_dataclass()

requirements/requirements-test.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
11
cloudpickle==2.2.0
2-
dm-tree>=0.1.5
2+
# dm-tree is not compatible with Python 3.13.
3+
dm-tree>=0.1.5; python_version < "3.13"

0 commit comments

Comments
 (0)