diff --git a/hera_filters/dspec.py b/hera_filters/dspec.py index 29bfb02..ec6b086 100644 --- a/hera_filters/dspec.py +++ b/hera_filters/dspec.py @@ -2820,7 +2820,7 @@ def _kron_matvec( Compute the matrix-vector product (Kronecker structured) for a given vector. Given axis_1_basis (m x i) and axis_2_basis (n x j), this function computes: - Y = axis_1_basis * X * axis_2_basis^T + Y = (axis_1_basis @ X @ axis_2_basis^T) * weights where X is the reshaped version of xs. Parameters: @@ -2828,7 +2828,8 @@ def _kron_matvec( x : np.ndarray Input vector of size (i * j,). weights : np.ndarray - Weighting matrix of size (m, n), applied element-wise. + Weighting matrix of size (m, n), applied element-wise to the + matrix product (axis_1_basis @ X @ axis_2_basis^T). axis_1_basis : np.ndarray Left transformation matrix of size (m, i). axis_2_basis : np.ndarray @@ -2860,7 +2861,7 @@ def _kron_rmatvec( Compute the adjoint matrix-vector product (Kronecker structured). Given axis_1_array (m x i) and axis_2_array (n x j), this function computes: - Y = axis_1_array^T * X * axis_2_array + Y = axis_1_array^T.conj() * X * axis_2_array.conj() where X is the weighted reshaped version of data * weights. Parameters: @@ -2897,7 +2898,7 @@ def sparse_linear_fit_2D( axis_2_basis: np.ndarray, atol: float = 1e-10, btol: float = 1e-10, - iter_lim: int = 500, + iter_lim: int = None, **kwargs ) -> np.ndarray: """ @@ -2907,7 +2908,7 @@ def sparse_linear_fit_2D( the design matrix is represented implicitly as the Kronecker product of `axis_1_basis` and `axis_2_basis`. The solution is computed using `scipy.sparse.linalg.lsqr`. Note the the convergence of the LSQR algorithm is not guaranteed for this problem, and highly - + dependent on the conditioning of the basis/weighting matrices. Parameters: ----------- @@ -2916,16 +2917,16 @@ def sparse_linear_fit_2D( weights : np.ndarray A weight matrix of the same shape as `data`, applied element-wise. axis_1_basis : np.ndarray - Basis basis along the first axis, shape (m, i). + Fitting basis along the first axis, shape (m, i). axis_2_basis : np.ndarray - Basis basis along the second axis, shape (n, j). + Fitting basis along the second axis, shape (n, j). atol, btol : float, optional, default 1e-10 Stopping tolerances for `lsqr`. The algorithm terminates when the ``norm(r) <= atol * norm(A) * norm(x) + btol * norm(b)``, where A is the implicit Kronecker product of `axis_1_basis` and `axis_2_basis`, and b is the flattened `data` array, x is the solution, and r is the residual. iter_lim : int, optional - Maximum number of iterations for `lsqr`, default is 200. + Maximum number of iterations for `lsqr`, default is None **kwargs : dict Additional keyword arguments passed to `scipy.sparse.linalg.lsqr`. @@ -2938,11 +2939,20 @@ def sparse_linear_fit_2D( from sparse.linalg.lsqr. """ if data.shape != weights.shape: - raise ValueError("Shape mismatch: `weights` must have the same shape as `data`.") + raise ValueError( + f"Shape mismatch: `weights` (shape: {weights.shape}) must have the same" \ + f"shape as `data` (shape {data.shape})." + ) if data.shape[0] != axis_1_basis.shape[0]: - raise ValueError("Shape mismatch: `axis_1_basis` must match the first dimension of `data`.") + raise ValueError( + f"Shape mismatch: `axis_1_basis` (shape: {axis_1_basis.shape}) must match" \ + f"the first dimension of `data` (shape: {data.shape})." + ) if data.shape[1] != axis_2_basis.shape[0]: - raise ValueError("Shape mismatch: `axis_2_basis` must match the second dimension of `data`.") + raise ValueError( + f"Shape mismatch: `axis_2_basis` (shape: {axis_2_basis.shape}) must match" \ + f"the second dimension of `data` (shape: {data.shape})." + ) # Define the shape of the implicit A matrix (Kronecker product of basis) full_operator_shape = ( @@ -2999,9 +3009,9 @@ def separable_linear_fit_2D( axis_2_weights : np.ndarray Weights along the second axis, shape `(n,)`. axis_1_basis : np.ndarray - Basis basis along the first axis, shape `(m, i)`. + Fitting basis along the first axis, shape `(m, i)`. axis_2_basis : np.ndarray - Basis basis along the second axis, shape `(n, j)`. + Fitting basis along the second axis, shape `(n, j)`. Returns: -------- @@ -3009,9 +3019,15 @@ def separable_linear_fit_2D( The computed least-squares solution of shape `(i, j)`. """ if data.shape[0] != axis_1_basis.shape[0]: - raise ValueError("Shape mismatch: `axis_1_basis` must match the first dimension of `data`.") + raise ValueError( + f"Shape mismatch: `axis_1_basis` (shape: {axis_1_basis.shape}) must match" \ + f"the first dimension of `data` (shape: {data.shape})." + ) if data.shape[1] != axis_2_basis.shape[0]: - raise ValueError("Shape mismatch: `axis_2_basis` must match the second dimension of `data`.") + raise ValueError( + f"Shape mismatch: `axis_2_basis` (shape: {axis_2_basis.shape}) must match" \ + f"the second dimension of `data` (shape: {data.shape})." + ) # Compute pseudo-inverses for the weighted least-squares operators axis_1_XTXinv = np.linalg.pinv((axis_1_basis.T.conj() * axis_1_weights) @ axis_1_basis) diff --git a/hera_filters/tests/test_dspec.py b/hera_filters/tests/test_dspec.py index 0432180..3da57d3 100644 --- a/hera_filters/tests/test_dspec.py +++ b/hera_filters/tests/test_dspec.py @@ -1378,7 +1378,7 @@ def test_separable_linear_fit_2D(): # By construction, the data is separable in the time and frequency directions # and the flags are also separable. The fit should be able to recover the # true data in the unflagged region. - rng = np.random.default_rng() + rng = np.random.default_rng(42) freq_basis, _ = dspec.dpss_operator(np.linspace(100e6, 200e6, nfreqs), [0], [20e-9], eigenval_cutoff=[1e-12]) time_basis, _ = dspec.dpss_operator(np.linspace(0, ntimes * 10, ntimes), [0], [1e-3], eigenval_cutoff=[1e-12]) time_flags = rng.choice([True, False], p=[0.1, 0.9], size=(ntimes, 1)) @@ -1429,7 +1429,7 @@ def test_sparse_linear_fit_2d(): # By construction, the data is separable in the time and frequency directions # and the flags are also separable. The fit should be able to recover the # true data in the unflagged region. - rng = np.random.default_rng() + rng = np.random.default_rng(42) freq_basis, _ = dspec.dpss_operator(np.linspace(100e6, 200e6, nfreqs), [0], [20e-9], eigenval_cutoff=[1e-12]) time_basis, _ = dspec.dpss_operator(np.linspace(0, ntimes * 10, ntimes), [0], [1e-3], eigenval_cutoff=[1e-12]) time_flags = rng.choice([True, False], p=[0.1, 0.9], size=(ntimes, 1))