Skip to content

Commit

Permalink
NHT changes plots combined fit (#166)
Browse files Browse the repository at this point in the history
* NHT changes plots combined fit

* feat: Exception for illegal combination added and test fixed.

---------

Co-authored-by: Fabian Joswig <[email protected]>
  • Loading branch information
nils-ht and fjosw authored Mar 17, 2023
1 parent 83204ce commit 2363b75
Show file tree
Hide file tree
Showing 2 changed files with 46 additions and 4 deletions.
14 changes: 10 additions & 4 deletions pyerrors/correlators.py
Original file line number Diff line number Diff line change
Expand Up @@ -788,7 +788,7 @@ def set_prange(self, prange):
self.prange = prange
return

def show(self, x_range=None, comp=None, y_range=None, logscale=False, plateau=None, fit_res=None, ylabel=None, save=None, auto_gamma=False, hide_sigma=None, references=None, title=None):
def show(self, x_range=None, comp=None, y_range=None, logscale=False, plateau=None, fit_res=None, fit_key=None, ylabel=None, save=None, auto_gamma=False, hide_sigma=None, references=None, title=None):
"""Plots the correlator using the tag of the correlator as label if available.
Parameters
Expand All @@ -804,6 +804,8 @@ def show(self, x_range=None, comp=None, y_range=None, logscale=False, plateau=No
Plateau value to be visualized in the figure.
fit_res : Fit_result
Fit_result object to be visualized.
fit_key : str
Key for the fit function in Fit_result.fit_function (for combined fits).
ylabel : str
Label for the y-axis.
save : str
Expand Down Expand Up @@ -883,9 +885,13 @@ def show(self, x_range=None, comp=None, y_range=None, logscale=False, plateau=No

if fit_res:
x_samples = np.arange(x_range[0], x_range[1] + 1, 0.05)
ax1.plot(x_samples,
fit_res.fit_function([o.value for o in fit_res.fit_parameters], x_samples),
ls='-', marker=',', lw=2)
if isinstance(fit_res.fit_function, dict):
if fit_key:
ax1.plot(x_samples, fit_res.fit_function[fit_key]([o.value for o in fit_res.fit_parameters], x_samples), ls='-', marker=',', lw=2)
else:
raise ValueError("Please provide a 'fit_key' for visualizing combined fits.")
else:
ax1.plot(x_samples, fit_res.fit_function([o.value for o in fit_res.fit_parameters], x_samples), ls='-', marker=',', lw=2)

ax1.set_xlabel(r'$x_0 / a$')
if ylabel:
Expand Down
36 changes: 36 additions & 0 deletions tests/fits_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -681,6 +681,42 @@ def func_b(a,x):

pe.least_squares(xs, ys, funcs, num_grad=True)

def test_plot_combined_fit_function():

def func_exp1(x):
return 0.3*anp.exp(0.5*x)

def func_exp2(x):
return 0.3*anp.exp(0.8*x)

xvals_b = np.arange(0,6)
xvals_a = np.arange(0,8)

def func_a(a,x):
return a[0]*anp.exp(a[1]*x)

def func_b(a,x):
return a[0]*anp.exp(a[2]*x)

corr_a = pe.Corr([pe.Obs([np.random.normal(item, item*1.5, 1000)],['ensemble1']) for item in func_exp1(xvals_a)])
corr_b = pe.Corr([pe.Obs([np.random.normal(item, item*1.4, 1000)],['ensemble1']) for item in func_exp2(xvals_b)])

funcs = {'a':func_a, 'b':func_b}
xs = {'a':xvals_a, 'b':xvals_b}
ys = {'a': [o[0] for o in corr_a.content],
'b': [o[0] for o in corr_b.content]}

corr_a.gm()
corr_b.gm()

comb_fit = pe.least_squares(xs, ys, funcs)

with pytest.raises(ValueError):
corr_a.show(x_range=[xs["a"][0], xs["a"][-1]], fit_res=comb_fit)

corr_a.show(x_range=[xs["a"][0], xs["a"][-1]], fit_res=comb_fit, fit_key="a")
corr_b.show(x_range=[xs["b"][0], xs["b"][-1]], fit_res=comb_fit, fit_key="b")


def test_combined_fit_invalid_fit_functions():
def func1(a, x):
Expand Down

0 comments on commit 2363b75

Please sign in to comment.