Skip to content

Commit

Permalink
REVERT: no longer directly subclass astropy to avoid units issues (#21)
Browse files Browse the repository at this point in the history
* REVERT: no longer directly subclass astropy to avoid units issues

* TEST: make astropy jit test more sensible
  • Loading branch information
ColmTalbot authored Oct 4, 2024
1 parent 4ee193c commit b423d82
Show file tree
Hide file tree
Showing 2 changed files with 187 additions and 14 deletions.
171 changes: 157 additions & 14 deletions wcosmo/astropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,27 +204,170 @@ def de_density_scale(self, z):
return (z + 1) ** (3 * (1 + self.w0))


# give these classes dummy names to avoid some kind of namespace collision
# that overwrites the astropy classes, see,
# https://github.com/ColmTalbot/wcosmo/issues/15
@dataclass(frozen=True)
class _FlatwCDM(WCosmoMixin, _acosmo.FlatwCDM):
pass


@dataclass(frozen=True)
class _FlatLambdaCDM(WCosmoMixin, _acosmo.FlatLambdaCDM):
w0: float = field(init=False, default=-1)
class FlatwCDM(WCosmoMixin):
def __init__(
self,
H0,
Om0,
w0=-1,
Tcmb0=None,
Neff=None,
m_nu=None,
Ob0=None,
*,
zmin=1e-4,
zmax=100,
name=None,
meta=None,
):
"""FLRW cosmology with a constant dark energy EoS and no spatial curvature.
This has one additional attribute beyond those of FLRW.
Docstring copied from :code:`astropy.cosmology.flrw.wcdm.FlatwCDM`
Parameters
----------
H0 : float or scalar quantity-like ['frequency']
Hubble constant at z = 0. If a float, must be in [km/sec/Mpc].
Om0 : float
Omega matter: density of non-relativistic matter in units of the
critical density at z=0.
w0 : float, optional
Dark energy equation of state at all redshifts. This is
pressure/density for dark energy in units where c=1. A cosmological
constant has w0=-1.0.
Tcmb0 : float or scalar quantity-like ['temperature'], optional
Temperature of the CMB z=0. If a float, must be in [K]. Default: 0 [K].
Setting this to zero will turn off both photons and neutrinos
(even massive ones).
Neff : float, optional
Effective number of Neutrino species. Default 3.04.
m_nu : quantity-like ['energy', 'mass'] or array-like, optional
Mass of each neutrino species in [eV] (mass-energy equivalency enabled).
If this is a scalar Quantity, then all neutrino species are assumed to
have that mass. Otherwise, the mass of each species. The actual number
of neutrino species (and hence the number of elements of m_nu if it is
not scalar) must be the floor of Neff. Typically this means you should
provide three neutrino masses unless you are considering something like
a sterile neutrino.
Ob0 : float or None, optional
Omega baryons: density of baryonic matter in units of the critical
density at z=0. If this is set to None (the default), any computation
that requires its value will raise an exception.
name : str or None (optional, keyword-only)
Name for this cosmological object.
meta : mapping or None (optional, keyword-only)
Metadata for the cosmology, e.g., a reference.
Examples
--------
>>> from astropy.cosmology import FlatwCDM
>>> cosmo = FlatwCDM(H0=70, Om0=0.3, w0=-0.9)
The comoving distance in Mpc at redshift z:
>>> z = 0.5
>>> dc = cosmo.comoving_distance(z)
"""
self.H0 = H0
self.Om0 = Om0
self.w0 = w0
self.zmin = zmin
self.zmax = zmax
self.name = name
self.meta = meta


class FlatLambdaCDM(WCosmoMixin):
def __init__(
self,
H0,
Om0,
Tcmb0=None,
Neff=None,
m_nu=None,
Ob0=None,
*,
zmin=1e-4,
zmax=100,
name=None,
meta=None,
):
"""FLRW cosmology with a cosmological constant and no curvature.
This has no additional attributes beyond those of FLRW.
Docstring copied from :code:`astropy.cosmology.flrw.lambdacdm.FlatLambdaCDM`
FlatwCDM = _FlatwCDM
FlatLambdaCDM = _FlatLambdaCDM
Parameters
----------
H0 : float or scalar quantity-like ['frequency']
Hubble constant at z = 0. If a float, must be in [km/sec/Mpc].
Om0 : float
Omega matter: density of non-relativistic matter in units of the
critical density at z=0.
Tcmb0 : float or scalar quantity-like ['temperature'], optional
Temperature of the CMB z=0. If a float, must be in [K]. Default: 0 [K].
Setting this to zero will turn off both photons and neutrinos
(even massive ones).
Neff : float, optional
Effective number of Neutrino species. Default 3.04.
m_nu : quantity-like ['energy', 'mass'] or array-like, optional
Mass of each neutrino species in [eV] (mass-energy equivalency enabled).
If this is a scalar Quantity, then all neutrino species are assumed to
have that mass. Otherwise, the mass of each species. The actual number
of neutrino species (and hence the number of elements of m_nu if it is
not scalar) must be the floor of Neff. Typically this means you should
provide three neutrino masses unless you are considering something like
a sterile neutrino.
Ob0 : float or None, optional
Omega baryons: density of baryonic matter in units of the critical
density at z=0. If this is set to None (the default), any computation
that requires its value will raise an exception.
name : str or None (optional, keyword-only)
Name for this cosmological object.
meta : mapping or None (optional, keyword-only)
Metadata for the cosmology, e.g., a reference.
Examples
--------
>>> from astropy.cosmology import FlatLambdaCDM
>>> cosmo = FlatLambdaCDM(H0=70, Om0=0.3)
The comoving distance in Mpc at redshift z:
>>> z = 0.5
>>> dc = cosmo.comoving_distance(z)
"""
self.H0 = H0
self.Om0 = Om0
self.w0 = -1
self.zmin = zmin
self.zmax = zmax
self.name = name
self.meta = meta


def __getattr__(name):
if name not in __all__:
alt = _acosmo.__getattr__(name)
cosmo = _FlatLambdaCDM(**alt.parameters)
cosmo = FlatLambdaCDM(**alt.parameters)
setattr(sys.modules[__name__], name, cosmo)
return cosmo

Expand Down
30 changes: 30 additions & 0 deletions wcosmo/test/test_astropy.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,37 @@
import pytest


def test_astropy_cosmology_not_clobbered():
"""See https://github.com/ColmTalbot/wcosmo/issues/15"""
import astropy.cosmology

import wcosmo.astropy

assert "wcosmo" not in astropy.cosmology.Planck15.__module__


def test_jits():
pytest.importorskip("jax")
import gwpopulation
from astropy.cosmology import FlatLambdaCDM
from jax import jit

import wcosmo
from wcosmo.astropy import FlatwCDM
from wcosmo.utils import disable_units

@jit
def test_func(h0):
cosmo = FlatwCDM(h0, 0.1, -1)
return cosmo.luminosity_distance(0.1)

gwpopulation.set_backend("jax")
disable_units()

assert (
abs(
float(test_func(67.0))
- FlatLambdaCDM(67.0, 0.1).luminosity_distance(0.1).value
)
< 1
)

0 comments on commit b423d82

Please sign in to comment.