Skip to content

Commit

Permalink
add fit method
Browse files Browse the repository at this point in the history
  • Loading branch information
HugoMVale committed Oct 13, 2023
1 parent f67407f commit f2c45d6
Showing 1 changed file with 67 additions and 43 deletions.
110 changes: 67 additions & 43 deletions src/polykin/physprops/property_equation.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,13 +192,13 @@ def plot(self,
Trange = (273.15, 373.15)

try:
shape = getattr(self, '_shape')
shape = self._shape
except AttributeError:
shape = None
if shape is not None:
print("Plot method not yet implemented for array-like equations.")
else:
TK = np.linspace(Trange[0], Trange[1], 100)
TK = np.linspace(*Trange, 100)
y = self.__call__(TK, 'K')
TC = TK - 273.15
if Tunit == 'C':
Expand All @@ -218,47 +218,71 @@ def plot(self,
if return_objects:
return (fig, ax)

# def fit(self,
# T: FloatVector,
# Y: FloatVector,
# sigmaY: Optional[FloatVector] = None,
# fitonly: list[str]=[],
# log: bool = False,
# plot: bool = True,
# ):

# # select parameters to be fitted
# params = self._pnames[0]
# if fitonly:
# params = set(fitonly) & set(params)
# args = {p: getattr(self, p) for p in params}

# # log transform
# if log:
# ydata = np.log(Y)
# else:
# ydata = Y

# def ffit(x, p):
# for
# self.eval(x)

# solution = curve_fit(ffit,
# xdata=T,
# ydata=ydata,
# p0=p0,
# sigma=sigmaY,
# absolute_sigma=True,
# full_output=True)
# if solution[4]:
# popt = solution[0]
# cov = solution[1]
# print("Fit successful")
# print(popt)
# print(cov)
# else:
# print("Fit error: ", solution[3])
# pass
def fit(self,
T: FloatVector,
Y: FloatVector,
sigmaY: Optional[FloatVector] = None,
fitonly: list[str] = [],
logY: bool = False,
plot: bool = True,
) -> dict:

# Current parameter values
pnames = self._pnames[0] + self._pnames[1]
pdict = {pname: pvalue for pname, pvalue in zip(pnames, self.pvalues)}

# Select parameters to be fitted
pnames_fit = self._pnames[0]
if fitonly:
pnames_fit = set(fitonly) & set(pnames_fit)
p0 = [pdict[pname] for pname in pnames_fit]

# Fit function
def ffit(x, *p):
for pname, pvalue in zip(pnames_fit, p):
pdict[pname] = pvalue
Yfit = self.equation(T=x, **pdict)
if logY:
Yfit = np.log(Yfit)
return Yfit

solution = curve_fit(ffit,
xdata=T,
ydata=np.log(Y) if logY else Y,
p0=p0,
sigma=sigmaY,
absolute_sigma=True,
full_output=True)
result = {}
result['success'] = bool(solution[4])
if solution[4]:
popt = solution[0]
pcov = solution[1]
print("Fit successful.")
for pname, pvalue in zip(pnames_fit, popt):
print(f"{pname}: {pvalue}")
print("Covariance:")
print(pcov)
result['covariance'] = pcov

# Update attributes
self.Trange = (min(T), max(T))
for pname, pvalue in zip(pnames_fit, popt):
pdict[pname] = pvalue
self.pvalues = tuple(pdict.values())
result['parameters'] = pdict

# plot
if plot:
kind = 'semilogy' if logY else 'linear'
fig, ax = self.plot(kind=kind, return_objects=True)
ax.plot(T, Y, 'o', mfc='none')
result['plot'] = (fig, ax)
else:
print("Fit error: ", solution[3])
result['message'] = solution[3]

return result


# %% Functions
Expand Down

0 comments on commit f2c45d6

Please sign in to comment.