Skip to content

Commit

Permalink
Change to use jax.Array and jax.typing
Browse files Browse the repository at this point in the history
  • Loading branch information
eelregit committed Oct 23, 2023
1 parent e830c2f commit 62a9aa7
Show file tree
Hide file tree
Showing 13 changed files with 103 additions and 110 deletions.
24 changes: 12 additions & 12 deletions pmwd/boltzmann.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,14 @@ def transfer_fit(k, cosmo, conf):
Parameters
----------
k : array_like
k : ArrayLike
Wavenumbers in [1/L].
cosmo : Cosmology
conf : Configuration
Returns
-------
T : jax.numpy.ndarray of (k * 1.).dtype
T : jax.Array of (k * 1.).dtype
Matter transfer function.
.. _Transfer Function:
Expand Down Expand Up @@ -129,14 +129,14 @@ def transfer(k, cosmo, conf):
Parameters
----------
k : array_like
k : ArrayLike
Wavenumbers in [1/L].
cosmo : Cosmology
conf : Configuration
Returns
-------
T : jax.numpy.ndarray of (k * 1.).dtype
T : jax.Array of (k * 1.).dtype
Matter transfer function.
Raises
Expand Down Expand Up @@ -238,7 +238,7 @@ def growth(a, cosmo, conf, order=1, deriv=0):
Parameters
----------
a : array_like
a : ArrayLike
Scale factors.
cosmo : Cosmology
conf : Configuration
Expand All @@ -249,7 +249,7 @@ def growth(a, cosmo, conf, order=1, deriv=0):
Returns
-------
D : jax.numpy.ndarray of (a * 1.).dtype
D : jax.Array of (a * 1.).dtype
Growth functions or derivatives.
Raises
Expand Down Expand Up @@ -297,16 +297,16 @@ def varlin(R, a, cosmo, conf):
Parameters
----------
R : array_like
R : ArrayLike
Scales in [L].
a : array_like or None
a : ArrayLike or None
Scale factors. If None, output is not scaled by growth.
cosmo : Cosmology
conf : Configuration
Returns
-------
sigma2 : jax.numpy.ndarray of (k * a * 1.).dtype
sigma2 : jax.Array of (k * a * 1.).dtype
Linear matter overdensity variance.
Raises
Expand Down Expand Up @@ -401,16 +401,16 @@ def linear_power(k, a, cosmo, conf):
Parameters
----------
k : array_like
k : ArrayLike
Wavenumbers in [1/L].
a : array_like or None
a : ArrayLike or None
Scale factors. If None, output is not scaled by growth.
cosmo : Cosmology
conf : Configuration
Returns
-------
Plin : jax.numpy.ndarray of (k * a * 1.).dtype
Plin : jax.Array of (k * a * 1.).dtype
Linear matter power spectrum in [L^3].
Raises
Expand Down
8 changes: 4 additions & 4 deletions pmwd/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
import math
from typing import ClassVar, Optional, Tuple, Union

from numpy.typing import DTypeLike
import jax
from jax import ensure_compile_time_eval
from jax.typing import DTypeLike
import jax.numpy as jnp
from jax.tree_util import tree_map
from mcfit import TophatVar
Expand Down Expand Up @@ -33,11 +33,11 @@ class Configuration:
ratio, to determine the mesh shape from that of the particle grid. The mesh grid
cannot be smaller than the particle grid (int or float values must not be
smaller than 1) and the two grids must have the same aspect ratio.
cosmo_dtype : dtype_like, optional
cosmo_dtype : DTypeLike, optional
Float dtype for Cosmology and Configuration.
pmid_dtype : dtype_like, optional
pmid_dtype : DTypeLike, optional
Signed integer dtype for particle or mesh grid indices.
float_dtype : dtype_like, optional
float_dtype : DTypeLike, optional
Float dtype for other particle and mesh quantities.
k_pivot_Mpc : float, optional
Primordial scalar power spectrum pivot scale in 1/Mpc.
Expand Down
59 changes: 28 additions & 31 deletions pmwd/cosmology.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,17 @@
from dataclasses import field
from functools import partial
from operator import add, sub
from typing import ClassVar, Optional, Union
from typing import ClassVar, Optional

import numpy as np
from jax import value_and_grad
from jax import Array, value_and_grad
from jax.typing import ArrayLike
import jax.numpy as jnp
from jax.tree_util import tree_map

from pmwd.tree_util import pytree_dataclass
from pmwd.configuration import Configuration


FloatParam = Union[float, jnp.ndarray]


@partial(pytree_dataclass, aux_fields="conf", frozen=True)
class Cosmology:
"""Cosmological and configuration parameters, "immutable" as a frozen dataclass.
Expand All @@ -34,45 +31,45 @@ class Cosmology:
----------
conf : Configuration
Configuration parameters.
A_s_1e9 : float or jax.numpy.ndarray
A_s_1e9 : float ArrayLike
Primordial scalar power spectrum amplitude, multiplied by 1e9.
n_s : float or jax.numpy.ndarray
n_s : float ArrayLike
Primordial scalar power spectrum spectral index.
Omega_m : float or jax.numpy.ndarray
Omega_m : float ArrayLike
Total matter density parameter today.
Omega_b : float or jax.numpy.ndarray
Omega_b : float ArrayLike
Baryonic matter density parameter today.
Omega_k_ : None, float, or jax.numpy.ndarray, optional
Omega_k_ : None or float ArrayLike, optional
Spatial curvature density parameter today. Default is None.
w_0_ : None, float, or jax.numpy.ndarray, optional
w_0_ : None or float ArrayLike, optional
Dark energy equation of state constant parameter. Default is None.
w_a_ : None, float, or jax.numpy.ndarray, optional
w_a_ : None or float ArrayLike, optional
Dark energy equation of state linear parameter. Default is None.
h : float or jax.numpy.ndarray
h : float ArrayLike
Hubble constant in unit of 100 [km/s/Mpc].
"""

conf: Configuration = field(repr=False)

A_s_1e9: FloatParam
n_s: FloatParam
Omega_m: FloatParam
Omega_b: FloatParam
h: FloatParam
A_s_1e9: ArrayLike
n_s: ArrayLike
Omega_m: ArrayLike
Omega_b: ArrayLike
h: ArrayLike

Omega_k_: Optional[FloatParam] = None
Omega_k_: Optional[ArrayLike] = None
Omega_k_fixed: ClassVar[float] = 0
w_0_: Optional[FloatParam] = None
w_0_: Optional[ArrayLike] = None
w_0_fixed: ClassVar[float] = -1
w_a_: Optional[FloatParam] = None
w_a_: Optional[ArrayLike] = None
w_a_fixed: ClassVar[float] = 0

transfer: Optional[jnp.ndarray] = field(default=None, compare=False)
transfer: Optional[Array] = field(default=None, compare=False)

growth: Optional[jnp.ndarray] = field(default=None, compare=False)
growth: Optional[Array] = field(default=None, compare=False)

varlin: Optional[jnp.ndarray] = field(default=None, compare=False)
varlin: Optional[Array] = field(default=None, compare=False)

def __post_init__(self):
if self._is_transforming():
Expand Down Expand Up @@ -191,13 +188,13 @@ def E2(a, cosmo):
Parameters
----------
a : array_like
a : ArrayLike
Scale factors.
cosmo : Cosmology
Returns
-------
E2 : jax.numpy.ndarray of cosmo.conf.cosmo_dtype
E2 : jax.Array of cosmo.conf.cosmo_dtype
Squared Hubble parameter time scaling factors.
Notes
Expand Down Expand Up @@ -229,13 +226,13 @@ def H_deriv(a, cosmo):
Parameters
----------
a : array_like
a : ArrayLike
Scale factors.
cosmo : Cosmology
Returns
-------
dlnH_dlna : jax.numpy.ndarray of cosmo.conf.cosmo_dtype
dlnH_dlna : jax.Array of cosmo.conf.cosmo_dtype
Hubble parameter derivatives.
"""
Expand All @@ -250,13 +247,13 @@ def Omega_m_a(a, cosmo):
Parameters
----------
a : array_like
a : ArrayLike
Scale factors.
cosmo : Cosmology
Returns
-------
Omega : jax.numpy.ndarray of cosmo.conf.cosmo_dtype
Omega : jax.Array of cosmo.conf.cosmo_dtype
Matter density parameters.
Notes
Expand Down
8 changes: 4 additions & 4 deletions pmwd/gather.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,18 @@ def gather(ptcl, conf, mesh, val=0, offset=0, cell_size=None):
----------
ptcl : Particles
conf : Configuration
mesh : array_like
mesh : ArrayLike
Input mesh.
val : array_like, optional
val : ArrayLike, optional
Input values, can be 0D.
offset : array_like, optional
offset : ArrayLike, optional
Offset of mesh to particle grid. If 0D, the value is used in each dimension.
cell_size : float, optional
Mesh cell size in [L]. Default is ``conf.cell_size``.
Returns
-------
val : jax.numpy.ndarray
val : jax.Array
Output values.
"""
Expand Down
4 changes: 2 additions & 2 deletions pmwd/lpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def levi_civita(indices):
Parameters
----------
indices : array_like
indices : ArrayLike
Returns
-------
Expand Down Expand Up @@ -140,7 +140,7 @@ def lpt(modes, cosmo, conf):
Parameters
----------
modes : jax.numpy.ndarray
modes : jax.Array
Linear matter overdensity Fourier modes in [L^3].
cosmo : Cosmology
conf : Configuration
Expand Down
6 changes: 3 additions & 3 deletions pmwd/modes.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ def white_noise(seed, conf, real=False, unit_abs=False, negate=False):
Returns
-------
modes : jax.numpy.ndarray of conf.float_dtype
modes : jax.Array of conf.float_dtype
White noise modes.
"""
Expand Down Expand Up @@ -75,7 +75,7 @@ def linear_modes(modes, cosmo, conf, a=None):
Parameters
----------
modes : jax.numpy.ndarray
modes : jax.Array
Fourier or real modes with white noise prior.
cosmo : Cosmology
conf : Configuration
Expand All @@ -84,7 +84,7 @@ def linear_modes(modes, cosmo, conf, a=None):
Returns
-------
modes : jax.numpy.ndarray of conf.float_dtype
modes : jax.Array of conf.float_dtype
Linear matter overdensity Fourier modes in [L^3].
Notes
Expand Down
Loading

0 comments on commit 62a9aa7

Please sign in to comment.