Skip to content

Commit

Permalink
Remove redundant parameter to Bounce2D. Last edit unless reviewers re…
Browse files Browse the repository at this point in the history
…quest changes.
  • Loading branch information
unalmis committed Oct 21, 2024
1 parent 332c053 commit 7c8902a
Show file tree
Hide file tree
Showing 4 changed files with 33 additions and 66 deletions.
22 changes: 11 additions & 11 deletions desc/integrals/bounce_integral.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ class Bounce2D(Bounce):
Fᵢ : λ, ζ₁, ζ₂ ↦ ∫ᵢ f(λ, ζ, {Gⱼ}) dζ
If the map G is multivalued at a physical location, then it is still
permissible if separable into a single valued and multivalued parts.
permissible if separable into single valued and multivalued parts.
In that case, supply the single valued parts, which will be interpolated
with FFTs, and use the provided coordinates θ,ζ ∈ ℝ to compose G.
Expand Down Expand Up @@ -253,7 +253,6 @@ def __init__(
self,
grid,
data,
iota,
theta,
Y_B=None,
num_transit=32,
Expand All @@ -280,14 +279,12 @@ def __init__(
----------
grid : Grid
Tensor-product grid in (ρ, θ, ζ) with uniformly spaced nodes
[0, 2π) × [0, 2π/NFP). The ζ nodes should be strictly increasing.
(θ, ζ) ∈ [0, 2π) × [0, 2π/NFP). The ζ coordinates (the unique values prior
to taking the tensor-product) must be strictly increasing.
Below shape notation defines ``M=grid.num_theta`` and ``N=grid.num_zeta``.
data : dict[str, jnp.ndarray]
Data evaluated on ``grid``.
Must include names in ``Bounce2D.required_names``.
iota : jnp.ndarray
Shape (num rho, ).
Rotational transform.
theta : jnp.ndarray
Shape (num rho, X, Y).
DESC coordinates θ sourced from the Clebsch coordinates
Expand Down Expand Up @@ -332,7 +329,7 @@ def __init__(
Chebyshev series algorithm is not yet implemented.
When using splines, it is recommended to reduce the ``num_well``
parameter in the ``points`` method from ``3*Y_B*num_transit`` to
``Y_B*num_transit``.
at most ``Y_B*num_transit``.
"""
errorif(grid.sym, NotImplementedError, msg="Need grid that works with FFTs.")
Expand All @@ -353,7 +350,9 @@ def __init__(
"B^zeta": _fourier(
grid, jnp.abs(data["B^zeta"]) * Lref / Bref, is_reshaped
),
"T(z)": fourier_chebyshev(theta, iota, alpha, num_transit),
"T(z)": fourier_chebyshev(
theta, grid.compress(data["iota"]), alpha, num_transit
),
}
Y_B = setdefault(Y_B, theta.shape[-1] * 2)
if spline:
Expand Down Expand Up @@ -600,17 +599,18 @@ def integrate(
as the indices that correspond to that field line.
f : list[jnp.ndarray] or jnp.ndarray
Shape (num rho, M, N).
Real scalar-valued (2π × 2π/NFP) periodic in (θ, ζ) functions evaluated
on the ``grid`` supplied to construct this object. These functions
Real scalar-valued periodic functions in (θ, ζ) ∈ [0, 2π) × [0, 2π/NFP)
evaluated on the ``grid`` supplied to construct this object. These functions
should be arguments to the callable ``integrand``. Use the method
``Bounce2D.reshape_data`` to reshape the data into the expected shape.
weight : jnp.ndarray
Shape (num rho, M, N).
Real scalar-valued periodic functions in (θ, ζ) ∈ [0, 2π) × [0, 2π/NFP)
evaluated on the ``grid`` supplied to construct this object.
If supplied, the bounce integral labeled by well j is weighted such that
the returned value is w(j) ∫ f(λ, ℓ) dℓ, where w(j) is ``weight``
interpolated to the deepest point in that magnetic well. Use the method
``Bounce2D.reshape_data`` to reshape the data into the expected shape.
It is assumed ``weight`` is a periodic function.
points : tuple[jnp.ndarray]
Shape (num rho, num pitch, num well).
Optional, output of method ``self.points``.
Expand Down
9 changes: 6 additions & 3 deletions desc/integrals/bounce_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -878,7 +878,8 @@ def interp_fft_to_argmin(
h : jnp.ndarray
Shape (..., grid.num_theta, grid.num_zeta)
Periodic function evaluated on tensor-product grid in (ρ, θ, ζ) with
uniformly spaced nodes [0, 2π) × [0, 2π/NFP).
uniformly spaced nodes (θ, ζ) ∈ [0, 2π) × [0, 2π/NFP).
Preferably power of 2 for ``grid.num_theta`` and ``grid.num_zeta``.
points : jnp.ndarray
Shape (..., num well).
Boundaries to detect argmin between.
Expand Down Expand Up @@ -962,7 +963,9 @@ def get_fieldline(alpha_0, iota, num_transit, period):
"""
# Δϕ (∂α/∂ϕ) = Δϕ ι̅ = Δϕ ι/2π = Δϕ data["iota"]
return alpha_0 + period * jnp.expand_dims(iota, -1) * jnp.arange(num_transit)
return alpha_0 + period * (
iota if iota.size == 1 else iota[:, jnp.newaxis]
) * jnp.arange(num_transit)


def fourier_chebyshev(theta, iota, alpha, num_transit):
Expand All @@ -976,7 +979,7 @@ def fourier_chebyshev(theta, iota, alpha, num_transit):
``FourierChebyshevSeries.nodes(M,N,domain=(0,2*jnp.pi))``.
iota : jnp.ndarray
Shape (num rho, ).
Rotational transform.
Rotational transform normalized by 2π.
alpha : float
Starting field line poloidal label.
num_transit : int
Expand Down
3 changes: 2 additions & 1 deletion desc/integrals/interp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,7 +437,8 @@ def _fourier(grid, f, is_reshaped=False):
Parameters
----------
grid : Grid
Tensor-product grid in (θ, ζ) with uniformly spaced nodes [0, 2π) × [0, 2π/NFP).
Tensor-product grid in (ρ, θ, ζ) with uniformly spaced nodes
(θ, ζ) ∈ [0, 2π) × [0, 2π/NFP).
Preferably power of 2 for ``grid.num_theta`` and ``grid.num_zeta``.
f : jnp.ndarray
Function evaluated on ``grid``.
Expand Down
65 changes: 14 additions & 51 deletions tests/test_integrals.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
from tests.test_interp_utils import _f_1d, _f_1d_nyquist_freq
from tests.test_plotting import tol_1d

from desc.backend import jit, jnp
from desc.backend import jnp
from desc.basis import FourierZernikeBasis
from desc.equilibrium import Equilibrium
from desc.equilibrium.coords import get_rtz_grid
Expand All @@ -42,7 +42,6 @@
from desc.integrals.bounce_utils import (
_get_extrema,
bounce_points,
get_fieldline,
get_pitch_inv_quad,
interp_fft_to_argmin,
interp_to_argmin,
Expand Down Expand Up @@ -1029,10 +1028,10 @@ def test_bounce_quadrature(self, is_strong, quad, automorphism):
integrand = lambda B, pitch: jnp.sqrt(1 - m * pitch * B)
truth = v * 2 * ellipe(m)
bounce = Bounce1D(
grid=Grid.create_meshgrid([1, 0, knots], coordinates="raz"),
data=data,
quad=quad,
automorphism=automorphism,
Grid.create_meshgrid([1, 0, knots], coordinates="raz"),
data,
quad,
automorphism,
check=True,
)
points = bounce.points(pitch_inv, num_well=1)
Expand Down Expand Up @@ -1613,22 +1612,6 @@ def fun2(pitch):
class TestBounce2D:
"""Test bounce integration that uses 2D pseudo-spectral methods."""

@pytest.mark.unit
@pytest.mark.parametrize(
"alpha_0, iota, num_period, period",
[
(0, np.sqrt(2), 1, 2 * np.pi),
(0, np.arange(1, 3) * np.sqrt(2), 5, 2 * np.pi),
],
)
def test_get_fieldline(self, alpha_0, iota, num_period, period):
"""Test field line label updating works with jit."""
fieldline = jit(get_fieldline, static_argnums=2)(
alpha_0, iota, num_period, period
)
shape = (iota.size, num_period) if np.ndim(iota) else (num_period,)
assert fieldline.shape == shape

@pytest.mark.unit
def test_interp_fft_to_argmin(self):
"""Test interpolation of h to argmin of g.""" # noqa: D202
Expand All @@ -1643,15 +1626,13 @@ def g(z):
[1, 0, fourier_pts(nyquist)], period=(np.inf, 2 * np.pi, 2 * np.pi)
)
bounce = Bounce2D(
grid=grid,
data=dict.fromkeys(Bounce2D.required_names, g(grid.nodes[:, 2])),
grid,
dict.fromkeys(Bounce2D.required_names, g(grid.nodes[:, 2])),
# dummy value; h depends on ζ alone,so doesn't matter what θ(α, ζ) is
theta=grid.meshgrid_reshape(grid.nodes[:, 1], "rtz"),
Y_B=2 * nyquist,
num_transit=1,
spline=True,
# dummy values
iota=1,
# h depends on ζ alone,so doesn't matter what θ(α, ζ) is
theta=grid.meshgrid_reshape(grid.nodes[:, 1], "rtz"),
)
np.testing.assert_allclose(
interp_fft_to_argmin(
Expand Down Expand Up @@ -1705,14 +1686,7 @@ def test_bounce2d_checks(self):
theta = Bounce2D.compute_theta(eq, X=8, Y=64, rho=rho)
# 5. Make the bounce integration operator.
bounce = Bounce2D(
grid,
data,
iota=grid.compress(data["iota"]),
theta=theta,
num_transit=2,
quad=leggauss(3),
check=True,
spline=False,
grid, data, theta, num_transit=2, quad=leggauss(3), check=True, spline=False
)
pitch_inv, _ = bounce.get_pitch_inv_quad(
min_B=grid.compress(data["min_tz |B|"]),
Expand Down Expand Up @@ -1779,15 +1753,7 @@ def test_bounce2d_checks(self):
_, _ = bounce.plot_theta(l, show=False)

# make sure tests pass when spline=True
b = Bounce2D(
grid,
data,
iota=grid.compress(data["iota"]),
theta=theta,
num_transit=2,
check=True,
spline=True,
)
b = Bounce2D(grid, data, theta, num_transit=2, check=True, spline=True)
b.check_points(b.points(pitch_inv), pitch_inv, plot=False)
_, _ = b.plot(l, pitch_inv[l], show=False)

Expand Down Expand Up @@ -1830,12 +1796,9 @@ def test_binormal_drift_bounce2d(self):
grid_data[name] = grid_data[name] * data["normalization"]

bounce = Bounce2D(
grid=grid,
data=grid_data,
iota=data["iota"],
theta=Bounce2D.compute_theta(
eq, X=8, Y=8, rho=data["rho"], iota=data["iota"]
),
grid,
grid_data,
Bounce2D.compute_theta(eq, X=8, Y=8, rho=data["rho"], iota=data["iota"]),
num_transit=3,
alpha=data["alpha"] - 2.5 * np.pi * data["iota"],
Bref=data["Bref"],
Expand Down

0 comments on commit 7c8902a

Please sign in to comment.