Skip to content

Commit

Permalink
Added constant-term functionality to linsolve.
Browse files Browse the repository at this point in the history
  • Loading branch information
HERA-Observer committed Jan 9, 2025
1 parent fd44e85 commit 816699e
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 6 deletions.
15 changes: 12 additions & 3 deletions src/linsolve/linsolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,7 @@ def __init__(self, val, **kwargs):
self.wgts = kwargs.pop("wgts", np.float32(1.0))
self.has_conj = False
constants = kwargs.pop("constants", kwargs)
self.additive_offset = np.float32(0.0)
self.process_terms(val, constants)

def process_terms(self, terms, constants):
Expand Down Expand Up @@ -211,11 +212,17 @@ def order_terms(self, terms):
for L in terms:
L.sort(key=lambda x: get_name(x) in self.prms)
# Validate that each term has exactly 1 unsolved parameter.
final_terms = []
for t in terms:
assert get_name(t[-1]) in self.prms
# Check if this term has no free parameters (i.e. additive constant)
if get_name(t[-1]) not in self.prms:
self.additive_offset += self.eval_consts(t)
continue
# Make sure there is no more than 1 free parameter per term
for ti in t[:-1]:
assert type(ti) is not str or get_name(ti) in self.consts
return terms
final_terms.append(t)
return final_terms

def eval_consts(self, const_list, wgts=np.float32(1.0)):
"""Multiply out constants (and wgts) for placing in matrix."""
Expand Down Expand Up @@ -251,6 +258,8 @@ def eval(self, sol):
else:
total *= sol[name]
rv += total
# add back in purely constant terms, which were filtered out of self.terms
rv += self.additive_offset
return rv


Expand Down Expand Up @@ -430,7 +439,7 @@ def get_weighted_data(self):
dtype = np.complex64
else:
dtype = np.complex128
d = np.array([self.data[k] for k in self.keys], dtype=dtype)
d = np.array([self.data[k] - eq.additive_offset for k, eq in zip(self.keys, self.eqs)], dtype=dtype)
if len(self.wgts) > 0:
w = np.array([self.wgts[k] for k in self.keys])
w.shape += (1,) * (d.ndim - w.ndim)
Expand Down
35 changes: 32 additions & 3 deletions tests/test_linsolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,10 @@ def test_term_check(self):
terms4 = [["c", "x", "a"], [1, "b", "y"]]
with pytest.raises(AssertionError):
le.order_terms(terms4)
terms5 = [[1, "a", "b"], [1, "b", "y"]]
with pytest.raises(AssertionError):
le.order_terms(terms5)
terms5 = [["a", "b"], [1, "b", "y"]]
terms = le.order_terms(terms5)
assert len(terms) == 1
assert le.additive_offset == 8

def test_eval(self):
le = linsolve.LinearEquation("a*x-b*y", a=2, b=4)
Expand All @@ -138,6 +139,9 @@ def test_eval(self):
sol = {"x": 3 + 3j * np.ones(10), "y": 7 + 2j * np.ones(10)}
ans = np.conj(sol["x"]) - sol["y"]
np.testing.assert_equal(ans, le.eval(sol))
le = linsolve.LinearEquation("a*b+a*x-b*y", a=2, b=4)
sol = {'x': 3, 'y': 7}
assert 2 * 4 + 2 * 3 - 4 * 7 == le.eval(sol)


class TestLinearSolver:
Expand Down Expand Up @@ -276,6 +280,23 @@ def test_eval(self):
result = ls.eval(sol, "a*x+b*y")
np.testing.assert_almost_equal(3 * 1 + 1 * 2, list(result.values())[0])

def test_eval_const_term(self):
x, y = 1.0, 2.0
a, b = 3.0 * np.ones(4), 1.0
eqs = ["a*b+a*x+y", "a+x+b*y"]
d, w = {}, {}
for eq in eqs:
d[eq], w[eq] = eval(eq) * np.ones(4), np.ones(4)
ls = linsolve.LinearSolver(d, w, a=a, b=b, sparse=self.sparse)
sol = ls.solve()
np.testing.assert_almost_equal(sol["x"], x * np.ones(4, dtype=np.float64))
np.testing.assert_almost_equal(sol["y"], y * np.ones(4, dtype=np.float64))
result = ls.eval(sol)
for eq in d:
np.testing.assert_almost_equal(d[eq], result[eq])
result = ls.eval(sol, "a*b+a*x+b*y")
np.testing.assert_almost_equal(3 * 1 + 3 * 1 + 1 * 2, list(result.values())[0])

def test_chisq(self):
x = 1.0
d = {"x": 1, "a*x": 2}
Expand All @@ -297,6 +318,14 @@ def test_chisq(self):
chisq = ls.chisq(sol)
np.testing.assert_almost_equal(sol["x"], 5.0 / 3.0, 6)
np.testing.assert_almost_equal(ls.chisq(sol), 1.0 / 3.0)
x = 1.0
d = {"1*x+1": 3.0, "x": 1.0}
w = {"1*x+1": 1.0, "x": 0.5}
ls = linsolve.LinearSolver(d, wgts=w, sparse=self.sparse)
sol = ls.solve()
chisq = ls.chisq(sol)
np.testing.assert_almost_equal(sol["x"], 5.0 / 3.0, 6)
np.testing.assert_almost_equal(ls.chisq(sol), 1.0 / 3.0)

def test_dtypes(self):
ls = linsolve.LinearSolver({"x_": 1.0 + 1.0j}, sparse=self.sparse)
Expand Down

0 comments on commit 816699e

Please sign in to comment.