Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Switching to integration via Splines #104

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
repos:
- repo: https://github.com/psf/black
rev: stable
rev: 22.12.0
hooks:
- id: black
language_version: python3.9
language_version: python3.10

- repo: https://github.com/asottile/reorder_python_imports
rev: v2.3.0
rev: v3.9.0
hooks:
- id: reorder-python-imports

21 changes: 16 additions & 5 deletions jax_cosmo/angular_cl.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import jax_cosmo.power as power
import jax_cosmo.transfer as tklib
from jax_cosmo.scipy.integrate import simps
from jax_cosmo.scipy.interpolate import InterpolatedUnivariateSpline
from jax_cosmo.utils import a2z
from jax_cosmo.utils import z2a

Expand Down Expand Up @@ -50,7 +51,12 @@ def find_index(a, b):


def angular_cl(
cosmo, ell, probes, transfer_fn=tklib.Eisenstein_Hu, nonlinear_fn=power.halofit
cosmo,
ell,
probes,
transfer_fn=tklib.Eisenstein_Hu,
nonlinear_fn=power.halofit,
npoints=128,
):
"""
Computes angular Cls for the provided probes
Expand Down Expand Up @@ -90,12 +96,17 @@ def combine_kernels(inds):
# Now kernels has shape [ncls, na]
kernels = lax.map(combine_kernels, cl_index)

result = pk * kernels * bkgrd.dchioverda(cosmo, a) / np.clip(chi ** 2, 1.0)
result = pk * kernels * bkgrd.dchioverda(cosmo, a) / np.clip(chi**2, 1.0)

# We transpose the result just to make sure that na is first
return result.T
return result

return simps(integrand, z2a(zmax), 1.0, 512) / const.c ** 2
atab = np.linspace(z2a(zmax), 1.0, npoints)
eval_integral = vmap(
lambda x: np.squeeze(
InterpolatedUnivariateSpline(atab, x).integral(z2a(zmax), 1.0)
)
)
return eval_integral(integrand(atab)) / const.c**2

return cl(ell)

Expand Down
2 changes: 1 addition & 1 deletion jax_cosmo/background.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ def dchioverda(cosmo, a):

\frac{d \chi}{da}(a) = \frac{R_H}{a^2 E(a)}
"""
return const.rh / (a ** 2 * np.sqrt(Esqr(cosmo, a)))
return const.rh / (a**2 * np.sqrt(Esqr(cosmo, a)))


def transverse_comoving_distance(cosmo, a):
Expand Down
72 changes: 36 additions & 36 deletions jax_cosmo/power.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def primordial_matter_power(cosmo, k):
"""Primordial power spectrum
Pk = k^n
"""
return k ** cosmo.n_s
return k**cosmo.n_s


def linear_matter_power(cosmo, k, a=1.0, transfer_fn=tklib.Eisenstein_Hu, **kwargs):
Expand Down Expand Up @@ -45,9 +45,9 @@ def linear_matter_power(cosmo, k, a=1.0, transfer_fn=tklib.Eisenstein_Hu, **kwar
g = bkgrd.growth_factor(cosmo, a)
t = transfer_fn(cosmo, k, **kwargs)

pknorm = cosmo.sigma8 ** 2 / sigmasqr(cosmo, 8.0, transfer_fn, **kwargs)
pknorm = cosmo.sigma8**2 / sigmasqr(cosmo, 8.0, transfer_fn, **kwargs)

pk = primordial_matter_power(cosmo, k) * t ** 2 * g ** 2
pk = primordial_matter_power(cosmo, k) * t**2 * g**2

# Apply normalisation
pk = pk * pknorm
Expand Down Expand Up @@ -76,7 +76,7 @@ def int_sigma(logk):
return k * (k * w) ** 2 * pk

y = romb(int_sigma, np.log10(kmin), np.log10(kmax), divmax=7)
return 1.0 / (2.0 * np.pi ** 2.0) * y
return 1.0 / (2.0 * np.pi**2.0) * y


def linear(cosmo, k, a, transfer_fn):
Expand All @@ -103,10 +103,10 @@ def int_sigma(logk):
pk = linear_matter_power(cosmo, k, transfer_fn=transfer_fn)
g = bkgrd.growth_factor(cosmo, np.atleast_1d(a))
return (
np.expand_dims(pk * k ** 3, axis=1)
* np.exp(-(y ** 2))
/ (2.0 * np.pi ** 2)
* g ** 2
np.expand_dims(pk * k**3, axis=1)
* np.exp(-(y**2))
/ (2.0 * np.pi**2)
* g**2
)

sigma = simps(int_sigma, np.log(1e-4), np.log(1e4), 256)
Expand All @@ -125,13 +125,13 @@ def integrand(logk):
pk = linear_matter_power(cosmo, k, transfer_fn=transfer_fn)
g = np.expand_dims(bkgrd.growth_factor(cosmo, np.atleast_1d(a)), 0)
res = (
np.expand_dims(pk * k ** 3, axis=1)
* np.exp(-(y ** 2))
* g ** 2
/ (2.0 * np.pi ** 2)
np.expand_dims(pk * k**3, axis=1)
* np.exp(-(y**2))
* g**2
/ (2.0 * np.pi**2)
)
dneff_dlogk = 2 * res * y ** 2
dC_dlogk = 4 * res * (y ** 2 - y ** 4)
dneff_dlogk = 2 * res * y**2
dC_dlogk = 4 * res * (y**2 - y**4)
return np.stack([dneff_dlogk, dC_dlogk], axis=1)

res = simps(integrand, np.log(1e-4), np.log(1e4), 256)
Expand Down Expand Up @@ -185,44 +185,44 @@ def halofit(cosmo, k, a, transfer_fn, prescription="takahashi2012"):
a_n = 10 ** (
1.4861
+ 1.8369 * n
+ 1.6762 * n ** 2
+ 0.7940 * n ** 3
+ 0.1670 * n ** 4
+ 1.6762 * n**2
+ 0.7940 * n**3
+ 0.1670 * n**4
- 0.6206 * C
)
b_n = 10 ** (0.9463 + 0.9466 * n + 0.3084 * n ** 2 - 0.9400 * C)
c_n = 10 ** (-0.2807 + 0.6669 * n + 0.3214 * n ** 2 - 0.0793 * C)
b_n = 10 ** (0.9463 + 0.9466 * n + 0.3084 * n**2 - 0.9400 * C)
c_n = 10 ** (-0.2807 + 0.6669 * n + 0.3214 * n**2 - 0.0793 * C)
gamma_n = 0.8649 + 0.2989 * n + 0.1631 * C
alpha_n = 1.3884 + 0.3700 * n - 0.1452 * n ** 2
beta_n = 0.8291 + 0.9854 * n + 0.3401 * n ** 2
alpha_n = 1.3884 + 0.3700 * n - 0.1452 * n**2
beta_n = 0.8291 + 0.9854 * n + 0.3401 * n**2
mu_n = 10 ** (-3.5442 + 0.1908 * n)
nu_n = 10 ** (0.9585 + 1.2857 * n)
elif prescription == "takahashi2012":
a_n = 10 ** (
1.5222
+ 2.8553 * n
+ 2.3706 * n ** 2
+ 0.9903 * n ** 3
+ 0.2250 * n ** 4
+ 2.3706 * n**2
+ 0.9903 * n**3
+ 0.2250 * n**4
- 0.6038 * C
+ 0.1749 * om_de * (1 + w)
)
b_n = 10 ** (
-0.5642
+ 0.5864 * n
+ 0.5716 * n ** 2
+ 0.5716 * n**2
- 1.5474 * C
+ 0.2279 * om_de * (1 + w)
)
c_n = 10 ** (0.3698 + 2.0404 * n + 0.8161 * n ** 2 + 0.5869 * C)
c_n = 10 ** (0.3698 + 2.0404 * n + 0.8161 * n**2 + 0.5869 * C)
gamma_n = 0.1971 - 0.0843 * n + 0.8460 * C
alpha_n = np.abs(6.0835 + 1.3373 * n - 0.1959 * n ** 2 - 5.5274 * C)
alpha_n = np.abs(6.0835 + 1.3373 * n - 0.1959 * n**2 - 5.5274 * C)
beta_n = (
2.0379
- 0.7354 * n
+ 0.3157 * n ** 2
+ 1.2490 * n ** 3
+ 0.3980 * n ** 4
+ 0.3157 * n**2
+ 1.2490 * n**3
+ 0.3980 * n**4
- 0.1682 * C
)
mu_n = 0.0
Expand All @@ -232,7 +232,7 @@ def halofit(cosmo, k, a, transfer_fn, prescription="takahashi2012"):

f1a = om_m ** (-0.0732)
f2a = om_m ** (-0.1423)
f3a = om_m ** 0.0725
f3a = om_m**0.0725
f1b = om_m ** (-0.0307)
f2b = om_m ** (-0.0585)
f3b = om_m ** (0.0743)
Expand All @@ -248,21 +248,21 @@ def halofit(cosmo, k, a, transfer_fn, prescription="takahashi2012"):
else:
raise NotImplementedError

f = lambda x: x / 4.0 + x ** 2 / 8.0
f = lambda x: x / 4.0 + x**2 / 8.0

d2l = k ** 3 * pklin / (2.0 * np.pi ** 2)
d2l = k**3 * pklin / (2.0 * np.pi**2)

y = k / k_nl

# Eq C2
d2q = d2l * ((1.0 + d2l) ** beta_n / (1 + alpha_n * d2l)) * np.exp(-f(y))
d2hprime = (
a_n * y ** (3 * f1) / (1.0 + b_n * y ** f2 + (c_n * f3 * y) ** (3.0 - gamma_n))
a_n * y ** (3 * f1) / (1.0 + b_n * y**f2 + (c_n * f3 * y) ** (3.0 - gamma_n))
)
d2h = d2hprime / (1.0 + mu_n / y + nu_n / y ** 2)
d2h = d2hprime / (1.0 + mu_n / y + nu_n / y**2)
# Eq. C1
d2nl = d2q + d2h
pk_nl = 2.0 * np.pi ** 2 / k ** 3 * d2nl
pk_nl = 2.0 * np.pi**2 / k**3 * d2nl

return pk_nl.squeeze()

Expand Down
4 changes: 2 additions & 2 deletions jax_cosmo/probes.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def integrand_single(z_prime):
radial_kernel = radial_kernel[inv]

# Constant term
constant_factor = 3.0 * const.H0 ** 2 * cosmo.Omega_m / 2.0 / const.c
constant_factor = 3.0 * const.H0**2 * cosmo.Omega_m / 2.0 / const.c
# Ell dependent factor
ell_factor = np.sqrt((ell - 1) * (ell) * (ell + 1) * (ell + 2)) / (ell + 0.5) ** 2
return constant_factor * ell_factor * radial_kernel
Expand Down Expand Up @@ -222,7 +222,7 @@ def noise(self):
sigma_e = np.array([s for s in self.config["sigma_e"]])
else:
sigma_e = self.config["sigma_e"]
return sigma_e ** 2 / ngals
return sigma_e**2 / ngals


@register_pytree_node_class
Expand Down
4 changes: 2 additions & 2 deletions jax_cosmo/redshift.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ class smail_nz(redshift_distribution):

def pz_fn(self, z):
a, b, z0 = self.params
return z ** a * np.exp(-((z / z0) ** b))
return z**a * np.exp(-((z / z0) ** b))


@register_pytree_node_class
Expand Down Expand Up @@ -118,7 +118,7 @@ class kde_nz(redshift_distribution):
def _kernel(self, bw, X, x):
"""Gaussian kernel for KDE"""
return (1.0 / np.sqrt(2 * np.pi) / bw) * np.exp(
-((X - x) ** 2) / (bw ** 2 * 2.0)
-((X - x) ** 2) / (bw**2 * 2.0)
)

def pz_fn(self, z):
Expand Down
4 changes: 2 additions & 2 deletions jax_cosmo/scipy/integrate.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ def _romberg_diff(b, c, k):
Compute the differences for the Romberg quadrature corrections.
See Forman Acton's "Real Computing Made Real," p 143.
"""
tmp = 4.0 ** k
tmp = 4.0**k
return (tmp * c - b) / (tmp - 1.0)


Expand Down Expand Up @@ -143,7 +143,7 @@ def scan_fn(carry, y):
return (x, k + 1), x

for i in range(1, divmax + 1):
n = 2 ** i
n = 2**i
ordsum = ordsum + _difftrapn(vfunc, interval, n)

x = intrange * ordsum / n
Expand Down
55 changes: 42 additions & 13 deletions jax_cosmo/scipy/interpolate.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,11 +258,11 @@ def __call__(self, x):

if self.k == 2:
t, a, b, c = self._compute_coeffs(x)
result = a + b * t + c * t ** 2
result = a + b * t + c * t**2

if self.k == 3:
t, a, b, c, d = self._compute_coeffs(x)
result = a + b * t + c * t ** 2 + d * t ** 3
result = a + b * t + c * t**2 + d * t**3

return result

Expand Down Expand Up @@ -300,7 +300,7 @@ def _compute_coeffs(self, xs):
dt = (x - knots[:-1])[ind]
b = coefficients[ind]
b1 = coefficients[ind + 1]
a = y[ind] - b * dt - (b1 - b) * dt ** 2 / (2 * h)
a = y[ind] - b * dt - (b1 - b) * dt**2 / (2 * h)
c = (b1 - b) / (2 * h)
result = (t, a, b, c)

Expand Down Expand Up @@ -343,7 +343,7 @@ def derivative(self, x, n=1):
if self.k == 3:
t, a, b, c, d = self._compute_coeffs(x)
if n == 1:
result = b + 2 * c * t + 3 * d * t ** 2
result = b + 2 * c * t + 3 * d * t**2
if n == 2:
result = 2 * c + 6 * d * t
if n == 3:
Expand Down Expand Up @@ -382,20 +382,20 @@ def antiderivative(self, xs):
a = y[:-1]
b = coefficients
h = np.diff(knots)
cst = np.concatenate([np.zeros(1), np.cumsum(a * h + b * h ** 2 / 2)])
return cst[ind] + a[ind] * t + b[ind] * t ** 2 / 2
cst = np.concatenate([np.zeros(1), np.cumsum(a * h + b * h**2 / 2)])
return cst[ind] + a[ind] * t + b[ind] * t**2 / 2

if self.k == 2:
h = np.diff(knots)
dt = x - knots[:-1]
b = coefficients[:-1]
b1 = coefficients[1:]
a = y - b * dt - (b1 - b) * dt ** 2 / (2 * h)
a = y - b * dt - (b1 - b) * dt**2 / (2 * h)
c = (b1 - b) / (2 * h)
cst = np.concatenate(
[np.zeros(1), np.cumsum(a * h + b * h ** 2 / 2 + c * h ** 3 / 3)]
[np.zeros(1), np.cumsum(a * h + b * h**2 / 2 + c * h**3 / 3)]
)
return cst[ind] + a[ind] * t + b[ind] * t ** 2 / 2 + c[ind] * t ** 3 / 3
return cst[ind] + a[ind] * t + b[ind] * t**2 / 2 + c[ind] * t**3 / 3

if self.k == 3:
h = np.diff(knots)
Expand All @@ -408,15 +408,15 @@ def antiderivative(self, xs):
cst = np.concatenate(
[
np.zeros(1),
np.cumsum(a * h + b * h ** 2 / 2 + c * h ** 3 / 3 + d * h ** 4 / 4),
np.cumsum(a * h + b * h**2 / 2 + c * h**3 / 3 + d * h**4 / 4),
]
)
return (
cst[ind]
+ a[ind] * t
+ b[ind] * t ** 2 / 2
+ c[ind] * t ** 3 / 3
+ d[ind] * t ** 4 / 4
+ b[ind] * t**2 / 2
+ c[ind] * t**3 / 3
+ d[ind] * t**4 / 4
)

def integral(self, a, b):
Expand All @@ -440,3 +440,32 @@ def integral(self, a, b):
sign = -1
xs = np.array([a, b])
return sign * np.diff(self.antiderivative(xs))


def splint(func, a, b, k=3, N=128):
"""Function that computes an integration with a spline function
slightly different from the original splint from scipy
"""
x = np.linspace(a, b, N)
return InterpolatedUnivariateSpline(x, func(x), k=k).integral(a, b)


# def splint_fwd(func, a, b, **kwargs):
# result = splint(func, a, b, **kwargs)
# aux = (a, b, kwargs)
# return result, aux

# def splint_bwd(func, aux, grad):
# a, b, kwargs = aux

# grad_a = -grad * func(a)
# grad_b = grad * func(b)

# grad_args = []
# for i in range(len(args)):
# def _vjp_func(_t, *_args):
# return jax.grad(func, i)(_t, *_args)
# grad_args.append(grad * quad(_vjp_func, a, b, args))
# grad_args = tuple(grad_args)

# return grad_a, grad_b, grad_args
Loading