Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add preconditioner to sparse solver #24

Merged
merged 9 commits into from
Feb 14, 2025
47 changes: 47 additions & 0 deletions hera_filters/dspec.py
Original file line number Diff line number Diff line change
Expand Up @@ -2899,6 +2899,8 @@ def sparse_linear_fit_2D(
atol: float = 1e-10,
btol: float = 1e-10,
iter_lim: int = None,
precondition_solver: bool = False,
eig_scaling_factor: float = 1e-1,
**kwargs
) -> np.ndarray:
"""
Expand Down Expand Up @@ -2927,6 +2929,22 @@ def sparse_linear_fit_2D(
flattened `data` array, x is the solution, and r is the residual.
iter_lim : int, optional
Maximum number of iterations for `lsqr`, default is None
precondition_solver : bool, optional, default False
If True, the solver will apply a preconditioner to the basis matrices before
solving the least-squares problem. This option is useful when the input weights
are frequency or time dependent and are either very large or very small, or when
the basis matrices are ill-conditioned due to large stretches of zeros.
The preconditioner is computed using the the inverse of the regularized Gramian
matrix (X^T W X) of the basis matrices. Prior to computing the inverse, the eigenvalues
of the Gramian matrix are regularized by adding a small value proportional to the smallest
eigenvalue. This helps to stabilize the computation of the inverse. The regularization
factor is computed as the minimum eigenvalue of the Gramian matrix multiplied by the
`eig_scaling_factor` parameter.
eig_scaling_factor : float, optional, default 1e-1
Regularization factor for the eigenvalues of the Gramian matrix. The factor
is computed as the minimum eigenvalue of the Gramian matrix multiplied by
`eig_scaling_factor`. Reasonable values are typically in the range of 1e-1
to 1e-3.
**kwargs : dict
Additional keyword arguments passed to `scipy.sparse.linalg.lsqr`.

Expand Down Expand Up @@ -2960,6 +2978,32 @@ def sparse_linear_fit_2D(
axis_1_basis.shape[-1] * axis_2_basis.shape[-1], # i * j
)

if precondition_solver:
# Compute separate preconditioners for the two axes
# Start by computing separable weights for the two axes
u, s, v = sparse.linalg.svds(weights, k=1)
axis_1_wgts = np.abs(u[:, 0] * np.sqrt(s[0]))
axis_2_wgts = np.abs(v[0] * np.sqrt(s[0]))

# Compute the preconditioner for the first axis
XTX_axis_1 = np.dot(axis_1_basis.T.conj() * axis_1_wgts, axis_1_basis)
eigenval = sparse.linalg.eigs(XTX_axis_1, k=1, which='SR', return_eigenvectors=False)
axis_1_lambda = eigenval.real * eig_scaling_factor
axis_1_pcond = np.linalg.pinv(
XTX_axis_1 + np.eye(XTX_axis_1.shape[0]) * axis_1_lambda
)

# Compute the preconditioner for the second axis
XTX_axis_2 = np.dot(axis_2_basis.T.conj() * axis_2_wgts, axis_2_basis)
eigenval = sparse.linalg.eigs(XTX_axis_2, k=1, which='SR', return_eigenvectors=False)
axis_2_lambda = eigenval.real * eig_scaling_factor
axis_2_pcond = np.linalg.pinv(
XTX_axis_2 + np.eye(XTX_axis_2.shape[0]) * axis_2_lambda
)

axis_1_basis = np.dot(axis_1_basis, axis_1_pcond)
axis_2_basis = np.dot(axis_2_basis, axis_2_pcond)

# Define the implicit LinearOperator representing the Kronecker product
linear_operator = sparse.linalg.LinearOperator(
full_operator_shape,
Expand All @@ -2985,6 +3029,9 @@ def sparse_linear_fit_2D(
# Reshape output
x = x.reshape(axis_1_basis.shape[-1], axis_2_basis.shape[-1])

if precondition_solver:
x = np.dot(axis_1_pcond, x).dot(axis_2_pcond)

return x, meta

def separable_linear_fit_2D(
Expand Down
55 changes: 55 additions & 0 deletions hera_filters/tests/test_dspec.py
Original file line number Diff line number Diff line change
Expand Up @@ -1535,3 +1535,58 @@ def test_sparse_linear_fit_2d_non_binary_wgts():

# Check that the fit closely matches to the separable fit
np.testing.assert_allclose(sol, sol_sparse, atol=1e-9, rtol=1e-6)

def test_precondition_sparse_solver():
# test that separable linear fit works as expected.
ntimes, nfreqs = 100, 50

# Generate some data/flags
# By construction, the data is separable in the time and frequency directions
# and the flags are also separable. The fit should be able to recover the
# true data in the unflagged region.
rng = np.random.default_rng(42)
freq_basis, _ = dspec.dpss_operator(np.linspace(100e6, 200e6, nfreqs), [0], [20e-9], eigenval_cutoff=[1e-12])
time_basis, _ = dspec.dpss_operator(np.linspace(0, ntimes * 10, ntimes), [0], [1e-3], eigenval_cutoff=[1e-12])
time_flags = rng.choice([True, False], p=[0.1, 0.9], size=(ntimes, 1))
freq_flags = rng.choice([True, False], p=[0.1, 0.9], size=(1, nfreqs))
x_true = rng.normal(0, 1, size=(time_basis.shape[-1], freq_basis.shape[-1]))
data = np.dot(time_basis, x_true).dot(freq_basis.T)
freqs = np.linspace(100e6, 200e6, nfreqs)

# Generate separable, non-binary weights
axis_1_weights = (~time_flags[:, 0]).astype(float) * rng.integers(1, 10, size=(ntimes,))
axis_2_weights = (~freq_flags[0]).astype(float)
wgts = np.outer(axis_1_weights, axis_2_weights)

# Add frequency dependence to the weights to make the problem more ill-conditioned
wgts *= (freqs / 150e6) ** -3.5

# Fit the data
sol = dspec.separable_linear_fit_2D(
data=data,
axis_1_weights=(~time_flags[:, 0]).astype(float),
axis_2_weights=(~freq_flags[0]).astype(float),
axis_1_basis=time_basis,
axis_2_basis=freq_basis,
)

sol_sparse, meta = dspec.sparse_linear_fit_2D(
data=data,
weights=wgts,
axis_1_basis=time_basis,
axis_2_basis=freq_basis,
precondition_solver=False
)

sol_sparse_precond, meta_precond = dspec.sparse_linear_fit_2D(
data=data,
weights=wgts,
axis_1_basis=time_basis,
axis_2_basis=freq_basis,
precondition_solver=True
)

# Check that the fit closely matches to the separable fit
np.testing.assert_allclose(sol, sol_sparse, atol=1e-8, rtol=1e-6)
np.testing.assert_allclose(sol, sol_sparse_precond, atol=1e-8, rtol=1e-6)
np.testing.assert_array_less(meta_precond['iter_num'], meta['iter_num'])