diff --git a/pyproject.toml b/pyproject.toml index bc8736a..144b127 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,6 +28,7 @@ test = [ "pytest-cov", "gwpopulation", "jax>=0.4.16", + "unxt", ] [tool.setuptools] @@ -38,8 +39,9 @@ write_to = "wcosmo/_version.py" [project.entry-points."gwpopulation.xp"] wcosmo = "wcosmo.wcosmo" -wcosmo-utils = "wcosmo.utils" +wcosmo-astropy = "wcosmo.astropy" wcosmo-taylor = "wcosmo.taylor" +wcosmo-utils = "wcosmo.utils" [project.entry-points."gwpopulation.other"] wcosmo-taylor = "wcosmo.taylor:scipy.linalg.toeplitz" diff --git a/wcosmo/astropy.py b/wcosmo/astropy.py index d90610c..a53e99c 100644 --- a/wcosmo/astropy.py +++ b/wcosmo/astropy.py @@ -10,13 +10,12 @@ """ import sys -from dataclasses import dataclass, field import astropy.cosmology as _acosmo import numpy as xp from astropy import units -from .utils import autodoc, method_autodoc, strip_units +from .utils import autodoc, convert_quantity_if_necessary, method_autodoc, strip_units from .wcosmo import * USE_UNITS = True @@ -60,6 +59,14 @@ class WCosmoMixin: frame to the detector frame, see :func:`source_to_detector_frame` """ + @property + def H0(self): + return self._H0 + + @H0.setter + def H0(self, value): + self._H0 = convert_quantity_if_necessary(value, unit="km s^-1 Mpc^-1") + @property def _kwargs(self): kwargs = {"H0": self.H0, "Om0": self.Om0, "w0": self.w0} @@ -364,11 +371,20 @@ def __init__( self.meta = meta +_known_cosmologies = dict() + + def __getattr__(name): - if name not in __all__: + if f"{name}_{xp.__name__}" in _known_cosmologies: + return _known_cosmologies[f"{name}_{xp.__name__}"] + elif name not in __all__: alt = _acosmo.__getattr__(name) - cosmo = FlatLambdaCDM(**alt.parameters) - setattr(sys.modules[__name__], name, cosmo) + params = { + key: convert_quantity_if_necessary(arg) + for key, arg in alt.parameters.items() + } + cosmo = FlatLambdaCDM(**params) + _known_cosmologies[f"{name}_{xp.__name__}"] = cosmo return cosmo diff --git a/wcosmo/constants.py b/wcosmo/constants.py index 18f5af5..2274194 100644 --- a/wcosmo/constants.py +++ b/wcosmo/constants.py @@ -10,6 +10,8 @@ from astropy import units +from .utils import convert_quantity_if_necessary + __all__ = ["USE_UNITS", "c_km_per_s", "gyr_km_per_s_mpc"] @@ -18,7 +20,7 @@ def __getattr__(name): if value is None: raise AttributeError(f"module {__name__!r} has no attribute {name!r}") if USE_UNITS: - value = value << _UNITS[name] + value = convert_quantity_if_necessary(value, _UNITS[name]) return value diff --git a/wcosmo/test/test_astropy.py b/wcosmo/test/test_astropy.py index 5543e83..63e000c 100644 --- a/wcosmo/test/test_astropy.py +++ b/wcosmo/test/test_astropy.py @@ -5,7 +5,7 @@ def test_astropy_cosmology_not_clobbered(): """See https://github.com/ColmTalbot/wcosmo/issues/15""" import astropy.cosmology - import wcosmo.astropy + import wcosmo.astropy # noqa assert "wcosmo" not in astropy.cosmology.Planck15.__module__ @@ -16,22 +16,17 @@ def test_jits(): from astropy.cosmology import FlatLambdaCDM from jax import jit - import wcosmo from wcosmo.astropy import FlatwCDM - from wcosmo.utils import disable_units + + gwpopulation.set_backend("jax") @jit def test_func(h0): cosmo = FlatwCDM(h0, 0.1, -1) return cosmo.luminosity_distance(0.1) - gwpopulation.set_backend("jax") - disable_units() - - assert ( - abs( - float(test_func(67.0)) - - FlatLambdaCDM(67.0, 0.1).luminosity_distance(0.1).value - ) - < 1 - ) + my_result = test_func(67.0) + their_result = FlatLambdaCDM(67.0, 0.1).luminosity_distance(0.1) + + assert abs(float(my_result.value) - their_result.value) < 1 + assert my_result.unit == their_result.unit diff --git a/wcosmo/utils.py b/wcosmo/utils.py index 6bdc0e9..09c0159 100644 --- a/wcosmo/utils.py +++ b/wcosmo/utils.py @@ -143,3 +143,48 @@ def strip_units(value): if isinstance(value, np.float64): value = value.item() return value + + +def convert_quantity_if_necessary(arg, unit=None): + """ + Helper function to convert between :code:`astropy` and :code:`unxt` + quantities and non-unitful values. + + The order of precedence is as follows: + + - If using :code:`jax.numpy` as the backend, the input is an + :code:`astropy` or :code:`unxt` quantity or :code:`unit` is specified, + convert to a :code:`unxt` quantity with the provided unit. + - If using :code:`jax.numpy` as the backend, the input is not a quantiy + and no unit is provided, return the input. + - If a unit and an :code:`astropy` quantity are provided, convert the + input to an :code:`astropy` quantity with the provided unit + - If a unit is provided, convert the input to an :code:`astropy` + quantity with the provided unit. + - Else return the input as is. + + Parameters + ========== + arg: Union[astropy.units.Quantity, unxt.Quantity, array_like] + The array to convert + unit: Optional[astropy.units.Unit, str] + The unit to convert to + + Returns + ======= + Union[astropy.units.Quantity, unxt.Quantity, array_like] + The converted array + """ + from astropy.units import Quantity as _Quantity + + if xp.__name__ == "jax.numpy" and (isinstance(arg, _Quantity) or unit is not None): + from unxt import Quantity + + if unit is None: + return Quantity.from_(arg) + return Quantity.from_(arg, unit) + elif unit is not None and isinstance(arg, _Quantity): + return arg.to(unit) + elif unit is not None: + return _Quantity(arg, unit) + return arg