Skip to content

Commit

Permalink
Merge pull request #60 from HERA-Team/const_term
Browse files Browse the repository at this point in the history
Added ability for linsolve equations to have purely constant terms, which are just subracted off of the equated value when solving.
  • Loading branch information
AaronParsons authored Jan 9, 2025
2 parents 9f72a28 + a320f22 commit 3a9c489
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 10 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# Overview

The solvers in `linsolve` include `LinearSolver`, `LogProductSolver`, and `LinProductSolver`.
`LinearSolver` solves linear equations of the form `'a*x + b*y + c*z'`.
`LinearSolver` solves linear equations of the form `'a*x + b*y + c*z + d'`.
`LogProductSolver` uses logrithms to linearize equations of the form `'x*y*z'`.
`LinProductSolver` uses symbolic Taylor expansion to linearize equations of the
form `'x*y + y*z'`.
Expand Down
21 changes: 15 additions & 6 deletions src/linsolve/linsolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,14 +12,14 @@
describing the equation (which is parsed according to python syntax) and each
value is the corresponding "measured" value of that equation. Variable names
in equations are checked against keyword arguments to the solver to determine
if they are provided constants or parameters to be solved for. Parameter anmes
if they are provided constants or parameters to be solved for. Parameter names
and solutions are return are returned as key:value pairs in ls.solve().
Parallel instances of equations can be evaluated by providing measured values
as numpy arrays. Constants can also be arrays that comply with standard numpy
broadcasting rules. Finally, weighting is implemented through an optional wgts
dictionary that parallels the construction of data.
LinearSolver solves linear equations of the form 'a*x + b*y + c*z'.
LinearSolver solves linear equations of the form 'a*x + b*y + c*z + d'.
LogProductSolver uses logrithms to linearize equations of the form 'x*y*z'.
LinProductSolver uses symbolic Taylor expansion to linearize equations of the
form 'x*y + y*z'.
Expand Down Expand Up @@ -177,6 +177,7 @@ def __init__(self, val, **kwargs):
self.wgts = kwargs.pop("wgts", np.float32(1.0))
self.has_conj = False
constants = kwargs.pop("constants", kwargs)
self.additive_offset = np.float32(0.0)
self.process_terms(val, constants)

def process_terms(self, terms, constants):
Expand Down Expand Up @@ -211,11 +212,17 @@ def order_terms(self, terms):
for L in terms:
L.sort(key=lambda x: get_name(x) in self.prms)
# Validate that each term has exactly 1 unsolved parameter.
final_terms = []
for t in terms:
assert get_name(t[-1]) in self.prms
# Check if this term has no free parameters (i.e. additive constant)
if get_name(t[-1]) not in self.prms:
self.additive_offset += self.eval_consts(t)
continue
# Make sure there is no more than 1 free parameter per term
for ti in t[:-1]:
assert type(ti) is not str or get_name(ti) in self.consts
return terms
final_terms.append(t)
return final_terms

def eval_consts(self, const_list, wgts=np.float32(1.0)):
"""Multiply out constants (and wgts) for placing in matrix."""
Expand Down Expand Up @@ -251,6 +258,8 @@ def eval(self, sol):
else:
total *= sol[name]
rv += total
# add back in purely constant terms, which were filtered out of self.terms
rv += self.additive_offset
return rv


Expand Down Expand Up @@ -299,7 +308,7 @@ def infer_dtype(values):

class LinearSolver:
def __init__(self, data, wgts={}, sparse=False, **kwargs):
"""Set up a linear system of equations of the form 1*a + 2*b + 3*c = 4.
"""Set up a linear system of equations of the form 1*a + 2*b + 3*c + 4 = 5.
Parameters
----------
Expand Down Expand Up @@ -430,7 +439,7 @@ def get_weighted_data(self):
dtype = np.complex64
else:
dtype = np.complex128
d = np.array([self.data[k] for k in self.keys], dtype=dtype)
d = np.array([self.data[k] - eq.additive_offset for k, eq in zip(self.keys, self.eqs)], dtype=dtype)
if len(self.wgts) > 0:
w = np.array([self.wgts[k] for k in self.keys])
w.shape += (1,) * (d.ndim - w.ndim)
Expand Down
35 changes: 32 additions & 3 deletions tests/test_linsolve.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,9 +124,10 @@ def test_term_check(self):
terms4 = [["c", "x", "a"], [1, "b", "y"]]
with pytest.raises(AssertionError):
le.order_terms(terms4)
terms5 = [[1, "a", "b"], [1, "b", "y"]]
with pytest.raises(AssertionError):
le.order_terms(terms5)
terms5 = [["a", "b"], [1, "b", "y"]]
terms = le.order_terms(terms5)
assert len(terms) == 1
assert le.additive_offset == 8

def test_eval(self):
le = linsolve.LinearEquation("a*x-b*y", a=2, b=4)
Expand All @@ -138,6 +139,9 @@ def test_eval(self):
sol = {"x": 3 + 3j * np.ones(10), "y": 7 + 2j * np.ones(10)}
ans = np.conj(sol["x"]) - sol["y"]
np.testing.assert_equal(ans, le.eval(sol))
le = linsolve.LinearEquation("a*b+a*x-b*y", a=2, b=4)
sol = {'x': 3, 'y': 7}
assert 2 * 4 + 2 * 3 - 4 * 7 == le.eval(sol)


class TestLinearSolver:
Expand Down Expand Up @@ -276,6 +280,23 @@ def test_eval(self):
result = ls.eval(sol, "a*x+b*y")
np.testing.assert_almost_equal(3 * 1 + 1 * 2, list(result.values())[0])

def test_eval_const_term(self):
x, y = 1.0, 2.0
a, b = 3.0 * np.ones(4), 1.0
eqs = ["a*b+a*x+y", "a+x+b*y"]
d, w = {}, {}
for eq in eqs:
d[eq], w[eq] = eval(eq) * np.ones(4), np.ones(4)
ls = linsolve.LinearSolver(d, w, a=a, b=b, sparse=self.sparse)
sol = ls.solve()
np.testing.assert_almost_equal(sol["x"], x * np.ones(4, dtype=np.float64))
np.testing.assert_almost_equal(sol["y"], y * np.ones(4, dtype=np.float64))
result = ls.eval(sol)
for eq in d:
np.testing.assert_almost_equal(d[eq], result[eq])
result = ls.eval(sol, "a*b+a*x+b*y")
np.testing.assert_almost_equal(3 * 1 + 3 * 1 + 1 * 2, list(result.values())[0])

def test_chisq(self):
x = 1.0
d = {"x": 1, "a*x": 2}
Expand All @@ -297,6 +318,14 @@ def test_chisq(self):
chisq = ls.chisq(sol)
np.testing.assert_almost_equal(sol["x"], 5.0 / 3.0, 6)
np.testing.assert_almost_equal(ls.chisq(sol), 1.0 / 3.0)
x = 1.0
d = {"1*x+1": 3.0, "x": 1.0}
w = {"1*x+1": 1.0, "x": 0.5}
ls = linsolve.LinearSolver(d, wgts=w, sparse=self.sparse)
sol = ls.solve()
chisq = ls.chisq(sol)
np.testing.assert_almost_equal(sol["x"], 5.0 / 3.0, 6)
np.testing.assert_almost_equal(ls.chisq(sol), 1.0 / 3.0)

def test_dtypes(self):
ls = linsolve.LinearSolver({"x_": 1.0 + 1.0j}, sparse=self.sparse)
Expand Down

0 comments on commit 3a9c489

Please sign in to comment.