Replace deprecated jax.tree_*
functions with jax.tree.*
#579
Job | Run time |
---|---|
52s | |
1m 8s | |
2m 0s |
jax.tree_*
functions with jax.tree.*
#579
Job | Run time |
---|---|
52s | |
1m 8s | |
2m 0s |