diff --git a/pmwd/boltzmann.py b/pmwd/boltzmann.py index 20e9f9d5..62be5b35 100644 --- a/pmwd/boltzmann.py +++ b/pmwd/boltzmann.py @@ -34,14 +34,14 @@ def transfer_fit(k, cosmo, conf): Parameters ---------- - k : array_like + k : ArrayLike Wavenumbers in [1/L]. cosmo : Cosmology conf : Configuration Returns ------- - T : jax.numpy.ndarray of (k * 1.).dtype + T : jax.Array of (k * 1.).dtype Matter transfer function. .. _Transfer Function: @@ -129,14 +129,14 @@ def transfer(k, cosmo, conf): Parameters ---------- - k : array_like + k : ArrayLike Wavenumbers in [1/L]. cosmo : Cosmology conf : Configuration Returns ------- - T : jax.numpy.ndarray of (k * 1.).dtype + T : jax.Array of (k * 1.).dtype Matter transfer function. Raises @@ -238,7 +238,7 @@ def growth(a, cosmo, conf, order=1, deriv=0): Parameters ---------- - a : array_like + a : ArrayLike Scale factors. cosmo : Cosmology conf : Configuration @@ -249,7 +249,7 @@ def growth(a, cosmo, conf, order=1, deriv=0): Returns ------- - D : jax.numpy.ndarray of (a * 1.).dtype + D : jax.Array of (a * 1.).dtype Growth functions or derivatives. Raises @@ -297,16 +297,16 @@ def varlin(R, a, cosmo, conf): Parameters ---------- - R : array_like + R : ArrayLike Scales in [L]. - a : array_like or None + a : ArrayLike or None Scale factors. If None, output is not scaled by growth. cosmo : Cosmology conf : Configuration Returns ------- - sigma2 : jax.numpy.ndarray of (k * a * 1.).dtype + sigma2 : jax.Array of (k * a * 1.).dtype Linear matter overdensity variance. Raises @@ -401,16 +401,16 @@ def linear_power(k, a, cosmo, conf): Parameters ---------- - k : array_like + k : ArrayLike Wavenumbers in [1/L]. - a : array_like or None + a : ArrayLike or None Scale factors. If None, output is not scaled by growth. cosmo : Cosmology conf : Configuration Returns ------- - Plin : jax.numpy.ndarray of (k * a * 1.).dtype + Plin : jax.Array of (k * a * 1.).dtype Linear matter power spectrum in [L^3]. Raises diff --git a/pmwd/configuration.py b/pmwd/configuration.py index 5509fd27..b21661ce 100644 --- a/pmwd/configuration.py +++ b/pmwd/configuration.py @@ -2,9 +2,9 @@ import math from typing import ClassVar, Optional, Tuple, Union -from numpy.typing import DTypeLike import jax from jax import ensure_compile_time_eval +from jax.typing import DTypeLike import jax.numpy as jnp from jax.tree_util import tree_map from mcfit import TophatVar diff --git a/pmwd/cosmology.py b/pmwd/cosmology.py index ca1d2672..104555e2 100644 --- a/pmwd/cosmology.py +++ b/pmwd/cosmology.py @@ -1,10 +1,10 @@ from dataclasses import field from functools import partial from operator import add, sub -from typing import ClassVar, Optional, Union +from typing import ClassVar, Optional -import numpy as np -from jax import value_and_grad +from jax import Array, value_and_grad +from jax.typing import ArrayLike import jax.numpy as jnp from jax.tree_util import tree_map @@ -12,9 +12,6 @@ from pmwd.configuration import Configuration -FloatParam = Union[float, jnp.ndarray] - - @partial(pytree_dataclass, aux_fields="conf", frozen=True) class Cosmology: """Cosmological and configuration parameters, "immutable" as a frozen dataclass. @@ -34,45 +31,45 @@ class Cosmology: ---------- conf : Configuration Configuration parameters. - A_s_1e9 : float or jax.numpy.ndarray + A_s_1e9 : float ArrayLike Primordial scalar power spectrum amplitude, multiplied by 1e9. - n_s : float or jax.numpy.ndarray + n_s : float ArrayLike Primordial scalar power spectrum spectral index. - Omega_m : float or jax.numpy.ndarray + Omega_m : float ArrayLike Total matter density parameter today. - Omega_b : float or jax.numpy.ndarray + Omega_b : float ArrayLike Baryonic matter density parameter today. - Omega_k_ : None, float, or jax.numpy.ndarray, optional + Omega_k_ : None or float ArrayLike, optional Spatial curvature density parameter today. Default is None. - w_0_ : None, float, or jax.numpy.ndarray, optional + w_0_ : None or float ArrayLike, optional Dark energy equation of state constant parameter. Default is None. - w_a_ : None, float, or jax.numpy.ndarray, optional + w_a_ : None or float ArrayLike, optional Dark energy equation of state linear parameter. Default is None. - h : float or jax.numpy.ndarray + h : float ArrayLike Hubble constant in unit of 100 [km/s/Mpc]. """ conf: Configuration = field(repr=False) - A_s_1e9: FloatParam - n_s: FloatParam - Omega_m: FloatParam - Omega_b: FloatParam - h: FloatParam + A_s_1e9: ArrayLike + n_s: ArrayLike + Omega_m: ArrayLike + Omega_b: ArrayLike + h: ArrayLike - Omega_k_: Optional[FloatParam] = None + Omega_k_: Optional[ArrayLike] = None Omega_k_fixed: ClassVar[float] = 0 - w_0_: Optional[FloatParam] = None + w_0_: Optional[ArrayLike] = None w_0_fixed: ClassVar[float] = -1 - w_a_: Optional[FloatParam] = None + w_a_: Optional[ArrayLike] = None w_a_fixed: ClassVar[float] = 0 - transfer: Optional[jnp.ndarray] = field(default=None, compare=False) + transfer: Optional[Array] = field(default=None, compare=False) - growth: Optional[jnp.ndarray] = field(default=None, compare=False) + growth: Optional[Array] = field(default=None, compare=False) - varlin: Optional[jnp.ndarray] = field(default=None, compare=False) + varlin: Optional[Array] = field(default=None, compare=False) def __post_init__(self): if self._is_transforming(): @@ -191,13 +188,13 @@ def E2(a, cosmo): Parameters ---------- - a : array_like + a : ArrayLike Scale factors. cosmo : Cosmology Returns ------- - E2 : jax.numpy.ndarray of cosmo.conf.cosmo_dtype + E2 : jax.Array of cosmo.conf.cosmo_dtype Squared Hubble parameter time scaling factors. Notes @@ -229,13 +226,13 @@ def H_deriv(a, cosmo): Parameters ---------- - a : array_like + a : ArrayLike Scale factors. cosmo : Cosmology Returns ------- - dlnH_dlna : jax.numpy.ndarray of cosmo.conf.cosmo_dtype + dlnH_dlna : jax.Array of cosmo.conf.cosmo_dtype Hubble parameter derivatives. """ @@ -250,13 +247,13 @@ def Omega_m_a(a, cosmo): Parameters ---------- - a : array_like + a : ArrayLike Scale factors. cosmo : Cosmology Returns ------- - Omega : jax.numpy.ndarray of cosmo.conf.cosmo_dtype + Omega : jax.Array of cosmo.conf.cosmo_dtype Matter density parameters. Notes diff --git a/pmwd/gather.py b/pmwd/gather.py index 3862e68b..70c9383b 100644 --- a/pmwd/gather.py +++ b/pmwd/gather.py @@ -12,18 +12,18 @@ def gather(ptcl, conf, mesh, val=0, offset=0, cell_size=None): ---------- ptcl : Particles conf : Configuration - mesh : array_like + mesh : ArrayLike Input mesh. - val : array_like, optional + val : ArrayLike, optional Input values, can be 0D. - offset : array_like, optional + offset : ArrayLike, optional Offset of mesh to particle grid. If 0D, the value is used in each dimension. cell_size : float, optional Mesh cell size in [L]. Default is ``conf.cell_size``. Returns ------- - val : jax.numpy.ndarray + val : jax.Array Output values. """ diff --git a/pmwd/lpt.py b/pmwd/lpt.py index ea8a3b5e..4200ac9c 100644 --- a/pmwd/lpt.py +++ b/pmwd/lpt.py @@ -81,7 +81,7 @@ def levi_civita(indices): Parameters ---------- - indices : array_like + indices : ArrayLike Returns ------- @@ -140,7 +140,7 @@ def lpt(modes, cosmo, conf): Parameters ---------- - modes : jax.numpy.ndarray + modes : jax.Array Linear matter overdensity Fourier modes in [L^3]. cosmo : Cosmology conf : Configuration diff --git a/pmwd/modes.py b/pmwd/modes.py index 08a7ca72..64d44a30 100644 --- a/pmwd/modes.py +++ b/pmwd/modes.py @@ -27,7 +27,7 @@ def white_noise(seed, conf, real=False, unit_abs=False, negate=False): Returns ------- - modes : jax.numpy.ndarray of conf.float_dtype + modes : jax.Array of conf.float_dtype White noise modes. """ @@ -75,7 +75,7 @@ def linear_modes(modes, cosmo, conf, a=None): Parameters ---------- - modes : jax.numpy.ndarray + modes : jax.Array Fourier or real modes with white noise prior. cosmo : Cosmology conf : Configuration @@ -84,7 +84,7 @@ def linear_modes(modes, cosmo, conf, a=None): Returns ------- - modes : jax.numpy.ndarray of conf.float_dtype + modes : jax.Array of conf.float_dtype Linear matter overdensity Fourier modes in [L^3]. Notes diff --git a/pmwd/particles.py b/pmwd/particles.py index fc193c3c..bf296a07 100644 --- a/pmwd/particles.py +++ b/pmwd/particles.py @@ -2,9 +2,9 @@ from functools import partial from itertools import accumulate from operator import itemgetter, mul -from typing import Optional, Any, Union +from typing import Optional, Any -from numpy.typing import ArrayLike +from jax.typing import ArrayLike import jax.numpy as jnp from jax.tree_util import tree_map @@ -15,37 +15,34 @@ from pmwd.pm_util import enmesh -ArrayLike = Union[ArrayLike, jnp.ndarray] - - @partial(pytree_dataclass, aux_fields="conf", frozen=True) class Particles: """Particle state. Particles are indexable. - Array-likes are converted to ``jax.numpy.ndarray`` of ``conf.pmid_dtype`` or + Array-likes are converted to ``jax.Array`` of ``conf.pmid_dtype`` or ``conf.float_dtype`` at instantiation. Parameters ---------- conf : Configuration Configuration parameters. - pmid : array_like + pmid : ArrayLike Particle IDs by mesh indices, of signed int dtype. They are the nearest mesh grid points from particles' Lagrangian positions. It can save memory compared to the raveled particle IDs, e.g., 6 bytes for 3 times int16 versus 8 bytes for uint64. Call ``raveled_id`` for the raveled IDs. - disp : array_like + disp : ArrayLike # FIXME after adding the CUDA scatter and gather ops Particle comoving displacements from pmid in [L]. For displacements from particles' grid Lagrangian positions, use ``ptcl_rpos(ptcl, Particles.gen_grid(ptcl.conf), ptcl.conf)``. It can save the particle locations with much more uniform precision than positions, whereever they are. Call ``pos`` for the positions. - vel : array_like, optional + vel : ArrayLike, optional Particle canonical velocities in [H_0 L]. - acc : array_like, optional + acc : ArrayLike, optional Particle accelerations in [H_0^2 L]. attr : pytree, optional Particle attributes (custom features). @@ -90,7 +87,7 @@ def from_pos(cls, conf, pos, wrap=True): Parameters ---------- conf : Configuration - pos : array_like + pos : ArrayLike Particle positions in [L]. wrap : bool, optional Whether to wrap around the periodic boundaries. @@ -168,7 +165,7 @@ def raveled_id(self, dtype=jnp.uint64, wrap=False): Returns ------- - raveled_id : jax.numpy.ndarray + raveled_id : jax.Array Particle raveled IDs. """ @@ -196,7 +193,7 @@ def pos(self, dtype=jnp.float64, wrap=True): Returns ------- - pos : jax.numpy.ndarray + pos : jax.Array Particle positions in [L]. """ @@ -222,7 +219,7 @@ def ptcl_enmesh(ptcl, conf, offset=0, cell_size=None, mesh_shape=None, ---------- ptcl : Particles conf : Configuration - offset : array_like, optional + offset : ArrayLike, optional Offset of mesh to particle grid. If 0D, the value is used in each dimension. cell_size : float, optional Mesh cell size in [L]. Default is ``conf.cell_size``. @@ -233,17 +230,17 @@ def ptcl_enmesh(ptcl, conf, offset=0, cell_size=None, mesh_shape=None, drop : bool, optional Whether to set negative out-of-bounds indices of ``ind`` to ``mesh_shape``, avoiding some of them being treated as in bounds, thus allowing them to be - dropped by ``add()`` and ``get()`` of ``jax.numpy.ndarray.at``. + dropped by ``add()`` and ``get()`` of ``jax.Array.at``. grad : bool, optional Whether to return ``frac_grad``, gradients of ``frac``. Returns ------- - ind : (ptcl_num, 2**dim, dim) jax.numpy.ndarray + ind : (ptcl_num, 2**dim, dim) jax.Array Mesh indices. - frac : (ptcl_num, 2**dim) jax.numpy.ndarray + frac : (ptcl_num, 2**dim) jax.Array Multilinear fractions on the mesh. - frac_grad : (ptcl_num, 2**dim, dim) jax.numpy.ndarray + frac_grad : (ptcl_num, 2**dim, dim) jax.Array Multilinear fraction gradients on the mesh. """ @@ -267,7 +264,7 @@ def ptcl_rpos(ptcl, ref, conf, wrap=True): Parameters ---------- ptcl : Particles - ref : array_like or Particles + ref : ArrayLike or Particles Reference points or particles. conf : Configuration wrap : bool, optional @@ -275,7 +272,7 @@ def ptcl_rpos(ptcl, ref, conf, wrap=True): Returns ------- - rpos : jax.numpy.ndarray of conf.float_dtype + rpos : jax.Array of conf.float_dtype Particle relative positions in [L]. """ @@ -300,16 +297,16 @@ def ptcl_rsd(ptcl, los, a, cosmo): Parameters ---------- ptcl : Particles - los : array_like + los : ArrayLike Line-of-sight **unit vectors**, global or per particle. Vector norms are *not* checked. - a : array_like + a : ArrayLike Scale factors, global or per particle. cosmo : Cosmology Returns ------- - rsd : jax.numpy.ndarray of cosmo.conf.float_dtype + rsd : jax.Array of cosmo.conf.float_dtype Particle redshift-space distortion displacements in [L]. """ @@ -333,13 +330,13 @@ def ptcl_los(ptcl, obs, conf): Parameters ---------- ptcl : Particles - obs : array_like or Particles + obs : ArrayLike or Particles Observer position. conf : Configuration Returns ------- - los : jax.numpy.ndarray of conf.float_dtype + los : jax.Array of conf.float_dtype Particles line-of-sight unit vectors. """ diff --git a/pmwd/pm_util.py b/pmwd/pm_util.py index 3ff61849..f7373c48 100644 --- a/pmwd/pm_util.py +++ b/pmwd/pm_util.py @@ -36,32 +36,32 @@ def enmesh(i1, d1, a1, s1, b12, a2, s2, grad): Parameters ---------- - i1 : (num, dim) array_like + i1 : (num, dim) ArrayLike Integer coordinates of points on grid 1. - d1 : (num, dim) array_like + d1 : (num, dim) ArrayLike Float displacements from the points on grid 1. a1 : float Cell size of grid 1. s1 : dim-tuple of int, or None Periodic boundary shape of grid 1. If None, no wrapping. - b12 : array_like + b12 : ArrayLike Offset of origin of grid 2 to that of grid 1. a2 : float or None Cell size of grid 2. If None, ``a2`` is the same as ``a1``. s2 : dim-tuple of int, or None Shape of grid 2. If not None, negative out-of-bounds indices of ``i2`` are set to ``s2``, avoiding some of them being treated as in bounds, thus allowing them - to be dropped by ``add()`` and ``get()`` of ``jax.numpy.ndarray.at``. + to be dropped by ``add()`` and ``get()`` of ``jax.Array.at``. grad : bool Whether to return gradients of ``f2``. Returns ------- - i2 : (num, 2**dim, dim) jax.numpy.ndarray + i2 : (num, 2**dim, dim) jax.Array Mesh indices on grid 2. - f2 : (num, 2**dim) jax.numpy.ndarray + f2 : (num, 2**dim) jax.Array Multilinear fractions on grid 2. - f2_grad : (num, 2**dim, dim) jax.numpy.ndarray + f2_grad : (num, 2**dim, dim) jax.Array Multilinear fraction gradients on grid 2. Notes @@ -170,7 +170,7 @@ def rfftnfreq(shape, spacing, dtype=jnp.float64): Returns ------- - kvec : list of jax.numpy.ndarray + kvec : list of jax.Array Wavevectors. """ diff --git a/pmwd/scatter.py b/pmwd/scatter.py index 4c7ffba3..9a8f0505 100644 --- a/pmwd/scatter.py +++ b/pmwd/scatter.py @@ -12,18 +12,18 @@ def scatter(ptcl, conf, mesh=None, val=None, offset=0, cell_size=None): ---------- ptcl : Particles conf : Configuration - mesh : array_like, optional + mesh : ArrayLike, optional Input mesh. Default is a ``zeros`` array of ``conf.mesh_shape + val.shape[1:]``. - val : array_like, optional + val : ArrayLike, optional Input values, can be 0D. Default is ``conf.mesh_size / conf.ptcl_num``. - offset : array_like, optional + offset : ArrayLike, optional Offset of mesh to particle grid. If 0D, the value is used in each dimension. cell_size : float, optional Mesh cell size in [L]. Default is ``conf.cell_size``. Returns ------- - mesh : jax.numpy.ndarray + mesh : jax.Array Output mesh. """ diff --git a/pmwd/spec_util.py b/pmwd/spec_util.py index 61e1f504..93253889 100644 --- a/pmwd/spec_util.py +++ b/pmwd/spec_util.py @@ -13,17 +13,17 @@ def powspec(f, spacing, bins=1j/3, g=None, deconv=0, cut_zero=True, cut_nyq=True Parameters ---------- - f : array_like + f : ArrayLike The field, with the last 3 axes for FFT and the other summed over. spacing : float Field grid spacing. - bins : float, complex, or 1D array_like, optional + bins : float, complex, or 1D ArrayLike, optional Wavenumber bins. A real number sets the linear spaced bin width in unit of the smallest fundamental in 3D (right edge inclusive starting from zero); an imaginary number sets the log spaced bin width in octave (left edge inclusive starting from the smallest fundamental in 3D); and an array sets the bin edges directly (right edge inclusive and must starting from zero). - g : array_like, optional + g : ArrayLike, optional Another field of the same shape for cross correlation. deconv : int, optional Power of sinc factors to deconvolve in the power spectrum. @@ -36,13 +36,13 @@ def powspec(f, spacing, bins=1j/3, g=None, deconv=0, cut_zero=True, cut_nyq=True Returns ------- - k : jax.numpy.ndarray + k : jax.Array Wavenumber. - P : jax.numpy.ndarray + P : jax.Array Power spectrum. - N : jax.numpy.ndarray + N : jax.Array Number of modes. - bins : jax.numpy.ndarray + bins : jax.Array Wavenumber bins. """ diff --git a/pmwd/test_util.py b/pmwd/test_util.py index f3f8e8a5..256ed4ad 100644 --- a/pmwd/test_util.py +++ b/pmwd/test_util.py @@ -115,11 +115,11 @@ def check_custom_vjp(fun, primals, partial_args=(), partial_kwargs={}, Returns ------- - cot : jax.numpy.ndarray + cot : jax.Array Input cotangents by custom vjp. - cot_orig : jax.numpy.ndarray + cot_orig : jax.Array Input cotangents by automatic vjp. - cot_diff : jax.numpy.ndarray + cot_diff : jax.Array Input cotangent differences. Raises diff --git a/pmwd/vis_util.py b/pmwd/vis_util.py index 2e408e5c..e121c44c 100644 --- a/pmwd/vis_util.py +++ b/pmwd/vis_util.py @@ -16,7 +16,7 @@ def simshow(x, figsize=(9, 7), dpi=72, cmap='inferno', norm=None, colorbar=True, Parameters ---------- - x : array_like + x : ArrayLike 2D field. figsize : 2-tuple of float, optional Width and height in inches. @@ -80,7 +80,7 @@ class CosmicWebNorm(FuncNorm): Parameters ---------- - x : array_like + x : ArrayLike Density field. q : float, optional Underdensity fraction in colormap. diff --git a/setup.py b/setup.py index 49941e0d..6ab274f0 100644 --- a/setup.py +++ b/setup.py @@ -25,7 +25,6 @@ python_requires='>=3.7', install_requires=[ 'jax>=0.4.7', - 'numpy>=1.20', # numpy.typing 'mcfit>=0.0.18', # jax backend ], extras_require={