Skip to content

Commit

Permalink
Merge branch 'fieldline_compute' into integrate_on_boundary
Browse files Browse the repository at this point in the history
  • Loading branch information
unalmis committed Jul 2, 2024
2 parents cc01bce + 369a013 commit 56ead44
Show file tree
Hide file tree
Showing 20 changed files with 1,237 additions and 527 deletions.
45 changes: 44 additions & 1 deletion desc/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,10 @@
switch = jax.lax.switch
while_loop = jax.lax.while_loop
vmap = jax.vmap
scan = jax.lax.scan
bincount = jnp.bincount
repeat = jnp.repeat
take = jnp.take
scan = jax.lax.scan
from jax import custom_jvp
from jax.experimental.ode import odeint
from jax.scipy.linalg import block_diag, cho_factor, cho_solve, qr, solve_triangular
Expand Down Expand Up @@ -635,6 +637,13 @@ def bincount(x, weights=None, minlength=None, length=None):
"""Same as np.bincount but with a dummy parameter to match jnp.bincount API."""
return np.bincount(x, weights, minlength)

def repeat(a, repeats, axis=None, total_repeat_length=None):
"""A numpy implementation of jnp.repeat."""
out = np.repeat(a, repeats, axis)
if total_repeat_length is not None:
out = out[:total_repeat_length]
return out

def custom_jvp(fun, *args, **kwargs):
"""Dummy function for custom_jvp without JAX."""
fun.defjvp = lambda *args, **kwargs: None
Expand Down Expand Up @@ -744,3 +753,37 @@ def root(
"""
out = scipy.optimize.root(fun, x0, args, jac=jac, tol=tol)
return out.x, out

def take(
a,
indices,
axis=None,
out=None,
mode="fill",
unique_indices=False,
indices_are_sorted=False,
fill_value=None,
):
"""A numpy implementation of jnp.take."""
if mode == "fill":
if fill_value is None:
# copy jax logic
# https://jax.readthedocs.io/en/latest/_modules/jax/_src/lax/slicing.html#gather
if np.issubdtype(a.dtype, np.inexact):
fill_value = np.nan
elif np.issubdtype(a.dtype, np.signedinteger):
fill_value = np.iinfo(a.dtype).min
elif np.issubdtype(a.dtype, np.unsignedinteger):
fill_value = np.iinfo(a.dtype).max
elif a.dtype == np.bool_:
fill_value = True
else:
raise ValueError(f"Unsupported dtype {a.dtype}.")
out = np.where(
(-a.size <= indices) & (indices < a.size),
np.take(a, indices, axis, out, mode="wrap"),
fill_value,
)
else:
out = np.take(a, indices, axis, out, mode)
return out
2 changes: 2 additions & 0 deletions desc/compute/_basis_vectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -1514,6 +1514,7 @@ def _e_sup_zeta_zz(params, transforms, profiles, data, **kwargs):
parameterization=[
"desc.equilibrium.equilibrium.Equilibrium",
"desc.geometry.surface.FourierRZToroidalSurface",
"desc.geometry.surface.ZernikeRZToroidalSection",
],
basis="basis",
)
Expand Down Expand Up @@ -3540,6 +3541,7 @@ def _n_zeta(params, transforms, profiles, data, **kwargs):
parameterization=[
"desc.equilibrium.equilibrium.Equilibrium",
"desc.geometry.surface.FourierRZToroidalSurface",
"desc.geometry.surface.ZernikeRZToroidalSection",
],
)
def _e_sub_theta_rp(params, transforms, profiles, data, **kwargs):
Expand Down
1 change: 1 addition & 0 deletions desc/compute/_bootstrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
coordinates="r",
data=["sqrt(g)", "V_r(r)", "|B|", "<|B|^2>", "max_tz |B|"],
axis_limit_data=["sqrt(g)_r", "V_rr(r)"],
resolution_requirement="tz",
n_gauss="int: Number of quadrature points to use for estimating trapped fraction. "
+ "Default 20.",
)
Expand Down
6 changes: 6 additions & 0 deletions desc/compute/_equil.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,7 @@ def _J_dot_B(params, transforms, profiles, data, **kwargs):
coordinates="r",
data=["J*sqrt(g)", "B", "V_r(r)"],
axis_limit_data=["(J*sqrt(g))_r", "V_rr(r)"],
resolution_requirement="tz",
)
def _J_dot_B_fsa(params, transforms, profiles, data, **kwargs):
J = transforms["grid"].replace_at_axis(
Expand Down Expand Up @@ -534,6 +535,7 @@ def _Fmag(params, transforms, profiles, data, **kwargs):
profiles=[],
coordinates="",
data=["|F|", "sqrt(g)", "V"],
resolution_requirement="rtz",
)
def _Fmag_vol(params, transforms, profiles, data, **kwargs):
data["<|F|>_vol"] = (
Expand Down Expand Up @@ -655,6 +657,7 @@ def _F_anisotropic(params, transforms, profiles, data, **kwargs):
profiles=[],
coordinates="",
data=["|B|", "sqrt(g)"],
resolution_requirement="rtz",
)
def _W_B(params, transforms, profiles, data, **kwargs):
data["W_B"] = jnp.sum(
Expand All @@ -675,6 +678,7 @@ def _W_B(params, transforms, profiles, data, **kwargs):
profiles=[],
coordinates="",
data=["B", "sqrt(g)"],
resolution_requirement="rtz",
)
def _W_Bpol(params, transforms, profiles, data, **kwargs):
data["W_Bpol"] = jnp.sum(
Expand All @@ -697,6 +701,7 @@ def _W_Bpol(params, transforms, profiles, data, **kwargs):
profiles=[],
coordinates="",
data=["B", "sqrt(g)"],
resolution_requirement="rtz",
)
def _W_Btor(params, transforms, profiles, data, **kwargs):
data["W_Btor"] = jnp.sum(
Expand All @@ -718,6 +723,7 @@ def _W_Btor(params, transforms, profiles, data, **kwargs):
coordinates="",
data=["p", "sqrt(g)"],
gamma="float: Adiabatic index. Default 0",
resolution_requirement="rtz",
)
def _W_p(params, transforms, profiles, data, **kwargs):
data["W_p"] = jnp.sum(data["p"] * data["sqrt(g)"] * transforms["grid"].weights) / (
Expand Down
14 changes: 12 additions & 2 deletions desc/compute/_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -2598,6 +2598,7 @@ def _grad_B(params, transforms, profiles, data, **kwargs):
profiles=[],
coordinates="",
data=["sqrt(g)", "|B|", "V"],
resolution_requirement="rtz",
)
def _B_vol(params, transforms, profiles, data, **kwargs):
data["<|B|>_vol"] = (
Expand All @@ -2617,11 +2618,12 @@ def _B_vol(params, transforms, profiles, data, **kwargs):
transforms={"grid": []},
profiles=[],
coordinates="",
data=["sqrt(g)", "|B|", "V"],
data=["sqrt(g)", "|B|^2", "V"],
resolution_requirement="rtz",
)
def _B_rms(params, transforms, profiles, data, **kwargs):
data["<|B|>_rms"] = jnp.sqrt(
jnp.sum(data["|B|"] ** 2 * data["sqrt(g)"] * transforms["grid"].weights)
jnp.sum(data["|B|^2"] * data["sqrt(g)"] * transforms["grid"].weights)
/ data["V"]
)
return data
Expand All @@ -2640,6 +2642,7 @@ def _B_rms(params, transforms, profiles, data, **kwargs):
coordinates="r",
data=["sqrt(g)", "|B|"],
axis_limit_data=["sqrt(g)_r"],
resolution_requirement="tz",
)
def _B_fsa(params, transforms, profiles, data, **kwargs):
data["<|B|>"] = surface_averages(
Expand All @@ -2665,6 +2668,7 @@ def _B_fsa(params, transforms, profiles, data, **kwargs):
coordinates="r",
data=["sqrt(g)", "|B|^2"],
axis_limit_data=["sqrt(g)_r"],
resolution_requirement="tz",
)
def _B2_fsa(params, transforms, profiles, data, **kwargs):
data["<|B|^2>"] = surface_averages(
Expand All @@ -2690,6 +2694,7 @@ def _B2_fsa(params, transforms, profiles, data, **kwargs):
coordinates="r",
data=["sqrt(g)", "|B|"],
axis_limit_data=["sqrt(g)_r"],
resolution_requirement="tz",
)
def _1_over_B_fsa(params, transforms, profiles, data, **kwargs):
data["<1/|B|>"] = surface_averages(
Expand All @@ -2715,6 +2720,7 @@ def _1_over_B_fsa(params, transforms, profiles, data, **kwargs):
coordinates="r",
data=["sqrt(g)", "sqrt(g)_r", "B", "B_r", "|B|^2", "V_r(r)", "V_rr(r)"],
axis_limit_data=["sqrt(g)_rr", "V_rrr(r)"],
resolution_requirement="tz",
)
def _B2_fsa_r(params, transforms, profiles, data, **kwargs):
integrate = surface_integrals_map(transforms["grid"])
Expand Down Expand Up @@ -2877,6 +2883,7 @@ def _gradB2mag(params, transforms, profiles, data, **kwargs):
profiles=[],
coordinates="",
data=["|grad(|B|^2)|/2mu0", "sqrt(g)", "V"],
resolution_requirement="rtz",
)
def _gradB2mag_vol(params, transforms, profiles, data, **kwargs):
data["<|grad(|B|^2)|/2mu0>_vol"] = (
Expand Down Expand Up @@ -3077,6 +3084,7 @@ def _B_dot_grad_B_mag(params, transforms, profiles, data, **kwargs):
profiles=[],
coordinates="",
data=["|(B*grad)B|", "sqrt(g)", "V"],
resolution_requirement="rtz",
)
def _B_dot_grad_B_mag_vol(params, transforms, profiles, data, **kwargs):
data["<|(B*grad)B|>_vol"] = (
Expand Down Expand Up @@ -3214,6 +3222,7 @@ def _B_dot_gradB_z(params, transforms, profiles, data, **kwargs):
profiles=[],
coordinates="r",
data=["|B|"],
resolution_requirement="tz",
)
def _min_tz_modB(params, transforms, profiles, data, **kwargs):
data["min_tz |B|"] = surface_min(transforms["grid"], data["|B|"])
Expand All @@ -3232,6 +3241,7 @@ def _min_tz_modB(params, transforms, profiles, data, **kwargs):
profiles=[],
coordinates="r",
data=["|B|"],
resolution_requirement="tz",
)
def _max_tz_modB(params, transforms, profiles, data, **kwargs):
data["max_tz |B|"] = surface_max(transforms["grid"], data["|B|"])
Expand Down
57 changes: 25 additions & 32 deletions desc/compute/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from desc.backend import jnp

from ..grid import QuadratureGrid
from .data_index import register_compute_fun
from .utils import cross, dot, line_integrals, safenorm, surface_integrals

Expand All @@ -26,10 +27,15 @@
transforms={"grid": []},
profiles=[],
coordinates="",
data=["sqrt(g)"],
data=["sqrt(g)", "V(r)"],
resolution_requirement="rtz", # If grid has LCFS then don't need radial resolution.
)
def _V(params, transforms, profiles, data, **kwargs):
data["V"] = jnp.sum(data["sqrt(g)"] * transforms["grid"].weights)
if isinstance(transforms["grid"], QuadratureGrid):
data["V"] = jnp.sum(data["sqrt(g)"] * transforms["grid"].weights)
else:
# Use divergence theorem to compute on LCFS.
data["V"] = jnp.max(data["V(r)"])

Check warning on line 38 in desc/compute/_geometry.py

View check run for this annotation

Codecov / codecov/patch

desc/compute/_geometry.py#L38

Added line #L38 was not covered by tests
return data


Expand All @@ -46,6 +52,7 @@ def _V(params, transforms, profiles, data, **kwargs):
coordinates="",
data=["e_theta", "e_zeta", "x"],
parameterization="desc.geometry.surface.FourierRZToroidalSurface",
resolution_requirement="tz",
)
def _V_FourierRZToroidalSurface(params, transforms, profiles, data, **kwargs):
# divergence theorem: integral(dV div [0, 0, Z]) = integral(dS dot [0, 0, Z])
Expand Down Expand Up @@ -73,6 +80,7 @@ def _V_FourierRZToroidalSurface(params, transforms, profiles, data, **kwargs):
profiles=[],
coordinates="r",
data=["e_theta", "e_zeta", "Z"],
resolution_requirement="tz",
)
def _V_of_r(params, transforms, profiles, data, **kwargs):
# divergence theorem: integral(dV div [0, 0, Z]) = integral(dS dot [0, 0, Z])
Expand All @@ -97,6 +105,7 @@ def _V_of_r(params, transforms, profiles, data, **kwargs):
profiles=[],
coordinates="r",
data=["sqrt(g)"],
resolution_requirement="tz",
)
def _V_r_of_r(params, transforms, profiles, data, **kwargs):
# eq. 4.9.10 in W.D. D'haeseleer et al. (1991) doi:10.1007/978-3-642-75595-8.
Expand All @@ -117,6 +126,7 @@ def _V_r_of_r(params, transforms, profiles, data, **kwargs):
profiles=[],
coordinates="r",
data=["sqrt(g)_r"],
resolution_requirement="tz",
)
def _V_rr_of_r(params, transforms, profiles, data, **kwargs):
# The sign of sqrt(g) is enforced to be non-negative.
Expand All @@ -137,6 +147,7 @@ def _V_rr_of_r(params, transforms, profiles, data, **kwargs):
profiles=[],
coordinates="r",
data=["sqrt(g)_rr"],
resolution_requirement="tz",
)
def _V_rrr_of_r(params, transforms, profiles, data, **kwargs):
# The sign of sqrt(g) is enforced to be non-negative.
Expand All @@ -149,43 +160,21 @@ def _V_rrr_of_r(params, transforms, profiles, data, **kwargs):
label="A(\\zeta)",
units="m^{2}",
units_long="square meters",
description="Cross-sectional area as function of zeta",
description="Enclosed cross-sectional (constant phi surface) area, "
"as function of zeta",
dim=1,
params=[],
transforms={"grid": []},
profiles=[],
coordinates="z",
data=["|e_rho x e_theta|"],
data=["Z", "n_rho", "e_theta|r,p", "rho"],
parameterization=[
"desc.equilibrium.equilibrium.Equilibrium",
"desc.geometry.surface.FourierRZToroidalSurface",
"desc.geometry.surface.ZernikeRZToroidalSection",
],
)
def _A_of_z(params, transforms, profiles, data, **kwargs):
data["A(z)"] = surface_integrals(
transforms["grid"],
data["|e_rho x e_theta|"],
surface_label="zeta",
expand_out=True,
)
return data


@register_compute_fun(
name="A(z)",
label="A(\\zeta)",
units="m^{2}",
units_long="square meters",
description="Enclosed cross-sectional (constant phi surface) area, "
"as function of zeta",
dim=1,
params=[],
transforms={"grid": []},
profiles=[],
coordinates="z",
data=["Z", "n_rho", "e_theta|r,p", "rho"],
parameterization=["desc.geometry.surface.FourierRZToroidalSurface"],
# FIXME: Add source grid requirement once omega is nonzero.
resolution_requirement="rt", # If grid has LCFS then don't need radial resolution.
)
def _A_of_z_FourierRZToroidalSurface(params, transforms, profiles, data, **kwargs):
# Denote any vector v = [vᴿ, v^ϕ, vᶻ] with a tuple of its contravariant components.
Expand Down Expand Up @@ -231,6 +220,7 @@ def _A_of_z_FourierRZToroidalSurface(params, transforms, profiles, data, **kwarg
"desc.equilibrium.equilibrium.Equilibrium",
"desc.geometry.core.Surface",
],
resolution_requirement="z",
)
def _A(params, transforms, profiles, data, **kwargs):
data["A"] = jnp.mean(
Expand Down Expand Up @@ -273,13 +263,12 @@ def _A_of_r(params, transforms, profiles, data, **kwargs):
"desc.equilibrium.equilibrium.Equilibrium",
"desc.geometry.surface.FourierRZToroidalSurface",
],
resolution_requirement="tz",
)
def _S(params, transforms, profiles, data, **kwargs):
data["S"] = jnp.max(
surface_integrals(
transforms["grid"],
data["|e_theta x e_zeta|"],
expand_out=False,
transforms["grid"], data["|e_theta x e_zeta|"], expand_out=False
)
)
return data
Expand All @@ -297,6 +286,7 @@ def _S(params, transforms, profiles, data, **kwargs):
profiles=[],
coordinates="r",
data=["|e_theta x e_zeta|"],
resolution_requirement="tz",
)
def _S_of_r(params, transforms, profiles, data, **kwargs):
data["S(r)"] = surface_integrals(transforms["grid"], data["|e_theta x e_zeta|"])
Expand All @@ -315,6 +305,7 @@ def _S_of_r(params, transforms, profiles, data, **kwargs):
profiles=[],
coordinates="r",
data=["|e_theta x e_zeta|_r"],
resolution_requirement="tz",
)
def _S_r_of_r(params, transforms, profiles, data, **kwargs):
data["S_r(r)"] = surface_integrals(transforms["grid"], data["|e_theta x e_zeta|_r"])
Expand All @@ -334,6 +325,7 @@ def _S_r_of_r(params, transforms, profiles, data, **kwargs):
profiles=[],
coordinates="r",
data=["|e_theta x e_zeta|_rr"],
resolution_requirement="tz",
)
def _S_rr_of_r(params, transforms, profiles, data, **kwargs):
data["S_rr(r)"] = surface_integrals(
Expand Down Expand Up @@ -424,6 +416,7 @@ def _R0_over_a(params, transforms, profiles, data, **kwargs):
"desc.equilibrium.equilibrium.Equilibrium",
"desc.geometry.core.Surface",
],
resolution_requirement="rt", # just need r near lcfs
)
def _perimeter_of_z(params, transforms, profiles, data, **kwargs):
max_rho = jnp.max(data["rho"])
Expand Down
Loading

0 comments on commit 56ead44

Please sign in to comment.