Open
Description
Hi I have written this implementation of TKEO but I am not sure where this can be added for a pull request
from enum import Enum, auto
import numpy as np
from spikeinterface.core import get_chunk_with_margin
from spikeinterface.core.core_tools import define_function_from_class
from spikeinterface.preprocessing.basepreprocessor import (
BasePreprocessor,
BasePreprocessorSegment,
)
class TKEOMethod(Enum):
"""Enumeration of TKEO calculation methods"""
LI_2007 = auto() # Li et al. 2007 (2 samples)
DEBURCHGRAVE_2008 = auto() # Deburchgrave et al. 2008 (4 samples)
ORIGINAL = auto() # Original Teager-Kaiser method
class TKEORecording(BasePreprocessor):
def __init__(
self,
recording,
margin_ms=5.0,
dtype=None,
tkeo_method=TKEOMethod.DEBURCHGRAVE_2008,
add_reflect_padding=False,
):
dtype = self._fix_dtype(recording, dtype)
BasePreprocessor.__init__(self, recording, dtype=dtype)
self.annotate(is_tkeo=True)
if "offset_to_uV" in self.get_property_keys():
self.set_channel_offsets(0)
margin = int(margin_ms * recording.get_sampling_frequency() / 1000.0)
for parent_segment in recording._recording_segments:
self.add_recording_segment(
TKEORecordingSegment(
parent_segment,
margin,
dtype,
tkeo_method=tkeo_method,
add_reflect_padding=add_reflect_padding,
)
)
self._kwargs = dict(
recording=recording,
margin_ms=margin_ms,
dtype=dtype.str,
tkeo_method=tkeo_method,
add_reflect_padding=add_reflect_padding,
)
@staticmethod
def _fix_dtype(recording, dtype):
if dtype is None:
dtype = recording.get_dtype()
dtype = np.dtype(dtype)
# if uint --> force int
if dtype.kind == "u":
dtype = np.dtype(dtype.str.replace("u", "i"))
return dtype
class TKEORecordingSegment(BasePreprocessorSegment):
def __init__(
self,
parent_recording_segment,
margin,
dtype,
tkeo_method=TKEOMethod.DEBURCHGRAVE_2008,
add_reflect_padding=False,
):
BasePreprocessorSegment.__init__(self, parent_recording_segment)
self.margin = margin
self.add_reflect_padding = add_reflect_padding
self.dtype = dtype
self.tkeo_method = tkeo_method
def get_traces(self, start_frame, end_frame, channel_indices):
traces_chunk, left_margin, right_margin = get_chunk_with_margin(
self.parent_recording_segment,
start_frame,
end_frame,
channel_indices,
self.margin,
add_reflect_padding=self.add_reflect_padding,
)
# Apply TKEO with selected method
tkeo_traces = self.apply_tkeo(traces_chunk)
if right_margin > 0:
tkeo_traces = tkeo_traces[left_margin:-right_margin, :]
else:
tkeo_traces = tkeo_traces[left_margin:, :]
if np.issubdtype(self.dtype, np.integer):
tkeo_traces = tkeo_traces.round()
return tkeo_traces.astype(self.dtype)
def apply_tkeo(self, traces):
"""
Apply TKEO based on selected method
Parameters:
-----------
traces : np.ndarray
Input traces
Returns:
--------
np.ndarray
TKEO-transformed traces
"""
if self.tkeo_method == TKEOMethod.LI_2007:
# Li et al. 2007 method (2 samples)
return np.abs(traces[:-2] * (traces[1:-1] ** 2 - traces[:-2] * traces[2:]))
elif self.tkeo_method == TKEOMethod.DEBURCHGRAVE_2008:
# Deburchgrave et al. 2008 method (4 samples)
result = np.zeros_like(traces)
result[2:-2] = traces[2:-2] * (
traces[3:-1] ** 2 - traces[2:-2] * traces[4:]
)
return np.abs(result)
elif self.tkeo_method == TKEOMethod.ORIGINAL:
# Original Teager-Kaiser method
result = np.zeros_like(traces)
result[1:-1] = traces[1:-1] ** 2 - traces[:-2] * traces[2:]
return np.abs(result)
else:
raise ValueError(f"Unknown TKEO method: {self.tkeo_method}")
tkeo_transform = define_function_from_class(
source_class=TKEORecording, name="tkeo_transform"
)