Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added ability for linsolve equations to have purely constant terms, which are just subracted off of the equated value when solving. #60

Merged
merged 4 commits into from
Jan 9, 2025
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 18 additions & 10 deletions src/linsolve/linsolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@
describing the equation (which is parsed according to python syntax) and each
value is the corresponding "measured" value of that equation. Variable names
in equations are checked against keyword arguments to the solver to determine
if they are provided constants or parameters to be solved for. Parameter anmes
if they are provided constants or parameters to be solved for. Parameter names
and solutions are return are returned as key:value pairs in ls.solve().
Parallel instances of equations can be evaluated by providing measured values
as numpy arrays. Constants can also be arrays that comply with standard numpy
broadcasting rules. Finally, weighting is implemented through an optional wgts
dictionary that parallels the construction of data.

LinearSolver solves linear equations of the form 'a*x + b*y + c*z'.
LinearSolver solves linear equations of the form 'a*x + b*y + c*z + d'.
LogProductSolver uses logrithms to linearize equations of the form 'x*y*z'.
LinProductSolver uses symbolic Taylor expansion to linearize equations of the
form 'x*y + y*z'.
Expand Down Expand Up @@ -177,6 +177,7 @@ def __init__(self, val, **kwargs):
self.wgts = kwargs.pop("wgts", np.float32(1.0))
self.has_conj = False
constants = kwargs.pop("constants", kwargs)
self.additive_offset = np.float32(0.0)
self.process_terms(val, constants)

def process_terms(self, terms, constants):
Expand Down Expand Up @@ -211,11 +212,17 @@ def order_terms(self, terms):
for L in terms:
L.sort(key=lambda x: get_name(x) in self.prms)
# Validate that each term has exactly 1 unsolved parameter.
final_terms = []
for t in terms:
assert get_name(t[-1]) in self.prms
# Check if this term has no free parameters (i.e. additive constant)
if get_name(t[-1]) not in self.prms:
self.additive_offset += self.eval_consts(t)
continue
# Make sure there is no more than 1 free parameter per term
for ti in t[:-1]:
assert type(ti) is not str or get_name(ti) in self.consts
return terms
final_terms.append(t)
return final_terms

def eval_consts(self, const_list, wgts=np.float32(1.0)):
"""Multiply out constants (and wgts) for placing in matrix."""
Expand Down Expand Up @@ -251,6 +258,8 @@ def eval(self, sol):
else:
total *= sol[name]
rv += total
# add back in purely constant terms, which were filtered out of self.terms
rv += self.additive_offset
return rv


Expand Down Expand Up @@ -299,7 +308,7 @@ def infer_dtype(values):

class LinearSolver:
def __init__(self, data, wgts={}, sparse=False, **kwargs):
"""Set up a linear system of equations of the form 1*a + 2*b + 3*c = 4.
"""Set up a linear system of equations of the form 1*a + 2*b + 3*c + 4 = 5.

Parameters
----------
Expand Down Expand Up @@ -430,7 +439,7 @@ def get_weighted_data(self):
dtype = np.complex64
else:
dtype = np.complex128
d = np.array([self.data[k] for k in self.keys], dtype=dtype)
d = np.array([self.data[k] - eq.additive_offset for k, eq in zip(self.keys, self.eqs)], dtype=dtype)
if len(self.wgts) > 0:
w = np.array([self.wgts[k] for k in self.keys])
w.shape += (1,) * (d.ndim - w.ndim)
Expand Down Expand Up @@ -571,13 +580,12 @@ def _invert_solve(self, A, y, rcond):
# vectors if b.ndim was equal to a.ndim - 1.
At = A.transpose([2, 1, 0]).conj()
AtA = [np.dot(At[k], A[..., k]) for k in range(y.shape[-1])]
Aty = [np.dot(At[k], y[..., k])[:, None] for k in range(y.shape[-1])]
Aty = [np.dot(At[k], y[..., k])[..., None] for k in range(y.shape[-1])]

# This is slower by about 50%: scipy.linalg.solve(AtA, Aty, 'her')

# But this sometimes errors if singular:
print(len(AtA), len(Aty), AtA[0].shape, Aty[0].shape)
return np.linalg.solve(AtA, Aty).T[0]
return np.linalg.solve(AtA, Aty)[..., 0].T

def _invert_solve_sparse(self, xs_ys_vals, y, rcond):
"""Use linalg.solve to solve a fully constrained (non-degenerate) system of eqs.
Expand All @@ -588,7 +596,7 @@ def _invert_solve_sparse(self, xs_ys_vals, y, rcond):
AtA, Aty = self._get_AtA_Aty_sparse(xs_ys_vals, y)
# AtA and Aty don't end up being that sparse, usually, so don't use this:
# --> x = scipy.sparse.linalg.spsolve(AtA, Aty)
return np.linalg.solve(AtA, Aty).T
return np.linalg.solve(AtA, Aty[..., None])[..., 0].T

def _invert_default(self, A, y, rcond):
"""The default inverter, currently 'pinv'."""
Expand Down
41 changes: 35 additions & 6 deletions tests/test_linsolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,10 @@ def test_term_check(self):
terms4 = [["c", "x", "a"], [1, "b", "y"]]
with pytest.raises(AssertionError):
le.order_terms(terms4)
terms5 = [[1, "a", "b"], [1, "b", "y"]]
with pytest.raises(AssertionError):
le.order_terms(terms5)
terms5 = [["a", "b"], [1, "b", "y"]]
terms = le.order_terms(terms5)
assert len(terms) == 1
assert le.additive_offset == 8

def test_eval(self):
le = linsolve.LinearEquation("a*x-b*y", a=2, b=4)
Expand All @@ -138,6 +139,9 @@ def test_eval(self):
sol = {"x": 3 + 3j * np.ones(10), "y": 7 + 2j * np.ones(10)}
ans = np.conj(sol["x"]) - sol["y"]
np.testing.assert_equal(ans, le.eval(sol))
le = linsolve.LinearEquation("a*b+a*x-b*y", a=2, b=4)
sol = {'x': 3, 'y': 7}
assert 2 * 4 + 2 * 3 - 4 * 7 == le.eval(sol)


class TestLinearSolver:
Expand Down Expand Up @@ -276,6 +280,23 @@ def test_eval(self):
result = ls.eval(sol, "a*x+b*y")
np.testing.assert_almost_equal(3 * 1 + 1 * 2, list(result.values())[0])

def test_eval_const_term(self):
x, y = 1.0, 2.0
a, b = 3.0 * np.ones(4), 1.0
eqs = ["a*b+a*x+y", "a+x+b*y"]
d, w = {}, {}
for eq in eqs:
d[eq], w[eq] = eval(eq) * np.ones(4), np.ones(4)
ls = linsolve.LinearSolver(d, w, a=a, b=b, sparse=self.sparse)
sol = ls.solve()
np.testing.assert_almost_equal(sol["x"], x * np.ones(4, dtype=np.float64))
np.testing.assert_almost_equal(sol["y"], y * np.ones(4, dtype=np.float64))
result = ls.eval(sol)
for eq in d:
np.testing.assert_almost_equal(d[eq], result[eq])
result = ls.eval(sol, "a*b+a*x+b*y")
np.testing.assert_almost_equal(3 * 1 + 3 * 1 + 1 * 2, list(result.values())[0])

def test_chisq(self):
x = 1.0
d = {"x": 1, "a*x": 2}
Expand All @@ -297,6 +318,14 @@ def test_chisq(self):
chisq = ls.chisq(sol)
np.testing.assert_almost_equal(sol["x"], 5.0 / 3.0, 6)
np.testing.assert_almost_equal(ls.chisq(sol), 1.0 / 3.0)
x = 1.0
d = {"1*x+1": 3.0, "x": 1.0}
w = {"1*x+1": 1.0, "x": 0.5}
ls = linsolve.LinearSolver(d, wgts=w, sparse=self.sparse)
sol = ls.solve()
chisq = ls.chisq(sol)
np.testing.assert_almost_equal(sol["x"], 5.0 / 3.0, 6)
np.testing.assert_almost_equal(ls.chisq(sol), 1.0 / 3.0)

def test_dtypes(self):
ls = linsolve.LinearSolver({"x_": 1.0 + 1.0j}, sparse=self.sparse)
Expand Down Expand Up @@ -355,7 +384,7 @@ def test_degen_sol(self):


class TestLinearSolverSparse(TestLinearSolver):
def setup(self):
def setup_method(self):
self.sparse = True
eqs = ["x+y", "x-y"]
x, y = 1, 2
Expand Down Expand Up @@ -461,7 +490,7 @@ def test_dtype(self):


class TestLogProductSolverSparse(TestLogProductSolver):
def setup(self):
def setup_method(self):
self.sparse = True


Expand Down Expand Up @@ -762,5 +791,5 @@ def test_degen_sol(self):


class TestLinProductSolverSparse(TestLinProductSolver):
def setup(self):
def setup_method(self):
self.sparse = True
Loading