Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make map_coordinates differentiable for JAX 0.4.34 #1293

Open
wants to merge 42 commits into
base: master
Choose a base branch
from

Conversation

YigitElma
Copy link
Collaborator

@YigitElma YigitElma commented Oct 4, 2024

  • Adds full_output flags to root and root_scalar to make them differentiable.
  • While working on jax problems, I used this PR to update our test_jax workflow with new jax versions and better dependency installation routine (i.e. previously since jax was uploaded later, rest of the packages were latest and only jax was old, this was causing incompatibilities and false-errors)
  • We missed to add default matplotlib version for pytests in case cached environment not found in Higher splitting for Github Actions and Cache venv #1213. Added lines to unit, regression and notebook workflows fix that
  • Adds newer versions of matplotlib to the mpl_test workflow

Resolves #1291

  • Maybe? Add tests for differentiability of root and root_scalar in addition to map_coordinates_derivative

@YigitElma YigitElma added test_jax Run tests against different versions of JAX easy Short and simple to code or review bug fix Something was fixed labels Oct 4, 2024
@YigitElma YigitElma self-assigned this Oct 4, 2024
Copy link
Contributor

github-actions bot commented Oct 4, 2024

|             benchmark_name             |         dt(%)          |         dt(s)          |        t_new(s)        |        t_old(s)        | 
| -------------------------------------- | ---------------------- | ---------------------- | ---------------------- | ---------------------- |
 test_build_transform_fft_midres         |     +0.68 +/- 7.09     | +4.21e-03 +/- 4.40e-02 |  6.25e-01 +/- 4.0e-02  |  6.21e-01 +/- 1.8e-02  |
 test_build_transform_fft_highres        |     -1.19 +/- 4.27     | -1.22e-02 +/- 4.37e-02 |  1.01e+00 +/- 3.8e-02  |  1.02e+00 +/- 2.1e-02  |
 test_equilibrium_init_lowres            |     +0.97 +/- 2.26     | +3.81e-02 +/- 8.84e-02 |  3.95e+00 +/- 7.4e-02  |  3.91e+00 +/- 4.8e-02  |
 test_objective_compile_atf              |     +0.59 +/- 3.78     | +4.57e-02 +/- 2.95e-01 |  7.85e+00 +/- 1.8e-01  |  7.81e+00 +/- 2.3e-01  |
 test_objective_compute_atf              |     +1.50 +/- 3.81     | +1.58e-04 +/- 4.00e-04 |  1.07e-02 +/- 3.4e-04  |  1.05e-02 +/- 2.1e-04  |
 test_objective_jac_atf                  |     -0.65 +/- 3.39     | -1.23e-02 +/- 6.40e-02 |  1.88e+00 +/- 5.5e-02  |  1.89e+00 +/- 3.2e-02  |
 test_perturb_1                          |     +0.89 +/- 2.33     | +1.14e-01 +/- 2.98e-01 |  1.29e+01 +/- 1.9e-01  |  1.28e+01 +/- 2.3e-01  |
 test_proximal_jac_atf                   |     -0.11 +/- 0.97     | -8.54e-03 +/- 7.84e-02 |  8.09e+00 +/- 4.8e-02  |  8.10e+00 +/- 6.2e-02  |
 test_proximal_freeb_compute             |     -0.46 +/- 0.97     | -8.53e-04 +/- 1.78e-03 |  1.84e-01 +/- 1.3e-03  |  1.85e-01 +/- 1.2e-03  |
 test_build_transform_fft_lowres         |    +10.07 +/- 5.71     | +5.20e-02 +/- 2.95e-02 |  5.68e-01 +/- 2.5e-02  |  5.16e-01 +/- 1.5e-02  |
-test_equilibrium_init_medres            |    +11.80 +/- 3.17     | +4.80e-01 +/- 1.29e-01 |  4.54e+00 +/- 1.1e-01  |  4.07e+00 +/- 6.5e-02  |
-test_equilibrium_init_highres           |     +8.79 +/- 2.76     | +4.76e-01 +/- 1.49e-01 |  5.89e+00 +/- 1.4e-01  |  5.41e+00 +/- 4.2e-02  |
 test_objective_compile_dshape_current   |     +4.02 +/- 8.78     | +1.53e-01 +/- 3.33e-01 |  3.95e+00 +/- 3.3e-01  |  3.80e+00 +/- 2.9e-02  |
 test_objective_compute_dshape_current   |     +2.55 +/- 1.69     | +9.14e-05 +/- 6.05e-05 |  3.67e-03 +/- 4.8e-05  |  3.58e-03 +/- 3.7e-05  |
 test_objective_jac_dshape_current       |     +1.49 +/- 5.76     | +5.96e-04 +/- 2.31e-03 |  4.07e-02 +/- 1.4e-03  |  4.01e-02 +/- 1.9e-03  |
 test_perturb_2                          |     +0.61 +/- 2.24     | +1.07e-01 +/- 3.89e-01 |  1.75e+01 +/- 2.1e-01  |  1.74e+01 +/- 3.3e-01  |
 test_proximal_freeb_jac                 |     -0.04 +/- 1.62     | -3.09e-03 +/- 1.20e-01 |  7.45e+00 +/- 8.4e-02  |  7.46e+00 +/- 8.6e-02  |
 test_solve_fixed_iter                   |     +1.32 +/- 59.49    | +6.56e-02 +/- 2.95e+00 |  5.02e+00 +/- 2.1e+00  |  4.96e+00 +/- 2.1e+00  |

Copy link

codecov bot commented Oct 4, 2024

Codecov Report

Attention: Patch coverage is 85.36585% with 6 lines in your changes missing coverage. Please review.

Project coverage is 95.51%. Comparing base (84a051b) to head (ee24838).

Files with missing lines Patch % Lines
desc/backend.py 86.36% 3 Missing ⚠️
desc/equilibrium/coords.py 77.77% 2 Missing ⚠️
desc/geometry/surface.py 75.00% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##           master    #1293      +/-   ##
==========================================
- Coverage   95.53%   95.51%   -0.02%     
==========================================
  Files          96       96              
  Lines       24005    24027      +22     
==========================================
+ Hits        22932    22949      +17     
- Misses       1073     1078       +5     
Files with missing lines Coverage Δ
desc/basis.py 98.18% <100.00%> (ø)
desc/geometry/surface.py 96.60% <75.00%> (-0.25%) ⬇️
desc/equilibrium/coords.py 88.25% <77.77%> (-0.13%) ⬇️
desc/backend.py 90.39% <86.36%> (+0.21%) ⬆️

... and 1 file with indirect coverage changes

@dpanici dpanici added test_jax Run tests against different versions of JAX and removed test_jax Run tests against different versions of JAX labels Oct 5, 2024
@dpanici dpanici added test_jax Run tests against different versions of JAX and removed test_jax Run tests against different versions of JAX labels Oct 6, 2024
@YigitElma YigitElma marked this pull request as draft October 6, 2024 18:56
@dpanici dpanici added test_jax Run tests against different versions of JAX and removed test_jax Run tests against different versions of JAX labels Oct 6, 2024
@dpanici dpanici added test_jax Run tests against different versions of JAX and removed test_jax Run tests against different versions of JAX labels Oct 7, 2024
@YigitElma
Copy link
Collaborator Author

YigitElma commented Oct 19, 2024

Check older versions of JAX, they seem to fail for full_output=True case.

@YigitElma YigitElma added test_jax Run tests against different versions of JAX and removed test_jax Run tests against different versions of JAX labels Oct 20, 2024
@YigitElma YigitElma added test_jax Run tests against different versions of JAX and removed test_jax Run tests against different versions of JAX test_mpl labels Oct 20, 2024
desc/basis.py Show resolved Hide resolved
requirements.txt Outdated
@@ -1,8 +1,8 @@
jax[cpu] >= 0.3.2, < 0.5.0
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor point but can we keep this in alphabetical order?

Also don't we need to bump the minimum jax version?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we need to bump it up. Are we ok with 0.4.24?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I changed this to make test_jax work better. In the alphabetical order (when jax comes later) pip was having trouble finding proper versions.

I made some other changes after that, maybe this is not an issue anymore but basically this prioritizes jax version over others in a way

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What are the issues with older versions of jax?

Copy link
Collaborator Author

@YigitElma YigitElma Oct 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • v0.4.23-v0.4.13 fails due to some scipy function (which is not used in our code directly but when we import a function from jax, this scipy function is used somewhere in the same file and computer cannot read that whole file due to this error)
    image
    I don't know why pip doesn't choose proper scipy version, but it is probably due to jax requirements which I cannot fix.

  • v0.4.12 and older won't even past the pip install phase.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't try that. Let me check

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also I think we can remove the [cpu] flag for jax, by default it installs the cpu version and this would make it less annoying for installing gpu stuff

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if we remove [cpu] than we need to add jaxlib to the requirements

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think all the recent versions install jaxlib automatically

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

btw, even if I install older scipy version, I get new errors
image
I just checked 0.4.13 locally because I need to add some logic to CI to test it. If you wanna test, feel free to do it here

requirements_conda.yml Outdated Show resolved Hide resolved
desc/backend.py Show resolved Hide resolved
f = _jacobi(n, alpha, beta, x, dx)
df = _jacobi(n, alpha, beta, x, dx + 1)
# in theory n, alpha, beta, dx aren't differentiable (they're integers)
# but marking them as non-diff argnums seems to cause escaped tracer values.
# probably a more elegant fix, but just setting those derivatives to zero seems
# to work fine.
return f, df * xdot + 0 * ndot + 0 * alphadot + 0 * betadot + 0 * dxdot
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I deleted redundant 0 multiplications because in some cases, this gives an error saying float0 cannot be used in math operations...

@YigitElma YigitElma added test_jax Run tests against different versions of JAX and removed test_jax Run tests against different versions of JAX labels Oct 21, 2024
@YigitElma YigitElma added test_jax Run tests against different versions of JAX and removed test_jax Run tests against different versions of JAX labels Oct 21, 2024
@YigitElma YigitElma added test_jax Run tests against different versions of JAX and removed test_jax Run tests against different versions of JAX labels Oct 21, 2024
@YigitElma YigitElma added test_jax Run tests against different versions of JAX and removed test_jax Run tests against different versions of JAX labels Oct 21, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug fix Something was fixed test_jax Run tests against different versions of JAX
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Error in map_coordinates with jax==0.4.34
3 participants