Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial implementation of diagonal coordinates #3

Merged
merged 1 commit into from
Jun 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ classifiers = [
dynamic = ["version"]
requires-python = ">=3.12"
dependencies = [
"jax",
"healpy",
"matplotlib",
"numpy",
Expand All @@ -27,6 +28,7 @@ dependencies = [
# On a mac, install optional dependencies with `pip install '.[dev]'` (include the single quotes)
[project.optional-dependencies]
dev = [
"jax[cpu]", # Run jax code on CPU
"jupyter", # Clears output from Jupyter notebooks
"pre-commit", # Used to run checks before finalizing a git commit
"pytest",
Expand Down
131 changes: 131 additions & 0 deletions src/healpix_geometry_analysis/coordinates.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
"""Healpix tile coordinates math"""

import dataclasses
from typing import Self

import jax.numpy as jnp


@dataclasses.dataclass
class HealpixCoordinates:
"""Healpix tile coordinates derived from diagonal indices

Parameters
----------
nside : int
Healpix nside parameter, 2^order
"""

nside: int
"""Healpix nside parameter, 2^order"""

@classmethod
def from_order(cls, order: int) -> Self:
"""Create HealpixCoordinates using healpix order (depth)"""
return cls(nside=1 << order)

def xyz(self, k, kp):
"""Cartesian coordinates on the unit sphere from diagonal indices

Parameters
----------
k : float
NW-SE diagonal index
kp : float
NE-SW diagonal index

Returns
-------
x : float
Cartesian x coordinate
y : float
Cartesian y coordinate
z : float
Cartesian z coordinate
"""
phi, z = self.phi_z(k, kp)
x = jnp.cos(phi) * jnp.sqrt(1 - z**2)
y = jnp.sin(phi) * jnp.sqrt(1 - z**2)
return x, y, z

def lonlat_radians(self, k, kp):
"""Longitude and latitude in radians from diagonal indices

Parameters
----------
k : float
SW-NE diagonal index
kp : float
SE-NW diagonal index

Returns
-------
lon : float
Longitude in radians
lat : float
Latitude in radians
"""
phi, z = self.phi_z(k, kp)
lon = phi
lat = jnp.arcsin(z)
return lon, lat

def lonlat_degrees(self, k, kp):
"""Longitude and latitude in degrees from diagonal indices

Parameters
----------
k : float
NW-SE diagonal index
kp : float
NE-SW diagonal index

Returns
-------
lon : float
Longitude in degrees
lat : float
Latitude in degrees
"""
lon, lat = self.lonlat_radians(k, kp)
return jnp.degrees(lon), jnp.degrees(lat)

def phi_z(self, k, kp):
"""Cylindrical coordinates from diagonal indices

Parameters
----------
k : float
NW-SE diagonal index
kp : float
NE-SW diagonal index

Returns
-------
phi : float
Longitude in radians
z : float
Sine of the latitude
"""
eq_phi, eq_z = self._eq(k, kp)
polar_phi, polar_z = self._polar(k, kp)
z = jnp.where(eq_z <= 2 / 3, eq_z, polar_z)
phi = jnp.where(eq_z <= 2 / 3, eq_phi, polar_phi)
return phi, z

def _eq(self, k, kp):
"""Cylidrical coordinates assuming the equatorial region"""
z = 2 / 3 * (2 - (kp + k) / self.nside)
phi = jnp.pi / 4 / self.nside * (self.nside - kp + k)
return phi, z

def _polar(self, k, kp):
"""Cylindrical coordinates assuming the polar region"""
j = jnp.abs(k) - 0.5
i = jnp.abs(kp) + jnp.abs(k)

z = 1 - (i / self.nside) ** 2 / 3
phi = 0.5 * jnp.pi * (j + 0.5) / i
phi = jnp.where(kp >= 0, phi, jnp.pi - phi)
phi = jnp.where(k >= 0, phi, -phi)
return phi, z
9 changes: 9 additions & 0 deletions src/healpix_geometry_analysis/enable_x64.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
import jax


def enable_x64():
"""Make Jax to use double precision by default

It must be run before any other Jax code.
"""
jax.config.update("jax_enable_x64", True)
3 changes: 3 additions & 0 deletions tests/healpix_geometry_analysis/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from healpix_geometry_analysis.enable_x64 import enable_x64

enable_x64()
71 changes: 71 additions & 0 deletions tests/healpix_geometry_analysis/test_coordinates.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import healpy as hp
import jax.numpy as jnp
import pytest
from healpix_geometry_analysis.coordinates import HealpixCoordinates
from numpy.testing import assert_allclose


@pytest.mark.parametrize("order", [0, 1, 2, 4, 8, 16, 20])
def test_equatorial_region(order):
"""Check if the equatorial region works as expected."""
nside = 1 << order

lon_ = jnp.radians(jnp.linspace(0.0, 90.0, 33))
lat_ = jnp.arcsin(jnp.linspace(-2 / 3, 2 / 3 - 1 / nside, 33))
lon, lat = jnp.meshgrid(lon_, lat_)

# Use healpy to get the pixel centers
tile_id = hp.ang2pix(nside, jnp.degrees(lon), jnp.degrees(lat), lonlat=True, nest=True)
lonlat_center = hp.pix2ang(nside, tile_id, lonlat=True, nest=True)
xyz_center = hp.ang2vec(*lonlat_center, lonlat=True)
phi_center, z_center = jnp.radians(lonlat_center[0]), jnp.sin(jnp.radians(lonlat_center[1]))

# Get the diagonal indices
k_c = 3 * nside / 4 * (2 / 3 - z_center + 8 * phi_center / (3 * jnp.pi))
kp_c = nside + 3 * nside / 4 * (2 / 3 - z_center - 8 * phi_center / (3 * jnp.pi))

coords = HealpixCoordinates(nside)
phi, z = coords.phi_z(k_c, kp_c)

assert_allclose(phi, phi_center, atol=1e-3 / nside, rtol=1e-3 / nside)
assert_allclose(z, z_center, atol=1e-3 / nside, rtol=1e-3 / nside)

lonlat_degrees = coords.lonlat_degrees(k_c, kp_c)
assert_allclose(lonlat_degrees, lonlat_center, atol=1e-3 / nside, rtol=1e-3 / nside)

xyz = coords.xyz(k_c, kp_c)
assert_allclose(xyz, xyz_center.T, atol=1e-3 / nside, rtol=1e-3 / nside)


@pytest.mark.parametrize("order", [0, 1, 2, 5, 9, 13, 19])
def test_polar_region(order):
"""Check if the North polar region works good"""
nside = 1 << order

lon_ = jnp.radians(jnp.linspace(0.0, 90 - 1e-3 / nside, 33))
lat_ = jnp.arcsin(jnp.linspace(2 / 3, 1, 33))
lon, lat = jnp.meshgrid(lon_, lat_)

# Use healpix to get the pixel centers
tile_id = hp.ang2pix(nside, jnp.degrees(lon), jnp.degrees(lat), lonlat=True, nest=True)
lonlat_center = hp.pix2ang(nside, tile_id, lonlat=True, nest=True)
xyz_center = hp.ang2vec(*lonlat_center, lonlat=True)
phi_center, z_center = jnp.radians(lonlat_center[0]), jnp.sin(jnp.radians(lonlat_center[1]))

# Get the diagonal indices
i_c = jnp.sqrt(3) * nside * jnp.sqrt(1 - z_center)
j_c = 2 * i_c / jnp.pi * phi_center - 0.5
k_c = j_c + 0.5
kp_c = i_c - j_c - 0.5

coords = HealpixCoordinates.from_order(order)
phi, z = coords.phi_z(k_c, kp_c)

assert_allclose(phi, phi_center, atol=1e-3 / nside, rtol=1e-3 / nside)
assert_allclose(z, z_center, atol=1e-3 / nside, rtol=1e-3 / nside)

lonlat_degrees = coords.lonlat_degrees(k_c, kp_c)
assert_allclose(lonlat_degrees, lonlat_center, atol=1e-3 / nside, rtol=1e-3 / nside)

xyz = coords.xyz(k_c, kp_c)
assert_allclose(xyz, xyz_center.T, atol=1e-3 / nside, rtol=1e-3 / nside)
Loading