From b423d828b359eaa8aab9db80c8eddce17281e84f Mon Sep 17 00:00:00 2001 From: Colm Talbot Date: Fri, 4 Oct 2024 11:21:16 -0500 Subject: [PATCH] REVERT: no longer directly subclass astropy to avoid units issues (#21) * REVERT: no longer directly subclass astropy to avoid units issues * TEST: make astropy jit test more sensible --- wcosmo/astropy.py | 171 +++++++++++++++++++++++++++++++++--- wcosmo/test/test_astropy.py | 30 +++++++ 2 files changed, 187 insertions(+), 14 deletions(-) diff --git a/wcosmo/astropy.py b/wcosmo/astropy.py index 0d52aaf..d90610c 100644 --- a/wcosmo/astropy.py +++ b/wcosmo/astropy.py @@ -204,27 +204,170 @@ def de_density_scale(self, z): return (z + 1) ** (3 * (1 + self.w0)) -# give these classes dummy names to avoid some kind of namespace collision -# that overwrites the astropy classes, see, -# https://github.com/ColmTalbot/wcosmo/issues/15 -@dataclass(frozen=True) -class _FlatwCDM(WCosmoMixin, _acosmo.FlatwCDM): - pass - - -@dataclass(frozen=True) -class _FlatLambdaCDM(WCosmoMixin, _acosmo.FlatLambdaCDM): - w0: float = field(init=False, default=-1) +class FlatwCDM(WCosmoMixin): + def __init__( + self, + H0, + Om0, + w0=-1, + Tcmb0=None, + Neff=None, + m_nu=None, + Ob0=None, + *, + zmin=1e-4, + zmax=100, + name=None, + meta=None, + ): + """FLRW cosmology with a constant dark energy EoS and no spatial curvature. + + This has one additional attribute beyond those of FLRW. + + Docstring copied from :code:`astropy.cosmology.flrw.wcdm.FlatwCDM` + Parameters + ---------- + H0 : float or scalar quantity-like ['frequency'] + Hubble constant at z = 0. If a float, must be in [km/sec/Mpc]. + + Om0 : float + Omega matter: density of non-relativistic matter in units of the + critical density at z=0. + + w0 : float, optional + Dark energy equation of state at all redshifts. This is + pressure/density for dark energy in units where c=1. A cosmological + constant has w0=-1.0. + + Tcmb0 : float or scalar quantity-like ['temperature'], optional + Temperature of the CMB z=0. If a float, must be in [K]. Default: 0 [K]. + Setting this to zero will turn off both photons and neutrinos + (even massive ones). + + Neff : float, optional + Effective number of Neutrino species. Default 3.04. + + m_nu : quantity-like ['energy', 'mass'] or array-like, optional + Mass of each neutrino species in [eV] (mass-energy equivalency enabled). + If this is a scalar Quantity, then all neutrino species are assumed to + have that mass. Otherwise, the mass of each species. The actual number + of neutrino species (and hence the number of elements of m_nu if it is + not scalar) must be the floor of Neff. Typically this means you should + provide three neutrino masses unless you are considering something like + a sterile neutrino. + + Ob0 : float or None, optional + Omega baryons: density of baryonic matter in units of the critical + density at z=0. If this is set to None (the default), any computation + that requires its value will raise an exception. + + name : str or None (optional, keyword-only) + Name for this cosmological object. + + meta : mapping or None (optional, keyword-only) + Metadata for the cosmology, e.g., a reference. + + Examples + -------- + >>> from astropy.cosmology import FlatwCDM + >>> cosmo = FlatwCDM(H0=70, Om0=0.3, w0=-0.9) + + The comoving distance in Mpc at redshift z: + + >>> z = 0.5 + >>> dc = cosmo.comoving_distance(z) + """ + self.H0 = H0 + self.Om0 = Om0 + self.w0 = w0 + self.zmin = zmin + self.zmax = zmax + self.name = name + self.meta = meta + + +class FlatLambdaCDM(WCosmoMixin): + def __init__( + self, + H0, + Om0, + Tcmb0=None, + Neff=None, + m_nu=None, + Ob0=None, + *, + zmin=1e-4, + zmax=100, + name=None, + meta=None, + ): + """FLRW cosmology with a cosmological constant and no curvature. + + This has no additional attributes beyond those of FLRW. + + Docstring copied from :code:`astropy.cosmology.flrw.lambdacdm.FlatLambdaCDM` -FlatwCDM = _FlatwCDM -FlatLambdaCDM = _FlatLambdaCDM + Parameters + ---------- + H0 : float or scalar quantity-like ['frequency'] + Hubble constant at z = 0. If a float, must be in [km/sec/Mpc]. + + Om0 : float + Omega matter: density of non-relativistic matter in units of the + critical density at z=0. + + Tcmb0 : float or scalar quantity-like ['temperature'], optional + Temperature of the CMB z=0. If a float, must be in [K]. Default: 0 [K]. + Setting this to zero will turn off both photons and neutrinos + (even massive ones). + + Neff : float, optional + Effective number of Neutrino species. Default 3.04. + + m_nu : quantity-like ['energy', 'mass'] or array-like, optional + Mass of each neutrino species in [eV] (mass-energy equivalency enabled). + If this is a scalar Quantity, then all neutrino species are assumed to + have that mass. Otherwise, the mass of each species. The actual number + of neutrino species (and hence the number of elements of m_nu if it is + not scalar) must be the floor of Neff. Typically this means you should + provide three neutrino masses unless you are considering something like + a sterile neutrino. + + Ob0 : float or None, optional + Omega baryons: density of baryonic matter in units of the critical + density at z=0. If this is set to None (the default), any computation + that requires its value will raise an exception. + + name : str or None (optional, keyword-only) + Name for this cosmological object. + + meta : mapping or None (optional, keyword-only) + Metadata for the cosmology, e.g., a reference. + + Examples + -------- + >>> from astropy.cosmology import FlatLambdaCDM + >>> cosmo = FlatLambdaCDM(H0=70, Om0=0.3) + + The comoving distance in Mpc at redshift z: + + >>> z = 0.5 + >>> dc = cosmo.comoving_distance(z) + """ + self.H0 = H0 + self.Om0 = Om0 + self.w0 = -1 + self.zmin = zmin + self.zmax = zmax + self.name = name + self.meta = meta def __getattr__(name): if name not in __all__: alt = _acosmo.__getattr__(name) - cosmo = _FlatLambdaCDM(**alt.parameters) + cosmo = FlatLambdaCDM(**alt.parameters) setattr(sys.modules[__name__], name, cosmo) return cosmo diff --git a/wcosmo/test/test_astropy.py b/wcosmo/test/test_astropy.py index f58134c..5543e83 100644 --- a/wcosmo/test/test_astropy.py +++ b/wcosmo/test/test_astropy.py @@ -1,3 +1,6 @@ +import pytest + + def test_astropy_cosmology_not_clobbered(): """See https://github.com/ColmTalbot/wcosmo/issues/15""" import astropy.cosmology @@ -5,3 +8,30 @@ def test_astropy_cosmology_not_clobbered(): import wcosmo.astropy assert "wcosmo" not in astropy.cosmology.Planck15.__module__ + + +def test_jits(): + pytest.importorskip("jax") + import gwpopulation + from astropy.cosmology import FlatLambdaCDM + from jax import jit + + import wcosmo + from wcosmo.astropy import FlatwCDM + from wcosmo.utils import disable_units + + @jit + def test_func(h0): + cosmo = FlatwCDM(h0, 0.1, -1) + return cosmo.luminosity_distance(0.1) + + gwpopulation.set_backend("jax") + disable_units() + + assert ( + abs( + float(test_func(67.0)) + - FlatLambdaCDM(67.0, 0.1).luminosity_distance(0.1).value + ) + < 1 + )