Skip to content

Commit

Permalink
Benchmarking code improvement
Browse files Browse the repository at this point in the history
  • Loading branch information
MerlinDumeur committed May 25, 2024
1 parent 1198bf2 commit ffb6513
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 3 deletions.
78 changes: 78 additions & 0 deletions pymultifracs/robust/benchmark.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,91 @@
from dataclasses import dataclass, field
from typing import Any, Callable

import numpy as np
import pandas as pd

from tqdm.auto import tqdm
from joblib import Parallel, delayed

from .. import wavelet_analysis, mfa
from .robust import get_outliers
from ..simul.noisy import gen_noisy


def get_grid(param_grid):

series = [
pd.DataFrame({name: signal_param})
for name, signal_param in param_grid.items()]

out = series[0]

for s in series[1:]:
out = out.merge(s, how='cross')

return out


@dataclass
class Benchmark:
signal_param_grid: dict[str, np.ndarray]
noise_param_grid: dict[str, np.ndarray]
signal_gen_func: Callable
noise_gen_func: Callable
estimation_grid: dict[str, Callable]
WT_params: dict[str, Any]
results: pd.DataFrame = field(init=False, repr=False)

def run(self, n_rep):

results = {}

signal_grid = get_grid(self.signal_param_grid)
noise_grid = get_grid(self.noise_param_grid)

for signal_params in signal_grid.itertuples(index=False):

signal_names = [*signal_params._fields]
signal_params = signal_params._asdict()

# X = self.signal_gen_func(**signal_params)
# print(X.shape)

X = np.c_[
*[self.signal_gen_func(**signal_params)
for i in range(n_rep)]]

# for repetition in range(n_rep):

for noise_params in noise_grid.itertuples(index=False):

noise_names = [*noise_params._fields]
noise_params = noise_params._asdict()

X_noisy = self.noise_gen_func(X, **noise_params)

for method, est_fun in tqdm(self.estimation_grid.items()):

WT = wavelet_analysis(X_noisy, **self.WT_params)

results[(method, *signal_params.values(), *noise_params.values())] = [est_fun(WT)]

self.results = pd.DataFrame.from_dict(results).transpose()

for i, name in enumerate(signal_names):
if name in noise_names:

signal_names[i] = name + '_signal'
noise_names[noise_names.index(name)] = name + '_noise'

self.results.index.names = [
'method', *signal_names, *noise_names]
self.results.columns.names = ['cumulants']

def plot(self):
pass


def estimate(gen_func, robust_cm=False, bootstrap_weight=False,
outlier_detect=False, alpha=1, generalized=False,
gen_func_kwargs=None, robust_kwargs=None):
Expand Down
4 changes: 2 additions & 2 deletions pymultifracs/robust/robust.py
Original file line number Diff line number Diff line change
Expand Up @@ -1016,8 +1016,8 @@ def cluster_reject_leaders(j1, j2, cm, leaders, pelt_beta, verbose=False,
continue

right_edge = np.nanmax(agg[:, j, idx_range, idx_signal])
bins = np.sort(
np.r_[1, 1-np.geomspace(1 - right_edge, 1, N_bins-1)])
# bins = np.sort(
# np.r_[1, 1-np.geomspace(1 - right_edge, 1, N_bins-1)])

for i in range(len(samples)):

Expand Down
5 changes: 4 additions & 1 deletion pymultifracs/simul/mrw.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# import matplotlib.pyplot as plt


def mrw(shape, H, lam, L, sigma=1, method='cme', z0=(None, None)):
def mrw(shape, H, lam, L=None, sigma=1, method='cme', z0=(None, None)):
'''
Create a realization of fractional Brownian motion using circulant
matrix embedding.
Expand Down Expand Up @@ -53,6 +53,9 @@ def mrw(shape, H, lam, L, sigma=1, method='cme', z0=(None, None)):
if not 0 <= H <= 1:
raise ValueError('H must satisfy 0 <= H <= 1')

if L is None:
L = N

if L > N:
raise ValueError('Integral scale L is larger than data length N')

Expand Down
3 changes: 3 additions & 0 deletions pymultifracs/wavelet.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,9 @@ def filtering2(approx, wt):

if approx.shape[0] % 2 == 1:
return -high[:-1], low[fp:lp]

# if approx.shape[0] % 2 == 1:
return -high[:-1], low[fp:lp]

if lp == -1:
low_slice = np.s_[fp:]
Expand Down

0 comments on commit ffb6513

Please sign in to comment.