diff --git a/s2fft/utils/signal_generator.py b/s2fft/utils/signal_generator.py index 799de9b9..256c135c 100644 --- a/s2fft/utils/signal_generator.py +++ b/s2fft/utils/signal_generator.py @@ -6,6 +6,12 @@ from s2fft.sampling import s2_samples as samples from s2fft.sampling import so3_samples as wigner_samples +TYPE_CHECKING = False +if TYPE_CHECKING: + from types import ModuleType + + import jax + def complex_normal( rng: np.random.Generator, @@ -74,6 +80,7 @@ def generate_flm( spin: int = 0, reality: bool = False, using_torch: bool = False, + size: tuple[int, ...] | int | None = None, ) -> np.ndarray | torch.Tensor: r""" Generate a 2D set of random harmonic coefficients. @@ -94,29 +101,39 @@ def generate_flm( using_torch (bool, optional): Desired frontend functionality. Defaults to False. + size (tuple[int, ...] | int | None, optional): Shape of realisations. + Returns: np.ndarray: Random set of spherical harmonic coefficients. """ - flm = np.zeros(samples.flm_shape(L), dtype=np.complex128) + # always turn size into a tuple of int + if size is None: + size = () + elif isinstance(size, int): + size = (size,) + elif not (isinstance(size, tuple) and all(isinstance(i, int) for i in size)): + raise TypeError("size must be int or tuple of int") + + flm = np.zeros((*size, *samples.flm_shape(L)), dtype=np.complex128) min_el = max(L_lower, abs(spin)) # m = 0 coefficients are always real - flm[min_el:L, L - 1] = rng.standard_normal(L - min_el) + flm[..., min_el:L, L - 1] = rng.standard_normal((*size, L - min_el)) # Construct arrays of m and el indices for entries in flm corresponding to complex- # valued coefficients (m > 0) el_indices, m_indices = complex_el_and_m_indices(L, min_el) - len_indices = len(m_indices) + rand_size = (*size, len(m_indices)) # Generate independent complex coefficients for positive m - flm[el_indices, L - 1 + m_indices] = complex_normal(rng, len_indices, var=2) + flm[..., el_indices, L - 1 + m_indices] = complex_normal(rng, rand_size, var=2) if reality: # Real-valued signal so set complex coefficients for negative m using conjugate # symmetry such that flm[el, L - 1 - m] = (-1)**m * flm[el, L - 1 + m].conj - flm[el_indices, L - 1 - m_indices] = (-1) ** m_indices * ( - flm[el_indices, L - 1 + m_indices].conj() + flm[..., el_indices, L - 1 - m_indices] = (-1) ** m_indices * ( + flm[..., el_indices, L - 1 + m_indices].conj() ) else: # Non-real signal so generate independent complex coefficients for negative m - flm[el_indices, L - 1 - m_indices] = complex_normal(rng, len_indices, var=2) + flm[..., el_indices, L - 1 - m_indices] = complex_normal(rng, rand_size, var=2) return torch.from_numpy(flm) if using_torch else flm @@ -199,3 +216,71 @@ def generate_flmn( rng, len_indices, var=2 ) return torch.from_numpy(flmn) if using_torch else flmn + + +def _get_array_namespace(obj: np.ndarray | jax.Array) -> ModuleType: + """Return the correct array namespace for numpy or jax arrays.""" + from sys import modules + + if (numpy := modules.get("numpy")) and isinstance(obj, numpy.ndarray): + return numpy + if (jax := modules.get("jax")) and isinstance(obj, jax.Array): + return jax.numpy + raise TypeError(f"unknown array type: {type(obj)!r}") + + +def generate_flm_from_spectra( + rng: np.random.Generator, + spectra: np.ndarray | jax.Array, +) -> np.ndarray | jax.Array: + r""" + Generate a stack of random harmonic coefficients from power spectra. + + The input power spectra must be a stack of shape *(K, K, L)* where + *K* is the number of fields to be sampled, and *L* is the harmonic + band-limit. + + Args: + rng (Generator): Random number generator. + + spectra (np.ndarray | jax.Array): Stack of angular power spectra. + + Returns: + np.ndarray | jax.Array: A stack of random spherical harmonic + coefficients with the given power spectra. + + """ + # get an ArrayAPI-ish namespace from spectra + xp = _get_array_namespace(spectra) + + # check input + if spectra.ndim != 3 or spectra.shape[0] != spectra.shape[1]: + raise ValueError("shape of spectra must be (K, K, L)") + + # K is the number of fields, L is the band limit + *_, K, L = spectra.shape + + # permute shape (K, K, L) -> (L, K, K) + spectra = xp.permute_dims(spectra, (2, 0, 1)) + + # SVD for matrix square root + # not using cholesky() here because matrix may be semi-definite + # divides spectra by 2 for correct amplitude + u, s, vh = xp.linalg.svd(spectra / 2, full_matrices=False) + + # compute the matrix square root for sampling + a = u @ (xp.sqrt(s[..., None]) * vh) + + # permute shape (L, K, K) -> (K, K, L) + a = xp.permute_dims(a, (1, 2, 0)) + + # sample the random coefficients + # always use reality=True; one can assemble complex fields from them + # shape of flm is (K, L, M) + flm = generate_flm(rng, L, reality=True, size=K) + + # compute the matrix multiplication by hand, because we have a mix of + # contraction (dim=K) and broadcasting (dim=L) + flm = (a[..., None] * flm).sum(axis=-3) + + return flm diff --git a/tests/test_signal_generator.py b/tests/test_signal_generator.py index 4f9be7d0..dbf64abf 100644 --- a/tests/test_signal_generator.py +++ b/tests/test_signal_generator.py @@ -1,5 +1,7 @@ +import jax.numpy as jnp import numpy as np import pytest +from jax.test_util import check_grads import s2fft import s2fft.sampling as smp @@ -55,6 +57,14 @@ def check_flm_conjugate_symmetry(flm, L, min_el): assert flm[el, L - 1 - m] == (-1) ** m * flm[el, L - 1 + m].conj() +def check_flm_unequal(flm1, flm2, L, min_el): + """assert that two passed flm are elementwise unequal""" + for el in range(L): + for m in range(L): + if not (el < min_el or m > el): + assert flm1[el, L - 1 + m] != flm2[el, L - 1 - m] + + @pytest.mark.parametrize("L", L_values_to_test) @pytest.mark.parametrize("L_lower", L_lower_to_test) @pytest.mark.parametrize("spin", spin_to_test) @@ -76,6 +86,26 @@ def test_generate_flm(rng, L, L_lower, spin, reality): assert np.allclose(f_complex.real, f_real) +@pytest.mark.parametrize("L", L_values_to_test) +@pytest.mark.parametrize("L_lower", L_lower_to_test) +@pytest.mark.parametrize("spin", spin_to_test) +@pytest.mark.parametrize("reality", reality_values_to_test) +def test_generate_flm_size(rng, L, L_lower, spin, reality): + if reality and spin != 0: + pytest.skip("Reality only valid for scalar fields (spin=0).") + + size = 2 + flm = gen.generate_flm(rng, L, L_lower, spin, reality, size=size) + assert flm.shape == (size,) + smp.s2_samples.flm_shape(L) + check_flm_zeros(flm[0], L, max(L_lower, abs(spin))) + check_flm_zeros(flm[1], L, max(L_lower, abs(spin))) + check_flm_unequal(flm[0], flm[1], L, max(L_lower, abs(spin))) + + size = (3, 4) + flm = gen.generate_flm(rng, L, L_lower, spin, reality, size=size) + assert flm.shape == size + smp.s2_samples.flm_shape(L) + + def check_flmn_zeros(flmn, L, N, L_lower): for n in range(-N + 1, N): min_el = max(L_lower, abs(n)) @@ -117,3 +147,92 @@ def test_generate_flmn(rng, L, N, L_lower, reality): assert np.allclose(f_complex.imag, 0) f_real = s2fft.wigner.inverse(flmn, L, N, reality=True, L_lower=L_lower) assert np.allclose(f_complex.real, f_real) + + +def gaussian_covariance(spectra): + """Gaussian covariance for a stack of spectra. + + If the shape of *spectra* is *(K, K, L)*, the shape of the + covariance is *(L, C, C)*, where ``C = K * (K + 1) // 2`` + is the number of independent spectra. + + """ + _, K, L = spectra.shape + row, col = np.tril_indices(K) + cov = np.zeros((L, row.size, col.size)) + ell = np.arange(L) + for i, j in np.ndindex(row.size, col.size): + cov[:, i, j] = ( + spectra[row[i], row[j]] * spectra[col[i], col[j]] + + spectra[row[i], col[j]] * spectra[col[i], row[j]] + ) / (2 * ell + 1) + return cov + + +@pytest.mark.parametrize("L", L_values_to_test) +@pytest.mark.parametrize("xp", [np, jnp]) +def test_generate_flm_from_spectra(rng, L, xp): + # number of fields to generate + K = 4 + + # correlation matrix for fields, applied to all ell + corr = xp.asarray( + [ + [1.0, 0.1, -0.1, 0.1], + [0.1, 1.0, 0.1, -0.1], + [-0.1, 0.1, 1.0, 0.1], + [0.1, -0.1, 0.1, 1.0], + ], + ) + + ell = xp.arange(L) + + # auto-spectra are power laws + powers = xp.arange(1, K + 1) + auto = 1 / (2 * ell + 1) ** powers[:, None] + + # compute the spectra from auto and corr + spectra = xp.sqrt(auto[:, None, :] * auto[None, :, :]) * corr[:, :, None] + assert spectra.shape == (K, K, L) + + # generate random flm from spectra + flm = s2fft.utils.signal_generator.generate_flm_from_spectra(rng, spectra) + assert flm.shape == (K, L, 2 * L - 1) + + # compute the realised spectra + re, im = flm.real, flm.imag + result = ( + re[None, :, :, :] * re[:, None, :, :] + im[None, :, :, :] * im[:, None, :, :] + ) + result = result.sum(axis=-1) / (2 * ell + 1) + + # compute covariance of sampled spectra + cov = gaussian_covariance(spectra) + + # data vector, remove duplicate entries, and put L dim first + x = result - spectra + x = x[np.tril_indices(K)] + x = x.T + + # compute chi2/n of realised spectra + y = xp.linalg.solve(cov, x[..., None])[..., 0] + n = x.size + chi2_n = (x * y).sum() / n + + # make sure chi2/n is as expected + sigma = np.sqrt(2 / n) + assert np.fabs(chi2_n - 1.0) < 3 * sigma + + +@pytest.mark.parametrize("L", L_values_to_test) +def test_generate_flm_from_spectra_grads(L): + # fixed set of power spectra + ell = jnp.arange(L) + cl = 1 / (2 * ell + 1) + spectra = cl.reshape(1, 1, L) + + def func(x): + rng = np.random.default_rng(42) + return s2fft.utils.signal_generator.generate_flm_from_spectra(rng, x) + + check_grads(func, (spectra,), 1)