diff --git a/hera_filters/dspec.py b/hera_filters/dspec.py index f52461b..29bfb02 100644 --- a/hera_filters/dspec.py +++ b/hera_filters/dspec.py @@ -2811,9 +2811,9 @@ def dayenu_mat_inv(x, filter_centers, filter_half_widths, return sdwi_mat def _kron_matvec( - x: np.ndarray, - weights: np.ndarray, - axis_1_basis: np.ndarray, + x: np.ndarray, + weights: np.ndarray, + axis_1_basis: np.ndarray, axis_2_basis: np.ndarray )-> np.ndarray: """ @@ -2840,20 +2840,20 @@ def _kron_matvec( Flattened result of size (m * n,). """ i, j = axis_1_basis.shape[1], axis_2_basis.shape[1] - + # Reshape v into (m, n) matrix using Fortran order to match vectorization X = x.reshape((i, j)) - + # Compute the transformation Y = (axis_1_basis @ X) @ axis_2_basis.T - + # Apply the weight W and return flattened result return (Y * weights).ravel() def _kron_rmatvec( - data: np.ndarray, - weights: np.ndarray, - axis_1_basis: np.ndarray, + data: np.ndarray, + weights: np.ndarray, + axis_1_basis: np.ndarray, axis_2_basis: np.ndarray ) -> np.ndarray: """ @@ -2880,13 +2880,13 @@ def _kron_rmatvec( Flattened result of size (i * j,). """ m, n = axis_1_basis.shape[0], axis_2_basis.shape[0] - + # Reshape u into (m, n) matrix using Fortran order and apply W X = (data.reshape((m, n))) * weights - + # Compute the transformation Y = (axis_1_basis.T.conj() @ X) @ axis_2_basis.conj() - + # Return flattened result return Y.ravel() @@ -2920,7 +2920,7 @@ def sparse_linear_fit_2D( axis_2_basis : np.ndarray Basis basis along the second axis, shape (n, j). atol, btol : float, optional, default 1e-10 - Stopping tolerances for `lsqr`. The algorithm terminates when the + Stopping tolerances for `lsqr`. The algorithm terminates when the ``norm(r) <= atol * norm(A) * norm(x) + btol * norm(b)``, where A is the implicit Kronecker product of `axis_1_basis` and `axis_2_basis`, and b is the flattened `data` array, x is the solution, and r is the residual. @@ -2959,10 +2959,10 @@ def sparse_linear_fit_2D( meta = {} # Solve the least-squares problem using LSQR ( - x, - meta['istop'], - meta['iter_num'], - *_ + x, + meta['istop'], + meta['iter_num'], + *_ )= sparse.linalg.lsqr( A=linear_operator, b=data.ravel(), @@ -2987,7 +2987,7 @@ def separable_linear_fit_2D( """ Solves a separable linear least-squares problem using weighted basis basis. - This function fits the input `data` using a least-squares approach with + This function fits the input `data` using a least-squares approach with separable weighting along two axes. The solution is computed using pseudo-inverses. Parameters: @@ -3024,4 +3024,4 @@ def separable_linear_fit_2D( # Compute the final solution x = axis_1_operator @ data @ axis_2_operator.T - return x \ No newline at end of file + return x diff --git a/hera_filters/tests/test_dspec.py b/hera_filters/tests/test_dspec.py index 4aa067a..0432180 100644 --- a/hera_filters/tests/test_dspec.py +++ b/hera_filters/tests/test_dspec.py @@ -1402,8 +1402,8 @@ def test_separable_linear_fit_2D(): # Errors should be raised if the data, weights, and bases are not compatible pytest.raises( - ValueError, - dspec.separable_linear_fit_2D, + ValueError, + dspec.separable_linear_fit_2D, data=data, axis_1_weights=(~time_flags[:, 0]).astype(float), axis_2_weights=(~freq_flags[0]).astype(float), @@ -1412,8 +1412,8 @@ def test_separable_linear_fit_2D(): ) pytest.raises( - ValueError, - dspec.separable_linear_fit_2D, + ValueError, + dspec.separable_linear_fit_2D, data=data, axis_1_weights=(~time_flags[:, 0]).astype(float), axis_2_weights=(~freq_flags[0]).astype(float), @@ -1471,26 +1471,26 @@ def test_sparse_linear_fit_2d(): # Errors should be raised if the data, weights, and bases are not compatible pytest.raises( - ValueError, - dspec.sparse_linear_fit_2D, - data=data, + ValueError, + dspec.sparse_linear_fit_2D, + data=data, weights=(~flags).astype(float), axis_1_basis=time_basis, axis_2_basis=time_basis, ) pytest.raises( - ValueError, - dspec.sparse_linear_fit_2D, - data=data, + ValueError, + dspec.sparse_linear_fit_2D, + data=data, weights=(~flags).astype(float), axis_1_basis=freq_basis, axis_2_basis=freq_basis, ) pytest.raises( - ValueError, - dspec.sparse_linear_fit_2D, - data=data, + ValueError, + dspec.sparse_linear_fit_2D, + data=data, weights=(~flags).astype(float).T, axis_1_basis=time_basis, axis_2_basis=freq_basis, - ) \ No newline at end of file + )