Skip to content

Commit

Permalink
Initial HealpixCoordinates
Browse files Browse the repository at this point in the history
  • Loading branch information
hombit committed Jun 18, 2024
1 parent f32aa97 commit 7737fbf
Show file tree
Hide file tree
Showing 5 changed files with 215 additions and 0 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ classifiers = [
dynamic = ["version"]
requires-python = ">=3.12"
dependencies = [
"jax",
"healpy",
"matplotlib",
"numpy",
Expand Down
131 changes: 131 additions & 0 deletions src/healpix_geometry_analysis/coordinates.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
"""Healpix tile coordinates math"""

import dataclasses
from typing import Self

import jax.numpy as jnp


@dataclasses.dataclass
class HealpixCoordinates:
"""Healpix tile coordinates derived from diagonal indices
Parameters
----------
nside : int
Healpix nside parameter, 2^order
"""

nside: int
"""Healpix nside parameter, 2^order"""

@classmethod
def from_order(cls, order: int) -> Self:
"""Create HealpixCoordinates using healpix order (depth)"""
return cls(nside=1 << order)

def xyz(self, k, kp):
"""Cartesian coordinates on the unit sphere from diagonal indices
Parameters
----------
k : float
NW-SE diagonal index
kp : float
NE-SW diagonal index
Returns
-------
x : float
Cartesian x coordinate
y : float
Cartesian y coordinate
z : float
Cartesian z coordinate
"""
phi, z = self.phi_z(k, kp)
x = jnp.cos(phi) * jnp.sqrt(1 - z**2)
y = jnp.sin(phi) * jnp.sqrt(1 - z**2)
return x, y, z

def lonlat_radians(self, k, kp):
"""Longitude and latitude in radians from diagonal indices
Parameters
----------
k : float
SW-NE diagonal index
kp : float
SE-NW diagonal index
Returns
-------
lon : float
Longitude in radians
lat : float
Latitude in radians
"""
phi, z = self.phi_z(k, kp)
lon = phi
lat = jnp.arcsin(z)
return lon, lat

def lonlat_degrees(self, k, kp):
"""Longitude and latitude in degrees from diagonal indices
Parameters
----------
k : float
NW-SE diagonal index
kp : float
NE-SW diagonal index
Returns
-------
lon : float
Longitude in degrees
lat : float
Latitude in degrees
"""
lon, lat = self.lonlat_radians(k, kp)
return jnp.degrees(lon), jnp.degrees(lat)

def phi_z(self, k, kp):
"""Cylindrical coordinates from diagonal indices
Parameters
----------
k : float
NW-SE diagonal index
kp : float
NE-SW diagonal index
Returns
-------
phi : float
Longitude in radians
z : float
Sine of the latitude
"""
eq_phi, eq_z = self._eq(k, kp)
polar_phi, polar_z = self._polar(k, kp)
z = jnp.where(eq_z <= 2 / 3, eq_z, polar_z)
phi = jnp.where(eq_z <= 2 / 3, eq_phi, polar_phi)
return phi, z

def _eq(self, k, kp):
"""Cylidrical coordinates assuming the equatorial region"""
z = 2 / 3 * (2 - (kp + k) / self.nside)
phi = jnp.pi / 4 / self.nside * (self.nside - kp + k)
return phi, z

def _polar(self, k, kp):
"""Cylindrical coordinates assuming the polar region"""
j = jnp.abs(k) - 0.5
i = jnp.abs(kp) + jnp.abs(k)

z = 1 - (i / self.nside) ** 2 / 3
phi = 0.5 * jnp.pi * (j + 0.5) / i
phi = jnp.where(kp >= 0, phi, jnp.pi - phi)
phi = jnp.where(k >= 0, phi, -phi)
return phi, z
9 changes: 9 additions & 0 deletions src/healpix_geometry_analysis/enable_x64.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import jax


def enable_x64():
"""Make Jax to use double precision by default
It must be run before any other Jax code.
"""
jax.config.update("jax_enable_x64", True)
3 changes: 3 additions & 0 deletions tests/healpix_geometry_analysis/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from healpix_geometry_analysis.enable_x64 import enable_x64

enable_x64()
71 changes: 71 additions & 0 deletions tests/healpix_geometry_analysis/test_coordinates.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import healpy as hp
import jax.numpy as jnp
import pytest
from healpix_geometry_analysis.coordinates import HealpixCoordinates
from numpy.testing import assert_allclose


@pytest.mark.parametrize("order", [0, 1, 2, 4, 8, 16, 20])
def test_equatorial_region(order):
"""Check if the equatorial region works as expected."""
nside = 1 << order

lon_ = jnp.radians(jnp.linspace(0.0, 90.0, 33))
lat_ = jnp.arcsin(jnp.linspace(-2 / 3, 2 / 3 - 1 / nside, 33))
lon, lat = jnp.meshgrid(lon_, lat_)

# Use healpy to get the pixel centers
tile_id = hp.ang2pix(nside, jnp.degrees(lon), jnp.degrees(lat), lonlat=True, nest=True)
lonlat_center = hp.pix2ang(nside, tile_id, lonlat=True, nest=True)
xyz_center = hp.ang2vec(*lonlat_center, lonlat=True)
phi_center, z_center = jnp.radians(lonlat_center[0]), jnp.sin(jnp.radians(lonlat_center[1]))

# Get the diagonal indices
k_c = 3 * nside / 4 * (2 / 3 - z_center + 8 * phi_center / (3 * jnp.pi))
kp_c = nside + 3 * nside / 4 * (2 / 3 - z_center - 8 * phi_center / (3 * jnp.pi))

coords = HealpixCoordinates(nside)
phi, z = coords.phi_z(k_c, kp_c)

assert_allclose(phi, phi_center, atol=1e-3 / nside, rtol=1e-3 / nside)
assert_allclose(z, z_center, atol=1e-3 / nside, rtol=1e-3 / nside)

lonlat_degrees = coords.lonlat_degrees(k_c, kp_c)
assert_allclose(lonlat_degrees, lonlat_center, atol=1e-3 / nside, rtol=1e-3 / nside)

xyz = coords.xyz(k_c, kp_c)
assert_allclose(xyz, xyz_center.T, atol=1e-3 / nside, rtol=1e-3 / nside)


@pytest.mark.parametrize("order", [0, 1, 2, 5, 9, 13, 19])
def test_polar_region(order):
"""Check if the North polar region works good"""
nside = 1 << order

lon_ = jnp.radians(jnp.linspace(0.0, 90 - 1e-3 / nside, 33))
lat_ = jnp.arcsin(jnp.linspace(2 / 3, 1, 33))
lon, lat = jnp.meshgrid(lon_, lat_)

# Use healpix to get the pixel centers
tile_id = hp.ang2pix(nside, jnp.degrees(lon), jnp.degrees(lat), lonlat=True, nest=True)
lonlat_center = hp.pix2ang(nside, tile_id, lonlat=True, nest=True)
xyz_center = hp.ang2vec(*lonlat_center, lonlat=True)
phi_center, z_center = jnp.radians(lonlat_center[0]), jnp.sin(jnp.radians(lonlat_center[1]))

# Get the diagonal indices
i_c = jnp.sqrt(3) * nside * jnp.sqrt(1 - z_center)
j_c = 2 * i_c / jnp.pi * phi_center - 0.5
k_c = j_c + 0.5
kp_c = i_c - j_c - 0.5

coords = HealpixCoordinates.from_order(order)
phi, z = coords.phi_z(k_c, kp_c)

assert_allclose(phi, phi_center, atol=1e-3 / nside, rtol=1e-3 / nside)
assert_allclose(z, z_center, atol=1e-3 / nside, rtol=1e-3 / nside)

lonlat_degrees = coords.lonlat_degrees(k_c, kp_c)
assert_allclose(lonlat_degrees, lonlat_center, atol=1e-3 / nside, rtol=1e-3 / nside)

xyz = coords.xyz(k_c, kp_c)
assert_allclose(xyz, xyz_center.T, atol=1e-3 / nside, rtol=1e-3 / nside)

0 comments on commit 7737fbf

Please sign in to comment.