Skip to content

Commit f08636e

Browse files
Jake VanderPlastree-math authors
Jake VanderPlas
authored and
tree-math authors
committed
Replace deprecated jax.tree_* functions with jax.tree.*
The top-level `jax.tree_*` aliases have long been deprecated, and will soon be removed. Alternate APIs are in `jax.tree_util`, with shorter aliases in the `jax.tree` submodule, added in JAX version 0.4.25. PiperOrigin-RevId: 633773679
1 parent 0727453 commit f08636e

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

tree_math/_src/structs_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,8 +41,8 @@ class StructsTest(test_util.TestCase):
4141
dict(testcase_name='Arrays', x=TestStruct(np.eye(10), np.ones([3, 4, 5])))
4242
)
4343
def testFlattenUnflatten(self, x):
44-
leaves, structure = jax.tree_flatten(x)
45-
y = jax.tree_unflatten(structure, leaves)
44+
leaves, structure = jax.tree.flatten(x)
45+
y = jax.tree.unflatten(structure, leaves)
4646
np.testing.assert_allclose(x.a, y.a)
4747
np.testing.assert_allclose(x.b, y.b)
4848

0 commit comments

Comments
 (0)