Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Feb 5, 2025
1 parent 5d92436 commit d908630
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 33 deletions.
38 changes: 19 additions & 19 deletions hera_filters/dspec.py
Original file line number Diff line number Diff line change
Expand Up @@ -2811,9 +2811,9 @@ def dayenu_mat_inv(x, filter_centers, filter_half_widths,
return sdwi_mat

def _kron_matvec(
x: np.ndarray,
weights: np.ndarray,
axis_1_basis: np.ndarray,
x: np.ndarray,
weights: np.ndarray,
axis_1_basis: np.ndarray,
axis_2_basis: np.ndarray
)-> np.ndarray:
"""
Expand All @@ -2840,20 +2840,20 @@ def _kron_matvec(
Flattened result of size (m * n,).
"""
i, j = axis_1_basis.shape[1], axis_2_basis.shape[1]

# Reshape v into (m, n) matrix using Fortran order to match vectorization
X = x.reshape((i, j))

# Compute the transformation
Y = (axis_1_basis @ X) @ axis_2_basis.T

# Apply the weight W and return flattened result
return (Y * weights).ravel()

def _kron_rmatvec(
data: np.ndarray,
weights: np.ndarray,
axis_1_basis: np.ndarray,
data: np.ndarray,
weights: np.ndarray,
axis_1_basis: np.ndarray,
axis_2_basis: np.ndarray
) -> np.ndarray:
"""
Expand All @@ -2880,13 +2880,13 @@ def _kron_rmatvec(
Flattened result of size (i * j,).
"""
m, n = axis_1_basis.shape[0], axis_2_basis.shape[0]

# Reshape u into (m, n) matrix using Fortran order and apply W
X = (data.reshape((m, n))) * weights

# Compute the transformation
Y = (axis_1_basis.T.conj() @ X) @ axis_2_basis.conj()

# Return flattened result
return Y.ravel()

Expand Down Expand Up @@ -2920,7 +2920,7 @@ def sparse_linear_fit_2D(
axis_2_basis : np.ndarray
Basis basis along the second axis, shape (n, j).
atol, btol : float, optional, default 1e-10
Stopping tolerances for `lsqr`. The algorithm terminates when the
Stopping tolerances for `lsqr`. The algorithm terminates when the
``norm(r) <= atol * norm(A) * norm(x) + btol * norm(b)``, where A is the
implicit Kronecker product of `axis_1_basis` and `axis_2_basis`, and b is the
flattened `data` array, x is the solution, and r is the residual.
Expand Down Expand Up @@ -2959,10 +2959,10 @@ def sparse_linear_fit_2D(
meta = {}
# Solve the least-squares problem using LSQR
(
x,
meta['istop'],
meta['iter_num'],
*_
x,
meta['istop'],
meta['iter_num'],
*_
)= sparse.linalg.lsqr(
A=linear_operator,
b=data.ravel(),
Expand All @@ -2987,7 +2987,7 @@ def separable_linear_fit_2D(
"""
Solves a separable linear least-squares problem using weighted basis basis.
This function fits the input `data` using a least-squares approach with
This function fits the input `data` using a least-squares approach with
separable weighting along two axes. The solution is computed using pseudo-inverses.
Parameters:
Expand Down Expand Up @@ -3024,4 +3024,4 @@ def separable_linear_fit_2D(
# Compute the final solution
x = axis_1_operator @ data @ axis_2_operator.T

return x
return x
28 changes: 14 additions & 14 deletions hera_filters/tests/test_dspec.py
Original file line number Diff line number Diff line change
Expand Up @@ -1402,8 +1402,8 @@ def test_separable_linear_fit_2D():

# Errors should be raised if the data, weights, and bases are not compatible
pytest.raises(
ValueError,
dspec.separable_linear_fit_2D,
ValueError,
dspec.separable_linear_fit_2D,
data=data,
axis_1_weights=(~time_flags[:, 0]).astype(float),
axis_2_weights=(~freq_flags[0]).astype(float),
Expand All @@ -1412,8 +1412,8 @@ def test_separable_linear_fit_2D():
)

pytest.raises(
ValueError,
dspec.separable_linear_fit_2D,
ValueError,
dspec.separable_linear_fit_2D,
data=data,
axis_1_weights=(~time_flags[:, 0]).astype(float),
axis_2_weights=(~freq_flags[0]).astype(float),
Expand Down Expand Up @@ -1471,26 +1471,26 @@ def test_sparse_linear_fit_2d():

# Errors should be raised if the data, weights, and bases are not compatible
pytest.raises(
ValueError,
dspec.sparse_linear_fit_2D,
data=data,
ValueError,
dspec.sparse_linear_fit_2D,
data=data,
weights=(~flags).astype(float),
axis_1_basis=time_basis,
axis_2_basis=time_basis,
)
pytest.raises(
ValueError,
dspec.sparse_linear_fit_2D,
data=data,
ValueError,
dspec.sparse_linear_fit_2D,
data=data,
weights=(~flags).astype(float),
axis_1_basis=freq_basis,
axis_2_basis=freq_basis,
)
pytest.raises(
ValueError,
dspec.sparse_linear_fit_2D,
data=data,
ValueError,
dspec.sparse_linear_fit_2D,
data=data,
weights=(~flags).astype(float).T,
axis_1_basis=time_basis,
axis_2_basis=freq_basis,
)
)

0 comments on commit d908630

Please sign in to comment.