Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove basis kwarg from _compute function calls in magnetic field classes #1137

Merged
merged 3 commits into from
Jul 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions desc/coils.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,12 +248,13 @@ def compute_magnetic_field(
else:
data = compute_fun(
self,
name=["x", "x_s", "ds"],
names=["x", "x_s", "ds"],
params=params,
transforms=transforms,
profiles={},
basis="xyz",
)
data["x_s"] = rpz2xyz_vec(data["x_s"], phi=data["x"][:, 1])
data["x"] = rpz2xyz(data["x"])

B = biot_savart_quad(
coords, data["x"], data["x_s"] * data["ds"][:, None], current
Expand Down
9 changes: 3 additions & 6 deletions desc/magnetic_fields/_current_potential.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from desc.backend import fori_loop, jnp
from desc.basis import DoubleFourierSeries
from desc.compute import rpz2xyz, rpz2xyz_vec, xyz2rpz, xyz2rpz_vec
from desc.compute import rpz2xyz, rpz2xyz_vec, xyz2rpz_vec
from desc.compute.utils import _compute as compute_fun
from desc.geometry import FourierRZToroidalSurface
from desc.grid import LinearGrid
Expand Down Expand Up @@ -646,12 +646,10 @@ def _compute_magnetic_field_from_CurrentPotentialField(

# compute surface current, and store grid quantities
# needed for integration in class
# TODO: does this have to be xyz, or can it be computed in rpz as well?
if not params or not transforms:
data = field.compute(
["K", "x"],
grid=source_grid,
basis="xyz",
params=params,
transforms=transforms,
jitable=True,
Expand All @@ -663,11 +661,10 @@ def _compute_magnetic_field_from_CurrentPotentialField(
params=params,
transforms=transforms,
profiles={},
basis="xyz",
)

_rs = xyz2rpz(data["x"])
_K = xyz2rpz_vec(data["K"], phi=source_grid.nodes[:, 2])
_rs = data["x"]
_K = data["K"]

# surface element, must divide by NFP to remove the NFP multiple on the
# surface grid weights, as we account for that when doing the for loop
Expand Down
9 changes: 7 additions & 2 deletions tests/test_coils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
MixedCoilSet,
SplineXYZCoil,
)
from desc.compute import get_transforms, xyz2rpz, xyz2rpz_vec
from desc.compute import get_params, get_transforms, xyz2rpz, xyz2rpz_vec
from desc.examples import get
from desc.geometry import FourierRZCurve, FourierRZToroidalSurface
from desc.grid import LinearGrid
Expand Down Expand Up @@ -45,8 +45,13 @@ def test_biot_savart_all_coils(self):
# FourierXYZCoil
coil = FourierXYZCoil(I)
transforms = get_transforms(["x", "x_s", "ds"], coil, coil_grid)
params = get_params(["x", "x_s", "ds"], coil)
B_xyz = coil.compute_magnetic_field(
grid_xyz, basis="xyz", source_grid=coil_grid, transforms=transforms
grid_xyz,
basis="xyz",
source_grid=coil_grid,
transforms=transforms,
params=params,
)
B_rpz = coil.compute_magnetic_field(
grid_rpz, basis="rpz", source_grid=coil_grid
Expand Down
11 changes: 10 additions & 1 deletion tests/test_magnetic_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from desc.backend import jit, jnp
from desc.basis import DoubleFourierSeries
from desc.compute import rpz2xyz_vec, xyz2rpz_vec
from desc.compute.utils import get_params, get_transforms
from desc.examples import get
from desc.geometry import FourierRZToroidalSurface
from desc.grid import LinearGrid
Expand Down Expand Up @@ -340,8 +341,16 @@ def test_fourier_current_potential_field(self):
field.change_resolution(3, 3)
field.change_Phi_resolution(2, 2)

params = get_params(["K", "x"], field)
transforms = get_transforms(["K", "x"], field, grid=surface_grid)

np.testing.assert_allclose(
field.compute_magnetic_field([10.0, 0, 0], source_grid=surface_grid),
field.compute_magnetic_field(
[10.0, 0, 0],
source_grid=surface_grid,
params=params,
transforms=transforms,
),
correct_field(10.0, 0, 0),
atol=1e-16,
rtol=1e-8,
Expand Down
Loading