Skip to content

Commit

Permalink
Initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
fmeirinhos committed Jan 18, 2024
1 parent 7e9c370 commit a115430
Showing 1 changed file with 29 additions and 53 deletions.
82 changes: 29 additions & 53 deletions kramersmoyal/kmc.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
import numpy as np
from scipy.signal import convolve
from scipy.special import factorial
from itertools import product

from .binning import histogramdd
from .kernels import silvermans_rule, epanechnikov

def km(timeseries: np.ndarray, bins: str='default', powers: int=4,
def km(timeseries: np.ndarray, powers: np.ndarray, bins: np.ndarray=None,
kernel: callable=epanechnikov, bw: float=None, tol: float=1e-10,
conv_method: str='auto', center_edges: bool=True,
full: bool=False) -> (np.ndarray, np.ndarray):
Expand All @@ -21,7 +20,7 @@ def km(timeseries: np.ndarray, bins: str='default', powers: int=4,
The D-dimensional timeseries `(N, D)`. The timeseries of length `N`
and dimensions `D`.
bins: int or list or np.ndarray or string (default `default`)
bins: list or np.ndarray (default `None`)
The number of bins. This is the underlying space for the Kramers─Moyal
coefficients to be estimated. If desired, bins along each dimension can
be given as monotonically increasing bin edges (tuple or list), e.g.,
Expand Down Expand Up @@ -112,70 +111,47 @@ def km(timeseries: np.ndarray, bins: str='default', powers: int=4,
373(39), 3507─3512, 2009.
"""

# Check finiteness, dimensions, and existence of the time series
timeseries = np.asarray_chkfinite(timeseries, dtype=float)
if len(timeseries.shape) == 1:
timeseries = timeseries.reshape(-1, 1)

# safety check, if data not in vertical (N, dims)
assert timeseries.shape[1] < timeseries.shape[0], \
"Timeseries seems to be (D, N) shaped, transpose it: Timeseries.T"

assert len(timeseries.shape) == 2, "Timeseries must be (N, D) shape"
assert timeseries.shape[0] > 0, "No data in timeseries"

timeseries = np.atleast_2d(np.asarray_chkfinite(timeseries, dtype=float))
n, dims = timeseries.shape

# Tranforming powers into right shape
if isinstance(powers, int):
# complicated way of obtaing powers in all dimensions
powers = np.array(sorted(product(*(range(powers + 1),) * dims),
key=lambda x: (max(x), x)))

powers = np.asarray_chkfinite(powers, dtype=float)
if len(powers.shape) == 1:
powers = powers.reshape(-1, 1)

if not (powers[0] == [0] * dims).all():
powers = np.array([[0] * dims, *powers])

assert (powers[0] == [0] * dims).all(), "First power must be zero"
assert dims == powers.shape[1], "Powers not matching timeseries' dimension"
if dims >= n:
raise ValueError("Timeseries should be transposed to (N, D) shape.")

powers = np.atleast_2d(np.asarray_chkfinite(powers, dtype=float))

# NOTE: `_km` always requires the first element of `powers` to map to the pdf
if not np.array_equal(powers[0], [0] * dims):
powers = np.vstack((np.zeros((1, dims)), powers))
remove_pdf = True
else:
remove_pdf = False

# Check and adjust bins
if isinstance(bins, str):
if bins == 'default':
bins = [5000] if dims == 1 else bins
bins = [100] * 2 if dims == 2 else bins
bins = [25] * 3 if dims == 3 else bins
assert dims < 4, "If dimension of timeseries > 3, set bins manually"
if dims != powers.shape[1]:
raise ValueError("Powers dimensions do not match timeseries dimensions.")

if isinstance(bins, int):
bins = [int(bins**(1/dims))] * dims
if bins is None:
# Sturges' formula
bins = np.full((dims,), np.ceil(np.log2(n) + 1), dtype=int)

if isinstance(bins, (list, tuple)):
assert all(isinstance(ele, (int, np.ndarray)) for ele in bins), \
"list or tuples of bins must either be ints or arrays"
bins = np.asarray_chkfinite(bins, dtype=int)

# bins = np.asarray_chkfinite(bins, dtype=int)
assert dims == len(bins), "Bins not matching timeseries' dimension"
if dims != len(bins):
raise ValueError("Bins dimensions do not match timeseries dimensions.")

if bw is None:
bw = silvermans_rule(timeseries)
elif callable(bw):
bw = bw(timeseries)
assert bw > 0.0, "Bandwidth must be > 0"
bw = bw(timeseries) if callable(bw) else silvermans_rule(timeseries) if bw is None else bw
if bw <= 0.0:
raise ValueError("Bandwidth must be positive.")

# This is where the calculations take place
kmc, edges = _km(timeseries, bins, powers, kernel, bw, tol, conv_method)

if remove_pdf:
kmc = kmc[1:, ...]

if center_edges:
edges = [edge[:-1] + 0.5 * (edge[1] - edge[0]) for edge in edges]

if not full:
return (kmc, edges)
else:
return (kmc, edges, bw, powers)
return (kmc, edges, bw, powers) if full else (kmc, edges)


def _km(timeseries: np.ndarray, bins: np.ndarray, powers: np.ndarray,
Expand Down

0 comments on commit a115430

Please sign in to comment.