Skip to content

Commit

Permalink
fix test and documentation issues
Browse files Browse the repository at this point in the history
  • Loading branch information
tyler-a-cox committed Feb 6, 2025
1 parent 6e53b57 commit 8f4c626
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 17 deletions.
46 changes: 31 additions & 15 deletions hera_filters/dspec.py
Original file line number Diff line number Diff line change
Expand Up @@ -2820,15 +2820,16 @@ def _kron_matvec(
Compute the matrix-vector product (Kronecker structured) for a given vector.
Given axis_1_basis (m x i) and axis_2_basis (n x j), this function computes:
Y = axis_1_basis * X * axis_2_basis^T
Y = (axis_1_basis @ X @ axis_2_basis^T) * weights
where X is the reshaped version of xs.
Parameters:
----------
x : np.ndarray
Input vector of size (i * j,).
weights : np.ndarray
Weighting matrix of size (m, n), applied element-wise.
Weighting matrix of size (m, n), applied element-wise to the
matrix product (axis_1_basis @ X @ axis_2_basis^T).
axis_1_basis : np.ndarray
Left transformation matrix of size (m, i).
axis_2_basis : np.ndarray
Expand Down Expand Up @@ -2860,7 +2861,7 @@ def _kron_rmatvec(
Compute the adjoint matrix-vector product (Kronecker structured).
Given axis_1_array (m x i) and axis_2_array (n x j), this function computes:
Y = axis_1_array^T * X * axis_2_array
Y = axis_1_array^T.conj() * X * axis_2_array.conj()
where X is the weighted reshaped version of data * weights.
Parameters:
Expand Down Expand Up @@ -2897,7 +2898,7 @@ def sparse_linear_fit_2D(
axis_2_basis: np.ndarray,
atol: float = 1e-10,
btol: float = 1e-10,
iter_lim: int = 500,
iter_lim: int = None,
**kwargs
) -> np.ndarray:
"""
Expand All @@ -2907,7 +2908,7 @@ def sparse_linear_fit_2D(
the design matrix is represented implicitly as the Kronecker product of `axis_1_basis`
and `axis_2_basis`. The solution is computed using `scipy.sparse.linalg.lsqr`. Note the
the convergence of the LSQR algorithm is not guaranteed for this problem, and highly
dependent on the conditioning of the basis/weighting matrices.
Parameters:
-----------
Expand All @@ -2916,16 +2917,16 @@ def sparse_linear_fit_2D(
weights : np.ndarray
A weight matrix of the same shape as `data`, applied element-wise.
axis_1_basis : np.ndarray
Basis basis along the first axis, shape (m, i).
Fitting basis along the first axis, shape (m, i).
axis_2_basis : np.ndarray
Basis basis along the second axis, shape (n, j).
Fitting basis along the second axis, shape (n, j).
atol, btol : float, optional, default 1e-10
Stopping tolerances for `lsqr`. The algorithm terminates when the
``norm(r) <= atol * norm(A) * norm(x) + btol * norm(b)``, where A is the
implicit Kronecker product of `axis_1_basis` and `axis_2_basis`, and b is the
flattened `data` array, x is the solution, and r is the residual.
iter_lim : int, optional
Maximum number of iterations for `lsqr`, default is 200.
Maximum number of iterations for `lsqr`, default is None
**kwargs : dict
Additional keyword arguments passed to `scipy.sparse.linalg.lsqr`.
Expand All @@ -2938,11 +2939,20 @@ def sparse_linear_fit_2D(
from sparse.linalg.lsqr.
"""
if data.shape != weights.shape:
raise ValueError("Shape mismatch: `weights` must have the same shape as `data`.")
raise ValueError(
f"Shape mismatch: `weights` (shape: {weights.shape}) must have the same" \
f"shape as `data` (shape {data.shape})."
)
if data.shape[0] != axis_1_basis.shape[0]:
raise ValueError("Shape mismatch: `axis_1_basis` must match the first dimension of `data`.")
raise ValueError(
f"Shape mismatch: `axis_1_basis` (shape: {axis_1_basis.shape}) must match" \
f"the first dimension of `data` (shape: {data.shape})."
)
if data.shape[1] != axis_2_basis.shape[0]:
raise ValueError("Shape mismatch: `axis_2_basis` must match the second dimension of `data`.")
raise ValueError(
f"Shape mismatch: `axis_2_basis` (shape: {axis_2_basis.shape}) must match" \
f"the second dimension of `data` (shape: {data.shape})."
)

# Define the shape of the implicit A matrix (Kronecker product of basis)
full_operator_shape = (
Expand Down Expand Up @@ -2999,19 +3009,25 @@ def separable_linear_fit_2D(
axis_2_weights : np.ndarray
Weights along the second axis, shape `(n,)`.
axis_1_basis : np.ndarray
Basis basis along the first axis, shape `(m, i)`.
Fitting basis along the first axis, shape `(m, i)`.
axis_2_basis : np.ndarray
Basis basis along the second axis, shape `(n, j)`.
Fitting basis along the second axis, shape `(n, j)`.
Returns:
--------
x : np.ndarray
The computed least-squares solution of shape `(i, j)`.
"""
if data.shape[0] != axis_1_basis.shape[0]:
raise ValueError("Shape mismatch: `axis_1_basis` must match the first dimension of `data`.")
raise ValueError(
f"Shape mismatch: `axis_1_basis` (shape: {axis_1_basis.shape}) must match" \
f"the first dimension of `data` (shape: {data.shape})."
)
if data.shape[1] != axis_2_basis.shape[0]:
raise ValueError("Shape mismatch: `axis_2_basis` must match the second dimension of `data`.")
raise ValueError(
f"Shape mismatch: `axis_2_basis` (shape: {axis_2_basis.shape}) must match" \
f"the second dimension of `data` (shape: {data.shape})."
)

# Compute pseudo-inverses for the weighted least-squares operators
axis_1_XTXinv = np.linalg.pinv((axis_1_basis.T.conj() * axis_1_weights) @ axis_1_basis)
Expand Down
4 changes: 2 additions & 2 deletions hera_filters/tests/test_dspec.py
Original file line number Diff line number Diff line change
Expand Up @@ -1378,7 +1378,7 @@ def test_separable_linear_fit_2D():
# By construction, the data is separable in the time and frequency directions
# and the flags are also separable. The fit should be able to recover the
# true data in the unflagged region.
rng = np.random.default_rng()
rng = np.random.default_rng(42)
freq_basis, _ = dspec.dpss_operator(np.linspace(100e6, 200e6, nfreqs), [0], [20e-9], eigenval_cutoff=[1e-12])
time_basis, _ = dspec.dpss_operator(np.linspace(0, ntimes * 10, ntimes), [0], [1e-3], eigenval_cutoff=[1e-12])
time_flags = rng.choice([True, False], p=[0.1, 0.9], size=(ntimes, 1))
Expand Down Expand Up @@ -1429,7 +1429,7 @@ def test_sparse_linear_fit_2d():
# By construction, the data is separable in the time and frequency directions
# and the flags are also separable. The fit should be able to recover the
# true data in the unflagged region.
rng = np.random.default_rng()
rng = np.random.default_rng(42)
freq_basis, _ = dspec.dpss_operator(np.linspace(100e6, 200e6, nfreqs), [0], [20e-9], eigenval_cutoff=[1e-12])
time_basis, _ = dspec.dpss_operator(np.linspace(0, ntimes * 10, ntimes), [0], [1e-3], eigenval_cutoff=[1e-12])
time_flags = rng.choice([True, False], p=[0.1, 0.9], size=(ntimes, 1))
Expand Down

0 comments on commit 8f4c626

Please sign in to comment.