From 816699ea6ad1a42305774f28d5a2faf9550be9b6 Mon Sep 17 00:00:00 2001 From: Aaron Parsons Date: Wed, 8 Jan 2025 16:40:22 -0800 Subject: [PATCH] Added constant-term functionality to linsolve. --- src/linsolve/linsolve.py | 15 ++++++++++++--- tests/test_linsolve.py | 35 ++++++++++++++++++++++++++++++++--- 2 files changed, 44 insertions(+), 6 deletions(-) diff --git a/src/linsolve/linsolve.py b/src/linsolve/linsolve.py index 6ea89b9..5217c7a 100644 --- a/src/linsolve/linsolve.py +++ b/src/linsolve/linsolve.py @@ -177,6 +177,7 @@ def __init__(self, val, **kwargs): self.wgts = kwargs.pop("wgts", np.float32(1.0)) self.has_conj = False constants = kwargs.pop("constants", kwargs) + self.additive_offset = np.float32(0.0) self.process_terms(val, constants) def process_terms(self, terms, constants): @@ -211,11 +212,17 @@ def order_terms(self, terms): for L in terms: L.sort(key=lambda x: get_name(x) in self.prms) # Validate that each term has exactly 1 unsolved parameter. + final_terms = [] for t in terms: - assert get_name(t[-1]) in self.prms + # Check if this term has no free parameters (i.e. additive constant) + if get_name(t[-1]) not in self.prms: + self.additive_offset += self.eval_consts(t) + continue + # Make sure there is no more than 1 free parameter per term for ti in t[:-1]: assert type(ti) is not str or get_name(ti) in self.consts - return terms + final_terms.append(t) + return final_terms def eval_consts(self, const_list, wgts=np.float32(1.0)): """Multiply out constants (and wgts) for placing in matrix.""" @@ -251,6 +258,8 @@ def eval(self, sol): else: total *= sol[name] rv += total + # add back in purely constant terms, which were filtered out of self.terms + rv += self.additive_offset return rv @@ -430,7 +439,7 @@ def get_weighted_data(self): dtype = np.complex64 else: dtype = np.complex128 - d = np.array([self.data[k] for k in self.keys], dtype=dtype) + d = np.array([self.data[k] - eq.additive_offset for k, eq in zip(self.keys, self.eqs)], dtype=dtype) if len(self.wgts) > 0: w = np.array([self.wgts[k] for k in self.keys]) w.shape += (1,) * (d.ndim - w.ndim) diff --git a/tests/test_linsolve.py b/tests/test_linsolve.py index a41bcab..676f4bc 100644 --- a/tests/test_linsolve.py +++ b/tests/test_linsolve.py @@ -124,9 +124,10 @@ def test_term_check(self): terms4 = [["c", "x", "a"], [1, "b", "y"]] with pytest.raises(AssertionError): le.order_terms(terms4) - terms5 = [[1, "a", "b"], [1, "b", "y"]] - with pytest.raises(AssertionError): - le.order_terms(terms5) + terms5 = [["a", "b"], [1, "b", "y"]] + terms = le.order_terms(terms5) + assert len(terms) == 1 + assert le.additive_offset == 8 def test_eval(self): le = linsolve.LinearEquation("a*x-b*y", a=2, b=4) @@ -138,6 +139,9 @@ def test_eval(self): sol = {"x": 3 + 3j * np.ones(10), "y": 7 + 2j * np.ones(10)} ans = np.conj(sol["x"]) - sol["y"] np.testing.assert_equal(ans, le.eval(sol)) + le = linsolve.LinearEquation("a*b+a*x-b*y", a=2, b=4) + sol = {'x': 3, 'y': 7} + assert 2 * 4 + 2 * 3 - 4 * 7 == le.eval(sol) class TestLinearSolver: @@ -276,6 +280,23 @@ def test_eval(self): result = ls.eval(sol, "a*x+b*y") np.testing.assert_almost_equal(3 * 1 + 1 * 2, list(result.values())[0]) + def test_eval_const_term(self): + x, y = 1.0, 2.0 + a, b = 3.0 * np.ones(4), 1.0 + eqs = ["a*b+a*x+y", "a+x+b*y"] + d, w = {}, {} + for eq in eqs: + d[eq], w[eq] = eval(eq) * np.ones(4), np.ones(4) + ls = linsolve.LinearSolver(d, w, a=a, b=b, sparse=self.sparse) + sol = ls.solve() + np.testing.assert_almost_equal(sol["x"], x * np.ones(4, dtype=np.float64)) + np.testing.assert_almost_equal(sol["y"], y * np.ones(4, dtype=np.float64)) + result = ls.eval(sol) + for eq in d: + np.testing.assert_almost_equal(d[eq], result[eq]) + result = ls.eval(sol, "a*b+a*x+b*y") + np.testing.assert_almost_equal(3 * 1 + 3 * 1 + 1 * 2, list(result.values())[0]) + def test_chisq(self): x = 1.0 d = {"x": 1, "a*x": 2} @@ -297,6 +318,14 @@ def test_chisq(self): chisq = ls.chisq(sol) np.testing.assert_almost_equal(sol["x"], 5.0 / 3.0, 6) np.testing.assert_almost_equal(ls.chisq(sol), 1.0 / 3.0) + x = 1.0 + d = {"1*x+1": 3.0, "x": 1.0} + w = {"1*x+1": 1.0, "x": 0.5} + ls = linsolve.LinearSolver(d, wgts=w, sparse=self.sparse) + sol = ls.solve() + chisq = ls.chisq(sol) + np.testing.assert_almost_equal(sol["x"], 5.0 / 3.0, 6) + np.testing.assert_almost_equal(ls.chisq(sol), 1.0 / 3.0) def test_dtypes(self): ls = linsolve.LinearSolver({"x_": 1.0 + 1.0j}, sparse=self.sparse)