From c252b9c1b895cb59847eb976ebfacd9415606e88 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Mon, 24 Feb 2025 18:08:03 +0100 Subject: [PATCH 1/2] Expose gather mode to tridesclous2 and spykingcircus2 --- src/spikeinterface/core/node_pipeline.py | 6 ++--- .../sorters/internal/spyking_circus2.py | 16 +++++++++---- .../internal/tests/test_spykingcircus2.py | 24 +++++++++++++++++-- .../internal/tests/test_tridesclous2.py | 23 +++++++++++++++++- .../sorters/internal/tridesclous2.py | 19 +++++++++++---- .../sortingcomponents/matching/main.py | 24 +++++++++++++++---- 6 files changed, 93 insertions(+), 19 deletions(-) diff --git a/src/spikeinterface/core/node_pipeline.py b/src/spikeinterface/core/node_pipeline.py index d510204467..d0493645c7 100644 --- a/src/spikeinterface/core/node_pipeline.py +++ b/src/spikeinterface/core/node_pipeline.py @@ -560,14 +560,14 @@ def run_node_pipeline( The classical job_kwargs job_name : str The name of the pipeline used for the progress_bar - gather_mode : "memory" | "npz" - + gather_mode : "memory" | "npy" + How to gather the output of the nodes. gather_kwargs : dict OPtions to control the "gather engine". See GatherToMemory or GatherToNpy. squeeze_output : bool, default True If only one output node then squeeze the tuple folder : str | Path | None - Used for gather_mode="npz" + Used for gather_mode="npy" names : list of str Names of outputs. verbose : bool, default False diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index 36b1383229..681c66ada3 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -51,7 +51,7 @@ class Spykingcircus2Sorter(ComponentsBasedSorter): }, }, "clustering": {"legacy": True}, - "matching": {"method": "circus-omp-svd"}, + "matching": {"method": "circus-omp-svd", "gather_mode": "memory"}, "apply_preprocessing": True, "matched_filtering": True, "cache_preprocessing": {"mode": "memory", "memory_limit": 0.5, "delete_cache": True}, @@ -321,14 +321,22 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): ## We launch a OMP matching pursuit by full convolution of the templates and the raw traces matching_method = params["matching"].pop("method") - matching_params = params["matching"].copy() + gather_mode = params["matching"].pop("gather_mode", "memory") + gather_kwargs = params["matching"].pop("gather_kwargs", {}) + matching_params = params["matching"].get("method_kwargs", {}).copy() matching_params["templates"] = templates if matching_method is not None: + if gather_mode == "npy": + gather_kwargs["folder"] = gather_kwargs.get("folder", sorter_output_folder / "matching") spikes = find_spikes_from_templates( - recording_w, matching_method, method_kwargs=matching_params, **job_kwargs + recording_w, + matching_method, + method_kwargs=matching_params, + gather_mode=gather_mode, + gather_kwargs=gather_kwargs, + **job_kwargs, ) - if params["debug"]: fitting_folder = sorter_output_folder / "fitting" fitting_folder.mkdir(parents=True, exist_ok=True) diff --git a/src/spikeinterface/sorters/internal/tests/test_spykingcircus2.py b/src/spikeinterface/sorters/internal/tests/test_spykingcircus2.py index df6e3821bb..8ab81e6c7c 100644 --- a/src/spikeinterface/sorters/internal/tests/test_spykingcircus2.py +++ b/src/spikeinterface/sorters/internal/tests/test_spykingcircus2.py @@ -1,8 +1,7 @@ import unittest from spikeinterface.sorters.tests.common_tests import SorterCommonTestSuite - -from spikeinterface.sorters import Spykingcircus2Sorter +from spikeinterface.sorters import Spykingcircus2Sorter, run_sorter from pathlib import Path @@ -10,6 +9,27 @@ class SpykingCircus2SorterCommonTestSuite(SorterCommonTestSuite, unittest.TestCase): SorterClass = Spykingcircus2Sorter + def test_with_numpy_gather(self): + recording = self.recording + sorter_name = self.SorterClass.sorter_name + output_folder = self.cache_folder / sorter_name + sorter_params = self.SorterClass.default_params() + + sorter_params["matching"]["gather_mode"] = "npy" + + sorting = run_sorter( + sorter_name, + recording, + folder=output_folder, + remove_existing_folder=True, + delete_output_folder=False, + verbose=False, + raise_error=True, + **sorter_params, + ) + assert (output_folder / "sorter_output" / "matching").is_dir() + assert (output_folder / "sorter_output" / "matching" / "spikes.npy").is_file() + if __name__ == "__main__": from spikeinterface import set_global_job_kwargs diff --git a/src/spikeinterface/sorters/internal/tests/test_tridesclous2.py b/src/spikeinterface/sorters/internal/tests/test_tridesclous2.py index b256dd1328..659a73ec50 100644 --- a/src/spikeinterface/sorters/internal/tests/test_tridesclous2.py +++ b/src/spikeinterface/sorters/internal/tests/test_tridesclous2.py @@ -2,7 +2,7 @@ from spikeinterface.sorters.tests.common_tests import SorterCommonTestSuite -from spikeinterface.sorters import Tridesclous2Sorter +from spikeinterface.sorters import Tridesclous2Sorter, run_sorter from pathlib import Path @@ -10,6 +10,27 @@ class Tridesclous2SorterCommonTestSuite(SorterCommonTestSuite, unittest.TestCase): SorterClass = Tridesclous2Sorter + def test_with_numpy_gather(self): + recording = self.recording + sorter_name = self.SorterClass.sorter_name + output_folder = self.cache_folder / sorter_name + sorter_params = self.SorterClass.default_params() + + sorter_params["matching"]["gather_mode"] = "npy" + + sorting = run_sorter( + sorter_name, + recording, + folder=output_folder, + remove_existing_folder=True, + delete_output_folder=False, + verbose=False, + raise_error=True, + **sorter_params, + ) + assert (output_folder / "sorter_output" / "matching").is_dir() + assert (output_folder / "sorter_output" / "matching" / "spikes.npy").is_file() + if __name__ == "__main__": test = Tridesclous2SorterCommonTestSuite() diff --git a/src/spikeinterface/sorters/internal/tridesclous2.py b/src/spikeinterface/sorters/internal/tridesclous2.py index 65dfb2ed45..7dc08e2d70 100644 --- a/src/spikeinterface/sorters/internal/tridesclous2.py +++ b/src/spikeinterface/sorters/internal/tridesclous2.py @@ -63,7 +63,7 @@ class Tridesclous2Sorter(ComponentsBasedSorter): }, # "matching": {"method": "tridesclous", "method_kwargs": {"peak_shift_ms": 0.2, "radius_um": 100.0}}, # "matching": {"method": "circus-omp-svd", "method_kwargs": {}}, - "matching": {"method": "wobble", "method_kwargs": {}}, + "matching": {"method": "wobble", "method_kwargs": {}, "gather_mode": "memory"}, "job_kwargs": {"n_jobs": -1}, "save_array": True, } @@ -232,13 +232,22 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): templates = remove_empty_templates(templates) ## peeler - matching_method = params["matching"]["method"] - matching_params = params["matching"]["method_kwargs"].copy() + matching_method = params["matching"].pop("method") + gather_mode = params["matching"].pop("gather_mode", "memory") + gather_kwargs = params["matching"].pop("gather_kwargs", {}) + matching_params = params["matching"].get("matching_kwargs", {}).copy() matching_params["templates"] = templates - if params["matching"]["method"] in ("tdc-peeler",): + if matching_method in ("tdc-peeler",): matching_params["noise_levels"] = noise_levels + if gather_mode == "npy": + gather_kwargs = {"folder": gather_kwargs.get("folder", sorter_output_folder / "matching")} spikes = find_spikes_from_templates( - recording_for_peeler, method=matching_method, method_kwargs=matching_params, **job_kwargs + recording_for_peeler, + method=matching_method, + method_kwargs=matching_params, + gather_mode=gather_mode, + gather_kwargs=gather_kwargs, + **job_kwargs, ) if params["save_array"]: diff --git a/src/spikeinterface/sortingcomponents/matching/main.py b/src/spikeinterface/sortingcomponents/matching/main.py index f423d55e2a..d0ff8775ed 100644 --- a/src/spikeinterface/sortingcomponents/matching/main.py +++ b/src/spikeinterface/sortingcomponents/matching/main.py @@ -11,7 +11,14 @@ def find_spikes_from_templates( - recording, method="naive", method_kwargs={}, extra_outputs=False, verbose=False, **job_kwargs + recording, + method="naive", + method_kwargs={}, + extra_outputs=False, + gather_mode="memory", + gather_kwargs=None, + verbose=False, + **job_kwargs, ) -> np.ndarray | tuple[np.ndarray, dict]: """Find spike from a recording from given templates. @@ -25,10 +32,14 @@ def find_spikes_from_templates( Keyword arguments for the chosen method extra_outputs : bool If True then a dict is also returned is also returned - **job_kwargs : dict - Parameters for ChunkRecordingExecutor + gather_mode : "memory" | "npy", default: "memory" + If "memory" then the output is gathered in memory, if "npy" then the output is gathered on disk + gather_kwargs : dict, optional + The kwargs for the gather method verbose : Bool, default: False If True, output is verbose + **job_kwargs : keyword arguments + Parameters for ChunkRecordingExecutor Returns ------- @@ -47,13 +58,18 @@ def find_spikes_from_templates( node0 = method_class(recording, **method_kwargs) nodes = [node0] + gather_kwargs = gather_kwargs or {} + names = ["spikes"] + spikes = run_node_pipeline( recording, nodes, job_kwargs, job_name=f"find spikes ({method})", - gather_mode="memory", + gather_mode=gather_mode, squeeze_output=True, + names=names, + **gather_kwargs, ) if extra_outputs: outputs = node0.get_extra_outputs() From c2f5bdbe41015d42d1cf7fb5468de0ff391a2078 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 7 Mar 2025 15:19:03 +0100 Subject: [PATCH 2/2] Apply suggestions from code review --- src/spikeinterface/sorters/internal/spyking_circus2.py | 4 ++-- .../sorters/internal/tests/test_spykingcircus2.py | 1 + .../sorters/internal/tests/test_tridesclous2.py | 1 + src/spikeinterface/sorters/internal/tridesclous2.py | 4 ++-- 4 files changed, 6 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/sorters/internal/spyking_circus2.py b/src/spikeinterface/sorters/internal/spyking_circus2.py index 681c66ada3..ea01d3c6ca 100644 --- a/src/spikeinterface/sorters/internal/spyking_circus2.py +++ b/src/spikeinterface/sorters/internal/spyking_circus2.py @@ -322,13 +322,13 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): ## We launch a OMP matching pursuit by full convolution of the templates and the raw traces matching_method = params["matching"].pop("method") gather_mode = params["matching"].pop("gather_mode", "memory") - gather_kwargs = params["matching"].pop("gather_kwargs", {}) matching_params = params["matching"].get("method_kwargs", {}).copy() matching_params["templates"] = templates if matching_method is not None: + gather_kwargs = {} if gather_mode == "npy": - gather_kwargs["folder"] = gather_kwargs.get("folder", sorter_output_folder / "matching") + gather_kwargs["folder"] = sorter_output_folder / "matching" spikes = find_spikes_from_templates( recording_w, matching_method, diff --git a/src/spikeinterface/sorters/internal/tests/test_spykingcircus2.py b/src/spikeinterface/sorters/internal/tests/test_spykingcircus2.py index 8ab81e6c7c..5188dd8329 100644 --- a/src/spikeinterface/sorters/internal/tests/test_spykingcircus2.py +++ b/src/spikeinterface/sorters/internal/tests/test_spykingcircus2.py @@ -9,6 +9,7 @@ class SpykingCircus2SorterCommonTestSuite(SorterCommonTestSuite, unittest.TestCase): SorterClass = Spykingcircus2Sorter + @unittest.skip("performance reason") def test_with_numpy_gather(self): recording = self.recording sorter_name = self.SorterClass.sorter_name diff --git a/src/spikeinterface/sorters/internal/tests/test_tridesclous2.py b/src/spikeinterface/sorters/internal/tests/test_tridesclous2.py index 659a73ec50..1f1f109d28 100644 --- a/src/spikeinterface/sorters/internal/tests/test_tridesclous2.py +++ b/src/spikeinterface/sorters/internal/tests/test_tridesclous2.py @@ -10,6 +10,7 @@ class Tridesclous2SorterCommonTestSuite(SorterCommonTestSuite, unittest.TestCase): SorterClass = Tridesclous2Sorter + @unittest.skip("performance reason") def test_with_numpy_gather(self): recording = self.recording sorter_name = self.SorterClass.sorter_name diff --git a/src/spikeinterface/sorters/internal/tridesclous2.py b/src/spikeinterface/sorters/internal/tridesclous2.py index 7dc08e2d70..8fb2fcfaf3 100644 --- a/src/spikeinterface/sorters/internal/tridesclous2.py +++ b/src/spikeinterface/sorters/internal/tridesclous2.py @@ -234,13 +234,13 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): ## peeler matching_method = params["matching"].pop("method") gather_mode = params["matching"].pop("gather_mode", "memory") - gather_kwargs = params["matching"].pop("gather_kwargs", {}) matching_params = params["matching"].get("matching_kwargs", {}).copy() matching_params["templates"] = templates if matching_method in ("tdc-peeler",): matching_params["noise_levels"] = noise_levels + gather_kwargs = {} if gather_mode == "npy": - gather_kwargs = {"folder": gather_kwargs.get("folder", sorter_output_folder / "matching")} + gather_kwargs["folder"] = sorter_output_folder / "matching" spikes = find_spikes_from_templates( recording_for_peeler, method=matching_method,