-
Notifications
You must be signed in to change notification settings - Fork 27
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Switch back to old version of the zernike radial function #998
Conversation
| benchmark_name | dt(%) | dt(s) | t_new(s) | t_old(s) |
| -------------------------------------- | ---------------------- | ---------------------- | ---------------------- | ---------------------- |
+test_build_transform_fft_lowres | -69.99 +/- 2.73 | -1.21e+00 +/- 4.71e-02 | 5.18e-01 +/- 1.1e-02 | 1.73e+00 +/- 4.6e-02 |
+test_build_transform_fft_midres | -69.21 +/- 2.24 | -1.33e+00 +/- 4.32e-02 | 5.94e-01 +/- 5.3e-03 | 1.93e+00 +/- 4.3e-02 |
+test_build_transform_fft_highres | -57.20 +/- 0.78 | -1.33e+00 +/- 1.82e-02 | 9.95e-01 +/- 9.2e-03 | 2.33e+00 +/- 1.6e-02 |
+test_equilibrium_init_lowres | -55.98 +/- 1.12 | -5.24e+00 +/- 1.05e-01 | 4.12e+00 +/- 9.3e-02 | 9.35e+00 +/- 4.8e-02 |
+test_equilibrium_init_medres | -54.08 +/- 1.40 | -5.44e+00 +/- 1.41e-01 | 4.62e+00 +/- 9.5e-02 | 1.01e+01 +/- 1.0e-01 |
+test_equilibrium_init_highres | -45.74 +/- 1.81 | -5.37e+00 +/- 2.12e-01 | 6.37e+00 +/- 8.1e-02 | 1.17e+01 +/- 2.0e-01 |
test_objective_compile_dshape_current | -1.50 +/- 7.71 | -5.57e-02 +/- 2.86e-01 | 3.65e+00 +/- 3.6e-02 | 3.71e+00 +/- 2.8e-01 |
test_objective_compile_atf | -1.34 +/- 2.97 | -9.54e-02 +/- 2.12e-01 | 7.03e+00 +/- 2.0e-01 | 7.12e+00 +/- 7.3e-02 |
test_objective_compute_dshape_current | +1.05 +/- 2.44 | +4.09e-05 +/- 9.52e-05 | 3.95e-03 +/- 6.8e-05 | 3.91e-03 +/- 6.7e-05 |
test_objective_compute_atf | -2.36 +/- 2.30 | -4.30e-04 +/- 4.19e-04 | 1.78e-02 +/- 3.5e-04 | 1.82e-02 +/- 2.2e-04 |
test_objective_jac_dshape_current | -5.85 +/- 4.50 | -2.49e-03 +/- 1.91e-03 | 4.00e-02 +/- 1.2e-03 | 4.25e-02 +/- 1.5e-03 |
test_objective_jac_atf | -3.83 +/- 2.45 | -7.53e-02 +/- 4.80e-02 | 1.89e+00 +/- 2.7e-02 | 1.96e+00 +/- 4.0e-02 |
+test_perturb_1 | -19.44 +/- 3.25 | -2.74e+00 +/- 4.58e-01 | 1.13e+01 +/- 1.2e-01 | 1.41e+01 +/- 4.4e-01 |
+test_perturb_2 | -14.23 +/- 1.75 | -2.69e+00 +/- 3.31e-01 | 1.62e+01 +/- 8.0e-02 | 1.89e+01 +/- 3.2e-01 |
test_proximal_jac_atf | -1.89 +/- 1.73 | -1.37e-01 +/- 1.25e-01 | 7.11e+00 +/- 1.0e-01 | 7.25e+00 +/- 7.1e-02 |
test_proximal_freeb_compute | +1.80 +/- 1.29 | +2.29e-03 +/- 1.63e-03 | 1.29e-01 +/- 1.2e-03 | 1.27e-01 +/- 1.1e-03 |
test_proximal_freeb_jac | +0.50 +/- 1.26 | +3.62e-02 +/- 9.05e-02 | 7.21e+00 +/- 5.0e-02 | 7.18e+00 +/- 7.6e-02 | |
The compute_all test is failing due to small numerical differences between 2 methods, but I don't understand why map_coordinates_derivative test fails. I replaced the original function which was forward backward autodiff capable. |
@f0uriest Can you take a look? I switched back to the old version but it has some AD issues. compute_all test is not a problem, it is due to the numerical method. |
It seems like the problem is desc/compute/utils.py:405: in get_transforms
t.build()
desc/transform.py:385: in build
self.matrices["direct1"][d[0]][d[1]][d[2]] = self.basis.evaluate(
desc/basis.py:1132: in evaluate
radial = zernike_radial(r[:, np.newaxis], lm[:, 0], lm[:, 1], dr=derivatives[0])
desc/basis.py:1520: in zernike_radial
out = r**m * _jacobi(n, alpha, beta, jacobi_arg, 0)
../../miniconda3/envs/desc-env/lib/python3.12/site-packages/jax/_src/numpy/array_methods.py:736: in op
return getattr(self.aval, f"_{name}")(self, *args)
../../miniconda3/envs/desc-env/lib/python3.12/site-packages/jax/_src/numpy/array_methods.py:264: in deferring_binary_op
return binary_op(*args)
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _
> return _power(x1, x2)
E jax._src.source_info_util.JaxStackTraceBeforeTransformation: AssertionError
E
E The preceding stack trace is the source of the JAX operation that, once transformed by JAX, triggered the following exception.
E
E --------------------
../../miniconda3/envs/desc-env/lib/python3.12/site-packages/jax/_src/numpy/ufuncs.py:355: JaxStackTraceBeforeTransformation
The above exception was the direct cause of the following exception: |
If I just run jacfwd and jit on the zernike_radial function, I get following output, import jax
from desc.basis import zernike_radial, ZernikePolynomial
basis = ZernikePolynomial(L=10, M=10, spectral_indexing="ansi", sym="cos")
r = np.linspace(0, 1, 100)
J1 = jax.jacfwd(zernike_radial)(r[:, np.newaxis], basis.modes[:,0], basis.modes[:,1], 0)
J2 = jax.jit(jax.jacrev(zernike_radial))(r[:, np.newaxis], basis.modes[:,0], basis.modes[:,1])
J3 = jax.jit(jax.jacrev(zernike_radial))(r[:, np.newaxis], basis.modes[:,0], basis.modes[:,1], 0)
which seems like if you try to jit a function which has partial(jit, static_argnum=...) decorator not by calling it by executing jax.jit(func), it doesn't care about the static_argnum part. Only if the parameter dr is not given, jit works as expected. Though, these 2 error messages are quite different from each other, the former complains about the r**m*_jacobi part, the latter says dr is not static, it is traced. I am confused.
Here, I am using directly on zernike_radial but I got similar behavior for nested functions. For example,
import jax
import functools
import jax.numpy as jnp
@functools.partial(jax.jit, static_argnums=1)
def fun(x, dx=0):
if dx == 0:
return x
else:
return x**dx
@jax.jit
def main(t, dx=0):
x = t**2
return fun(x, dx)
J1 = jax.jit(main)(jnp.array([1, 2, 3]), 2) ---------------------------------------------------------------------------
TracerBoolConversionError Traceback (most recent call last)
Cell In[53], [line 1](vscode-notebook-cell:?execution_count=53&line=1)
----> [1](vscode-notebook-cell:?execution_count=53&line=1) asd = jax.jit(main)(jnp.array([1, 2, 3]), 2)
[2](vscode-notebook-cell:?execution_count=53&line=2) print(asd)
[... skipping hidden 24 frame]
Cell In[44], [line 14](vscode-notebook-cell:?execution_count=44&line=14)
[11](vscode-notebook-cell:?execution_count=44&line=11) @jax.jit
[12](vscode-notebook-cell:?execution_count=44&line=12) def main(t, dx=0):
[13](vscode-notebook-cell:?execution_count=44&line=13) x = t**2
---> [14](vscode-notebook-cell:?execution_count=44&line=14) return fun(x, dx)
[... skipping hidden 12 frame]
Cell In[44], [line 6](vscode-notebook-cell:?execution_count=44&line=6)
[4](vscode-notebook-cell:?execution_count=44&line=4) @functools.partial(jax.jit, static_argnums=1)
[5](vscode-notebook-cell:?execution_count=44&line=5) def fun(x, dx=0):
----> [6](vscode-notebook-cell:?execution_count=44&line=6) if dx == 0:
[7](vscode-notebook-cell:?execution_count=44&line=7) return x
[8](vscode-notebook-cell:?execution_count=44&line=8) else:
[... skipping hidden 1 frame]
File [~/miniconda3/envs/desc-env/lib/python3.12/site-packages/jax/_src/core.py:1492](https://file+.vscode-resource.vscode-cdn.net/home/yigit/Codes/DESC/DEBUGGING/~/miniconda3/envs/desc-env/lib/python3.12/site-packages/jax/_src/core.py:1492), in concretization_function_error.<locals>.error(self, arg)
[1491](https://file+.vscode-resource.vscode-cdn.net/home/yigit/Codes/DESC/DEBUGGING/~/miniconda3/envs/desc-env/lib/python3.12/site-packages/jax/_src/core.py:1491) def error(self, arg):
-> [1492](https://file+.vscode-resource.vscode-cdn.net/home/yigit/Codes/DESC/DEBUGGING/~/miniconda3/envs/desc-env/lib/python3.12/site-packages/jax/_src/core.py:1492) raise TracerBoolConversionError(arg)
TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[]..
The error occurred while tracing the function fun at [/tmp/ipykernel_31334/1112096422.py:4](https://file+.vscode-resource.vscode-cdn.net/tmp/ipykernel_31334/1112096422.py:4) for jit. This value became a tracer due to JAX operations on these lines:
operation a:bool[] = eq b c
from line [/tmp/ipykernel_31334/1112096422.py:6:7](https://file+.vscode-resource.vscode-cdn.net/tmp/ipykernel_31334/1112096422.py:6:7) (fun)
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError |
Yes, if you jit a function it basically ignores any other jit decorators that are inside that function. So you need to specify dr as static at the outermost level. But in our case its fine since the values of dr are always called statically. The |
Hmm, I checked the branch with the old dependencies that are changed with PR #883. The tests pass as expected. I think either
and it works. So, we would have faced this issue when we bumped dependencies even if we had not changed the And also during the process, I realized that installing |
Could this be related to this change? jax-ml/jax@69ad4df Apparently, this is implemented right after 0.4.14 |
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #998 +/- ##
========================================
Coverage 94.92% 94.92%
========================================
Files 87 87
Lines 22003 21849 -154
========================================
- Hits 20886 20741 -145
+ Misses 1117 1108 -9
|
How does this affect say an eq solve on cpu and GPU? just want to remember what the difference was |
You can refer issue #941, I have bunch of benchmarks there. For an actual equilibrium solve, it shouldn't change much. |
Also, a reminder for myself. I tested running different derivative orders multiple times for the version in this PR. Since it has a static argument, for every different value of that argument, JAX has to jit the function again. For the previous version, all derivative compilations happen at the same time, so it doesn't require multiple compilations. My concern was if the JAX replaces the jitted version of the function with the new jitted one with a different static argument and hence compiles the function every time we change the argument. Fortunately, the answer is no. It compiles the function stores every version in the cache and calls the one with a matching static argument. I don't know the impact of this on the memory consumption but it is worth keeping in mind. YES, the static_argnum flag is cool, but it has an effect. |
Until we decide to whether use new repo for zernike_radial, we switch back to the old version of the code.
Resolves #941