Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make map_coordinates differentiable for JAX 0.4.34 #1293

Open
wants to merge 42 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
c3f399d
fix the custom_jvp problem with new jax version, we don't need to mul…
YigitElma Oct 4, 2024
641207d
Merge branch 'master' into yge/customjvp_fix
YigitElma Oct 4, 2024
7e2c884
revert back to latest JAX version
YigitElma Oct 4, 2024
ec9a9f8
update jax_test versions
YigitElma Oct 4, 2024
b71afc6
have jax_tests use python 3.10 so can test recent versions
dpanici Oct 6, 2024
979a5fb
fix python version
dpanici Oct 6, 2024
449eb70
revert back to multiplying by 0s
YigitElma Oct 6, 2024
2a16943
Merge branch 'yge/customjvp_fix' of github.com:PlasmaControl/DESC int…
YigitElma Oct 6, 2024
21e5f8d
fix python version syntax
YigitElma Oct 6, 2024
8137caf
fix jax test to install jax first and then decide other dependency ve…
YigitElma Oct 6, 2024
1b48da4
update matplotlib latest version on mpl test, force 3.7.2 for none ca…
YigitElma Oct 6, 2024
52957ad
fix jax test
YigitElma Oct 6, 2024
531c5c4
fix incorrect name of dev-requirements
dpanici Oct 7, 2024
c4e75f7
update jax dependency test
YigitElma Oct 7, 2024
743b0fe
Merge branch 'yge/customjvp_fix' of github.com:PlasmaControl/DESC int…
YigitElma Oct 7, 2024
fed17e5
clean-up the dependency installation and print installed dependencies…
YigitElma Oct 7, 2024
e225308
take jax dependency to the top of the file
YigitElma Oct 7, 2024
9631353
add custom_jvp to zernike_radial directly, this won't work for jax 0.…
YigitElma Oct 7, 2024
891c914
fix missing argument dr problem
YigitElma Oct 7, 2024
ff3a0e9
re-add the _jacobi custom_jvp for the test but actually not needed fo…
YigitElma Oct 7, 2024
b4218ac
Merge branch 'master' into yge/customjvp_fix
YigitElma Oct 11, 2024
49c2ac8
use nondiff_argnums for zernike_radial custom_jvp
YigitElma Oct 15, 2024
e55d84e
Merge branch 'yge/customjvp_fix' of github.com:PlasmaControl/DESC int…
YigitElma Oct 15, 2024
4c87756
revert back to old version, nondiff creates some unexpected tracers f…
YigitElma Oct 15, 2024
7256bce
Merge branch 'master' into yge/customjvp_fix
YigitElma Oct 15, 2024
937a547
fix auxilary return values in a hacky way until JAX people reply
YigitElma Oct 19, 2024
7660d57
Merge branch 'master' into yge/customjvp_fix
YigitElma Oct 19, 2024
92bf3d1
apply the same fix to root_scalar, add docstring, apply same to numpy…
YigitElma Oct 19, 2024
e10b2e4
fix constant_offset_surface function
YigitElma Oct 19, 2024
c40eaf1
make full_output case also differentiable, increase coverage
YigitElma Oct 19, 2024
b7675e6
fix zeta phi problem causing nans
YigitElma Oct 20, 2024
6364dba
make test_map_coordinates_derivative test different cases of map_coor…
YigitElma Oct 20, 2024
b376c67
move matplotlib changes to new PR
YigitElma Oct 20, 2024
b6061ce
move matplotlib changes to new PR
YigitElma Oct 20, 2024
895589a
add root and root_scalar tests as well as their derivatives
YigitElma Oct 20, 2024
3d94830
revert float64 stuff
YigitElma Oct 21, 2024
7a276e5
bump minimum version of jax to 0.4.24
YigitElma Oct 21, 2024
31d0f8a
Merge branch 'master' into yge/customjvp_fix
YigitElma Oct 21, 2024
0542f0e
just checking jax version
YigitElma Oct 21, 2024
3c83c04
try to test scipy
YigitElma Oct 21, 2024
732a951
back to previous version
YigitElma Oct 21, 2024
ee24838
fix conda requirements
YigitElma Oct 21, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 19 additions & 18 deletions .github/workflows/jax_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -12,34 +12,35 @@ jobs:
strategy:
fail-fast: false
matrix:
jax-version: [0.3.0, 0.3.1, 0.3.2, 0.3.3, 0.3.4, 0.3.5,
0.3.6, 0.3.7, 0.3.8, 0.3.9, 0.3.10, 0.3.11,
0.3.12, 0.3.13, 0.3.14, 0.3.15, 0.3.16, 0.3.17,
0.3.19, 0.3.20, 0.3.21, 0.3.22, 0.3.23, 0.3.24,
0.3.25, 0.4.1, 0.4.2, 0.4.3, 0.4.4, 0.4.5,
0.4.6, 0.4.7, 0.4.8, 0.4.9, 0.4.10, 0.4.11,
0.4.12, 0.4.13, 0.4.14, 0.4.16, 0.4.17, 0.4.18,
jax-version: [0.4.12, 0.4.13, 0.4.14, 0.4.16, 0.4.17,
0.4.18, 0.4.19, 0.4.20, 0.4.21, 0.4.22, 0.4.23,
0.4.24, 0.4.25]
0.4.24, 0.4.25, 0.4.26, 0.4.27, 0.4.28, 0.4.29,
0.4.30, 0.4.31, 0.4.33, 0.4.34]
# 0.4.32 is not available on PyPI
# earlier jax versions are not compatible with other
# dependencies as of 2024-10-04
group: [1, 2]
steps:
- uses: actions/checkout@v4
- name: Set up Python 3.9
- name: Set up Python 3.10
uses: actions/setup-python@v5
with:
python-version: 3.9
cache: pip
- name: Install dependencies
python-version: '3.10'
- name: Upgrade pip
run: |
python -m pip install --upgrade pip
pip install -r devtools/dev-requirements.txt
pip install matplotlib==3.5.0
- name: Remove jax
- name: Install dependencies with given JAX version
run: |
pip uninstall jax jaxlib -y
- name: install jax
sed -i '/jax/d' ./requirements.txt
sed -i '1i\jax[cpu] == ${{ matrix.jax-version }}' ./requirements.txt
cat ./requirements.txt
pip install -r ./devtools/dev-requirements.txt
pip install matplotlib==3.7.2
- name: Verify dependencies
run: |
pip install "jax[cpu]==${{ matrix.jax-version }}"
python --version
pip --version
pip list
- name: Test with pytest
run: |
pwd
Expand Down
73 changes: 56 additions & 17 deletions desc/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,9 +87,12 @@
treedef_is_leaf,
)

trapezoid = (
jnp.trapezoid if hasattr(jnp, "trapezoid") else jax.scipy.integrate.trapezoid
)
if hasattr(jnp, "trapezoid"):
trapezoid = jnp.trapezoid # for JAX 0.4.26 and later
elif hasattr(jax.scipy, "integrate"):
trapezoid = jax.scipy.integrate.trapezoid

Check warning on line 93 in desc/backend.py

View check run for this annotation

Codecov / codecov/patch

desc/backend.py#L92-L93

Added lines #L92 - L93 were not covered by tests
else:
trapezoid = jnp.trapz # for older versions of JAX, deprecated by jax 0.4.16

Check warning on line 95 in desc/backend.py

View check run for this annotation

Codecov / codecov/patch

desc/backend.py#L95

Added line #L95 was not covered by tests

def execute_on_cpu(func):
"""Decorator to set default device to CPU for a function.
Expand Down Expand Up @@ -198,6 +201,7 @@
maxiter_ls=5,
alpha=0.1,
fixup=None,
full_output=False,
):
"""Find x where fun(x, *args) == 0.

Expand Down Expand Up @@ -225,6 +229,9 @@
fixup : callable, optional
Function to modify x after each update, ie to enforce periodicity. Should
have a signature of the form fixup(x, *args) -> x'.
full_output : bool, optional
If True, also return a tuple where the first element is the residual from
the root finding and the second is the number of iterations.

Returns
-------
Expand Down Expand Up @@ -269,18 +276,25 @@
xk1, fk1 = backtrack(xk1, fk1, d)
return xk1, fk1, k1 + 1

state = guess, res(guess), 0
state = guess, res(guess), 0.0
f0uriest marked this conversation as resolved.
Show resolved Hide resolved
state = jax.lax.while_loop(condfun, bodyfun, state)
return state[0], state[1:]
if full_output:
return state[0], state[1:]
else:
return state[0]

def tangent_solve(g, y):
A = jax.jacfwd(g)(y)
return y / A

x, (res, niter) = jax.lax.custom_root(
res, x0, solve, tangent_solve, has_aux=True
)
return x, (abs(res), niter)
if full_output:
x, (res, niter) = jax.lax.custom_root(
res, x0, solve, tangent_solve, has_aux=True
)
return x, (abs(res), niter)
else:
x = jax.lax.custom_root(res, x0, solve, tangent_solve, has_aux=False)
return x

def root(
fun,
Expand All @@ -292,6 +306,7 @@
maxiter_ls=0,
alpha=0.1,
fixup=None,
full_output=False,
):
"""Find x where fun(x, *args) == 0.

Expand Down Expand Up @@ -319,6 +334,9 @@
fixup : callable, optional
Function to modify x after each update, ie to enforce periodicity. Should
have a signature of the form fixup(x, *args) -> 1d array.
full_output : bool, optional
If True, also return a tuple where the first element is the residual from
the root finding and the second is the number of iterations.

Returns
-------
Expand Down Expand Up @@ -386,19 +404,26 @@
state = (
jnp.atleast_1d(jnp.asarray(guess)),
jnp.atleast_1d(resfun(guess)),
0,
0.0,
)
state = jax.lax.while_loop(condfun, bodyfun, state)
return state[0], state[1:]
if full_output:
return state[0], state[1:]
else:
return state[0]

def tangent_solve(g, y):
A = jnp.atleast_2d(jax.jacfwd(g)(y))
return _lstsq(A, jnp.atleast_1d(y))

x, (res, niter) = jax.lax.custom_root(
res, x0, solve, tangent_solve, has_aux=True
)
return x, (safenorm(res), niter)
if full_output:
x, (res, niter) = jax.lax.custom_root(
res, x0, solve, tangent_solve, has_aux=True
)
return x, (safenorm(res), niter)
else:
x = jax.lax.custom_root(res, x0, solve, tangent_solve, has_aux=False)
return x


# we can't really test the numpy backend stuff in automated testing, so we ignore it
Expand Down Expand Up @@ -708,6 +733,7 @@
maxiter_ls=5,
alpha=0.1,
fixup=None,
full_output=False,
):
"""Find x where fun(x, *args) == 0.

Expand Down Expand Up @@ -735,6 +761,9 @@
fixup : callable, optional
Function to modify x after each update, ie to enforce periodicity. Should
have a signature of the form fixup(x) -> x'.
full_output : bool, optional
If True, also return a tuple where the first element is the residual from
the root finding and the second is the number of iterations.

Returns
-------
Expand All @@ -747,7 +776,10 @@
out = scipy.optimize.root_scalar(
fun, args, x0=x0, fprime=jac, xtol=tol, rtol=tol
)
return out.root, out
if full_output:
return out.root, out
else:
return out.root

def root(
fun,
Expand All @@ -759,6 +791,7 @@
maxiter_ls=0,
alpha=0.1,
fixup=None,
full_output=False,
):
"""Find x where fun(x, *args) == 0.

Expand Down Expand Up @@ -786,6 +819,9 @@
fixup : callable, optional
Function to modify x after each update, ie to enforce periodicity. Should
have a signature of the form fixup(x, *args) -> 1d array.
full_output : bool, optional
If True, also return a tuple where the first element is the residual from
the root finding and the second is the number of iterations.

Returns
-------
Expand All @@ -800,7 +836,10 @@
will solve it in a least squares sense.
"""
out = scipy.optimize.root(fun, x0, args, jac=jac, tol=tol)
return out.x, out
if full_output:
return out.x, out
else:
return out.x

def flatnonzero(a, size=None, fill_value=0):
"""A numpy implementation of jnp.flatnonzero."""
Expand Down
12 changes: 6 additions & 6 deletions desc/basis.py
YigitElma marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
Expand Up @@ -1573,7 +1573,7 @@ def zernike_radial(r, l, m, dr=0):
"Analytic radial derivatives of Zernike polynomials for order>4 "
+ "have not been implemented."
)
return s * jnp.where((l - m) % 2 == 0, out, 0)
return s * jnp.where((l - m) % 2 == 0, out, 0.0)


def power_coeffs(l):
Expand Down Expand Up @@ -1732,7 +1732,7 @@ def _binom_body_fun(i, b_n):
return b


@custom_jvp
@functools.partial(custom_jvp, nondiff_argnums=(4,))
@jit
@jnp.vectorize
def _jacobi(n, alpha, beta, x, dx=0):
Expand Down Expand Up @@ -1804,13 +1804,13 @@ def _jacobi_body_fun(kk, d_p_a_b_x):


@_jacobi.defjvp
def _jacobi_jvp(x, xdot):
(n, alpha, beta, x, dx) = x
(ndot, alphadot, betadot, xdot, dxdot) = xdot
def _jacobi_jvp(dx, x, xdot):
(n, alpha, beta, x) = x
(*_, xdot) = xdot
f = _jacobi(n, alpha, beta, x, dx)
df = _jacobi(n, alpha, beta, x, dx + 1)
# in theory n, alpha, beta, dx aren't differentiable (they're integers)
# but marking them as non-diff argnums seems to cause escaped tracer values.
# probably a more elegant fix, but just setting those derivatives to zero seems
# to work fine.
return f, df * xdot + 0 * ndot + 0 * alphadot + 0 * betadot + 0 * dxdot
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I deleted redundant 0 multiplications because in some cases, this gives an error saying float0 cannot be used in math operations...

return f, df * xdot
36 changes: 27 additions & 9 deletions desc/equilibrium/coords.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,14 +208,18 @@
fixup=fixup,
tol=tol,
maxiter=maxiter,
full_output=full_output,
**kwargs,
)
)
)
# See description here
# https://github.com/PlasmaControl/DESC/pull/504#discussion_r1194172532
# except we make sure properly handle periodic coordinates.
yk, (res, niter) = vecroot(yk, coords)
if full_output:
yk, (res, niter) = vecroot(yk, coords)

Check warning on line 220 in desc/equilibrium/coords.py

View check run for this annotation

Codecov / codecov/patch

desc/equilibrium/coords.py#L220

Added line #L220 was not covered by tests
else:
yk = vecroot(yk, coords)

out = compute(yk, outbasis)
if full_output:
Expand Down Expand Up @@ -363,18 +367,28 @@
fixup=fixup,
tol=tol,
maxiter=maxiter,
full_output=full_output,
**kwargs,
)
)
)
rho, theta_PEST, zeta = coords.T
theta, (res, niter) = vecroot(
# Assume λ=0 for default initial guess.
setdefault(guess, theta_PEST),
theta_PEST,
rho,
zeta,
)
if full_output:
theta, (res, niter) = vecroot(

Check warning on line 377 in desc/equilibrium/coords.py

View check run for this annotation

Codecov / codecov/patch

desc/equilibrium/coords.py#L377

Added line #L377 was not covered by tests
# Assume λ=0 for default initial guess.
setdefault(guess, theta_PEST),
theta_PEST,
rho,
zeta,
)
else:
theta = vecroot(
# Assume λ=0 for default initial guess.
setdefault(guess, theta_PEST),
theta_PEST,
rho,
zeta,
)
out = jnp.column_stack([rho, jnp.atleast_1d(theta.squeeze()), zeta])
if full_output:
return out, (res, niter)
Expand Down Expand Up @@ -466,6 +480,7 @@
fixup=fixup,
tol=tol,
maxiter=maxiter,
full_output=full_output,
**kwargs,
)
)
Expand All @@ -474,7 +489,10 @@
if guess is None:
# Assume λ=0 for default initial guess.
guess = alpha + iota * zeta
theta, (res, niter) = vecroot(guess, alpha, rho, zeta, iota)
if full_output:
theta, (res, niter) = vecroot(guess, alpha, rho, zeta, iota)
else:
theta = vecroot(guess, alpha, rho, zeta, iota)

out = jnp.column_stack([rho, jnp.atleast_1d(theta.squeeze()), zeta])
if full_output:
Expand Down
15 changes: 12 additions & 3 deletions desc/geometry/surface.py
Original file line number Diff line number Diff line change
Expand Up @@ -741,10 +741,19 @@
n, r, r_offset = n_and_r_jax(nodes)
return jnp.arctan(r_offset[0, 1] / r_offset[0, 0]) - zeta

vecroot = jit(vmap(lambda x0, *p: root_scalar(fun_jax, x0, jac=None, args=p)))
zetas, (res, niter) = vecroot(
grid.nodes[:, 2], grid.nodes[:, 1], grid.nodes[:, 2]
vecroot = jit(
vmap(
lambda x0, *p: root_scalar(
fun_jax, x0, jac=None, args=p, full_output=full_output
)
)
)
if full_output:
zetas, (res, niter) = vecroot(
grid.nodes[:, 2], grid.nodes[:, 1], grid.nodes[:, 2]
)
else:
zetas = vecroot(grid.nodes[:, 2], grid.nodes[:, 1], grid.nodes[:, 2])

Check warning on line 756 in desc/geometry/surface.py

View check run for this annotation

Codecov / codecov/patch

desc/geometry/surface.py#L756

Added line #L756 was not covered by tests

zetas = np.asarray(zetas)
nodes = np.vstack((np.ones_like(grid.nodes[:, 1]), grid.nodes[:, 1], zetas)).T
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
jax[cpu] >= 0.4.13, < 0.5.0
colorama
diffrax >= 0.4.1
h5py >= 3.0.0, < 4.0
interpax >= 0.3.3
jax[cpu] >= 0.3.2, < 0.4.34
matplotlib >= 3.5.0, < 4.0.0
mpmath >= 1.0.0, < 2.0
netcdf4 >= 1.5.4, < 2.0
Expand Down
2 changes: 1 addition & 1 deletion requirements_conda.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ dependencies:
- pip:
# Conda only parses a single list of pip requirements.
# If two pip lists are given, all but the last list is skipped.
- jax[cpu] >= 0.4.13, < 0.5.0
- interpax >= 0.3.3
- jax[cpu] >= 0.3.2, < 0.5.0
- nvgpu
- orthax
- plotly >= 5.16, < 6.0
Expand Down
Loading