Skip to content

Commit

Permalink
Merge branch 'master' into map_coords_missing_params
Browse files Browse the repository at this point in the history
  • Loading branch information
dpanici authored Apr 30, 2024
2 parents 1d97cab + 0e7a195 commit 2d4afa5
Show file tree
Hide file tree
Showing 4 changed files with 105 additions and 1 deletion.
1 change: 1 addition & 0 deletions desc/magnetic_fields/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""Classes for Magnetic Fields."""

from ._core import (
MagneticFieldFromUser,
OmnigenousField,
PoloidalMagneticField,
ScalarPotentialField,
Expand Down
81 changes: 80 additions & 1 deletion desc/magnetic_fields/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from desc.optimizable import Optimizable, OptimizableCollection, optimizable_parameter
from desc.singularities import compute_B_plasma
from desc.transform import Transform
from desc.utils import copy_coeffs, flatten_list, setdefault, warnif
from desc.utils import copy_coeffs, errorif, flatten_list, setdefault, warnif
from desc.vmec_utils import ptolemy_identity_fwd, ptolemy_identity_rev


Expand Down Expand Up @@ -540,6 +540,85 @@ def save_mgrid(
file.close()


class MagneticFieldFromUser(_MagneticField, Optimizable):
"""Wrap an arbitrary function for calculating magnetic field in lab coordinates.
Parameters
----------
fun : callable
Function to compute magnetic field at arbitrary points. Should have a signature
of the form ``fun(coords, params) -> B`` where
- ``coords`` is a (n,3) array of positions in R, phi, Z coordinates where
the field is to be evaluated.
- ``params`` is an array of optional parameters, eg for optimizing the field.
- ``B`` is the returned value of the magnetic field as a (n,3) array in R,
phi, Z coordinates.
params : ndarray, optional
Default values for parameters. Defaults to an empty array.
"""

def __init__(self, fun, params=None):
errorif(not callable(fun), ValueError, "fun must be callable")
self._params = jnp.asarray(setdefault(params, jnp.array([])))

import jax

dummy_coords = np.empty((7, 3))
dummy_B = jax.eval_shape(fun, dummy_coords, self.params)
errorif(
dummy_B.shape != (7, 3),
ValueError,
"fun should return an array of the same shape as coords",
)
self._fun = fun

@optimizable_parameter
@property
def params(self):
"""ndarray: Parameters of the field allowed to vary during optimization."""
return self._params

@params.setter
def params(self, params):
self._params = params

def compute_magnetic_field(
self, coords, params=None, basis="rpz", source_grid=None
):
"""Compute magnetic field at a set of points.
Parameters
----------
coords : array-like shape(n,3)
Nodes to evaluate field at in [R,phi,Z] or [X,Y,Z] coordinates.
params : array-like, optional
Optimizable parameters, defaults to field.params.
basis : {"rpz", "xyz"}
Basis for input coordinates and returned magnetic field.
source_grid : Grid, int or None or array-like, optional
Unused by this class, only kept for API compatibility
Returns
-------
field : ndarray, shape(N,3)
magnetic field at specified points
"""
coords = jnp.atleast_2d(jnp.asarray(coords))
if params is None:
params = self.params
if basis == "xyz":
coords = xyz2rpz(coords)

B = self._fun(coords, params)
if basis == "xyz":
B = rpz2xyz_vec(B, phi=coords[:, 1])
return B


class ScaledMagneticField(_MagneticField, Optimizable):
"""Magnetic field scaled by a scalar value.
Expand Down
1 change: 1 addition & 0 deletions docs/api_fields.rst
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ a surface and computes the normal field strength on that surface.

desc.magnetic_fields.SplineMagneticField
desc.magnetic_fields.DommaschkPotentialField
desc.magnetic_fields.MagneticFieldFromUser
desc.magnetic_fields.ScalarPotentialField
desc.magnetic_fields.ToroidalMagneticField
desc.magnetic_fields.VerticalMagneticField
Expand Down
23 changes: 23 additions & 0 deletions tests/test_magnetic_fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
CurrentPotentialField,
DommaschkPotentialField,
FourierCurrentPotentialField,
MagneticFieldFromUser,
OmnigenousField,
PoloidalMagneticField,
ScalarPotentialField,
Expand Down Expand Up @@ -64,6 +65,28 @@ def test_basic_fields(self):
(tfield + vfield - pfield)([1, 0, 0.1]), [[0.4, 2, 1]]
)

@pytest.mark.unit
def test_field_from_user(self):
"""Test for MagneticFieldFromUser."""
tfield = ToroidalMagneticField(2, 1)

def fun(coords, params):
R0, B0 = params
coords = jnp.atleast_2d(jnp.asarray(coords))
bp = B0 * R0 / coords[:, 0]
brz = jnp.zeros_like(bp)
B = jnp.array([brz, bp, brz]).T
return B

ufield = MagneticFieldFromUser(fun, [tfield.R0, tfield.B0])
np.testing.assert_allclose(
tfield([1, 0, 0]), ufield([1, 0, 0], params=[tfield.R0, tfield.B0])
)
np.testing.assert_allclose(
tfield([1, 1, 0], basis="xyz"),
ufield([1, 1, 0], params=[tfield.R0, tfield.B0], basis="xyz"),
)

@pytest.mark.unit
def test_combined_fields(self):
"""Tests for sum/scaled fields."""
Expand Down

0 comments on commit 2d4afa5

Please sign in to comment.