diff --git a/README.md b/README.md index 6f24f6f..70c380b 100644 --- a/README.md +++ b/README.md @@ -9,7 +9,7 @@ # Overview The solvers in `linsolve` include `LinearSolver`, `LogProductSolver`, and `LinProductSolver`. -`LinearSolver` solves linear equations of the form `'a*x + b*y + c*z'`. +`LinearSolver` solves linear equations of the form `'a*x + b*y + c*z + d'`. `LogProductSolver` uses logrithms to linearize equations of the form `'x*y*z'`. `LinProductSolver` uses symbolic Taylor expansion to linearize equations of the form `'x*y + y*z'`. diff --git a/src/linsolve/linsolve.py b/src/linsolve/linsolve.py index 550c6c7..2364d74 100644 --- a/src/linsolve/linsolve.py +++ b/src/linsolve/linsolve.py @@ -12,14 +12,14 @@ describing the equation (which is parsed according to python syntax) and each value is the corresponding "measured" value of that equation. Variable names in equations are checked against keyword arguments to the solver to determine -if they are provided constants or parameters to be solved for. Parameter anmes +if they are provided constants or parameters to be solved for. Parameter names and solutions are return are returned as key:value pairs in ls.solve(). Parallel instances of equations can be evaluated by providing measured values as numpy arrays. Constants can also be arrays that comply with standard numpy broadcasting rules. Finally, weighting is implemented through an optional wgts dictionary that parallels the construction of data. -LinearSolver solves linear equations of the form 'a*x + b*y + c*z'. +LinearSolver solves linear equations of the form 'a*x + b*y + c*z + d'. LogProductSolver uses logrithms to linearize equations of the form 'x*y*z'. LinProductSolver uses symbolic Taylor expansion to linearize equations of the form 'x*y + y*z'. @@ -177,6 +177,7 @@ def __init__(self, val, **kwargs): self.wgts = kwargs.pop("wgts", np.float32(1.0)) self.has_conj = False constants = kwargs.pop("constants", kwargs) + self.additive_offset = np.float32(0.0) self.process_terms(val, constants) def process_terms(self, terms, constants): @@ -211,11 +212,17 @@ def order_terms(self, terms): for L in terms: L.sort(key=lambda x: get_name(x) in self.prms) # Validate that each term has exactly 1 unsolved parameter. + final_terms = [] for t in terms: - assert get_name(t[-1]) in self.prms + # Check if this term has no free parameters (i.e. additive constant) + if get_name(t[-1]) not in self.prms: + self.additive_offset += self.eval_consts(t) + continue + # Make sure there is no more than 1 free parameter per term for ti in t[:-1]: assert type(ti) is not str or get_name(ti) in self.consts - return terms + final_terms.append(t) + return final_terms def eval_consts(self, const_list, wgts=np.float32(1.0)): """Multiply out constants (and wgts) for placing in matrix.""" @@ -251,6 +258,8 @@ def eval(self, sol): else: total *= sol[name] rv += total + # add back in purely constant terms, which were filtered out of self.terms + rv += self.additive_offset return rv @@ -299,7 +308,7 @@ def infer_dtype(values): class LinearSolver: def __init__(self, data, wgts={}, sparse=False, **kwargs): - """Set up a linear system of equations of the form 1*a + 2*b + 3*c = 4. + """Set up a linear system of equations of the form 1*a + 2*b + 3*c + 4 = 5. Parameters ---------- @@ -430,7 +439,7 @@ def get_weighted_data(self): dtype = np.complex64 else: dtype = np.complex128 - d = np.array([self.data[k] for k in self.keys], dtype=dtype) + d = np.array([self.data[k] - eq.additive_offset for k, eq in zip(self.keys, self.eqs)], dtype=dtype) if len(self.wgts) > 0: w = np.array([self.wgts[k] for k in self.keys]) w.shape += (1,) * (d.ndim - w.ndim) @@ -571,13 +580,12 @@ def _invert_solve(self, A, y, rcond): # vectors if b.ndim was equal to a.ndim - 1. At = A.transpose([2, 1, 0]).conj() AtA = [np.dot(At[k], A[..., k]) for k in range(y.shape[-1])] - Aty = [np.dot(At[k], y[..., k])[:, None] for k in range(y.shape[-1])] + Aty = [np.dot(At[k], y[..., k])[..., None] for k in range(y.shape[-1])] # This is slower by about 50%: scipy.linalg.solve(AtA, Aty, 'her') # But this sometimes errors if singular: - print(len(AtA), len(Aty), AtA[0].shape, Aty[0].shape) - return np.linalg.solve(AtA, Aty).T[0] + return np.linalg.solve(AtA, Aty)[..., 0].T def _invert_solve_sparse(self, xs_ys_vals, y, rcond): """Use linalg.solve to solve a fully constrained (non-degenerate) system of eqs. @@ -588,7 +596,7 @@ def _invert_solve_sparse(self, xs_ys_vals, y, rcond): AtA, Aty = self._get_AtA_Aty_sparse(xs_ys_vals, y) # AtA and Aty don't end up being that sparse, usually, so don't use this: # --> x = scipy.sparse.linalg.spsolve(AtA, Aty) - return np.linalg.solve(AtA, Aty).T + return np.linalg.solve(AtA, Aty[..., None])[..., 0].T def _invert_default(self, A, y, rcond): """The default inverter, currently 'pinv'.""" diff --git a/tests/test_linsolve.py b/tests/test_linsolve.py index df14d97..676f4bc 100644 --- a/tests/test_linsolve.py +++ b/tests/test_linsolve.py @@ -124,9 +124,10 @@ def test_term_check(self): terms4 = [["c", "x", "a"], [1, "b", "y"]] with pytest.raises(AssertionError): le.order_terms(terms4) - terms5 = [[1, "a", "b"], [1, "b", "y"]] - with pytest.raises(AssertionError): - le.order_terms(terms5) + terms5 = [["a", "b"], [1, "b", "y"]] + terms = le.order_terms(terms5) + assert len(terms) == 1 + assert le.additive_offset == 8 def test_eval(self): le = linsolve.LinearEquation("a*x-b*y", a=2, b=4) @@ -138,6 +139,9 @@ def test_eval(self): sol = {"x": 3 + 3j * np.ones(10), "y": 7 + 2j * np.ones(10)} ans = np.conj(sol["x"]) - sol["y"] np.testing.assert_equal(ans, le.eval(sol)) + le = linsolve.LinearEquation("a*b+a*x-b*y", a=2, b=4) + sol = {'x': 3, 'y': 7} + assert 2 * 4 + 2 * 3 - 4 * 7 == le.eval(sol) class TestLinearSolver: @@ -276,6 +280,23 @@ def test_eval(self): result = ls.eval(sol, "a*x+b*y") np.testing.assert_almost_equal(3 * 1 + 1 * 2, list(result.values())[0]) + def test_eval_const_term(self): + x, y = 1.0, 2.0 + a, b = 3.0 * np.ones(4), 1.0 + eqs = ["a*b+a*x+y", "a+x+b*y"] + d, w = {}, {} + for eq in eqs: + d[eq], w[eq] = eval(eq) * np.ones(4), np.ones(4) + ls = linsolve.LinearSolver(d, w, a=a, b=b, sparse=self.sparse) + sol = ls.solve() + np.testing.assert_almost_equal(sol["x"], x * np.ones(4, dtype=np.float64)) + np.testing.assert_almost_equal(sol["y"], y * np.ones(4, dtype=np.float64)) + result = ls.eval(sol) + for eq in d: + np.testing.assert_almost_equal(d[eq], result[eq]) + result = ls.eval(sol, "a*b+a*x+b*y") + np.testing.assert_almost_equal(3 * 1 + 3 * 1 + 1 * 2, list(result.values())[0]) + def test_chisq(self): x = 1.0 d = {"x": 1, "a*x": 2} @@ -297,6 +318,14 @@ def test_chisq(self): chisq = ls.chisq(sol) np.testing.assert_almost_equal(sol["x"], 5.0 / 3.0, 6) np.testing.assert_almost_equal(ls.chisq(sol), 1.0 / 3.0) + x = 1.0 + d = {"1*x+1": 3.0, "x": 1.0} + w = {"1*x+1": 1.0, "x": 0.5} + ls = linsolve.LinearSolver(d, wgts=w, sparse=self.sparse) + sol = ls.solve() + chisq = ls.chisq(sol) + np.testing.assert_almost_equal(sol["x"], 5.0 / 3.0, 6) + np.testing.assert_almost_equal(ls.chisq(sol), 1.0 / 3.0) def test_dtypes(self): ls = linsolve.LinearSolver({"x_": 1.0 + 1.0j}, sparse=self.sparse) @@ -355,7 +384,7 @@ def test_degen_sol(self): class TestLinearSolverSparse(TestLinearSolver): - def setup(self): + def setup_method(self): self.sparse = True eqs = ["x+y", "x-y"] x, y = 1, 2 @@ -461,7 +490,7 @@ def test_dtype(self): class TestLogProductSolverSparse(TestLogProductSolver): - def setup(self): + def setup_method(self): self.sparse = True @@ -762,5 +791,5 @@ def test_degen_sol(self): class TestLinProductSolverSparse(TestLinProductSolver): - def setup(self): + def setup_method(self): self.sparse = True