Replace deprecated jax.tree_*
functions with jax.tree.*
#578
Loading
jax.tree_*
functions with jax.tree.*
#578