Skip to content

Commit

Permalink
FEAT: enable jax units using unxt (#23)
Browse files Browse the repository at this point in the history
* FEAT: enable jax units using unxt

* BLD: fix which modules get xp and add unxt as a test dependency
  • Loading branch information
ColmTalbot authored Oct 15, 2024
1 parent b423d82 commit ff85eab
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 20 deletions.
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ test = [
"pytest-cov",
"gwpopulation",
"jax>=0.4.16",
"unxt",
]

[tool.setuptools]
Expand All @@ -38,8 +39,9 @@ write_to = "wcosmo/_version.py"

[project.entry-points."gwpopulation.xp"]
wcosmo = "wcosmo.wcosmo"
wcosmo-utils = "wcosmo.utils"
wcosmo-astropy = "wcosmo.astropy"
wcosmo-taylor = "wcosmo.taylor"
wcosmo-utils = "wcosmo.utils"

[project.entry-points."gwpopulation.other"]
wcosmo-taylor = "wcosmo.taylor:scipy.linalg.toeplitz"
26 changes: 21 additions & 5 deletions wcosmo/astropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,12 @@
"""

import sys
from dataclasses import dataclass, field

import astropy.cosmology as _acosmo
import numpy as xp
from astropy import units

from .utils import autodoc, method_autodoc, strip_units
from .utils import autodoc, convert_quantity_if_necessary, method_autodoc, strip_units
from .wcosmo import *

USE_UNITS = True
Expand Down Expand Up @@ -60,6 +59,14 @@ class WCosmoMixin:
frame to the detector frame, see :func:`source_to_detector_frame`
"""

@property
def H0(self):
return self._H0

@H0.setter
def H0(self, value):
self._H0 = convert_quantity_if_necessary(value, unit="km s^-1 Mpc^-1")

@property
def _kwargs(self):
kwargs = {"H0": self.H0, "Om0": self.Om0, "w0": self.w0}
Expand Down Expand Up @@ -364,11 +371,20 @@ def __init__(
self.meta = meta


_known_cosmologies = dict()


def __getattr__(name):
if name not in __all__:
if f"{name}_{xp.__name__}" in _known_cosmologies:
return _known_cosmologies[f"{name}_{xp.__name__}"]
elif name not in __all__:
alt = _acosmo.__getattr__(name)
cosmo = FlatLambdaCDM(**alt.parameters)
setattr(sys.modules[__name__], name, cosmo)
params = {
key: convert_quantity_if_necessary(arg)
for key, arg in alt.parameters.items()
}
cosmo = FlatLambdaCDM(**params)
_known_cosmologies[f"{name}_{xp.__name__}"] = cosmo
return cosmo


Expand Down
4 changes: 3 additions & 1 deletion wcosmo/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@

from astropy import units

from .utils import convert_quantity_if_necessary

__all__ = ["USE_UNITS", "c_km_per_s", "gyr_km_per_s_mpc"]


Expand All @@ -18,7 +20,7 @@ def __getattr__(name):
if value is None:
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
if USE_UNITS:
value = value << _UNITS[name]
value = convert_quantity_if_necessary(value, _UNITS[name])
return value


Expand Down
21 changes: 8 additions & 13 deletions wcosmo/test/test_astropy.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ def test_astropy_cosmology_not_clobbered():
"""See https://github.com/ColmTalbot/wcosmo/issues/15"""
import astropy.cosmology

import wcosmo.astropy
import wcosmo.astropy # noqa

assert "wcosmo" not in astropy.cosmology.Planck15.__module__

Expand All @@ -16,22 +16,17 @@ def test_jits():
from astropy.cosmology import FlatLambdaCDM
from jax import jit

import wcosmo
from wcosmo.astropy import FlatwCDM
from wcosmo.utils import disable_units

gwpopulation.set_backend("jax")

@jit
def test_func(h0):
cosmo = FlatwCDM(h0, 0.1, -1)
return cosmo.luminosity_distance(0.1)

gwpopulation.set_backend("jax")
disable_units()

assert (
abs(
float(test_func(67.0))
- FlatLambdaCDM(67.0, 0.1).luminosity_distance(0.1).value
)
< 1
)
my_result = test_func(67.0)
their_result = FlatLambdaCDM(67.0, 0.1).luminosity_distance(0.1)

assert abs(float(my_result.value) - their_result.value) < 1
assert my_result.unit == their_result.unit
45 changes: 45 additions & 0 deletions wcosmo/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,3 +143,48 @@ def strip_units(value):
if isinstance(value, np.float64):
value = value.item()
return value


def convert_quantity_if_necessary(arg, unit=None):
"""
Helper function to convert between :code:`astropy` and :code:`unxt`
quantities and non-unitful values.
The order of precedence is as follows:
- If using :code:`jax.numpy` as the backend, the input is an
:code:`astropy` or :code:`unxt` quantity or :code:`unit` is specified,
convert to a :code:`unxt` quantity with the provided unit.
- If using :code:`jax.numpy` as the backend, the input is not a quantiy
and no unit is provided, return the input.
- If a unit and an :code:`astropy` quantity are provided, convert the
input to an :code:`astropy` quantity with the provided unit
- If a unit is provided, convert the input to an :code:`astropy`
quantity with the provided unit.
- Else return the input as is.
Parameters
==========
arg: Union[astropy.units.Quantity, unxt.Quantity, array_like]
The array to convert
unit: Optional[astropy.units.Unit, str]
The unit to convert to
Returns
=======
Union[astropy.units.Quantity, unxt.Quantity, array_like]
The converted array
"""
from astropy.units import Quantity as _Quantity

if xp.__name__ == "jax.numpy" and (isinstance(arg, _Quantity) or unit is not None):
from unxt import Quantity

if unit is None:
return Quantity.from_(arg)
return Quantity.from_(arg, unit)
elif unit is not None and isinstance(arg, _Quantity):
return arg.to(unit)
elif unit is not None:
return _Quantity(arg, unit)
return arg

0 comments on commit ff85eab

Please sign in to comment.