Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PEST coordinate system basis vectors #1090

Merged
merged 22 commits into from
Jul 22, 2024
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
42e6112
Genearlize PEST coordinate system compute funs to genearlized toroida…
unalmis Jul 1, 2024
b65351a
BUGFIX: don't change basis to xyz erroneously assuming input is in rpz
unalmis Jul 1, 2024
4c4553a
BUGFIX: don't change basis to xyz erroneously assuming input is in rpz
unalmis Jul 1, 2024
a145e6f
Fix issue with merge
unalmis Jul 1, 2024
455a9a3
Fix issue with merge
unalmis Jul 1, 2024
bd840b9
Merge branch 'basis_bug' into sfl_coordinates
unalmis Jul 1, 2024
031d29e
Add guards to compute funs that should be fixed in GitHub pull reques…
unalmis Jul 1, 2024
571f187
add missing basis kwarg to e_phi|R,Z
unalmis Jul 2, 2024
41033d7
Merge branch 'basis_bug' into sfl_coordinates
unalmis Jul 2, 2024
7aa74ce
review suggestions
unalmis Jul 2, 2024
e7e6572
Merge branch 'master' into basis_bug
dpanici Jul 3, 2024
4266028
Merge branch 'basis_bug' into sfl_coordinates
unalmis Jul 3, 2024
4192967
Review requests
unalmis Jul 3, 2024
35ec354
Remove e_phi|R,Z
unalmis Jul 3, 2024
eebeb90
Merge branch 'master' into basis_bug
unalmis Jul 5, 2024
6341542
Merge branch 'master' into basis_bug
unalmis Jul 5, 2024
92717c0
Merge branch 'basis_bug' into sfl_coordinates
unalmis Jul 5, 2024
449d332
Remove code that merge didn't remove
unalmis Jul 5, 2024
33e1e81
Use master_compute_data from master
unalmis Jul 5, 2024
8e9f43a
Merge branch 'master' into sfl_coordinates
unalmis Jul 11, 2024
a2c5cf9
Fix merge conflicts from basing branch of now closed basis_bug
unalmis Jul 11, 2024
2dad8f8
Merge branch 'master' into sfl_coordinates
unalmis Jul 20, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
239 changes: 87 additions & 152 deletions desc/compute/_basis_vectors.py

Large diffs are not rendered by default.

4 changes: 4 additions & 0 deletions desc/compute/_curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from desc.backend import jnp

from ..utils import errorif
from .data_index import register_compute_fun
from .geom_utils import rotation_matrix, rpz2xyz, rpz2xyz_vec, xyz2rpz, xyz2rpz_vec
from .utils import cross, dot, safenormalize
Expand Down Expand Up @@ -156,6 +157,9 @@ def _phi_Curve(params, transforms, profiles, data, **kwargs):
parameterization="desc.geometry.core.Curve",
)
def _Z_Curve(params, transforms, profiles, data, **kwargs):
errorif(
kwargs.get("basis", "rpz").lower() not in {"rpz", "xyz"}, NotImplementedError
)
data["Z"] = data["x"][:, 2]
return data

Expand Down
8 changes: 8 additions & 0 deletions desc/compute/_equil.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from desc.backend import jnp

from ..utils import errorif
from .data_index import register_compute_fun
from .utils import cross, dot, safediv, safenorm, surface_averages

Expand Down Expand Up @@ -231,6 +232,7 @@ def _J_sqrt_g_r(params, transforms, profiles, data, **kwargs):
data=["J"],
)
def _J_R(params, transforms, profiles, data, **kwargs):
errorif(kwargs.get("basis", "rpz").lower() != "rpz", NotImplementedError)
data["J_R"] = data["J"][:, 0]
return data

Expand All @@ -249,6 +251,7 @@ def _J_R(params, transforms, profiles, data, **kwargs):
data=["J"],
)
def _J_phi(params, transforms, profiles, data, **kwargs):
errorif(kwargs.get("basis", "rpz").lower() != "rpz", NotImplementedError)
data["J_phi"] = data["J"][:, 1]
return data

Expand All @@ -267,6 +270,9 @@ def _J_phi(params, transforms, profiles, data, **kwargs):
data=["J"],
)
def _J_Z(params, transforms, profiles, data, **kwargs):
errorif(
kwargs.get("basis", "rpz").lower() not in {"rpz", "xyz"}, NotImplementedError
)
data["J_Z"] = data["J"][:, 2]
return data

Expand Down Expand Up @@ -677,6 +683,7 @@ def _W_B(params, transforms, profiles, data, **kwargs):
data=["B", "sqrt(g)"],
)
def _W_Bpol(params, transforms, profiles, data, **kwargs):
errorif(kwargs.get("basis", "rpz").lower() != "rpz", NotImplementedError)
data["W_Bpol"] = jnp.sum(
dot(data["B"][:, (0, 2)], data["B"][:, (0, 2)])
* data["sqrt(g)"]
Expand All @@ -699,6 +706,7 @@ def _W_Bpol(params, transforms, profiles, data, **kwargs):
data=["B", "sqrt(g)"],
)
def _W_Btor(params, transforms, profiles, data, **kwargs):
errorif(kwargs.get("basis", "rpz").lower() != "rpz", NotImplementedError)
data["W_Btor"] = jnp.sum(
data["B"][:, 1] ** 2 * data["sqrt(g)"] * transforms["grid"].weights
) / (2 * mu_0)
Expand Down
6 changes: 6 additions & 0 deletions desc/compute/_field.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

from desc.backend import jnp

from ..utils import errorif
from .data_index import register_compute_fun
from .utils import (
cross,
Expand Down Expand Up @@ -140,6 +141,7 @@ def _B(params, transforms, profiles, data, **kwargs):
data=["B"],
)
def _B_R(params, transforms, profiles, data, **kwargs):
errorif(kwargs.get("basis", "rpz").lower() != "rpz", NotImplementedError)
data["B_R"] = data["B"][:, 0]
return data

Expand All @@ -158,6 +160,7 @@ def _B_R(params, transforms, profiles, data, **kwargs):
data=["B"],
)
def _B_phi(params, transforms, profiles, data, **kwargs):
errorif(kwargs.get("basis", "rpz").lower() != "rpz", NotImplementedError)
data["B_phi"] = data["B"][:, 1]
return data

Expand All @@ -176,6 +179,9 @@ def _B_phi(params, transforms, profiles, data, **kwargs):
data=["B"],
)
def _B_Z(params, transforms, profiles, data, **kwargs):
errorif(
kwargs.get("basis", "rpz").lower() not in {"rpz", "xyz"}, NotImplementedError
)
data["B_Z"] = data["B"][:, 2]
return data

Expand Down
8 changes: 8 additions & 0 deletions desc/compute/_geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@

from desc.backend import jnp

from ..utils import errorif
from .data_index import register_compute_fun
from .utils import cross, dot, line_integrals, surface_integrals

Expand Down Expand Up @@ -48,6 +49,9 @@ def _V(params, transforms, profiles, data, **kwargs):
parameterization="desc.geometry.surface.FourierRZToroidalSurface",
)
def _V_FourierRZToroidalSurface(params, transforms, profiles, data, **kwargs):
errorif(
kwargs.get("basis", "rpz").lower() not in {"rpz", "xyz"}, NotImplementedError
)
# divergence theorem: integral(dV div [0, 0, Z]) = integral(dS dot [0, 0, Z])
data["V"] = jnp.max( # take max in case there are multiple surfaces for some reason
jnp.abs(
Expand Down Expand Up @@ -75,6 +79,9 @@ def _V_FourierRZToroidalSurface(params, transforms, profiles, data, **kwargs):
data=["e_theta", "e_zeta", "Z"],
)
def _V_of_r(params, transforms, profiles, data, **kwargs):
errorif(
kwargs.get("basis", "rpz").lower() not in {"rpz", "xyz"}, NotImplementedError
)
# divergence theorem: integral(dV div [0, 0, Z]) = integral(dS dot [0, 0, Z])
data["V(r)"] = jnp.abs(
surface_integrals(
Expand Down Expand Up @@ -186,6 +193,7 @@ def _A_of_z(params, transforms, profiles, data, **kwargs):
parameterization=["desc.geometry.surface.FourierRZToroidalSurface"],
)
def _A_of_z_FourierRZToroidalSurface(params, transforms, profiles, data, **kwargs):
errorif(kwargs.get("basis", "rpz").lower() != "rpz", NotImplementedError)
# divergence theorem: integral(dA div [0, 0, Z]) = integral(ds n dot [0, 0, Z])
# but we need only the part of n in the R,Z plane
n = data["n_rho"]
Expand Down
12 changes: 8 additions & 4 deletions desc/compute/_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,17 +40,21 @@ def _sqrtg(params, transforms, profiles, data, **kwargs):
label="\\sqrt{g}_{PEST}",
units="m^{3}",
units_long="cubic meters",
description="Jacobian determinant of PEST flux coordinate system",
description="Jacobian determinant of (ρ,ϑ,ϕ) coordinate system or"
" straight field line PEST coordinates. ϕ increases counterclockwise"
" when viewed from above (cylindrical R,ϕ plane with Z out of page).",
dim=1,
params=[],
transforms={},
profiles=[],
coordinates="rtz",
data=["e_rho", "e_theta_PEST", "e_phi"],
data=["sqrt(g)", "theta_PEST_t", "phi_z", "theta_PEST_z", "phi_t"],
)
def _sqrtg_pest(params, transforms, profiles, data, **kwargs):
data["sqrt(g)_PEST"] = dot(
data["e_rho"], cross(data["e_theta_PEST"], data["e_phi"])
# Same as dot(data["e_rho|v,p"], cross(data["e_vartheta"], data["e_phi|r,v"])), but
# more efficient as it avoids computing radial derivatives of the stream functions.
data["sqrt(g)_PEST"] = data["sqrt(g)"] / (
data["theta_PEST_t"] * data["phi_z"] - data["theta_PEST_z"] * data["phi_t"]
)
return data

Expand Down
7 changes: 7 additions & 0 deletions desc/compute/_surface.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
from desc.backend import jnp

from ..utils import errorif
from .data_index import register_compute_fun
from .geom_utils import rpz2xyz, xyz2rpz, xyz2rpz_vec

# TODO: review when zeta no longer equals phi


@register_compute_fun(
name="x",
Expand All @@ -26,6 +29,7 @@
def _x_FourierRZToroidalSurface(params, transforms, profiles, data, **kwargs):
R = transforms["R"].transform(params["R_lmn"])
Z = transforms["Z"].transform(params["Z_lmn"])
# TODO: change when zeta no longer equals phi
phi = transforms["grid"].nodes[:, 2]
coords = jnp.stack([R, phi, Z], axis=1)
if kwargs.get("basis", "rpz").lower() == "xyz":
Expand Down Expand Up @@ -217,6 +221,9 @@ def _phi_z_Surface(params, transforms, profiles, data, **kwargs):
parameterization="desc.geometry.core.Surface",
)
def _Z_Surface(params, transforms, profiles, data, **kwargs):
errorif(
kwargs.get("basis", "rpz").lower() not in {"rpz", "xyz"}, NotImplementedError
)
data["Z"] = data["x"][:, 2]
return data

Expand Down
Binary file modified tests/inputs/master_compute_data.pkl
Binary file not shown.
14 changes: 14 additions & 0 deletions tests/test_compute_funs.py
Original file line number Diff line number Diff line change
Expand Up @@ -1683,3 +1683,17 @@ def test_surface_equilibrium_geometry():
rtol=3e-13,
atol=1e-13,
)


@pytest.mark.unit
def test_basis_kwarg():
"""Test that we don't change basis to xyz erroneously assuming input is in rpz."""
eq = get("W7-X")
names = ["b", "e^rho", "e_theta_PEST", "phi"]
data_rpz = eq.compute(names, basis="rpz")
data_xyz = eq.compute(names, basis="xyz")
for name in names:
if name != "phi":
np.testing.assert_allclose(
rpz2xyz_vec(data_rpz[name], phi=data_rpz["phi"]), data_xyz[name]
)
15 changes: 13 additions & 2 deletions tests/test_data_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import desc.compute
from desc.compute import data_index
from desc.compute.data_index import _class_inheritance
from desc.utils import errorif


class TestDataIndex:
Expand Down Expand Up @@ -39,6 +40,11 @@ def get_parameterization(fun, default="desc.equilibrium.equilibrium.Equilibrium"
matches.discard("")
return matches if matches else {default}

@staticmethod
def _is_function(func):
# JITed functions are not functions according to inspect.
return inspect.isfunction(func) or callable(func)
unalmis marked this conversation as resolved.
Show resolved Hide resolved

@pytest.mark.unit
def test_data_index_deps(self):
"""Ensure developers do not add extra (or forget needed) dependencies.
Expand Down Expand Up @@ -74,7 +80,7 @@ def test_data_index_deps(self):
pattern_params = re.compile(r"params\[(.*?)]")
for module_name, module in inspect.getmembers(desc.compute, inspect.ismodule):
if module_name[0] == "_":
for _, fun in inspect.getmembers(module, inspect.isfunction):
for _, fun in inspect.getmembers(module, self._is_function):
# quantities that this function computes
names = self.get_matches(fun, pattern_names)
# dependencies queried in source code of this function
Expand All @@ -97,7 +103,6 @@ def test_data_index_deps(self):

for p in data_index:
for name, val in data_index[p].items():
print(name)
err_msg = f"Parameterization: {p}. Name: {name}."
deps = val["dependencies"]
data = set(deps["data"])
Expand All @@ -111,6 +116,12 @@ def test_data_index_deps(self):
assert len(profiles) == len(deps["profiles"]), err_msg
assert len(params) == len(deps["params"]), err_msg
# assert correct dependencies are queried
errorif(
name not in queried_deps[p],
AssertionError,
"Did you reuse the function name (i.e. def_...) for"
f" '{name}' for some other quantity?",
)
assert queried_deps[p][name]["data"] == data | axis_limit_data, err_msg
assert queried_deps[p][name]["profiles"] == profiles, err_msg
assert queried_deps[p][name]["params"] == params, err_msg