From 96863de4a8adba8ae46e97ac166cb4f9dc6cbfad Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 11 Mar 2025 13:22:26 +0100 Subject: [PATCH 01/22] Add pydantic curation model and improve curation format and merging rules --- pyproject.toml | 1 + src/spikeinterface/core/sorting_tools.py | 13 +- src/spikeinterface/core/sortinganalyzer.py | 22 +- .../curation/curation_format.py | 191 ++++++++---------- src/spikeinterface/curation/curation_model.py | 144 +++++++++++++ .../curation/tests/test_curation_format.py | 14 +- 6 files changed, 256 insertions(+), 129 deletions(-) create mode 100644 src/spikeinterface/curation/curation_model.py diff --git a/pyproject.toml b/pyproject.toml index 97ba77299e..10a21f3595 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -97,6 +97,7 @@ full = [ "pandas", "scipy", "scikit-learn", + "pydantic", "networkx", "distinctipy", "matplotlib>=3.6", # matplotlib.colormaps diff --git a/src/spikeinterface/core/sorting_tools.py b/src/spikeinterface/core/sorting_tools.py index 47ce8cf848..ca8c731040 100644 --- a/src/spikeinterface/core/sorting_tools.py +++ b/src/spikeinterface/core/sorting_tools.py @@ -232,8 +232,13 @@ def random_spikes_selection( def apply_merges_to_sorting( - sorting, merge_unit_groups, new_unit_ids=None, censor_ms=None, return_extra=False, new_id_strategy="append" -): + sorting: BaseSorting, + merge_unit_groups: list[list[int | str]] | list[tuple[int | str]], + new_unit_ids: list[int | str] | None = None, + censor_ms: float | None = None, + return_extra: bool = False, + new_id_strategy: str = "append", +) -> NumpySorting | tuple[NumpySorting, np.ndarray, list[int | str]]: """ Apply a resolved representation of the merges to a sorting object. @@ -245,9 +250,9 @@ def apply_merges_to_sorting( Parameters ---------- - sorting : Sorting + sorting : BaseSorting The Sorting object to apply merges. - merge_unit_groups : list/tuple of lists/tuples + merge_unit_groups : list of lists/tuples A list of lists for every merge group. Each element needs to have at least two elements (two units to merge), but it can also have more (merge multiple units at once). new_unit_ids : list | None, default: None diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 85d405c443..15478b8001 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -1138,18 +1138,18 @@ def remove_units(self, remove_unit_ids, format="memory", folder=None) -> "Sortin def merge_units( self, - merge_unit_groups, - new_unit_ids=None, - censor_ms=None, - merging_mode="soft", - sparsity_overlap=0.75, - new_id_strategy="append", - return_new_unit_ids=False, - format="memory", - folder=None, - verbose=False, + merge_unit_groups: list[list[str | int]] | list[tuple[str | int]], + new_unit_ids: list[int | str] | None = None, + censor_ms: float | None = None, + merging_mode: str = "soft", + sparsity_overlap: float = 0.75, + new_id_strategy: str = "append", + return_new_unit_ids: bool = False, + format: str = "memory", + folder: Path | str | None = None, + verbose: bool = False, **job_kwargs, - ) -> "SortingAnalyzer": + ) -> "SortingAnalyzer | tuple[SortingAnalyzer, list[int | str]]": """ This method is equivalent to `save_as()`but with a list of merges that have to be achieved. Merges units by creating a new SortingAnalyzer object with the appropriate merges diff --git a/src/spikeinterface/curation/curation_format.py b/src/spikeinterface/curation/curation_format.py index 80f251ca43..186bb34568 100644 --- a/src/spikeinterface/curation/curation_format.py +++ b/src/spikeinterface/curation/curation_format.py @@ -1,14 +1,14 @@ -from itertools import combinations +from __future__ import annotations +import copy import numpy as np +from spikeinterface import curation from spikeinterface.core import BaseSorting, SortingAnalyzer, apply_merges_to_sorting -import copy - -supported_curation_format_versions = {"1"} +from spikeinterface.curation.curation_model import CurationModel -def validate_curation_dict(curation_dict): +def validate_curation_dict(curation_dict: dict): """ Validate that the curation dictionary given as parameter complies with the format @@ -19,61 +19,11 @@ def validate_curation_dict(curation_dict): curation_dict : dict """ + # this will validate the format of the curation_dict + CurationModel(**curation_dict) - # format - if "format_version" not in curation_dict: - raise ValueError("No version_format") - - if curation_dict["format_version"] not in supported_curation_format_versions: - raise ValueError( - f"Format version ({curation_dict['format_version']}) not supported. " - f"Only {supported_curation_format_versions} are valid" - ) - - # unit_ids - labeled_unit_set = set([lbl["unit_id"] for lbl in curation_dict["manual_labels"]]) - merged_units_set = set(sum(curation_dict["merge_unit_groups"], [])) - removed_units_set = set(curation_dict["removed_units"]) - - if curation_dict["unit_ids"] is not None: - # old format v0 did not contain unit_ids so this can contains None - unit_set = set(curation_dict["unit_ids"]) - if not labeled_unit_set.issubset(unit_set): - raise ValueError("Curation format: some labeled units are not in the unit list") - if not merged_units_set.issubset(unit_set): - raise ValueError("Curation format: some merged units are not in the unit list") - if not removed_units_set.issubset(unit_set): - raise ValueError("Curation format: some removed units are not in the unit list") - for group in curation_dict["merge_unit_groups"]: - if len(group) < 2: - raise ValueError("Curation format: 'merge_unit_groups' must be list of list with at least 2 elements") - - all_merging_groups = [set(group) for group in curation_dict["merge_unit_groups"]] - for gp_1, gp_2 in combinations(all_merging_groups, 2): - if len(gp_1.intersection(gp_2)) != 0: - raise ValueError("Curation format: some units belong to multiple merge groups") - if len(removed_units_set.intersection(merged_units_set)) != 0: - raise ValueError("Curation format: some units were merged and deleted") - - # Check the labels exclusivity - for lbl in curation_dict["manual_labels"]: - for label_key in curation_dict["label_definitions"].keys(): - if label_key in lbl: - unit_id = lbl["unit_id"] - label_value = lbl[label_key] - if not isinstance(label_value, list): - raise ValueError(f"Curation format: manual_labels {unit_id} is invalid shoudl be a list") - - is_exclusive = curation_dict["label_definitions"][label_key]["exclusive"] - - if is_exclusive and not len(label_value) <= 1: - raise ValueError( - f"Curation format: manual_labels {unit_id} {label_key} are exclusive labels. {label_value} is invalid" - ) - - -def convert_from_sortingview_curation_format_v0(sortingview_dict, destination_format="1"): +def convert_from_sortingview_curation_format_v0(sortingview_dict: dict, destination_format: str = "1"): """ Converts the old sortingview curation format (v0) into a curation dictionary new format (v1) Couple of caveats: @@ -99,7 +49,6 @@ def convert_from_sortingview_curation_format_v0(sortingview_dict, destination_fo if "mergeGroups" not in sortingview_dict.keys(): sortingview_dict["mergeGroups"] = [] merge_groups = sortingview_dict["mergeGroups"] - merged_units = sum(merge_groups, []) first_unit_id = next(iter(sortingview_dict["labelsByUnit"].keys())) if str.isdigit(first_unit_id): @@ -115,13 +64,19 @@ def convert_from_sortingview_curation_format_v0(sortingview_dict, destination_fo all_labels.extend(l_labels) # recorver the correct type for unit_id unit_id = unit_id_type(unit_id_) - all_units.append(unit_id) + if unit_id not in all_units: + all_units.append(unit_id) manual_labels.append({"unit_id": unit_id, general_cat: l_labels}) labels_def = {"all_labels": {"name": "all_labels", "label_options": list(set(all_labels)), "exclusive": False}} + for merge_group in merge_groups: + for unit_id in merge_group: + if unit_id not in all_units: + all_units.append(unit_id) + curation_dict = { "format_version": destination_format, - "unit_ids": None, + "unit_ids": all_units, "label_definitions": labels_def, "manual_labels": manual_labels, "merge_unit_groups": merge_groups, @@ -131,7 +86,7 @@ def convert_from_sortingview_curation_format_v0(sortingview_dict, destination_fo return curation_dict -def curation_label_to_vectors(curation_dict): +def curation_label_to_vectors(curation_dict_or_model: dict | CurationModel): """ Transform the curation dict into dict of vectors. For label category with exclusive=True : a column is created and values are the unique label. @@ -141,43 +96,46 @@ def curation_label_to_vectors(curation_dict): Parameters ---------- - curation_dict : dict - A curation dictionary + curation_dict : dict or CurationModel + A curation dictionary or model Returns ------- labels : dict of numpy vector """ - unit_ids = list(curation_dict["unit_ids"]) + if isinstance(curation_dict_or_model, dict): + curation_model = CurationModel(**curation_dict_or_model) + else: + curation_model = curation_dict_or_model + unit_ids = list(curation_model.unit_ids) n = len(unit_ids) labels = {} - for label_key, label_def in curation_dict["label_definitions"].items(): - if label_def["exclusive"]: + for label_key, label_def in curation_model.label_definitions.items(): + if label_def.exclusive: assert label_key not in labels, f"{label_key} is already a key" labels[label_key] = [""] * n - for lbl in curation_dict["manual_labels"]: - value = lbl.get(label_key, []) - if len(value) == 1: - unit_index = unit_ids.index(lbl["unit_id"]) - labels[label_key][unit_index] = value[0] + for manual_label in curation_model.manual_labels: + values = manual_label.labels.get(label_key, []) + if len(values) == 1: + unit_index = unit_ids.index(manual_label.unit_id) + labels[label_key][unit_index] = values[0] labels[label_key] = np.array(labels[label_key]) else: - for label_opt in label_def["label_options"]: + for label_opt in label_def.label_options: assert label_opt not in labels, f"{label_opt} is already a key" labels[label_opt] = np.zeros(n, dtype=bool) - for lbl in curation_dict["manual_labels"]: - values = lbl.get(label_key, []) + for manual_label in curation_model.manual_labels: + values = manual_label.labels.get(label_key, []) for value in values: - unit_index = unit_ids.index(lbl["unit_id"]) + unit_index = unit_ids.index(manual_label.unit_id) labels[value][unit_index] = True - return labels -def clean_curation_dict(curation_dict): +def clean_curation_dict(curation_dict: dict): """ In some cases the curation_dict can have inconsistencies (like in the sorting view format). For instance, some unit_ids are both in 'merge_unit_groups' and 'removed_units'. @@ -200,7 +158,7 @@ def clean_curation_dict(curation_dict): return curation_dict -def curation_label_to_dataframe(curation_dict): +def curation_label_to_dataframe(curation_dict_or_model: dict | CurationModel): """ Transform the curation dict into a pandas dataframe. For label category with exclusive=True : a column is created and values are the unique label. @@ -220,11 +178,18 @@ def curation_label_to_dataframe(curation_dict): """ import pandas as pd - labels = pd.DataFrame(curation_label_to_vectors(curation_dict), index=curation_dict["unit_ids"]) + if isinstance(curation_dict_or_model, dict): + curation_model = CurationModel(**curation_dict_or_model) + else: + curation_model = curation_dict_or_model + + labels = pd.DataFrame(curation_label_to_vectors(curation_model), index=curation_model.unit_ids) return labels -def apply_curation_labels(sorting, new_unit_ids, curation_dict): +def apply_curation_labels( + sorting: BaseSorting, new_unit_ids: list[int, str], curation_dict_or_model: dict | CurationModel +): """ Apply manual labels after merges. @@ -233,25 +198,29 @@ def apply_curation_labels(sorting, new_unit_ids, curation_dict): * for merged group, when exclusive=True, if all have the same label then this label is applied * for merged group, when exclusive=False, if one unit has the label then the new one have also it """ + if isinstance(curation_dict_or_model, dict): + curation_model = CurationModel(**curation_dict_or_model) + else: + curation_model = curation_dict_or_model # Please note that manual_labels is done on the unit_ids before the merge!!! - manual_labels = curation_label_to_vectors(curation_dict) + manual_labels = curation_label_to_vectors(curation_model) # apply on non merged for key, values in manual_labels.items(): all_values = np.zeros(sorting.unit_ids.size, dtype=values.dtype) for unit_ind, unit_id in enumerate(sorting.unit_ids): if unit_id not in new_unit_ids: - ind = list(curation_dict["unit_ids"]).index(unit_id) + ind = list(curation_model.unit_ids).index(unit_id) all_values[unit_ind] = values[ind] sorting.set_property(key, all_values) - for new_unit_id, old_group_ids in zip(new_unit_ids, curation_dict["merge_unit_groups"]): - for label_key, label_def in curation_dict["label_definitions"].items(): - if label_def["exclusive"]: + for new_unit_id, old_group_ids in zip(new_unit_ids, curation_model.merge_unit_groups): + for label_key, label_def in curation_model.label_definitions.items(): + if label_def.exclusive: group_values = [] for unit_id in old_group_ids: - ind = curation_dict["unit_ids"].index(unit_id) + ind = list(curation_model.unit_ids).index(unit_id) value = manual_labels[label_key][ind] if value != "": group_values.append(value) @@ -260,10 +229,10 @@ def apply_curation_labels(sorting, new_unit_ids, curation_dict): sorting.set_property(key, values=group_values[:1], ids=[new_unit_id]) else: - for key in label_def["label_options"]: + for key in label_def.label_options: group_values = [] for unit_id in old_group_ids: - ind = curation_dict["unit_ids"].index(unit_id) + ind = list(curation_model.unit_ids).index(unit_id) value = manual_labels[key][ind] group_values.append(value) new_value = np.any(group_values) @@ -271,13 +240,13 @@ def apply_curation_labels(sorting, new_unit_ids, curation_dict): def apply_curation( - sorting_or_analyzer, - curation_dict, - censor_ms=None, - new_id_strategy="append", - merging_mode="soft", - sparsity_overlap=0.75, - verbose=False, + sorting_or_analyzer: BaseSorting | SortingAnalyzer, + curation_dict_or_model: dict | CurationModel, + censor_ms: float | None = None, + new_id_strategy: str = "append", + merging_mode: str = "soft", + sparsity_overlap: float = 0.75, + verbose: bool = False, **job_kwargs, ): """ @@ -294,9 +263,9 @@ def apply_curation( Parameters ---------- sorting_or_analyzer : Sorting | SortingAnalyzer - The Sorting object to apply merges. - curation_dict : dict - The curation dict. + The Sorting or SortingAnalyzer object to apply merges. + curation_dict : dict or CurationModel + The curation dict or model. censor_ms : float | None, default: None When applying the merges, any consecutive spikes within the `censor_ms` are removed. This can be thought of as the desired refractory period. If `censor_ms=None`, no spikes are discarded. @@ -324,30 +293,34 @@ def apply_curation( """ - validate_curation_dict(curation_dict) - if not np.array_equal(np.asarray(curation_dict["unit_ids"]), sorting_or_analyzer.unit_ids): + if isinstance(curation_dict_or_model, dict): + curation_model = CurationModel(**curation_dict_or_model) + else: + curation_model = curation_dict_or_model + + if not np.array_equal(np.asarray(curation_model.unit_ids), sorting_or_analyzer.unit_ids): raise ValueError("unit_ids from the curation_dict do not match the one from Sorting or SortingAnalyzer") if isinstance(sorting_or_analyzer, BaseSorting): sorting = sorting_or_analyzer - sorting = sorting.remove_units(curation_dict["removed_units"]) + sorting = sorting.remove_units(curation_model.removed_units) sorting, _, new_unit_ids = apply_merges_to_sorting( sorting, - curation_dict["merge_unit_groups"], + curation_model.merge_unit_groups, censor_ms=censor_ms, return_extra=True, new_id_strategy=new_id_strategy, ) - apply_curation_labels(sorting, new_unit_ids, curation_dict) + apply_curation_labels(sorting, new_unit_ids, curation_model) return sorting elif isinstance(sorting_or_analyzer, SortingAnalyzer): analyzer = sorting_or_analyzer - if len(curation_dict["removed_units"]) > 0: - analyzer = analyzer.remove_units(curation_dict["removed_units"]) - if len(curation_dict["merge_unit_groups"]) > 0: + if len(curation_model.removed_units) > 0: + analyzer = analyzer.remove_units(curation_model.removed_units) + if len(curation_model.merge_unit_groups) > 0: analyzer, new_unit_ids = analyzer.merge_units( - curation_dict["merge_unit_groups"], + curation_model.merge_unit_groups, censor_ms=censor_ms, merging_mode=merging_mode, sparsity_overlap=sparsity_overlap, @@ -359,7 +332,7 @@ def apply_curation( ) else: new_unit_ids = [] - apply_curation_labels(analyzer.sorting, new_unit_ids, curation_dict) + apply_curation_labels(analyzer.sorting, new_unit_ids, curation_model) return analyzer else: raise TypeError( diff --git a/src/spikeinterface/curation/curation_model.py b/src/spikeinterface/curation/curation_model.py new file mode 100644 index 0000000000..5e7459df7c --- /dev/null +++ b/src/spikeinterface/curation/curation_model.py @@ -0,0 +1,144 @@ +from pydantic import BaseModel, Field, field_validator, model_validator +from typing import List, Dict, Union, Optional +from itertools import combinations + +supported_curation_format_versions = {"1"} + + +class LabelDefinition(BaseModel): + name: str = Field(..., description="Name of the label") + label_options: List[str] = Field(..., description="List of possible label options") + exclusive: bool = Field(..., description="Whether the label is exclusive") + + +class ManualLabel(BaseModel): + unit_id: Union[int, str] = Field(..., description="ID of the unit") + labels: Dict[str, List[str]] = Field(..., description="Dictionary of labels for the unit") + + +class CurationModel(BaseModel): + format_version: str = Field(..., description="Version of the curation format") + unit_ids: List[Union[int, str]] = Field(..., description="List of unit IDs") + label_definitions: Dict[str, LabelDefinition] = Field(..., description="Dictionary of label definitions") + manual_labels: List[ManualLabel] = Field(..., description="List of manual labels") + merge_unit_groups: List[List[Union[int, str]]] = Field(..., description="List of groups of units to be merged") + removed_units: List[Union[int, str]] = Field(..., description="List of removed unit IDs") + merge_new_unit_ids: Optional[List[Union[int, str]]] = Field( + default=None, description="List of new unit IDs after merging" + ) + + @field_validator("format_version") + def check_format_version(cls, v): + if v not in supported_curation_format_versions: + raise ValueError(f"Format version ({v}) not supported. Only {supported_curation_format_versions} are valid") + return v + + @field_validator("label_definitions", mode="before") + def add_label_definition_name(cls, v): + if v is None: + v = {} + else: + v_copy = v.copy() + for key in v_copy.keys(): + v[key]["name"] = key + return v + + @model_validator(mode="before") + def check_manual_labels(cls, values): + unit_ids = values["unit_ids"] + manual_labels = values["manual_labels"] + if manual_labels is None: + values["manual_labels"] = [] + else: + for manual_label in manual_labels: + unit_id = manual_label["unit_id"] + labels = manual_label.get("labels") + if labels is None: + labels = set(manual_label.keys()) - {"unit_id"} + manual_label["labels"] = {} + for label in labels: + if label not in values["label_definitions"]: + raise ValueError(f"Manual label {unit_id} has an unknown label {label}") + manual_label["labels"][label] = manual_label[label] + if unit_id not in unit_ids: + raise ValueError(f"Manual label unit_id {unit_id} is not in the unit list") + return values + + @model_validator(mode="before") + def check_merge_unit_groups(cls, values): + unit_ids = values["unit_ids"] + merge_unit_groups = values.get("merge_unit_groups", []) + for merge_group in merge_unit_groups: + for unit_id in merge_group: + if unit_id not in unit_ids: + raise ValueError(f"Merge unit group unit_id {unit_id} is not in the unit list") + if len(merge_group) < 2: + raise ValueError("Merge unit groups must have at least 2 elements") + return values + + @model_validator(mode="before") + def check_merge_new_unit_ids(cls, values): + unit_ids = values["unit_ids"] + merge_new_unit_ids = values.get("merge_new_unit_ids") + if merge_new_unit_ids is not None: + merge_unit_groups = values.get("merge_unit_groups") + assert merge_unit_groups is not None, "Merge unit groups must be defined if merge new unit ids are defined" + if len(merge_unit_groups) != len(merge_new_unit_ids): + raise ValueError("Merge unit groups and new unit ids must have the same length") + if len(merge_new_unit_ids) > 0: + for new_unit_id in merge_new_unit_ids: + if new_unit_id in unit_ids: + raise ValueError(f"New unit ID {new_unit_id} is already in the unit list") + return values + + @model_validator(mode="before") + def check_removed_units(cls, values): + unit_ids = values["unit_ids"] + removed_units = values.get("removed_units", []) + for unit_id in removed_units: + if unit_id not in unit_ids: + raise ValueError(f"Removed unit_id {unit_id} is not in the unit list") + return values + + @model_validator(mode="after") + def validate_curation_dict(cls, values): + labeled_unit_set = set([lbl.unit_id for lbl in values.manual_labels]) + merged_units_set = set(sum(values.merge_unit_groups, [])) + removed_units_set = set(values.removed_units) + unit_ids = values.unit_ids + + unit_set = set(unit_ids) + if not labeled_unit_set.issubset(unit_set): + raise ValueError("Curation format: some labeled units are not in the unit list") + if not merged_units_set.issubset(unit_set): + raise ValueError("Curation format: some merged units are not in the unit list") + if not removed_units_set.issubset(unit_set): + raise ValueError("Curation format: some removed units are not in the unit list") + + for group in values.merge_unit_groups: + if len(group) < 2: + raise ValueError("Curation format: 'merge_unit_groups' must be list of list with at least 2 elements") + + all_merging_groups = [set(group) for group in values.merge_unit_groups] + for gp_1, gp_2 in combinations(all_merging_groups, 2): + if len(gp_1.intersection(gp_2)) != 0: + raise ValueError("Curation format: some units belong to multiple merge groups") + if len(removed_units_set.intersection(merged_units_set)) != 0: + raise ValueError("Curation format: some units were merged and deleted") + + for manual_label in values.manual_labels: + for label_key in values.label_definitions.keys(): + if label_key in manual_label.labels: + unit_id = manual_label.unit_id + label_value = manual_label.labels[label_key] + if not isinstance(label_value, list): + raise ValueError(f"Curation format: manual_labels {unit_id} is invalid should be a list") + + is_exclusive = values.label_definitions[label_key].exclusive + + if is_exclusive and not len(label_value) <= 1: + raise ValueError( + f"Curation format: manual_labels {unit_id} {label_key} are exclusive labels. {label_value} is invalid" + ) + + return values diff --git a/src/spikeinterface/curation/tests/test_curation_format.py b/src/spikeinterface/curation/tests/test_curation_format.py index af9d8e1eac..0d9562f404 100644 --- a/src/spikeinterface/curation/tests/test_curation_format.py +++ b/src/spikeinterface/curation/tests/test_curation_format.py @@ -1,5 +1,8 @@ import pytest +from pydantic import BaseModel, ValidationError, field_validator + + from pathlib import Path import json import numpy as np @@ -183,10 +186,11 @@ def test_apply_curation(): if __name__ == "__main__": - # test_curation_format_validation() - # test_to_from_json() - # test_convert_from_sortingview_curation_format_v0() - # test_curation_label_to_vectors() - # test_curation_label_to_dataframe() + test_curation_format_validation() + test_curation_format_validation() + test_to_from_json() + test_convert_from_sortingview_curation_format_v0() + test_curation_label_to_vectors() + test_curation_label_to_dataframe() test_apply_curation() From 1ce611c0c33beeb97800d90f9ab70d7ec8516a01 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 11 Mar 2025 13:25:33 +0100 Subject: [PATCH 02/22] Update src/spikeinterface/curation/curation_model.py --- src/spikeinterface/curation/curation_model.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/spikeinterface/curation/curation_model.py b/src/spikeinterface/curation/curation_model.py index 5e7459df7c..0e2f1da870 100644 --- a/src/spikeinterface/curation/curation_model.py +++ b/src/spikeinterface/curation/curation_model.py @@ -38,8 +38,7 @@ def add_label_definition_name(cls, v): if v is None: v = {} else: - v_copy = v.copy() - for key in v_copy.keys(): + for key in list(v.keys()): v[key]["name"] = key return v From 3464987d1d987db126377a00dbde8c48c66268c2 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 11 Mar 2025 14:45:05 +0100 Subject: [PATCH 03/22] Move pydantic to core --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 89e69d90f9..a142391310 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,6 +28,7 @@ dependencies = [ "neo>=0.14.0", "probeinterface>=0.2.23", "packaging", + "pydantic", ] [build-system] @@ -97,7 +98,6 @@ full = [ "pandas", "scipy", "scikit-learn", - "pydantic", "networkx", "distinctipy", "matplotlib>=3.6", # matplotlib.colormaps From 677f90c18bee7b0c0ca86b3f9f2eb3c5a3ab689b Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 25 Mar 2025 11:21:26 -0400 Subject: [PATCH 04/22] wip --- src/spikeinterface/curation/curation_model.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/curation/curation_model.py b/src/spikeinterface/curation/curation_model.py index 0e2f1da870..f7c7415543 100644 --- a/src/spikeinterface/curation/curation_model.py +++ b/src/spikeinterface/curation/curation_model.py @@ -5,6 +5,14 @@ supported_curation_format_versions = {"1"} +# TODO: splitting +# - add split_units to curation model +# - add split_mode to curation model +# - add split to apply_curation +# - add split_units to SortingAnalyzer +# - add _split_units to extensions + + class LabelDefinition(BaseModel): name: str = Field(..., description="Name of the label") label_options: List[str] = Field(..., description="List of possible label options") @@ -21,11 +29,14 @@ class CurationModel(BaseModel): unit_ids: List[Union[int, str]] = Field(..., description="List of unit IDs") label_definitions: Dict[str, LabelDefinition] = Field(..., description="Dictionary of label definitions") manual_labels: List[ManualLabel] = Field(..., description="List of manual labels") - merge_unit_groups: List[List[Union[int, str]]] = Field(..., description="List of groups of units to be merged") removed_units: List[Union[int, str]] = Field(..., description="List of removed unit IDs") + merge_unit_groups: List[List[Union[int, str]]] = Field(..., description="List of groups of units to be merged") merge_new_unit_ids: Optional[List[Union[int, str]]] = Field( default=None, description="List of new unit IDs after merging" ) + split_units: Optional[Dict[Union[int, str], List[List[int]]]] = Field( + default=None, description="Dictionary of units to be split" + ) @field_validator("format_version") def check_format_version(cls, v): From 86a3ab47b7fbffd67bb5f53d46074edaa0dce69a Mon Sep 17 00:00:00 2001 From: "jain.anoushka24" Date: Wed, 26 Mar 2025 13:57:39 +0100 Subject: [PATCH 05/22] Enhance CurationModel: Add split_units validation --- src/spikeinterface/curation/curation_model.py | 71 +++++++++++++++---- .../curation/tests/test_curation_model.py | 29 ++++++++ 2 files changed, 87 insertions(+), 13 deletions(-) create mode 100644 src/spikeinterface/curation/tests/test_curation_model.py diff --git a/src/spikeinterface/curation/curation_model.py b/src/spikeinterface/curation/curation_model.py index f7c7415543..b8802bb9a1 100644 --- a/src/spikeinterface/curation/curation_model.py +++ b/src/spikeinterface/curation/curation_model.py @@ -1,12 +1,11 @@ from pydantic import BaseModel, Field, field_validator, model_validator from typing import List, Dict, Union, Optional -from itertools import combinations - +from itertools import combinations, chain supported_curation_format_versions = {"1"} # TODO: splitting -# - add split_units to curation model +# - add split_units to curation model done # - add split_mode to curation model # - add split to apply_curation # - add split_units to SortingAnalyzer @@ -27,17 +26,20 @@ class ManualLabel(BaseModel): class CurationModel(BaseModel): format_version: str = Field(..., description="Version of the curation format") unit_ids: List[Union[int, str]] = Field(..., description="List of unit IDs") - label_definitions: Dict[str, LabelDefinition] = Field(..., description="Dictionary of label definitions") - manual_labels: List[ManualLabel] = Field(..., description="List of manual labels") - removed_units: List[Union[int, str]] = Field(..., description="List of removed unit IDs") - merge_unit_groups: List[List[Union[int, str]]] = Field(..., description="List of groups of units to be merged") + label_definitions: Optional[Dict[str, LabelDefinition]] = Field(default = None, description="Dictionary of label definitions") + manual_labels: Optional[List[ManualLabel]]= Field(default = None, description="List of manual labels") + removed_units: Optional[List[Union[int, str]]] = Field(default = None, description="List of removed unit IDs") + merge_unit_groups: Optional[List[List[Union[int, str]]]] = Field(default = None, description="List of groups of units to be merged") merge_new_unit_ids: Optional[List[Union[int, str]]] = Field( default=None, description="List of new unit IDs after merging" ) - split_units: Optional[Dict[Union[int, str], List[List[int]]]] = Field( + split_units: Optional[Dict[Union[int, str], Union[List[List[int]], List[int]]]] = Field( default=None, description="Dictionary of units to be split" ) + + + @field_validator("format_version") def check_format_version(cls, v): if v not in supported_curation_format_versions: @@ -56,7 +58,7 @@ def add_label_definition_name(cls, v): @model_validator(mode="before") def check_manual_labels(cls, values): unit_ids = values["unit_ids"] - manual_labels = values["manual_labels"] + manual_labels = values.get("manual_labels") if manual_labels is None: values["manual_labels"] = [] else: @@ -84,6 +86,8 @@ def check_merge_unit_groups(cls, values): raise ValueError(f"Merge unit group unit_id {unit_id} is not in the unit list") if len(merge_group) < 2: raise ValueError("Merge unit groups must have at least 2 elements") + else: + values["merge_unit_groups"] = merge_unit_groups return values @model_validator(mode="before") @@ -101,19 +105,60 @@ def check_merge_new_unit_ids(cls, values): raise ValueError(f"New unit ID {new_unit_id} is already in the unit list") return values + + + @model_validator(mode="before") + def check_split_units(cls, values): + # we want to get split_units as a dictionary + # if method 1 Union[List[List[int] is used we want to check there no duplicates in any list of split_units: contacenate the list number of unique elements should be equal to the length of the list + # if method 2 Union[List[int] is used we want to check list dont have duplicate + # both these methods are possible + + split_units = values.get("split_units", {}) + unit_ids = values["unit_ids"] + for unit_id, split in split_units.items(): + if unit_id not in unit_ids: + raise ValueError(f"Split unit_id {unit_id} is not in the unit list") + if len(split) == 0: + raise ValueError(f"Split unit_id {unit_id} has no split") + if not isinstance(split[0], list): # uses method 1 + split = [split] + if len(split) > 1: # uses method 2 + # concatenate the list and check if the number of unique elements is equal to the length of the list + flatten = list(chain.from_iterable(split)) + if len(flatten) != len(set(flatten)): + raise ValueError(f"Split unit_id {unit_id} has duplicate units in the split") + # if len(set(sum(split))) != len(sum(split)): + # raise ValueError(f"Split unit_id {unit_id} has duplicate units in the split") + elif len(split) == 1: # uses method 1 + # check the list dont have duplicates + if len(split[0]) != len(set(split[0])): + raise ValueError(f"Split unit_id {unit_id} has duplicate units in the split") + return values + + @model_validator(mode="before") def check_removed_units(cls, values): unit_ids = values["unit_ids"] removed_units = values.get("removed_units", []) - for unit_id in removed_units: - if unit_id not in unit_ids: - raise ValueError(f"Removed unit_id {unit_id} is not in the unit list") + if removed_units is None: + + for unit_id in removed_units: + if unit_id not in unit_ids: + raise ValueError(f"Removed unit_id {unit_id} is not in the unit list") + + else: + values["removed_units"] = removed_units + return values @model_validator(mode="after") def validate_curation_dict(cls, values): labeled_unit_set = set([lbl.unit_id for lbl in values.manual_labels]) - merged_units_set = set(sum(values.merge_unit_groups, [])) + if len(values.merge_unit_groups)>0: + merged_units_set = set(sum(values.merge_unit_groups)) + else: + merged_units_set = set() removed_units_set = set(values.removed_units) unit_ids = values.unit_ids diff --git a/src/spikeinterface/curation/tests/test_curation_model.py b/src/spikeinterface/curation/tests/test_curation_model.py new file mode 100644 index 0000000000..648189a650 --- /dev/null +++ b/src/spikeinterface/curation/tests/test_curation_model.py @@ -0,0 +1,29 @@ +import pytest + +from pydantic import BaseModel, ValidationError, field_validator + + +from pathlib import Path +import json +import numpy as np + +from spikeinterface.curation.curation_model import CurationModel + +values_1 = { "format_version": "1", + "unit_ids": [1, 2, 3], + "split_units": {1: [1, 2], 2: [2, 3],3: [4,5]} +} + + +values_2 = { "format_version": "1", + "unit_ids": [1, 2, 3, 4], + "split_units": { + 1: [[1, 2], [3, 4]], + 2: [[2, 3], [4, 1]] + } +} + +curation_model1 = CurationModel(**values_1) +curation_model = CurationModel(**values_2) + + \ No newline at end of file From d4e0f84d4ea1f670a5f2cc54a718503b67fa0ca7 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 26 Mar 2025 10:45:07 -0400 Subject: [PATCH 06/22] Add splitting sorting to curation format --- src/spikeinterface/core/__init__.py | 7 +- src/spikeinterface/core/sorting_tools.py | 129 ++++++++++++++++++ .../curation/curation_format.py | 102 ++++++++------ src/spikeinterface/curation/curation_model.py | 70 +++++----- .../curation/tests/test_curation_format.py | 26 +++- .../curation/tests/test_curation_model.py | 38 +++--- 6 files changed, 279 insertions(+), 93 deletions(-) diff --git a/src/spikeinterface/core/__init__.py b/src/spikeinterface/core/__init__.py index fb2e173b3e..fdcfc73c27 100644 --- a/src/spikeinterface/core/__init__.py +++ b/src/spikeinterface/core/__init__.py @@ -109,7 +109,12 @@ get_chunk_with_margin, order_channels_by_depth, ) -from .sorting_tools import spike_vector_to_spike_trains, random_spikes_selection, apply_merges_to_sorting +from .sorting_tools import ( + spike_vector_to_spike_trains, + random_spikes_selection, + apply_merges_to_sorting, + apply_splits_to_sorting, +) from .waveform_tools import extract_waveforms_to_buffers, estimate_templates, estimate_templates_with_accumulator from .snippets_tools import snippets_from_sorting diff --git a/src/spikeinterface/core/sorting_tools.py b/src/spikeinterface/core/sorting_tools.py index ca8c731040..b60b0c1b94 100644 --- a/src/spikeinterface/core/sorting_tools.py +++ b/src/spikeinterface/core/sorting_tools.py @@ -231,6 +231,7 @@ def random_spikes_selection( return random_spikes_indices +### MERGING ZONE ### def apply_merges_to_sorting( sorting: BaseSorting, merge_unit_groups: list[list[int | str]] | list[tuple[int | str]], @@ -445,3 +446,131 @@ def generate_unit_ids_for_merge_group(old_unit_ids, merge_unit_groups, new_unit_ raise ValueError("wrong new_id_strategy") return new_unit_ids + + +### SPLITTING ZONE ### +def apply_splits_to_sorting(sorting, unit_splits, new_unit_ids=None, return_extra=False, new_id_strategy="append"): + spikes = sorting.to_spike_vector().copy() + + num_spikes = sorting.count_num_spikes_per_unit() + + # take care of single-list splits + full_unit_splits = {} + for unit_id, split_indices in unit_splits.items(): + if not isinstance(split_indices[0], (list, np.ndarray)): + split_2 = np.arange(num_spikes[unit_id]) + split_2 = split_2[~np.isin(split_2, split_indices)] + new_split_indices = [split_indices, split_2] + else: + new_split_indices = split_indices + full_unit_splits[unit_id] = new_split_indices + + new_unit_ids = generate_unit_ids_for_split( + sorting.unit_ids, full_unit_splits, new_unit_ids=new_unit_ids, new_id_strategy=new_id_strategy + ) + old_unit_ids = sorting.unit_ids + all_unit_ids = list(old_unit_ids) + for split_unit, split_new_units in zip(full_unit_splits, new_unit_ids): + all_unit_ids.remove(split_unit) + all_unit_ids.extend(split_new_units) + + num_seg = sorting.get_num_segments() + assert num_seg == 1 + seg_lims = np.searchsorted(spikes["segment_index"], np.arange(0, num_seg + 2)) + segment_slices = [(seg_lims[i], seg_lims[i + 1]) for i in range(num_seg)] + + # using this function vaoid to use the mask approach and simplify a lot the algo + spike_vector_list = [spikes[s0:s1] for s0, s1 in segment_slices] + spike_indices = spike_vector_to_indices(spike_vector_list, sorting.unit_ids, absolute_index=True) + + # TODO deal with segments in splits + for unit_id in old_unit_ids: + if unit_id in full_unit_splits: + split_indices = full_unit_splits[unit_id] + new_split_ids = new_unit_ids[list(full_unit_splits.keys()).index(unit_id)] + + for split, new_unit_id in zip(split_indices, new_split_ids): + new_unit_index = all_unit_ids.index(new_unit_id) + for segment_index in range(num_seg): + spike_inds = spike_indices[segment_index][unit_id] + spikes["unit_index"][spike_inds[split]] = new_unit_index + else: + new_unit_index = all_unit_ids.index(unit_id) + for segment_index in range(num_seg): + spike_inds = spike_indices[segment_index][unit_id] + spikes["unit_index"][spike_inds] = new_unit_index + sorting = NumpySorting(spikes, sorting.sampling_frequency, all_unit_ids) + + if return_extra: + return sorting, new_unit_ids + else: + return sorting + + +def generate_unit_ids_for_split(old_unit_ids, unit_splits, new_unit_ids=None, new_id_strategy="append"): + """ + Function to generate new units ids during a merging procedure. If new_units_ids + are provided, it will return these unit ids, checking that they have the the same + length as `merge_unit_groups`. + + Parameters + ---------- + old_unit_ids : np.array + The old unit_ids. + unit_splits : dict + + new_unit_ids : list | None, default: None + Optional new unit_ids for merged units. If given, it needs to have the same length as `merge_unit_groups`. + If None, new ids will be generated. + new_id_strategy : "append" | "take_first" | "join", default: "append" + The strategy that should be used, if `new_unit_ids` is None, to create new unit_ids. + + * "append" : new_units_ids will be added at the end of max(sorging.unit_ids) + * "split" : new_unit_ids will join unit_ids of groups with a "-". + Only works if unit_ids are str otherwise switch to "append" + + Returns + ------- + new_unit_ids : list of lists + The new units_ids associated with the merges. + """ + assert new_id_strategy in ["append", "split"], "new_id_strategy should be 'append' or 'split'" + old_unit_ids = np.asarray(old_unit_ids) + + if new_unit_ids is not None: + for split_unit, new_split_ids in zip(unit_splits.values(), new_unit_ids): + # then only doing a consistency check + assert len(split_unit) == len(new_split_ids), "new_unit_ids should have the same len as unit_splits.values" + # new_unit_ids can also be part of old_unit_ids only inside the same group: + assert all( + new_split_id not in old_unit_ids for new_split_id in new_split_ids + ), "new_unit_ids already exists but outside the split groups" + else: + dtype = old_unit_ids.dtype + new_unit_ids = [] + for unit_to_split, split_indices in unit_splits.items(): + num_splits = len(split_indices) + # select new_unit_ids greater that the max id, event greater than the numerical str ids + if new_id_strategy == "append": + if np.issubdtype(dtype, np.character): + # dtype str + if all(p.isdigit() for p in old_unit_ids): + # All str are digit : we can generate a max + m = max(int(p) for p in old_unit_ids) + 1 + new_unit_ids.append([str(m + i) for i in range(num_splits)]) + else: + # we cannot automatically find new names + new_unit_ids.append([f"split{i}" for i in range(num_splits)]) + else: + # dtype int + new_unit_ids.append(list(max(old_unit_ids) + 1 + np.arange(num_splits, dtype=dtype))) + old_unit_ids = np.concatenate([old_unit_ids, new_unit_ids[-1]]) + elif new_id_strategy == "split": + if np.issubdtype(dtype, np.character): + new_unit_ids.append([f"{unit_to_split}-{i}" for i in np.arange(len(split_indices))]) + else: + # dtype int + new_unit_ids.append(list(max(old_unit_ids) + 1 + np.arange(num_splits, dtype=dtype))) + old_unit_ids = np.concatenate([old_unit_ids, new_unit_ids[-1]]) + + return new_unit_ids diff --git a/src/spikeinterface/curation/curation_format.py b/src/spikeinterface/curation/curation_format.py index 186bb34568..540c508575 100644 --- a/src/spikeinterface/curation/curation_format.py +++ b/src/spikeinterface/curation/curation_format.py @@ -2,9 +2,9 @@ import copy import numpy as np +from itertools import chain -from spikeinterface import curation -from spikeinterface.core import BaseSorting, SortingAnalyzer, apply_merges_to_sorting +from spikeinterface.core import BaseSorting, SortingAnalyzer, apply_merges_to_sorting, apply_splits_to_sorting from spikeinterface.curation.curation_model import CurationModel @@ -187,16 +187,15 @@ def curation_label_to_dataframe(curation_dict_or_model: dict | CurationModel): return labels -def apply_curation_labels( - sorting: BaseSorting, new_unit_ids: list[int, str], curation_dict_or_model: dict | CurationModel -): +def apply_curation_labels(sorting: BaseSorting, curation_dict_or_model: dict | CurationModel): """ - Apply manual labels after merges. + Apply manual labels after merges/splits. Rules: - * label for non merge is applied first + * label for non merged units is applied first * for merged group, when exclusive=True, if all have the same label then this label is applied * for merged group, when exclusive=False, if one unit has the label then the new one have also it + * for split units, the original label is applied to all split units """ if isinstance(curation_dict_or_model, dict): curation_model = CurationModel(**curation_dict_or_model) @@ -206,37 +205,51 @@ def apply_curation_labels( # Please note that manual_labels is done on the unit_ids before the merge!!! manual_labels = curation_label_to_vectors(curation_model) - # apply on non merged + # apply on non merged / split + merge_new_unit_ids = curation_model.merge_new_unit_ids if curation_model.merge_new_unit_ids is not None else [] + split_new_unit_ids = ( + list(chain(*curation_model.split_new_unit_ids)) if curation_model.split_new_unit_ids is not None else [] + ) + merged_split_units = merge_new_unit_ids + split_new_unit_ids for key, values in manual_labels.items(): all_values = np.zeros(sorting.unit_ids.size, dtype=values.dtype) for unit_ind, unit_id in enumerate(sorting.unit_ids): - if unit_id not in new_unit_ids: + if unit_id not in merged_split_units: ind = list(curation_model.unit_ids).index(unit_id) all_values[unit_ind] = values[ind] sorting.set_property(key, all_values) - for new_unit_id, old_group_ids in zip(new_unit_ids, curation_model.merge_unit_groups): - for label_key, label_def in curation_model.label_definitions.items(): - if label_def.exclusive: - group_values = [] - for unit_id in old_group_ids: - ind = list(curation_model.unit_ids).index(unit_id) - value = manual_labels[label_key][ind] - if value != "": - group_values.append(value) - if len(set(group_values)) == 1: - # all group has the same label or empty - sorting.set_property(key, values=group_values[:1], ids=[new_unit_id]) - else: - - for key in label_def.label_options: + # merges + if len(merge_new_unit_ids) > 0: + for new_unit_id, old_group_ids in zip(curation_model.merge_new_unit_ids, curation_model.merge_unit_groups): + for label_key, label_def in curation_model.label_definitions.items(): + if label_def.exclusive: group_values = [] for unit_id in old_group_ids: ind = list(curation_model.unit_ids).index(unit_id) - value = manual_labels[key][ind] - group_values.append(value) - new_value = np.any(group_values) - sorting.set_property(key, values=[new_value], ids=[new_unit_id]) + value = manual_labels[label_key][ind] + if value != "": + group_values.append(value) + if len(set(group_values)) == 1: + # all group has the same label or empty + sorting.set_property(key, values=group_values[:1], ids=[new_unit_id]) + else: + for key in label_def.label_options: + group_values = [] + for unit_id in old_group_ids: + ind = list(curation_model.unit_ids).index(unit_id) + value = manual_labels[key][ind] + group_values.append(value) + new_value = np.any(group_values) + sorting.set_property(key, values=[new_value], ids=[new_unit_id]) + + # splits + if len(split_new_unit_ids) > 0: + for new_unit_index, old_unit in enumerate(curation_model.split_units): + for label_key, label_def in curation_model.label_definitions.items(): + ind = list(curation_model.unit_ids).index(old_unit) + value = manual_labels[label_key][ind] + sorting.set_property(label_key, values=[value], ids=[curation_model.split_new_unit_ids[new_unit_index]]) def apply_curation( @@ -255,7 +268,8 @@ def apply_curation( Steps are done in this order: 1. Apply removal using curation_dict["removed_units"] 2. Apply merges using curation_dict["merge_unit_groups"] - 3. Set labels using curation_dict["manual_labels"] + 3. Apply splits using curation_dict["split_units"] + 4. Set labels using curation_dict["manual_labels"] A new Sorting or SortingAnalyzer (in memory) is returned. The user (an adult) has the responsability to save it somewhere (or not). @@ -304,14 +318,25 @@ def apply_curation( if isinstance(sorting_or_analyzer, BaseSorting): sorting = sorting_or_analyzer sorting = sorting.remove_units(curation_model.removed_units) - sorting, _, new_unit_ids = apply_merges_to_sorting( - sorting, - curation_model.merge_unit_groups, - censor_ms=censor_ms, - return_extra=True, - new_id_strategy=new_id_strategy, - ) - apply_curation_labels(sorting, new_unit_ids, curation_model) + new_unit_ids = sorting.unit_ids + if len(curation_model.merge_unit_groups) > 0: + sorting, _, new_unit_ids = apply_merges_to_sorting( + sorting, + curation_model.merge_unit_groups, + censor_ms=censor_ms, + return_extra=True, + new_id_strategy=new_id_strategy, + ) + curation_model.merge_new_unit_ids = new_unit_ids + if len(curation_model.split_units) > 0: + sorting, new_unit_ids = apply_splits_to_sorting( + sorting, + curation_model.split_units, + new_id_strategy=new_id_strategy, + return_extra=True, + ) + curation_model.split_new_unit_ids = new_unit_ids + apply_curation_labels(sorting, curation_model) return sorting elif isinstance(sorting_or_analyzer, SortingAnalyzer): @@ -330,9 +355,10 @@ def apply_curation( verbose=verbose, **job_kwargs, ) + curation_model.merge_new_unit_ids = new_unit_ids else: new_unit_ids = [] - apply_curation_labels(analyzer.sorting, new_unit_ids, curation_model) + apply_curation_labels(analyzer.sorting, curation_model) return analyzer else: raise TypeError( diff --git a/src/spikeinterface/curation/curation_model.py b/src/spikeinterface/curation/curation_model.py index b8802bb9a1..0ada9c4777 100644 --- a/src/spikeinterface/curation/curation_model.py +++ b/src/spikeinterface/curation/curation_model.py @@ -1,13 +1,14 @@ from pydantic import BaseModel, Field, field_validator, model_validator from typing import List, Dict, Union, Optional from itertools import combinations, chain + supported_curation_format_versions = {"1"} # TODO: splitting -# - add split_units to curation model done -# - add split_mode to curation model -# - add split to apply_curation +# - add split_units to curation model done V +# - add split_mode to curation model X +# - add split to apply_curation V # - add split_units to SortingAnalyzer # - add _split_units to extensions @@ -26,19 +27,23 @@ class ManualLabel(BaseModel): class CurationModel(BaseModel): format_version: str = Field(..., description="Version of the curation format") unit_ids: List[Union[int, str]] = Field(..., description="List of unit IDs") - label_definitions: Optional[Dict[str, LabelDefinition]] = Field(default = None, description="Dictionary of label definitions") - manual_labels: Optional[List[ManualLabel]]= Field(default = None, description="List of manual labels") - removed_units: Optional[List[Union[int, str]]] = Field(default = None, description="List of removed unit IDs") - merge_unit_groups: Optional[List[List[Union[int, str]]]] = Field(default = None, description="List of groups of units to be merged") + label_definitions: Optional[Dict[str, LabelDefinition]] = Field( + default=None, description="Dictionary of label definitions" + ) + manual_labels: Optional[List[ManualLabel]] = Field(default=None, description="List of manual labels") + removed_units: Optional[List[Union[int, str]]] = Field(default=None, description="List of removed unit IDs") + merge_unit_groups: Optional[List[List[Union[int, str]]]] = Field( + default=None, description="List of groups of units to be merged" + ) merge_new_unit_ids: Optional[List[Union[int, str]]] = Field( - default=None, description="List of new unit IDs after merging" + default=None, description="List of new unit IDs for each merge group" ) split_units: Optional[Dict[Union[int, str], Union[List[List[int]], List[int]]]] = Field( - default=None, description="Dictionary of units to be split" + default=None, description="Dictionary of units to be split. TODO more description needed" + ) + split_new_unit_ids: Optional[List[Union[int, str]]] = Field( + default=None, description="List of new unit IDs for each unit split" ) - - - @field_validator("format_version") def check_format_version(cls, v): @@ -46,14 +51,16 @@ def check_format_version(cls, v): raise ValueError(f"Format version ({v}) not supported. Only {supported_curation_format_versions} are valid") return v - @field_validator("label_definitions", mode="before") - def add_label_definition_name(cls, v): - if v is None: - v = {} + @model_validator(mode="before") + def add_label_definition_name(cls, values): + label_definitions = values.get("label_definitions") + if label_definitions is None: + label_definitions = {} else: - for key in list(v.keys()): - v[key]["name"] = key - return v + for key in list(label_definitions.keys()): + label_definitions[key]["name"] = key + values["label_definitions"] = label_definitions + return values @model_validator(mode="before") def check_manual_labels(cls, values): @@ -105,15 +112,9 @@ def check_merge_new_unit_ids(cls, values): raise ValueError(f"New unit ID {new_unit_id} is already in the unit list") return values - - @model_validator(mode="before") def check_split_units(cls, values): - # we want to get split_units as a dictionary - # if method 1 Union[List[List[int] is used we want to check there no duplicates in any list of split_units: contacenate the list number of unique elements should be equal to the length of the list - # if method 2 Union[List[int] is used we want to check list dont have duplicate # both these methods are possible - split_units = values.get("split_units", {}) unit_ids = values["unit_ids"] for unit_id, split in split_units.items(): @@ -121,32 +122,30 @@ def check_split_units(cls, values): raise ValueError(f"Split unit_id {unit_id} is not in the unit list") if len(split) == 0: raise ValueError(f"Split unit_id {unit_id} has no split") - if not isinstance(split[0], list): # uses method 1 + if not isinstance(split[0], list): # uses method 1 split = [split] - if len(split) > 1: # uses method 2 + if len(split) > 1: # uses method 2 # concatenate the list and check if the number of unique elements is equal to the length of the list flatten = list(chain.from_iterable(split)) if len(flatten) != len(set(flatten)): raise ValueError(f"Split unit_id {unit_id} has duplicate units in the split") - # if len(set(sum(split))) != len(sum(split)): - # raise ValueError(f"Split unit_id {unit_id} has duplicate units in the split") - elif len(split) == 1: # uses method 1 + elif len(split) == 1: # uses method 1 # check the list dont have duplicates if len(split[0]) != len(set(split[0])): raise ValueError(f"Split unit_id {unit_id} has duplicate units in the split") + values["split_units"] = split_units return values - @model_validator(mode="before") def check_removed_units(cls, values): unit_ids = values["unit_ids"] removed_units = values.get("removed_units", []) if removed_units is None: - + for unit_id in removed_units: if unit_id not in unit_ids: raise ValueError(f"Removed unit_id {unit_id} is not in the unit list") - + else: values["removed_units"] = removed_units @@ -155,10 +154,7 @@ def check_removed_units(cls, values): @model_validator(mode="after") def validate_curation_dict(cls, values): labeled_unit_set = set([lbl.unit_id for lbl in values.manual_labels]) - if len(values.merge_unit_groups)>0: - merged_units_set = set(sum(values.merge_unit_groups)) - else: - merged_units_set = set() + merged_units_set = set(chain.from_iterable(values.merge_unit_groups)) removed_units_set = set(values.removed_units) unit_ids = values.unit_ids diff --git a/src/spikeinterface/curation/tests/test_curation_format.py b/src/spikeinterface/curation/tests/test_curation_format.py index 0d9562f404..45c1ad6b25 100644 --- a/src/spikeinterface/curation/tests/test_curation_format.py +++ b/src/spikeinterface/curation/tests/test_curation_format.py @@ -87,6 +87,16 @@ "removed_units": ["u31", "u42"], # Can not be in the merged_units } +curation_with_split = { + "format_version": "1", + "unit_ids": [1, 2, 3, 6, 10, 14, 20, 31, 42], + "split_units": { + 1: np.arange(10), + 2: np.arange(10, 20), + }, +} + + # This is a failure example with duplicated merge duplicate_merge = curation_ids_int.copy() duplicate_merge["merge_unit_groups"] = [[3, 6, 10], [10, 14, 20]] @@ -173,7 +183,7 @@ def test_curation_label_to_dataframe(): def test_apply_curation(): recording, sorting = generate_ground_truth_recording(durations=[10.0], num_units=9, seed=2205) - sorting._main_ids = np.array([1, 2, 3, 6, 10, 14, 20, 31, 42]) + sorting = sorting.rename_units(np.array([1, 2, 3, 6, 10, 14, 20, 31, 42])) analyzer = create_sorting_analyzer(sorting, recording, sparse=False) sorting_curated = apply_curation(sorting, curation_ids_int) @@ -185,6 +195,20 @@ def test_apply_curation(): assert "quality" in analyzer_curated.sorting.get_property_keys() +def test_apply_curation_with_split(): + recording, sorting = generate_ground_truth_recording(durations=[10.0], num_units=9, seed=2205) + sorting = sorting.rename_units(np.array([1, 2, 3, 6, 10, 14, 20, 31, 42])) + analyzer = create_sorting_analyzer(sorting, recording, sparse=False) + + sorting_curated = apply_curation(sorting, curation_with_split) + assert len(sorting_curated.unit_ids) == len(sorting.unit_ids) + 2 + + assert 1 not in sorting_curated.unit_ids + assert 2 not in sorting_curated.unit_ids + assert 43 in sorting_curated.unit_ids + assert 44 in sorting_curated.unit_ids + + if __name__ == "__main__": test_curation_format_validation() test_curation_format_validation() diff --git a/src/spikeinterface/curation/tests/test_curation_model.py b/src/spikeinterface/curation/tests/test_curation_model.py index 648189a650..9dbc6fac22 100644 --- a/src/spikeinterface/curation/tests/test_curation_model.py +++ b/src/spikeinterface/curation/tests/test_curation_model.py @@ -1,29 +1,35 @@ import pytest -from pydantic import BaseModel, ValidationError, field_validator - - +from pydantic import ValidationError from pathlib import Path -import json import numpy as np from spikeinterface.curation.curation_model import CurationModel -values_1 = { "format_version": "1", - "unit_ids": [1, 2, 3], - "split_units": {1: [1, 2], 2: [2, 3],3: [4,5]} -} +valid_split_1 = {"format_version": "1", "unit_ids": [1, 2, 3], "split_units": {1: [1, 2], 2: [2, 3], 3: [4, 5]}} -values_2 = { "format_version": "1", +valid_split_2 = { + "format_version": "1", "unit_ids": [1, 2, 3, 4], - "split_units": { - 1: [[1, 2], [3, 4]], - 2: [[2, 3], [4, 1]] - } + "split_units": {1: [[1, 2], [3, 4]], 2: [[2, 3], [4, 1]]}, } -curation_model1 = CurationModel(**values_1) -curation_model = CurationModel(**values_2) +invalid_split_1 = { + "format_version": "1", + "unit_ids": [1, 2, 3], + "split_units": {1: [[1, 2], [2, 3]], 2: [2, 3], 3: [4, 5]}, +} + +invalid_split_2 = {"format_version": "1", "unit_ids": [1, 2, 3], "split_units": {4: [[1, 2], [2, 3]]}} + + +def test_unit_split(): + CurationModel(**valid_split_1) + CurationModel(**valid_split_2) - \ No newline at end of file + # shold raise error + with pytest.raises(ValidationError): + CurationModel(**invalid_split_1) + with pytest.raises(ValidationError): + CurationModel(**invalid_split_2) From c7316bb2fb803b59bb99c67a3eb93a3d809da107 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 26 Mar 2025 11:55:13 -0400 Subject: [PATCH 07/22] (wip) Add split_units to SortingAnalyzer --- src/spikeinterface/core/sorting_tools.py | 48 +++++- src/spikeinterface/core/sortinganalyzer.py | 162 ++++++++++++++++++--- 2 files changed, 185 insertions(+), 25 deletions(-) diff --git a/src/spikeinterface/core/sorting_tools.py b/src/spikeinterface/core/sorting_tools.py index b60b0c1b94..292780bb38 100644 --- a/src/spikeinterface/core/sorting_tools.py +++ b/src/spikeinterface/core/sorting_tools.py @@ -468,11 +468,7 @@ def apply_splits_to_sorting(sorting, unit_splits, new_unit_ids=None, return_extr new_unit_ids = generate_unit_ids_for_split( sorting.unit_ids, full_unit_splits, new_unit_ids=new_unit_ids, new_id_strategy=new_id_strategy ) - old_unit_ids = sorting.unit_ids - all_unit_ids = list(old_unit_ids) - for split_unit, split_new_units in zip(full_unit_splits, new_unit_ids): - all_unit_ids.remove(split_unit) - all_unit_ids.extend(split_new_units) + all_unit_ids = _get_ids_after_splitting(sorting.unit_ids, full_unit_splits, new_unit_ids) num_seg = sorting.get_num_segments() assert num_seg == 1 @@ -574,3 +570,45 @@ def generate_unit_ids_for_split(old_unit_ids, unit_splits, new_unit_ids=None, ne old_unit_ids = np.concatenate([old_unit_ids, new_unit_ids[-1]]) return new_unit_ids + + +def _get_ids_after_splitting(old_unit_ids, split_units, new_unit_ids): + """ + Function to get the list of unique unit_ids after some splits, with given new_units_ids would + be provided. + + Every new unit_id will be added at the end if not already present. + + Parameters + ---------- + old_unit_ids : np.array + The old unit_ids. + split_units : dict + A dict of split units. Each element needs to have at least two elements (two units to split). + new_unit_ids : list | None + A new unit_ids for split units. If given, it needs to have the same length as `split_units` values. + + Returns + ------- + + all_unit_ids : The unit ids in the split sorting + The units_ids that will be present after splits + + """ + old_unit_ids = np.asarray(old_unit_ids) + dtype = old_unit_ids.dtype + if dtype.kind == "U": + # the new dtype can be longer + dtype = "U" + + assert len(new_unit_ids) == len(split_units), "new_unit_ids should have the same len as merge_unit_groups" + for new_unit_in_split, unit_to_split in zip(new_unit_ids, split_units.keys()): + assert len(new_unit_in_split) == len( + split_units[unit_to_split] + ), "new_unit_ids should have the same len as split_units values" + + all_unit_ids = list(old_unit_ids.copy()) + for split_unit, split_new_units in zip(split_units, new_unit_ids): + all_unit_ids.remove(split_unit) + all_unit_ids.extend(split_new_units) + return np.array(all_unit_ids, dtype=dtype) diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index a53b4c5cb9..62a8e3b6aa 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -31,7 +31,12 @@ is_path_remote, clean_zarr_folder_name, ) -from .sorting_tools import generate_unit_ids_for_merge_group, _get_ids_after_merging +from .sorting_tools import ( + generate_unit_ids_for_merge_group, + _get_ids_after_merging, + generate_unit_ids_for_split, + _get_ids_after_splitting, +) from .job_tools import split_job_kwargs from .numpyextractors import NumpySorting from .sparsity import ChannelSparsity, estimate_sparsity @@ -867,17 +872,19 @@ def are_units_mergeable( else: return mergeable - def _save_or_select_or_merge( + def _save_or_select_or_merge_or_split( self, format="binary_folder", folder=None, unit_ids=None, merge_unit_groups=None, + split_units=None, censor_ms=None, merging_mode="soft", sparsity_overlap=0.75, verbose=False, - new_unit_ids=None, + merge_new_unit_ids=None, + split_new_unit_ids=None, backend_options=None, **job_kwargs, ) -> "SortingAnalyzer": @@ -896,6 +903,8 @@ def _save_or_select_or_merge( merge_unit_groups : list/tuple of lists/tuples or None, default: None A list of lists for every merge group. Each element needs to have at least two elements (two units to merge). If `merge_unit_groups` is not None, `new_unit_ids` must be given. + split_units : dict or None, default: None + A dictionary with the keys being the unit ids to split and the values being the split indices. censor_ms : None or float, default: None When merging units, any spikes violating this refractory period will be discarded. merging_mode : "soft" | "hard", default: "soft" @@ -904,8 +913,10 @@ def _save_or_select_or_merge( sparsity_overlap : float, default 0.75 The percentage of overlap that units should share in order to accept merges. If this criteria is not achieved, soft merging will not be performed. - new_unit_ids : list or None, default: None + merge_new_unit_ids : list or None, default: None The new unit ids for merged units. Required if `merge_unit_groups` is not None. + split_new_unit_ids : list or None, default: None + The new unit ids for split units. Required if `split_units` is not None. verbose : bool, default: False If True, output is verbose. backend_options : dict | None, default: None @@ -943,8 +954,8 @@ def _save_or_select_or_merge( ) for unit_index, unit_id in enumerate(all_unit_ids): - if unit_id in new_unit_ids: - merge_unit_group = tuple(merge_unit_groups[new_unit_ids.index(unit_id)]) + if unit_id in merge_new_unit_ids: + merge_unit_group = tuple(merge_unit_groups[merge_new_unit_ids.index(unit_id)]) if not mergeable[merge_unit_group]: raise Exception( f"The sparsity of {merge_unit_group} do not overlap enough for a soft merge using " @@ -967,25 +978,35 @@ def _save_or_select_or_merge( # if the original sorting object is not available anymore (kilosort folder deleted, ....), take the copy sorting_provenance = self.sorting - if merge_unit_groups is None: + if merge_unit_groups is None and split_units is None: # when only some unit_ids then the sorting must be sliced # TODO check that unit_ids are in same order otherwise many extension do handle it properly!!!! sorting_provenance = sorting_provenance.select_units(unit_ids) - else: + elif merge_unit_groups is not None: + assert split_units is None, "split_units must be None when merge_unit_groups is None" from spikeinterface.core.sorting_tools import apply_merges_to_sorting sorting_provenance, keep_mask, _ = apply_merges_to_sorting( sorting=sorting_provenance, merge_unit_groups=merge_unit_groups, - new_unit_ids=new_unit_ids, + new_unit_ids=merge_new_unit_ids, censor_ms=censor_ms, return_extra=True, ) if censor_ms is None: # in this case having keep_mask None is faster instead of having a vector of ones keep_mask = None - # TODO: sam/pierre would create a curation field / curation.json with the applied merges. - # What do you think? + elif split_units is not None: + assert merge_unit_groups is None, "merge_unit_groups must be None when split_units is not None" + from spikeinterface.core.sorting_tools import apply_splits_to_sorting + + sorting_provenance = apply_splits_to_sorting( + sorting=sorting_provenance, + split_units=split_units, + new_unit_ids=split_new_unit_ids, + ) + # TODO: sam/pierre would create a curation field / curation.json with the applied merges. + # What do you think? backend_options = {} if backend_options is None else backend_options @@ -1034,24 +1055,31 @@ def _save_or_select_or_merge( recompute_dict = {} for extension_name, extension in sorted_extensions.items(): - if merge_unit_groups is None: + if merge_unit_groups is None and split_units is None: # copy full or select new_sorting_analyzer.extensions[extension_name] = extension.copy( new_sorting_analyzer, unit_ids=unit_ids ) - else: + elif merge_unit_groups is not None: # merge if merging_mode == "soft": new_sorting_analyzer.extensions[extension_name] = extension.merge( new_sorting_analyzer, merge_unit_groups=merge_unit_groups, - new_unit_ids=new_unit_ids, + new_unit_ids=merge_new_unit_ids, keep_mask=keep_mask, verbose=verbose, **job_kwargs, ) elif merging_mode == "hard": recompute_dict[extension_name] = extension.params + else: + # split + # TODO + print("Splitting extension needs to be implemented") + # new_sorting_analyzer.extensions[extension_name] = extension.split( + # new_sorting_analyzer, split_units=split_units, new_unit_ids=split_new_unit_ids, verbose=verbose + # ) if merge_unit_groups is not None and merging_mode == "hard" and len(recompute_dict) > 0: new_sorting_analyzer.compute_several_extensions(recompute_dict, save=True, verbose=verbose, **job_kwargs) @@ -1081,7 +1109,7 @@ def save_as(self, format="memory", folder=None, backend_options=None) -> "Sortin """ if format == "zarr": folder = clean_zarr_folder_name(folder) - return self._save_or_select_or_merge(format=format, folder=folder, backend_options=backend_options) + return self._save_or_select_or_merge_or_split(format=format, folder=folder, backend_options=backend_options) def select_units(self, unit_ids, format="memory", folder=None) -> "SortingAnalyzer": """ @@ -1108,7 +1136,7 @@ def select_units(self, unit_ids, format="memory", folder=None) -> "SortingAnalyz # TODO check that unit_ids are in same order otherwise many extension do handle it properly!!!! if format == "zarr": folder = clean_zarr_folder_name(folder) - return self._save_or_select_or_merge(format=format, folder=folder, unit_ids=unit_ids) + return self._save_or_select_or_merge_or_split(format=format, folder=folder, unit_ids=unit_ids) def remove_units(self, remove_unit_ids, format="memory", folder=None) -> "SortingAnalyzer": """ @@ -1136,7 +1164,7 @@ def remove_units(self, remove_unit_ids, format="memory", folder=None) -> "Sortin unit_ids = self.unit_ids[~np.isin(self.unit_ids, remove_unit_ids)] if format == "zarr": folder = clean_zarr_folder_name(folder) - return self._save_or_select_or_merge(format=format, folder=folder, unit_ids=unit_ids) + return self._save_or_select_or_merge_or_split(format=format, folder=folder, unit_ids=unit_ids) def merge_units( self, @@ -1222,7 +1250,7 @@ def merge_units( ) all_unit_ids = _get_ids_after_merging(self.unit_ids, merge_unit_groups, new_unit_ids=new_unit_ids) - new_analyzer = self._save_or_select_or_merge( + new_analyzer = self._save_or_select_or_merge_or_split( format=format, folder=folder, merge_unit_groups=merge_unit_groups, @@ -1231,7 +1259,80 @@ def merge_units( merging_mode=merging_mode, sparsity_overlap=sparsity_overlap, verbose=verbose, - new_unit_ids=new_unit_ids, + merge_new_unit_ids=new_unit_ids, + **job_kwargs, + ) + if return_new_unit_ids: + return new_analyzer, new_unit_ids + else: + return new_analyzer + + def split_units( + self, + split_units: dict[list[str | int], list[int] | list[list[int]]], + new_unit_ids: list[list[int | str]] | None = None, + new_id_strategy: str = "append", + return_new_unit_ids: bool = False, + format: str = "memory", + folder: Path | str | None = None, + verbose: bool = False, + **job_kwargs, + ) -> "SortingAnalyzer | tuple[SortingAnalyzer, list[int | str]]": + """ + This method is equivalent to `save_as()` but with a list of splits that have to be achieved. + Split units by creating a new SortingAnalyzer object with the appropriate splits + + Extensions are also updated to display the split `unit_ids`. + + Parameters + ---------- + split_units : dict + A dictionary with the keys being the unit ids to split and the values being the split indices. + new_unit_ids : None | list, default: None + A new unit_ids for split units. If given, it needs to have the same length as `merge_unit_groups`. If None, + merged units will have the first unit_id of every lists of merges + new_id_strategy : "append" | "split", default: "append" + The strategy that should be used, if `new_unit_ids` is None, to create new unit_ids. + + * "append" : new_units_ids will be added at the end of max(sorting.unit_ids) + * "split" : new_unit_ids will be the original unit_id to split with -{subsplit} + return_new_unit_ids : bool, default False + Alse return new_unit_ids which are the ids of the new units. + folder : Path | None, default: None + The new folder where the analyzer with merged units is copied for `format` "binary_folder" or "zarr" + format : "memory" | "binary_folder" | "zarr", default: "memory" + The format of SortingAnalyzer + verbose : bool, default: False + Whether to display calculations (such as sparsity estimation) + + Returns + ------- + analyzer : SortingAnalyzer + The newly create `SortingAnalyzer` with the selected units + """ + + if format == "zarr": + folder = clean_zarr_folder_name(folder) + + if len(split_units) == 0: + # TODO I think we should raise an error or at least make a copy and not return itself + if return_new_unit_ids: + return self, [] + else: + return self + + # TODO: add some checks + + new_unit_ids = generate_unit_ids_for_split(self.unit_ids, split_units, new_unit_ids, new_id_strategy) + all_unit_ids = _get_ids_after_splitting(self.unit_ids, split_units, new_unit_ids=new_unit_ids) + + new_analyzer = self._save_or_select_or_merge_or_split( + format=format, + folder=folder, + split_units=split_units, + unit_ids=all_unit_ids, + verbose=verbose, + split_new_unit_ids=new_unit_ids, **job_kwargs, ) if return_new_unit_ids: @@ -1243,7 +1344,7 @@ def copy(self): """ Create a a copy of SortingAnalyzer with format "memory". """ - return self._save_or_select_or_merge(format="memory", folder=None) + return self._save_or_select_or_merge_or_split(format="memory", folder=None) def is_read_only(self) -> bool: if self.format == "memory": @@ -2048,6 +2149,10 @@ def _merge_extension_data( # must be implemented in subclass raise NotImplementedError + def _split_extension_data(self, split_units, new_unit_ids, new_sorting_analyzer, verbose=False, **job_kwargs): + # must be implemented in subclass + raise NotImplementedError + def _get_pipeline_nodes(self): # must be implemented in subclass only if use_nodepipeline=True raise NotImplementedError @@ -2283,6 +2388,23 @@ def merge( new_extension.save() return new_extension + def split( + self, + new_sorting_analyzer, + split_units, + new_unit_ids, + verbose=False, + **job_kwargs, + ): + new_extension = self.__class__(new_sorting_analyzer) + new_extension.params = self.params.copy() + new_extension.data = self._split_extension_data( + split_units, new_unit_ids, new_sorting_analyzer, verbose=verbose, **job_kwargs + ) + new_extension.run_info = copy(self.run_info) + new_extension.save() + return new_extension + def run(self, save=True, **kwargs): if save and not self.sorting_analyzer.is_read_only(): # NB: this call to _save_params() also resets the folder or zarr group From 2e21923412319a1c496a463ef17f65d6214f3b91 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 26 Mar 2025 15:32:48 -0400 Subject: [PATCH 08/22] Add split_units to sorting analyzer --- .../core/analyzer_extension_core.py | 52 ++++++++- src/spikeinterface/core/sorting_tools.py | 30 +++--- src/spikeinterface/core/sortinganalyzer.py | 100 ++++++++++++------ .../postprocessing/amplitude_scalings.py | 3 + .../postprocessing/correlograms.py | 9 +- src/spikeinterface/postprocessing/isi.py | 24 +++++ .../postprocessing/principal_component.py | 4 + .../postprocessing/spike_amplitudes.py | 4 + .../postprocessing/spike_locations.py | 4 + .../postprocessing/template_metrics.py | 21 ++++ .../postprocessing/template_similarity.py | 53 ++++++++++ .../postprocessing/unit_locations.py | 26 ++++- .../quality_metric_calculator.py | 28 +++++ 13 files changed, 305 insertions(+), 53 deletions(-) diff --git a/src/spikeinterface/core/analyzer_extension_core.py b/src/spikeinterface/core/analyzer_extension_core.py index 447bbe562e..4834e864d5 100644 --- a/src/spikeinterface/core/analyzer_extension_core.py +++ b/src/spikeinterface/core/analyzer_extension_core.py @@ -94,6 +94,11 @@ def _merge_extension_data( new_data["random_spikes_indices"] = np.flatnonzero(selected_mask[keep_mask]) return new_data + def _split_extension_data(self, split_units, new_unit_ids, new_sorting_analyzer, verbose=False, **job_kwargs): + new_data = dict() + new_data["random_spikes_indices"] = self.data["random_spikes_indices"].copy() + return new_data + def _get_data(self): return self.data["random_spikes_indices"] @@ -245,8 +250,6 @@ def _select_extension_data(self, unit_ids): def _merge_extension_data( self, merge_unit_groups, new_unit_ids, new_sorting_analyzer, keep_mask=None, verbose=False, **job_kwargs ): - new_data = dict() - waveforms = self.data["waveforms"] some_spikes = self.sorting_analyzer.get_extension("random_spikes").get_random_spikes() if keep_mask is not None: @@ -277,6 +280,11 @@ def _merge_extension_data( return dict(waveforms=waveforms) + def _split_extension_data(self, split_units, new_unit_ids, new_sorting_analyzer, verbose=False, **job_kwargs): + # splitting only affects random spikes, not waveforms + new_data = dict(waveforms=self.data["waveforms"].copy()) + return new_data + def get_waveforms_one_unit(self, unit_id, force_dense: bool = False): """ Returns the waveforms of a unit id. @@ -556,6 +564,42 @@ def _merge_extension_data( return new_data + def _split_extension_data(self, split_units, new_unit_ids, new_sorting_analyzer, verbose=False, **job_kwargs): + new_data = dict() + for operator, arr in self.data.items(): + # we first copy the unsplit units + new_array = np.zeros((len(new_sorting_analyzer.unit_ids), arr.shape[1], arr.shape[2]), dtype=arr.dtype) + new_analyzer_unit_ids = list(new_sorting_analyzer.unit_ids) + unsplit_unit_ids = [unit_id for unit_id in self.sorting_analyzer.unit_ids if unit_id not in split_units] + new_indices = np.array([new_analyzer_unit_ids.index(unit_id) for unit_id in unsplit_unit_ids]) + old_indices = self.sorting_analyzer.sorting.ids_to_indices(unsplit_unit_ids) + new_array[new_indices, ...] = arr[old_indices, ...] + + for split_unit_id, new_splits in zip(split_units, new_unit_ids): + if new_sorting_analyzer.has_extension("waveforms"): + for new_unit_id in new_splits: + split_unit_index = new_sorting_analyzer.sorting.id_to_index(new_unit_id) + wfs = new_sorting_analyzer.get_extension("waveforms").get_waveforms_one_unit( + new_unit_id, force_dense=True + ) + + if operator == "average": + arr = np.average(wfs, axis=0) + elif operator == "std": + arr = np.std(wfs, axis=0) + elif operator == "median": + arr = np.median(wfs, axis=0) + elif "percentile" in operator: + _, percentile = operator.splot("_") + arr = np.percentile(wfs, float(percentile), axis=0) + new_array[split_unit_index, ...] = arr + else: + old_template = arr[self.sorting_analyzer.sorting.ids_to_indices([split_unit_id])[0], ...] + new_indices = np.array([new_unit_ids.index(unit_id) for unit_id in new_splits]) + new_array[new_indices, ...] = np.tile(old_template, (len(new_splits), 1, 1)) + new_data[operator] = new_array + return new_data + def _get_data(self, operator="average", percentile=None, outputs="numpy"): if operator != "percentile": key = operator @@ -729,6 +773,10 @@ def _merge_extension_data( # this does not depend on units return self.data.copy() + def _split_extension_data(self, split_units, new_unit_ids, new_sorting_analyzer, verbose=False, **job_kwargs): + # this does not depend on units + return self.data.copy() + def _run(self, verbose=False): self.data["noise_levels"] = get_noise_levels( self.sorting_analyzer.recording, return_scaled=self.sorting_analyzer.return_scaled, **self.params diff --git a/src/spikeinterface/core/sorting_tools.py b/src/spikeinterface/core/sorting_tools.py index 292780bb38..d6ba6a7e6b 100644 --- a/src/spikeinterface/core/sorting_tools.py +++ b/src/spikeinterface/core/sorting_tools.py @@ -452,23 +452,14 @@ def generate_unit_ids_for_merge_group(old_unit_ids, merge_unit_groups, new_unit_ def apply_splits_to_sorting(sorting, unit_splits, new_unit_ids=None, return_extra=False, new_id_strategy="append"): spikes = sorting.to_spike_vector().copy() - num_spikes = sorting.count_num_spikes_per_unit() - # take care of single-list splits - full_unit_splits = {} - for unit_id, split_indices in unit_splits.items(): - if not isinstance(split_indices[0], (list, np.ndarray)): - split_2 = np.arange(num_spikes[unit_id]) - split_2 = split_2[~np.isin(split_2, split_indices)] - new_split_indices = [split_indices, split_2] - else: - new_split_indices = split_indices - full_unit_splits[unit_id] = new_split_indices + full_unit_splits = _get_full_unit_splits(unit_splits, sorting) new_unit_ids = generate_unit_ids_for_split( sorting.unit_ids, full_unit_splits, new_unit_ids=new_unit_ids, new_id_strategy=new_id_strategy ) all_unit_ids = _get_ids_after_splitting(sorting.unit_ids, full_unit_splits, new_unit_ids) + all_unit_ids = list(all_unit_ids) num_seg = sorting.get_num_segments() assert num_seg == 1 @@ -480,7 +471,7 @@ def apply_splits_to_sorting(sorting, unit_splits, new_unit_ids=None, return_extr spike_indices = spike_vector_to_indices(spike_vector_list, sorting.unit_ids, absolute_index=True) # TODO deal with segments in splits - for unit_id in old_unit_ids: + for unit_id in sorting.unit_ids: if unit_id in full_unit_splits: split_indices = full_unit_splits[unit_id] new_split_ids = new_unit_ids[list(full_unit_splits.keys()).index(unit_id)] @@ -572,6 +563,21 @@ def generate_unit_ids_for_split(old_unit_ids, unit_splits, new_unit_ids=None, ne return new_unit_ids +def _get_full_unit_splits(unit_splits, sorting): + # take care of single-list splits + full_unit_splits = {} + num_spikes = sorting.count_num_spikes_per_unit() + for unit_id, split_indices in unit_splits.items(): + if not isinstance(split_indices[0], (list, np.ndarray)): + split_2 = np.arange(num_spikes[unit_id]) + split_2 = split_2[~np.isin(split_2, split_indices)] + new_split_indices = [split_indices, split_2] + else: + new_split_indices = split_indices + full_unit_splits[unit_id] = new_split_indices + return full_unit_splits + + def _get_ids_after_splitting(old_unit_ids, split_units, new_unit_ids): """ Function to get the list of unique unit_ids after some splits, with given new_units_ids would diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 62a8e3b6aa..d7b96b32b3 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -36,6 +36,7 @@ _get_ids_after_merging, generate_unit_ids_for_split, _get_ids_after_splitting, + _get_full_unit_splits, ) from .job_tools import split_job_kwargs from .numpyextractors import NumpySorting @@ -939,36 +940,63 @@ def _save_or_select_or_merge_or_split( else: recording = None - if self.sparsity is not None and unit_ids is None and merge_unit_groups is None: - sparsity = self.sparsity - elif self.sparsity is not None and unit_ids is not None and merge_unit_groups is None: - sparsity_mask = self.sparsity.mask[np.isin(self.unit_ids, unit_ids), :] - sparsity = ChannelSparsity(sparsity_mask, unit_ids, self.channel_ids) - elif self.sparsity is not None and merge_unit_groups is not None: - all_unit_ids = unit_ids - sparsity_mask = np.zeros((len(all_unit_ids), self.sparsity.mask.shape[1]), dtype=bool) - mergeable, masks = self.are_units_mergeable( - merge_unit_groups, - sparsity_overlap=sparsity_overlap, - return_masks=True, - ) + has_removed = unit_ids is not None + has_merges = merge_unit_groups is not None + has_splits = split_units is not None + assert not has_merges if has_splits else True, "Cannot merge and split at the same time" + + if self.sparsity is not None: + if not has_removed and not has_merges and not has_splits: + # no changes in units + sparsity = self.sparsity + elif has_removed and not has_merges and not has_splits: + # remove units + sparsity_mask = self.sparsity.mask[np.isin(self.unit_ids, unit_ids), :] + sparsity = ChannelSparsity(sparsity_mask, unit_ids, self.channel_ids) + elif has_merges: + # merge units + all_unit_ids = unit_ids + sparsity_mask = np.zeros((len(all_unit_ids), self.sparsity.mask.shape[1]), dtype=bool) + mergeable, masks = self.are_units_mergeable( + merge_unit_groups, + sparsity_overlap=sparsity_overlap, + return_masks=True, + ) - for unit_index, unit_id in enumerate(all_unit_ids): - if unit_id in merge_new_unit_ids: - merge_unit_group = tuple(merge_unit_groups[merge_new_unit_ids.index(unit_id)]) - if not mergeable[merge_unit_group]: - raise Exception( - f"The sparsity of {merge_unit_group} do not overlap enough for a soft merge using " - f"a sparsity threshold of {sparsity_overlap}. You can either lower the threshold or use " - "a hard merge." - ) + for unit_index, unit_id in enumerate(all_unit_ids): + if unit_id in merge_new_unit_ids: + merge_unit_group = tuple(merge_unit_groups[merge_new_unit_ids.index(unit_id)]) + if not mergeable[merge_unit_group]: + raise Exception( + f"The sparsity of {merge_unit_group} do not overlap enough for a soft merge using " + f"a sparsity threshold of {sparsity_overlap}. You can either lower the threshold or use " + "a hard merge." + ) + else: + sparsity_mask[unit_index] = masks[merge_unit_group] else: - sparsity_mask[unit_index] = masks[merge_unit_group] - else: - # This means that the unit is already in the previous sorting - index = self.sorting.id_to_index(unit_id) - sparsity_mask[unit_index] = self.sparsity.mask[index] - sparsity = ChannelSparsity(sparsity_mask, list(all_unit_ids), self.channel_ids) + # This means that the unit is already in the previous sorting + index = self.sorting.id_to_index(unit_id) + sparsity_mask[unit_index] = self.sparsity.mask[index] + sparsity = ChannelSparsity(sparsity_mask, list(all_unit_ids), self.channel_ids) + elif has_splits: + # split units + all_unit_ids = unit_ids + original_unit_ids = self.unit_ids + sparsity_mask = np.zeros((len(all_unit_ids), self.sparsity.mask.shape[1]), dtype=bool) + for unit_index, unit_id in enumerate(all_unit_ids): + if unit_id not in original_unit_ids: + # then it is a new unit + # we assign the original sparsity + for split_unit, new_unit_ids in zip(split_units, split_new_unit_ids): + if unit_id in new_unit_ids: + original_unit_index = self.sorting.id_to_index(split_unit) + sparsity_mask[unit_index] = self.sparsity.mask[original_unit_index] + break + else: + original_unit_index = self.sorting.id_to_index(unit_id) + sparsity_mask[unit_index] = self.sparsity.mask[original_unit_index] + sparsity = ChannelSparsity(sparsity_mask, list(all_unit_ids), self.channel_ids) else: sparsity = None @@ -1002,7 +1030,7 @@ def _save_or_select_or_merge_or_split( sorting_provenance = apply_splits_to_sorting( sorting=sorting_provenance, - split_units=split_units, + unit_splits=split_units, new_unit_ids=split_new_unit_ids, ) # TODO: sam/pierre would create a curation field / curation.json with the applied merges. @@ -1075,13 +1103,14 @@ def _save_or_select_or_merge_or_split( recompute_dict[extension_name] = extension.params else: # split - # TODO - print("Splitting extension needs to be implemented") - # new_sorting_analyzer.extensions[extension_name] = extension.split( - # new_sorting_analyzer, split_units=split_units, new_unit_ids=split_new_unit_ids, verbose=verbose - # ) + try: + new_sorting_analyzer.extensions[extension_name] = extension.split( + new_sorting_analyzer, split_units=split_units, new_unit_ids=split_new_unit_ids, verbose=verbose + ) + except NotImplementedError: + recompute_dict[extension_name] = extension.params - if merge_unit_groups is not None and merging_mode == "hard" and len(recompute_dict) > 0: + if len(recompute_dict) > 0: new_sorting_analyzer.compute_several_extensions(recompute_dict, save=True, verbose=verbose, **job_kwargs) return new_sorting_analyzer @@ -1322,6 +1351,7 @@ def split_units( return self # TODO: add some checks + split_units = _get_full_unit_splits(split_units, self.sorting) new_unit_ids = generate_unit_ids_for_split(self.unit_ids, split_units, new_unit_ids, new_id_strategy) all_unit_ids = _get_ids_after_splitting(self.unit_ids, split_units, new_unit_ids=new_unit_ids) diff --git a/src/spikeinterface/postprocessing/amplitude_scalings.py b/src/spikeinterface/postprocessing/amplitude_scalings.py index 278151a930..ff926a998d 100644 --- a/src/spikeinterface/postprocessing/amplitude_scalings.py +++ b/src/spikeinterface/postprocessing/amplitude_scalings.py @@ -128,6 +128,9 @@ def _merge_extension_data( return new_data + def _split_extension_data(self, split_units, new_unit_ids, new_sorting_analyzer, verbose=False, **job_kwargs): + return self.data.copy() + def _get_pipeline_nodes(self): recording = self.sorting_analyzer.recording diff --git a/src/spikeinterface/postprocessing/correlograms.py b/src/spikeinterface/postprocessing/correlograms.py index 5e30d7c68b..d41beb595f 100644 --- a/src/spikeinterface/postprocessing/correlograms.py +++ b/src/spikeinterface/postprocessing/correlograms.py @@ -154,9 +154,6 @@ def _merge_extension_data( if unit_involved_in_merge is False: old_to_new_unit_index_map[old_unit_index] = new_sorting_analyzer.sorting.id_to_index(old_unit) - need_to_append = False - delete_from = 1 - correlograms, new_bins = deepcopy(self.get_data()) for new_unit_id, merge_unit_group in zip(new_unit_ids, merge_unit_groups): @@ -188,6 +185,12 @@ def _merge_extension_data( new_data = dict(ccgs=new_correlograms, bins=new_bins) return new_data + def _split_extension_data(self, split_units, new_unit_ids, new_sorting_analyzer, verbose=False, **job_kwargs): + # TODO: for now we just copy + new_ccgs, new_bins = _compute_correlograms_on_sorting(new_sorting_analyzer.sorting, **self.params) + new_data = dict(ccgs=new_ccgs, bins=new_bins) + return new_data + def _run(self, verbose=False): ccgs, bins = _compute_correlograms_on_sorting(self.sorting_analyzer.sorting, **self.params) self.data["ccgs"] = ccgs diff --git a/src/spikeinterface/postprocessing/isi.py b/src/spikeinterface/postprocessing/isi.py index 542f829f21..03bd9d71a8 100644 --- a/src/spikeinterface/postprocessing/isi.py +++ b/src/spikeinterface/postprocessing/isi.py @@ -1,6 +1,7 @@ from __future__ import annotations import numpy as np +from itertools import chain from spikeinterface.core.sortinganalyzer import register_result_extension, AnalyzerExtension @@ -80,6 +81,29 @@ def _merge_extension_data( new_extension_data = dict(isi_histograms=new_isi_hists, bins=new_bins) return new_extension_data + def _split_extension_data(self, split_units, new_unit_ids, new_sorting_analyzer, verbose=False, **job_kwargs): + new_bins = self.data["bins"] + arr = self.data["isi_histograms"] + num_dims = arr.shape[1] + all_new_units = new_sorting_analyzer.unit_ids + new_isi_hists = np.zeros((len(all_new_units), num_dims), dtype=arr.dtype) + + # compute all new isi at once + new_unit_ids_f = list(chain(*new_unit_ids)) + new_sorting = new_sorting_analyzer.sorting.select_units(new_unit_ids_f) + only_new_hist, _ = _compute_isi_histograms(new_sorting, **self.params) + + for unit_ind, unit_id in enumerate(all_new_units): + if unit_id not in new_unit_ids_f: + keep_unit_index = self.sorting_analyzer.sorting.id_to_index(unit_id) + new_isi_hists[unit_ind, :] = arr[keep_unit_index, :] + else: + new_unit_index = new_sorting.id_to_index(unit_id) + new_isi_hists[unit_ind, :] = only_new_hist[new_unit_index, :] + + new_extension_data = dict(isi_histograms=new_isi_hists, bins=new_bins) + return new_extension_data + def _run(self, verbose=False): isi_histograms, bins = _compute_isi_histograms(self.sorting_analyzer.sorting, **self.params) self.data["isi_histograms"] = isi_histograms diff --git a/src/spikeinterface/postprocessing/principal_component.py b/src/spikeinterface/postprocessing/principal_component.py index dd3a8febd7..c340b7ff50 100644 --- a/src/spikeinterface/postprocessing/principal_component.py +++ b/src/spikeinterface/postprocessing/principal_component.py @@ -149,6 +149,10 @@ def _merge_extension_data( new_data[k] = v return new_data + def _split_extension_data(self, split_units, new_unit_ids, new_sorting_analyzer, verbose=False, **job_kwargs): + # splitting only changes random spikes assignments + return self.data.copy() + def get_pca_model(self): """ Returns the scikit-learn PCA model objects. diff --git a/src/spikeinterface/postprocessing/spike_amplitudes.py b/src/spikeinterface/postprocessing/spike_amplitudes.py index 577dc948c3..4b7a4e8eae 100644 --- a/src/spikeinterface/postprocessing/spike_amplitudes.py +++ b/src/spikeinterface/postprocessing/spike_amplitudes.py @@ -92,6 +92,10 @@ def _merge_extension_data( return new_data + def _split_extension_data(self, split_units, new_unit_ids, new_sorting_analyzer, verbose=False, **job_kwargs): + # splitting only changes random spikes assignments + return self.data.copy() + def _get_pipeline_nodes(self): recording = self.sorting_analyzer.recording diff --git a/src/spikeinterface/postprocessing/spike_locations.py b/src/spikeinterface/postprocessing/spike_locations.py index 6995fc04da..c33b9bb8aa 100644 --- a/src/spikeinterface/postprocessing/spike_locations.py +++ b/src/spikeinterface/postprocessing/spike_locations.py @@ -105,6 +105,10 @@ def _merge_extension_data( ### in a merged could be different. Should be discussed return dict(spike_locations=new_spike_locations) + def _split_extension_data(self, split_units, new_unit_ids, new_sorting_analyzer, verbose=False, **job_kwargs): + # splitting only changes random spikes assignments + return self.data.copy() + def _get_pipeline_nodes(self): from spikeinterface.sortingcomponents.peak_localization import get_localization_pipeline_nodes diff --git a/src/spikeinterface/postprocessing/template_metrics.py b/src/spikeinterface/postprocessing/template_metrics.py index d78b1e3809..e077dab482 100644 --- a/src/spikeinterface/postprocessing/template_metrics.py +++ b/src/spikeinterface/postprocessing/template_metrics.py @@ -8,6 +8,7 @@ import numpy as np import warnings +from itertools import chain from copy import deepcopy from spikeinterface.core.sortinganalyzer import register_result_extension, AnalyzerExtension @@ -195,6 +196,26 @@ def _merge_extension_data( new_data = dict(metrics=metrics) return new_data + def _split_extension_data(self, split_units, new_unit_ids, new_sorting_analyzer, verbose=False, **job_kwargs): + import pandas as pd + + metric_names = self.params["metric_names"] + old_metrics = self.data["metrics"] + + all_unit_ids = new_sorting_analyzer.unit_ids + new_unit_ids_f = list(chain(*new_unit_ids)) + not_new_ids = all_unit_ids[~np.isin(all_unit_ids, new_unit_ids_f)] + + metrics = pd.DataFrame(index=all_unit_ids, columns=old_metrics.columns) + + metrics.loc[not_new_ids, :] = old_metrics.loc[not_new_ids, :] + metrics.loc[new_unit_ids_f, :] = self._compute_metrics( + new_sorting_analyzer, new_unit_ids_f, verbose, metric_names, **job_kwargs + ) + + new_data = dict(metrics=metrics) + return new_data + def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, metric_names=None, **job_kwargs): """ Compute template metrics. diff --git a/src/spikeinterface/postprocessing/template_similarity.py b/src/spikeinterface/postprocessing/template_similarity.py index 1928e12edc..5469c7fe5a 100644 --- a/src/spikeinterface/postprocessing/template_similarity.py +++ b/src/spikeinterface/postprocessing/template_similarity.py @@ -2,6 +2,7 @@ import numpy as np import warnings +from itertools import chain from spikeinterface.core.sortinganalyzer import register_result_extension, AnalyzerExtension from spikeinterface.core.template_tools import get_dense_templates_array @@ -128,6 +129,58 @@ def _merge_extension_data( return dict(similarity=similarity) + def _split_extension_data(self, split_units, new_unit_ids, new_sorting_analyzer, verbose=False, **job_kwargs): + num_shifts = int(self.params["max_lag_ms"] * self.sorting_analyzer.sampling_frequency / 1000) + all_templates_array = get_dense_templates_array( + new_sorting_analyzer, return_scaled=self.sorting_analyzer.return_scaled + ) + + new_unit_ids_f = list(chain(*new_unit_ids)) + keep = np.isin(new_sorting_analyzer.unit_ids, new_unit_ids_f) + new_templates_array = all_templates_array[keep, :, :] + if new_sorting_analyzer.sparsity is None: + new_sparsity = None + else: + new_sparsity = ChannelSparsity( + new_sorting_analyzer.sparsity.mask[keep, :], new_unit_ids_f, new_sorting_analyzer.channel_ids + ) + + new_similarity = compute_similarity_with_templates_array( + new_templates_array, + all_templates_array, + method=self.params["method"], + num_shifts=num_shifts, + support=self.params["support"], + sparsity=new_sparsity, + other_sparsity=new_sorting_analyzer.sparsity, + ) + + old_similarity = self.data["similarity"] + + all_new_unit_ids = new_sorting_analyzer.unit_ids + n = all_new_unit_ids.size + similarity = np.zeros((n, n), dtype=old_similarity.dtype) + + # copy old similarity + for unit_ind1, unit_id1 in enumerate(all_new_unit_ids): + if unit_id1 not in new_unit_ids_f: + old_ind1 = self.sorting_analyzer.sorting.id_to_index(unit_id1) + for unit_ind2, unit_id2 in enumerate(all_new_unit_ids): + if unit_id2 not in new_unit_ids_f: + old_ind2 = self.sorting_analyzer.sorting.id_to_index(unit_id2) + s = self.data["similarity"][old_ind1, old_ind2] + similarity[unit_ind1, unit_ind2] = s + similarity[unit_ind1, unit_ind2] = s + + # insert new similarity both way + for unit_ind, unit_id in enumerate(all_new_unit_ids): + if unit_id in new_unit_ids_f: + new_index = list(new_unit_ids_f).index(unit_id) + similarity[unit_ind, :] = new_similarity[new_index, :] + similarity[:, unit_ind] = new_similarity[new_index, :] + + return dict(similarity=similarity) + def _run(self, verbose=False): num_shifts = int(self.params["max_lag_ms"] * self.sorting_analyzer.sampling_frequency / 1000) templates_array = get_dense_templates_array( diff --git a/src/spikeinterface/postprocessing/unit_locations.py b/src/spikeinterface/postprocessing/unit_locations.py index 5618499770..ea297f7b6c 100644 --- a/src/spikeinterface/postprocessing/unit_locations.py +++ b/src/spikeinterface/postprocessing/unit_locations.py @@ -1,7 +1,7 @@ from __future__ import annotations import numpy as np -import warnings +from itertools import chain from spikeinterface.core.sortinganalyzer import register_result_extension, AnalyzerExtension from .localization_tools import _unit_location_methods @@ -88,6 +88,30 @@ def _merge_extension_data( return dict(unit_locations=unit_location) + def _split_extension_data(self, split_units, new_unit_ids, new_sorting_analyzer, verbose=False, **job_kwargs): + old_unit_locations = self.data["unit_locations"] + num_dims = old_unit_locations.shape[1] + + method = self.params.get("method") + method_kwargs = self.params.copy() + method_kwargs.pop("method") + func = _unit_location_methods[method] + new_unit_ids_f = list(chain(*new_unit_ids)) + new_unit_locations = func(new_sorting_analyzer, unit_ids=new_unit_ids_f, **method_kwargs) + assert new_unit_locations.shape[0] == len(new_unit_ids_f) + + all_new_unit_ids = new_sorting_analyzer.unit_ids + unit_location = np.zeros((len(all_new_unit_ids), num_dims), dtype=old_unit_locations.dtype) + for unit_index, unit_id in enumerate(all_new_unit_ids): + if unit_id not in new_unit_ids_f: + old_index = self.sorting_analyzer.sorting.id_to_index(unit_id) + unit_location[unit_index] = old_unit_locations[old_index] + else: + new_index = list(new_unit_ids_f).index(unit_id) + unit_location[unit_index] = new_unit_locations[new_index] + + return dict(unit_locations=unit_location) + def _run(self, verbose=False): method = self.params.get("method") method_kwargs = self.params.copy() diff --git a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py index 134849e70f..055fefc78c 100644 --- a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py @@ -3,6 +3,7 @@ from __future__ import annotations import warnings +from itertools import chain from copy import deepcopy import numpy as np @@ -158,6 +159,33 @@ def _merge_extension_data( new_data = dict(metrics=metrics) return new_data + def _split_extension_data(self, split_units, new_unit_ids, new_sorting_analyzer, verbose=False, **job_kwargs): + import pandas as pd + + metric_names = self.params["metric_names"] + old_metrics = self.data["metrics"] + + all_unit_ids = new_sorting_analyzer.unit_ids + new_unit_ids_f = list(chain(*new_unit_ids)) + not_new_ids = all_unit_ids[~np.isin(all_unit_ids, new_unit_ids_f)] + + # this creates a new metrics dictionary, but the dtype for everything will be + # object. So we will need to fix this later after computing metrics + metrics = pd.DataFrame(index=all_unit_ids, columns=old_metrics.columns) + metrics.loc[not_new_ids, :] = old_metrics.loc[not_new_ids, :] + metrics.loc[new_unit_ids_f, :] = self._compute_metrics( + new_sorting_analyzer, new_unit_ids_f, verbose, metric_names, **job_kwargs + ) + + # we need to fix the dtypes after we compute everything because we have nans + # we can iterate through the columns and convert them back to the dtype + # of the original quality dataframe. + for column in old_metrics.columns: + metrics[column] = metrics[column].astype(old_metrics[column].dtype) + + new_data = dict(metrics=metrics) + return new_data + def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, metric_names=None, **job_kwargs): """ Compute quality metrics. From 4afdb805c04bbdbe2ea6cc099341532824337150 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Wed, 26 Mar 2025 15:50:26 -0400 Subject: [PATCH 09/22] Propagate SortingAnalyzer.split_units to apply_curation --- src/spikeinterface/curation/curation_format.py | 11 +++++++++-- src/spikeinterface/curation/curation_model.py | 8 -------- 2 files changed, 9 insertions(+), 10 deletions(-) diff --git a/src/spikeinterface/curation/curation_format.py b/src/spikeinterface/curation/curation_format.py index 540c508575..bc121090f1 100644 --- a/src/spikeinterface/curation/curation_format.py +++ b/src/spikeinterface/curation/curation_format.py @@ -356,8 +356,15 @@ def apply_curation( **job_kwargs, ) curation_model.merge_new_unit_ids = new_unit_ids - else: - new_unit_ids = [] + if len(curation_model.split_units) > 0: + analyzer, new_unit_ids = analyzer.split_units( + curation_model.split_units, + new_id_strategy=new_id_strategy, + return_new_unit_ids=True, + format="memory", + verbose=verbose, + ) + curation_model.split_new_unit_ids = new_unit_ids apply_curation_labels(analyzer.sorting, curation_model) return analyzer else: diff --git a/src/spikeinterface/curation/curation_model.py b/src/spikeinterface/curation/curation_model.py index 0ada9c4777..f5cc035676 100644 --- a/src/spikeinterface/curation/curation_model.py +++ b/src/spikeinterface/curation/curation_model.py @@ -5,14 +5,6 @@ supported_curation_format_versions = {"1"} -# TODO: splitting -# - add split_units to curation model done V -# - add split_mode to curation model X -# - add split to apply_curation V -# - add split_units to SortingAnalyzer -# - add _split_units to extensions - - class LabelDefinition(BaseModel): name: str = Field(..., description="Name of the label") label_options: List[str] = Field(..., description="List of possible label options") From 62bfb7f955718136511d7bdcbdb0986bd4f1a490 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 27 Mar 2025 09:22:33 -0400 Subject: [PATCH 10/22] Extend CurationModel tests --- src/spikeinterface/curation/curation_model.py | 28 ++- .../curation/tests/test_curation_model.py | 221 ++++++++++++++++-- 2 files changed, 221 insertions(+), 28 deletions(-) diff --git a/src/spikeinterface/curation/curation_model.py b/src/spikeinterface/curation/curation_model.py index f5cc035676..b50db1e69c 100644 --- a/src/spikeinterface/curation/curation_model.py +++ b/src/spikeinterface/curation/curation_model.py @@ -1,13 +1,14 @@ from pydantic import BaseModel, Field, field_validator, model_validator from typing import List, Dict, Union, Optional from itertools import combinations, chain +import numpy as np supported_curation_format_versions = {"1"} class LabelDefinition(BaseModel): name: str = Field(..., description="Name of the label") - label_options: List[str] = Field(..., description="List of possible label options") + label_options: List[str] = Field(..., description="List of possible label options", min_length=2) exclusive: bool = Field(..., description="Whether the label is exclusive") @@ -47,11 +48,16 @@ def check_format_version(cls, v): def add_label_definition_name(cls, values): label_definitions = values.get("label_definitions") if label_definitions is None: - label_definitions = {} - else: - for key in list(label_definitions.keys()): - label_definitions[key]["name"] = key - values["label_definitions"] = label_definitions + values["label_definitions"] = {} + return values + if isinstance(values["label_definitions"], dict): + if label_definitions is None: + label_definitions = {} + else: + for key in list(label_definitions.keys()): + if isinstance(label_definitions[key], dict): + label_definitions[key]["name"] = key + values["label_definitions"] = label_definitions return values @model_validator(mode="before") @@ -70,7 +76,11 @@ def check_manual_labels(cls, values): for label in labels: if label not in values["label_definitions"]: raise ValueError(f"Manual label {unit_id} has an unknown label {label}") - manual_label["labels"][label] = manual_label[label] + if label not in manual_label["labels"]: + if label in manual_label: + manual_label["labels"][label] = manual_label[label] + else: + raise ValueError(f"Manual label {unit_id} has no value for label {label}") if unit_id not in unit_ids: raise ValueError(f"Manual label unit_id {unit_id} is not in the unit list") return values @@ -114,7 +124,7 @@ def check_split_units(cls, values): raise ValueError(f"Split unit_id {unit_id} is not in the unit list") if len(split) == 0: raise ValueError(f"Split unit_id {unit_id} has no split") - if not isinstance(split[0], list): # uses method 1 + if not isinstance(split[0], (list, np.ndarray)): # uses method 1 split = [split] if len(split) > 1: # uses method 2 # concatenate the list and check if the number of unique elements is equal to the length of the list @@ -123,7 +133,7 @@ def check_split_units(cls, values): raise ValueError(f"Split unit_id {unit_id} has duplicate units in the split") elif len(split) == 1: # uses method 1 # check the list dont have duplicates - if len(split[0]) != len(set(split[0])): + if len(split[0]) != len(set(list(split[0]))): raise ValueError(f"Split unit_id {unit_id} has duplicate units in the split") values["split_units"] = split_units return values diff --git a/src/spikeinterface/curation/tests/test_curation_model.py b/src/spikeinterface/curation/tests/test_curation_model.py index 9dbc6fac22..12db0984c8 100644 --- a/src/spikeinterface/curation/tests/test_curation_model.py +++ b/src/spikeinterface/curation/tests/test_curation_model.py @@ -1,35 +1,218 @@ import pytest from pydantic import ValidationError -from pathlib import Path import numpy as np -from spikeinterface.curation.curation_model import CurationModel +from spikeinterface.curation.curation_model import CurationModel, LabelDefinition -valid_split_1 = {"format_version": "1", "unit_ids": [1, 2, 3], "split_units": {1: [1, 2], 2: [2, 3], 3: [4, 5]}} +# Test data for format version +def test_format_version(): + # Valid format version + CurationModel(format_version="1", unit_ids=[1, 2, 3]) + + # Invalid format version + with pytest.raises(ValidationError): + CurationModel(format_version="2", unit_ids=[1, 2, 3]) + with pytest.raises(ValidationError): + CurationModel(format_version="0", unit_ids=[1, 2, 3]) -valid_split_2 = { - "format_version": "1", - "unit_ids": [1, 2, 3, 4], - "split_units": {1: [[1, 2], [3, 4]], 2: [[2, 3], [4, 1]]}, -} -invalid_split_1 = { - "format_version": "1", - "unit_ids": [1, 2, 3], - "split_units": {1: [[1, 2], [2, 3]], 2: [2, 3], 3: [4, 5]}, -} +# Test data for label definitions +def test_label_definitions(): + valid_label_def = { + "format_version": "1", + "unit_ids": [1, 2, 3], + "label_definitions": { + "quality": LabelDefinition(name="quality", label_options=["good", "noise"], exclusive=True), + "tags": LabelDefinition(name="tags", label_options=["burst", "slow", "fast"], exclusive=False), + }, + } -invalid_split_2 = {"format_version": "1", "unit_ids": [1, 2, 3], "split_units": {4: [[1, 2], [2, 3]]}} + model = CurationModel(**valid_label_def) + assert "quality" in model.label_definitions + assert model.label_definitions["quality"].name == "quality" + assert model.label_definitions["quality"].exclusive is True + + # Test invalid label definition + with pytest.raises(ValidationError): + LabelDefinition(name="quality", label_options=[], exclusive=True) # Empty options should be invalid +# Test manual labels +def test_manual_labels(): + valid_labels = { + "format_version": "1", + "unit_ids": [1, 2, 3], + "label_definitions": { + "quality": LabelDefinition(name="quality", label_options=["good", "noise"], exclusive=True), + "tags": LabelDefinition(name="tags", label_options=["burst", "slow", "fast"], exclusive=False), + }, + "manual_labels": [ + {"unit_id": 1, "labels": {"quality": ["good"], "tags": ["burst", "fast"]}}, + {"unit_id": 2, "labels": {"quality": ["noise"]}}, + ], + } + + model = CurationModel(**valid_labels) + assert len(model.manual_labels) == 2 + + # Test invalid unit ID + invalid_unit = { + "format_version": "1", + "unit_ids": [1, 2, 3], + "label_definitions": { + "quality": LabelDefinition(name="quality", label_options=["good", "noise"], exclusive=True) + }, + "manual_labels": [{"unit_id": 4, "labels": {"quality": ["good"]}}], # Non-existent unit + } + with pytest.raises(ValidationError): + CurationModel(**invalid_unit) + + # Test violation of exclusive label + invalid_exclusive = { + "format_version": "1", + "unit_ids": [1, 2, 3], + "label_definitions": { + "quality": LabelDefinition(name="quality", label_options=["good", "noise"], exclusive=True) + }, + "manual_labels": [ + {"unit_id": 1, "labels": {"quality": ["good", "noise"]}} # Multiple values for exclusive label + ], + } + with pytest.raises(ValidationError): + CurationModel(**invalid_exclusive) + + +# Test merge functionality +def test_merge_units(): + valid_merge = { + "format_version": "1", + "unit_ids": [1, 2, 3, 4], + "merge_unit_groups": [[1, 2], [3, 4]], + "merge_new_unit_ids": [5, 6], + } + + model = CurationModel(**valid_merge) + assert len(model.merge_unit_groups) == 2 + assert len(model.merge_new_unit_ids) == 2 + + # Test invalid merge group (single unit) + invalid_merge_group = {"format_version": "1", "unit_ids": [1, 2, 3], "merge_unit_groups": [[1], [2, 3]]} + with pytest.raises(ValidationError): + CurationModel(**invalid_merge_group) + + # Test overlapping merge groups + invalid_overlap = {"format_version": "1", "unit_ids": [1, 2, 3], "merge_unit_groups": [[1, 2], [2, 3]]} + with pytest.raises(ValidationError): + CurationModel(**invalid_overlap) + + # Test merge new unit IDs length mismatch + invalid_new_ids = { + "format_version": "1", + "unit_ids": [1, 2, 3, 4], + "merge_unit_groups": [[1, 2], [3, 4]], + "merge_new_unit_ids": [5], # Missing one ID + } + with pytest.raises(ValidationError): + CurationModel(**invalid_new_ids) + + +# Test removed units +def test_removed_units(): + valid_remove = {"format_version": "1", "unit_ids": [1, 2, 3], "removed_units": [2]} + + model = CurationModel(**valid_remove) + assert len(model.removed_units) == 1 + + # Test removing non-existent unit + invalid_remove = {"format_version": "1", "unit_ids": [1, 2, 3], "removed_units": [4]} # Non-existent unit + with pytest.raises(ValidationError): + CurationModel(**invalid_remove) + + # Test conflict between merge and remove + invalid_merge_remove = { + "format_version": "1", + "unit_ids": [1, 2, 3], + "merge_unit_groups": [[1, 2]], + "removed_units": [1], # Unit is both merged and removed + } + with pytest.raises(ValidationError): + CurationModel(**invalid_merge_remove) + + +# Test complete model with multiple operations +def test_complete_model(): + complete_model = { + "format_version": "1", + "unit_ids": [1, 2, 3, 4, 5], + "label_definitions": { + "quality": LabelDefinition(name="quality", label_options=["good", "noise"], exclusive=True), + "tags": LabelDefinition(name="tags", label_options=["burst", "slow"], exclusive=False), + }, + "manual_labels": [{"unit_id": 1, "labels": {"quality": ["good"], "tags": ["burst"]}}], + "merge_unit_groups": [[2, 3]], + "merge_new_unit_ids": [6], + "split_units": {4: [[1, 2], [3, 4]]}, + "removed_units": [5], + } + + model = CurationModel(**complete_model) + assert model.format_version == "1" + assert len(model.unit_ids) == 5 + assert len(model.label_definitions) == 2 + assert len(model.manual_labels) == 1 + assert len(model.merge_unit_groups) == 1 + assert len(model.merge_new_unit_ids) == 1 + assert len(model.split_units) == 1 + assert len(model.removed_units) == 1 + + +# Test unit splitting functionality def test_unit_split(): - CurationModel(**valid_split_1) - CurationModel(**valid_split_2) + # Test simple split (method 1) + valid_simple_split = { + "format_version": "1", + "unit_ids": [1, 2, 3], + "split_units": { + 1: [1, 2], # Split unit 1 into two parts + 2: [2, 3], # Split unit 2 into two parts + 3: [4, 5], # Split unit 3 into two parts + }, + } + model = CurationModel(**valid_simple_split) + assert len(model.split_units) == 3 - # shold raise error + # Test complex split with multiple groups (method 2) + valid_complex_split = { + "format_version": "1", + "unit_ids": [1, 2, 3, 4], + "split_units": { + 1: [[1, 2], [3, 4]], # Split unit 1 into two groups + 2: [[2, 3], [4, 1]], # Split unit 2 into two groups + }, + } + model = CurationModel(**valid_complex_split) + assert len(model.split_units) == 2 + + # Test invalid mixing of methods + invalid_mixed_methods = { + "format_version": "1", + "unit_ids": [1, 2, 3], + "split_units": { + 1: [[1, 2], [2, 3]], # Using method 2 + 2: [2, 3], # Using method 1 + 3: [4, 5], # Using method 1 + }, + } with pytest.raises(ValidationError): - CurationModel(**invalid_split_1) + CurationModel(**invalid_mixed_methods) + + # Test invalid unit ID + invalid_unit_id = { + "format_version": "1", + "unit_ids": [1, 2, 3], + "split_units": {4: [[1, 2], [2, 3]]}, # Unit 4 doesn't exist in unit_ids + } with pytest.raises(ValidationError): - CurationModel(**invalid_split_2) + CurationModel(**invalid_unit_id) From 40fe01be5e1a57e3824f1eae856af60316d8c82f Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 27 Mar 2025 09:32:33 -0400 Subject: [PATCH 11/22] Add analyzer split to curation tests --- src/spikeinterface/curation/tests/test_curation_format.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/src/spikeinterface/curation/tests/test_curation_format.py b/src/spikeinterface/curation/tests/test_curation_format.py index 45c1ad6b25..23fc69925e 100644 --- a/src/spikeinterface/curation/tests/test_curation_format.py +++ b/src/spikeinterface/curation/tests/test_curation_format.py @@ -208,6 +208,14 @@ def test_apply_curation_with_split(): assert 43 in sorting_curated.unit_ids assert 44 in sorting_curated.unit_ids + analyzer_curated = apply_curation(analyzer, curation_with_split) + assert len(analyzer_curated.sorting.unit_ids) == len(analyzer.sorting.unit_ids) + 2 + + assert 1 not in analyzer_curated.unit_ids + assert 2 not in analyzer_curated.unit_ids + assert 43 in analyzer_curated.unit_ids + assert 44 in analyzer_curated.unit_ids + if __name__ == "__main__": test_curation_format_validation() From ca6f2e0b17de9162c9615eb3e841d69613d59d02 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 27 Mar 2025 10:42:07 -0400 Subject: [PATCH 12/22] wip: add split tests in postprocessing --- .../tests/test_multi_extensions.py | 158 +++++++++++++----- 1 file changed, 113 insertions(+), 45 deletions(-) diff --git a/src/spikeinterface/postprocessing/tests/test_multi_extensions.py b/src/spikeinterface/postprocessing/tests/test_multi_extensions.py index be0070d94a..0c8c2649af 100644 --- a/src/spikeinterface/postprocessing/tests/test_multi_extensions.py +++ b/src/spikeinterface/postprocessing/tests/test_multi_extensions.py @@ -11,8 +11,46 @@ ) from spikeinterface.core.generate import inject_some_split_units +# even if this is in postprocessing, we make an extension for quality metrics +extension_dict = { + "noise_levels": dict(), + "random_spikes": dict(), + "waveforms": dict(), + "templates": dict(), + "principal_components": dict(), + "spike_amplitudes": dict(), + "template_similarity": dict(), + "correlograms": dict(), + "isi_histograms": dict(), + "amplitude_scalings": dict(handle_collisions=False), # otherwise hard mode could fail due to dropped spikes + "spike_locations": dict(method="center_of_mass"), # trick to avoid UserWarning + "unit_locations": dict(), + "template_metrics": dict(), + "quality_metrics": dict(metric_names=["firing_rate", "isi_violation", "snr"]), +} +extension_data_type = { + "noise_levels": None, + "templates": "unit", + "isi_histograms": "unit", + "unit_locations": "unit", + "spike_amplitudes": "spike", + "amplitude_scalings": "spike", + "spike_locations": "spike", + "quality_metrics": "pandas", + "template_metrics": "pandas", + "correlograms": "matrix", + "template_similarity": "matrix", + "principal_components": "random", + "waveforms": "random", + "random_spikes": "random_spikes", +} +data_with_miltiple_returns = ["isi_histograms", "correlograms"] +# due to incremental PCA, hard computation could result in different results for PCA +# the model is differents always +random_computation = ["principal_components"] -def get_dataset(): + +def get_dataset_with_splits(): recording, sorting = generate_ground_truth_recording( durations=[30.0], sampling_frequency=16000.0, @@ -36,15 +74,15 @@ def get_dataset(): sort_by_amp = np.argsort(list(get_template_extremum_amplitude(analyzer_raw).values()))[::-1] split_ids = sorting.unit_ids[sort_by_amp][:3] - sorting_with_splits, other_ids = inject_some_split_units( + sorting_with_splits, split_unit_ids = inject_some_split_units( sorting, num_split=3, split_ids=split_ids, output_ids=True, seed=0 ) - return recording, sorting_with_splits, other_ids + return recording, sorting_with_splits, split_unit_ids @pytest.fixture(scope="module") def dataset(): - return get_dataset() + return get_dataset_with_splits() @pytest.mark.parametrize("sparse", [False, True]) @@ -54,52 +92,14 @@ def test_SortingAnalyzer_merge_all_extensions(dataset, sparse): recording, sorting, other_ids = dataset sorting_analyzer = create_sorting_analyzer(sorting, recording, format="memory", sparse=sparse) + extension_dict_merge = extension_dict.copy() # we apply the merges according to the artificial splits merges = [list(v) for v in other_ids.values()] split_unit_ids = np.ravel(merges) unmerged_unit_ids = sorting_analyzer.unit_ids[~np.isin(sorting_analyzer.unit_ids, split_unit_ids)] - # even if this is in postprocessing, we make an extension for quality metrics - extension_dict = { - "noise_levels": dict(), - "random_spikes": dict(), - "waveforms": dict(), - "templates": dict(), - "principal_components": dict(), - "spike_amplitudes": dict(), - "template_similarity": dict(), - "correlograms": dict(), - "isi_histograms": dict(), - "amplitude_scalings": dict(handle_collisions=False), # otherwise hard mode could fail due to dropped spikes - "spike_locations": dict(method="center_of_mass"), # trick to avoid UserWarning - "unit_locations": dict(), - "template_metrics": dict(), - "quality_metrics": dict(metric_names=["firing_rate", "isi_violation", "snr"]), - } - extension_data_type = { - "noise_levels": None, - "templates": "unit", - "isi_histograms": "unit", - "unit_locations": "unit", - "spike_amplitudes": "spike", - "amplitude_scalings": "spike", - "spike_locations": "spike", - "quality_metrics": "pandas", - "template_metrics": "pandas", - "correlograms": "matrix", - "template_similarity": "matrix", - "principal_components": "random", - "waveforms": "random", - "random_spikes": "random_spikes", - } - data_with_miltiple_returns = ["isi_histograms", "correlograms"] - - # due to incremental PCA, hard computation could result in different results for PCA - # the model is differents always - random_computation = ["principal_components"] - - sorting_analyzer.compute(extension_dict, n_jobs=1) + sorting_analyzer.compute(extension_dict_merge, n_jobs=1) # TODO: still some UserWarnings for n_jobs, where from? t0 = time.perf_counter() @@ -165,6 +165,74 @@ def test_SortingAnalyzer_merge_all_extensions(dataset, sparse): assert np.allclose(data_hard_merged[f], data_soft_merged[f], rtol=0.1) +@pytest.mark.parametrize("sparse", [False, True]) +def test_SortingAnalyzer_split_all_extensions(dataset, sparse): + set_global_job_kwargs(n_jobs=1) + + recording, sorting, _ = dataset + + sorting_analyzer = create_sorting_analyzer(sorting, recording, format="memory", sparse=sparse) + extension_dict_split = extension_dict.copy() + sorting_analyzer.compute(extension_dict, n_jobs=1) + + # we randomly apply splits (at half of spiketrain) + num_spikes = sorting.count_num_spikes_per_unit() + + units_to_split = [sorting_analyzer.unit_ids[1], sorting_analyzer.unit_ids[5]] + unsplit_unit_ids = sorting_analyzer.unit_ids[~np.isin(sorting_analyzer.unit_ids, units_to_split)] + splits = {} + for unit in units_to_split: + splits[unit] = np.arange(num_spikes[unit] // 2) + + analyzer_split, split_unit_ids = sorting_analyzer.split_units(split_units=splits, return_new_unit_ids=True) + split_unit_ids = list(np.concatenate(split_unit_ids)) + + # also do a full recopute + analyzer_hard = create_sorting_analyzer(analyzer_split.sorting, recording, format="memory", sparse=sparse) + # we propagate random spikes to avoid random spikes to be recomputed + extension_dict_ = extension_dict_split.copy() + extension_dict_.pop("random_spikes") + analyzer_hard.extensions["random_spikes"] = analyzer_split.extensions["random_spikes"] + analyzer_hard.compute(extension_dict_, n_jobs=1) + + for ext in extension_dict: + # 1. check that data are exactly the same for unchanged units between original/split + data_original = sorting_analyzer.get_extension(ext).get_data() + data_split = analyzer_split.get_extension(ext).get_data() + data_recompute = analyzer_hard.get_extension(ext).get_data() + if ext in data_with_miltiple_returns: + data_original = data_original[0] + data_split = data_split[0] + data_recompute = data_recompute[0] + data_original_unsplit = get_extension_data_for_units( + sorting_analyzer, data_original, unsplit_unit_ids, extension_data_type[ext] + ) + data_split_unsplit = get_extension_data_for_units( + analyzer_split, data_split, unsplit_unit_ids, extension_data_type[ext] + ) + + np.testing.assert_array_equal(data_original_unsplit, data_split_unsplit) + + # 2. check that split data are the same for extension split and recompute + data_split_soft = get_extension_data_for_units( + analyzer_split, data_split, split_unit_ids, extension_data_type[ext] + ) + data_split_hard = get_extension_data_for_units( + analyzer_hard, data_recompute, split_unit_ids, extension_data_type[ext] + ) + # TODO: fix amplitude scalings + failing_extensions = [] + if ext not in random_computation + failing_extensions: + if extension_data_type[ext] == "pandas": + data_split_soft = data_split_soft.dropna().to_numpy().astype("float") + data_split_hard = data_split_hard.dropna().to_numpy().astype("float") + if data_split_hard.dtype.fields is None: + assert np.allclose(data_split_hard, data_split_soft, rtol=0.1) + else: + for f in data_split_hard.dtype.fields: + assert np.allclose(data_split_hard[f], data_split_soft[f], rtol=0.1) + + def get_extension_data_for_units(sorting_analyzer, data, unit_ids, ext_data_type): unit_indices = sorting_analyzer.sorting.ids_to_indices(unit_ids) spike_vector = sorting_analyzer.sorting.to_spike_vector() @@ -191,5 +259,5 @@ def get_extension_data_for_units(sorting_analyzer, data, unit_ids, ext_data_type if __name__ == "__main__": - dataset = get_dataset() + dataset = get_dataset_with_splits() test_SortingAnalyzer_merge_all_extensions(dataset, False) From 58b62fb00f5010a4fb9b8a51f1d5ec301825ed1b Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 27 Mar 2025 13:18:14 -0400 Subject: [PATCH 13/22] wip - modify model --- src/spikeinterface/curation/curation_model.py | 18 +++++++++++++++++- 1 file changed, 17 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/curation/curation_model.py b/src/spikeinterface/curation/curation_model.py index b50db1e69c..9b4672f904 100644 --- a/src/spikeinterface/curation/curation_model.py +++ b/src/spikeinterface/curation/curation_model.py @@ -1,8 +1,9 @@ from pydantic import BaseModel, Field, field_validator, model_validator -from typing import List, Dict, Union, Optional +from typing import List, Dict, Union, Optional, Literal from itertools import combinations, chain import numpy as np + supported_curation_format_versions = {"1"} @@ -17,6 +18,21 @@ class ManualLabel(BaseModel): labels: Dict[str, List[str]] = Field(..., description="Dictionary of labels for the unit") +class Merges(BaseModel): + merge_unit_groups: List[List[Union[int, str]]] = Field(..., description="List of groups of units to be merged") + merge_new_unit_ids: List[Union[int, str]] = Field(..., description="List of new unit IDs for each merge group") + + +class Split(BaseModel): + unit_id: Union[int, str] = Field(..., description="ID of the unit") + split_mode: Literal["indices", "labels"] = Field(default="indices", description="Mode of the split") + split_indices: Optional[Union[List[List[int]]]] = Field(default=None, description="Information about the split") + split_labels = Optional[List[int]] = Field(default=None, description="List of labels for the split") + split_new_unit_ids: Optional[List[Union[int, str]]] = Field( + default=None, description="List of new unit IDs for each unit split" + ) + + class CurationModel(BaseModel): format_version: str = Field(..., description="Version of the curation format") unit_ids: List[Union[int, str]] = Field(..., description="List of unit IDs") From dbfa315296376063995b4ebb5ebae9b6d24ed727 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 27 Mar 2025 13:39:11 -0400 Subject: [PATCH 14/22] Refactor curation model to include merges and splits --- src/spikeinterface/curation/curation_model.py | 233 +++++++++++---- .../curation/tests/test_curation_model.py | 279 ++++++++++++++++++ 2 files changed, 462 insertions(+), 50 deletions(-) create mode 100644 src/spikeinterface/curation/tests/test_curation_model.py diff --git a/src/spikeinterface/curation/curation_model.py b/src/spikeinterface/curation/curation_model.py index 0e2f1da870..a595f32eac 100644 --- a/src/spikeinterface/curation/curation_model.py +++ b/src/spikeinterface/curation/curation_model.py @@ -1,13 +1,15 @@ from pydantic import BaseModel, Field, field_validator, model_validator -from typing import List, Dict, Union, Optional -from itertools import combinations +from typing import List, Dict, Union, Optional, Literal +from itertools import combinations, chain +import numpy as np + supported_curation_format_versions = {"1"} class LabelDefinition(BaseModel): name: str = Field(..., description="Name of the label") - label_options: List[str] = Field(..., description="List of possible label options") + label_options: List[str] = Field(..., description="List of possible label options", min_length=2) exclusive: bool = Field(..., description="Whether the label is exclusive") @@ -16,16 +18,39 @@ class ManualLabel(BaseModel): labels: Dict[str, List[str]] = Field(..., description="Dictionary of labels for the unit") +class Merge(BaseModel): + merge_unit_group: List[Union[int, str]] = Field(..., description="List of groups of units to be merged") + merge_new_unit_ids: Optional[Union[int, str]] = Field(default=None, description="New unit IDs for the merge group") + + +class Split(BaseModel): + unit_id: Union[int, str] = Field(..., description="ID of the unit") + split_mode: Literal["indices", "labels"] = Field( + default="indices", + description=( + "Mode of the split. The split can be defined by indices or labels. " + "If indices, the split is defined by the a list of lists of indices of spikes within spikes " + "belonging to the unit (`split_indices`). " + "If labels, the split is defined by a list of labels for each spike (`split_labels`). " + ), + ) + split_indices: Optional[Union[List[List[int]]]] = Field(default=None, description="List of indices for the split") + split_labels: Optional[List[int]] = Field(default=None, description="List of labels for the split") + split_new_unit_ids: Optional[List[Union[int, str]]] = Field( + default=None, description="List of new unit IDs for each split" + ) + + class CurationModel(BaseModel): format_version: str = Field(..., description="Version of the curation format") unit_ids: List[Union[int, str]] = Field(..., description="List of unit IDs") - label_definitions: Dict[str, LabelDefinition] = Field(..., description="Dictionary of label definitions") - manual_labels: List[ManualLabel] = Field(..., description="List of manual labels") - merge_unit_groups: List[List[Union[int, str]]] = Field(..., description="List of groups of units to be merged") - removed_units: List[Union[int, str]] = Field(..., description="List of removed unit IDs") - merge_new_unit_ids: Optional[List[Union[int, str]]] = Field( - default=None, description="List of new unit IDs after merging" + label_definitions: Optional[Dict[str, LabelDefinition]] = Field( + default=None, description="Dictionary of label definitions" ) + manual_labels: Optional[List[ManualLabel]] = Field(default=None, description="List of manual labels") + removed: Optional[List[Union[int, str]]] = Field(default=None, description="List of removed unit IDs") + merges: Optional[List[Merge]] = Field(default=None, description="List of merges") + splits: Optional[List[Split]] = Field(default=None, description="List of splits") @field_validator("format_version") def check_format_version(cls, v): @@ -33,19 +58,26 @@ def check_format_version(cls, v): raise ValueError(f"Format version ({v}) not supported. Only {supported_curation_format_versions} are valid") return v - @field_validator("label_definitions", mode="before") - def add_label_definition_name(cls, v): - if v is None: - v = {} - else: - for key in list(v.keys()): - v[key]["name"] = key - return v + @model_validator(mode="before") + def add_label_definition_name(cls, values): + label_definitions = values.get("label_definitions") + if label_definitions is None: + values["label_definitions"] = {} + return values + if isinstance(values["label_definitions"], dict): + if label_definitions is None: + label_definitions = {} + else: + for key in list(label_definitions.keys()): + if isinstance(label_definitions[key], dict): + label_definitions[key]["name"] = key + values["label_definitions"] = label_definitions + return values @model_validator(mode="before") def check_manual_labels(cls, values): unit_ids = values["unit_ids"] - manual_labels = values["manual_labels"] + manual_labels = values.get("manual_labels") if manual_labels is None: values["manual_labels"] = [] else: @@ -58,52 +90,148 @@ def check_manual_labels(cls, values): for label in labels: if label not in values["label_definitions"]: raise ValueError(f"Manual label {unit_id} has an unknown label {label}") - manual_label["labels"][label] = manual_label[label] + if label not in manual_label["labels"]: + if label in manual_label: + manual_label["labels"][label] = manual_label[label] + else: + raise ValueError(f"Manual label {unit_id} has no value for label {label}") if unit_id not in unit_ids: raise ValueError(f"Manual label unit_id {unit_id} is not in the unit list") return values @model_validator(mode="before") - def check_merge_unit_groups(cls, values): + def check_merges(cls, values): unit_ids = values["unit_ids"] - merge_unit_groups = values.get("merge_unit_groups", []) - for merge_group in merge_unit_groups: - for unit_id in merge_group: + merges = values.get("merges") + if merges is None: + values["merges"] = [] + return values + elif isinstance(merges, list): + # Convert list of lists to Merge objects + for i, merge in enumerate(merges): + if isinstance(merge, list): + merges[i] = Merge(merge_unit_group=merge) + elif isinstance(merges, dict): + # Convert dict format to list of Merge objects if needed + merge_list = [] + for merge_new_id, merge_group in merges.items(): + merge_list.append({"merge_unit_group": merge_group, "merge_new_unit_ids": merge_new_id}) + merges = merge_list + values["merges"] = merges + + # Convert dict items to Merge objects if needed + for i, merge in enumerate(merges): + if isinstance(merge, dict): + merges[i] = Merge(**merge) + + for merge in merges: + # Check unit ids exist + for unit_id in merge.merge_unit_group: if unit_id not in unit_ids: raise ValueError(f"Merge unit group unit_id {unit_id} is not in the unit list") - if len(merge_group) < 2: + + # Check minimum group size + if len(merge.merge_unit_group) < 2: raise ValueError("Merge unit groups must have at least 2 elements") + + # Check new unit id not already used + if merge.merge_new_unit_ids is not None: + if merge.merge_new_unit_ids in unit_ids: + raise ValueError(f"New unit ID {merge.merge_new_unit_ids} is already in the unit list") + return values @model_validator(mode="before") - def check_merge_new_unit_ids(cls, values): + def check_splits(cls, values): unit_ids = values["unit_ids"] - merge_new_unit_ids = values.get("merge_new_unit_ids") - if merge_new_unit_ids is not None: - merge_unit_groups = values.get("merge_unit_groups") - assert merge_unit_groups is not None, "Merge unit groups must be defined if merge new unit ids are defined" - if len(merge_unit_groups) != len(merge_new_unit_ids): - raise ValueError("Merge unit groups and new unit ids must have the same length") - if len(merge_new_unit_ids) > 0: - for new_unit_id in merge_new_unit_ids: - if new_unit_id in unit_ids: - raise ValueError(f"New unit ID {new_unit_id} is already in the unit list") + splits = values.get("splits") + if splits is None: + values["splits"] = [] + return values + + # Convert dict format to list of Split objects if needed + if isinstance(splits, dict): + split_list = [] + for unit_id, split_data in splits.items(): + # If split_data is a list of indices, assume indices mode + if isinstance(split_data[0], (list, np.ndarray)) if split_data else False: + split_list.append({"unit_id": unit_id, "split_mode": "indices", "split_indices": split_data}) + # Otherwise assume it's a list of labels + else: + split_list.append({"unit_id": unit_id, "split_mode": "labels", "split_labels": split_data}) + splits = split_list + values["splits"] = splits + + # Convert dict items to Split objects if needed + for i, split in enumerate(splits): + if isinstance(split, dict): + splits[i] = Split(**split) + + for split in splits: + # Check unit exists + if split.unit_id not in unit_ids: + raise ValueError(f"Split unit_id {split.unit_id} is not in the unit list") + + # Check split definition based on mode + if split.split_mode == "indices": + if split.split_indices is None: + raise ValueError(f"Split unit {split.unit_id} has no split_indices defined") + if len(split.split_indices) < 1: + raise ValueError(f"Split unit {split.unit_id} has empty split_indices") + # Check no duplicate indices across splits + all_indices = list(chain.from_iterable(split.split_indices)) + if len(all_indices) != len(set(all_indices)): + raise ValueError(f"Split unit {split.unit_id} has duplicate indices") + + elif split.split_mode == "labels": + if split.split_labels is None: + raise ValueError(f"Split unit {split.unit_id} has no split_labels defined") + if len(split.split_labels) == 0: + raise ValueError(f"Split unit {split.unit_id} has empty split_labels") + + # Check new unit ids if provided + if split.split_new_unit_ids is not None: + if split.split_mode == "indices": + if len(split.split_new_unit_ids) != len(split.split_indices): + raise ValueError( + f"Number of new unit IDs does not match number of splits for unit {split.unit_id}" + ) + elif split.split_mode == "labels": + if len(split.split_new_unit_ids) != len(set(split.split_labels)): + raise ValueError( + f"Number of new unit IDs does not match number of unique labels for unit {split.unit_id}" + ) + + # Check new ids not already used + for new_id in split.split_new_unit_ids: + if new_id in unit_ids: + raise ValueError(f"New unit ID {new_id} is already in the unit list") + return values @model_validator(mode="before") - def check_removed_units(cls, values): + def check_removed(cls, values): unit_ids = values["unit_ids"] - removed_units = values.get("removed_units", []) - for unit_id in removed_units: - if unit_id not in unit_ids: - raise ValueError(f"Removed unit_id {unit_id} is not in the unit list") + removed = values.get("removed", []) + if removed is None: + + for unit_id in removed: + if unit_id not in unit_ids: + raise ValueError(f"Removed unit_id {unit_id} is not in the unit list") + + else: + values["removed"] = removed + return values @model_validator(mode="after") def validate_curation_dict(cls, values): - labeled_unit_set = set([lbl.unit_id for lbl in values.manual_labels]) - merged_units_set = set(sum(values.merge_unit_groups, [])) - removed_units_set = set(values.removed_units) + labeled_unit_set = set([lbl.unit_id for lbl in values.manual_labels]) if values.manual_labels else set() + merged_units_set = ( + set(chain.from_iterable(merge.merge_unit_group for merge in values.merges)) if values.merges else set() + ) + split_units_set = set(split.unit_id for split in values.splits) if values.splits else set() + removed_set = set(values.removed) if values.removed else set() unit_ids = values.unit_ids unit_set = set(unit_ids) @@ -111,19 +239,24 @@ def validate_curation_dict(cls, values): raise ValueError("Curation format: some labeled units are not in the unit list") if not merged_units_set.issubset(unit_set): raise ValueError("Curation format: some merged units are not in the unit list") - if not removed_units_set.issubset(unit_set): + if not split_units_set.issubset(unit_set): + raise ValueError("Curation format: some split units are not in the unit list") + if not removed_set.issubset(unit_set): raise ValueError("Curation format: some removed units are not in the unit list") - for group in values.merge_unit_groups: - if len(group) < 2: - raise ValueError("Curation format: 'merge_unit_groups' must be list of list with at least 2 elements") - - all_merging_groups = [set(group) for group in values.merge_unit_groups] + # Check for units being merged multiple times + all_merging_groups = [set(merge.merge_unit_group) for merge in values.merges] if values.merges else [] for gp_1, gp_2 in combinations(all_merging_groups, 2): if len(gp_1.intersection(gp_2)) != 0: raise ValueError("Curation format: some units belong to multiple merge groups") - if len(removed_units_set.intersection(merged_units_set)) != 0: + + # Check no overlaps between operations + if len(removed_set.intersection(merged_units_set)) != 0: raise ValueError("Curation format: some units were merged and deleted") + if len(removed_set.intersection(split_units_set)) != 0: + raise ValueError("Curation format: some units were split and deleted") + if len(merged_units_set.intersection(split_units_set)) != 0: + raise ValueError("Curation format: some units were both merged and split") for manual_label in values.manual_labels: for label_key in values.label_definitions.keys(): diff --git a/src/spikeinterface/curation/tests/test_curation_model.py b/src/spikeinterface/curation/tests/test_curation_model.py new file mode 100644 index 0000000000..29292c495c --- /dev/null +++ b/src/spikeinterface/curation/tests/test_curation_model.py @@ -0,0 +1,279 @@ +import pytest + +from pydantic import ValidationError +import numpy as np + +from spikeinterface.curation.curation_model import CurationModel, LabelDefinition + + +# Test data for format version +def test_format_version(): + # Valid format version + CurationModel(format_version="1", unit_ids=[1, 2, 3]) + + # Invalid format version + with pytest.raises(ValidationError): + CurationModel(format_version="2", unit_ids=[1, 2, 3]) + with pytest.raises(ValidationError): + CurationModel(format_version="0", unit_ids=[1, 2, 3]) + + +# Test data for label definitions +def test_label_definitions(): + valid_label_def = { + "format_version": "1", + "unit_ids": [1, 2, 3], + "label_definitions": { + "quality": LabelDefinition(name="quality", label_options=["good", "noise"], exclusive=True), + "tags": LabelDefinition(name="tags", label_options=["burst", "slow", "fast"], exclusive=False), + }, + } + + model = CurationModel(**valid_label_def) + assert "quality" in model.label_definitions + assert model.label_definitions["quality"].name == "quality" + assert model.label_definitions["quality"].exclusive is True + + # Test invalid label definition + with pytest.raises(ValidationError): + LabelDefinition(name="quality", label_options=[], exclusive=True) # Empty options should be invalid + + +# Test manual labels +def test_manual_labels(): + valid_labels = { + "format_version": "1", + "unit_ids": [1, 2, 3], + "label_definitions": { + "quality": LabelDefinition(name="quality", label_options=["good", "noise"], exclusive=True), + "tags": LabelDefinition(name="tags", label_options=["burst", "slow", "fast"], exclusive=False), + }, + "manual_labels": [ + {"unit_id": 1, "labels": {"quality": ["good"], "tags": ["burst", "fast"]}}, + {"unit_id": 2, "labels": {"quality": ["noise"]}}, + ], + } + + model = CurationModel(**valid_labels) + assert len(model.manual_labels) == 2 + + # Test invalid unit ID + invalid_unit = { + "format_version": "1", + "unit_ids": [1, 2, 3], + "label_definitions": { + "quality": LabelDefinition(name="quality", label_options=["good", "noise"], exclusive=True) + }, + "manual_labels": [{"unit_id": 4, "labels": {"quality": ["good"]}}], # Non-existent unit + } + with pytest.raises(ValidationError): + CurationModel(**invalid_unit) + + # Test violation of exclusive label + invalid_exclusive = { + "format_version": "1", + "unit_ids": [1, 2, 3], + "label_definitions": { + "quality": LabelDefinition(name="quality", label_options=["good", "noise"], exclusive=True) + }, + "manual_labels": [ + {"unit_id": 1, "labels": {"quality": ["good", "noise"]}} # Multiple values for exclusive label + ], + } + with pytest.raises(ValidationError): + CurationModel(**invalid_exclusive) + + +# Test merge functionality +def test_merge_units(): + # Test list format + valid_merge = { + "format_version": "1", + "unit_ids": [1, 2, 3, 4], + "merges": [ + {"merge_unit_group": [1, 2], "merge_new_unit_ids": 5}, + {"merge_unit_group": [3, 4], "merge_new_unit_ids": 6}, + ], + } + + model = CurationModel(**valid_merge) + assert len(model.merges) == 2 + assert model.merges[0].merge_new_unit_ids == 5 + assert model.merges[1].merge_new_unit_ids == 6 + + # Test dictionary format + valid_merge_dict = {"format_version": "1", "unit_ids": [1, 2, 3, 4], "merges": {5: [1, 2], 6: [3, 4]}} + + model = CurationModel(**valid_merge_dict) + assert len(model.merges) == 2 + merge_new_ids = {merge.merge_new_unit_ids for merge in model.merges} + assert merge_new_ids == {5, 6} + + # Test invalid merge group (single unit) + invalid_merge_group = { + "format_version": "1", + "unit_ids": [1, 2, 3], + "merges": [{"merge_unit_group": [1], "merge_new_unit_ids": 4}], + } + with pytest.raises(ValidationError): + CurationModel(**invalid_merge_group) + + # Test overlapping merge groups + invalid_overlap = { + "format_version": "1", + "unit_ids": [1, 2, 3], + "merges": [ + {"merge_unit_group": [1, 2], "merge_new_unit_ids": 4}, + {"merge_unit_group": [2, 3], "merge_new_unit_ids": 5}, + ], + } + with pytest.raises(ValidationError): + CurationModel(**invalid_overlap) + + +# Test split functionality +def test_split_units(): + # Test indices mode with list format + valid_split_indices = { + "format_version": "1", + "unit_ids": [1, 2, 3], + "splits": [ + { + "unit_id": 1, + "split_mode": "indices", + "split_indices": [[0, 1, 2], [3, 4, 5]], + "split_new_unit_ids": [4, 5], + } + ], + } + + model = CurationModel(**valid_split_indices) + assert len(model.splits) == 1 + assert model.splits[0].split_mode == "indices" + assert len(model.splits[0].split_indices) == 2 + + # Test labels mode with list format + valid_split_labels = { + "format_version": "1", + "unit_ids": [1, 2, 3], + "splits": [ + {"unit_id": 1, "split_mode": "labels", "split_labels": [0, 0, 1, 1, 0, 2], "split_new_unit_ids": [4, 5, 6]} + ], + } + + model = CurationModel(**valid_split_labels) + assert len(model.splits) == 1 + assert model.splits[0].split_mode == "labels" + assert len(set(model.splits[0].split_labels)) == 3 + + # Test dictionary format with indices + valid_split_dict = { + "format_version": "1", + "unit_ids": [1, 2, 3], + "splits": { + 1: [[0, 1, 2], [3, 4, 5]], # Split unit 1 into two parts + 2: [[0, 1], [2, 3], [4, 5]], # Split unit 2 into three parts + }, + } + + model = CurationModel(**valid_split_dict) + assert len(model.splits) == 2 + assert all(split.split_mode == "indices" for split in model.splits) + + # Test invalid unit ID + invalid_unit_id = { + "format_version": "1", + "unit_ids": [1, 2, 3], + "splits": [{"unit_id": 4, "split_mode": "indices", "split_indices": [[0, 1], [2, 3]]}], # Non-existent unit + } + with pytest.raises(ValidationError): + CurationModel(**invalid_unit_id) + + # Test invalid new unit IDs count for indices mode + invalid_new_ids = { + "format_version": "1", + "unit_ids": [1, 2, 3], + "splits": [ + { + "unit_id": 1, + "split_mode": "indices", + "split_indices": [[0, 1], [2, 3]], + "split_new_unit_ids": [4], # Should have 2 new IDs for 2 splits + } + ], + } + with pytest.raises(ValidationError): + CurationModel(**invalid_new_ids) + + +# Test removed units +def test_removed_units(): + valid_remove = {"format_version": "1", "unit_ids": [1, 2, 3], "removed": [2]} + + model = CurationModel(**valid_remove) + assert len(model.removed) == 1 + + # Test removing non-existent unit + invalid_remove = {"format_version": "1", "unit_ids": [1, 2, 3], "removed": [4]} # Non-existent unit + with pytest.raises(ValidationError): + CurationModel(**invalid_remove) + + # Test conflict between merge and remove + invalid_merge_remove = { + "format_version": "1", + "unit_ids": [1, 2, 3], + "merges": [{"merge_unit_group": [1, 2], "merge_new_unit_ids": 4}], + "removed": [1], # Unit is both merged and removed + } + with pytest.raises(ValidationError): + CurationModel(**invalid_merge_remove) + + +# Test complete model with multiple operations +def test_complete_model(): + complete_model = { + "format_version": "1", + "unit_ids": [1, 2, 3, 4, 5], + "label_definitions": { + "quality": LabelDefinition(name="quality", label_options=["good", "noise"], exclusive=True), + "tags": LabelDefinition(name="tags", label_options=["burst", "slow"], exclusive=False), + }, + "manual_labels": [{"unit_id": 1, "labels": {"quality": ["good"], "tags": ["burst"]}}], + "merges": [{"merge_unit_group": [2, 3], "merge_new_unit_ids": 6}], + "splits": [ + {"unit_id": 4, "split_mode": "indices", "split_indices": [[0, 1], [2, 3]], "split_new_unit_ids": [7, 8]} + ], + "removed": [5], + } + + model = CurationModel(**complete_model) + assert model.format_version == "1" + assert len(model.unit_ids) == 5 + assert len(model.label_definitions) == 2 + assert len(model.manual_labels) == 1 + assert len(model.merges) == 1 + assert len(model.splits) == 1 + assert len(model.removed) == 1 + + # Test dictionary format for complete model + complete_model_dict = { + "format_version": "1", + "unit_ids": [1, 2, 3, 4, 5], + "label_definitions": { + "quality": LabelDefinition(name="quality", label_options=["good", "noise"], exclusive=True), + "tags": LabelDefinition(name="tags", label_options=["burst", "slow"], exclusive=False), + }, + "manual_labels": [{"unit_id": 1, "labels": {"quality": ["good"], "tags": ["burst"]}}], + "merges": {6: [2, 3]}, + "splits": {4: [[0, 1], [2, 3]]}, + "removed": [5], + } + + model = CurationModel(**complete_model_dict) + assert model.format_version == "1" + assert len(model.unit_ids) == 5 + assert len(model.label_definitions) == 2 + assert len(model.manual_labels) == 1 + assert len(model.merges) == 1 + assert len(model.splits) == 1 + assert len(model.removed) == 1 From 82526b0dc40457c1e14154af701d5aeab26bffd3 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 27 Mar 2025 13:46:35 -0400 Subject: [PATCH 15/22] Add merge list to tests --- src/spikeinterface/curation/tests/test_curation_model.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/spikeinterface/curation/tests/test_curation_model.py b/src/spikeinterface/curation/tests/test_curation_model.py index 29292c495c..1ebb36296d 100644 --- a/src/spikeinterface/curation/tests/test_curation_model.py +++ b/src/spikeinterface/curation/tests/test_curation_model.py @@ -109,6 +109,15 @@ def test_merge_units(): merge_new_ids = {merge.merge_new_unit_ids for merge in model.merges} assert merge_new_ids == {5, 6} + # Test list format + valid_merge_list = { + "format_version": "1", + "unit_ids": [1, 2, 3, 4], + "merges": [[1, 2], [3, 4]], # Merge each pair into a new unit + } + model = CurationModel(**valid_merge_list) + assert len(model.merges) == 2 + # Test invalid merge group (single unit) invalid_merge_group = { "format_version": "1", From 482f0be221177f31f341e4a70e7a1621f2850fdb Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 27 Mar 2025 15:20:09 -0400 Subject: [PATCH 16/22] Simplify and centralize conversion and checks --- .../curation/curation_format.py | 122 ++--------- src/spikeinterface/curation/curation_model.py | 207 ++++++++++++------ .../curation/sortingview_curation.py | 143 ++---------- .../curation/tests/test_curation_format.py | 165 +++++++++----- .../curation/tests/test_curation_model.py | 46 ++-- 5 files changed, 320 insertions(+), 363 deletions(-) diff --git a/src/spikeinterface/curation/curation_format.py b/src/spikeinterface/curation/curation_format.py index 186bb34568..d634735344 100644 --- a/src/spikeinterface/curation/curation_format.py +++ b/src/spikeinterface/curation/curation_format.py @@ -1,9 +1,7 @@ from __future__ import annotations -import copy import numpy as np -from spikeinterface import curation from spikeinterface.core import BaseSorting, SortingAnalyzer, apply_merges_to_sorting from spikeinterface.curation.curation_model import CurationModel @@ -23,69 +21,6 @@ def validate_curation_dict(curation_dict: dict): CurationModel(**curation_dict) -def convert_from_sortingview_curation_format_v0(sortingview_dict: dict, destination_format: str = "1"): - """ - Converts the old sortingview curation format (v0) into a curation dictionary new format (v1) - Couple of caveats: - * The list of units is not available in the original sortingview dictionary. We set it to None - * Labels can not be mutually exclusive. - * Labels have no category, so we regroup them under the "all_labels" category - - Parameters - ---------- - sortingview_dict : dict - Dictionary containing the curation information from sortingview - destination_format : str - Version of the format to use. - Default to "1" - - Returns - ------- - curation_dict : dict - A curation dictionary - """ - - assert destination_format == "1" - if "mergeGroups" not in sortingview_dict.keys(): - sortingview_dict["mergeGroups"] = [] - merge_groups = sortingview_dict["mergeGroups"] - - first_unit_id = next(iter(sortingview_dict["labelsByUnit"].keys())) - if str.isdigit(first_unit_id): - unit_id_type = int - else: - unit_id_type = str - - all_units = [] - all_labels = [] - manual_labels = [] - general_cat = "all_labels" - for unit_id_, l_labels in sortingview_dict["labelsByUnit"].items(): - all_labels.extend(l_labels) - # recorver the correct type for unit_id - unit_id = unit_id_type(unit_id_) - if unit_id not in all_units: - all_units.append(unit_id) - manual_labels.append({"unit_id": unit_id, general_cat: l_labels}) - labels_def = {"all_labels": {"name": "all_labels", "label_options": list(set(all_labels)), "exclusive": False}} - - for merge_group in merge_groups: - for unit_id in merge_group: - if unit_id not in all_units: - all_units.append(unit_id) - - curation_dict = { - "format_version": destination_format, - "unit_ids": all_units, - "label_definitions": labels_def, - "manual_labels": manual_labels, - "merge_unit_groups": merge_groups, - "removed_units": [], - } - - return curation_dict - - def curation_label_to_vectors(curation_dict_or_model: dict | CurationModel): """ Transform the curation dict into dict of vectors. @@ -135,29 +70,6 @@ def curation_label_to_vectors(curation_dict_or_model: dict | CurationModel): return labels -def clean_curation_dict(curation_dict: dict): - """ - In some cases the curation_dict can have inconsistencies (like in the sorting view format). - For instance, some unit_ids are both in 'merge_unit_groups' and 'removed_units'. - This is ambiguous! - - This cleaner helper function ensures units tagged as `removed_units` are removed from the `merge_unit_groups` - """ - curation_dict = copy.deepcopy(curation_dict) - - clean_merge_unit_groups = [] - for group in curation_dict["merge_unit_groups"]: - clean_group = [] - for unit_id in group: - if unit_id not in curation_dict["removed_units"]: - clean_group.append(unit_id) - if len(clean_group) > 1: - clean_merge_unit_groups.append(clean_group) - - curation_dict["merge_unit_groups"] = clean_merge_unit_groups - return curation_dict - - def curation_label_to_dataframe(curation_dict_or_model: dict | CurationModel): """ Transform the curation dict into a pandas dataframe. @@ -215,7 +127,8 @@ def apply_curation_labels( all_values[unit_ind] = values[ind] sorting.set_property(key, all_values) - for new_unit_id, old_group_ids in zip(new_unit_ids, curation_model.merge_unit_groups): + for new_unit_id, merge in zip(new_unit_ids, curation_model.merges): + old_group_ids = merge.merge_unit_group for label_key, label_def in curation_model.label_definitions.items(): if label_def.exclusive: group_values = [] @@ -253,8 +166,8 @@ def apply_curation( Apply curation dict to a Sorting or a SortingAnalyzer. Steps are done in this order: - 1. Apply removal using curation_dict["removed_units"] - 2. Apply merges using curation_dict["merge_unit_groups"] + 1. Apply removal using curation_dict["removed"] + 2. Apply merges using curation_dict["merges"] 3. Set labels using curation_dict["manual_labels"] A new Sorting or SortingAnalyzer (in memory) is returned. @@ -303,24 +216,27 @@ def apply_curation( if isinstance(sorting_or_analyzer, BaseSorting): sorting = sorting_or_analyzer - sorting = sorting.remove_units(curation_model.removed_units) - sorting, _, new_unit_ids = apply_merges_to_sorting( - sorting, - curation_model.merge_unit_groups, - censor_ms=censor_ms, - return_extra=True, - new_id_strategy=new_id_strategy, - ) + sorting = sorting.remove_units(curation_model.removed) + if len(curation_model.merges) > 0: + sorting, _, new_unit_ids = apply_merges_to_sorting( + sorting, + merge_unit_groups=[m.merge_unit_group for m in curation_model.merges], + censor_ms=censor_ms, + return_extra=True, + new_id_strategy=new_id_strategy, + ) + else: + new_unit_ids = [] apply_curation_labels(sorting, new_unit_ids, curation_model) return sorting elif isinstance(sorting_or_analyzer, SortingAnalyzer): analyzer = sorting_or_analyzer - if len(curation_model.removed_units) > 0: - analyzer = analyzer.remove_units(curation_model.removed_units) - if len(curation_model.merge_unit_groups) > 0: + if len(curation_model.removed) > 0: + analyzer = analyzer.remove_units(curation_model.removed) + if len(curation_model.removed) > 0: analyzer, new_unit_ids = analyzer.merge_units( - curation_model.merge_unit_groups, + merge_unit_groups=[m.merge_unit_group for m in curation_model.merges], censor_ms=censor_ms, merging_mode=merging_mode, sparsity_overlap=sparsity_overlap, diff --git a/src/spikeinterface/curation/curation_model.py b/src/spikeinterface/curation/curation_model.py index a595f32eac..787af88111 100644 --- a/src/spikeinterface/curation/curation_model.py +++ b/src/spikeinterface/curation/curation_model.py @@ -1,12 +1,9 @@ -from pydantic import BaseModel, Field, field_validator, model_validator -from typing import List, Dict, Union, Optional, Literal -from itertools import combinations, chain +from pydantic import BaseModel, Field, model_validator, field_validator +from typing import List, Dict, Union, Optional, Literal, Tuple +from itertools import chain, combinations import numpy as np -supported_curation_format_versions = {"1"} - - class LabelDefinition(BaseModel): name: str = Field(..., description="Name of the label") label_options: List[str] = Field(..., description="List of possible label options", min_length=2) @@ -42,6 +39,9 @@ class Split(BaseModel): class CurationModel(BaseModel): + supported_versions: Tuple[Literal["1"], Literal["2"]] = Field( + default=["1", "2"], description="Supported versions of the curation format" + ) format_version: str = Field(..., description="Version of the curation format") unit_ids: List[Union[int, str]] = Field(..., description="List of unit IDs") label_definitions: Optional[Dict[str, LabelDefinition]] = Field( @@ -52,78 +52,81 @@ class CurationModel(BaseModel): merges: Optional[List[Merge]] = Field(default=None, description="List of merges") splits: Optional[List[Split]] = Field(default=None, description="List of splits") - @field_validator("format_version") - def check_format_version(cls, v): - if v not in supported_curation_format_versions: - raise ValueError(f"Format version ({v}) not supported. Only {supported_curation_format_versions} are valid") - return v - - @model_validator(mode="before") - def add_label_definition_name(cls, values): - label_definitions = values.get("label_definitions") + @field_validator("label_definitions", mode="before") + def add_label_definition_name(cls, label_definitions): if label_definitions is None: - values["label_definitions"] = {} - return values - if isinstance(values["label_definitions"], dict): - if label_definitions is None: - label_definitions = {} - else: - for key in list(label_definitions.keys()): - if isinstance(label_definitions[key], dict): - label_definitions[key]["name"] = key - values["label_definitions"] = label_definitions - return values - - @model_validator(mode="before") + return {} + if isinstance(label_definitions, dict): + label_definitions = dict(label_definitions) + for key in list(label_definitions.keys()): + if isinstance(label_definitions[key], dict): + label_definitions[key] = dict(label_definitions[key]) + label_definitions[key]["name"] = key + return label_definitions + return label_definitions + + @classmethod def check_manual_labels(cls, values): - unit_ids = values["unit_ids"] + values = dict(values) + unit_ids = list(values["unit_ids"]) manual_labels = values.get("manual_labels") if manual_labels is None: values["manual_labels"] = [] else: - for manual_label in manual_labels: + manual_labels = list(manual_labels) + for i, manual_label in enumerate(manual_labels): + manual_label = dict(manual_label) unit_id = manual_label["unit_id"] labels = manual_label.get("labels") if labels is None: labels = set(manual_label.keys()) - {"unit_id"} manual_label["labels"] = {} + else: + manual_label["labels"] = {k: list(v) for k, v in labels.items()} for label in labels: if label not in values["label_definitions"]: raise ValueError(f"Manual label {unit_id} has an unknown label {label}") if label not in manual_label["labels"]: if label in manual_label: - manual_label["labels"][label] = manual_label[label] + manual_label["labels"][label] = list(manual_label[label]) else: raise ValueError(f"Manual label {unit_id} has no value for label {label}") if unit_id not in unit_ids: raise ValueError(f"Manual label unit_id {unit_id} is not in the unit list") + manual_labels[i] = manual_label + values["manual_labels"] = manual_labels return values - @model_validator(mode="before") + @classmethod def check_merges(cls, values): - unit_ids = values["unit_ids"] + values = dict(values) + unit_ids = list(values["unit_ids"]) merges = values.get("merges") if merges is None: values["merges"] = [] return values - elif isinstance(merges, list): - # Convert list of lists to Merge objects - for i, merge in enumerate(merges): - if isinstance(merge, list): - merges[i] = Merge(merge_unit_group=merge) - elif isinstance(merges, dict): - # Convert dict format to list of Merge objects if needed + + if isinstance(merges, dict): + # Convert dict format to list of Merge objects merge_list = [] for merge_new_id, merge_group in merges.items(): - merge_list.append({"merge_unit_group": merge_group, "merge_new_unit_ids": merge_new_id}) + merge_list.append({"merge_unit_group": list(merge_group), "merge_new_unit_ids": merge_new_id}) merges = merge_list - values["merges"] = merges - # Convert dict items to Merge objects if needed + # Make a copy of the list + merges = list(merges) + + # Convert items to Merge objects for i, merge in enumerate(merges): + if isinstance(merge, list): + merge = {"merge_unit_group": list(merge)} if isinstance(merge, dict): + merge = dict(merge) + if "merge_unit_group" in merge: + merge["merge_unit_group"] = list(merge["merge_unit_group"]) merges[i] = Merge(**merge) + # Validate merges for merge in merges: # Check unit ids exist for unit_id in merge.merge_unit_group: @@ -139,46 +142,62 @@ def check_merges(cls, values): if merge.merge_new_unit_ids in unit_ids: raise ValueError(f"New unit ID {merge.merge_new_unit_ids} is already in the unit list") + values["merges"] = merges return values - @model_validator(mode="before") + @classmethod def check_splits(cls, values): - unit_ids = values["unit_ids"] + values = dict(values) + unit_ids = list(values["unit_ids"]) splits = values.get("splits") if splits is None: values["splits"] = [] return values - # Convert dict format to list of Split objects if needed + # Convert dict format to list format if isinstance(splits, dict): split_list = [] for unit_id, split_data in splits.items(): - # If split_data is a list of indices, assume indices mode if isinstance(split_data[0], (list, np.ndarray)) if split_data else False: - split_list.append({"unit_id": unit_id, "split_mode": "indices", "split_indices": split_data}) - # Otherwise assume it's a list of labels + split_list.append( + { + "unit_id": unit_id, + "split_mode": "indices", + "split_indices": [list(indices) for indices in split_data], + } + ) else: - split_list.append({"unit_id": unit_id, "split_mode": "labels", "split_labels": split_data}) + split_list.append({"unit_id": unit_id, "split_mode": "labels", "split_labels": list(split_data)}) splits = split_list - values["splits"] = splits - # Convert dict items to Split objects if needed + # Make a copy of the list + splits = list(splits) + + # Convert items to Split objects for i, split in enumerate(splits): if isinstance(split, dict): + split = dict(split) + if "split_indices" in split: + split["split_indices"] = [list(indices) for indices in split["split_indices"]] + if "split_labels" in split: + split["split_labels"] = list(split["split_labels"]) + if "split_new_unit_ids" in split: + split["split_new_unit_ids"] = list(split["split_new_unit_ids"]) splits[i] = Split(**split) + # Validate splits for split in splits: # Check unit exists if split.unit_id not in unit_ids: raise ValueError(f"Split unit_id {split.unit_id} is not in the unit list") - # Check split definition based on mode + # Validate based on mode if split.split_mode == "indices": if split.split_indices is None: raise ValueError(f"Split unit {split.unit_id} has no split_indices defined") if len(split.split_indices) < 1: raise ValueError(f"Split unit {split.unit_id} has empty split_indices") - # Check no duplicate indices across splits + # Check no duplicate indices all_indices = list(chain.from_iterable(split.split_indices)) if len(all_indices) != len(set(all_indices)): raise ValueError(f"Split unit {split.unit_id} has duplicate indices") @@ -189,7 +208,7 @@ def check_splits(cls, values): if len(split.split_labels) == 0: raise ValueError(f"Split unit {split.unit_id} has empty split_labels") - # Check new unit ids if provided + # Validate new unit IDs if split.split_new_unit_ids is not None: if split.split_mode == "indices": if len(split.split_new_unit_ids) != len(split.split_indices): @@ -202,30 +221,94 @@ def check_splits(cls, values): f"Number of new unit IDs does not match number of unique labels for unit {split.unit_id}" ) - # Check new ids not already used for new_id in split.split_new_unit_ids: if new_id in unit_ids: raise ValueError(f"New unit ID {new_id} is already in the unit list") + values["splits"] = splits return values - @model_validator(mode="before") + @classmethod def check_removed(cls, values): - unit_ids = values["unit_ids"] - removed = values.get("removed", []) + values = dict(values) + unit_ids = list(values["unit_ids"]) + removed = values.get("removed") if removed is None: - + values["removed"] = [] + else: + removed = list(removed) for unit_id in removed: if unit_id not in unit_ids: raise ValueError(f"Removed unit_id {unit_id} is not in the unit list") - - else: values["removed"] = removed + return values + + @classmethod + def convert_old_format(cls, values): + format_version = values.get("format_version", "0") + if format_version != "2": + values = dict(values) + if format_version == "0": + print("Conversion from format version v0 (sortingview) to v2") + if "mergeGroups" not in values.keys(): + values["mergeGroups"] = [] + merge_groups = values["mergeGroups"] + + first_unit_id = next(iter(values["labelsByUnit"].keys())) + if str.isdigit(first_unit_id): + unit_id_type = int + else: + unit_id_type = str + + all_units = [] + all_labels = [] + manual_labels = [] + general_cat = "all_labels" + for unit_id_, l_labels in values["labelsByUnit"].items(): + all_labels.extend(l_labels) + unit_id = unit_id_type(unit_id_) + if unit_id not in all_units: + all_units.append(unit_id) + manual_labels.append({"unit_id": unit_id, general_cat: list(l_labels)}) + labels_def = { + "all_labels": {"name": "all_labels", "label_options": list(set(all_labels)), "exclusive": False} + } + + values = { + "format_version": "2", + "unit_ids": values["unit_ids"], + "label_definitions": labels_def, + "manual_labels": list(manual_labels), + "merges": [{"merge_unit_group": list(group)} for group in merge_groups], + "splits": [], + "removed": [], + } + elif values["format_version"] == "1": + merge_unit_groups = values.get("merge_unit_groups") + if merge_unit_groups is not None: + values["merges"] = [{"merge_unit_group": list(group)} for group in merge_unit_groups] + removed_units = values.get("removed_units") + if removed_units is not None: + values["removed"] = list(removed_units) + return values + @model_validator(mode="before") + def validate_fields(cls, values): + values = dict(values) + values = cls.convert_old_format(values) + values = cls.check_manual_labels(values) + values = cls.check_merges(values) + values = cls.check_splits(values) + values = cls.check_removed(values) return values @model_validator(mode="after") def validate_curation_dict(cls, values): + if values.format_version not in values.supported_versions: + raise ValueError( + f"Format version {values.format_version} not supported. Only {values.supported_versions} are valid" + ) + labeled_unit_set = set([lbl.unit_id for lbl in values.manual_labels]) if values.manual_labels else set() merged_units_set = ( set(chain.from_iterable(merge.merge_unit_group for merge in values.merges)) if values.merges else set() diff --git a/src/spikeinterface/curation/sortingview_curation.py b/src/spikeinterface/curation/sortingview_curation.py index f33051309c..8970463831 100644 --- a/src/spikeinterface/curation/sortingview_curation.py +++ b/src/spikeinterface/curation/sortingview_curation.py @@ -7,13 +7,11 @@ import numpy as np from pathlib import Path -from .curationsorting import CurationSorting from .curation_format import ( - convert_from_sortingview_curation_format_v0, apply_curation, curation_label_to_vectors, - clean_curation_dict, ) +from .curation_model import CurationModel def get_kachery(): @@ -34,6 +32,7 @@ def get_kachery(): ) +# TODO: fix sortingview curation with new format def apply_sortingview_curation( sorting_or_analyzer, uri_or_json, exclude_labels=None, include_labels=None, skip_merge=False, verbose=None ): @@ -82,15 +81,20 @@ def apply_sortingview_curation( except: raise Exception(f"Could not retrieve curation from SortingView uri: {uri_or_json}") - # convert to new format - if "format_version" not in curation_dict: - curation_dict = convert_from_sortingview_curation_format_v0(curation_dict) - unit_ids = sorting_or_analyzer.unit_ids + curation_dict["unit_ids"] = unit_ids + curation_model = CurationModel(**curation_dict) - # this is a hack because it was not in the old format - curation_dict["unit_ids"] = list(unit_ids) + if not skip_merge: + sorting_curated = apply_curation(sorting_or_analyzer, curation_model) + else: + sorting_curated = sorting_or_analyzer + # now remove units based on labels + curation_model.merges = [] + curation_model.unit_ids = sorting_curated.unit_ids + + # this is a hack because it was not in the old format if exclude_labels is not None: assert include_labels is None, "Use either `include_labels` or `exclude_labels` to filter units." manual_labels = curation_label_to_vectors(curation_dict) @@ -99,7 +103,7 @@ def apply_sortingview_curation( remove_mask = manual_labels[k] removed_units.extend(unit_ids[remove_mask]) removed_units = np.unique(removed_units) - curation_dict["removed_units"] = removed_units + curation_model.removed = removed_units if include_labels is not None: manual_labels = curation_label_to_vectors(curation_dict) @@ -108,122 +112,9 @@ def apply_sortingview_curation( remove_mask = ~manual_labels[k] removed_units.extend(unit_ids[remove_mask]) removed_units = np.unique(removed_units) - curation_dict["removed_units"] = removed_units + curation_model.removed = removed_units - if skip_merge: - curation_dict["merge_unit_groups"] = [] - - # cleaner to ensure validity - curation_dict = clean_curation_dict(curation_dict) - - # apply - sorting_curated = apply_curation(sorting_or_analyzer, curation_dict, new_id_strategy="join") + if len(curation_model.removed) > 0: + sorting_curated = apply_curation(sorting_curated, curation_model) return sorting_curated - - -# TODO @alessio you remove this after testing -def apply_sortingview_curation_legacy( - sorting, uri_or_json, exclude_labels=None, include_labels=None, skip_merge=False, verbose=False -): - """ - Apply curation from SortingView manual curation. - First, merges (if present) are applied. Then labels are loaded and units - are optionally filtered based on exclude_labels and include_labels. - - Parameters - ---------- - sorting : BaseSorting - The sorting object to be curated - uri_or_json : str or Path - The URI curation link from SortingView or the path to the curation json file - exclude_labels : list, default: None - Optional list of labels to exclude (e.g. ["reject", "noise"]). - Mutually exclusive with include_labels - include_labels : list, default: None - Optional list of labels to include (e.g. ["accept"]). - Mutually exclusive with exclude_labels, by default None - skip_merge : bool, default: False - If True, merges are not applied (only labels) - verbose : bool, default: False - If True, output is verbose - - Returns - ------- - sorting_curated : BaseSorting - The curated sorting - """ - ka = get_kachery() - curation_sorting = CurationSorting(sorting, make_graph=False, properties_policy="keep") - - # get sorting view curation - if Path(uri_or_json).suffix == ".json" and not str(uri_or_json).startswith("gh://"): - with open(uri_or_json, "r") as f: - sortingview_curation_dict = json.load(f) - else: - try: - sortingview_curation_dict = ka.load_json(uri=uri_or_json) - except: - raise Exception(f"Could not retrieve curation from SortingView uri: {uri_or_json}") - - unit_ids_dtype = sorting.unit_ids.dtype - - # STEP 1: merge groups - labels_dict = sortingview_curation_dict["labelsByUnit"] - if "mergeGroups" in sortingview_curation_dict and not skip_merge: - merge_groups = sortingview_curation_dict["mergeGroups"] - for merge_group in merge_groups: - # Store labels of units that are about to be merged - labels_to_inherit = [] - for unit in merge_group: - labels_to_inherit.extend(labels_dict.get(str(unit), [])) - labels_to_inherit = list(set(labels_to_inherit)) # Remove duplicates - - if verbose: - print(f"Merging {merge_group}") - if unit_ids_dtype.kind in ("U", "S"): - merge_group = [str(unit) for unit in merge_group] - # if unit dtype is str, set new id as "{unit1}-{unit2}" - new_unit_id = "-".join(merge_group) - curation_sorting.merge(merge_group, new_unit_id=new_unit_id) - else: - # in this case, the CurationSorting takes care of finding a new unused int - curation_sorting.merge(merge_group, new_unit_id=None) - new_unit_id = curation_sorting.max_used_id # merged unit id - labels_dict[str(new_unit_id)] = labels_to_inherit - - # STEP 2: gather and apply sortingview curation labels - # In sortingview, a unit is not required to have all labels. - # For example, the first 3 units could be labeled as "accept". - # In this case, the first 3 values of the property "accept" will be True, the rest False - - # Initialize the properties dictionary - properties = { - label: np.zeros(len(curation_sorting.current_sorting.unit_ids), dtype=bool) - for labels in labels_dict.values() - for label in labels - } - - # Populate the properties dictionary - for unit_index, unit_id in enumerate(curation_sorting.current_sorting.unit_ids): - unit_id_str = str(unit_id) - if unit_id_str in labels_dict: - for label in labels_dict[unit_id_str]: - properties[label][unit_index] = True - - for prop_name, prop_values in properties.items(): - curation_sorting.current_sorting.set_property(prop_name, prop_values) - - if include_labels is not None or exclude_labels is not None: - units_to_remove = [] - unit_ids = curation_sorting.current_sorting.unit_ids - assert include_labels or exclude_labels, "Use either `include_labels` or `exclude_labels` to filter units." - if include_labels: - for include_label in include_labels: - units_to_remove.extend(unit_ids[curation_sorting.current_sorting.get_property(include_label) == False]) - if exclude_labels: - for exclude_label in exclude_labels: - units_to_remove.extend(unit_ids[curation_sorting.current_sorting.get_property(exclude_label) == True]) - units_to_remove = np.unique(units_to_remove) - curation_sorting.remove_units(units_to_remove) - return curation_sorting.current_sorting diff --git a/src/spikeinterface/curation/tests/test_curation_format.py b/src/spikeinterface/curation/tests/test_curation_format.py index 0d9562f404..c3ed4a115f 100644 --- a/src/spikeinterface/curation/tests/test_curation_format.py +++ b/src/spikeinterface/curation/tests/test_curation_format.py @@ -1,8 +1,5 @@ import pytest -from pydantic import BaseModel, ValidationError, field_validator - - from pathlib import Path import json import numpy as np @@ -11,15 +8,34 @@ from spikeinterface.curation.curation_format import ( validate_curation_dict, - convert_from_sortingview_curation_format_v0, curation_label_to_vectors, curation_label_to_dataframe, apply_curation, ) +""" +v1 = { + 'format_version': '1', + 'unit_ids': List[int | str], + 'label_definitions': { + 'category_key1': + { + 'label_options': List[str], + 'exclusive': bool} + }, + 'manual_labels': [ + { + 'unit_id': str or int, + 'category_key1': List[str], + } + ], + 'removed_units': List[int | str] # Can not be in the merged_units + 'merge_unit_groups': List[List[int | str]], # one cell goes into at most one list +} -"""example = { - 'unit_ids': List[str, int], +v2 = { + 'format_version': '2', + 'unit_ids': List[int | int], 'label_definitions': { 'category_key1': { @@ -27,18 +43,38 @@ 'exclusive': bool} }, 'manual_labels': [ - {'unit_id': str or int, - category_key1': List[str], + { + 'unit_id': str | int, + 'category_key1': List[str], } ], - 'merge_unit_groups': List[List[unit_ids]], # one cell goes into at most one list - 'removed_units': List[unit_ids] # Can not be in the merged_units -} -""" + 'removed': List[unit_ids], # Can not be in the merged_units + 'merges': [ + { + 'merge_unit_group': List[unit_ids], + 'merge_new_unit_id': int | str (optional) + } + ], + 'splits': [ + { + 'unit_id': int | str + 'split_mode': 'indices' or 'labels', + 'split_indices': List[List[int]], + 'split_labels': List[int]], + 'split_new_unit_ids': List[int | str] + } + ] + +sortingview_curation = { + 'mergeGroups': List[List[int | str]], + 'labelsByUnit': { + 'unit_id': List[str] + } +""" curation_ids_int = { - "format_version": "1", + "format_version": "2", "unit_ids": [1, 2, 3, 6, 10, 14, 20, 31, 42], "label_definitions": { "quality": {"label_options": ["good", "noise", "MUA", "artifact"], "exclusive": True}, @@ -51,19 +87,21 @@ {"unit_id": 1, "quality": ["good"]}, { "unit_id": 2, - "quality": [ - "noise", - ], + "quality": ["noise"], "putative_type": ["excitatory", "pyramidal"], }, {"unit_id": 3, "putative_type": ["inhibitory"]}, ], - "merge_unit_groups": [[3, 6], [10, 14, 20]], # one cell goes into at most one list - "removed_units": [31, 42], # Can not be in the merged_units + "merges": [{"merge_unit_group": [3, 6]}, {"merge_unit_group": [10, 14, 20]}], + "splits": [], + "removed": [31, 42], } +# Test dictionary format for merges +curation_ids_int_dict = {**curation_ids_int, "merges": {50: [3, 6], 51: [10, 14, 20]}} + curation_ids_str = { - "format_version": "1", + "format_version": "2", "unit_ids": ["u1", "u2", "u3", "u6", "u10", "u14", "u20", "u31", "u42"], "label_definitions": { "quality": {"label_options": ["good", "noise", "MUA", "artifact"], "exclusive": True}, @@ -76,40 +114,65 @@ {"unit_id": "u1", "quality": ["good"]}, { "unit_id": "u2", - "quality": [ - "noise", - ], + "quality": ["noise"], "putative_type": ["excitatory", "pyramidal"], }, {"unit_id": "u3", "putative_type": ["inhibitory"]}, ], - "merge_unit_groups": [["u3", "u6"], ["u10", "u14", "u20"]], # one cell goes into at most one list - "removed_units": ["u31", "u42"], # Can not be in the merged_units + "merges": [{"merge_unit_group": ["u3", "u6"]}, {"merge_unit_group": ["u10", "u14", "u20"]}], + "splits": [], + "removed": ["u31", "u42"], } -# This is a failure example with duplicated merge -duplicate_merge = curation_ids_int.copy() -duplicate_merge["merge_unit_groups"] = [[3, 6, 10], [10, 14, 20]] +# Test dictionary format for merges with string IDs +curation_ids_str_dict = {**curation_ids_str, "merges": {"u50": ["u3", "u6"], "u51": ["u10", "u14", "u20"]}} +# Test with splits +curation_with_splits = { + **curation_ids_int, + "splits": [ + {"unit_id": 2, "split_mode": "indices", "split_indices": [[0, 1, 2], [3, 4, 5]], "split_new_unit_ids": [50, 51]} + ], +} + +# Test dictionary format for splits +curation_with_splits_dict = {**curation_ids_int, "splits": {2: [[0, 1, 2], [3, 4, 5]]}} + +# This is a failure example with duplicated merge +duplicate_merge = {**curation_ids_int, "merges": [{"merge_unit_group": [3, 6, 10]}, {"merge_unit_group": [10, 14, 20]}]} # This is a failure example with unit 3 both in removed and merged -merged_and_removed = curation_ids_int.copy() -merged_and_removed["merge_unit_groups"] = [[3, 6], [10, 14, 20]] -merged_and_removed["removed_units"] = [3, 31, 42] +merged_and_removed = { + **curation_ids_int, + "merges": [{"merge_unit_group": [3, 6]}, {"merge_unit_group": [10, 14, 20]}], + "removed": [3, 31, 42], +} -# this is a failure because unit 99 is not in the initial list -unknown_merged_unit = curation_ids_int.copy() -unknown_merged_unit["merge_unit_groups"] = [[3, 6, 99], [10, 14, 20]] +# This is a failure because unit 99 is not in the initial list +unknown_merged_unit = { + **curation_ids_int, + "merges": [{"merge_unit_group": [3, 6, 99]}, {"merge_unit_group": [10, 14, 20]}], +} -# this is a failure because unit 99 is not in the initial list -unknown_removed_unit = curation_ids_int.copy() -unknown_removed_unit["removed_units"] = [31, 42, 99] +# This is a failure because unit 99 is not in the initial list +unknown_removed_unit = {**curation_ids_int, "removed": [31, 42, 99]} def test_curation_format_validation(): + # Test basic formats + print(curation_ids_int) validate_curation_dict(curation_ids_int) + print(curation_ids_int) validate_curation_dict(curation_ids_str) + # Test dictionary formats + validate_curation_dict(curation_ids_int_dict) + validate_curation_dict(curation_ids_str_dict) + + # Test splits + validate_curation_dict(curation_with_splits) + validate_curation_dict(curation_with_splits_dict) + with pytest.raises(ValueError): # Raised because duplicated merged units validate_curation_dict(duplicate_merge) @@ -125,13 +188,13 @@ def test_curation_format_validation(): def test_to_from_json(): - json.loads(json.dumps(curation_ids_int, indent=4)) json.loads(json.dumps(curation_ids_str, indent=4)) + json.loads(json.dumps(curation_ids_int_dict, indent=4)) + json.loads(json.dumps(curation_with_splits, indent=4)) def test_convert_from_sortingview_curation_format_v0(): - parent_folder = Path(__file__).parent for filename in ( "sv-sorting-curation.json", @@ -139,18 +202,13 @@ def test_convert_from_sortingview_curation_format_v0(): "sv-sorting-curation-str.json", "sv-sorting-curation-false-positive.json", ): - json_file = parent_folder / filename with open(json_file, "r") as f: curation_v0 = json.load(f) - # print(curation_v0) - curation_v1 = convert_from_sortingview_curation_format_v0(curation_v0) - # print(curation_v1) - validate_curation_dict(curation_v1) + validate_curation_dict(curation_v0) def test_curation_label_to_vectors(): - labels = curation_label_to_vectors(curation_ids_int) assert "quality" in labels assert "excitatory" in labels @@ -161,36 +219,45 @@ def test_curation_label_to_vectors(): def test_curation_label_to_dataframe(): - df = curation_label_to_dataframe(curation_ids_int) assert "quality" in df.columns assert "excitatory" in df.columns print(df) df = curation_label_to_dataframe(curation_ids_str) - # print(df) + print(df) def test_apply_curation(): recording, sorting = generate_ground_truth_recording(durations=[10.0], num_units=9, seed=2205) - sorting._main_ids = np.array([1, 2, 3, 6, 10, 14, 20, 31, 42]) + sorting = sorting.rename_units([1, 2, 3, 6, 10, 14, 20, 31, 42]) analyzer = create_sorting_analyzer(sorting, recording, sparse=False) + # Test with list format sorting_curated = apply_curation(sorting, curation_ids_int) assert sorting_curated.get_property("quality", ids=[1])[0] == "good" assert sorting_curated.get_property("quality", ids=[2])[0] == "noise" assert sorting_curated.get_property("excitatory", ids=[2])[0] + # Test with dictionary format + sorting_curated = apply_curation(sorting, curation_ids_int_dict) + assert sorting_curated.get_property("quality", ids=[1])[0] == "good" + assert sorting_curated.get_property("quality", ids=[2])[0] == "noise" + assert sorting_curated.get_property("excitatory", ids=[2])[0] + + # Test with splits + sorting_curated = apply_curation(sorting, curation_with_splits) + assert sorting_curated.get_property("quality", ids=[1])[0] == "good" + + # Test analyzer analyzer_curated = apply_curation(analyzer, curation_ids_int) assert "quality" in analyzer_curated.sorting.get_property_keys() if __name__ == "__main__": - test_curation_format_validation() test_curation_format_validation() test_to_from_json() test_convert_from_sortingview_curation_format_v0() test_curation_label_to_vectors() test_curation_label_to_dataframe() - test_apply_curation() diff --git a/src/spikeinterface/curation/tests/test_curation_model.py b/src/spikeinterface/curation/tests/test_curation_model.py index 1ebb36296d..3db904a480 100644 --- a/src/spikeinterface/curation/tests/test_curation_model.py +++ b/src/spikeinterface/curation/tests/test_curation_model.py @@ -13,15 +13,15 @@ def test_format_version(): # Invalid format version with pytest.raises(ValidationError): - CurationModel(format_version="2", unit_ids=[1, 2, 3]) + CurationModel(format_version="3", unit_ids=[1, 2, 3]) with pytest.raises(ValidationError): - CurationModel(format_version="0", unit_ids=[1, 2, 3]) + CurationModel(format_version="0.1", unit_ids=[1, 2, 3]) # Test data for label definitions def test_label_definitions(): valid_label_def = { - "format_version": "1", + "format_version": "2", "unit_ids": [1, 2, 3], "label_definitions": { "quality": LabelDefinition(name="quality", label_options=["good", "noise"], exclusive=True), @@ -42,7 +42,7 @@ def test_label_definitions(): # Test manual labels def test_manual_labels(): valid_labels = { - "format_version": "1", + "format_version": "2", "unit_ids": [1, 2, 3], "label_definitions": { "quality": LabelDefinition(name="quality", label_options=["good", "noise"], exclusive=True), @@ -59,7 +59,7 @@ def test_manual_labels(): # Test invalid unit ID invalid_unit = { - "format_version": "1", + "format_version": "2", "unit_ids": [1, 2, 3], "label_definitions": { "quality": LabelDefinition(name="quality", label_options=["good", "noise"], exclusive=True) @@ -71,7 +71,7 @@ def test_manual_labels(): # Test violation of exclusive label invalid_exclusive = { - "format_version": "1", + "format_version": "2", "unit_ids": [1, 2, 3], "label_definitions": { "quality": LabelDefinition(name="quality", label_options=["good", "noise"], exclusive=True) @@ -88,7 +88,7 @@ def test_manual_labels(): def test_merge_units(): # Test list format valid_merge = { - "format_version": "1", + "format_version": "2", "unit_ids": [1, 2, 3, 4], "merges": [ {"merge_unit_group": [1, 2], "merge_new_unit_ids": 5}, @@ -102,7 +102,7 @@ def test_merge_units(): assert model.merges[1].merge_new_unit_ids == 6 # Test dictionary format - valid_merge_dict = {"format_version": "1", "unit_ids": [1, 2, 3, 4], "merges": {5: [1, 2], 6: [3, 4]}} + valid_merge_dict = {"format_version": "2", "unit_ids": [1, 2, 3, 4], "merges": {5: [1, 2], 6: [3, 4]}} model = CurationModel(**valid_merge_dict) assert len(model.merges) == 2 @@ -111,7 +111,7 @@ def test_merge_units(): # Test list format valid_merge_list = { - "format_version": "1", + "format_version": "2", "unit_ids": [1, 2, 3, 4], "merges": [[1, 2], [3, 4]], # Merge each pair into a new unit } @@ -120,7 +120,7 @@ def test_merge_units(): # Test invalid merge group (single unit) invalid_merge_group = { - "format_version": "1", + "format_version": "2", "unit_ids": [1, 2, 3], "merges": [{"merge_unit_group": [1], "merge_new_unit_ids": 4}], } @@ -129,7 +129,7 @@ def test_merge_units(): # Test overlapping merge groups invalid_overlap = { - "format_version": "1", + "format_version": "2", "unit_ids": [1, 2, 3], "merges": [ {"merge_unit_group": [1, 2], "merge_new_unit_ids": 4}, @@ -144,7 +144,7 @@ def test_merge_units(): def test_split_units(): # Test indices mode with list format valid_split_indices = { - "format_version": "1", + "format_version": "2", "unit_ids": [1, 2, 3], "splits": [ { @@ -163,7 +163,7 @@ def test_split_units(): # Test labels mode with list format valid_split_labels = { - "format_version": "1", + "format_version": "2", "unit_ids": [1, 2, 3], "splits": [ {"unit_id": 1, "split_mode": "labels", "split_labels": [0, 0, 1, 1, 0, 2], "split_new_unit_ids": [4, 5, 6]} @@ -177,7 +177,7 @@ def test_split_units(): # Test dictionary format with indices valid_split_dict = { - "format_version": "1", + "format_version": "2", "unit_ids": [1, 2, 3], "splits": { 1: [[0, 1, 2], [3, 4, 5]], # Split unit 1 into two parts @@ -191,7 +191,7 @@ def test_split_units(): # Test invalid unit ID invalid_unit_id = { - "format_version": "1", + "format_version": "2", "unit_ids": [1, 2, 3], "splits": [{"unit_id": 4, "split_mode": "indices", "split_indices": [[0, 1], [2, 3]]}], # Non-existent unit } @@ -200,7 +200,7 @@ def test_split_units(): # Test invalid new unit IDs count for indices mode invalid_new_ids = { - "format_version": "1", + "format_version": "2", "unit_ids": [1, 2, 3], "splits": [ { @@ -217,19 +217,19 @@ def test_split_units(): # Test removed units def test_removed_units(): - valid_remove = {"format_version": "1", "unit_ids": [1, 2, 3], "removed": [2]} + valid_remove = {"format_version": "2", "unit_ids": [1, 2, 3], "removed": [2]} model = CurationModel(**valid_remove) assert len(model.removed) == 1 # Test removing non-existent unit - invalid_remove = {"format_version": "1", "unit_ids": [1, 2, 3], "removed": [4]} # Non-existent unit + invalid_remove = {"format_version": "2", "unit_ids": [1, 2, 3], "removed": [4]} # Non-existent unit with pytest.raises(ValidationError): CurationModel(**invalid_remove) # Test conflict between merge and remove invalid_merge_remove = { - "format_version": "1", + "format_version": "2", "unit_ids": [1, 2, 3], "merges": [{"merge_unit_group": [1, 2], "merge_new_unit_ids": 4}], "removed": [1], # Unit is both merged and removed @@ -241,7 +241,7 @@ def test_removed_units(): # Test complete model with multiple operations def test_complete_model(): complete_model = { - "format_version": "1", + "format_version": "2", "unit_ids": [1, 2, 3, 4, 5], "label_definitions": { "quality": LabelDefinition(name="quality", label_options=["good", "noise"], exclusive=True), @@ -256,7 +256,7 @@ def test_complete_model(): } model = CurationModel(**complete_model) - assert model.format_version == "1" + assert model.format_version == "2" assert len(model.unit_ids) == 5 assert len(model.label_definitions) == 2 assert len(model.manual_labels) == 1 @@ -266,7 +266,7 @@ def test_complete_model(): # Test dictionary format for complete model complete_model_dict = { - "format_version": "1", + "format_version": "2", "unit_ids": [1, 2, 3, 4, 5], "label_definitions": { "quality": LabelDefinition(name="quality", label_options=["good", "noise"], exclusive=True), @@ -279,7 +279,7 @@ def test_complete_model(): } model = CurationModel(**complete_model_dict) - assert model.format_version == "1" + assert model.format_version == "2" assert len(model.unit_ids) == 5 assert len(model.label_definitions) == 2 assert len(model.manual_labels) == 1 From f122db720b8ec39ce1231747290d9b3fe05b723f Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 27 Mar 2025 15:49:39 -0400 Subject: [PATCH 17/22] Fix sortingview tests --- .../curation/curation_format.py | 3 ++- .../curation/sortingview_curation.py | 27 +++++++++++-------- 2 files changed, 18 insertions(+), 12 deletions(-) diff --git a/src/spikeinterface/curation/curation_format.py b/src/spikeinterface/curation/curation_format.py index d634735344..ea1c19e719 100644 --- a/src/spikeinterface/curation/curation_format.py +++ b/src/spikeinterface/curation/curation_format.py @@ -182,11 +182,12 @@ def apply_curation( censor_ms : float | None, default: None When applying the merges, any consecutive spikes within the `censor_ms` are removed. This can be thought of as the desired refractory period. If `censor_ms=None`, no spikes are discarded. - new_id_strategy : "append" | "take_first", default: "append" + new_id_strategy : "append" | "take_first" | "join", default: "append" The strategy that should be used, if `new_unit_ids` is None, to create new unit_ids. * "append" : new_units_ids will be added at the end of max(sorting.unit_ids) * "take_first" : new_unit_ids will be the first unit_id of every list of merges + * "join" : new_unit_ids will be the concatenation of all unit_ids of every list of merges merging_mode : "soft" | "hard", default: "soft" How merges are performed for SortingAnalyzer. If the `merge_mode` is "soft" , merges will be approximated, with no reloading of the waveforms. This will lead to approximations. If `merge_mode` is "hard", recomputations are accurately diff --git a/src/spikeinterface/curation/sortingview_curation.py b/src/spikeinterface/curation/sortingview_curation.py index 8970463831..fe21b72263 100644 --- a/src/spikeinterface/curation/sortingview_curation.py +++ b/src/spikeinterface/curation/sortingview_curation.py @@ -11,7 +11,7 @@ apply_curation, curation_label_to_vectors, ) -from .curation_model import CurationModel +from .curation_model import CurationModel, Merge def get_kachery(): @@ -32,7 +32,6 @@ def get_kachery(): ) -# TODO: fix sortingview curation with new format def apply_sortingview_curation( sorting_or_analyzer, uri_or_json, exclude_labels=None, include_labels=None, skip_merge=False, verbose=None ): @@ -85,14 +84,8 @@ def apply_sortingview_curation( curation_dict["unit_ids"] = unit_ids curation_model = CurationModel(**curation_dict) - if not skip_merge: - sorting_curated = apply_curation(sorting_or_analyzer, curation_model) - else: - sorting_curated = sorting_or_analyzer - - # now remove units based on labels - curation_model.merges = [] - curation_model.unit_ids = sorting_curated.unit_ids + if skip_merge: + curation_model.merges = [] # this is a hack because it was not in the old format if exclude_labels is not None: @@ -114,7 +107,19 @@ def apply_sortingview_curation( removed_units = np.unique(removed_units) curation_model.removed = removed_units + # make merges and removed units if len(curation_model.removed) > 0: - sorting_curated = apply_curation(sorting_curated, curation_model) + clean_merges = [] + for merge in curation_model.merges: + clean_merge = [] + for unit_id in merge.merge_unit_group: + if unit_id not in curation_model.removed: + clean_merge.append(unit_id) + if len(clean_merge) > 1: + clean_merges.append(Merge(merge_unit_group=clean_merge)) + curation_model.merges = clean_merges + + # apply curation + sorting_curated = apply_curation(sorting_or_analyzer, curation_model, new_id_strategy="join") return sorting_curated From 4f14e90295dc2dab228bf26355fb98449b46f983 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 27 Mar 2025 15:57:42 -0400 Subject: [PATCH 18/22] Fix sortingview conversion --- src/spikeinterface/curation/curation_model.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/curation/curation_model.py b/src/spikeinterface/curation/curation_model.py index 787af88111..94a277b3b7 100644 --- a/src/spikeinterface/curation/curation_model.py +++ b/src/spikeinterface/curation/curation_model.py @@ -273,10 +273,13 @@ def convert_old_format(cls, values): labels_def = { "all_labels": {"name": "all_labels", "label_options": list(set(all_labels)), "exclusive": False} } + for merge_group in merge_groups: + all_units.extend(merge_group) + all_units = list(set(all_units)) values = { "format_version": "2", - "unit_ids": values["unit_ids"], + "unit_ids": values.get("unit_ids", all_units), "label_definitions": labels_def, "manual_labels": list(manual_labels), "merges": [{"merge_unit_group": list(group)} for group in merge_groups], From d2f220a0abfefa9712f9fa0f9810f8b8ed0a6025 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 28 Mar 2025 16:41:48 -0400 Subject: [PATCH 19/22] Fix test-multi-extensions --- .../tests/test_multi_extensions.py | 92 +++++++++++++++---- 1 file changed, 76 insertions(+), 16 deletions(-) diff --git a/src/spikeinterface/postprocessing/tests/test_multi_extensions.py b/src/spikeinterface/postprocessing/tests/test_multi_extensions.py index 0c8c2649af..8c512c2109 100644 --- a/src/spikeinterface/postprocessing/tests/test_multi_extensions.py +++ b/src/spikeinterface/postprocessing/tests/test_multi_extensions.py @@ -48,9 +48,22 @@ # due to incremental PCA, hard computation could result in different results for PCA # the model is differents always random_computation = ["principal_components"] +# for some extensions (templates, amplitude_scalings), since the templates slightly change for merges/splits +# we allow a relative tolerance +# (amplitud_scalings are the moste sensitive!) +extensions_with_rel_tolerance_merge = { + "amplitude_scalings": 1e-1, + "templates": 1e-3, + "template_similarity": 1e-3, + "unit_locations": 1e-3, + "template_metrics": 1e-3, + "quality_metrics": 1e-3, +} +extensions_with_rel_tolerance_splits = {"amplitude_scalings": 1e-1} -def get_dataset_with_splits(): +def get_dataset_to_merge(): + # generate a dataset with some split units to minimize merge errors recording, sorting = generate_ground_truth_recording( durations=[30.0], sampling_frequency=16000.0, @@ -58,6 +71,7 @@ def get_dataset_with_splits(): num_units=10, generate_sorting_kwargs=dict(firing_rates=10.0, refractory_period_ms=4.0), noise_kwargs=dict(noise_levels=5.0, strategy="tile_pregenerated"), + generate_unit_locations_kwargs=dict(margin_um=10.0, minimum_z=2.0, maximum_z=15.0, minimum_distance=20), seed=2205, ) @@ -80,16 +94,49 @@ def get_dataset_with_splits(): return recording, sorting_with_splits, split_unit_ids +def get_dataset_to_split(): + # generate a dataset and return large unit to split to minimize split errors + recording, sorting = generate_ground_truth_recording( + durations=[30.0], + sampling_frequency=16000.0, + num_channels=10, + num_units=10, + generate_sorting_kwargs=dict(firing_rates=10.0, refractory_period_ms=4.0), + noise_kwargs=dict(noise_levels=5.0, strategy="tile_pregenerated"), + seed=2205, + ) + + channel_ids_as_integers = [id for id in range(recording.get_num_channels())] + unit_ids_as_integers = [id for id in range(sorting.get_num_units())] + recording = recording.rename_channels(new_channel_ids=channel_ids_as_integers) + sorting = sorting.rename_units(new_unit_ids=unit_ids_as_integers) + + # since templates are going to be averaged and this might be a problem for amplitude scaling + # we select the 3 units with the largest templates to split + analyzer_raw = create_sorting_analyzer(sorting, recording, format="memory", sparse=False) + analyzer_raw.compute(["random_spikes", "templates"]) + # select 3 largest templates to split + sort_by_amp = np.argsort(list(get_template_extremum_amplitude(analyzer_raw).values()))[::-1] + large_units = sorting.unit_ids[sort_by_amp][:2] + + return recording, sorting, large_units + + +@pytest.fixture(scope="module") +def dataset_to_merge(): + return get_dataset_to_merge() + + @pytest.fixture(scope="module") -def dataset(): - return get_dataset_with_splits() +def dataset_to_split(): + return get_dataset_to_split() @pytest.mark.parametrize("sparse", [False, True]) -def test_SortingAnalyzer_merge_all_extensions(dataset, sparse): +def test_SortingAnalyzer_merge_all_extensions(dataset_to_merge, sparse): set_global_job_kwargs(n_jobs=1) - recording, sorting, other_ids = dataset + recording, sorting, other_ids = dataset_to_merge sorting_analyzer = create_sorting_analyzer(sorting, recording, format="memory", sparse=sparse) extension_dict_merge = extension_dict.copy() @@ -155,21 +202,29 @@ def test_SortingAnalyzer_merge_all_extensions(dataset, sparse): ) if ext not in random_computation: + if ext in extensions_with_rel_tolerance_merge: + rtol = extensions_with_rel_tolerance_merge[ext] + else: + rtol = 0 if extension_data_type[ext] == "pandas": data_hard_merged = data_hard_merged.dropna().to_numpy().astype("float") data_soft_merged = data_soft_merged.dropna().to_numpy().astype("float") if data_hard_merged.dtype.fields is None: - assert np.allclose(data_hard_merged, data_soft_merged, rtol=0.1) + if not np.allclose(data_hard_merged, data_soft_merged, rtol=rtol): + max_error = np.max(np.abs(data_hard_merged - data_soft_merged)) + raise Exception(f"Failed for {ext} - max error {max_error}") else: for f in data_hard_merged.dtype.fields: - assert np.allclose(data_hard_merged[f], data_soft_merged[f], rtol=0.1) + if not np.allclose(data_hard_merged[f], data_soft_merged[f], rtol=rtol): + max_error = np.max(np.abs(data_hard_merged[f] - data_soft_merged[f])) + raise Exception(f"Failed for {ext} - field {f} - max error {max_error}") @pytest.mark.parametrize("sparse", [False, True]) -def test_SortingAnalyzer_split_all_extensions(dataset, sparse): +def test_SortingAnalyzer_split_all_extensions(dataset_to_split, sparse): set_global_job_kwargs(n_jobs=1) - recording, sorting, _ = dataset + recording, sorting, units_to_split = dataset_to_split sorting_analyzer = create_sorting_analyzer(sorting, recording, format="memory", sparse=sparse) extension_dict_split = extension_dict.copy() @@ -178,7 +233,6 @@ def test_SortingAnalyzer_split_all_extensions(dataset, sparse): # we randomly apply splits (at half of spiketrain) num_spikes = sorting.count_num_spikes_per_unit() - units_to_split = [sorting_analyzer.unit_ids[1], sorting_analyzer.unit_ids[5]] unsplit_unit_ids = sorting_analyzer.unit_ids[~np.isin(sorting_analyzer.unit_ids, units_to_split)] splits = {} for unit in units_to_split: @@ -220,17 +274,23 @@ def test_SortingAnalyzer_split_all_extensions(dataset, sparse): data_split_hard = get_extension_data_for_units( analyzer_hard, data_recompute, split_unit_ids, extension_data_type[ext] ) - # TODO: fix amplitude scalings - failing_extensions = [] - if ext not in random_computation + failing_extensions: + if ext not in random_computation: + if ext in extensions_with_rel_tolerance_splits: + rtol = extensions_with_rel_tolerance_splits[ext] + else: + rtol = 0 if extension_data_type[ext] == "pandas": data_split_soft = data_split_soft.dropna().to_numpy().astype("float") data_split_hard = data_split_hard.dropna().to_numpy().astype("float") if data_split_hard.dtype.fields is None: - assert np.allclose(data_split_hard, data_split_soft, rtol=0.1) + if not np.allclose(data_split_hard, data_split_soft, rtol=rtol): + max_error = np.max(np.abs(data_split_hard - data_split_soft)) + raise Exception(f"Failed for {ext} - max error {max_error}") else: for f in data_split_hard.dtype.fields: - assert np.allclose(data_split_hard[f], data_split_soft[f], rtol=0.1) + if not np.allclose(data_split_hard[f], data_split_soft[f], rtol=rtol): + max_error = np.max(np.abs(data_split_hard[f] - data_split_soft[f])) + raise Exception(f"Failed for {ext} - field {f} - max error {max_error}") def get_extension_data_for_units(sorting_analyzer, data, unit_ids, ext_data_type): @@ -259,5 +319,5 @@ def get_extension_data_for_units(sorting_analyzer, data, unit_ids, ext_data_type if __name__ == "__main__": - dataset = get_dataset_with_splits() + dataset = get_dataset_to_merge() test_SortingAnalyzer_merge_all_extensions(dataset, False) From 317f87c291a9b712664e5f951be8a4b2b711724a Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 28 Mar 2025 16:45:34 -0400 Subject: [PATCH 20/22] merge_new_unit_ids -> merge_new_unit_id --- src/spikeinterface/curation/curation_model.py | 20 +++++++++---------- .../curation/tests/test_curation_model.py | 20 +++++++++---------- 2 files changed, 19 insertions(+), 21 deletions(-) diff --git a/src/spikeinterface/curation/curation_model.py b/src/spikeinterface/curation/curation_model.py index 94a277b3b7..afcbef17d5 100644 --- a/src/spikeinterface/curation/curation_model.py +++ b/src/spikeinterface/curation/curation_model.py @@ -17,7 +17,7 @@ class ManualLabel(BaseModel): class Merge(BaseModel): merge_unit_group: List[Union[int, str]] = Field(..., description="List of groups of units to be merged") - merge_new_unit_ids: Optional[Union[int, str]] = Field(default=None, description="New unit IDs for the merge group") + merge_new_unit_id: Optional[Union[int, str]] = Field(default=None, description="New unit IDs for the merge group") class Split(BaseModel): @@ -67,7 +67,7 @@ def add_label_definition_name(cls, label_definitions): @classmethod def check_manual_labels(cls, values): - values = dict(values) + unit_ids = list(values["unit_ids"]) manual_labels = values.get("manual_labels") if manual_labels is None: @@ -99,7 +99,7 @@ def check_manual_labels(cls, values): @classmethod def check_merges(cls, values): - values = dict(values) + unit_ids = list(values["unit_ids"]) merges = values.get("merges") if merges is None: @@ -110,7 +110,7 @@ def check_merges(cls, values): # Convert dict format to list of Merge objects merge_list = [] for merge_new_id, merge_group in merges.items(): - merge_list.append({"merge_unit_group": list(merge_group), "merge_new_unit_ids": merge_new_id}) + merge_list.append({"merge_unit_group": list(merge_group), "merge_new_unit_id": merge_new_id}) merges = merge_list # Make a copy of the list @@ -138,16 +138,16 @@ def check_merges(cls, values): raise ValueError("Merge unit groups must have at least 2 elements") # Check new unit id not already used - if merge.merge_new_unit_ids is not None: - if merge.merge_new_unit_ids in unit_ids: - raise ValueError(f"New unit ID {merge.merge_new_unit_ids} is already in the unit list") + if merge.merge_new_unit_id is not None: + if merge.merge_new_unit_id in unit_ids: + raise ValueError(f"New unit ID {merge.merge_new_unit_id} is already in the unit list") values["merges"] = merges return values @classmethod def check_splits(cls, values): - values = dict(values) + unit_ids = list(values["unit_ids"]) splits = values.get("splits") if splits is None: @@ -230,7 +230,6 @@ def check_splits(cls, values): @classmethod def check_removed(cls, values): - values = dict(values) unit_ids = list(values["unit_ids"]) removed = values.get("removed") if removed is None: @@ -246,8 +245,6 @@ def check_removed(cls, values): @classmethod def convert_old_format(cls, values): format_version = values.get("format_version", "0") - if format_version != "2": - values = dict(values) if format_version == "0": print("Conversion from format version v0 (sortingview) to v2") if "mergeGroups" not in values.keys(): @@ -298,6 +295,7 @@ def convert_old_format(cls, values): @model_validator(mode="before") def validate_fields(cls, values): values = dict(values) + values["label_definitions"] = values.get("label_definitions", {}) values = cls.convert_old_format(values) values = cls.check_manual_labels(values) values = cls.check_merges(values) diff --git a/src/spikeinterface/curation/tests/test_curation_model.py b/src/spikeinterface/curation/tests/test_curation_model.py index 3db904a480..7354ac1892 100644 --- a/src/spikeinterface/curation/tests/test_curation_model.py +++ b/src/spikeinterface/curation/tests/test_curation_model.py @@ -91,22 +91,22 @@ def test_merge_units(): "format_version": "2", "unit_ids": [1, 2, 3, 4], "merges": [ - {"merge_unit_group": [1, 2], "merge_new_unit_ids": 5}, - {"merge_unit_group": [3, 4], "merge_new_unit_ids": 6}, + {"merge_unit_group": [1, 2], "merge_new_unit_id": 5}, + {"merge_unit_group": [3, 4], "merge_new_unit_id": 6}, ], } model = CurationModel(**valid_merge) assert len(model.merges) == 2 - assert model.merges[0].merge_new_unit_ids == 5 - assert model.merges[1].merge_new_unit_ids == 6 + assert model.merges[0].merge_new_unit_id == 5 + assert model.merges[1].merge_new_unit_id == 6 # Test dictionary format valid_merge_dict = {"format_version": "2", "unit_ids": [1, 2, 3, 4], "merges": {5: [1, 2], 6: [3, 4]}} model = CurationModel(**valid_merge_dict) assert len(model.merges) == 2 - merge_new_ids = {merge.merge_new_unit_ids for merge in model.merges} + merge_new_ids = {merge.merge_new_unit_id for merge in model.merges} assert merge_new_ids == {5, 6} # Test list format @@ -122,7 +122,7 @@ def test_merge_units(): invalid_merge_group = { "format_version": "2", "unit_ids": [1, 2, 3], - "merges": [{"merge_unit_group": [1], "merge_new_unit_ids": 4}], + "merges": [{"merge_unit_group": [1], "merge_new_unit_id": 4}], } with pytest.raises(ValidationError): CurationModel(**invalid_merge_group) @@ -132,8 +132,8 @@ def test_merge_units(): "format_version": "2", "unit_ids": [1, 2, 3], "merges": [ - {"merge_unit_group": [1, 2], "merge_new_unit_ids": 4}, - {"merge_unit_group": [2, 3], "merge_new_unit_ids": 5}, + {"merge_unit_group": [1, 2], "merge_new_unit_id": 4}, + {"merge_unit_group": [2, 3], "merge_new_unit_id": 5}, ], } with pytest.raises(ValidationError): @@ -231,7 +231,7 @@ def test_removed_units(): invalid_merge_remove = { "format_version": "2", "unit_ids": [1, 2, 3], - "merges": [{"merge_unit_group": [1, 2], "merge_new_unit_ids": 4}], + "merges": [{"merge_unit_group": [1, 2], "merge_new_unit_id": 4}], "removed": [1], # Unit is both merged and removed } with pytest.raises(ValidationError): @@ -248,7 +248,7 @@ def test_complete_model(): "tags": LabelDefinition(name="tags", label_options=["burst", "slow"], exclusive=False), }, "manual_labels": [{"unit_id": 1, "labels": {"quality": ["good"], "tags": ["burst"]}}], - "merges": [{"merge_unit_group": [2, 3], "merge_new_unit_ids": 6}], + "merges": [{"merge_unit_group": [2, 3], "merge_new_unit_id": 6}], "splits": [ {"unit_id": 4, "split_mode": "indices", "split_indices": [[0, 1], [2, 3]], "split_new_unit_ids": [7, 8]} ], From d4fa8bf0da57d9449c8ca91de8a64d03075696a4 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 28 Mar 2025 16:52:25 -0400 Subject: [PATCH 21/22] Deal with multi-segment --- src/spikeinterface/core/sorting_tools.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/src/spikeinterface/core/sorting_tools.py b/src/spikeinterface/core/sorting_tools.py index d6ba6a7e6b..f9d05d29ca 100644 --- a/src/spikeinterface/core/sorting_tools.py +++ b/src/spikeinterface/core/sorting_tools.py @@ -452,17 +452,16 @@ def generate_unit_ids_for_merge_group(old_unit_ids, merge_unit_groups, new_unit_ def apply_splits_to_sorting(sorting, unit_splits, new_unit_ids=None, return_extra=False, new_id_strategy="append"): spikes = sorting.to_spike_vector().copy() - # take care of single-list splits - full_unit_splits = _get_full_unit_splits(unit_splits, sorting) + # here we assume that unit_splits split_indices are already full. + # this is true when running via apply_curation new_unit_ids = generate_unit_ids_for_split( - sorting.unit_ids, full_unit_splits, new_unit_ids=new_unit_ids, new_id_strategy=new_id_strategy + sorting.unit_ids, unit_splits, new_unit_ids=new_unit_ids, new_id_strategy=new_id_strategy ) - all_unit_ids = _get_ids_after_splitting(sorting.unit_ids, full_unit_splits, new_unit_ids) + all_unit_ids = _get_ids_after_splitting(sorting.unit_ids, unit_splits, new_unit_ids) all_unit_ids = list(all_unit_ids) num_seg = sorting.get_num_segments() - assert num_seg == 1 seg_lims = np.searchsorted(spikes["segment_index"], np.arange(0, num_seg + 2)) segment_slices = [(seg_lims[i], seg_lims[i + 1]) for i in range(num_seg)] @@ -470,17 +469,18 @@ def apply_splits_to_sorting(sorting, unit_splits, new_unit_ids=None, return_extr spike_vector_list = [spikes[s0:s1] for s0, s1 in segment_slices] spike_indices = spike_vector_to_indices(spike_vector_list, sorting.unit_ids, absolute_index=True) - # TODO deal with segments in splits + # split_indices are a concatenation across segments for unit_id in sorting.unit_ids: - if unit_id in full_unit_splits: - split_indices = full_unit_splits[unit_id] - new_split_ids = new_unit_ids[list(full_unit_splits.keys()).index(unit_id)] + if unit_id in unit_splits: + split_indices = unit_splits[unit_id] + new_split_ids = new_unit_ids[list(unit_splits.keys()).index(unit_id)] for split, new_unit_id in zip(split_indices, new_split_ids): new_unit_index = all_unit_ids.index(new_unit_id) - for segment_index in range(num_seg): - spike_inds = spike_indices[segment_index][unit_id] - spikes["unit_index"][spike_inds[split]] = new_unit_index + spike_indices_unit = np.concatenate( + spike_indices[segment_index][unit_id] for segment_index in range(num_seg) + ) + spikes["unit_index"][spike_indices_unit[split]] = new_unit_index else: new_unit_index = all_unit_ids.index(unit_id) for segment_index in range(num_seg): From 659ecff46fee32ce59671c56af0dd57245da2d0c Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Sat, 29 Mar 2025 11:52:05 -0400 Subject: [PATCH 22/22] Extend splitting-tests to multi-segment and mask labels --- src/spikeinterface/core/sorting_tools.py | 5 +- .../curation/tests/test_curation_format.py | 121 ++++++++++++++++-- 2 files changed, 116 insertions(+), 10 deletions(-) diff --git a/src/spikeinterface/core/sorting_tools.py b/src/spikeinterface/core/sorting_tools.py index f9d05d29ca..24718d4e5d 100644 --- a/src/spikeinterface/core/sorting_tools.py +++ b/src/spikeinterface/core/sorting_tools.py @@ -469,7 +469,6 @@ def apply_splits_to_sorting(sorting, unit_splits, new_unit_ids=None, return_extr spike_vector_list = [spikes[s0:s1] for s0, s1 in segment_slices] spike_indices = spike_vector_to_indices(spike_vector_list, sorting.unit_ids, absolute_index=True) - # split_indices are a concatenation across segments for unit_id in sorting.unit_ids: if unit_id in unit_splits: split_indices = unit_splits[unit_id] @@ -477,8 +476,10 @@ def apply_splits_to_sorting(sorting, unit_splits, new_unit_ids=None, return_extr for split, new_unit_id in zip(split_indices, new_split_ids): new_unit_index = all_unit_ids.index(new_unit_id) + # split_indices are a concatenation across segments with absolute indices + # so we need to concatenate the spike indices across segments spike_indices_unit = np.concatenate( - spike_indices[segment_index][unit_id] for segment_index in range(num_seg) + [spike_indices[segment_index][unit_id] for segment_index in range(num_seg)] ) spikes["unit_index"][spike_indices_unit[split]] = new_unit_index else: diff --git a/src/spikeinterface/curation/tests/test_curation_format.py b/src/spikeinterface/curation/tests/test_curation_format.py index fef1f9d6f7..d0126d5460 100644 --- a/src/spikeinterface/curation/tests/test_curation_format.py +++ b/src/spikeinterface/curation/tests/test_curation_format.py @@ -127,7 +127,6 @@ # Test dictionary format for merges with string IDs curation_ids_str_dict = {**curation_ids_str, "merges": {"u50": ["u3", "u6"], "u51": ["u10", "u14", "u20"]}} - # This is a failure example with duplicated merge duplicate_merge = curation_ids_int.copy() duplicate_merge["merge_unit_groups"] = [[3, 6, 10], [10, 14, 20]] @@ -292,11 +291,117 @@ def test_apply_curation_with_split(): assert analyzer_curated.sorting.get_property("pyramidal", ids=[unit_id])[0] +def test_apply_curation_with_split_multi_segment(): + recording, sorting = generate_ground_truth_recording(durations=[10.0, 10.0], num_units=9, seed=2205) + sorting = sorting.rename_units(np.array([1, 2, 3, 6, 10, 14, 20, 31, 42])) + analyzer = create_sorting_analyzer(sorting, recording, sparse=False) + num_segments = sorting.get_num_segments() + + curation_with_splits_multi_segment = curation_with_splits.copy() + + # we make a split so that each subsplit will have all spikes from different segments + split_unit_id = curation_with_splits_multi_segment["splits"][0]["unit_id"] + sv = sorting.to_spike_vector() + unit_index = sorting.id_to_index(split_unit_id) + spikes_from_split_unit = sv[sv["unit_index"] == unit_index] + + split_indices = [] + cum_spikes = 0 + for segment_index in range(num_segments): + spikes_in_segment = spikes_from_split_unit[spikes_from_split_unit["segment_index"] == segment_index] + split_indices.append(np.arange(0, len(spikes_in_segment)) + cum_spikes) + cum_spikes += len(spikes_in_segment) + + curation_with_splits_multi_segment["splits"][0]["split_indices"] = split_indices + + sorting_curated = apply_curation(sorting, curation_with_splits_multi_segment) + + assert len(sorting_curated.unit_ids) == len(sorting.unit_ids) + 1 + assert 2 not in sorting_curated.unit_ids + assert 43 in sorting_curated.unit_ids + assert 44 in sorting_curated.unit_ids + + # check that spike trains are correctly split across segments + for seg_index in range(num_segments): + st_43 = sorting_curated.get_unit_spike_train(43, segment_index=seg_index) + st_44 = sorting_curated.get_unit_spike_train(44, segment_index=seg_index) + if seg_index == 0: + assert len(st_43) > 0 + assert len(st_44) == 0 + else: + assert len(st_43) == 0 + assert len(st_44) > 0 + + +def test_apply_curation_splits_with_mask(): + recording, sorting = generate_ground_truth_recording(durations=[10.0], num_units=9, seed=2205) + sorting = sorting.rename_units(np.array([1, 2, 3, 6, 10, 14, 20, 31, 42])) + analyzer = create_sorting_analyzer(sorting, recording, sparse=False) + + # Get number of spikes for unit 2 + num_spikes = sorting.count_num_spikes_per_unit()[2] + + # Create split labels that assign spikes to 3 different clusters + split_labels = np.zeros(num_spikes, dtype=int) + split_labels[: num_spikes // 3] = 0 # First third to cluster 0 + split_labels[num_spikes // 3 : 2 * num_spikes // 3] = 1 # Second third to cluster 1 + split_labels[2 * num_spikes // 3 :] = 2 # Last third to cluster 2 + + curation_with_mask_split = { + "format_version": "2", + "unit_ids": [1, 2, 3, 6, 10, 14, 20, 31, 42], + "label_definitions": { + "quality": {"label_options": ["good", "noise", "MUA", "artifact"], "exclusive": True}, + "putative_type": { + "label_options": ["excitatory", "inhibitory", "pyramidal", "mitral"], + "exclusive": False, + }, + }, + "manual_labels": [ + {"unit_id": 2, "quality": ["good"], "putative_type": ["excitatory", "pyramidal"]}, + ], + "splits": [ + { + "unit_id": 2, + "split_mode": "labels", + "split_labels": split_labels.tolist(), + "split_new_unit_ids": [43, 44, 45], + } + ], + } + + sorting_curated = apply_curation(sorting, curation_with_mask_split) + + # Check results + assert len(sorting_curated.unit_ids) == len(sorting.unit_ids) + 2 # Original units - 1 (split) + 3 (new) + assert 2 not in sorting_curated.unit_ids # Original unit should be removed + + # Check new split units + split_unit_ids = [43, 44, 45] + for unit_id in split_unit_ids: + assert unit_id in sorting_curated.unit_ids + # Check properties are propagated + assert sorting_curated.get_property("quality", ids=[unit_id])[0] == "good" + assert sorting_curated.get_property("excitatory", ids=[unit_id])[0] + assert sorting_curated.get_property("pyramidal", ids=[unit_id])[0] + + # Check analyzer + analyzer_curated = apply_curation(analyzer, curation_with_mask_split) + assert len(analyzer_curated.sorting.unit_ids) == len(analyzer.sorting.unit_ids) + 2 + + # Verify split sizes + spike_counts = analyzer_curated.sorting.count_num_spikes_per_unit() + assert spike_counts[43] == num_spikes // 3 # First third + assert spike_counts[44] == num_spikes // 3 # Second third + assert spike_counts[45] == num_spikes - 2 * (num_spikes // 3) # Remainder + + if __name__ == "__main__": - # test_curation_format_validation() - # test_to_from_json() - # test_convert_from_sortingview_curation_format_v0() - # test_curation_label_to_vectors() - # test_curation_label_to_dataframe() - # test_apply_curation() - test_apply_curation_with_split() + test_curation_format_validation() + test_to_from_json() + test_convert_from_sortingview_curation_format_v0() + test_curation_label_to_vectors() + test_curation_label_to_dataframe() + test_apply_curation() + test_apply_curation_with_split_multi_segment() + test_apply_curation_splits_with_mask()