Skip to content

Commit

Permalink
Merge pull request #53 from HERA-Team/ast-update
Browse files Browse the repository at this point in the history
Updates for numpy 2
  • Loading branch information
tyler-a-cox authored Jul 5, 2024
2 parents a1f9316 + e0a8831 commit b11ea51
Show file tree
Hide file tree
Showing 3 changed files with 39 additions and 8 deletions.
23 changes: 23 additions & 0 deletions .github/workflows/warnings_tests.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
name: Warnings Tests
on: [push, pull_request]

jobs:
tests:
name: Warning Tests
runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v3
- name: Setup Python
uses: actions/setup-python@v4
with:
python-version: '3.12'

- name: Install linsolve
run: |
python -m pip install --upgrade pip
pip install -e ".[dev]"
- name: Run Tests
run: |
pytest -W error
14 changes: 10 additions & 4 deletions src/linsolve/linsolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def ast_getterms(n):
if type(n) is ast.Name:
return [[n.id]]
elif type(n) is ast.Constant or type(n) is ast.Constant:
return [[n.n]]
return [[n.value]]
elif type(n) is ast.Expression:
return ast_getterms(n.body)
elif type(n) is ast.UnaryOp:
Expand Down Expand Up @@ -564,14 +564,20 @@ def _invert_solve(self, A, y, rcond):
methods.
"""
# As of numpy 1.8, solve works on stacks of matrices
# Change in numpy 2.0:
# The b array is only treated as a shape (M,) column vector if it is
# exactly 1-dimensional. In all other instances it is treated as a stack
# of (M, K) matrices. Previously b would be treated as a stack of (M,)
# vectors if b.ndim was equal to a.ndim - 1.
At = A.transpose([2, 1, 0]).conj()
AtA = [np.dot(At[k], A[..., k]) for k in range(y.shape[-1])]
Aty = [np.dot(At[k], y[..., k]) for k in range(y.shape[-1])]
Aty = [np.dot(At[k], y[..., k])[:, None] for k in range(y.shape[-1])]

# This is slower by about 50%: scipy.linalg.solve(AtA, Aty, 'her')

# But this sometimes errors if singular:
return np.linalg.solve(AtA, Aty).T
print(len(AtA), len(Aty), AtA[0].shape, Aty[0].shape)
return np.linalg.solve(AtA, Aty).T[0]

def _invert_solve_sparse(self, xs_ys_vals, y, rcond):
"""Use linalg.solve to solve a fully constrained (non-degenerate) system of eqs.
Expand Down Expand Up @@ -690,7 +696,7 @@ def eval(self, sol, keys=None):
def _chisq(self, sol, data, wgts, evaluator):
"""Internal adaptable chisq calculator."""
if len(wgts) == 0:
sigma2 = {k: 1.0 for k in list(data.keys())} # equal weights
sigma2 = dict.fromkeys(data.keys(), value=1.0) # equal weights
else:
sigma2 = {k: wgts[k] ** -1 for k in list(wgts.keys())}
evaluated = evaluator(sol, keys=data)
Expand Down
10 changes: 6 additions & 4 deletions tests/test_linsolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def test_eval(self):


class TestLinearSolver:
def setup(self):
def setup_class(self):
self.sparse = False
eqs = ["x+y", "x-y"]
x, y = 1, 2
Expand Down Expand Up @@ -366,7 +366,7 @@ def setup(self):


class TestLogProductSolver:
def setup(self):
def setup_class(self):
self.sparse = False

def test_init(self):
Expand Down Expand Up @@ -466,7 +466,7 @@ def setup(self):


class TestLinProductSolver:
def setup(self):
def setup_class(self):
self.sparse = False

def test_init(self):
Expand All @@ -490,7 +490,9 @@ def test_init(self):
np.testing.assert_almost_equal(eval(k), 0.002)
assert len(ls.ls.prms) == 3

ls = linsolve.LinProductSolver(d, sol0, w, sparse=self.sparse, build_solver=False)
ls = linsolve.LinProductSolver(
d, sol0, w, sparse=self.sparse, build_solver=False
)
assert not hasattr(ls, "ls")
assert ls.dtype == np.complex64

Expand Down

0 comments on commit b11ea51

Please sign in to comment.