Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Switch back to old version of the zernike radial function #998

Merged
merged 13 commits into from
Apr 29, 2024

Conversation

YigitElma
Copy link
Collaborator

Until we decide to whether use new repo for zernike_radial, we switch back to the old version of the code.
Resolves #941

@PlasmaControl PlasmaControl deleted a comment from review-notebook-app bot Apr 17, 2024
@YigitElma YigitElma self-assigned this Apr 17, 2024
Copy link
Contributor

github-actions bot commented Apr 18, 2024

|             benchmark_name             |         dt(%)          |         dt(s)          |        t_new(s)        |        t_old(s)        | 
| -------------------------------------- | ---------------------- | ---------------------- | ---------------------- | ---------------------- |
+test_build_transform_fft_lowres         |    -69.99 +/- 2.73     | -1.21e+00 +/- 4.71e-02 |  5.18e-01 +/- 1.1e-02  |  1.73e+00 +/- 4.6e-02  |
+test_build_transform_fft_midres         |    -69.21 +/- 2.24     | -1.33e+00 +/- 4.32e-02 |  5.94e-01 +/- 5.3e-03  |  1.93e+00 +/- 4.3e-02  |
+test_build_transform_fft_highres        |    -57.20 +/- 0.78     | -1.33e+00 +/- 1.82e-02 |  9.95e-01 +/- 9.2e-03  |  2.33e+00 +/- 1.6e-02  |
+test_equilibrium_init_lowres            |    -55.98 +/- 1.12     | -5.24e+00 +/- 1.05e-01 |  4.12e+00 +/- 9.3e-02  |  9.35e+00 +/- 4.8e-02  |
+test_equilibrium_init_medres            |    -54.08 +/- 1.40     | -5.44e+00 +/- 1.41e-01 |  4.62e+00 +/- 9.5e-02  |  1.01e+01 +/- 1.0e-01  |
+test_equilibrium_init_highres           |    -45.74 +/- 1.81     | -5.37e+00 +/- 2.12e-01 |  6.37e+00 +/- 8.1e-02  |  1.17e+01 +/- 2.0e-01  |
 test_objective_compile_dshape_current   |     -1.50 +/- 7.71     | -5.57e-02 +/- 2.86e-01 |  3.65e+00 +/- 3.6e-02  |  3.71e+00 +/- 2.8e-01  |
 test_objective_compile_atf              |     -1.34 +/- 2.97     | -9.54e-02 +/- 2.12e-01 |  7.03e+00 +/- 2.0e-01  |  7.12e+00 +/- 7.3e-02  |
 test_objective_compute_dshape_current   |     +1.05 +/- 2.44     | +4.09e-05 +/- 9.52e-05 |  3.95e-03 +/- 6.8e-05  |  3.91e-03 +/- 6.7e-05  |
 test_objective_compute_atf              |     -2.36 +/- 2.30     | -4.30e-04 +/- 4.19e-04 |  1.78e-02 +/- 3.5e-04  |  1.82e-02 +/- 2.2e-04  |
 test_objective_jac_dshape_current       |     -5.85 +/- 4.50     | -2.49e-03 +/- 1.91e-03 |  4.00e-02 +/- 1.2e-03  |  4.25e-02 +/- 1.5e-03  |
 test_objective_jac_atf                  |     -3.83 +/- 2.45     | -7.53e-02 +/- 4.80e-02 |  1.89e+00 +/- 2.7e-02  |  1.96e+00 +/- 4.0e-02  |
+test_perturb_1                          |    -19.44 +/- 3.25     | -2.74e+00 +/- 4.58e-01 |  1.13e+01 +/- 1.2e-01  |  1.41e+01 +/- 4.4e-01  |
+test_perturb_2                          |    -14.23 +/- 1.75     | -2.69e+00 +/- 3.31e-01 |  1.62e+01 +/- 8.0e-02  |  1.89e+01 +/- 3.2e-01  |
 test_proximal_jac_atf                   |     -1.89 +/- 1.73     | -1.37e-01 +/- 1.25e-01 |  7.11e+00 +/- 1.0e-01  |  7.25e+00 +/- 7.1e-02  |
 test_proximal_freeb_compute             |     +1.80 +/- 1.29     | +2.29e-03 +/- 1.63e-03 |  1.29e-01 +/- 1.2e-03  |  1.27e-01 +/- 1.1e-03  |
 test_proximal_freeb_jac                 |     +0.50 +/- 1.26     | +3.62e-02 +/- 9.05e-02 |  7.21e+00 +/- 5.0e-02  |  7.18e+00 +/- 7.6e-02  |

@YigitElma YigitElma requested a review from f0uriest April 18, 2024 03:06
@YigitElma
Copy link
Collaborator Author

The compute_all test is failing due to small numerical differences between 2 methods, but I don't understand why map_coordinates_derivative test fails. I replaced the original function which was forward backward autodiff capable.

@YigitElma
Copy link
Collaborator Author

@f0uriest Can you take a look? I switched back to the old version but it has some AD issues. compute_all test is not a problem, it is due to the numerical method.

@YigitElma
Copy link
Collaborator Author

YigitElma commented Apr 24, 2024

It seems like the problem is r**m part, but I don't know why that was not an issue before.

desc/compute/utils.py:405: in get_transforms
    t.build()
desc/transform.py:385: in build
    self.matrices["direct1"][d[0]][d[1]][d[2]] = self.basis.evaluate(
desc/basis.py:1132: in evaluate
    radial = zernike_radial(r[:, np.newaxis], lm[:, 0], lm[:, 1], dr=derivatives[0])
desc/basis.py:1520: in zernike_radial
    out = r**m * _jacobi(n, alpha, beta, jacobi_arg, 0)
../../miniconda3/envs/desc-env/lib/python3.12/site-packages/jax/_src/numpy/array_methods.py:736: in op
    return getattr(self.aval, f"_{name}")(self, *args)
../../miniconda3/envs/desc-env/lib/python3.12/site-packages/jax/_src/numpy/array_methods.py:264: in deferring_binary_op
    return binary_op(*args)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

>   return _power(x1, x2)
E   jax._src.source_info_util.JaxStackTraceBeforeTransformation: AssertionError
E   
E   The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.
E   
E   --------------------

../../miniconda3/envs/desc-env/lib/python3.12/site-packages/jax/_src/numpy/ufuncs.py:355: JaxStackTraceBeforeTransformation

The above exception was the direct cause of the following exception:

@YigitElma
Copy link
Collaborator Author

YigitElma commented Apr 24, 2024

If I just run jacfwd and jit on the zernike_radial function, I get following output,

import jax
from desc.basis import zernike_radial, ZernikePolynomial

basis = ZernikePolynomial(L=10, M=10, spectral_indexing="ansi", sym="cos")
r = np.linspace(0, 1, 100)

J1 = jax.jacfwd(zernike_radial)(r[:, np.newaxis], basis.modes[:,0], basis.modes[:,1], 0)
J2 = jax.jit(jax.jacrev(zernike_radial))(r[:, np.newaxis], basis.modes[:,0], basis.modes[:,1])
J3 = jax.jit(jax.jacrev(zernike_radial))(r[:, np.newaxis], basis.modes[:,0], basis.modes[:,1], 0)
---------------------------------------------------------------------------
TracerBoolConversionError                 Traceback (most recent call last)
Cell In[52], [line 9](vscode-notebook-cell:?execution_count=52&line=9)
      J1 = jax.jacfwd(zernike_radial)(r[:, np.newaxis], basis.modes[:,0], basis.modes[:,1], 0)
      J2 = jax.jit(jax.jacrev(zernike_radial))(r[:, np.newaxis], basis.modes[:,0], basis.modes[:,1])
---->  J3 = jax.jit(jax.jacrev(zernike_radial))(r[:, np.newaxis], basis.modes[:,0], basis.modes[:,1], 0)

    [... skipping hidden 31 frame]

File [~/Codes/DESC/desc/basis.py:1519](https://file+.vscode-resource.vscode-cdn.net/home/yigit/Codes/DESC/DEBUGGING/~/Codes/DESC/desc/basis.py:1519), in zernike_radial(r, l, m, dr)
   [1517] s = (-1) ** n
   [1518] jacobi_arg = 1 - 2 * r**2
-> [1519] if dr == 0:
   [1520]    out = jnp.pow(r, m) * _jacobi(n, alpha, beta, jacobi_arg, 0)
   [1521] elif dr == 1:

    [... skipping hidden 1 frame]

File [~/miniconda3/envs/desc-env/lib/python3.12/site-packages/jax/_src/core.py:1492](https://file+.vscode-resource.vscode-cdn.net/home/yigit/Codes/DESC/DEBUGGING/~/miniconda3/envs/desc-env/lib/python3.12/site-packages/jax/_src/core.py:1492), in concretization_function_error.<locals>.error(self, arg)
   [1491] def error(self, arg):
-> [1492] raise TracerBoolConversionError(arg)

TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[]..
The error occurred while tracing the function zernike_radial at [/home/yigit/Codes/DESC/desc/basis.py:1487](https://file+.vscode-resource.vscode-cdn.net/home/yigit/Codes/DESC/desc/basis.py:1487) for jit. This value became a tracer due to JAX operations on these lines:

  operation a:bool[] = eq b c
    from line [/home/yigit/Codes/DESC/desc/basis.py:1519:7](https://file+.vscode-resource.vscode-cdn.net/home/yigit/Codes/DESC/desc/basis.py:1519:7) (zernike_radial)
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError

which seems like if you try to jit a function which has partial(jit, static_argnum=...) decorator not by calling it by executing jax.jit(func), it doesn't care about the static_argnum part. Only if the parameter dr is not given, jit works as expected.

Though, these 2 error messages are quite different from each other, the former complains about the r**m*_jacobi part, the latter says dr is not static, it is traced. I am confused.

Here, I am using directly on zernike_radial but I got similar behavior for nested functions. For example,
import jax
import functools
import jax.numpy as jnp

@functools.partial(jax.jit, static_argnums=1)
def fun(x, dx=0):
    if dx == 0:
        return x
    else:
        return x**dx
    
@jax.jit
def main(t, dx=0):
    x = t**2
    return fun(x, dx)

J1 = jax.jit(main)(jnp.array([1, 2, 3]), 2)
---------------------------------------------------------------------------
TracerBoolConversionError                 Traceback (most recent call last)
Cell In[53], [line 1](vscode-notebook-cell:?execution_count=53&line=1)
----> [1](vscode-notebook-cell:?execution_count=53&line=1) asd = jax.jit(main)(jnp.array([1, 2, 3]), 2)
      [2](vscode-notebook-cell:?execution_count=53&line=2) print(asd)

    [... skipping hidden 24 frame]

Cell In[44], [line 14](vscode-notebook-cell:?execution_count=44&line=14)
     [11](vscode-notebook-cell:?execution_count=44&line=11) @jax.jit
     [12](vscode-notebook-cell:?execution_count=44&line=12) def main(t, dx=0):
     [13](vscode-notebook-cell:?execution_count=44&line=13)     x = t**2
---> [14](vscode-notebook-cell:?execution_count=44&line=14)     return fun(x, dx)

    [... skipping hidden 12 frame]

Cell In[44], [line 6](vscode-notebook-cell:?execution_count=44&line=6)
      [4](vscode-notebook-cell:?execution_count=44&line=4) @functools.partial(jax.jit, static_argnums=1)
      [5](vscode-notebook-cell:?execution_count=44&line=5) def fun(x, dx=0):
----> [6](vscode-notebook-cell:?execution_count=44&line=6)     if dx == 0:
      [7](vscode-notebook-cell:?execution_count=44&line=7)         return x
      [8](vscode-notebook-cell:?execution_count=44&line=8)     else:

    [... skipping hidden 1 frame]

File [~/miniconda3/envs/desc-env/lib/python3.12/site-packages/jax/_src/core.py:1492](https://file+.vscode-resource.vscode-cdn.net/home/yigit/Codes/DESC/DEBUGGING/~/miniconda3/envs/desc-env/lib/python3.12/site-packages/jax/_src/core.py:1492), in concretization_function_error.<locals>.error(self, arg)
   [1491](https://file+.vscode-resource.vscode-cdn.net/home/yigit/Codes/DESC/DEBUGGING/~/miniconda3/envs/desc-env/lib/python3.12/site-packages/jax/_src/core.py:1491) def error(self, arg):
-> [1492](https://file+.vscode-resource.vscode-cdn.net/home/yigit/Codes/DESC/DEBUGGING/~/miniconda3/envs/desc-env/lib/python3.12/site-packages/jax/_src/core.py:1492)   raise TracerBoolConversionError(arg)

TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[]..
The error occurred while tracing the function fun at [/tmp/ipykernel_31334/1112096422.py:4](https://file+.vscode-resource.vscode-cdn.net/tmp/ipykernel_31334/1112096422.py:4) for jit. This value became a tracer due to JAX operations on these lines:

  operation a:bool[] = eq b c
    from line [/tmp/ipykernel_31334/1112096422.py:6:7](https://file+.vscode-resource.vscode-cdn.net/tmp/ipykernel_31334/1112096422.py:6:7) (fun)
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError

@f0uriest
Copy link
Member

Yes, if you jit a function it basically ignores any other jit decorators that are inside that function. So you need to specify dr as static at the outermost level. But in our case its fine since the values of dr are always called statically.

The r**m thing might be related to this?: #961 (comment)

@YigitElma
Copy link
Collaborator Author

YigitElma commented Apr 25, 2024

Hmm, I checked the branch with the old dependencies that are changed with PR #883. The tests pass as expected. I think either JAX, numpy or scipy changed something, which causes the issue. I only changed the current requirements to

jax[cpu] >= 0.3.2, <= 0.4.14
numpy >= 1.20.0, < 1.25.0
scipy >= 1.5.0, < 1.11.0

and it works. So, we would have faced this issue when we bumped dependencies even if we had not changed the zernike_radial to my version. I can look into it in detail.

And also during the process, I realized that installing jax[cpu]==0.4.14 on Python 3.12 causes the issue even without the different dependency versions mentioned in #1004

@YigitElma
Copy link
Collaborator Author

Yes, if you jit a function it basically ignores any other jit decorators that are inside that function. So you need to specify dr as static at the outermost level. But in our case its fine since the values of dr are always called statically.

The r**m thing might be related to this?: #961 (comment)

Could this be related to this change? jax-ml/jax@69ad4df Apparently, this is implemented right after 0.4.14

Copy link

codecov bot commented Apr 25, 2024

Codecov Report

Attention: Patch coverage is 98.78049% with 1 lines in your changes are missing coverage. Please review.

Project coverage is 94.92%. Comparing base (83c231a) to head (35ec25b).

Additional details and impacted files
@@           Coverage Diff            @@
##           master     #998    +/-   ##
========================================
  Coverage   94.92%   94.92%            
========================================
  Files          87       87            
  Lines       22003    21849   -154     
========================================
- Hits        20886    20741   -145     
+ Misses       1117     1108     -9     
Files Coverage Δ
desc/equilibrium/initial_guess.py 97.90% <ø> (-0.02%) ⬇️
desc/objectives/linear_objectives.py 94.61% <ø> (ø)
desc/vmec_utils.py 83.02% <100.00%> (ø)
desc/basis.py 98.34% <98.75%> (+0.98%) ⬆️

... and 1 file with indirect coverage changes

@dpanici
Copy link
Collaborator

dpanici commented Apr 26, 2024

How does this affect say an eq solve on cpu and GPU? just want to remember what the difference was

@YigitElma
Copy link
Collaborator Author

How does this affect say an eq solve on cpu and GPU? just want to remember what the difference was

You can refer issue #941, I have bunch of benchmarks there. For an actual equilibrium solve, it shouldn't change much.

@YigitElma
Copy link
Collaborator Author

Also, a reminder for myself. I tested running different derivative orders multiple times for the version in this PR. Since it has a static argument, for every different value of that argument, JAX has to jit the function again. For the previous version, all derivative compilations happen at the same time, so it doesn't require multiple compilations. My concern was if the JAX replaces the jitted version of the function with the new jitted one with a different static argument and hence compiles the function every time we change the argument. Fortunately, the answer is no. It compiles the function stores every version in the cache and calls the one with a matching static argument. I don't know the impact of this on the memory consumption but it is worth keeping in mind. YES, the static_argnum flag is cool, but it has an effect.

@YigitElma YigitElma merged commit 8b2fd29 into master Apr 29, 2024
20 checks passed
@YigitElma YigitElma deleted the yge/zernike branch April 29, 2024 19:04
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Regression in _initial_guess_surface from v0.10.4
3 participants