diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index bd0433f..cf7ef05 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,12 +1,12 @@ repos: - repo: https://github.com/psf/black - rev: stable + rev: 22.12.0 hooks: - id: black - language_version: python3.9 + language_version: python3.10 - repo: https://github.com/asottile/reorder_python_imports - rev: v2.3.0 + rev: v3.9.0 hooks: - id: reorder-python-imports diff --git a/jax_cosmo/angular_cl.py b/jax_cosmo/angular_cl.py index 6db6c9f..1f43c11 100644 --- a/jax_cosmo/angular_cl.py +++ b/jax_cosmo/angular_cl.py @@ -11,6 +11,7 @@ import jax_cosmo.power as power import jax_cosmo.transfer as tklib from jax_cosmo.scipy.integrate import simps +from jax_cosmo.scipy.interpolate import InterpolatedUnivariateSpline from jax_cosmo.utils import a2z from jax_cosmo.utils import z2a @@ -50,7 +51,12 @@ def find_index(a, b): def angular_cl( - cosmo, ell, probes, transfer_fn=tklib.Eisenstein_Hu, nonlinear_fn=power.halofit + cosmo, + ell, + probes, + transfer_fn=tklib.Eisenstein_Hu, + nonlinear_fn=power.halofit, + npoints=128, ): """ Computes angular Cls for the provided probes @@ -90,12 +96,17 @@ def combine_kernels(inds): # Now kernels has shape [ncls, na] kernels = lax.map(combine_kernels, cl_index) - result = pk * kernels * bkgrd.dchioverda(cosmo, a) / np.clip(chi ** 2, 1.0) + result = pk * kernels * bkgrd.dchioverda(cosmo, a) / np.clip(chi**2, 1.0) - # We transpose the result just to make sure that na is first - return result.T + return result - return simps(integrand, z2a(zmax), 1.0, 512) / const.c ** 2 + atab = np.linspace(z2a(zmax), 1.0, npoints) + eval_integral = vmap( + lambda x: np.squeeze( + InterpolatedUnivariateSpline(atab, x).integral(z2a(zmax), 1.0) + ) + ) + return eval_integral(integrand(atab)) / const.c**2 return cl(ell) diff --git a/jax_cosmo/background.py b/jax_cosmo/background.py index 1a2a182..ba7566c 100644 --- a/jax_cosmo/background.py +++ b/jax_cosmo/background.py @@ -291,7 +291,7 @@ def dchioverda(cosmo, a): \frac{d \chi}{da}(a) = \frac{R_H}{a^2 E(a)} """ - return const.rh / (a ** 2 * np.sqrt(Esqr(cosmo, a))) + return const.rh / (a**2 * np.sqrt(Esqr(cosmo, a))) def transverse_comoving_distance(cosmo, a): diff --git a/jax_cosmo/power.py b/jax_cosmo/power.py index 4d28cb1..62dcfb7 100644 --- a/jax_cosmo/power.py +++ b/jax_cosmo/power.py @@ -16,7 +16,7 @@ def primordial_matter_power(cosmo, k): """Primordial power spectrum Pk = k^n """ - return k ** cosmo.n_s + return k**cosmo.n_s def linear_matter_power(cosmo, k, a=1.0, transfer_fn=tklib.Eisenstein_Hu, **kwargs): @@ -45,9 +45,9 @@ def linear_matter_power(cosmo, k, a=1.0, transfer_fn=tklib.Eisenstein_Hu, **kwar g = bkgrd.growth_factor(cosmo, a) t = transfer_fn(cosmo, k, **kwargs) - pknorm = cosmo.sigma8 ** 2 / sigmasqr(cosmo, 8.0, transfer_fn, **kwargs) + pknorm = cosmo.sigma8**2 / sigmasqr(cosmo, 8.0, transfer_fn, **kwargs) - pk = primordial_matter_power(cosmo, k) * t ** 2 * g ** 2 + pk = primordial_matter_power(cosmo, k) * t**2 * g**2 # Apply normalisation pk = pk * pknorm @@ -76,7 +76,7 @@ def int_sigma(logk): return k * (k * w) ** 2 * pk y = romb(int_sigma, np.log10(kmin), np.log10(kmax), divmax=7) - return 1.0 / (2.0 * np.pi ** 2.0) * y + return 1.0 / (2.0 * np.pi**2.0) * y def linear(cosmo, k, a, transfer_fn): @@ -103,10 +103,10 @@ def int_sigma(logk): pk = linear_matter_power(cosmo, k, transfer_fn=transfer_fn) g = bkgrd.growth_factor(cosmo, np.atleast_1d(a)) return ( - np.expand_dims(pk * k ** 3, axis=1) - * np.exp(-(y ** 2)) - / (2.0 * np.pi ** 2) - * g ** 2 + np.expand_dims(pk * k**3, axis=1) + * np.exp(-(y**2)) + / (2.0 * np.pi**2) + * g**2 ) sigma = simps(int_sigma, np.log(1e-4), np.log(1e4), 256) @@ -125,13 +125,13 @@ def integrand(logk): pk = linear_matter_power(cosmo, k, transfer_fn=transfer_fn) g = np.expand_dims(bkgrd.growth_factor(cosmo, np.atleast_1d(a)), 0) res = ( - np.expand_dims(pk * k ** 3, axis=1) - * np.exp(-(y ** 2)) - * g ** 2 - / (2.0 * np.pi ** 2) + np.expand_dims(pk * k**3, axis=1) + * np.exp(-(y**2)) + * g**2 + / (2.0 * np.pi**2) ) - dneff_dlogk = 2 * res * y ** 2 - dC_dlogk = 4 * res * (y ** 2 - y ** 4) + dneff_dlogk = 2 * res * y**2 + dC_dlogk = 4 * res * (y**2 - y**4) return np.stack([dneff_dlogk, dC_dlogk], axis=1) res = simps(integrand, np.log(1e-4), np.log(1e4), 256) @@ -185,44 +185,44 @@ def halofit(cosmo, k, a, transfer_fn, prescription="takahashi2012"): a_n = 10 ** ( 1.4861 + 1.8369 * n - + 1.6762 * n ** 2 - + 0.7940 * n ** 3 - + 0.1670 * n ** 4 + + 1.6762 * n**2 + + 0.7940 * n**3 + + 0.1670 * n**4 - 0.6206 * C ) - b_n = 10 ** (0.9463 + 0.9466 * n + 0.3084 * n ** 2 - 0.9400 * C) - c_n = 10 ** (-0.2807 + 0.6669 * n + 0.3214 * n ** 2 - 0.0793 * C) + b_n = 10 ** (0.9463 + 0.9466 * n + 0.3084 * n**2 - 0.9400 * C) + c_n = 10 ** (-0.2807 + 0.6669 * n + 0.3214 * n**2 - 0.0793 * C) gamma_n = 0.8649 + 0.2989 * n + 0.1631 * C - alpha_n = 1.3884 + 0.3700 * n - 0.1452 * n ** 2 - beta_n = 0.8291 + 0.9854 * n + 0.3401 * n ** 2 + alpha_n = 1.3884 + 0.3700 * n - 0.1452 * n**2 + beta_n = 0.8291 + 0.9854 * n + 0.3401 * n**2 mu_n = 10 ** (-3.5442 + 0.1908 * n) nu_n = 10 ** (0.9585 + 1.2857 * n) elif prescription == "takahashi2012": a_n = 10 ** ( 1.5222 + 2.8553 * n - + 2.3706 * n ** 2 - + 0.9903 * n ** 3 - + 0.2250 * n ** 4 + + 2.3706 * n**2 + + 0.9903 * n**3 + + 0.2250 * n**4 - 0.6038 * C + 0.1749 * om_de * (1 + w) ) b_n = 10 ** ( -0.5642 + 0.5864 * n - + 0.5716 * n ** 2 + + 0.5716 * n**2 - 1.5474 * C + 0.2279 * om_de * (1 + w) ) - c_n = 10 ** (0.3698 + 2.0404 * n + 0.8161 * n ** 2 + 0.5869 * C) + c_n = 10 ** (0.3698 + 2.0404 * n + 0.8161 * n**2 + 0.5869 * C) gamma_n = 0.1971 - 0.0843 * n + 0.8460 * C - alpha_n = np.abs(6.0835 + 1.3373 * n - 0.1959 * n ** 2 - 5.5274 * C) + alpha_n = np.abs(6.0835 + 1.3373 * n - 0.1959 * n**2 - 5.5274 * C) beta_n = ( 2.0379 - 0.7354 * n - + 0.3157 * n ** 2 - + 1.2490 * n ** 3 - + 0.3980 * n ** 4 + + 0.3157 * n**2 + + 1.2490 * n**3 + + 0.3980 * n**4 - 0.1682 * C ) mu_n = 0.0 @@ -232,7 +232,7 @@ def halofit(cosmo, k, a, transfer_fn, prescription="takahashi2012"): f1a = om_m ** (-0.0732) f2a = om_m ** (-0.1423) - f3a = om_m ** 0.0725 + f3a = om_m**0.0725 f1b = om_m ** (-0.0307) f2b = om_m ** (-0.0585) f3b = om_m ** (0.0743) @@ -248,21 +248,21 @@ def halofit(cosmo, k, a, transfer_fn, prescription="takahashi2012"): else: raise NotImplementedError - f = lambda x: x / 4.0 + x ** 2 / 8.0 + f = lambda x: x / 4.0 + x**2 / 8.0 - d2l = k ** 3 * pklin / (2.0 * np.pi ** 2) + d2l = k**3 * pklin / (2.0 * np.pi**2) y = k / k_nl # Eq C2 d2q = d2l * ((1.0 + d2l) ** beta_n / (1 + alpha_n * d2l)) * np.exp(-f(y)) d2hprime = ( - a_n * y ** (3 * f1) / (1.0 + b_n * y ** f2 + (c_n * f3 * y) ** (3.0 - gamma_n)) + a_n * y ** (3 * f1) / (1.0 + b_n * y**f2 + (c_n * f3 * y) ** (3.0 - gamma_n)) ) - d2h = d2hprime / (1.0 + mu_n / y + nu_n / y ** 2) + d2h = d2hprime / (1.0 + mu_n / y + nu_n / y**2) # Eq. C1 d2nl = d2q + d2h - pk_nl = 2.0 * np.pi ** 2 / k ** 3 * d2nl + pk_nl = 2.0 * np.pi**2 / k**3 * d2nl return pk_nl.squeeze() diff --git a/jax_cosmo/probes.py b/jax_cosmo/probes.py index fe47895..fab6880 100644 --- a/jax_cosmo/probes.py +++ b/jax_cosmo/probes.py @@ -70,7 +70,7 @@ def integrand_single(z_prime): radial_kernel = radial_kernel[inv] # Constant term - constant_factor = 3.0 * const.H0 ** 2 * cosmo.Omega_m / 2.0 / const.c + constant_factor = 3.0 * const.H0**2 * cosmo.Omega_m / 2.0 / const.c # Ell dependent factor ell_factor = np.sqrt((ell - 1) * (ell) * (ell + 1) * (ell + 2)) / (ell + 0.5) ** 2 return constant_factor * ell_factor * radial_kernel @@ -222,7 +222,7 @@ def noise(self): sigma_e = np.array([s for s in self.config["sigma_e"]]) else: sigma_e = self.config["sigma_e"] - return sigma_e ** 2 / ngals + return sigma_e**2 / ngals @register_pytree_node_class diff --git a/jax_cosmo/redshift.py b/jax_cosmo/redshift.py index e02e9c7..c7fb4b2 100644 --- a/jax_cosmo/redshift.py +++ b/jax_cosmo/redshift.py @@ -75,7 +75,7 @@ class smail_nz(redshift_distribution): def pz_fn(self, z): a, b, z0 = self.params - return z ** a * np.exp(-((z / z0) ** b)) + return z**a * np.exp(-((z / z0) ** b)) @register_pytree_node_class @@ -118,7 +118,7 @@ class kde_nz(redshift_distribution): def _kernel(self, bw, X, x): """Gaussian kernel for KDE""" return (1.0 / np.sqrt(2 * np.pi) / bw) * np.exp( - -((X - x) ** 2) / (bw ** 2 * 2.0) + -((X - x) ** 2) / (bw**2 * 2.0) ) def pz_fn(self, z): diff --git a/jax_cosmo/scipy/integrate.py b/jax_cosmo/scipy/integrate.py index f19dae3..baca247 100644 --- a/jax_cosmo/scipy/integrate.py +++ b/jax_cosmo/scipy/integrate.py @@ -61,7 +61,7 @@ def _romberg_diff(b, c, k): Compute the differences for the Romberg quadrature corrections. See Forman Acton's "Real Computing Made Real," p 143. """ - tmp = 4.0 ** k + tmp = 4.0**k return (tmp * c - b) / (tmp - 1.0) @@ -143,7 +143,7 @@ def scan_fn(carry, y): return (x, k + 1), x for i in range(1, divmax + 1): - n = 2 ** i + n = 2**i ordsum = ordsum + _difftrapn(vfunc, interval, n) x = intrange * ordsum / n diff --git a/jax_cosmo/scipy/interpolate.py b/jax_cosmo/scipy/interpolate.py index 9c34e7a..d1389de 100644 --- a/jax_cosmo/scipy/interpolate.py +++ b/jax_cosmo/scipy/interpolate.py @@ -258,11 +258,11 @@ def __call__(self, x): if self.k == 2: t, a, b, c = self._compute_coeffs(x) - result = a + b * t + c * t ** 2 + result = a + b * t + c * t**2 if self.k == 3: t, a, b, c, d = self._compute_coeffs(x) - result = a + b * t + c * t ** 2 + d * t ** 3 + result = a + b * t + c * t**2 + d * t**3 return result @@ -300,7 +300,7 @@ def _compute_coeffs(self, xs): dt = (x - knots[:-1])[ind] b = coefficients[ind] b1 = coefficients[ind + 1] - a = y[ind] - b * dt - (b1 - b) * dt ** 2 / (2 * h) + a = y[ind] - b * dt - (b1 - b) * dt**2 / (2 * h) c = (b1 - b) / (2 * h) result = (t, a, b, c) @@ -343,7 +343,7 @@ def derivative(self, x, n=1): if self.k == 3: t, a, b, c, d = self._compute_coeffs(x) if n == 1: - result = b + 2 * c * t + 3 * d * t ** 2 + result = b + 2 * c * t + 3 * d * t**2 if n == 2: result = 2 * c + 6 * d * t if n == 3: @@ -382,20 +382,20 @@ def antiderivative(self, xs): a = y[:-1] b = coefficients h = np.diff(knots) - cst = np.concatenate([np.zeros(1), np.cumsum(a * h + b * h ** 2 / 2)]) - return cst[ind] + a[ind] * t + b[ind] * t ** 2 / 2 + cst = np.concatenate([np.zeros(1), np.cumsum(a * h + b * h**2 / 2)]) + return cst[ind] + a[ind] * t + b[ind] * t**2 / 2 if self.k == 2: h = np.diff(knots) dt = x - knots[:-1] b = coefficients[:-1] b1 = coefficients[1:] - a = y - b * dt - (b1 - b) * dt ** 2 / (2 * h) + a = y - b * dt - (b1 - b) * dt**2 / (2 * h) c = (b1 - b) / (2 * h) cst = np.concatenate( - [np.zeros(1), np.cumsum(a * h + b * h ** 2 / 2 + c * h ** 3 / 3)] + [np.zeros(1), np.cumsum(a * h + b * h**2 / 2 + c * h**3 / 3)] ) - return cst[ind] + a[ind] * t + b[ind] * t ** 2 / 2 + c[ind] * t ** 3 / 3 + return cst[ind] + a[ind] * t + b[ind] * t**2 / 2 + c[ind] * t**3 / 3 if self.k == 3: h = np.diff(knots) @@ -408,15 +408,15 @@ def antiderivative(self, xs): cst = np.concatenate( [ np.zeros(1), - np.cumsum(a * h + b * h ** 2 / 2 + c * h ** 3 / 3 + d * h ** 4 / 4), + np.cumsum(a * h + b * h**2 / 2 + c * h**3 / 3 + d * h**4 / 4), ] ) return ( cst[ind] + a[ind] * t - + b[ind] * t ** 2 / 2 - + c[ind] * t ** 3 / 3 - + d[ind] * t ** 4 / 4 + + b[ind] * t**2 / 2 + + c[ind] * t**3 / 3 + + d[ind] * t**4 / 4 ) def integral(self, a, b): @@ -440,3 +440,32 @@ def integral(self, a, b): sign = -1 xs = np.array([a, b]) return sign * np.diff(self.antiderivative(xs)) + + +def splint(func, a, b, k=3, N=128): + """Function that computes an integration with a spline function + slightly different from the original splint from scipy + """ + x = np.linspace(a, b, N) + return InterpolatedUnivariateSpline(x, func(x), k=k).integral(a, b) + + +# def splint_fwd(func, a, b, **kwargs): +# result = splint(func, a, b, **kwargs) +# aux = (a, b, kwargs) +# return result, aux + +# def splint_bwd(func, aux, grad): +# a, b, kwargs = aux + +# grad_a = -grad * func(a) +# grad_b = grad * func(b) + +# grad_args = [] +# for i in range(len(args)): +# def _vjp_func(_t, *_args): +# return jax.grad(func, i)(_t, *_args) +# grad_args.append(grad * quad(_vjp_func, a, b, args)) +# grad_args = tuple(grad_args) + +# return grad_a, grad_b, grad_args diff --git a/jax_cosmo/transfer.py b/jax_cosmo/transfer.py index 23def59..ce66c80 100644 --- a/jax_cosmo/transfer.py +++ b/jax_cosmo/transfer.py @@ -45,7 +45,7 @@ def Eisenstein_Hu(cosmo, k, type="eisenhu_osc"): # - sh_d : sound horizon at drag epoch # - k_silk : Silk damping scale T_2_7_sqr = (const.tcmb / 2.7) ** 2 - h2 = cosmo.h ** 2 + h2 = cosmo.h**2 w_m = cosmo.Omega_m * h2 w_b = cosmo.Omega_b * h2 fb = cosmo.Omega_b / cosmo.Omega_m @@ -111,7 +111,7 @@ def Eisenstein_Hu(cosmo, k, type="eisenhu_osc"): # EH98 (11, 12) a1 = np.power(46.9 * w_m, 0.670) * (1.0 + np.power(32.1 * w_m, -0.532)) a2 = np.power(12.0 * w_m, 0.424) * (1.0 + np.power(45.0 * w_m, -0.582)) - alpha_c = np.power(a1, -fb) * np.power(a2, -(fb ** 3)) + alpha_c = np.power(a1, -fb) * np.power(a2, -(fb**3)) b1 = 0.944 / (1.0 + np.power(458.0 * w_m, -0.708)) b2 = np.power(0.395 * w_m, -0.0266) beta_c = 1.0 + b1 * (np.power(fc, b2) - 1.0) diff --git a/tests/test_power.py b/tests/test_power.py index 3c49939..8e08ed4 100644 --- a/tests/test_power.py +++ b/tests/test_power.py @@ -38,7 +38,7 @@ def test_eisenstein_hu(): # Computing matter power spectrum pk_ccl = ccl.linear_matter_power(cosmo_ccl, k, 1.0) pk_jax = ( - power.linear_matter_power(cosmo_jax, k / cosmo_jax.h, a=1.0) / cosmo_jax.h ** 3 + power.linear_matter_power(cosmo_jax, k / cosmo_jax.h, a=1.0) / cosmo_jax.h**3 ) assert_allclose(pk_ccl, pk_jax, rtol=0.5e-2) @@ -81,7 +81,7 @@ def test_halofit(): transfer_fn=tklib.Eisenstein_Hu, nonlinear_fn=power.halofit, ) - / cosmo_jax.h ** 3 + / cosmo_jax.h**3 ) assert_allclose(pk_ccl, pk_jax, rtol=0.5e-2) @@ -130,7 +130,7 @@ def test_halofit_nl_scales(): transfer_fn=tklib.Eisenstein_Hu, nonlinear_fn=power.halofit, ) - / cosmo_jax.h ** 3 + / cosmo_jax.h**3 ) assert_allclose(pk_ccl, pk_jax, rtol=0.5e-2) @@ -145,7 +145,7 @@ def test_halofit_nl_scales(): transfer_fn=tklib.Eisenstein_Hu, nonlinear_fn=power.halofit, ) - / cosmo_jax.h ** 3 + / cosmo_jax.h**3 ) # We relax the test here, because actually CCL is not accurate in this regime assert_allclose(pk_ccl, pk_jax, rtol=2e-2)