Replace deprecated jax.tree_*
functions with jax.tree.*
#580
Job | Run time |
---|---|
53s | |
52s | |
1m 45s |
jax.tree_*
functions with jax.tree.*
#580
Job | Run time |
---|---|
53s | |
52s | |
1m 45s |