Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Flexible fit functions #195

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 47 additions & 2 deletions fooof/core/funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,12 @@
- They are left available for easy swapping back in, if desired.
"""

from inspect import isfunction

import numpy as np

from scipy.stats import norm

from fooof.core.errors import InconsistentDataError

###################################################################################################
Expand Down Expand Up @@ -41,6 +45,43 @@ def gaussian_function(xs, *params):
return ys


def skewed_gaussian_function(xs, *params):
"""Skewed gaussian fitting function.

Parameters
----------
xs : 1d array
Input x-axis values.
*params : float
Parameters that define the skewed gaussian function (center, height, width, alpha).

Returns
-------
ys : 1d array
Output values for skewed gaussian function.
"""

ys = np.zeros_like(xs)

for ii in range(0, len(params), 4):

ctr, hgt, wid, alpha = params[ii:ii+4]

# Gaussian distribution
ys = gaussian_function(xs, ctr, hgt, wid)

# Skewed cumulative distribution function
cdf = norm.cdf(alpha * ((xs - ctr) / wid))

# Skew the gaussian
ys = ys * cdf

# Rescale height
ys = (ys / np.max(ys)) * hgt

return ys


def expo_function(xs, *params):
"""Exponential fitting function, for fitting aperiodic component with a 'knee'.

Expand Down Expand Up @@ -167,7 +208,9 @@ def get_pe_func(periodic_mode):

"""

if periodic_mode == 'gaussian':
if isfunction(periodic_mode):
pe_func = periodic_mode
elif periodic_mode == 'gaussian':
pe_func = gaussian_function
else:
raise ValueError("Requested periodic mode not understood.")
Expand All @@ -194,7 +237,9 @@ def get_ap_func(aperiodic_mode):
If the specified aperiodic mode label is not understood.
"""

if aperiodic_mode == 'fixed':
if isfunction(aperiodic_mode):
ap_func = aperiodic_mode
elif aperiodic_mode == 'fixed':
ap_func = expo_nk_function
elif aperiodic_mode == 'knee':
ap_func = expo_function
Expand Down
28 changes: 18 additions & 10 deletions fooof/objs/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@
from fooof.core.reports import save_report_fm
from fooof.core.modutils import copy_doc_func_to_method
from fooof.core.utils import group_three, check_array_dim
from fooof.core.funcs import gaussian_function, get_ap_func, infer_ap_func
from fooof.core.funcs import get_pe_func, get_ap_func, infer_ap_func
from fooof.core.errors import (FitError, NoModelError, DataError,
NoDataError, InconsistentDataError)
from fooof.core.strings import (gen_settings_str, gen_results_fm_str,
Expand Down Expand Up @@ -154,8 +154,9 @@ class FOOOF():
"""
# pylint: disable=attribute-defined-outside-init

def __init__(self, peak_width_limits=(0.5, 12.0), max_n_peaks=np.inf, min_peak_height=0.0,
peak_threshold=2.0, aperiodic_mode='fixed', verbose=True):
def __init__(self, peak_width_limits=(0.5, 12.0), max_n_peaks=np.inf,
min_peak_height=0.0, peak_threshold=2.0, aperiodic_mode='fixed',
periodic_mode='gaussian', verbose=True):
"""Initialize object with desired settings."""

# Set input settings
Expand All @@ -164,6 +165,7 @@ def __init__(self, peak_width_limits=(0.5, 12.0), max_n_peaks=np.inf, min_peak_h
self.min_peak_height = min_peak_height
self.peak_threshold = peak_threshold
self.aperiodic_mode = aperiodic_mode
self.periodic_mode = periodic_mode
self.verbose = verbose

## PRIVATE SETTINGS
Expand Down Expand Up @@ -439,6 +441,9 @@ def fit(self, freqs=None, power_spectrum=None, freq_range=None):
if self.verbose:
self._check_width_limits()

# Determine the aperiodic and periodic fit funcs
self._set_fit_funcs()

# In rare cases, the model fails to fit, and so uses try / except
try:

Expand Down Expand Up @@ -715,6 +720,11 @@ def set_check_data_mode(self, check_data):

self._check_data = check_data

def _set_fit_funcs(self):
"""Set the requested aperiodic and periodic fit functions."""

self._pe_func = get_pe_func(self.periodic_mode)
self._ap_func = get_ap_func(self.aperiodic_mode)

def _check_width_limits(self):
"""Check and warn about peak width limits / frequency resolution interaction."""
Expand Down Expand Up @@ -762,8 +772,7 @@ def _simple_ap_fit(self, freqs, power_spectrum):
try:
with warnings.catch_warnings():
warnings.simplefilter("ignore")
aperiodic_params, _ = curve_fit(get_ap_func(self.aperiodic_mode),
freqs, power_spectrum, p0=guess,
aperiodic_params, _ = curve_fit(self._ap_func, freqs, power_spectrum, p0=guess,
maxfev=self._maxfev, bounds=ap_bounds)
except RuntimeError:
raise FitError("Model fitting failed due to not finding parameters in "
Expand Down Expand Up @@ -818,9 +827,8 @@ def _robust_ap_fit(self, freqs, power_spectrum):
try:
with warnings.catch_warnings():
warnings.simplefilter("ignore")
aperiodic_params, _ = curve_fit(get_ap_func(self.aperiodic_mode),
freqs_ignore, spectrum_ignore, p0=popt,
maxfev=self._maxfev, bounds=ap_bounds)
aperiodic_params, _ = curve_fit(self._ap_func, freqs_ignore, spectrum_ignore,
p0=popt, maxfev=self._maxfev, bounds=ap_bounds)
except RuntimeError:
raise FitError("Model fitting failed due to not finding "
"parameters in the robust aperiodic fit.")
Expand Down Expand Up @@ -904,7 +912,7 @@ def _fit_peaks(self, flat_iter):

# Collect guess parameters and subtract this guess gaussian from the data
guess = np.vstack((guess, (guess_freq, guess_height, guess_std)))
peak_gauss = gaussian_function(self.freqs, guess_freq, guess_height, guess_std)
peak_gauss = self._pe_func(self.freqs, guess_freq, guess_height, guess_std)
flat_iter = flat_iter - peak_gauss

# Check peaks based on edges, and on overlap, dropping any that violate requirements
Expand Down Expand Up @@ -963,7 +971,7 @@ def _fit_peak_guess(self, guess):

# Fit the peaks
try:
gaussian_params, _ = curve_fit(gaussian_function, self.freqs, self._spectrum_flat,
gaussian_params, _ = curve_fit(self._pe_func, self.freqs, self._spectrum_flat,
p0=guess, maxfev=self._maxfev, bounds=gaus_param_bounds)
except RuntimeError:
raise FitError("Model fitting failed due to not finding "
Expand Down
14 changes: 14 additions & 0 deletions fooof/tests/core/test_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,20 @@ def test_gaussian_function():
assert max(ys) == hgt
assert np.allclose([i/sum(ys) for i in ys], norm.pdf(xs, ctr, wid))

def test_skewed_gaussian_function():

ctr, hgt, wid, alpha = 50, 5, 10, 4

xs = np.arange(1, 100)
ys_gaussian = gaussian_function(xs, ctr, hgt, wid)
ys = skewed_gaussian_function(xs, ctr, hgt, wid, alpha)

assert np.all(ys)

# Positive alphas shift the max to the right
assert np.argmax(ys) >= np.argmax(ys_gaussian)
assert np.max(ys) == np.max(ys_gaussian) == hgt

def test_expo_function():

off, knee, exp = 10, 5, 2
Expand Down