Skip to content

Gaussian random coefficients from power spectra #296

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 92 additions & 7 deletions s2fft/utils/signal_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,12 @@
from s2fft.sampling import s2_samples as samples
from s2fft.sampling import so3_samples as wigner_samples

TYPE_CHECKING = False
if TYPE_CHECKING:
from types import ModuleType

import jax


def complex_normal(
rng: np.random.Generator,
Expand Down Expand Up @@ -74,6 +80,7 @@ def generate_flm(
spin: int = 0,
reality: bool = False,
using_torch: bool = False,
size: tuple[int, ...] | int | None = None,
) -> np.ndarray | torch.Tensor:
r"""
Generate a 2D set of random harmonic coefficients.
Expand All @@ -94,29 +101,39 @@ def generate_flm(

using_torch (bool, optional): Desired frontend functionality. Defaults to False.

size (tuple[int, ...] | int | None, optional): Shape of realisations.

Returns:
np.ndarray: Random set of spherical harmonic coefficients.

"""
flm = np.zeros(samples.flm_shape(L), dtype=np.complex128)
# always turn size into a tuple of int
if size is None:
size = ()
elif isinstance(size, int):
size = (size,)
elif not (isinstance(size, tuple) and all(isinstance(i, int) for i in size)):
raise TypeError("size must be int or tuple of int")

flm = np.zeros((*size, *samples.flm_shape(L)), dtype=np.complex128)
min_el = max(L_lower, abs(spin))
# m = 0 coefficients are always real
flm[min_el:L, L - 1] = rng.standard_normal(L - min_el)
flm[..., min_el:L, L - 1] = rng.standard_normal((*size, L - min_el))
# Construct arrays of m and el indices for entries in flm corresponding to complex-
# valued coefficients (m > 0)
el_indices, m_indices = complex_el_and_m_indices(L, min_el)
len_indices = len(m_indices)
rand_size = (*size, len(m_indices))
# Generate independent complex coefficients for positive m
flm[el_indices, L - 1 + m_indices] = complex_normal(rng, len_indices, var=2)
flm[..., el_indices, L - 1 + m_indices] = complex_normal(rng, rand_size, var=2)
if reality:
# Real-valued signal so set complex coefficients for negative m using conjugate
# symmetry such that flm[el, L - 1 - m] = (-1)**m * flm[el, L - 1 + m].conj
flm[el_indices, L - 1 - m_indices] = (-1) ** m_indices * (
flm[el_indices, L - 1 + m_indices].conj()
flm[..., el_indices, L - 1 - m_indices] = (-1) ** m_indices * (
flm[..., el_indices, L - 1 + m_indices].conj()
)
else:
# Non-real signal so generate independent complex coefficients for negative m
flm[el_indices, L - 1 - m_indices] = complex_normal(rng, len_indices, var=2)
flm[..., el_indices, L - 1 - m_indices] = complex_normal(rng, rand_size, var=2)
return torch.from_numpy(flm) if using_torch else flm


Expand Down Expand Up @@ -199,3 +216,71 @@ def generate_flmn(
rng, len_indices, var=2
)
return torch.from_numpy(flmn) if using_torch else flmn


def _get_array_namespace(obj: np.ndarray | jax.Array) -> ModuleType:
"""Return the correct array namespace for numpy or jax arrays."""
from sys import modules

if (numpy := modules.get("numpy")) and isinstance(obj, numpy.ndarray):
return numpy
if (jax := modules.get("jax")) and isinstance(obj, jax.Array):
return jax.numpy
raise TypeError(f"unknown array type: {type(obj)!r}")


def generate_flm_from_spectra(
rng: np.random.Generator,
spectra: np.ndarray | jax.Array,
) -> np.ndarray | jax.Array:
r"""
Generate a stack of random harmonic coefficients from power spectra.

The input power spectra must be a stack of shape *(K, K, L)* where
*K* is the number of fields to be sampled, and *L* is the harmonic
band-limit.

Args:
rng (Generator): Random number generator.

spectra (np.ndarray | jax.Array): Stack of angular power spectra.

Returns:
np.ndarray | jax.Array: A stack of random spherical harmonic
coefficients with the given power spectra.

"""
# get an ArrayAPI-ish namespace from spectra
xp = _get_array_namespace(spectra)

# check input
if spectra.ndim != 3 or spectra.shape[0] != spectra.shape[1]:
raise ValueError("shape of spectra must be (K, K, L)")

# K is the number of fields, L is the band limit
*_, K, L = spectra.shape

# permute shape (K, K, L) -> (L, K, K)
spectra = xp.permute_dims(spectra, (2, 0, 1))

# SVD for matrix square root
# not using cholesky() here because matrix may be semi-definite
# divides spectra by 2 for correct amplitude
u, s, vh = xp.linalg.svd(spectra / 2, full_matrices=False)

# compute the matrix square root for sampling
a = u @ (xp.sqrt(s[..., None]) * vh)

# permute shape (L, K, K) -> (K, K, L)
a = xp.permute_dims(a, (1, 2, 0))

# sample the random coefficients
# always use reality=True; one can assemble complex fields from them
# shape of flm is (K, L, M)
flm = generate_flm(rng, L, reality=True, size=K)

# compute the matrix multiplication by hand, because we have a mix of
# contraction (dim=K) and broadcasting (dim=L)
flm = (a[..., None] * flm).sum(axis=-3)

return flm
119 changes: 119 additions & 0 deletions tests/test_signal_generator.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import jax.numpy as jnp
import numpy as np
import pytest
from jax.test_util import check_grads

import s2fft
import s2fft.sampling as smp
Expand Down Expand Up @@ -55,6 +57,14 @@ def check_flm_conjugate_symmetry(flm, L, min_el):
assert flm[el, L - 1 - m] == (-1) ** m * flm[el, L - 1 + m].conj()


def check_flm_unequal(flm1, flm2, L, min_el):
"""assert that two passed flm are elementwise unequal"""
for el in range(L):
for m in range(L):
if not (el < min_el or m > el):
assert flm1[el, L - 1 + m] != flm2[el, L - 1 - m]


@pytest.mark.parametrize("L", L_values_to_test)
@pytest.mark.parametrize("L_lower", L_lower_to_test)
@pytest.mark.parametrize("spin", spin_to_test)
Expand All @@ -76,6 +86,26 @@ def test_generate_flm(rng, L, L_lower, spin, reality):
assert np.allclose(f_complex.real, f_real)


@pytest.mark.parametrize("L", L_values_to_test)
@pytest.mark.parametrize("L_lower", L_lower_to_test)
@pytest.mark.parametrize("spin", spin_to_test)
@pytest.mark.parametrize("reality", reality_values_to_test)
def test_generate_flm_size(rng, L, L_lower, spin, reality):
if reality and spin != 0:
pytest.skip("Reality only valid for scalar fields (spin=0).")

size = 2
flm = gen.generate_flm(rng, L, L_lower, spin, reality, size=size)
assert flm.shape == (size,) + smp.s2_samples.flm_shape(L)
check_flm_zeros(flm[0], L, max(L_lower, abs(spin)))
check_flm_zeros(flm[1], L, max(L_lower, abs(spin)))
check_flm_unequal(flm[0], flm[1], L, max(L_lower, abs(spin)))

size = (3, 4)
flm = gen.generate_flm(rng, L, L_lower, spin, reality, size=size)
assert flm.shape == size + smp.s2_samples.flm_shape(L)


def check_flmn_zeros(flmn, L, N, L_lower):
for n in range(-N + 1, N):
min_el = max(L_lower, abs(n))
Expand Down Expand Up @@ -117,3 +147,92 @@ def test_generate_flmn(rng, L, N, L_lower, reality):
assert np.allclose(f_complex.imag, 0)
f_real = s2fft.wigner.inverse(flmn, L, N, reality=True, L_lower=L_lower)
assert np.allclose(f_complex.real, f_real)


def gaussian_covariance(spectra):
"""Gaussian covariance for a stack of spectra.

If the shape of *spectra* is *(K, K, L)*, the shape of the
covariance is *(L, C, C)*, where ``C = K * (K + 1) // 2``
is the number of independent spectra.

"""
_, K, L = spectra.shape
row, col = np.tril_indices(K)
cov = np.zeros((L, row.size, col.size))
ell = np.arange(L)
for i, j in np.ndindex(row.size, col.size):
cov[:, i, j] = (
spectra[row[i], row[j]] * spectra[col[i], col[j]]
+ spectra[row[i], col[j]] * spectra[col[i], row[j]]
) / (2 * ell + 1)
return cov


@pytest.mark.parametrize("L", L_values_to_test)
@pytest.mark.parametrize("xp", [np, jnp])
def test_generate_flm_from_spectra(rng, L, xp):
# number of fields to generate
K = 4

# correlation matrix for fields, applied to all ell
corr = xp.asarray(
[
[1.0, 0.1, -0.1, 0.1],
[0.1, 1.0, 0.1, -0.1],
[-0.1, 0.1, 1.0, 0.1],
[0.1, -0.1, 0.1, 1.0],
],
)

ell = xp.arange(L)

# auto-spectra are power laws
powers = xp.arange(1, K + 1)
auto = 1 / (2 * ell + 1) ** powers[:, None]

# compute the spectra from auto and corr
spectra = xp.sqrt(auto[:, None, :] * auto[None, :, :]) * corr[:, :, None]
assert spectra.shape == (K, K, L)

# generate random flm from spectra
flm = s2fft.utils.signal_generator.generate_flm_from_spectra(rng, spectra)
assert flm.shape == (K, L, 2 * L - 1)

# compute the realised spectra
re, im = flm.real, flm.imag
result = (
re[None, :, :, :] * re[:, None, :, :] + im[None, :, :, :] * im[:, None, :, :]
)
result = result.sum(axis=-1) / (2 * ell + 1)

# compute covariance of sampled spectra
cov = gaussian_covariance(spectra)

# data vector, remove duplicate entries, and put L dim first
x = result - spectra
x = x[np.tril_indices(K)]
x = x.T

# compute chi2/n of realised spectra
y = xp.linalg.solve(cov, x[..., None])[..., 0]
n = x.size
chi2_n = (x * y).sum() / n

# make sure chi2/n is as expected
sigma = np.sqrt(2 / n)
assert np.fabs(chi2_n - 1.0) < 3 * sigma


@pytest.mark.parametrize("L", L_values_to_test)
def test_generate_flm_from_spectra_grads(L):
# fixed set of power spectra
ell = jnp.arange(L)
cl = 1 / (2 * ell + 1)
spectra = cl.reshape(1, 1, L)

def func(x):
rng = np.random.default_rng(42)
return s2fft.utils.signal_generator.generate_flm_from_spectra(rng, x)

check_grads(func, (spectra,), 1)
Loading