Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implemented initial conditions with local non-Gaussianity #26

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
766 changes: 766 additions & 0 deletions docs/examples/png_example.ipynb

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion docs/examples/quickstart.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -699,7 +699,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.10"
"version": "3.10.4"
}
},
"nbformat": 4,
Expand Down
2 changes: 1 addition & 1 deletion pmwd/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from pmwd.configuration import Configuration
from pmwd.cosmology import Cosmology, SimpleLCDM, Planck18, E2, H_deriv, Omega_m_a
from pmwd.boltzmann import (transfer_integ, transfer_fit, transfer, growth_integ,
growth, varlin_integ, varlin, boltzmann, linear_power)
growth, varlin_integ, varlin, boltzmann, linear_power, linear_transfer)
from pmwd.particles import (Particles, ptcl_enmesh,
ptcl_pos, ptcl_rpos, ptcl_rsd, ptcl_los)
from pmwd.scatter import scatter
Expand Down
45 changes: 45 additions & 0 deletions pmwd/boltzmann.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,3 +453,48 @@ def linear_power(k, a, cosmo, conf):
Plin *= D**2

return Plin.astype(float_dtype)

def linear_transfer(k, a, cosmo, conf):
r"""Linear matter transfer function at given wavenumbers and scale factors.

Parameters
----------
k : array_like
Wavenumbers in [1/L].
a : array_like or None
Scale factors. If None, output is not scaled by growth.
cosmo : Cosmology
conf : Configuration

Returns
-------
Tlin : jax.numpy.ndarray of (k * a * 1.).dtype
Linear matter transfer function.

Raises
------
ValueError
If not in 3D.

"""

if conf.dim != 3:
raise ValueError(f'dim={conf.dim} not supported')

k = jnp.asarray(k)
float_dtype = jnp.promote_types(k.dtype, float)

T = transfer(k, cosmo, conf)

# TF: the 3/5 is because the primordial amplitude A_s is given for \zeta instead of \Phi
Tlin = (3/5) * (2/3) * (conf.c / conf.H_0)**2 / cosmo.Omega_m * T

if a is not None:
a = jnp.asarray(a)
float_dtype = jnp.promote_types(float_dtype, a.dtype)

D = growth(a, cosmo, conf)

Tlin *= D

return Tlin.astype(float_dtype)
5 changes: 4 additions & 1 deletion pmwd/cosmology.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,8 @@ class Cosmology:
Dark energy equation of state linear parameter. Default is None.
h : float ArrayLike
Hubble constant in unit of 100 [km/s/Mpc].

f_nl_loc : float or jax.numpy.ndarray, optional
amplitude of local primordial non-Gaussianity
"""

conf: Configuration = field(repr=False)
Expand All @@ -64,6 +65,8 @@ class Cosmology:
w_0_fixed: ClassVar[float] = -1
w_a_: Optional[ArrayLike] = None
w_a_fixed: ClassVar[float] = 0
f_nl_loc: Optional[ArrayLike] = None
f_nl_loc_fixed: ClassVar[float] = 0

transfer: Optional[Array] = field(default=None, compare=False)

Expand Down
43 changes: 37 additions & 6 deletions pmwd/modes.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from jax import random
import jax.numpy as jnp

from pmwd.boltzmann import linear_power
from pmwd.boltzmann import linear_power, linear_transfer
from pmwd.pm_util import fftfreq, fftfwd, fftinv


Expand Down Expand Up @@ -65,7 +65,7 @@ def _safe_sqrt_bwd(y, y_cot):


@partial(jit, static_argnums=4)
@partial(checkpoint, static_argnums=4)
# @partial(checkpoint, static_argnums=4)
def linear_modes(modes, cosmo, conf, a=None, real=False):
"""Linear matter overdensity Fourier or real modes.

Expand Down Expand Up @@ -100,12 +100,43 @@ def linear_modes(modes, cosmo, conf, a=None, real=False):
if a is not None:
a = jnp.asarray(a, dtype=conf.float_dtype)

Plin = linear_power(k, a, cosmo, conf)

if jnp.isrealobj(modes):
modes = fftfwd(modes, norm='ortho')

modes *= _safe_sqrt(Plin * conf.box_vol)

if cosmo.f_nl_loc is not None:
Tlin = linear_transfer(k, a, cosmo, conf)*k*k
Pprim = 2*jnp.pi**2. * cosmo.A_s * (k/cosmo.k_pivot)**(cosmo.n_s-1.)\
* k**(-3.)

modes *= _safe_sqrt(Pprim / conf.box_vol)
modes = modes.at[0,0,0].set(0.+0.j)

# TF: To generate non-Gaussian primordial field without aliassing effects, we generate the square of the field at a higher grid size
# TF: When squaring the field in real space, the generated higher frequency modes can be accomodated on the larger grid and don't 'fold back' over the relevant modes.
modes_NG = jnp.zeros(shape=(conf.ptcl_grid_shape[0]*2,conf.ptcl_grid_shape[1]*2,conf.ptcl_grid_shape[2] + 1),dtype=modes.dtype)
# TF: We fill the higher resolution box only halfway with the previously generated modes (note factor of 8 for 2**3 times more gridpoints):
modes_NG = modes_NG.at[conf.ptcl_grid_shape[0]-conf.ptcl_grid_shape[0]//2:conf.ptcl_grid_shape[0]+conf.ptcl_grid_shape[0]//2,conf.ptcl_grid_shape[1]-conf.ptcl_grid_shape[1]//2:conf.ptcl_grid_shape[1]+conf.ptcl_grid_shape[1]//2,:conf.ptcl_grid_shape[2]//2+1].set(jnp.fft.fftshift(modes*jnp.sqrt(8),axes=[0,1]))
modes_NG = jnp.fft.ifftshift(modes_NG,axes=[0,1])
# TF: Move to real space, square and back to Fourier space
modes_NG = fftfwd(fftinv(modes_NG, norm='ortho')**2., norm='ortho')
# TF: After squaring, downsample back to the target resolution in Fourier space
modes_NG = jnp.fft.fftshift(modes_NG,axes=[0,1])
modes_NG = modes_NG[conf.ptcl_grid_shape[0]-conf.ptcl_grid_shape[0]//2:conf.ptcl_grid_shape[0]+conf.ptcl_grid_shape[0]//2,conf.ptcl_grid_shape[1]-conf.ptcl_grid_shape[1]//2:conf.ptcl_grid_shape[1]+conf.ptcl_grid_shape[1]//2,:conf.ptcl_grid_shape[2]//2+1]/jnp.sqrt(8)
modes_NG = jnp.fft.ifftshift(modes_NG,axes=[0,1])

# TF: And now to real space again to do the addition in the proper way
modes = fftinv(modes, norm='ortho')
modes_NG = fftinv(modes_NG, norm='ortho')

# TF: add the non-guassian field, factor of 3/5 is because we are generating \zeta and f_nl is defined for \Phi
modes = modes + 3/5 * cosmo.f_nl_loc * jnp.sqrt(conf.ptcl_num) * (modes_NG - jnp.mean(modes_NG))

# TF: apply transfer function
modes = fftfwd(modes, norm='ortho')
modes *= Tlin * conf.box_vol
else:
Plin = linear_power(k, a, cosmo, conf)
modes *= _safe_sqrt(Plin * conf.box_vol)

if real:
modes = fftinv(modes, shape=conf.ptcl_grid_shape, norm=conf.ptcl_spacing)
Expand Down