Skip to content

Commit b4bb3b9

Browse files
dhgarretteChexDev
authored andcommitted
Use format_shape_matcher when constructing the assert_tree_shape error message.
PiperOrigin-RevId: 799998984
1 parent 79c0a58 commit b4bb3b9

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

chex/_src/asserts.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1276,7 +1276,7 @@ def _assert_fn(path, leaf):
12761276
if not _shape_matches(leaf.shape, expected_shape):
12771277
errors.append((
12781278
f"Tree leaf '{_ai.format_tree_path(path)}' has shape {leaf.shape}"
1279-
f" but expected {expected_shape}."
1279+
f" but expected {_ai.format_shape_matcher(expected_shape)}."
12801280
))
12811281

12821282
for path, leaf in jax.tree_util.tree_flatten_with_path(tree)[0]:

0 commit comments

Comments
 (0)