diff --git a/pyproject.toml b/pyproject.toml index 93293a4..b15f642 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,6 +16,7 @@ classifiers = [ dynamic = ["version"] requires-python = ">=3.12" dependencies = [ + "jax", "healpy", "matplotlib", "numpy", @@ -27,6 +28,7 @@ dependencies = [ # On a mac, install optional dependencies with `pip install '.[dev]'` (include the single quotes) [project.optional-dependencies] dev = [ + "jax[cpu]", # Run jax code on CPU "jupyter", # Clears output from Jupyter notebooks "pre-commit", # Used to run checks before finalizing a git commit "pytest", diff --git a/src/healpix_geometry_analysis/coordinates.py b/src/healpix_geometry_analysis/coordinates.py new file mode 100644 index 0000000..880969f --- /dev/null +++ b/src/healpix_geometry_analysis/coordinates.py @@ -0,0 +1,131 @@ +"""Healpix tile coordinates math""" + +import dataclasses +from typing import Self + +import jax.numpy as jnp + + +@dataclasses.dataclass +class HealpixCoordinates: + """Healpix tile coordinates derived from diagonal indices + + Parameters + ---------- + nside : int + Healpix nside parameter, 2^order + """ + + nside: int + """Healpix nside parameter, 2^order""" + + @classmethod + def from_order(cls, order: int) -> Self: + """Create HealpixCoordinates using healpix order (depth)""" + return cls(nside=1 << order) + + def xyz(self, k, kp): + """Cartesian coordinates on the unit sphere from diagonal indices + + Parameters + ---------- + k : float + NW-SE diagonal index + kp : float + NE-SW diagonal index + + Returns + ------- + x : float + Cartesian x coordinate + y : float + Cartesian y coordinate + z : float + Cartesian z coordinate + """ + phi, z = self.phi_z(k, kp) + x = jnp.cos(phi) * jnp.sqrt(1 - z**2) + y = jnp.sin(phi) * jnp.sqrt(1 - z**2) + return x, y, z + + def lonlat_radians(self, k, kp): + """Longitude and latitude in radians from diagonal indices + + Parameters + ---------- + k : float + SW-NE diagonal index + kp : float + SE-NW diagonal index + + Returns + ------- + lon : float + Longitude in radians + lat : float + Latitude in radians + """ + phi, z = self.phi_z(k, kp) + lon = phi + lat = jnp.arcsin(z) + return lon, lat + + def lonlat_degrees(self, k, kp): + """Longitude and latitude in degrees from diagonal indices + + Parameters + ---------- + k : float + NW-SE diagonal index + kp : float + NE-SW diagonal index + + Returns + ------- + lon : float + Longitude in degrees + lat : float + Latitude in degrees + """ + lon, lat = self.lonlat_radians(k, kp) + return jnp.degrees(lon), jnp.degrees(lat) + + def phi_z(self, k, kp): + """Cylindrical coordinates from diagonal indices + + Parameters + ---------- + k : float + NW-SE diagonal index + kp : float + NE-SW diagonal index + + Returns + ------- + phi : float + Longitude in radians + z : float + Sine of the latitude + """ + eq_phi, eq_z = self._eq(k, kp) + polar_phi, polar_z = self._polar(k, kp) + z = jnp.where(eq_z <= 2 / 3, eq_z, polar_z) + phi = jnp.where(eq_z <= 2 / 3, eq_phi, polar_phi) + return phi, z + + def _eq(self, k, kp): + """Cylidrical coordinates assuming the equatorial region""" + z = 2 / 3 * (2 - (kp + k) / self.nside) + phi = jnp.pi / 4 / self.nside * (self.nside - kp + k) + return phi, z + + def _polar(self, k, kp): + """Cylindrical coordinates assuming the polar region""" + j = jnp.abs(k) - 0.5 + i = jnp.abs(kp) + jnp.abs(k) + + z = 1 - (i / self.nside) ** 2 / 3 + phi = 0.5 * jnp.pi * (j + 0.5) / i + phi = jnp.where(kp >= 0, phi, jnp.pi - phi) + phi = jnp.where(k >= 0, phi, -phi) + return phi, z diff --git a/src/healpix_geometry_analysis/enable_x64.py b/src/healpix_geometry_analysis/enable_x64.py new file mode 100644 index 0000000..912276b --- /dev/null +++ b/src/healpix_geometry_analysis/enable_x64.py @@ -0,0 +1,9 @@ +import jax + + +def enable_x64(): + """Make Jax to use double precision by default + + It must be run before any other Jax code. + """ + jax.config.update("jax_enable_x64", True) diff --git a/tests/healpix_geometry_analysis/conftest.py b/tests/healpix_geometry_analysis/conftest.py index e69de29..1d6ed43 100644 --- a/tests/healpix_geometry_analysis/conftest.py +++ b/tests/healpix_geometry_analysis/conftest.py @@ -0,0 +1,3 @@ +from healpix_geometry_analysis.enable_x64 import enable_x64 + +enable_x64() diff --git a/tests/healpix_geometry_analysis/test_coordinates.py b/tests/healpix_geometry_analysis/test_coordinates.py new file mode 100644 index 0000000..8ed4348 --- /dev/null +++ b/tests/healpix_geometry_analysis/test_coordinates.py @@ -0,0 +1,71 @@ +import healpy as hp +import jax.numpy as jnp +import pytest +from healpix_geometry_analysis.coordinates import HealpixCoordinates +from numpy.testing import assert_allclose + + +@pytest.mark.parametrize("order", [0, 1, 2, 4, 8, 16, 20]) +def test_equatorial_region(order): + """Check if the equatorial region works as expected.""" + nside = 1 << order + + lon_ = jnp.radians(jnp.linspace(0.0, 90.0, 33)) + lat_ = jnp.arcsin(jnp.linspace(-2 / 3, 2 / 3 - 1 / nside, 33)) + lon, lat = jnp.meshgrid(lon_, lat_) + + # Use healpy to get the pixel centers + tile_id = hp.ang2pix(nside, jnp.degrees(lon), jnp.degrees(lat), lonlat=True, nest=True) + lonlat_center = hp.pix2ang(nside, tile_id, lonlat=True, nest=True) + xyz_center = hp.ang2vec(*lonlat_center, lonlat=True) + phi_center, z_center = jnp.radians(lonlat_center[0]), jnp.sin(jnp.radians(lonlat_center[1])) + + # Get the diagonal indices + k_c = 3 * nside / 4 * (2 / 3 - z_center + 8 * phi_center / (3 * jnp.pi)) + kp_c = nside + 3 * nside / 4 * (2 / 3 - z_center - 8 * phi_center / (3 * jnp.pi)) + + coords = HealpixCoordinates(nside) + phi, z = coords.phi_z(k_c, kp_c) + + assert_allclose(phi, phi_center, atol=1e-3 / nside, rtol=1e-3 / nside) + assert_allclose(z, z_center, atol=1e-3 / nside, rtol=1e-3 / nside) + + lonlat_degrees = coords.lonlat_degrees(k_c, kp_c) + assert_allclose(lonlat_degrees, lonlat_center, atol=1e-3 / nside, rtol=1e-3 / nside) + + xyz = coords.xyz(k_c, kp_c) + assert_allclose(xyz, xyz_center.T, atol=1e-3 / nside, rtol=1e-3 / nside) + + +@pytest.mark.parametrize("order", [0, 1, 2, 5, 9, 13, 19]) +def test_polar_region(order): + """Check if the North polar region works good""" + nside = 1 << order + + lon_ = jnp.radians(jnp.linspace(0.0, 90 - 1e-3 / nside, 33)) + lat_ = jnp.arcsin(jnp.linspace(2 / 3, 1, 33)) + lon, lat = jnp.meshgrid(lon_, lat_) + + # Use healpix to get the pixel centers + tile_id = hp.ang2pix(nside, jnp.degrees(lon), jnp.degrees(lat), lonlat=True, nest=True) + lonlat_center = hp.pix2ang(nside, tile_id, lonlat=True, nest=True) + xyz_center = hp.ang2vec(*lonlat_center, lonlat=True) + phi_center, z_center = jnp.radians(lonlat_center[0]), jnp.sin(jnp.radians(lonlat_center[1])) + + # Get the diagonal indices + i_c = jnp.sqrt(3) * nside * jnp.sqrt(1 - z_center) + j_c = 2 * i_c / jnp.pi * phi_center - 0.5 + k_c = j_c + 0.5 + kp_c = i_c - j_c - 0.5 + + coords = HealpixCoordinates.from_order(order) + phi, z = coords.phi_z(k_c, kp_c) + + assert_allclose(phi, phi_center, atol=1e-3 / nside, rtol=1e-3 / nside) + assert_allclose(z, z_center, atol=1e-3 / nside, rtol=1e-3 / nside) + + lonlat_degrees = coords.lonlat_degrees(k_c, kp_c) + assert_allclose(lonlat_degrees, lonlat_center, atol=1e-3 / nside, rtol=1e-3 / nside) + + xyz = coords.xyz(k_c, kp_c) + assert_allclose(xyz, xyz_center.T, atol=1e-3 / nside, rtol=1e-3 / nside)