-
Notifications
You must be signed in to change notification settings - Fork 27
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make map_coordinates
differentiable for JAX 0.4.34
#1293
base: master
Are you sure you want to change the base?
Conversation
…tiply by all dots
| benchmark_name | dt(%) | dt(s) | t_new(s) | t_old(s) |
| -------------------------------------- | ---------------------- | ---------------------- | ---------------------- | ---------------------- |
test_build_transform_fft_midres | +0.68 +/- 7.09 | +4.21e-03 +/- 4.40e-02 | 6.25e-01 +/- 4.0e-02 | 6.21e-01 +/- 1.8e-02 |
test_build_transform_fft_highres | -1.19 +/- 4.27 | -1.22e-02 +/- 4.37e-02 | 1.01e+00 +/- 3.8e-02 | 1.02e+00 +/- 2.1e-02 |
test_equilibrium_init_lowres | +0.97 +/- 2.26 | +3.81e-02 +/- 8.84e-02 | 3.95e+00 +/- 7.4e-02 | 3.91e+00 +/- 4.8e-02 |
test_objective_compile_atf | +0.59 +/- 3.78 | +4.57e-02 +/- 2.95e-01 | 7.85e+00 +/- 1.8e-01 | 7.81e+00 +/- 2.3e-01 |
test_objective_compute_atf | +1.50 +/- 3.81 | +1.58e-04 +/- 4.00e-04 | 1.07e-02 +/- 3.4e-04 | 1.05e-02 +/- 2.1e-04 |
test_objective_jac_atf | -0.65 +/- 3.39 | -1.23e-02 +/- 6.40e-02 | 1.88e+00 +/- 5.5e-02 | 1.89e+00 +/- 3.2e-02 |
test_perturb_1 | +0.89 +/- 2.33 | +1.14e-01 +/- 2.98e-01 | 1.29e+01 +/- 1.9e-01 | 1.28e+01 +/- 2.3e-01 |
test_proximal_jac_atf | -0.11 +/- 0.97 | -8.54e-03 +/- 7.84e-02 | 8.09e+00 +/- 4.8e-02 | 8.10e+00 +/- 6.2e-02 |
test_proximal_freeb_compute | -0.46 +/- 0.97 | -8.53e-04 +/- 1.78e-03 | 1.84e-01 +/- 1.3e-03 | 1.85e-01 +/- 1.2e-03 |
test_build_transform_fft_lowres | +10.07 +/- 5.71 | +5.20e-02 +/- 2.95e-02 | 5.68e-01 +/- 2.5e-02 | 5.16e-01 +/- 1.5e-02 |
-test_equilibrium_init_medres | +11.80 +/- 3.17 | +4.80e-01 +/- 1.29e-01 | 4.54e+00 +/- 1.1e-01 | 4.07e+00 +/- 6.5e-02 |
-test_equilibrium_init_highres | +8.79 +/- 2.76 | +4.76e-01 +/- 1.49e-01 | 5.89e+00 +/- 1.4e-01 | 5.41e+00 +/- 4.2e-02 |
test_objective_compile_dshape_current | +4.02 +/- 8.78 | +1.53e-01 +/- 3.33e-01 | 3.95e+00 +/- 3.3e-01 | 3.80e+00 +/- 2.9e-02 |
test_objective_compute_dshape_current | +2.55 +/- 1.69 | +9.14e-05 +/- 6.05e-05 | 3.67e-03 +/- 4.8e-05 | 3.58e-03 +/- 3.7e-05 |
test_objective_jac_dshape_current | +1.49 +/- 5.76 | +5.96e-04 +/- 2.31e-03 | 4.07e-02 +/- 1.4e-03 | 4.01e-02 +/- 1.9e-03 |
test_perturb_2 | +0.61 +/- 2.24 | +1.07e-01 +/- 3.89e-01 | 1.75e+01 +/- 2.1e-01 | 1.74e+01 +/- 3.3e-01 |
test_proximal_freeb_jac | -0.04 +/- 1.62 | -3.09e-03 +/- 1.20e-01 | 7.45e+00 +/- 8.4e-02 | 7.46e+00 +/- 8.6e-02 |
test_solve_fixed_iter | +1.32 +/- 59.49 | +6.56e-02 +/- 2.95e+00 | 5.02e+00 +/- 2.1e+00 | 4.96e+00 +/- 2.1e+00 | |
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #1293 +/- ##
==========================================
- Coverage 95.53% 95.51% -0.02%
==========================================
Files 96 96
Lines 24005 24027 +22
==========================================
+ Hits 22932 22949 +17
- Misses 1073 1078 +5
|
…o yge/customjvp_fix
…rsion depende=ing on jax version to prevent errors
|
requirements.txt
Outdated
@@ -1,8 +1,8 @@ | |||
jax[cpu] >= 0.3.2, < 0.5.0 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
minor point but can we keep this in alphabetical order?
Also don't we need to bump the minimum jax version?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, we need to bump it up. Are we ok with 0.4.24
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I changed this to make test_jax
work better. In the alphabetical order (when jax comes later) pip was having trouble finding proper versions.
I made some other changes after that, maybe this is not an issue anymore but basically this prioritizes jax
version over others in a way
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What are the issues with older versions of jax?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
-
v0.4.23
-v0.4.13
fails due to some scipy function (which is not used in our code directly but when we import a function from jax, this scipy function is used somewhere in the same file and computer cannot read that whole file due to this error)
I don't know why pip doesn't choose proper scipy version, but it is probably due tojax
requirements which I cannot fix. -
v0.4.12
and older won't even past the pip install phase.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I didn't try that. Let me check
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
also I think we can remove the [cpu]
flag for jax, by default it installs the cpu version and this would make it less annoying for installing gpu stuff
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if we remove [cpu]
than we need to add jaxlib
to the requirements
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think all the recent versions install jaxlib automatically
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
f = _jacobi(n, alpha, beta, x, dx) | ||
df = _jacobi(n, alpha, beta, x, dx + 1) | ||
# in theory n, alpha, beta, dx aren't differentiable (they're integers) | ||
# but marking them as non-diff argnums seems to cause escaped tracer values. | ||
# probably a more elegant fix, but just setting those derivatives to zero seems | ||
# to work fine. | ||
return f, df * xdot + 0 * ndot + 0 * alphadot + 0 * betadot + 0 * dxdot |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I deleted redundant 0 multiplications because in some cases, this gives an error saying float0 cannot be used in math operations...
full_output
flags toroot
androot_scalar
to make them differentiable.test_jax
workflow with new jax versions and better dependency installation routine (i.e. previously since jax was uploaded later, rest of the packages were latest and only jax was old, this was causing incompatibilities and false-errors)mpl_test
workflowResolves #1291
root
androot_scalar
in addition tomap_coordinates_derivative