Skip to content

Adding support for TKEO operation #3668

Open
@jesusdpa1

Description

@jesusdpa1

Hi I have written this implementation of TKEO but I am not sure where this can be added for a pull request

from enum import Enum, auto

import numpy as np
from spikeinterface.core import get_chunk_with_margin
from spikeinterface.core.core_tools import define_function_from_class
from spikeinterface.preprocessing.basepreprocessor import (
    BasePreprocessor,
    BasePreprocessorSegment,
)


class TKEOMethod(Enum):
    """Enumeration of TKEO calculation methods"""

    LI_2007 = auto()  # Li et al. 2007 (2 samples)
    DEBURCHGRAVE_2008 = auto()  # Deburchgrave et al. 2008 (4 samples)
    ORIGINAL = auto()  # Original Teager-Kaiser method


class TKEORecording(BasePreprocessor):
    def __init__(
        self,
        recording,
        margin_ms=5.0,
        dtype=None,
        tkeo_method=TKEOMethod.DEBURCHGRAVE_2008,
        add_reflect_padding=False,
    ):
        dtype = self._fix_dtype(recording, dtype)

        BasePreprocessor.__init__(self, recording, dtype=dtype)
        self.annotate(is_tkeo=True)

        if "offset_to_uV" in self.get_property_keys():
            self.set_channel_offsets(0)

        margin = int(margin_ms * recording.get_sampling_frequency() / 1000.0)
        for parent_segment in recording._recording_segments:
            self.add_recording_segment(
                TKEORecordingSegment(
                    parent_segment,
                    margin,
                    dtype,
                    tkeo_method=tkeo_method,
                    add_reflect_padding=add_reflect_padding,
                )
            )

        self._kwargs = dict(
            recording=recording,
            margin_ms=margin_ms,
            dtype=dtype.str,
            tkeo_method=tkeo_method,
            add_reflect_padding=add_reflect_padding,
        )

    @staticmethod
    def _fix_dtype(recording, dtype):
        if dtype is None:
            dtype = recording.get_dtype()
        dtype = np.dtype(dtype)

        # if uint --> force int
        if dtype.kind == "u":
            dtype = np.dtype(dtype.str.replace("u", "i"))

        return dtype


class TKEORecordingSegment(BasePreprocessorSegment):
    def __init__(
        self,
        parent_recording_segment,
        margin,
        dtype,
        tkeo_method=TKEOMethod.DEBURCHGRAVE_2008,
        add_reflect_padding=False,
    ):
        BasePreprocessorSegment.__init__(self, parent_recording_segment)
        self.margin = margin
        self.add_reflect_padding = add_reflect_padding
        self.dtype = dtype
        self.tkeo_method = tkeo_method

    def get_traces(self, start_frame, end_frame, channel_indices):
        traces_chunk, left_margin, right_margin = get_chunk_with_margin(
            self.parent_recording_segment,
            start_frame,
            end_frame,
            channel_indices,
            self.margin,
            add_reflect_padding=self.add_reflect_padding,
        )

        # Apply TKEO with selected method
        tkeo_traces = self.apply_tkeo(traces_chunk)

        if right_margin > 0:
            tkeo_traces = tkeo_traces[left_margin:-right_margin, :]
        else:
            tkeo_traces = tkeo_traces[left_margin:, :]

        if np.issubdtype(self.dtype, np.integer):
            tkeo_traces = tkeo_traces.round()

        return tkeo_traces.astype(self.dtype)

    def apply_tkeo(self, traces):
        """
        Apply TKEO based on selected method

        Parameters:
        -----------
        traces : np.ndarray
            Input traces

        Returns:
        --------
        np.ndarray
            TKEO-transformed traces
        """
        if self.tkeo_method == TKEOMethod.LI_2007:
            # Li et al. 2007 method (2 samples)
            return np.abs(traces[:-2] * (traces[1:-1] ** 2 - traces[:-2] * traces[2:]))

        elif self.tkeo_method == TKEOMethod.DEBURCHGRAVE_2008:
            # Deburchgrave et al. 2008 method (4 samples)
            result = np.zeros_like(traces)
            result[2:-2] = traces[2:-2] * (
                traces[3:-1] ** 2 - traces[2:-2] * traces[4:]
            )
            return np.abs(result)

        elif self.tkeo_method == TKEOMethod.ORIGINAL:
            # Original Teager-Kaiser method
            result = np.zeros_like(traces)
            result[1:-1] = traces[1:-1] ** 2 - traces[:-2] * traces[2:]
            return np.abs(result)

        else:
            raise ValueError(f"Unknown TKEO method: {self.tkeo_method}")


tkeo_transform = define_function_from_class(
    source_class=TKEORecording, name="tkeo_transform"
)

Metadata

Metadata

Assignees

No one assigned

    Labels

    preprocessingRelated to preprocessing module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions