Skip to content

Commit

Permalink
FEAT: Switch to implicit backend choosing (#27)
Browse files Browse the repository at this point in the history
* DOC: some fixes for rtd

* REFACTOR: make backend switching implicit

* DOC: update api docs

* MAINT: add missing dependency

* BUG: Refactor H0 setter logic for value conversion

* MAINT: bump some versions

* MAINT: resolve deprecation warning

* BUG: fix implicit unit checks

* FORMAT: apply pre-commits

* DOC: fix pages requirements

* DOC: fix rtd configuration
  • Loading branch information
ColmTalbot authored Nov 27, 2024
1 parent 1f4dcaf commit d15ee7d
Show file tree
Hide file tree
Showing 27 changed files with 567 additions and 298 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/pages.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ jobs:
- name: Install dependencies
run: |
conda install --file doc/pages_requirements.txt
python -m pip install .
python -m pip install .[test]
- name: Build documentation
run: |
Expand Down
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
*.egg-info
build/
doc/build/
doc/source/api/
doc/source/api/_autosummary
doc/source/examples/*.ipynb
.coverage
wcosmo/_version.py
Expand Down
4 changes: 3 additions & 1 deletion .readthedocs.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,6 @@ python:
install:
- requirements: doc/pages_requirements.txt
- method: pip
path: .
path: .
extra_requirements:
- test
54 changes: 29 additions & 25 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ with :code:`numpy`-like backends, e.g., :code:`jax` and :code:`cupy`.
There are two main features leading to superior efficiency to :code:`astropy`:

- Integrals of :math:`E(z)` and related functions are performed analytically
with Pade approximations.
with Pade approximations or analytic expressions using hypergeometric functions.
- Support for :code:`jax` and :code:`cupy` backends allow hardware
acceleration, just-in-time compilation, and automatic differentiation.

Expand All @@ -17,7 +17,8 @@ The primary limitations are:
equations of state, e.g., :code:`FlatwCDM`.
- Approximations to the various integrals generally agree with :code:`astropy`
at the <0.1% level.
- The :code:`astropy` units are incompatible with non-:code:`numpy` backends.
- The :code:`astropy` units are incompatible with the :code:`cupy` backend.
Units are supported with the :code:`jax` backend using :code:`unxt`.

Installation and contribution
-----------------------------
Expand Down Expand Up @@ -61,6 +62,22 @@ To import an astropy-like cosmology
>>> cosmology = FlatwCDM(H0=70, Om0=0.3, w0=-1)
>>> cosmology.luminosity_distance(1)
The built-in cosmologies in :code:`astropy` are all available, e.g.,

.. code-block:: python
>>> from wcosmo import Planck18
>>> Planck18.luminosity_distance(1)
<Quantity 6797.43628659 Mpc>
they can also be accessed using :code:`wcosmo.available`

.. code-block:: python
>>> from wcosmo import available
>>> available["Planck18"].luminosity_distance(1)
<Quantity 6797.43628659 Mpc>
Explicit usage of :code:`astropy` units can be freely enabled/disabled.
In this case, the values will have the default units for each method.

Expand All @@ -78,32 +95,19 @@ In this case, the values will have the default units for each method.
>>> cosmology.luminosity_distance(1)
<Quantity 6607.65773208 Mpc>
GWPopulation
^^^^^^^^^^^^

The primary intention for this package is for use with :code:`GWPopulation`.
This code is automatically used in :code:`GWPopulation` when using either
:code:`gwpopulation.experimental.cosmo_models.CosmoModel` and/or
:code:`PowerLawRedshift`
Changing backend
^^^^^^^^^^^^^^^^

The backend can be switched automatically using, e.g.,

.. code-block:: python
>>> import gwpopulation
>>> gwpopulation.backend.set_backend("jax")
Manual backend setting can be done as follows:

.. code-block:: python
:code:`wcosmo` mostly relies on implicit backend switching. The backend is
determined automatically based on the input arguments. When an input value is
a :code:`Python` built-in type, the default backend is chosen using the
environment variable :code:`WCOSMO_ARRAY_API`. The default is :code:`numpy`.

>>> import jax.numpy as jnp
>>> from jax.scipy.linalg.toeplitz import toeplitz
GWPopulation
^^^^^^^^^^^^

>>> from wcosmo import wcosmo, utils
>>> wcosmo.xp = jnp
>>> utils.xp = jnp
>>> utils.toeplitz = toeplitz
The original intention for this package was for use with :code:`GWPopulation`.
This code is automatically used in :code:`GWPopulation` when using either
:code:`gwpopulation.experimental.cosmo_models.CosmoModel` and/or
:code:`PowerLawRedshift`
11 changes: 11 additions & 0 deletions doc/source/api/analytic.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
`wcosmo.analytic`
=================

.. currentmodule:: wcosmo.analytic

.. automodule:: wcosmo.analytic

.. autosummary::
:toctree: _autosummary

indefinite_integral_hypergeometric
13 changes: 13 additions & 0 deletions doc/source/api/astropy.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
`wcosmo.astropy`
================

.. currentmodule:: wcosmo.astropy

.. automodule:: wcosmo.astropy

.. autosummary::
:toctree: _autosummary

FlatLambdaCDM
FlatwCDM
WCosmoMixin
8 changes: 4 additions & 4 deletions doc/source/api/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@ API Reference

.. currentmodule:: wcosmo

.. autosummary::
:toctree: .
:template: custom-module-template.rst
.. toctree::
:caption: API
:recursive:
:maxdepth: 1

analytic
astropy
integrate
taylor
utils
wcosmo
12 changes: 12 additions & 0 deletions doc/source/api/integrate.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
`wcosmo.integrate`
==================

.. currentmodule:: wcosmo.integrate

.. automodule:: wcosmo.integrate

.. autosummary::
:toctree: _autosummary

analytic_integral
indefinite_integral
13 changes: 13 additions & 0 deletions doc/source/api/taylor.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
`wcosmo.taylor`
===============

.. currentmodule:: wcosmo.taylor

.. automodule:: wcosmo.taylor

.. autosummary::
:toctree: _autosummary

flat_wcdm_pade_coefficients
flat_wcdm_taylor_expansion
indefinite_integral_pade
19 changes: 19 additions & 0 deletions doc/source/api/utils.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
`wcosmo.utils`
===============

.. currentmodule:: wcosmo.utils

.. automodule:: wcosmo.utils

.. autosummary::
:toctree: _autosummary

array_namespace
autodoc
convert_quantity_if_necessary
default_array_namespace
disable_units
enable_units
method_autodoc
maybe_jit
strip_units
25 changes: 25 additions & 0 deletions doc/source/api/wcosmo.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
`wcosmo.wcosmo`
===============

.. currentmodule:: wcosmo.wcosmo

.. automodule:: wcosmo.wcosmo

.. autosummary::
:toctree: _autosummary

absorption_distance
comoving_distance
comoving_volume
detector_to_source_frame
differential_comoving_volume
dDLdz
efunc
hubble_distance
hubble_parameter
hubble_time
inv_efunc
lookback_time
luminosity_distance
source_to_detector_frame
z_at_value
23 changes: 7 additions & 16 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,19 @@ dependencies = [
"numpy",
"scipy",
"astropy>=6.1",
"plum-dispatch",
"array-api-compat",
]
dynamic = ["version"]

[project.optional-dependencies]
test = [
"pytest-cov",
"gwpopulation",
"jax>=0.4.16",
"jax>=0.4.34",
"unxt",
]
jax = [
"jax>=0.4.34",
"unxt",
]

Expand All @@ -36,17 +41,3 @@ packages = ["wcosmo"]

[tool.setuptools_scm]
write_to = "wcosmo/_version.py"

[project.entry-points."gwpopulation.xp"]
wcosmo = "wcosmo.wcosmo"
wcosmo-analytic = "wcosmo.analytic"
wcosmo-astropy = "wcosmo.astropy"
wcosmo-integrate = "wcosmo.integrate"
wcosmo-taylor = "wcosmo.taylor"
wcosmo-utils = "wcosmo.utils"

[project.entry-points."gwpopulation.scs"]
wcosmo-analytic = "wcosmo.analytic"

[project.entry-points."gwpopulation.other"]
wcosmo-taylor = "wcosmo.taylor:scipy.linalg.toeplitz"
1 change: 1 addition & 0 deletions wcosmo/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from ._version import __version__
from .backend import AVAILABLE_BACKENDS
from .utils import disable_units, enable_units
from .wcosmo import *

Expand Down
32 changes: 0 additions & 32 deletions wcosmo/_hyp2f1_jax.py

This file was deleted.

33 changes: 18 additions & 15 deletions wcosmo/analytic.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
from typing import Union

import numpy as np
import scipy.special as scs
from plum import dispatch

from .utils import autodoc

xp = np

__all__ = ["indefinite_integral_hypergeometric"]


@dispatch
@autodoc
def indefinite_integral_hypergeometric(z, Om0, w0=-1, zpower=0):
def indefinite_integral_hypergeometric(
z: Union[float, int, np.ndarray], Om0, w0=-1, zpower=0
):
r"""
Compute the integral of :math:`(1+z)^k / E(z)` as described in
https://doi.org/10.4236/jhepgc.2021.73057.
Expand Down Expand Up @@ -61,20 +65,19 @@ def indefinite_integral_hypergeometric(z, Om0, w0=-1, zpower=0):
This has been discussed in :code:`cupy` and may be implemented in the
future (https://github.com/cupy/cupy/issues/8274).
"""
if xp.__name__ == "jax.numpy":
from ._hyp2f1_jax import hyp2f1
else:
from scipy.special import hyp2f1
return _indefinite_integral_hypergeometric(
z, Om0=Om0, w0=w0, zpower=zpower, hyp2f1=scs.hyp2f1, beta=scs.beta
)


def _indefinite_integral_hypergeometric(
z: np.ndarray, Om0, w0=-1, zpower=0, *, hyp2f1=None, beta=None
):
value = (1 + z) ** (zpower - 1 / 2)
x = (Om0 - 1) / Om0 * (1 + z) ** (3 * w0)
aa = 1 / 2
# jax will evaluate all the branches of the analytic integral and so we
# need to manually catch zero division errors.
try:
x = (Om0 - 1) / Om0 * (1 + z) ** (3 * w0)
bb = (zpower - 1 / 2) / (3 * w0)
except ZeroDivisionError:
return z * 0.0
bb = (zpower - 1 / 2) / (3 * w0)
cc = bb + 1
values = hyp2f1(aa, bb, cc, x)
normalization = scs.beta(bb, cc - bb) * values / (3 * w0 * Om0**0.5)
normalization = beta(bb, cc - bb) * values / (3 * w0 * Om0**0.5)
return value * normalization
Loading

0 comments on commit d15ee7d

Please sign in to comment.