diff --git a/desc/magnetic_fields/__init__.py b/desc/magnetic_fields/__init__.py index 094fbc724a..0a8f18abd8 100644 --- a/desc/magnetic_fields/__init__.py +++ b/desc/magnetic_fields/__init__.py @@ -1,6 +1,7 @@ """Classes for Magnetic Fields.""" from ._core import ( + MagneticFieldFromUser, OmnigenousField, PoloidalMagneticField, ScalarPotentialField, diff --git a/desc/magnetic_fields/_core.py b/desc/magnetic_fields/_core.py index 6eee8f3da2..6900d9cbd5 100644 --- a/desc/magnetic_fields/_core.py +++ b/desc/magnetic_fields/_core.py @@ -24,7 +24,7 @@ from desc.optimizable import Optimizable, OptimizableCollection, optimizable_parameter from desc.singularities import compute_B_plasma from desc.transform import Transform -from desc.utils import copy_coeffs, flatten_list, setdefault, warnif +from desc.utils import copy_coeffs, errorif, flatten_list, setdefault, warnif from desc.vmec_utils import ptolemy_identity_fwd, ptolemy_identity_rev @@ -540,6 +540,85 @@ def save_mgrid( file.close() +class MagneticFieldFromUser(_MagneticField, Optimizable): + """Wrap an arbitrary function for calculating magnetic field in lab coordinates. + + Parameters + ---------- + fun : callable + Function to compute magnetic field at arbitrary points. Should have a signature + of the form ``fun(coords, params) -> B`` where + + - ``coords`` is a (n,3) array of positions in R, phi, Z coordinates where + the field is to be evaluated. + - ``params`` is an array of optional parameters, eg for optimizing the field. + - ``B`` is the returned value of the magnetic field as a (n,3) array in R, + phi, Z coordinates. + + params : ndarray, optional + Default values for parameters. Defaults to an empty array. + + """ + + def __init__(self, fun, params=None): + errorif(not callable(fun), ValueError, "fun must be callable") + self._params = jnp.asarray(setdefault(params, jnp.array([]))) + + import jax + + dummy_coords = np.empty((7, 3)) + dummy_B = jax.eval_shape(fun, dummy_coords, self.params) + errorif( + dummy_B.shape != (7, 3), + ValueError, + "fun should return an array of the same shape as coords", + ) + self._fun = fun + + @optimizable_parameter + @property + def params(self): + """ndarray: Parameters of the field allowed to vary during optimization.""" + return self._params + + @params.setter + def params(self, params): + self._params = params + + def compute_magnetic_field( + self, coords, params=None, basis="rpz", source_grid=None + ): + """Compute magnetic field at a set of points. + + Parameters + ---------- + coords : array-like shape(n,3) + Nodes to evaluate field at in [R,phi,Z] or [X,Y,Z] coordinates. + params : array-like, optional + Optimizable parameters, defaults to field.params. + basis : {"rpz", "xyz"} + Basis for input coordinates and returned magnetic field. + source_grid : Grid, int or None or array-like, optional + Unused by this class, only kept for API compatibility + + Returns + ------- + field : ndarray, shape(N,3) + magnetic field at specified points + + """ + coords = jnp.atleast_2d(jnp.asarray(coords)) + if params is None: + params = self.params + if basis == "xyz": + coords = xyz2rpz(coords) + + B = self._fun(coords, params) + if basis == "xyz": + B = rpz2xyz_vec(B, phi=coords[:, 1]) + return B + + class ScaledMagneticField(_MagneticField, Optimizable): """Magnetic field scaled by a scalar value. diff --git a/docs/api_fields.rst b/docs/api_fields.rst index 63c0648f8a..32b5cc572b 100644 --- a/docs/api_fields.rst +++ b/docs/api_fields.rst @@ -23,6 +23,7 @@ a surface and computes the normal field strength on that surface. desc.magnetic_fields.SplineMagneticField desc.magnetic_fields.DommaschkPotentialField + desc.magnetic_fields.MagneticFieldFromUser desc.magnetic_fields.ScalarPotentialField desc.magnetic_fields.ToroidalMagneticField desc.magnetic_fields.VerticalMagneticField diff --git a/tests/test_magnetic_fields.py b/tests/test_magnetic_fields.py index ae5b51c35f..9e4103f0dc 100644 --- a/tests/test_magnetic_fields.py +++ b/tests/test_magnetic_fields.py @@ -15,6 +15,7 @@ CurrentPotentialField, DommaschkPotentialField, FourierCurrentPotentialField, + MagneticFieldFromUser, OmnigenousField, PoloidalMagneticField, ScalarPotentialField, @@ -64,6 +65,28 @@ def test_basic_fields(self): (tfield + vfield - pfield)([1, 0, 0.1]), [[0.4, 2, 1]] ) + @pytest.mark.unit + def test_field_from_user(self): + """Test for MagneticFieldFromUser.""" + tfield = ToroidalMagneticField(2, 1) + + def fun(coords, params): + R0, B0 = params + coords = jnp.atleast_2d(jnp.asarray(coords)) + bp = B0 * R0 / coords[:, 0] + brz = jnp.zeros_like(bp) + B = jnp.array([brz, bp, brz]).T + return B + + ufield = MagneticFieldFromUser(fun, [tfield.R0, tfield.B0]) + np.testing.assert_allclose( + tfield([1, 0, 0]), ufield([1, 0, 0], params=[tfield.R0, tfield.B0]) + ) + np.testing.assert_allclose( + tfield([1, 1, 0], basis="xyz"), + ufield([1, 1, 0], params=[tfield.R0, tfield.B0], basis="xyz"), + ) + @pytest.mark.unit def test_combined_fields(self): """Tests for sum/scaled fields."""