diff --git a/hera_filters/dspec.py b/hera_filters/dspec.py index 0249925..087727e 100644 --- a/hera_filters/dspec.py +++ b/hera_filters/dspec.py @@ -2899,6 +2899,8 @@ def sparse_linear_fit_2D( atol: float = 1e-10, btol: float = 1e-10, iter_lim: int = None, + precondition_solver: bool = False, + eig_scaling_factor: float = 1e-1, **kwargs ) -> np.ndarray: """ @@ -2927,6 +2929,22 @@ def sparse_linear_fit_2D( flattened `data` array, x is the solution, and r is the residual. iter_lim : int, optional Maximum number of iterations for `lsqr`, default is None + precondition_solver : bool, optional, default False + If True, the solver will apply a preconditioner to the basis matrices before + solving the least-squares problem. This option is useful when the input weights + are frequency or time dependent and are either very large or very small, or when + the basis matrices are ill-conditioned due to large stretches of zeros. + The preconditioner is computed using the the inverse of the regularized Gramian + matrix (X^T W X) of the basis matrices. Prior to computing the inverse, the eigenvalues + of the Gramian matrix are regularized by adding a small value proportional to the smallest + eigenvalue. This helps to stabilize the computation of the inverse. The regularization + factor is computed as the minimum eigenvalue of the Gramian matrix multiplied by the + `eig_scaling_factor` parameter. + eig_scaling_factor : float, optional, default 1e-1 + Regularization factor for the eigenvalues of the Gramian matrix. The factor + is computed as the minimum eigenvalue of the Gramian matrix multiplied by + `eig_scaling_factor`. Reasonable values are typically in the range of 1e-1 + to 1e-3. **kwargs : dict Additional keyword arguments passed to `scipy.sparse.linalg.lsqr`. @@ -2960,6 +2978,32 @@ def sparse_linear_fit_2D( axis_1_basis.shape[-1] * axis_2_basis.shape[-1], # i * j ) + if precondition_solver: + # Compute separate preconditioners for the two axes + # Start by computing separable weights for the two axes + u, s, v = sparse.linalg.svds(weights, k=1) + axis_1_wgts = np.abs(u[:, 0] * np.sqrt(s[0])) + axis_2_wgts = np.abs(v[0] * np.sqrt(s[0])) + + # Compute the preconditioner for the first axis + XTX_axis_1 = np.dot(axis_1_basis.T.conj() * axis_1_wgts, axis_1_basis) + eigenval = sparse.linalg.eigs(XTX_axis_1, k=1, which='SM', return_eigenvectors=False) + axis_1_lambda = np.abs(eigenval) * eig_scaling_factor + axis_1_pcond = np.linalg.pinv( + XTX_axis_1 + np.eye(XTX_axis_1.shape[0]) * axis_1_lambda + ) + + # Compute the preconditioner for the second axis + XTX_axis_2 = np.dot(axis_2_basis.T.conj() * axis_2_wgts, axis_2_basis) + eigenval = sparse.linalg.eigs(XTX_axis_2, k=1, which='SM', return_eigenvectors=False) + axis_2_lambda = np.abs(eigenval) * eig_scaling_factor + axis_2_pcond = np.linalg.pinv( + XTX_axis_2 + np.eye(XTX_axis_2.shape[0]) * axis_2_lambda + ) + + axis_1_basis = np.dot(axis_1_basis, axis_1_pcond) + axis_2_basis = np.dot(axis_2_basis, axis_2_pcond) + # Define the implicit LinearOperator representing the Kronecker product linear_operator = sparse.linalg.LinearOperator( full_operator_shape, @@ -2985,6 +3029,9 @@ def sparse_linear_fit_2D( # Reshape output x = x.reshape(axis_1_basis.shape[-1], axis_2_basis.shape[-1]) + if precondition_solver: + x = np.dot(axis_1_pcond, x).dot(axis_2_pcond) + return x, meta def separable_linear_fit_2D( diff --git a/hera_filters/tests/test_dspec.py b/hera_filters/tests/test_dspec.py index 931836c..7c3c0a7 100644 --- a/hera_filters/tests/test_dspec.py +++ b/hera_filters/tests/test_dspec.py @@ -1535,3 +1535,58 @@ def test_sparse_linear_fit_2d_non_binary_wgts(): # Check that the fit closely matches to the separable fit np.testing.assert_allclose(sol, sol_sparse, atol=1e-9, rtol=1e-6) + +def test_precondition_sparse_solver(): + # test that separable linear fit works as expected. + ntimes, nfreqs = 100, 50 + + # Generate some data/flags + # By construction, the data is separable in the time and frequency directions + # and the flags are also separable. The fit should be able to recover the + # true data in the unflagged region. + rng = np.random.default_rng(42) + freq_basis, _ = dspec.dpss_operator(np.linspace(100e6, 200e6, nfreqs), [0], [20e-9], eigenval_cutoff=[1e-12]) + time_basis, _ = dspec.dpss_operator(np.linspace(0, ntimes * 10, ntimes), [0], [1e-3], eigenval_cutoff=[1e-12]) + time_flags = rng.choice([True, False], p=[0.1, 0.9], size=(ntimes, 1)) + freq_flags = rng.choice([True, False], p=[0.1, 0.9], size=(1, nfreqs)) + x_true = rng.normal(0, 1, size=(time_basis.shape[-1], freq_basis.shape[-1])) + data = np.dot(time_basis, x_true).dot(freq_basis.T) + freqs = np.linspace(100e6, 200e6, nfreqs) + + # Generate separable, non-binary weights + axis_1_weights = (~time_flags[:, 0]).astype(float) * rng.integers(1, 10, size=(ntimes,)) + axis_2_weights = (~freq_flags[0]).astype(float) + wgts = np.outer(axis_1_weights, axis_2_weights) + + # Add frequency dependence to the weights to make the problem more ill-conditioned + wgts *= (freqs / 150e6) ** -3.5 + + # Fit the data + sol = dspec.separable_linear_fit_2D( + data=data, + axis_1_weights=(~time_flags[:, 0]).astype(float), + axis_2_weights=(~freq_flags[0]).astype(float), + axis_1_basis=time_basis, + axis_2_basis=freq_basis, + ) + + sol_sparse, meta = dspec.sparse_linear_fit_2D( + data=data, + weights=wgts, + axis_1_basis=time_basis, + axis_2_basis=freq_basis, + precondition_solver=False + ) + + sol_sparse_precond, meta_precond = dspec.sparse_linear_fit_2D( + data=data, + weights=wgts, + axis_1_basis=time_basis, + axis_2_basis=freq_basis, + precondition_solver=True + ) + + # Check that the fit closely matches to the separable fit + np.testing.assert_allclose(sol, sol_sparse, atol=1e-8, rtol=1e-6) + np.testing.assert_allclose(sol, sol_sparse_precond, atol=1e-8, rtol=1e-6) + np.testing.assert_array_less(meta_precond['iter_num'], meta['iter_num'])