Skip to content

Commit

Permalink
Bug in sparse fitter when weights are not binary (#23)
Browse files Browse the repository at this point in the history
* fix sparse solver when weights are not binary

* document test

* seed random number generator

* wrong name for random number generator

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
tyler-a-cox and pre-commit-ci[bot] authored Feb 13, 2025
1 parent f165ec1 commit 1403da3
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 3 deletions.
4 changes: 2 additions & 2 deletions hera_filters/dspec.py
Original file line number Diff line number Diff line change
Expand Up @@ -2821,7 +2821,7 @@ def _kron_matvec(
Given axis_1_basis (m x i) and axis_2_basis (n x j), this function computes:
Y = (axis_1_basis @ X @ axis_2_basis^T) * weights
where X is the reshaped version of xs.
where X is the reshaped version of x.
Parameters:
----------
Expand Down Expand Up @@ -2975,7 +2975,7 @@ def sparse_linear_fit_2D(
*_
)= sparse.linalg.lsqr(
A=linear_operator,
b=data.ravel(),
b=(data * weights).ravel(),
atol=atol,
btol=btol,
iter_lim=iter_lim,
Expand Down
43 changes: 42 additions & 1 deletion hera_filters/tests/test_dspec.py
Original file line number Diff line number Diff line change
Expand Up @@ -1467,7 +1467,7 @@ def test_sparse_linear_fit_2d():


# Check that the fit closely matches to the separable fit
np.testing.assert_allclose(sol, sol_sparse)
np.testing.assert_allclose(sol, sol_sparse, atol=1e-9, rtol=1e-6)

# Errors should be raised if the data, weights, and bases are not compatible
pytest.raises(
Expand All @@ -1494,3 +1494,44 @@ def test_sparse_linear_fit_2d():
axis_1_basis=time_basis,
axis_2_basis=freq_basis,
)

def test_sparse_linear_fit_2d_non_binary_wgts():
# test that separable linear fit works as expected.
ntimes, nfreqs = 100, 50

# Generate some data/flags
# By construction, the data is separable in the time and frequency directions
# and the flags are also separable. The fit should be able to recover the
# true data in the unflagged region.
rng = np.random.default_rng(42)
freq_basis, _ = dspec.dpss_operator(np.linspace(100e6, 200e6, nfreqs), [0], [20e-9], eigenval_cutoff=[1e-12])
time_basis, _ = dspec.dpss_operator(np.linspace(0, ntimes * 10, ntimes), [0], [1e-3], eigenval_cutoff=[1e-12])
time_flags = rng.choice([True, False], p=[0.1, 0.9], size=(ntimes, 1))
freq_flags = rng.choice([True, False], p=[0.1, 0.9], size=(1, nfreqs))
x_true = rng.normal(0, 1, size=(time_basis.shape[-1], freq_basis.shape[-1]))
data = np.dot(time_basis, x_true).dot(freq_basis.T)
flags = (time_flags | freq_flags)

# Generate separable, non-binary weights
axis_1_weights = (~time_flags[:, 0]).astype(float) * rng.integers(1, 10, size=(ntimes,))
axis_2_weights = (~freq_flags[0]).astype(float)
wgts = np.outer(axis_1_weights, axis_2_weights)

# Fit the data
sol = dspec.separable_linear_fit_2D(
data=data,
axis_1_weights=(~time_flags[:, 0]).astype(float),
axis_2_weights=(~freq_flags[0]).astype(float),
axis_1_basis=time_basis,
axis_2_basis=freq_basis,
)

sol_sparse, meta = dspec.sparse_linear_fit_2D(
data=data,
weights=wgts,
axis_1_basis=time_basis,
axis_2_basis=freq_basis,
)

# Check that the fit closely matches to the separable fit
np.testing.assert_allclose(sol, sol_sparse, atol=1e-9, rtol=1e-6)

0 comments on commit 1403da3

Please sign in to comment.