From 979b995ea0ab6b812e19deb9c6e1277997e9fea8 Mon Sep 17 00:00:00 2001 From: nstarman Date: Wed, 6 Dec 2023 23:55:44 -0500 Subject: [PATCH] SCF Signed-off-by: nstarman --- pyproject.toml | 2 + src/galax/dynamics/_core.py | 4 +- src/galax/dynamics/mockstream/_core.py | 2 +- src/galax/dynamics/mockstream/_df/fardal.py | 12 +- src/galax/potential/_potential/__init__.py | 3 +- src/galax/potential/_potential/param/core.py | 14 +- .../potential/_potential/scf/__init__.py | 14 +- src/galax/potential/_potential/scf/bfe.py | 259 ++++++++++++++++++ .../potential/_potential/scf/bfe_helper.py | 82 ++++++ src/galax/potential/_potential/scf/coeffs.py | 108 ++++++++ .../potential/_potential/scf/coeffs_helper.py | 68 +++++ .../potential/_potential/scf/gegenbauer.py | 53 ++-- src/galax/potential/_potential/scf/utils.py | 103 ++++++- src/galax/typing.py | 2 + tests/potential/scf/test_coeff_helper.py | 75 +++++ tests/potential/scf/test_coeffs.py | 41 +++ tests/potential/scf/test_gegenbauer.py | 80 +++++- tests/potential/scf/test_utils.py | 145 ++++++++++ 18 files changed, 1009 insertions(+), 58 deletions(-) create mode 100644 src/galax/potential/_potential/scf/bfe.py create mode 100644 src/galax/potential/_potential/scf/bfe_helper.py create mode 100644 src/galax/potential/_potential/scf/coeffs.py create mode 100644 src/galax/potential/_potential/scf/coeffs_helper.py create mode 100644 tests/potential/scf/test_coeff_helper.py create mode 100644 tests/potential/scf/test_coeffs.py create mode 100644 tests/potential/scf/test_utils.py diff --git a/pyproject.toml b/pyproject.toml index 68a43b3c..9f5e6ead 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,6 +44,7 @@ scf = [ "sympy2jax", ] test = [ + "gala", "hypothesis[numpy]", "pytest >=6", "pytest-cov >=3", @@ -159,6 +160,7 @@ ignore = [ "F821", # undefined name <- jaxtyping "FIX002", # Line contains TODO, consider resolving the issue "N80", # Naming conventions. + "N816", # Variable in global scope should not be mixedCase "PD", # pandas-vet "PLR", # Design related pylint codes "PYI041", # Use `float` instead of `int | float` <- beartype is more strict diff --git a/src/galax/dynamics/_core.py b/src/galax/dynamics/_core.py index d18237b9..b99ccbc3 100644 --- a/src/galax/dynamics/_core.py +++ b/src/galax/dynamics/_core.py @@ -51,7 +51,7 @@ def shape(self) -> tuple[int, ...]: @property @partial_jit() def qp(self) -> BatchVec6: - """Return as a single Array[(*batch, Q + P),].""" + """Return as a single Array[float, (*batch, Q + P),].""" batch_shape, component_shapes = self._shape_tuple q = xp.broadcast_to(self.q, batch_shape + component_shapes[0:1]) p = xp.broadcast_to(self.p, batch_shape + component_shapes[1:2]) @@ -99,7 +99,7 @@ def _shape_tuple(self) -> tuple[tuple[int, ...], tuple[int, int, int]]: @property @partial_jit() def w(self) -> BatchVec7: - """Return as a single Array[(*batch, Q + P + T),].""" + """Return as a single Array[float, (*batch, Q + P + T)].""" batch_shape, component_shapes = self._shape_tuple q = xp.broadcast_to(self.q, batch_shape + component_shapes[0:1]) p = xp.broadcast_to(self.p, batch_shape + component_shapes[1:2]) diff --git a/src/galax/dynamics/mockstream/_core.py b/src/galax/dynamics/mockstream/_core.py index 0238f413..a06f3a41 100644 --- a/src/galax/dynamics/mockstream/_core.py +++ b/src/galax/dynamics/mockstream/_core.py @@ -38,7 +38,7 @@ def _shape_tuple(self) -> tuple[tuple[int, ...], tuple[int, int, int]]: @property @partial_jit() def w(self) -> BatchVec7: - """Return as a single Array[(*batch, Q + P + T),].""" + """Return as a single Array[float, (*batch, Q + P + T)].""" batch_shape, component_shapes = self._shape_tuple q = xp.broadcast_to(self.q, batch_shape + component_shapes[0:1]) p = xp.broadcast_to(self.p, batch_shape + component_shapes[1:2]) diff --git a/src/galax/dynamics/mockstream/_df/fardal.py b/src/galax/dynamics/mockstream/_df/fardal.py index 089a3a69..0d60f21d 100644 --- a/src/galax/dynamics/mockstream/_df/fardal.py +++ b/src/galax/dynamics/mockstream/_df/fardal.py @@ -117,7 +117,7 @@ def dphidr(potential: AbstractPotentialBase, x: Vec3, t: FloatScalar) -> Vec3: ---------- potential: AbstractPotentialBase The gravitational potential. - x: Array[(3,), Any] + x: Array[Any, (3,)] 3d position (x, y, z) in [kpc] t: Numeric Time in [Myr] @@ -143,7 +143,7 @@ def d2phidr2( ---------- potential: AbstractPotentialBase The gravitational potential. - x: Array[(3,), Any] + x: Array[Any, (3,)] 3d position (x, y, z) in [kpc] t: Numeric Time in [Myr] @@ -172,9 +172,9 @@ def orbital_angular_velocity(x: Vec3, v: Vec3, /) -> Vec3: Arguments: --------- - x: Array[(3,), Any] + x: Array[Any, (3,)] 3d position (x, y, z) in [length] - v: Array[(3,), Any] + v: Array[Any, (3,)] 3d velocity (v_x, v_y, v_z) in [length/time] Returns: @@ -199,9 +199,9 @@ def orbital_angular_velocity_mag(x: Vec3, v: Vec3, /) -> FloatScalar: Arguments: --------- - x: Array[(3,), Any] + x: Array[Any, (3,)] 3d position (x, y, z) in [kpc] - v: Array[(3,), Any] + v: Array[Any, (3,)] 3d velocity (v_x, v_y, v_z) in [kpc/Myr] Returns: diff --git a/src/galax/potential/_potential/__init__.py b/src/galax/potential/_potential/__init__.py index dcaaec26..43f0a99a 100644 --- a/src/galax/potential/_potential/__init__.py +++ b/src/galax/potential/_potential/__init__.py @@ -7,6 +7,7 @@ from .composite import * from .core import * from .param import * +from .scf import SCFPotential __all__: list[str] = [] __all__ += base.__all__ @@ -14,4 +15,4 @@ __all__ += composite.__all__ __all__ += param.__all__ __all__ += builtin.__all__ -__all__ += ["scf"] +__all__ += ["scf", "SCFPotential"] diff --git a/src/galax/potential/_potential/param/core.py b/src/galax/potential/_potential/param/core.py index a53ee791..32d4f927 100644 --- a/src/galax/potential/_potential/param/core.py +++ b/src/galax/potential/_potential/param/core.py @@ -43,14 +43,14 @@ def __call__(self, t: FloatScalar, **kwargs: Any) -> FloatArrayAnyShape: Parameters ---------- - t : Array + t : float | Array[float, ()] The time(s) at which to compute the parameter value. - **kwargs + **kwargs : Any Additional parameters to pass to the parameter function. Returns ------- - Array + Array[float, "*shape"] The parameter value at times ``t``. """ ... @@ -74,7 +74,7 @@ def __call__( Parameters ---------- - t : Array, optional + t : float | Array[float, ()], optional This is ignored and is thus optional. Note that for most :class:`~galax.potential.AbstractParameter` the time is required. @@ -83,7 +83,7 @@ def __call__( Returns ------- - Array + Array[float, "*shape"] The constant parameter value. """ # Vectorization to enable broadcasting over the time dimension. @@ -107,14 +107,14 @@ def __call__(self, t: FloatScalar, **kwargs: Any) -> FloatArrayAnyShape: Parameters ---------- - t : Array + t : float | Array[float, ()] Time(s) at which to compute the parameter value. **kwargs : Any Additional parameters to pass to the parameter function. Returns ------- - Array + Array[float, "*shape"] Parameter value(s) at the given time(s). """ ... diff --git a/src/galax/potential/_potential/scf/__init__.py b/src/galax/potential/_potential/scf/__init__.py index 7495419c..a62b0ea6 100644 --- a/src/galax/potential/_potential/scf/__init__.py +++ b/src/galax/potential/_potential/scf/__init__.py @@ -1,7 +1,11 @@ -from __future__ import annotations - -from . import gegenbauer -from .gegenbauer import * +from . import bfe, bfe_helper, coeffs, coeffs_helper +from .bfe import * +from .bfe_helper import * +from .coeffs import * +from .coeffs_helper import * __all__: list[str] = [] -__all__ += gegenbauer.__all__ +__all__ += bfe.__all__ +__all__ += bfe_helper.__all__ +__all__ += coeffs.__all__ +__all__ += coeffs_helper.__all__ diff --git a/src/galax/potential/_potential/scf/bfe.py b/src/galax/potential/_potential/scf/bfe.py new file mode 100644 index 00000000..17d7044a --- /dev/null +++ b/src/galax/potential/_potential/scf/bfe.py @@ -0,0 +1,259 @@ +"""Self-Consistent Field Potential.""" + +__all__ = ["SCFPotential", "STnlmSnapshotParameter"] + +from collections.abc import Callable +from dataclasses import KW_ONLY +from typing import Any + +import astropy.units as u +import equinox as eqx +import jax.numpy as xp +from jaxtyping import Array, Float +from typing_extensions import override + +from galax.potential._potential.core import AbstractPotential +from galax.potential._potential.param import AbstractParameter, ParameterField +from galax.typing import ArrayAnyShape, FloatLike, FloatScalar, Vec3 +from galax.utils import partial_jit, vectorize_method + +from .bfe_helper import phi_nl as calculate_phi_nl +from .bfe_helper import rho_nl as calculate_rho_nl +from .coeffs import compute_coeffs_discrete +from .gegenbauer import GegenbauerCalculator +from .utils import cartesian_to_spherical, expand_dim1, real_Ylm + +############################################################################## + + +class SCFPotential(AbstractPotential): + r"""Self-Consistent Field (SCF) potential. + + A gravitational potential represented as a basis function expansion. This + uses the self-consistent field (SCF) method of Hernquist & Ostriker (1992) + and Lowing et al. (2011), and represents all coefficients as real + quantities. + + Parameters + ---------- + m : numeric + Scale mass. + r_s : numeric + Scale length. + Snlm : Array[float, (nmax+1, lmax+1, lmax+1)] | Callable + Array of coefficients for the cos() terms of the expansion. This should + be a 3D array with shape `(nmax+1, lmax+1, lmax+1)`, where `nmax` is the + number of radial expansion terms and `lmax` is the number of spherical + harmonic `l` terms. If a callable is provided, it should accept a + single argument `t` and return the array of coefficients for that time. + Tnlm : Array[float, (nmax+1, lmax+1, lmax+1)] | Callable + Array of coefficients for the sin() terms of the expansion. This should + be a 3D array with shape `(nmax+1, lmax+1, lmax+1)`, where `nmax` is the + number of radial expansion terms and `lmax` is the number of spherical + harmonic `l` terms. If a callable is provided, it should accept a + single argument `t` and return the array of coefficients for that time. + units : iterable + Unique list of non-reducable units that specify (at minimum) the length, + mass, time, and angle units. + """ + + m: AbstractParameter = ParameterField(dimensions="mass") # type: ignore[assignment] + r_s: AbstractParameter = ParameterField(dimensions="length") # type: ignore[assignment] + Snlm: AbstractParameter = ParameterField(dimensions="dimensionless") # type: ignore[assignment] + Tnlm: AbstractParameter = ParameterField(dimensions="dimensionless") # type: ignore[assignment] + + nmax: int = eqx.field(init=False, static=True, repr=False) + lmax: int = eqx.field(init=False, static=True, repr=False) + _ultra_sph: GegenbauerCalculator = eqx.field(init=False, static=True, repr=False) + + def __post_init__(self) -> None: + super().__post_init__() + + # shape parameters + shape = self.Snlm(0).shape + object.__setattr__(self, "nmax", shape[0] - 1) + object.__setattr__(self, "lmax", shape[1] - 1) + + # gegenbauer calculator + object.__setattr__(self, "_ultra_sph", GegenbauerCalculator(self.nmax)) + + # ========================================================================== + + @partial_jit() + @eqx.filter_vmap(in_axes=(None, 1, None)) # type: ignore[misc] # on `q` axis 1 + def _potential_energy_helper( + self, q: Float[Array, "3 N"], /, t: Float[Array, "1"] + ) -> Float[Array, "N"]: # type: ignore[name-defined] + r, theta, phi = cartesian_to_spherical(q) + r_s = self.r_s(t) + s = xp.atleast_1d(r / r_s)[:, None, None, None] + theta = xp.atleast_1d(theta)[:, None, None, None] + phi = xp.atleast_1d(phi)[:, None, None, None] + + ns = xp.arange(self.nmax + 1)[None, :, None, None] # ([N], n, [l], [m]) + ls = xp.arange(self.lmax + 1)[None, None, :, None] # ([N], [n], l, [m]) + phi_nl = calculate_phi_nl(s, ns, ls, gegenbauer=self._ultra_sph) + + li, mi = xp.tril_indices(self.lmax + 1) # (l*(l+1)//2,) + shape = (1, 1, self.lmax + 1, self.lmax + 1) + midx = xp.zeros(shape, dtype=int).at[:, :, li, mi].set(mi) + Ylm = xp.zeros((len(theta), 1, self.lmax + 1, self.lmax + 1)) + Ylm = Ylm.at[:, :, li, mi].set(real_Ylm(li[None], mi[None], theta[:, :, 0, 0])) + + Snlm = self.Snlm(t, r_s=r_s)[None] + Tnlm = self.Tnlm(t, r_s=r_s)[None] + + out = (self._G * self.m(t) / r_s) * xp.sum( + Ylm * phi_nl * (Snlm * xp.cos(midx * phi) + Tnlm * xp.sin(midx * phi)), + axis=(1, 2, 3), + ) + return out[0] if len(q.shape) == 1 else out + + @partial_jit() + @vectorize_method(signature="(3),()->()") + def _potential_energy(self, q: Vec3, /, t: FloatScalar) -> FloatScalar: + """Compute the potential energy at the given position(s).""" + out = self._potential_energy_helper(expand_dim1(q), t) + return out[0] if len(q.shape) == 1 else out + + # @partial_jit() + # @eqx.filter_vmap(in_axes=(None, 1, None)) # type: ignore[misc] # on `q` axis 1 + # def _gradient(self, q: Float[Array, "3"], /, t: jt.Array) -> jt.Array: + # """Compute the gradient.""" + # r, theta, phi = cartesian_to_spherical(q) + # r_s = self.r_s(t) + # s = xp.atleast_1d(r / r_s)[:, None, None, None] + # theta = xp.atleast_1d(theta)[:, None, None, None] + # phi = xp.atleast_1d(phi)[:, None, None, None] + + # ns = xp.arange(self.nmax + 1)[None, :, None, None] # ([N], n, [l], [m]) + # ls = xp.arange(self.lmax + 1)[None, None, :, None] # ([N], [n], l, [m]) + # phi_nl = calculate_phi_nl(s, ns, ls, gegenbauer=self._ultra_sph) + # dphi_nl_dr = phi_nl_grad(s, ns, ls, self._ultra_sph) + + # li, mi = xp.tril_indices(self.lmax + 1) # (l*(l+1)//2,) + # shape = (1, 1, self.lmax + 1, self.lmax + 1) + # lidx = xp.zeros(shape, dtype=int).at[:, :, li, mi].set(li) + # midx = xp.zeros(shape, dtype=int).at[:, :, li, mi].set(mi) + # mvalid = xp.zeros(shape).at[:, :, li, mi].set(1) # m <= l + # Ylm = real_Ylm(lidx, midx, theta) + # dYlm_dtheta = calculate_dYlm_dtheta(lidx, midx, theta) + + # Snlm = self.Snlm(t, r_s=r_s)[None] + # Tnlm = self.Tnlm(t, r_s=r_s)[None] + + # grad_r = xp.sum( + # (mvalid * Ylm) + # * dphi_nl_dr + # * (Snlm * xp.cos(midx * phi) + Tnlm * xp.sin(midx * phi)), + # axis=(1, 2, 3), + # ) + # grad_theta = (1 / s[:, 0, 0, 0]) * xp.sum( + # (mvalid * dYlm_dtheta) + # * phi_nl + # * (Snlm * xp.cos(midx * phi) + Tnlm * xp.sin(midx * phi)), + # axis=(1, 2, 3), + # ) + # grad_phi = (1 / s[:, 0, 0, 0]) * xp.sum( + # (mvalid * Ylm / xp.sin(theta)) + # * phi_nl + # * (Tnlm * xp.cos(midx * phi) - Snlm * xp.sin(midx * phi)), + # axis=(1, 2, 3), + # ) + # return (self._G * self.m(t) / r_s) * xp.stack([grad_r, grad_theta, grad_phi], + # axis=-1) + + # @partial_jit() + # def gradient(self, q: jt.Array, /, t: jt.Array) -> jt.Array: + # """Compute the potential energy at the given position(s).""" + # out = self._gradient(expand_dim1(q), t) + # return out[0, 0] if len(q.shape) == 1 else out[:, 0] # TODO: fix this + + @partial_jit() + @eqx.filter_vmap(in_axes=(None, 1, None)) # type: ignore[misc] # on `q` axis 1 + def density( + self, q: Float[Array, "3 N"], /, t: Float[Array, "1"] + ) -> Float[Array, "N"]: # type: ignore[name-defined] + """Compute the density at the given position(s).""" + r, theta, phi = cartesian_to_spherical(q) + r_s = self.r_s(t) + s = xp.atleast_1d(r / r_s)[:, None, None, None] + theta = xp.atleast_1d(theta)[:, None, None, None] + phi = xp.atleast_1d(phi)[:, None, None, None] + + ns = xp.arange(self.nmax + 1)[:, None, None] # (n, [l], [m]) + ls = xp.arange(self.lmax + 1)[None, :, None] # ([n], l, [m]) + + phi_nl = calculate_rho_nl(s, ns[None], ls[None], gegenbauer=self._ultra_sph) + + li, mi = xp.tril_indices(self.lmax + 1) # (l*(l+1)//2,) + shape = (1, 1, self.lmax + 1, self.lmax + 1) + midx = xp.zeros(shape, dtype=int).at[:, :, li, mi].set(mi) + Ylm = xp.zeros((len(theta), 1, self.lmax + 1, self.lmax + 1)) + Ylm = Ylm.at[:, :, li, mi].set(real_Ylm(li[None], mi[None], theta[:, :, 0, 0])) + + Snlm = self.Snlm(t, r_s=r_s)[None] + Tnlm = self.Tnlm(t, r_s=r_s)[None] + + out = (self._G * self.m(t) / r_s) * xp.sum( + Ylm * phi_nl * (Snlm * xp.cos(midx * phi) + Tnlm * xp.sin(midx * phi)), + axis=(1, 2, 3), + ) + return out[0] if len(q.shape) == 1 else out + + +# ============================================================================= + + +class STnlmSnapshotParameter(AbstractParameter): + """Parameter for the STnlm coefficients.""" + + snapshot: Callable[ # type: ignore[name-defined] + [Float[Array, "N"]], + tuple[Float[Array, "3 N"], Float[Array, "N"]], + ] + """Cartesian coordinates of the snapshot. + + This should be a callable that accepts a single argument `t` and returns + the cartesian coordinates and the masses of the snapshot at that time. + """ + + nmax: int = eqx.field(static=True, converter=int) + """Radial expansion term.""" + + lmax: int = eqx.field(static=True, converter=int) + """Spherical harmonic term.""" + + _: KW_ONLY + unit: u.Unit = eqx.field(default=u.one, static=True, converter=u.Unit) + + def __post_init__(self) -> None: + super().__post_init__() + if self.unit != u.one: + msg = "unit must be dimensionless" + raise ValueError(msg) + + @override + def __call__( # type: ignore[override] + self, t: FloatLike, *, r_s: float, **kwargs: Any + ) -> tuple[ArrayAnyShape, ArrayAnyShape]: + """Return the coefficients at the given time(s). + + Parameters + ---------- + t : float | Array[float, ()] + Time at which to evaluate the coefficients. + r_s : float | Array[float, ()] + Scale length of the potential at the given time(s. + **kwargs : Any + Additional keyword arguments are ignored. + + Returns + ------- + Snlm : Array[float, (nmax+1, lmax+1, lmax+1)] + The value of the cosine expansion coefficient. + Tnlm : Array[float, (nmax+1, lmax+1, lmax+1)] + The value of the sine expansion coefficient. + """ + xyz, m = self.snapshot(t) + return compute_coeffs_discrete(xyz, m, nmax=self.nmax, lmax=self.lmax, r_s=r_s) diff --git a/src/galax/potential/_potential/scf/bfe_helper.py b/src/galax/potential/_potential/scf/bfe_helper.py new file mode 100644 index 00000000..824697a9 --- /dev/null +++ b/src/galax/potential/_potential/scf/bfe_helper.py @@ -0,0 +1,82 @@ +"""Self-Consistent Field Potential.""" + +__all__: list[str] = [] + +import jax +import jax.numpy as xp +from jaxtyping import Array, Float + +from galax.potential._potential.scf.gegenbauer import GegenbauerCalculator +from galax.typing import IntLike +from galax.utils import partial_jit + +from .coeffs_helper import normalization_Knl +from .utils import psi_of_r + + +@partial_jit(static_argnames=("gegenbauer",)) +def rho_nl( + s: Float[Array, "N"], n: int, l: int, *, gegenbauer: GegenbauerCalculator +) -> Float[Array, "N"]: + r"""Radial density expansion terms. + + Parameters + ---------- + s : Array[float, (n,)] + Scaled radius :math:`r/r_s`. + n : int + Radial expansion term. + l : int + Spherical harmonic term. + gegenbauer : GegenbauerCalculator, keyword-only + Gegenbauer calculator. This is used to compute the Gegenbauer + polynomials efficiently. + + Returns + ------- + Array[float, (n,)] + """ + return ( + xp.sqrt(xp.pi * 4) + * (normalization_Knl(n=n, l=l) / (2 * xp.pi)) + * (s**l / (s * (1 + s) ** (2 * l + 3))) + * gegenbauer(n, 2 * l + 1.5, psi_of_r(s)) + ) + + +# ====================================================================== + + +@partial_jit(static_argnames=("gegenbauer",)) +def phi_nl( + s: Float[Array, "samples"], n: IntLike, l: IntLike, gegenbauer: GegenbauerCalculator +) -> Float[Array, "samples"]: + r"""Angular density expansion terms. + + Parameters + ---------- + s : Array[float, (n_samples,)] + Scaled radius :math:`r/r_s`. + n : int + Radial expansion term. + l : int + Spherical harmonic term. + gegenbauer : GegenbauerCalculator, keyword-only + Gegenbauer calculator. This is used to compute the Gegenbauer + polynomials efficiently. + + Returns + ------- + Array[float, (n_samples,)] + """ + return ( + -xp.sqrt(4 * xp.pi) + * (s**l / (1.0 + s) ** (2 * l + 1)) + * gegenbauer(n, 2 * l + 1.5, psi_of_r(s)) + ) + + +phi_nl_vec = xp.vectorize(phi_nl, signature="(n),(),()->(n)", excluded=(3,)) + +phi_nl_grad = jax.jit(xp.vectorize(jax.grad(phi_nl, argnums=0), excluded=(3,))) +r"""Derivative :math:`\frac{d}{dr} \phi_{nl}`.""" diff --git a/src/galax/potential/_potential/scf/coeffs.py b/src/galax/potential/_potential/scf/coeffs.py new file mode 100644 index 00000000..b60c5465 --- /dev/null +++ b/src/galax/potential/_potential/scf/coeffs.py @@ -0,0 +1,108 @@ +"""Self-Consistent Field Potential.""" + +__all__ = ["compute_coeffs_discrete"] + + +import jax +import jax.numpy as xp +from jaxtyping import Array, Float + +from galax.typing import FloatLike, IntLike +from galax.utils import partial_jit + +from .bfe_helper import phi_nl_vec +from .coeffs_helper import expansion_coeffs_Anl_discrete +from .gegenbauer import GegenbauerCalculator +from .utils import cartesian_to_spherical, real_Ylm + + +@partial_jit(static_argnames=("nmax", "lmax", "gegenbauer")) +def compute_coeffs_discrete( + xyz: Float[Array, "samples 3"], + mass: Float[Array, "samples"], # type: ignore[name-defined] + nmax: IntLike, + lmax: IntLike, + r_s: IntLike | FloatLike, + *, + gegenbauer: GegenbauerCalculator | None = None, +) -> tuple[ + Float[Array, "{nmax}+1 {lmax}+1 {lmax}+1"], + Float[Array, "{nmax}+1 {lmax}+1 {lmax}+1"], +]: + """Compute expansion coefficients for the SCF potential. + + Compute the expansion coefficients for representing the density distribution + of input points as a basis function expansion. The points, ``xyz``, are + assumed to be samples from the density distribution. + + This is Equation 15 of Lowing et al. (2011). + + Parameters + ---------- + xyz : Array[float, (n_samples, 3)] + Samples from the density distribution. + :todo:`unit support` + mass : Array[float, (n_samples,)] + Mass of each sample. + :todo:`unit support` + nmax : int + Maximum value of ``n`` for the radial expansion. + lmax : int + Maximum value of ``l`` for the spherical harmonics. + r_s : numeric + Scale radius. + :todo:`unit support` + + gegenbauer : GegenbauerCalculator, optional + Gegenbauer calculator. This is used to compute the Gegenbauer + polynomials efficiently. If not provided, a new calculator will be + created. + + Returns + ------- + Snlm : Array[float, (nmax+1, lmax+1, lmax+1)] + The value of the cosine expansion coefficient. + Tnlm : Array[float, (nmax+1, lmax+1, lmax+1)] + The value of the sine expansion coefficient. + """ + if gegenbauer is None: + ggncalc = GegenbauerCalculator(nmax=nmax) + elif gegenbauer.nmax != nmax: + msg = "gegenbauer.nmax != nmax" + raise ValueError(msg) + else: + ggncalc = gegenbauer + + rthetaphi = cartesian_to_spherical(xyz) + r = rthetaphi[..., 0] + theta = rthetaphi[..., 1] + phi = rthetaphi[..., 2] + s = r / r_s + + ns = xp.arange(nmax + 1)[:, None] # (n, l) + ls = xp.arange(lmax + 1)[None, :] # (n, l) + + Anl_til = expansion_coeffs_Anl_discrete(ns, ls) # (n, l) + phinl = phi_nl_vec(s, ns, ls, ggncalc) # (n, l, N) + + li, mi = xp.tril_indices(lmax + 1) # (l*(l+1)//2,) + lm = xp.zeros((lmax + 1, lmax + 1), dtype=int).at[li, mi].set(li) # (l, m) + ms = xp.zeros((lmax + 1, lmax + 1), dtype=int).at[li, mi].set(mi) # (l, m) + # TODO: this is VERY SLOW. Can we do better? + Ylm = real_Ylm(theta[None, None, :], lm[..., None], ms[..., None]) # (l, m, N) + + delta = jax.lax.select(ms == 0, xp.ones_like(ms), xp.zeros_like(ms)) # (l, m) + mvalid = xp.zeros((lmax + 1, lmax + 1)).at[li, mi].set(1) # select m <= l + + tmp = ( # (n, l, m, N) using broadcasting + mvalid[None, :, :, None] + * (2 - delta[None, :, :, None]) + * Anl_til[:, :, None, None] + * mass[None, None, None, :] + * phinl[:, :, None, :] + * Ylm[None, :, :, :] + ) + Snlm = xp.sum(tmp * xp.cos(ms[None, :, :, None] * phi[None, None, None]), axis=-1) + Tnlm = xp.sum(tmp * xp.sin(ms[None, :, :, None] * phi[None, None, None]), axis=-1) + + return Snlm, Tnlm diff --git a/src/galax/potential/_potential/scf/coeffs_helper.py b/src/galax/potential/_potential/scf/coeffs_helper.py new file mode 100644 index 00000000..1720dcb1 --- /dev/null +++ b/src/galax/potential/_potential/scf/coeffs_helper.py @@ -0,0 +1,68 @@ +"""Self-Consistent Field Potential.""" + +__all__: list[str] = [] + +from typing import overload + +import jax.numpy as xp +from jax.scipy.special import gamma +from jaxtyping import Array, Float, Integer + +from galax.utils import partial_jit + +from .utils import factorial + + +@overload +def normalization_Knl(n: int, l: int) -> float: + ... + + +@overload +def normalization_Knl(n: Array, l: Array) -> Array: + ... + + +def normalization_Knl( + n: Integer[Array, "*#shape"] | int, l: Integer[Array, "*#shape"] | int +) -> Float[Array, "*shape"] | float: + """SCF normalization factor. + + Parameters + ---------- + n : int + Radial expansion term. + l : int + Spherical harmonic term. + + Returns + ------- + float + """ + return 0.5 * n * (n + 4 * l + 3.0) + (l + 1) * (2 * l + 1) + + +@partial_jit() +def expansion_coeffs_Anl_discrete( + n: Integer[Array, "*#shape"], l: Integer[Array, "*#shape"] +) -> Float[Array, "*shape"]: + """Return normalization factor for the coefficients. + + Equation 16 of Lowing et al. (2011). + + Parameters + ---------- + n : int + Radial expansion term. + l : int + spherical harmonic term. + + Returns + ------- + float + """ + Knl = normalization_Knl(n=n, l=l) + prefac = -(2 ** (8.0 * l + 6)) / (4 * xp.pi * Knl) + numerator = factorial(n) * (n + 2 * l + 1.5) * gamma(2 * l + 1.5) ** 2 + denominator = gamma(n + 4.0 * l + 3.0) + return prefac * (numerator / denominator) diff --git a/src/galax/potential/_potential/scf/gegenbauer.py b/src/galax/potential/_potential/scf/gegenbauer.py index a8e00154..4cea8f6b 100644 --- a/src/galax/potential/_potential/scf/gegenbauer.py +++ b/src/galax/potential/_potential/scf/gegenbauer.py @@ -1,45 +1,54 @@ """Gegenbauer polynomials.""" -from __future__ import annotations + __all__ = ["GegenbauerCalculator"] -from typing import Protocol +from typing import Protocol, runtime_checkable import equinox as eqx import jax -import jax.typing as jt import sympy as sp import sympy2jax from jax.scipy.special import gamma -from galax.utils import partial_jit +from galax.typing import FloatLike, IntLike, VecN +from galax.utils import partial_jit, vectorize_method from .utils import factorial +@runtime_checkable class AbstractGegenbauerDerivativeTerm(Protocol): - def __call__(self, x: jt.Array, alpha: float) -> jt.Array: + def __call__(self, x: VecN, alpha: float) -> VecN: ... -# TODO: better names -def _compute_derivative_terms( +def _compute_weight_function_derivatives( nmax: int, ) -> tuple[AbstractGegenbauerDerivativeTerm, ...]: - """Generate the nth derivative of the Gegenbauer polynomials.""" + """Compute the nth derivative of the weight function for the Gegenbauer polynomials. + + .. todo:: + + This isn't quite Equation 22.2.3 of Abramowitz & Stegun (1972). + + It's too costly to have JAX compute the Gegenbauer polynomials directly, so + we instead compute the nth derivative of the weight function, and then + use the recurrence relation to compute the Gegenbauer polynomials. + """ # Make Symbols x: sp.Symbol = sp.Symbol("x") n: sp.Symbol = sp.Symbol("n", integer=True) alpha: sp.Symbol = sp.Symbol("alpha", positive=True) - # Derivative generator term - deriv_arg = (1 - x**2) ** (n + alpha - 0.5) + # weight function for the Gegenbauer polynomials + weight = (1 - x**2) ** (n + alpha - 0.5) # Compute the nth derivative term for the Gegenbauer polynomials func_list = [] for i in range(nmax + 1): # Symbolic computation of the nth derivative - fn_sympy: sp.Expr = sp.simplify(sp.diff(deriv_arg, x, i)) + fn_sympy: sp.Expr = sp.simplify(sp.diff(weight, x, i)) # Convert to a JAX function fn_jax = sympy2jax.SymbolicModule(fn_sympy) # Re-arrange the arguments @@ -56,22 +65,28 @@ class GegenbauerCalculator(eqx.Module): # type: ignore[misc] nmax: int """Maximum order of the Gegenbauer polynomials.""" - _func_list: tuple[AbstractGegenbauerDerivativeTerm, ...] = eqx.field( - init=False, static=True + _weights: tuple[AbstractGegenbauerDerivativeTerm, ...] = eqx.field( + init=False, static=True, repr=False ) - """Tuple of functions for nth derivative of the Gegenbauer polynomials. + """Tuple of weights for nth derivative of the Gegenbauer polynomials. This is computed in the __post_init__ method. """ def __post_init__(self) -> None: - object.__setattr__(self, "_func_list", _compute_derivative_terms(self.nmax)) + object.__setattr__( + self, "_weights", _compute_weight_function_derivatives(self.nmax) + ) - @partial_jit() - def __call__(self, n: int, alpha: float, x: jt.Array) -> jt.Array: + # TODO: switch requires integer n, but everything else already vectorized + @partial_jit(static_argnames=("n", "alpha")) + @vectorize_method(signature="(),(),()->()") + def __call__( + self, n: IntLike, alpha: IntLike | FloatLike, x: FloatLike | VecN + ) -> FloatLike | VecN: r"""Calculate :math:`C_n^\alpha(x)`.""" - # TODO: how to vmap methods on PyTrees so it broadcasts and `n` can be an array? - nth_deriv = jax.lax.switch(n, self._func_list, x, alpha) + # # TODO: the Gegenbauer polynomials have limits on valid inputs + nth_deriv = jax.lax.switch(n, self._weights, x, alpha) # TODO: write out full mathematical derivation factor0 = ((-1.0) ** n) / ((2**n) * factorial(n)) diff --git a/src/galax/potential/_potential/scf/utils.py b/src/galax/potential/_potential/scf/utils.py index 0abf6b86..bd29a060 100644 --- a/src/galax/potential/_potential/scf/utils.py +++ b/src/galax/potential/_potential/scf/utils.py @@ -1,12 +1,103 @@ """Utility Functions.""" -# ruff:noqa: UP037 -from __future__ import annotations +from typing import TypeVar, cast -from jax.scipy.special import gamma -from jaxtyping import Array, Integer +import jax +import jax.numpy as xp +from jax import lax +from jax._src.numpy.util import promote_args_inexact +from jax.scipy.special import sph_harm +from jaxtyping import Array, ArrayLike, Float +from galax.typing import ( + ArrayAnyShape, + BatchableIntLike, + FloatScalar, + IntLike, + Vec3, +) +from galax.utils import partial_jit, partial_vectorize -def factorial(n: Integer[Array, "1"]) -> Integer[Array, "1"]: +T = TypeVar("T", bound=ArrayLike) + + +@jax.jit # type: ignore[misc] +@partial_vectorize(signature="(3)->(3)") +def cartesian_to_spherical(xyz: Vec3, /) -> Vec3: + """Convert Cartesian coordinates to spherical coordinates. + + Parameters + ---------- + xyz : Array[float, (3,)] + Cartesian coordinates in the form (x, y, z). + + Returns + ------- + r_theta_phi : Array[float, (3,)] + Spherical radius. + Inclination (polar) angle in [0, pi] from North to South pole. + Azimuthal angle in [-pi, pi] + """ + r = xp.sqrt(xp.sum(xyz**2, axis=0)) # spherical radius + # TODO: this is a hack to avoid the ambiguity at r==0. This should be done better. + theta = jax.lax.select( + r == 0, xp.zeros_like(r), xp.arccos(xyz[2] / r) + ) # inclination angle + phi = xp.arctan2(xyz[1], xyz[0]) # azimuthal angle + return xp.array([r, theta, phi]) + + +# TODO: replace with upstream, when available +def factorial(n: T) -> T: """Factorial helper function.""" - return gamma(n + 1.0) # n! = gamma(n+1) + (n,) = promote_args_inexact("factorial", n) + return cast("T", xp.where(n < 0, 0, lax.exp(lax.lgamma(n + 1)))) + + +def psi_of_r(r: T) -> T: + r""":math:`\psi(r) = (r-1)/(r+1)`. + + Equation 9 of Lowing et al. (2011). + """ + return cast("T", (r - 1.0) / (r + 1.0)) + + +@partial_vectorize(signature="(),(),()->()", excluded=(3,)) +@partial_jit(static_argnames=("m_max",)) # TODO: should l,m be static? +def _real_Ylm(theta: FloatScalar, l: IntLike, m: IntLike, m_max: int) -> FloatScalar: + # TODO: sph_harm only supports scalars, even though it returns an array! + return sph_harm( + m, xp.atleast_1d(l), theta=0, phi=xp.atleast_1d(theta), n_max=m_max + ).real[0] + + +def real_Ylm( + theta: ArrayAnyShape, l: BatchableIntLike, m: BatchableIntLike, m_max: int = 100 +) -> ArrayAnyShape: + r"""Get the spherical harmonic :math:`Y_{lm}(\theta)` of the polar angle. + + This is different than the scipy (and thus JAX) convention, which is + :math:`Y_{lm}(\theta, \phi)`. + Note that scipy also uses the opposite convention for theta, phi where + theta is the azimuthal angle and phi is the polar angle. + + Parameters + ---------- + theta : Array[float, (n,)] + Polar angle in [0, pi]. + l, m : int | Array[int, ()] + Spherical harmonic terms. l in [0,lmax], m in [0,l]. + m_max : int, optional + Maximum order of the spherical harmonic expansion. + + Returns + ------- + Array[float, (n)] + Spherical harmonic. + """ + # TODO: raise an error if m > m_max + return _real_Ylm(theta, l, m, m_max) + + +def expand_dim1(x: Float[Array, "N"], /) -> Float[Array, "N"]: + return xp.expand_dims(x, axis=1) if len(x.shape) == 1 else x diff --git a/src/galax/typing.py b/src/galax/typing.py index 57714b63..7b5e0dd4 100644 --- a/src/galax/typing.py +++ b/src/galax/typing.py @@ -91,6 +91,8 @@ # ----------------- # Any Shape +VecShape = Float[Array, "*shape"] + FloatArrayAnyShape = Float[Array, "..."] """A float array with any shape.""" diff --git a/tests/potential/scf/test_coeff_helper.py b/tests/potential/scf/test_coeff_helper.py new file mode 100644 index 00000000..7e5a795a --- /dev/null +++ b/tests/potential/scf/test_coeff_helper.py @@ -0,0 +1,75 @@ +"""Test the Gegenbauer class.""" + +import jax +import jax.numpy as xp +import numpy as np +import pytest + +from galax.potential._potential.scf.bfe_helper import rho_nl +from galax.potential._potential.scf.coeffs_helper import ( + expansion_coeffs_Anl_discrete, + normalization_Knl, +) +from galax.potential._potential.scf.gegenbauer import GegenbauerCalculator + + +def test_normalization_Knl(): + """Test the ``normalization_Knl`` function. + + .. todo:: + + This test is not very good. It should be improved. + """ + assert normalization_Knl(0, 0) == 1 + assert normalization_Knl(1, 0) == 3 + assert normalization_Knl(2, 0) == 6 + assert normalization_Knl(0, 1) == 6 + assert normalization_Knl(0, 2) == 15 + assert normalization_Knl(1, 1) == 10 + + +# ============================================================================= + + +def test_expansion_coeffs_Anl_discrete(): + """Test the ``expansion_coeffs_Anl_discrete`` function. + + .. todo:: + + This test is not very good. It should be improved. + """ + np.testing.assert_allclose(expansion_coeffs_Anl_discrete(0, 0), -3) + np.testing.assert_allclose( + expansion_coeffs_Anl_discrete(1, 0), -0.555555, rtol=1e-5 + ) + + +# ============================================================================= + + +@jax.jit +def compare_rho_nl(s, n, l): + """Compare the ``rho_nl`` function.""" + gc = GegenbauerCalculator(10) + + mock = rho_nl(s, n, l, gegenbauer=gc) + observed = jax.lax.stop_gradient(mock) + + return -xp.sum((observed - mock) ** 2) + + +@pytest.mark.skip(reason="TODO") +def test_rho_nl(): + """Test the ``rho_nl`` function.""" + s = xp.linspace(0, 4, 100, dtype=float) + n = xp.array([1.0]) + l = xp.array([2.0]) + + first_deriv = jax.jacfwd(compare_rho_nl)(s, n, l) + assert first_deriv == 0 + + +@pytest.mark.skip(reason="TODO") +def test_phi_nl(): + """Test the ``phi_nl`` function.""" + raise NotImplementedError diff --git a/tests/potential/scf/test_coeffs.py b/tests/potential/scf/test_coeffs.py new file mode 100644 index 00000000..04958bcd --- /dev/null +++ b/tests/potential/scf/test_coeffs.py @@ -0,0 +1,41 @@ +"""Test the Coefficient Calculations.""" + +import gala.potential as gp +import jax.numpy as xp +import numpy as np + +import galax.potential as gpx + + +def test_compute_coeffs_discrete(): + """Test the ``normalization_Knl`` function. + + .. todo:: + + This test is not very good. It should be improved. + """ + # Setup + rng = np.random.default_rng(42) + particle_xyz = rng.normal(0.0, 5.0, size=(3, 10_000)) + particle_xyz[2] = np.abs(particle_xyz[2]) + particle_xyz = xp.array(particle_xyz) + + particle_mass = xp.ones(particle_xyz.shape[1]) + particle_mass = 1e12 * particle_mass / particle_mass.sum() + + nmax = 2 + lmax = 3 + r_s = 10 + + # Gala + gala_Snlm, gala_Tnlm = gp.scf.compute_coeffs_discrete( + np.array(particle_xyz), np.array(particle_mass), nmax=nmax, lmax=lmax, r_s=r_s + ) + + # Galdynamix + Snlm, Tnlm = gpx.scf.compute_coeffs_discrete( + particle_xyz, particle_mass, nmax=nmax, lmax=lmax, r_s=r_s + ) + + np.testing.assert_allclose(Snlm, gala_Snlm, rtol=1e-7) + np.testing.assert_allclose(Tnlm, gala_Tnlm, rtol=1e-7) diff --git a/tests/potential/scf/test_gegenbauer.py b/tests/potential/scf/test_gegenbauer.py index 7230fcae..0ac258b5 100644 --- a/tests/potential/scf/test_gegenbauer.py +++ b/tests/potential/scf/test_gegenbauer.py @@ -1,27 +1,85 @@ """Test the Gegenbauer class.""" -from __future__ import annotations + +from types import LambdaType import jax.numpy as xp import numpy as np +import pytest from scipy.special import gegenbauer as gegenbauer_sp -from galdynamix.potential._potential.scf import GegenbauerCalculator +from galax.potential._potential.scf.gegenbauer import ( + GegenbauerCalculator, + _compute_weight_function_derivatives, +) class TestGegenbauerCalculator: """Test the GegenbauerCalculator class.""" - def test_C_n_alpha(self): - """Test the C_n_alpha function.""" + def test_compute_weight_function_derivatives(self): + """Test the ``_compute_weight_function_derivatives`` function. + + ..todo:: + + This test is not very good, since it doesn't actually check the + correctness of the output. It just checks that the function runs. + + """ + nmax = 5 + terms = _compute_weight_function_derivatives(nmax) + assert len(terms) == nmax + 1 + + x = xp.linspace(0.02, 0.99, 10000) + alpha = 2 + got = terms[0](x, alpha) + assert got.shape == x.shape + + @pytest.mark.parametrize("nmax", [0, 2, 6]) + def test_init(self, nmax): + """Test initializing the GegenbauerCalculator.""" + gc = GegenbauerCalculator(nmax) + assert gc.nmax == nmax + + # Check the pre-computed weights + assert isinstance(gc._weights, tuple) + assert len(gc._weights) == nmax + 1 + assert all(isinstance(w, LambdaType) for w in gc._weights) + + # @pytest.mark.parametrize("nmax", [0, 1, 2, 3, 4, 5, 6]) + def test_call(self): + """Test the functor.""" + # Setup + nmax = 5 x = xp.linspace(0.02, 0.99, 10000) + gc = GegenbauerCalculator(nmax) - # Compute the Gegenbauer polynomial - ultra_sph = GegenbauerCalculator(6) - got = ultra_sph(n=xp.array([5]), alpha=2 * 0 + (3.0 / 2), x=xp.asarray(x)) + # # n > nmax should raise an error + # with pytest.raises(ValueError): + # gc(6, 0, x) + + # With ints and floats + got = gc(1, 0, x) + assert got == pytest.approx(gegenbauer_sp(n=1, alpha=0)(x), rel=1e-8) + + def test_call_vectorized(self): + """Test the functor with vectorized arguments.""" + # Setup + nmax = 5 + x = xp.linspace(0.02, 0.99, 10000) + gc = GegenbauerCalculator(nmax) - # Compare with scipy's gegenbauer function - expected = gegenbauer_sp(n=5, alpha=2 * 0 + (3 / 2))(x) + n = xp.array([3, 5]) + alpha = 1.5 + got = gc(n, alpha, x[:, None]) + expected = np.c_[ + gegenbauer_sp(n=3, alpha=alpha)(np.array(x)), + gegenbauer_sp(n=5, alpha=alpha)(np.array(x)), + ] + assert got.shape == expected.shape + assert got == pytest.approx(expected, rel=1e-8) - # Test they match - np.testing.assert_array_almost_equal(np.array(got), expected) + @pytest.mark.skip(reason="TODO with `hypothesis`") + def test_validity(self): + """Test the regions of validity of the Gegenbauer calculator.""" + raise NotImplementedError diff --git a/tests/potential/scf/test_utils.py b/tests/potential/scf/test_utils.py new file mode 100644 index 00000000..dda04f10 --- /dev/null +++ b/tests/potential/scf/test_utils.py @@ -0,0 +1,145 @@ +"""Test the Gegenbauer class.""" + +import hypothesis +import hypothesis.extra.numpy as hnp +import jax +import jax.numpy as xp +import numpy as np +import numpy.typing as npt +import scipy.special as sp +from hypothesis import assume, given +from hypothesis import strategies as st + +from galax.potential._potential.scf.utils import ( + cartesian_to_spherical, + factorial, + psi_of_r, + real_Ylm, +) + + +# TODO: use hnp.floating_dtypes() +# TODO: test more batch dimensions +def xyz_strategy() -> st.SearchStrategy[np.ndarray]: + return hnp.arrays( + dtype=np.float64, + shape=st.tuples(st.integers(1, 100), st.integers(3, 3)), + elements=st.floats(-10, 10, allow_subnormal=False, allow_nan=False), + ) + + +@given(xyz_strategy()) +def test_cartesian_to_spherical(xyz): + """Test the ``cartesian_to_spherical`` function.""" + assume(np.all(xyz.sum(axis=1) != 0)) + + n = len(xyz) + xyz = xp.asarray(xyz) + + rthetaphi = cartesian_to_spherical(xyz) + r = rthetaphi[..., 0] + theta = rthetaphi[..., 1] + phi = rthetaphi[..., 2] + + # Check + assert r.shape == (n,) + assert theta.shape == (n,) + assert phi.shape == (n,) + + assert xp.all(r >= 0) + assert xp.all(theta >= 0) & xp.all(theta <= xp.pi) + assert xp.all(phi >= -xp.pi) & xp.all(phi <= xp.pi) + + +def test_cartesian_to_spherical_jac(): + """Test the ``cartesian_to_spherical`` function.""" + # Scalar + xyz = xp.asarray([1, 0, 0], dtype=float) + assert xyz.shape == (3,) + + output = jax.jacfwd(cartesian_to_spherical)(xyz) + np.testing.assert_array_equal( + output, [[1.0, 0.0, 0.0], [-0.0, -0.0, -1.0], [0.0, 1.0, 0.0]] + ) + + # Vector + xyz = xp.asarray([[1, 0, 0], [0, 1, 0]], dtype=float) + assert xyz.shape == (2, 3) + + output = jax.jacfwd(cartesian_to_spherical)(xyz) + assert output.shape == (2, 3, 2, 3) # WTF? + np.testing.assert_array_equal( + output[0, :, 0, :], [[1.0, 0.0, 0.0], [-0.0, -0.0, -1.0], [0.0, 1.0, 0.0]] + ) + np.testing.assert_array_equal( + output[1, :, 1, :], [[0.0, 1.0, 0.0], [0.0, 0.0, -1.0], [-1.0, 0.0, 0.0]] + ) + + +# ============================================================================= + + +@given( + n=st.integers(0, 100) + | hnp.arrays(dtype=int, shape=hnp.array_shapes(), elements=st.integers(0, 100)) +) +def test_factorial(n: int | npt.NDArray[np.int_]): + """Test the ``factorial`` function.""" + got = factorial(xp.asarray(n)) + expected = sp.factorial(n) + np.testing.assert_allclose(got, expected) + + +# ============================================================================= + + +@given( + r=st.floats(0, 100) + | hnp.arrays(dtype=float, shape=hnp.array_shapes(), elements=st.floats(0, 100)), +) +def test_psi_of_r(r): + """Test the ``psi_of_r`` function.""" + got = psi_of_r(r) + expected = (r - 1) / (r + 1) + np.testing.assert_allclose(got, expected) + + +# ============================================================================= + + +def test_Ylm_jitting(): + """Test the ``real_Ylm`` function.""" + got = real_Ylm(5, 0, np.pi) + expected = np.real(sp.sph_harm(0, 5, 0, np.pi)) + np.testing.assert_allclose(got, expected) + + +@hypothesis.settings(deadline=500) +@given(l=st.integers(1, 25), m=st.integers(1, 25), theta=st.floats(0, np.pi)) +def test_real_Ylm(l, m, theta): + """Test the ``real_Ylm`` function.""" + assume(theta != 0) + assume(m <= l) + got = real_Ylm(l, m, theta) + expected = np.real(sp.sph_harm(m, l, 0, theta)) + np.testing.assert_allclose(got, expected) + + +# # TODO: mark as slow +# # TODO: test batch dimensions of l, m, theta +# @hypothesis.settings(deadline=500) +# @given( +# l=st.integers(1, 25), +# m=st.integers(1, 25), +# theta=hnp.arrays( +# dtype=np.float64, +# shape=st.integers(1, 100), +# elements=st.floats(1e-5, np.pi, allow_subnormal=False, allow_nan=False), +# ), +# ) +# def test_real_Ylm_vec(l, m, theta): +# """Test the ``real_Ylm`` function.""" +# assume(m <= l) +# got = real_Ylm(l, m, theta) +# expected = np.real(sp.sph_harm(m, l, 0, theta)) +# np.testing.assert_allclose(got, expected)