From 96863de4a8adba8ae46e97ac166cb4f9dc6cbfad Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 11 Mar 2025 13:22:26 +0100 Subject: [PATCH 01/11] Add pydantic curation model and improve curation format and merging rules --- pyproject.toml | 1 + src/spikeinterface/core/sorting_tools.py | 13 +- src/spikeinterface/core/sortinganalyzer.py | 22 +- .../curation/curation_format.py | 191 ++++++++---------- src/spikeinterface/curation/curation_model.py | 144 +++++++++++++ .../curation/tests/test_curation_format.py | 14 +- 6 files changed, 256 insertions(+), 129 deletions(-) create mode 100644 src/spikeinterface/curation/curation_model.py diff --git a/pyproject.toml b/pyproject.toml index 97ba77299e..10a21f3595 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -97,6 +97,7 @@ full = [ "pandas", "scipy", "scikit-learn", + "pydantic", "networkx", "distinctipy", "matplotlib>=3.6", # matplotlib.colormaps diff --git a/src/spikeinterface/core/sorting_tools.py b/src/spikeinterface/core/sorting_tools.py index 47ce8cf848..ca8c731040 100644 --- a/src/spikeinterface/core/sorting_tools.py +++ b/src/spikeinterface/core/sorting_tools.py @@ -232,8 +232,13 @@ def random_spikes_selection( def apply_merges_to_sorting( - sorting, merge_unit_groups, new_unit_ids=None, censor_ms=None, return_extra=False, new_id_strategy="append" -): + sorting: BaseSorting, + merge_unit_groups: list[list[int | str]] | list[tuple[int | str]], + new_unit_ids: list[int | str] | None = None, + censor_ms: float | None = None, + return_extra: bool = False, + new_id_strategy: str = "append", +) -> NumpySorting | tuple[NumpySorting, np.ndarray, list[int | str]]: """ Apply a resolved representation of the merges to a sorting object. @@ -245,9 +250,9 @@ def apply_merges_to_sorting( Parameters ---------- - sorting : Sorting + sorting : BaseSorting The Sorting object to apply merges. - merge_unit_groups : list/tuple of lists/tuples + merge_unit_groups : list of lists/tuples A list of lists for every merge group. Each element needs to have at least two elements (two units to merge), but it can also have more (merge multiple units at once). new_unit_ids : list | None, default: None diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index 85d405c443..15478b8001 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -1138,18 +1138,18 @@ def remove_units(self, remove_unit_ids, format="memory", folder=None) -> "Sortin def merge_units( self, - merge_unit_groups, - new_unit_ids=None, - censor_ms=None, - merging_mode="soft", - sparsity_overlap=0.75, - new_id_strategy="append", - return_new_unit_ids=False, - format="memory", - folder=None, - verbose=False, + merge_unit_groups: list[list[str | int]] | list[tuple[str | int]], + new_unit_ids: list[int | str] | None = None, + censor_ms: float | None = None, + merging_mode: str = "soft", + sparsity_overlap: float = 0.75, + new_id_strategy: str = "append", + return_new_unit_ids: bool = False, + format: str = "memory", + folder: Path | str | None = None, + verbose: bool = False, **job_kwargs, - ) -> "SortingAnalyzer": + ) -> "SortingAnalyzer | tuple[SortingAnalyzer, list[int | str]]": """ This method is equivalent to `save_as()`but with a list of merges that have to be achieved. Merges units by creating a new SortingAnalyzer object with the appropriate merges diff --git a/src/spikeinterface/curation/curation_format.py b/src/spikeinterface/curation/curation_format.py index 80f251ca43..186bb34568 100644 --- a/src/spikeinterface/curation/curation_format.py +++ b/src/spikeinterface/curation/curation_format.py @@ -1,14 +1,14 @@ -from itertools import combinations +from __future__ import annotations +import copy import numpy as np +from spikeinterface import curation from spikeinterface.core import BaseSorting, SortingAnalyzer, apply_merges_to_sorting -import copy - -supported_curation_format_versions = {"1"} +from spikeinterface.curation.curation_model import CurationModel -def validate_curation_dict(curation_dict): +def validate_curation_dict(curation_dict: dict): """ Validate that the curation dictionary given as parameter complies with the format @@ -19,61 +19,11 @@ def validate_curation_dict(curation_dict): curation_dict : dict """ + # this will validate the format of the curation_dict + CurationModel(**curation_dict) - # format - if "format_version" not in curation_dict: - raise ValueError("No version_format") - - if curation_dict["format_version"] not in supported_curation_format_versions: - raise ValueError( - f"Format version ({curation_dict['format_version']}) not supported. " - f"Only {supported_curation_format_versions} are valid" - ) - - # unit_ids - labeled_unit_set = set([lbl["unit_id"] for lbl in curation_dict["manual_labels"]]) - merged_units_set = set(sum(curation_dict["merge_unit_groups"], [])) - removed_units_set = set(curation_dict["removed_units"]) - - if curation_dict["unit_ids"] is not None: - # old format v0 did not contain unit_ids so this can contains None - unit_set = set(curation_dict["unit_ids"]) - if not labeled_unit_set.issubset(unit_set): - raise ValueError("Curation format: some labeled units are not in the unit list") - if not merged_units_set.issubset(unit_set): - raise ValueError("Curation format: some merged units are not in the unit list") - if not removed_units_set.issubset(unit_set): - raise ValueError("Curation format: some removed units are not in the unit list") - for group in curation_dict["merge_unit_groups"]: - if len(group) < 2: - raise ValueError("Curation format: 'merge_unit_groups' must be list of list with at least 2 elements") - - all_merging_groups = [set(group) for group in curation_dict["merge_unit_groups"]] - for gp_1, gp_2 in combinations(all_merging_groups, 2): - if len(gp_1.intersection(gp_2)) != 0: - raise ValueError("Curation format: some units belong to multiple merge groups") - if len(removed_units_set.intersection(merged_units_set)) != 0: - raise ValueError("Curation format: some units were merged and deleted") - - # Check the labels exclusivity - for lbl in curation_dict["manual_labels"]: - for label_key in curation_dict["label_definitions"].keys(): - if label_key in lbl: - unit_id = lbl["unit_id"] - label_value = lbl[label_key] - if not isinstance(label_value, list): - raise ValueError(f"Curation format: manual_labels {unit_id} is invalid shoudl be a list") - - is_exclusive = curation_dict["label_definitions"][label_key]["exclusive"] - - if is_exclusive and not len(label_value) <= 1: - raise ValueError( - f"Curation format: manual_labels {unit_id} {label_key} are exclusive labels. {label_value} is invalid" - ) - - -def convert_from_sortingview_curation_format_v0(sortingview_dict, destination_format="1"): +def convert_from_sortingview_curation_format_v0(sortingview_dict: dict, destination_format: str = "1"): """ Converts the old sortingview curation format (v0) into a curation dictionary new format (v1) Couple of caveats: @@ -99,7 +49,6 @@ def convert_from_sortingview_curation_format_v0(sortingview_dict, destination_fo if "mergeGroups" not in sortingview_dict.keys(): sortingview_dict["mergeGroups"] = [] merge_groups = sortingview_dict["mergeGroups"] - merged_units = sum(merge_groups, []) first_unit_id = next(iter(sortingview_dict["labelsByUnit"].keys())) if str.isdigit(first_unit_id): @@ -115,13 +64,19 @@ def convert_from_sortingview_curation_format_v0(sortingview_dict, destination_fo all_labels.extend(l_labels) # recorver the correct type for unit_id unit_id = unit_id_type(unit_id_) - all_units.append(unit_id) + if unit_id not in all_units: + all_units.append(unit_id) manual_labels.append({"unit_id": unit_id, general_cat: l_labels}) labels_def = {"all_labels": {"name": "all_labels", "label_options": list(set(all_labels)), "exclusive": False}} + for merge_group in merge_groups: + for unit_id in merge_group: + if unit_id not in all_units: + all_units.append(unit_id) + curation_dict = { "format_version": destination_format, - "unit_ids": None, + "unit_ids": all_units, "label_definitions": labels_def, "manual_labels": manual_labels, "merge_unit_groups": merge_groups, @@ -131,7 +86,7 @@ def convert_from_sortingview_curation_format_v0(sortingview_dict, destination_fo return curation_dict -def curation_label_to_vectors(curation_dict): +def curation_label_to_vectors(curation_dict_or_model: dict | CurationModel): """ Transform the curation dict into dict of vectors. For label category with exclusive=True : a column is created and values are the unique label. @@ -141,43 +96,46 @@ def curation_label_to_vectors(curation_dict): Parameters ---------- - curation_dict : dict - A curation dictionary + curation_dict : dict or CurationModel + A curation dictionary or model Returns ------- labels : dict of numpy vector """ - unit_ids = list(curation_dict["unit_ids"]) + if isinstance(curation_dict_or_model, dict): + curation_model = CurationModel(**curation_dict_or_model) + else: + curation_model = curation_dict_or_model + unit_ids = list(curation_model.unit_ids) n = len(unit_ids) labels = {} - for label_key, label_def in curation_dict["label_definitions"].items(): - if label_def["exclusive"]: + for label_key, label_def in curation_model.label_definitions.items(): + if label_def.exclusive: assert label_key not in labels, f"{label_key} is already a key" labels[label_key] = [""] * n - for lbl in curation_dict["manual_labels"]: - value = lbl.get(label_key, []) - if len(value) == 1: - unit_index = unit_ids.index(lbl["unit_id"]) - labels[label_key][unit_index] = value[0] + for manual_label in curation_model.manual_labels: + values = manual_label.labels.get(label_key, []) + if len(values) == 1: + unit_index = unit_ids.index(manual_label.unit_id) + labels[label_key][unit_index] = values[0] labels[label_key] = np.array(labels[label_key]) else: - for label_opt in label_def["label_options"]: + for label_opt in label_def.label_options: assert label_opt not in labels, f"{label_opt} is already a key" labels[label_opt] = np.zeros(n, dtype=bool) - for lbl in curation_dict["manual_labels"]: - values = lbl.get(label_key, []) + for manual_label in curation_model.manual_labels: + values = manual_label.labels.get(label_key, []) for value in values: - unit_index = unit_ids.index(lbl["unit_id"]) + unit_index = unit_ids.index(manual_label.unit_id) labels[value][unit_index] = True - return labels -def clean_curation_dict(curation_dict): +def clean_curation_dict(curation_dict: dict): """ In some cases the curation_dict can have inconsistencies (like in the sorting view format). For instance, some unit_ids are both in 'merge_unit_groups' and 'removed_units'. @@ -200,7 +158,7 @@ def clean_curation_dict(curation_dict): return curation_dict -def curation_label_to_dataframe(curation_dict): +def curation_label_to_dataframe(curation_dict_or_model: dict | CurationModel): """ Transform the curation dict into a pandas dataframe. For label category with exclusive=True : a column is created and values are the unique label. @@ -220,11 +178,18 @@ def curation_label_to_dataframe(curation_dict): """ import pandas as pd - labels = pd.DataFrame(curation_label_to_vectors(curation_dict), index=curation_dict["unit_ids"]) + if isinstance(curation_dict_or_model, dict): + curation_model = CurationModel(**curation_dict_or_model) + else: + curation_model = curation_dict_or_model + + labels = pd.DataFrame(curation_label_to_vectors(curation_model), index=curation_model.unit_ids) return labels -def apply_curation_labels(sorting, new_unit_ids, curation_dict): +def apply_curation_labels( + sorting: BaseSorting, new_unit_ids: list[int, str], curation_dict_or_model: dict | CurationModel +): """ Apply manual labels after merges. @@ -233,25 +198,29 @@ def apply_curation_labels(sorting, new_unit_ids, curation_dict): * for merged group, when exclusive=True, if all have the same label then this label is applied * for merged group, when exclusive=False, if one unit has the label then the new one have also it """ + if isinstance(curation_dict_or_model, dict): + curation_model = CurationModel(**curation_dict_or_model) + else: + curation_model = curation_dict_or_model # Please note that manual_labels is done on the unit_ids before the merge!!! - manual_labels = curation_label_to_vectors(curation_dict) + manual_labels = curation_label_to_vectors(curation_model) # apply on non merged for key, values in manual_labels.items(): all_values = np.zeros(sorting.unit_ids.size, dtype=values.dtype) for unit_ind, unit_id in enumerate(sorting.unit_ids): if unit_id not in new_unit_ids: - ind = list(curation_dict["unit_ids"]).index(unit_id) + ind = list(curation_model.unit_ids).index(unit_id) all_values[unit_ind] = values[ind] sorting.set_property(key, all_values) - for new_unit_id, old_group_ids in zip(new_unit_ids, curation_dict["merge_unit_groups"]): - for label_key, label_def in curation_dict["label_definitions"].items(): - if label_def["exclusive"]: + for new_unit_id, old_group_ids in zip(new_unit_ids, curation_model.merge_unit_groups): + for label_key, label_def in curation_model.label_definitions.items(): + if label_def.exclusive: group_values = [] for unit_id in old_group_ids: - ind = curation_dict["unit_ids"].index(unit_id) + ind = list(curation_model.unit_ids).index(unit_id) value = manual_labels[label_key][ind] if value != "": group_values.append(value) @@ -260,10 +229,10 @@ def apply_curation_labels(sorting, new_unit_ids, curation_dict): sorting.set_property(key, values=group_values[:1], ids=[new_unit_id]) else: - for key in label_def["label_options"]: + for key in label_def.label_options: group_values = [] for unit_id in old_group_ids: - ind = curation_dict["unit_ids"].index(unit_id) + ind = list(curation_model.unit_ids).index(unit_id) value = manual_labels[key][ind] group_values.append(value) new_value = np.any(group_values) @@ -271,13 +240,13 @@ def apply_curation_labels(sorting, new_unit_ids, curation_dict): def apply_curation( - sorting_or_analyzer, - curation_dict, - censor_ms=None, - new_id_strategy="append", - merging_mode="soft", - sparsity_overlap=0.75, - verbose=False, + sorting_or_analyzer: BaseSorting | SortingAnalyzer, + curation_dict_or_model: dict | CurationModel, + censor_ms: float | None = None, + new_id_strategy: str = "append", + merging_mode: str = "soft", + sparsity_overlap: float = 0.75, + verbose: bool = False, **job_kwargs, ): """ @@ -294,9 +263,9 @@ def apply_curation( Parameters ---------- sorting_or_analyzer : Sorting | SortingAnalyzer - The Sorting object to apply merges. - curation_dict : dict - The curation dict. + The Sorting or SortingAnalyzer object to apply merges. + curation_dict : dict or CurationModel + The curation dict or model. censor_ms : float | None, default: None When applying the merges, any consecutive spikes within the `censor_ms` are removed. This can be thought of as the desired refractory period. If `censor_ms=None`, no spikes are discarded. @@ -324,30 +293,34 @@ def apply_curation( """ - validate_curation_dict(curation_dict) - if not np.array_equal(np.asarray(curation_dict["unit_ids"]), sorting_or_analyzer.unit_ids): + if isinstance(curation_dict_or_model, dict): + curation_model = CurationModel(**curation_dict_or_model) + else: + curation_model = curation_dict_or_model + + if not np.array_equal(np.asarray(curation_model.unit_ids), sorting_or_analyzer.unit_ids): raise ValueError("unit_ids from the curation_dict do not match the one from Sorting or SortingAnalyzer") if isinstance(sorting_or_analyzer, BaseSorting): sorting = sorting_or_analyzer - sorting = sorting.remove_units(curation_dict["removed_units"]) + sorting = sorting.remove_units(curation_model.removed_units) sorting, _, new_unit_ids = apply_merges_to_sorting( sorting, - curation_dict["merge_unit_groups"], + curation_model.merge_unit_groups, censor_ms=censor_ms, return_extra=True, new_id_strategy=new_id_strategy, ) - apply_curation_labels(sorting, new_unit_ids, curation_dict) + apply_curation_labels(sorting, new_unit_ids, curation_model) return sorting elif isinstance(sorting_or_analyzer, SortingAnalyzer): analyzer = sorting_or_analyzer - if len(curation_dict["removed_units"]) > 0: - analyzer = analyzer.remove_units(curation_dict["removed_units"]) - if len(curation_dict["merge_unit_groups"]) > 0: + if len(curation_model.removed_units) > 0: + analyzer = analyzer.remove_units(curation_model.removed_units) + if len(curation_model.merge_unit_groups) > 0: analyzer, new_unit_ids = analyzer.merge_units( - curation_dict["merge_unit_groups"], + curation_model.merge_unit_groups, censor_ms=censor_ms, merging_mode=merging_mode, sparsity_overlap=sparsity_overlap, @@ -359,7 +332,7 @@ def apply_curation( ) else: new_unit_ids = [] - apply_curation_labels(analyzer.sorting, new_unit_ids, curation_dict) + apply_curation_labels(analyzer.sorting, new_unit_ids, curation_model) return analyzer else: raise TypeError( diff --git a/src/spikeinterface/curation/curation_model.py b/src/spikeinterface/curation/curation_model.py new file mode 100644 index 0000000000..5e7459df7c --- /dev/null +++ b/src/spikeinterface/curation/curation_model.py @@ -0,0 +1,144 @@ +from pydantic import BaseModel, Field, field_validator, model_validator +from typing import List, Dict, Union, Optional +from itertools import combinations + +supported_curation_format_versions = {"1"} + + +class LabelDefinition(BaseModel): + name: str = Field(..., description="Name of the label") + label_options: List[str] = Field(..., description="List of possible label options") + exclusive: bool = Field(..., description="Whether the label is exclusive") + + +class ManualLabel(BaseModel): + unit_id: Union[int, str] = Field(..., description="ID of the unit") + labels: Dict[str, List[str]] = Field(..., description="Dictionary of labels for the unit") + + +class CurationModel(BaseModel): + format_version: str = Field(..., description="Version of the curation format") + unit_ids: List[Union[int, str]] = Field(..., description="List of unit IDs") + label_definitions: Dict[str, LabelDefinition] = Field(..., description="Dictionary of label definitions") + manual_labels: List[ManualLabel] = Field(..., description="List of manual labels") + merge_unit_groups: List[List[Union[int, str]]] = Field(..., description="List of groups of units to be merged") + removed_units: List[Union[int, str]] = Field(..., description="List of removed unit IDs") + merge_new_unit_ids: Optional[List[Union[int, str]]] = Field( + default=None, description="List of new unit IDs after merging" + ) + + @field_validator("format_version") + def check_format_version(cls, v): + if v not in supported_curation_format_versions: + raise ValueError(f"Format version ({v}) not supported. Only {supported_curation_format_versions} are valid") + return v + + @field_validator("label_definitions", mode="before") + def add_label_definition_name(cls, v): + if v is None: + v = {} + else: + v_copy = v.copy() + for key in v_copy.keys(): + v[key]["name"] = key + return v + + @model_validator(mode="before") + def check_manual_labels(cls, values): + unit_ids = values["unit_ids"] + manual_labels = values["manual_labels"] + if manual_labels is None: + values["manual_labels"] = [] + else: + for manual_label in manual_labels: + unit_id = manual_label["unit_id"] + labels = manual_label.get("labels") + if labels is None: + labels = set(manual_label.keys()) - {"unit_id"} + manual_label["labels"] = {} + for label in labels: + if label not in values["label_definitions"]: + raise ValueError(f"Manual label {unit_id} has an unknown label {label}") + manual_label["labels"][label] = manual_label[label] + if unit_id not in unit_ids: + raise ValueError(f"Manual label unit_id {unit_id} is not in the unit list") + return values + + @model_validator(mode="before") + def check_merge_unit_groups(cls, values): + unit_ids = values["unit_ids"] + merge_unit_groups = values.get("merge_unit_groups", []) + for merge_group in merge_unit_groups: + for unit_id in merge_group: + if unit_id not in unit_ids: + raise ValueError(f"Merge unit group unit_id {unit_id} is not in the unit list") + if len(merge_group) < 2: + raise ValueError("Merge unit groups must have at least 2 elements") + return values + + @model_validator(mode="before") + def check_merge_new_unit_ids(cls, values): + unit_ids = values["unit_ids"] + merge_new_unit_ids = values.get("merge_new_unit_ids") + if merge_new_unit_ids is not None: + merge_unit_groups = values.get("merge_unit_groups") + assert merge_unit_groups is not None, "Merge unit groups must be defined if merge new unit ids are defined" + if len(merge_unit_groups) != len(merge_new_unit_ids): + raise ValueError("Merge unit groups and new unit ids must have the same length") + if len(merge_new_unit_ids) > 0: + for new_unit_id in merge_new_unit_ids: + if new_unit_id in unit_ids: + raise ValueError(f"New unit ID {new_unit_id} is already in the unit list") + return values + + @model_validator(mode="before") + def check_removed_units(cls, values): + unit_ids = values["unit_ids"] + removed_units = values.get("removed_units", []) + for unit_id in removed_units: + if unit_id not in unit_ids: + raise ValueError(f"Removed unit_id {unit_id} is not in the unit list") + return values + + @model_validator(mode="after") + def validate_curation_dict(cls, values): + labeled_unit_set = set([lbl.unit_id for lbl in values.manual_labels]) + merged_units_set = set(sum(values.merge_unit_groups, [])) + removed_units_set = set(values.removed_units) + unit_ids = values.unit_ids + + unit_set = set(unit_ids) + if not labeled_unit_set.issubset(unit_set): + raise ValueError("Curation format: some labeled units are not in the unit list") + if not merged_units_set.issubset(unit_set): + raise ValueError("Curation format: some merged units are not in the unit list") + if not removed_units_set.issubset(unit_set): + raise ValueError("Curation format: some removed units are not in the unit list") + + for group in values.merge_unit_groups: + if len(group) < 2: + raise ValueError("Curation format: 'merge_unit_groups' must be list of list with at least 2 elements") + + all_merging_groups = [set(group) for group in values.merge_unit_groups] + for gp_1, gp_2 in combinations(all_merging_groups, 2): + if len(gp_1.intersection(gp_2)) != 0: + raise ValueError("Curation format: some units belong to multiple merge groups") + if len(removed_units_set.intersection(merged_units_set)) != 0: + raise ValueError("Curation format: some units were merged and deleted") + + for manual_label in values.manual_labels: + for label_key in values.label_definitions.keys(): + if label_key in manual_label.labels: + unit_id = manual_label.unit_id + label_value = manual_label.labels[label_key] + if not isinstance(label_value, list): + raise ValueError(f"Curation format: manual_labels {unit_id} is invalid should be a list") + + is_exclusive = values.label_definitions[label_key].exclusive + + if is_exclusive and not len(label_value) <= 1: + raise ValueError( + f"Curation format: manual_labels {unit_id} {label_key} are exclusive labels. {label_value} is invalid" + ) + + return values diff --git a/src/spikeinterface/curation/tests/test_curation_format.py b/src/spikeinterface/curation/tests/test_curation_format.py index af9d8e1eac..0d9562f404 100644 --- a/src/spikeinterface/curation/tests/test_curation_format.py +++ b/src/spikeinterface/curation/tests/test_curation_format.py @@ -1,5 +1,8 @@ import pytest +from pydantic import BaseModel, ValidationError, field_validator + + from pathlib import Path import json import numpy as np @@ -183,10 +186,11 @@ def test_apply_curation(): if __name__ == "__main__": - # test_curation_format_validation() - # test_to_from_json() - # test_convert_from_sortingview_curation_format_v0() - # test_curation_label_to_vectors() - # test_curation_label_to_dataframe() + test_curation_format_validation() + test_curation_format_validation() + test_to_from_json() + test_convert_from_sortingview_curation_format_v0() + test_curation_label_to_vectors() + test_curation_label_to_dataframe() test_apply_curation() From 1ce611c0c33beeb97800d90f9ab70d7ec8516a01 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 11 Mar 2025 13:25:33 +0100 Subject: [PATCH 02/11] Update src/spikeinterface/curation/curation_model.py --- src/spikeinterface/curation/curation_model.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/spikeinterface/curation/curation_model.py b/src/spikeinterface/curation/curation_model.py index 5e7459df7c..0e2f1da870 100644 --- a/src/spikeinterface/curation/curation_model.py +++ b/src/spikeinterface/curation/curation_model.py @@ -38,8 +38,7 @@ def add_label_definition_name(cls, v): if v is None: v = {} else: - v_copy = v.copy() - for key in v_copy.keys(): + for key in list(v.keys()): v[key]["name"] = key return v From 3464987d1d987db126377a00dbde8c48c66268c2 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Tue, 11 Mar 2025 14:45:05 +0100 Subject: [PATCH 03/11] Move pydantic to core --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 89e69d90f9..a142391310 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,6 +28,7 @@ dependencies = [ "neo>=0.14.0", "probeinterface>=0.2.23", "packaging", + "pydantic", ] [build-system] @@ -97,7 +98,6 @@ full = [ "pandas", "scipy", "scikit-learn", - "pydantic", "networkx", "distinctipy", "matplotlib>=3.6", # matplotlib.colormaps From dbfa315296376063995b4ebb5ebae9b6d24ed727 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 27 Mar 2025 13:39:11 -0400 Subject: [PATCH 04/11] Refactor curation model to include merges and splits --- src/spikeinterface/curation/curation_model.py | 233 +++++++++++---- .../curation/tests/test_curation_model.py | 279 ++++++++++++++++++ 2 files changed, 462 insertions(+), 50 deletions(-) create mode 100644 src/spikeinterface/curation/tests/test_curation_model.py diff --git a/src/spikeinterface/curation/curation_model.py b/src/spikeinterface/curation/curation_model.py index 0e2f1da870..a595f32eac 100644 --- a/src/spikeinterface/curation/curation_model.py +++ b/src/spikeinterface/curation/curation_model.py @@ -1,13 +1,15 @@ from pydantic import BaseModel, Field, field_validator, model_validator -from typing import List, Dict, Union, Optional -from itertools import combinations +from typing import List, Dict, Union, Optional, Literal +from itertools import combinations, chain +import numpy as np + supported_curation_format_versions = {"1"} class LabelDefinition(BaseModel): name: str = Field(..., description="Name of the label") - label_options: List[str] = Field(..., description="List of possible label options") + label_options: List[str] = Field(..., description="List of possible label options", min_length=2) exclusive: bool = Field(..., description="Whether the label is exclusive") @@ -16,16 +18,39 @@ class ManualLabel(BaseModel): labels: Dict[str, List[str]] = Field(..., description="Dictionary of labels for the unit") +class Merge(BaseModel): + merge_unit_group: List[Union[int, str]] = Field(..., description="List of groups of units to be merged") + merge_new_unit_ids: Optional[Union[int, str]] = Field(default=None, description="New unit IDs for the merge group") + + +class Split(BaseModel): + unit_id: Union[int, str] = Field(..., description="ID of the unit") + split_mode: Literal["indices", "labels"] = Field( + default="indices", + description=( + "Mode of the split. The split can be defined by indices or labels. " + "If indices, the split is defined by the a list of lists of indices of spikes within spikes " + "belonging to the unit (`split_indices`). " + "If labels, the split is defined by a list of labels for each spike (`split_labels`). " + ), + ) + split_indices: Optional[Union[List[List[int]]]] = Field(default=None, description="List of indices for the split") + split_labels: Optional[List[int]] = Field(default=None, description="List of labels for the split") + split_new_unit_ids: Optional[List[Union[int, str]]] = Field( + default=None, description="List of new unit IDs for each split" + ) + + class CurationModel(BaseModel): format_version: str = Field(..., description="Version of the curation format") unit_ids: List[Union[int, str]] = Field(..., description="List of unit IDs") - label_definitions: Dict[str, LabelDefinition] = Field(..., description="Dictionary of label definitions") - manual_labels: List[ManualLabel] = Field(..., description="List of manual labels") - merge_unit_groups: List[List[Union[int, str]]] = Field(..., description="List of groups of units to be merged") - removed_units: List[Union[int, str]] = Field(..., description="List of removed unit IDs") - merge_new_unit_ids: Optional[List[Union[int, str]]] = Field( - default=None, description="List of new unit IDs after merging" + label_definitions: Optional[Dict[str, LabelDefinition]] = Field( + default=None, description="Dictionary of label definitions" ) + manual_labels: Optional[List[ManualLabel]] = Field(default=None, description="List of manual labels") + removed: Optional[List[Union[int, str]]] = Field(default=None, description="List of removed unit IDs") + merges: Optional[List[Merge]] = Field(default=None, description="List of merges") + splits: Optional[List[Split]] = Field(default=None, description="List of splits") @field_validator("format_version") def check_format_version(cls, v): @@ -33,19 +58,26 @@ def check_format_version(cls, v): raise ValueError(f"Format version ({v}) not supported. Only {supported_curation_format_versions} are valid") return v - @field_validator("label_definitions", mode="before") - def add_label_definition_name(cls, v): - if v is None: - v = {} - else: - for key in list(v.keys()): - v[key]["name"] = key - return v + @model_validator(mode="before") + def add_label_definition_name(cls, values): + label_definitions = values.get("label_definitions") + if label_definitions is None: + values["label_definitions"] = {} + return values + if isinstance(values["label_definitions"], dict): + if label_definitions is None: + label_definitions = {} + else: + for key in list(label_definitions.keys()): + if isinstance(label_definitions[key], dict): + label_definitions[key]["name"] = key + values["label_definitions"] = label_definitions + return values @model_validator(mode="before") def check_manual_labels(cls, values): unit_ids = values["unit_ids"] - manual_labels = values["manual_labels"] + manual_labels = values.get("manual_labels") if manual_labels is None: values["manual_labels"] = [] else: @@ -58,52 +90,148 @@ def check_manual_labels(cls, values): for label in labels: if label not in values["label_definitions"]: raise ValueError(f"Manual label {unit_id} has an unknown label {label}") - manual_label["labels"][label] = manual_label[label] + if label not in manual_label["labels"]: + if label in manual_label: + manual_label["labels"][label] = manual_label[label] + else: + raise ValueError(f"Manual label {unit_id} has no value for label {label}") if unit_id not in unit_ids: raise ValueError(f"Manual label unit_id {unit_id} is not in the unit list") return values @model_validator(mode="before") - def check_merge_unit_groups(cls, values): + def check_merges(cls, values): unit_ids = values["unit_ids"] - merge_unit_groups = values.get("merge_unit_groups", []) - for merge_group in merge_unit_groups: - for unit_id in merge_group: + merges = values.get("merges") + if merges is None: + values["merges"] = [] + return values + elif isinstance(merges, list): + # Convert list of lists to Merge objects + for i, merge in enumerate(merges): + if isinstance(merge, list): + merges[i] = Merge(merge_unit_group=merge) + elif isinstance(merges, dict): + # Convert dict format to list of Merge objects if needed + merge_list = [] + for merge_new_id, merge_group in merges.items(): + merge_list.append({"merge_unit_group": merge_group, "merge_new_unit_ids": merge_new_id}) + merges = merge_list + values["merges"] = merges + + # Convert dict items to Merge objects if needed + for i, merge in enumerate(merges): + if isinstance(merge, dict): + merges[i] = Merge(**merge) + + for merge in merges: + # Check unit ids exist + for unit_id in merge.merge_unit_group: if unit_id not in unit_ids: raise ValueError(f"Merge unit group unit_id {unit_id} is not in the unit list") - if len(merge_group) < 2: + + # Check minimum group size + if len(merge.merge_unit_group) < 2: raise ValueError("Merge unit groups must have at least 2 elements") + + # Check new unit id not already used + if merge.merge_new_unit_ids is not None: + if merge.merge_new_unit_ids in unit_ids: + raise ValueError(f"New unit ID {merge.merge_new_unit_ids} is already in the unit list") + return values @model_validator(mode="before") - def check_merge_new_unit_ids(cls, values): + def check_splits(cls, values): unit_ids = values["unit_ids"] - merge_new_unit_ids = values.get("merge_new_unit_ids") - if merge_new_unit_ids is not None: - merge_unit_groups = values.get("merge_unit_groups") - assert merge_unit_groups is not None, "Merge unit groups must be defined if merge new unit ids are defined" - if len(merge_unit_groups) != len(merge_new_unit_ids): - raise ValueError("Merge unit groups and new unit ids must have the same length") - if len(merge_new_unit_ids) > 0: - for new_unit_id in merge_new_unit_ids: - if new_unit_id in unit_ids: - raise ValueError(f"New unit ID {new_unit_id} is already in the unit list") + splits = values.get("splits") + if splits is None: + values["splits"] = [] + return values + + # Convert dict format to list of Split objects if needed + if isinstance(splits, dict): + split_list = [] + for unit_id, split_data in splits.items(): + # If split_data is a list of indices, assume indices mode + if isinstance(split_data[0], (list, np.ndarray)) if split_data else False: + split_list.append({"unit_id": unit_id, "split_mode": "indices", "split_indices": split_data}) + # Otherwise assume it's a list of labels + else: + split_list.append({"unit_id": unit_id, "split_mode": "labels", "split_labels": split_data}) + splits = split_list + values["splits"] = splits + + # Convert dict items to Split objects if needed + for i, split in enumerate(splits): + if isinstance(split, dict): + splits[i] = Split(**split) + + for split in splits: + # Check unit exists + if split.unit_id not in unit_ids: + raise ValueError(f"Split unit_id {split.unit_id} is not in the unit list") + + # Check split definition based on mode + if split.split_mode == "indices": + if split.split_indices is None: + raise ValueError(f"Split unit {split.unit_id} has no split_indices defined") + if len(split.split_indices) < 1: + raise ValueError(f"Split unit {split.unit_id} has empty split_indices") + # Check no duplicate indices across splits + all_indices = list(chain.from_iterable(split.split_indices)) + if len(all_indices) != len(set(all_indices)): + raise ValueError(f"Split unit {split.unit_id} has duplicate indices") + + elif split.split_mode == "labels": + if split.split_labels is None: + raise ValueError(f"Split unit {split.unit_id} has no split_labels defined") + if len(split.split_labels) == 0: + raise ValueError(f"Split unit {split.unit_id} has empty split_labels") + + # Check new unit ids if provided + if split.split_new_unit_ids is not None: + if split.split_mode == "indices": + if len(split.split_new_unit_ids) != len(split.split_indices): + raise ValueError( + f"Number of new unit IDs does not match number of splits for unit {split.unit_id}" + ) + elif split.split_mode == "labels": + if len(split.split_new_unit_ids) != len(set(split.split_labels)): + raise ValueError( + f"Number of new unit IDs does not match number of unique labels for unit {split.unit_id}" + ) + + # Check new ids not already used + for new_id in split.split_new_unit_ids: + if new_id in unit_ids: + raise ValueError(f"New unit ID {new_id} is already in the unit list") + return values @model_validator(mode="before") - def check_removed_units(cls, values): + def check_removed(cls, values): unit_ids = values["unit_ids"] - removed_units = values.get("removed_units", []) - for unit_id in removed_units: - if unit_id not in unit_ids: - raise ValueError(f"Removed unit_id {unit_id} is not in the unit list") + removed = values.get("removed", []) + if removed is None: + + for unit_id in removed: + if unit_id not in unit_ids: + raise ValueError(f"Removed unit_id {unit_id} is not in the unit list") + + else: + values["removed"] = removed + return values @model_validator(mode="after") def validate_curation_dict(cls, values): - labeled_unit_set = set([lbl.unit_id for lbl in values.manual_labels]) - merged_units_set = set(sum(values.merge_unit_groups, [])) - removed_units_set = set(values.removed_units) + labeled_unit_set = set([lbl.unit_id for lbl in values.manual_labels]) if values.manual_labels else set() + merged_units_set = ( + set(chain.from_iterable(merge.merge_unit_group for merge in values.merges)) if values.merges else set() + ) + split_units_set = set(split.unit_id for split in values.splits) if values.splits else set() + removed_set = set(values.removed) if values.removed else set() unit_ids = values.unit_ids unit_set = set(unit_ids) @@ -111,19 +239,24 @@ def validate_curation_dict(cls, values): raise ValueError("Curation format: some labeled units are not in the unit list") if not merged_units_set.issubset(unit_set): raise ValueError("Curation format: some merged units are not in the unit list") - if not removed_units_set.issubset(unit_set): + if not split_units_set.issubset(unit_set): + raise ValueError("Curation format: some split units are not in the unit list") + if not removed_set.issubset(unit_set): raise ValueError("Curation format: some removed units are not in the unit list") - for group in values.merge_unit_groups: - if len(group) < 2: - raise ValueError("Curation format: 'merge_unit_groups' must be list of list with at least 2 elements") - - all_merging_groups = [set(group) for group in values.merge_unit_groups] + # Check for units being merged multiple times + all_merging_groups = [set(merge.merge_unit_group) for merge in values.merges] if values.merges else [] for gp_1, gp_2 in combinations(all_merging_groups, 2): if len(gp_1.intersection(gp_2)) != 0: raise ValueError("Curation format: some units belong to multiple merge groups") - if len(removed_units_set.intersection(merged_units_set)) != 0: + + # Check no overlaps between operations + if len(removed_set.intersection(merged_units_set)) != 0: raise ValueError("Curation format: some units were merged and deleted") + if len(removed_set.intersection(split_units_set)) != 0: + raise ValueError("Curation format: some units were split and deleted") + if len(merged_units_set.intersection(split_units_set)) != 0: + raise ValueError("Curation format: some units were both merged and split") for manual_label in values.manual_labels: for label_key in values.label_definitions.keys(): diff --git a/src/spikeinterface/curation/tests/test_curation_model.py b/src/spikeinterface/curation/tests/test_curation_model.py new file mode 100644 index 0000000000..29292c495c --- /dev/null +++ b/src/spikeinterface/curation/tests/test_curation_model.py @@ -0,0 +1,279 @@ +import pytest + +from pydantic import ValidationError +import numpy as np + +from spikeinterface.curation.curation_model import CurationModel, LabelDefinition + + +# Test data for format version +def test_format_version(): + # Valid format version + CurationModel(format_version="1", unit_ids=[1, 2, 3]) + + # Invalid format version + with pytest.raises(ValidationError): + CurationModel(format_version="2", unit_ids=[1, 2, 3]) + with pytest.raises(ValidationError): + CurationModel(format_version="0", unit_ids=[1, 2, 3]) + + +# Test data for label definitions +def test_label_definitions(): + valid_label_def = { + "format_version": "1", + "unit_ids": [1, 2, 3], + "label_definitions": { + "quality": LabelDefinition(name="quality", label_options=["good", "noise"], exclusive=True), + "tags": LabelDefinition(name="tags", label_options=["burst", "slow", "fast"], exclusive=False), + }, + } + + model = CurationModel(**valid_label_def) + assert "quality" in model.label_definitions + assert model.label_definitions["quality"].name == "quality" + assert model.label_definitions["quality"].exclusive is True + + # Test invalid label definition + with pytest.raises(ValidationError): + LabelDefinition(name="quality", label_options=[], exclusive=True) # Empty options should be invalid + + +# Test manual labels +def test_manual_labels(): + valid_labels = { + "format_version": "1", + "unit_ids": [1, 2, 3], + "label_definitions": { + "quality": LabelDefinition(name="quality", label_options=["good", "noise"], exclusive=True), + "tags": LabelDefinition(name="tags", label_options=["burst", "slow", "fast"], exclusive=False), + }, + "manual_labels": [ + {"unit_id": 1, "labels": {"quality": ["good"], "tags": ["burst", "fast"]}}, + {"unit_id": 2, "labels": {"quality": ["noise"]}}, + ], + } + + model = CurationModel(**valid_labels) + assert len(model.manual_labels) == 2 + + # Test invalid unit ID + invalid_unit = { + "format_version": "1", + "unit_ids": [1, 2, 3], + "label_definitions": { + "quality": LabelDefinition(name="quality", label_options=["good", "noise"], exclusive=True) + }, + "manual_labels": [{"unit_id": 4, "labels": {"quality": ["good"]}}], # Non-existent unit + } + with pytest.raises(ValidationError): + CurationModel(**invalid_unit) + + # Test violation of exclusive label + invalid_exclusive = { + "format_version": "1", + "unit_ids": [1, 2, 3], + "label_definitions": { + "quality": LabelDefinition(name="quality", label_options=["good", "noise"], exclusive=True) + }, + "manual_labels": [ + {"unit_id": 1, "labels": {"quality": ["good", "noise"]}} # Multiple values for exclusive label + ], + } + with pytest.raises(ValidationError): + CurationModel(**invalid_exclusive) + + +# Test merge functionality +def test_merge_units(): + # Test list format + valid_merge = { + "format_version": "1", + "unit_ids": [1, 2, 3, 4], + "merges": [ + {"merge_unit_group": [1, 2], "merge_new_unit_ids": 5}, + {"merge_unit_group": [3, 4], "merge_new_unit_ids": 6}, + ], + } + + model = CurationModel(**valid_merge) + assert len(model.merges) == 2 + assert model.merges[0].merge_new_unit_ids == 5 + assert model.merges[1].merge_new_unit_ids == 6 + + # Test dictionary format + valid_merge_dict = {"format_version": "1", "unit_ids": [1, 2, 3, 4], "merges": {5: [1, 2], 6: [3, 4]}} + + model = CurationModel(**valid_merge_dict) + assert len(model.merges) == 2 + merge_new_ids = {merge.merge_new_unit_ids for merge in model.merges} + assert merge_new_ids == {5, 6} + + # Test invalid merge group (single unit) + invalid_merge_group = { + "format_version": "1", + "unit_ids": [1, 2, 3], + "merges": [{"merge_unit_group": [1], "merge_new_unit_ids": 4}], + } + with pytest.raises(ValidationError): + CurationModel(**invalid_merge_group) + + # Test overlapping merge groups + invalid_overlap = { + "format_version": "1", + "unit_ids": [1, 2, 3], + "merges": [ + {"merge_unit_group": [1, 2], "merge_new_unit_ids": 4}, + {"merge_unit_group": [2, 3], "merge_new_unit_ids": 5}, + ], + } + with pytest.raises(ValidationError): + CurationModel(**invalid_overlap) + + +# Test split functionality +def test_split_units(): + # Test indices mode with list format + valid_split_indices = { + "format_version": "1", + "unit_ids": [1, 2, 3], + "splits": [ + { + "unit_id": 1, + "split_mode": "indices", + "split_indices": [[0, 1, 2], [3, 4, 5]], + "split_new_unit_ids": [4, 5], + } + ], + } + + model = CurationModel(**valid_split_indices) + assert len(model.splits) == 1 + assert model.splits[0].split_mode == "indices" + assert len(model.splits[0].split_indices) == 2 + + # Test labels mode with list format + valid_split_labels = { + "format_version": "1", + "unit_ids": [1, 2, 3], + "splits": [ + {"unit_id": 1, "split_mode": "labels", "split_labels": [0, 0, 1, 1, 0, 2], "split_new_unit_ids": [4, 5, 6]} + ], + } + + model = CurationModel(**valid_split_labels) + assert len(model.splits) == 1 + assert model.splits[0].split_mode == "labels" + assert len(set(model.splits[0].split_labels)) == 3 + + # Test dictionary format with indices + valid_split_dict = { + "format_version": "1", + "unit_ids": [1, 2, 3], + "splits": { + 1: [[0, 1, 2], [3, 4, 5]], # Split unit 1 into two parts + 2: [[0, 1], [2, 3], [4, 5]], # Split unit 2 into three parts + }, + } + + model = CurationModel(**valid_split_dict) + assert len(model.splits) == 2 + assert all(split.split_mode == "indices" for split in model.splits) + + # Test invalid unit ID + invalid_unit_id = { + "format_version": "1", + "unit_ids": [1, 2, 3], + "splits": [{"unit_id": 4, "split_mode": "indices", "split_indices": [[0, 1], [2, 3]]}], # Non-existent unit + } + with pytest.raises(ValidationError): + CurationModel(**invalid_unit_id) + + # Test invalid new unit IDs count for indices mode + invalid_new_ids = { + "format_version": "1", + "unit_ids": [1, 2, 3], + "splits": [ + { + "unit_id": 1, + "split_mode": "indices", + "split_indices": [[0, 1], [2, 3]], + "split_new_unit_ids": [4], # Should have 2 new IDs for 2 splits + } + ], + } + with pytest.raises(ValidationError): + CurationModel(**invalid_new_ids) + + +# Test removed units +def test_removed_units(): + valid_remove = {"format_version": "1", "unit_ids": [1, 2, 3], "removed": [2]} + + model = CurationModel(**valid_remove) + assert len(model.removed) == 1 + + # Test removing non-existent unit + invalid_remove = {"format_version": "1", "unit_ids": [1, 2, 3], "removed": [4]} # Non-existent unit + with pytest.raises(ValidationError): + CurationModel(**invalid_remove) + + # Test conflict between merge and remove + invalid_merge_remove = { + "format_version": "1", + "unit_ids": [1, 2, 3], + "merges": [{"merge_unit_group": [1, 2], "merge_new_unit_ids": 4}], + "removed": [1], # Unit is both merged and removed + } + with pytest.raises(ValidationError): + CurationModel(**invalid_merge_remove) + + +# Test complete model with multiple operations +def test_complete_model(): + complete_model = { + "format_version": "1", + "unit_ids": [1, 2, 3, 4, 5], + "label_definitions": { + "quality": LabelDefinition(name="quality", label_options=["good", "noise"], exclusive=True), + "tags": LabelDefinition(name="tags", label_options=["burst", "slow"], exclusive=False), + }, + "manual_labels": [{"unit_id": 1, "labels": {"quality": ["good"], "tags": ["burst"]}}], + "merges": [{"merge_unit_group": [2, 3], "merge_new_unit_ids": 6}], + "splits": [ + {"unit_id": 4, "split_mode": "indices", "split_indices": [[0, 1], [2, 3]], "split_new_unit_ids": [7, 8]} + ], + "removed": [5], + } + + model = CurationModel(**complete_model) + assert model.format_version == "1" + assert len(model.unit_ids) == 5 + assert len(model.label_definitions) == 2 + assert len(model.manual_labels) == 1 + assert len(model.merges) == 1 + assert len(model.splits) == 1 + assert len(model.removed) == 1 + + # Test dictionary format for complete model + complete_model_dict = { + "format_version": "1", + "unit_ids": [1, 2, 3, 4, 5], + "label_definitions": { + "quality": LabelDefinition(name="quality", label_options=["good", "noise"], exclusive=True), + "tags": LabelDefinition(name="tags", label_options=["burst", "slow"], exclusive=False), + }, + "manual_labels": [{"unit_id": 1, "labels": {"quality": ["good"], "tags": ["burst"]}}], + "merges": {6: [2, 3]}, + "splits": {4: [[0, 1], [2, 3]]}, + "removed": [5], + } + + model = CurationModel(**complete_model_dict) + assert model.format_version == "1" + assert len(model.unit_ids) == 5 + assert len(model.label_definitions) == 2 + assert len(model.manual_labels) == 1 + assert len(model.merges) == 1 + assert len(model.splits) == 1 + assert len(model.removed) == 1 From 82526b0dc40457c1e14154af701d5aeab26bffd3 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 27 Mar 2025 13:46:35 -0400 Subject: [PATCH 05/11] Add merge list to tests --- src/spikeinterface/curation/tests/test_curation_model.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/spikeinterface/curation/tests/test_curation_model.py b/src/spikeinterface/curation/tests/test_curation_model.py index 29292c495c..1ebb36296d 100644 --- a/src/spikeinterface/curation/tests/test_curation_model.py +++ b/src/spikeinterface/curation/tests/test_curation_model.py @@ -109,6 +109,15 @@ def test_merge_units(): merge_new_ids = {merge.merge_new_unit_ids for merge in model.merges} assert merge_new_ids == {5, 6} + # Test list format + valid_merge_list = { + "format_version": "1", + "unit_ids": [1, 2, 3, 4], + "merges": [[1, 2], [3, 4]], # Merge each pair into a new unit + } + model = CurationModel(**valid_merge_list) + assert len(model.merges) == 2 + # Test invalid merge group (single unit) invalid_merge_group = { "format_version": "1", From 482f0be221177f31f341e4a70e7a1621f2850fdb Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 27 Mar 2025 15:20:09 -0400 Subject: [PATCH 06/11] Simplify and centralize conversion and checks --- .../curation/curation_format.py | 122 ++--------- src/spikeinterface/curation/curation_model.py | 207 ++++++++++++------ .../curation/sortingview_curation.py | 143 ++---------- .../curation/tests/test_curation_format.py | 165 +++++++++----- .../curation/tests/test_curation_model.py | 46 ++-- 5 files changed, 320 insertions(+), 363 deletions(-) diff --git a/src/spikeinterface/curation/curation_format.py b/src/spikeinterface/curation/curation_format.py index 186bb34568..d634735344 100644 --- a/src/spikeinterface/curation/curation_format.py +++ b/src/spikeinterface/curation/curation_format.py @@ -1,9 +1,7 @@ from __future__ import annotations -import copy import numpy as np -from spikeinterface import curation from spikeinterface.core import BaseSorting, SortingAnalyzer, apply_merges_to_sorting from spikeinterface.curation.curation_model import CurationModel @@ -23,69 +21,6 @@ def validate_curation_dict(curation_dict: dict): CurationModel(**curation_dict) -def convert_from_sortingview_curation_format_v0(sortingview_dict: dict, destination_format: str = "1"): - """ - Converts the old sortingview curation format (v0) into a curation dictionary new format (v1) - Couple of caveats: - * The list of units is not available in the original sortingview dictionary. We set it to None - * Labels can not be mutually exclusive. - * Labels have no category, so we regroup them under the "all_labels" category - - Parameters - ---------- - sortingview_dict : dict - Dictionary containing the curation information from sortingview - destination_format : str - Version of the format to use. - Default to "1" - - Returns - ------- - curation_dict : dict - A curation dictionary - """ - - assert destination_format == "1" - if "mergeGroups" not in sortingview_dict.keys(): - sortingview_dict["mergeGroups"] = [] - merge_groups = sortingview_dict["mergeGroups"] - - first_unit_id = next(iter(sortingview_dict["labelsByUnit"].keys())) - if str.isdigit(first_unit_id): - unit_id_type = int - else: - unit_id_type = str - - all_units = [] - all_labels = [] - manual_labels = [] - general_cat = "all_labels" - for unit_id_, l_labels in sortingview_dict["labelsByUnit"].items(): - all_labels.extend(l_labels) - # recorver the correct type for unit_id - unit_id = unit_id_type(unit_id_) - if unit_id not in all_units: - all_units.append(unit_id) - manual_labels.append({"unit_id": unit_id, general_cat: l_labels}) - labels_def = {"all_labels": {"name": "all_labels", "label_options": list(set(all_labels)), "exclusive": False}} - - for merge_group in merge_groups: - for unit_id in merge_group: - if unit_id not in all_units: - all_units.append(unit_id) - - curation_dict = { - "format_version": destination_format, - "unit_ids": all_units, - "label_definitions": labels_def, - "manual_labels": manual_labels, - "merge_unit_groups": merge_groups, - "removed_units": [], - } - - return curation_dict - - def curation_label_to_vectors(curation_dict_or_model: dict | CurationModel): """ Transform the curation dict into dict of vectors. @@ -135,29 +70,6 @@ def curation_label_to_vectors(curation_dict_or_model: dict | CurationModel): return labels -def clean_curation_dict(curation_dict: dict): - """ - In some cases the curation_dict can have inconsistencies (like in the sorting view format). - For instance, some unit_ids are both in 'merge_unit_groups' and 'removed_units'. - This is ambiguous! - - This cleaner helper function ensures units tagged as `removed_units` are removed from the `merge_unit_groups` - """ - curation_dict = copy.deepcopy(curation_dict) - - clean_merge_unit_groups = [] - for group in curation_dict["merge_unit_groups"]: - clean_group = [] - for unit_id in group: - if unit_id not in curation_dict["removed_units"]: - clean_group.append(unit_id) - if len(clean_group) > 1: - clean_merge_unit_groups.append(clean_group) - - curation_dict["merge_unit_groups"] = clean_merge_unit_groups - return curation_dict - - def curation_label_to_dataframe(curation_dict_or_model: dict | CurationModel): """ Transform the curation dict into a pandas dataframe. @@ -215,7 +127,8 @@ def apply_curation_labels( all_values[unit_ind] = values[ind] sorting.set_property(key, all_values) - for new_unit_id, old_group_ids in zip(new_unit_ids, curation_model.merge_unit_groups): + for new_unit_id, merge in zip(new_unit_ids, curation_model.merges): + old_group_ids = merge.merge_unit_group for label_key, label_def in curation_model.label_definitions.items(): if label_def.exclusive: group_values = [] @@ -253,8 +166,8 @@ def apply_curation( Apply curation dict to a Sorting or a SortingAnalyzer. Steps are done in this order: - 1. Apply removal using curation_dict["removed_units"] - 2. Apply merges using curation_dict["merge_unit_groups"] + 1. Apply removal using curation_dict["removed"] + 2. Apply merges using curation_dict["merges"] 3. Set labels using curation_dict["manual_labels"] A new Sorting or SortingAnalyzer (in memory) is returned. @@ -303,24 +216,27 @@ def apply_curation( if isinstance(sorting_or_analyzer, BaseSorting): sorting = sorting_or_analyzer - sorting = sorting.remove_units(curation_model.removed_units) - sorting, _, new_unit_ids = apply_merges_to_sorting( - sorting, - curation_model.merge_unit_groups, - censor_ms=censor_ms, - return_extra=True, - new_id_strategy=new_id_strategy, - ) + sorting = sorting.remove_units(curation_model.removed) + if len(curation_model.merges) > 0: + sorting, _, new_unit_ids = apply_merges_to_sorting( + sorting, + merge_unit_groups=[m.merge_unit_group for m in curation_model.merges], + censor_ms=censor_ms, + return_extra=True, + new_id_strategy=new_id_strategy, + ) + else: + new_unit_ids = [] apply_curation_labels(sorting, new_unit_ids, curation_model) return sorting elif isinstance(sorting_or_analyzer, SortingAnalyzer): analyzer = sorting_or_analyzer - if len(curation_model.removed_units) > 0: - analyzer = analyzer.remove_units(curation_model.removed_units) - if len(curation_model.merge_unit_groups) > 0: + if len(curation_model.removed) > 0: + analyzer = analyzer.remove_units(curation_model.removed) + if len(curation_model.removed) > 0: analyzer, new_unit_ids = analyzer.merge_units( - curation_model.merge_unit_groups, + merge_unit_groups=[m.merge_unit_group for m in curation_model.merges], censor_ms=censor_ms, merging_mode=merging_mode, sparsity_overlap=sparsity_overlap, diff --git a/src/spikeinterface/curation/curation_model.py b/src/spikeinterface/curation/curation_model.py index a595f32eac..787af88111 100644 --- a/src/spikeinterface/curation/curation_model.py +++ b/src/spikeinterface/curation/curation_model.py @@ -1,12 +1,9 @@ -from pydantic import BaseModel, Field, field_validator, model_validator -from typing import List, Dict, Union, Optional, Literal -from itertools import combinations, chain +from pydantic import BaseModel, Field, model_validator, field_validator +from typing import List, Dict, Union, Optional, Literal, Tuple +from itertools import chain, combinations import numpy as np -supported_curation_format_versions = {"1"} - - class LabelDefinition(BaseModel): name: str = Field(..., description="Name of the label") label_options: List[str] = Field(..., description="List of possible label options", min_length=2) @@ -42,6 +39,9 @@ class Split(BaseModel): class CurationModel(BaseModel): + supported_versions: Tuple[Literal["1"], Literal["2"]] = Field( + default=["1", "2"], description="Supported versions of the curation format" + ) format_version: str = Field(..., description="Version of the curation format") unit_ids: List[Union[int, str]] = Field(..., description="List of unit IDs") label_definitions: Optional[Dict[str, LabelDefinition]] = Field( @@ -52,78 +52,81 @@ class CurationModel(BaseModel): merges: Optional[List[Merge]] = Field(default=None, description="List of merges") splits: Optional[List[Split]] = Field(default=None, description="List of splits") - @field_validator("format_version") - def check_format_version(cls, v): - if v not in supported_curation_format_versions: - raise ValueError(f"Format version ({v}) not supported. Only {supported_curation_format_versions} are valid") - return v - - @model_validator(mode="before") - def add_label_definition_name(cls, values): - label_definitions = values.get("label_definitions") + @field_validator("label_definitions", mode="before") + def add_label_definition_name(cls, label_definitions): if label_definitions is None: - values["label_definitions"] = {} - return values - if isinstance(values["label_definitions"], dict): - if label_definitions is None: - label_definitions = {} - else: - for key in list(label_definitions.keys()): - if isinstance(label_definitions[key], dict): - label_definitions[key]["name"] = key - values["label_definitions"] = label_definitions - return values - - @model_validator(mode="before") + return {} + if isinstance(label_definitions, dict): + label_definitions = dict(label_definitions) + for key in list(label_definitions.keys()): + if isinstance(label_definitions[key], dict): + label_definitions[key] = dict(label_definitions[key]) + label_definitions[key]["name"] = key + return label_definitions + return label_definitions + + @classmethod def check_manual_labels(cls, values): - unit_ids = values["unit_ids"] + values = dict(values) + unit_ids = list(values["unit_ids"]) manual_labels = values.get("manual_labels") if manual_labels is None: values["manual_labels"] = [] else: - for manual_label in manual_labels: + manual_labels = list(manual_labels) + for i, manual_label in enumerate(manual_labels): + manual_label = dict(manual_label) unit_id = manual_label["unit_id"] labels = manual_label.get("labels") if labels is None: labels = set(manual_label.keys()) - {"unit_id"} manual_label["labels"] = {} + else: + manual_label["labels"] = {k: list(v) for k, v in labels.items()} for label in labels: if label not in values["label_definitions"]: raise ValueError(f"Manual label {unit_id} has an unknown label {label}") if label not in manual_label["labels"]: if label in manual_label: - manual_label["labels"][label] = manual_label[label] + manual_label["labels"][label] = list(manual_label[label]) else: raise ValueError(f"Manual label {unit_id} has no value for label {label}") if unit_id not in unit_ids: raise ValueError(f"Manual label unit_id {unit_id} is not in the unit list") + manual_labels[i] = manual_label + values["manual_labels"] = manual_labels return values - @model_validator(mode="before") + @classmethod def check_merges(cls, values): - unit_ids = values["unit_ids"] + values = dict(values) + unit_ids = list(values["unit_ids"]) merges = values.get("merges") if merges is None: values["merges"] = [] return values - elif isinstance(merges, list): - # Convert list of lists to Merge objects - for i, merge in enumerate(merges): - if isinstance(merge, list): - merges[i] = Merge(merge_unit_group=merge) - elif isinstance(merges, dict): - # Convert dict format to list of Merge objects if needed + + if isinstance(merges, dict): + # Convert dict format to list of Merge objects merge_list = [] for merge_new_id, merge_group in merges.items(): - merge_list.append({"merge_unit_group": merge_group, "merge_new_unit_ids": merge_new_id}) + merge_list.append({"merge_unit_group": list(merge_group), "merge_new_unit_ids": merge_new_id}) merges = merge_list - values["merges"] = merges - # Convert dict items to Merge objects if needed + # Make a copy of the list + merges = list(merges) + + # Convert items to Merge objects for i, merge in enumerate(merges): + if isinstance(merge, list): + merge = {"merge_unit_group": list(merge)} if isinstance(merge, dict): + merge = dict(merge) + if "merge_unit_group" in merge: + merge["merge_unit_group"] = list(merge["merge_unit_group"]) merges[i] = Merge(**merge) + # Validate merges for merge in merges: # Check unit ids exist for unit_id in merge.merge_unit_group: @@ -139,46 +142,62 @@ def check_merges(cls, values): if merge.merge_new_unit_ids in unit_ids: raise ValueError(f"New unit ID {merge.merge_new_unit_ids} is already in the unit list") + values["merges"] = merges return values - @model_validator(mode="before") + @classmethod def check_splits(cls, values): - unit_ids = values["unit_ids"] + values = dict(values) + unit_ids = list(values["unit_ids"]) splits = values.get("splits") if splits is None: values["splits"] = [] return values - # Convert dict format to list of Split objects if needed + # Convert dict format to list format if isinstance(splits, dict): split_list = [] for unit_id, split_data in splits.items(): - # If split_data is a list of indices, assume indices mode if isinstance(split_data[0], (list, np.ndarray)) if split_data else False: - split_list.append({"unit_id": unit_id, "split_mode": "indices", "split_indices": split_data}) - # Otherwise assume it's a list of labels + split_list.append( + { + "unit_id": unit_id, + "split_mode": "indices", + "split_indices": [list(indices) for indices in split_data], + } + ) else: - split_list.append({"unit_id": unit_id, "split_mode": "labels", "split_labels": split_data}) + split_list.append({"unit_id": unit_id, "split_mode": "labels", "split_labels": list(split_data)}) splits = split_list - values["splits"] = splits - # Convert dict items to Split objects if needed + # Make a copy of the list + splits = list(splits) + + # Convert items to Split objects for i, split in enumerate(splits): if isinstance(split, dict): + split = dict(split) + if "split_indices" in split: + split["split_indices"] = [list(indices) for indices in split["split_indices"]] + if "split_labels" in split: + split["split_labels"] = list(split["split_labels"]) + if "split_new_unit_ids" in split: + split["split_new_unit_ids"] = list(split["split_new_unit_ids"]) splits[i] = Split(**split) + # Validate splits for split in splits: # Check unit exists if split.unit_id not in unit_ids: raise ValueError(f"Split unit_id {split.unit_id} is not in the unit list") - # Check split definition based on mode + # Validate based on mode if split.split_mode == "indices": if split.split_indices is None: raise ValueError(f"Split unit {split.unit_id} has no split_indices defined") if len(split.split_indices) < 1: raise ValueError(f"Split unit {split.unit_id} has empty split_indices") - # Check no duplicate indices across splits + # Check no duplicate indices all_indices = list(chain.from_iterable(split.split_indices)) if len(all_indices) != len(set(all_indices)): raise ValueError(f"Split unit {split.unit_id} has duplicate indices") @@ -189,7 +208,7 @@ def check_splits(cls, values): if len(split.split_labels) == 0: raise ValueError(f"Split unit {split.unit_id} has empty split_labels") - # Check new unit ids if provided + # Validate new unit IDs if split.split_new_unit_ids is not None: if split.split_mode == "indices": if len(split.split_new_unit_ids) != len(split.split_indices): @@ -202,30 +221,94 @@ def check_splits(cls, values): f"Number of new unit IDs does not match number of unique labels for unit {split.unit_id}" ) - # Check new ids not already used for new_id in split.split_new_unit_ids: if new_id in unit_ids: raise ValueError(f"New unit ID {new_id} is already in the unit list") + values["splits"] = splits return values - @model_validator(mode="before") + @classmethod def check_removed(cls, values): - unit_ids = values["unit_ids"] - removed = values.get("removed", []) + values = dict(values) + unit_ids = list(values["unit_ids"]) + removed = values.get("removed") if removed is None: - + values["removed"] = [] + else: + removed = list(removed) for unit_id in removed: if unit_id not in unit_ids: raise ValueError(f"Removed unit_id {unit_id} is not in the unit list") - - else: values["removed"] = removed + return values + + @classmethod + def convert_old_format(cls, values): + format_version = values.get("format_version", "0") + if format_version != "2": + values = dict(values) + if format_version == "0": + print("Conversion from format version v0 (sortingview) to v2") + if "mergeGroups" not in values.keys(): + values["mergeGroups"] = [] + merge_groups = values["mergeGroups"] + + first_unit_id = next(iter(values["labelsByUnit"].keys())) + if str.isdigit(first_unit_id): + unit_id_type = int + else: + unit_id_type = str + + all_units = [] + all_labels = [] + manual_labels = [] + general_cat = "all_labels" + for unit_id_, l_labels in values["labelsByUnit"].items(): + all_labels.extend(l_labels) + unit_id = unit_id_type(unit_id_) + if unit_id not in all_units: + all_units.append(unit_id) + manual_labels.append({"unit_id": unit_id, general_cat: list(l_labels)}) + labels_def = { + "all_labels": {"name": "all_labels", "label_options": list(set(all_labels)), "exclusive": False} + } + + values = { + "format_version": "2", + "unit_ids": values["unit_ids"], + "label_definitions": labels_def, + "manual_labels": list(manual_labels), + "merges": [{"merge_unit_group": list(group)} for group in merge_groups], + "splits": [], + "removed": [], + } + elif values["format_version"] == "1": + merge_unit_groups = values.get("merge_unit_groups") + if merge_unit_groups is not None: + values["merges"] = [{"merge_unit_group": list(group)} for group in merge_unit_groups] + removed_units = values.get("removed_units") + if removed_units is not None: + values["removed"] = list(removed_units) + return values + @model_validator(mode="before") + def validate_fields(cls, values): + values = dict(values) + values = cls.convert_old_format(values) + values = cls.check_manual_labels(values) + values = cls.check_merges(values) + values = cls.check_splits(values) + values = cls.check_removed(values) return values @model_validator(mode="after") def validate_curation_dict(cls, values): + if values.format_version not in values.supported_versions: + raise ValueError( + f"Format version {values.format_version} not supported. Only {values.supported_versions} are valid" + ) + labeled_unit_set = set([lbl.unit_id for lbl in values.manual_labels]) if values.manual_labels else set() merged_units_set = ( set(chain.from_iterable(merge.merge_unit_group for merge in values.merges)) if values.merges else set() diff --git a/src/spikeinterface/curation/sortingview_curation.py b/src/spikeinterface/curation/sortingview_curation.py index f33051309c..8970463831 100644 --- a/src/spikeinterface/curation/sortingview_curation.py +++ b/src/spikeinterface/curation/sortingview_curation.py @@ -7,13 +7,11 @@ import numpy as np from pathlib import Path -from .curationsorting import CurationSorting from .curation_format import ( - convert_from_sortingview_curation_format_v0, apply_curation, curation_label_to_vectors, - clean_curation_dict, ) +from .curation_model import CurationModel def get_kachery(): @@ -34,6 +32,7 @@ def get_kachery(): ) +# TODO: fix sortingview curation with new format def apply_sortingview_curation( sorting_or_analyzer, uri_or_json, exclude_labels=None, include_labels=None, skip_merge=False, verbose=None ): @@ -82,15 +81,20 @@ def apply_sortingview_curation( except: raise Exception(f"Could not retrieve curation from SortingView uri: {uri_or_json}") - # convert to new format - if "format_version" not in curation_dict: - curation_dict = convert_from_sortingview_curation_format_v0(curation_dict) - unit_ids = sorting_or_analyzer.unit_ids + curation_dict["unit_ids"] = unit_ids + curation_model = CurationModel(**curation_dict) - # this is a hack because it was not in the old format - curation_dict["unit_ids"] = list(unit_ids) + if not skip_merge: + sorting_curated = apply_curation(sorting_or_analyzer, curation_model) + else: + sorting_curated = sorting_or_analyzer + # now remove units based on labels + curation_model.merges = [] + curation_model.unit_ids = sorting_curated.unit_ids + + # this is a hack because it was not in the old format if exclude_labels is not None: assert include_labels is None, "Use either `include_labels` or `exclude_labels` to filter units." manual_labels = curation_label_to_vectors(curation_dict) @@ -99,7 +103,7 @@ def apply_sortingview_curation( remove_mask = manual_labels[k] removed_units.extend(unit_ids[remove_mask]) removed_units = np.unique(removed_units) - curation_dict["removed_units"] = removed_units + curation_model.removed = removed_units if include_labels is not None: manual_labels = curation_label_to_vectors(curation_dict) @@ -108,122 +112,9 @@ def apply_sortingview_curation( remove_mask = ~manual_labels[k] removed_units.extend(unit_ids[remove_mask]) removed_units = np.unique(removed_units) - curation_dict["removed_units"] = removed_units + curation_model.removed = removed_units - if skip_merge: - curation_dict["merge_unit_groups"] = [] - - # cleaner to ensure validity - curation_dict = clean_curation_dict(curation_dict) - - # apply - sorting_curated = apply_curation(sorting_or_analyzer, curation_dict, new_id_strategy="join") + if len(curation_model.removed) > 0: + sorting_curated = apply_curation(sorting_curated, curation_model) return sorting_curated - - -# TODO @alessio you remove this after testing -def apply_sortingview_curation_legacy( - sorting, uri_or_json, exclude_labels=None, include_labels=None, skip_merge=False, verbose=False -): - """ - Apply curation from SortingView manual curation. - First, merges (if present) are applied. Then labels are loaded and units - are optionally filtered based on exclude_labels and include_labels. - - Parameters - ---------- - sorting : BaseSorting - The sorting object to be curated - uri_or_json : str or Path - The URI curation link from SortingView or the path to the curation json file - exclude_labels : list, default: None - Optional list of labels to exclude (e.g. ["reject", "noise"]). - Mutually exclusive with include_labels - include_labels : list, default: None - Optional list of labels to include (e.g. ["accept"]). - Mutually exclusive with exclude_labels, by default None - skip_merge : bool, default: False - If True, merges are not applied (only labels) - verbose : bool, default: False - If True, output is verbose - - Returns - ------- - sorting_curated : BaseSorting - The curated sorting - """ - ka = get_kachery() - curation_sorting = CurationSorting(sorting, make_graph=False, properties_policy="keep") - - # get sorting view curation - if Path(uri_or_json).suffix == ".json" and not str(uri_or_json).startswith("gh://"): - with open(uri_or_json, "r") as f: - sortingview_curation_dict = json.load(f) - else: - try: - sortingview_curation_dict = ka.load_json(uri=uri_or_json) - except: - raise Exception(f"Could not retrieve curation from SortingView uri: {uri_or_json}") - - unit_ids_dtype = sorting.unit_ids.dtype - - # STEP 1: merge groups - labels_dict = sortingview_curation_dict["labelsByUnit"] - if "mergeGroups" in sortingview_curation_dict and not skip_merge: - merge_groups = sortingview_curation_dict["mergeGroups"] - for merge_group in merge_groups: - # Store labels of units that are about to be merged - labels_to_inherit = [] - for unit in merge_group: - labels_to_inherit.extend(labels_dict.get(str(unit), [])) - labels_to_inherit = list(set(labels_to_inherit)) # Remove duplicates - - if verbose: - print(f"Merging {merge_group}") - if unit_ids_dtype.kind in ("U", "S"): - merge_group = [str(unit) for unit in merge_group] - # if unit dtype is str, set new id as "{unit1}-{unit2}" - new_unit_id = "-".join(merge_group) - curation_sorting.merge(merge_group, new_unit_id=new_unit_id) - else: - # in this case, the CurationSorting takes care of finding a new unused int - curation_sorting.merge(merge_group, new_unit_id=None) - new_unit_id = curation_sorting.max_used_id # merged unit id - labels_dict[str(new_unit_id)] = labels_to_inherit - - # STEP 2: gather and apply sortingview curation labels - # In sortingview, a unit is not required to have all labels. - # For example, the first 3 units could be labeled as "accept". - # In this case, the first 3 values of the property "accept" will be True, the rest False - - # Initialize the properties dictionary - properties = { - label: np.zeros(len(curation_sorting.current_sorting.unit_ids), dtype=bool) - for labels in labels_dict.values() - for label in labels - } - - # Populate the properties dictionary - for unit_index, unit_id in enumerate(curation_sorting.current_sorting.unit_ids): - unit_id_str = str(unit_id) - if unit_id_str in labels_dict: - for label in labels_dict[unit_id_str]: - properties[label][unit_index] = True - - for prop_name, prop_values in properties.items(): - curation_sorting.current_sorting.set_property(prop_name, prop_values) - - if include_labels is not None or exclude_labels is not None: - units_to_remove = [] - unit_ids = curation_sorting.current_sorting.unit_ids - assert include_labels or exclude_labels, "Use either `include_labels` or `exclude_labels` to filter units." - if include_labels: - for include_label in include_labels: - units_to_remove.extend(unit_ids[curation_sorting.current_sorting.get_property(include_label) == False]) - if exclude_labels: - for exclude_label in exclude_labels: - units_to_remove.extend(unit_ids[curation_sorting.current_sorting.get_property(exclude_label) == True]) - units_to_remove = np.unique(units_to_remove) - curation_sorting.remove_units(units_to_remove) - return curation_sorting.current_sorting diff --git a/src/spikeinterface/curation/tests/test_curation_format.py b/src/spikeinterface/curation/tests/test_curation_format.py index 0d9562f404..c3ed4a115f 100644 --- a/src/spikeinterface/curation/tests/test_curation_format.py +++ b/src/spikeinterface/curation/tests/test_curation_format.py @@ -1,8 +1,5 @@ import pytest -from pydantic import BaseModel, ValidationError, field_validator - - from pathlib import Path import json import numpy as np @@ -11,15 +8,34 @@ from spikeinterface.curation.curation_format import ( validate_curation_dict, - convert_from_sortingview_curation_format_v0, curation_label_to_vectors, curation_label_to_dataframe, apply_curation, ) +""" +v1 = { + 'format_version': '1', + 'unit_ids': List[int | str], + 'label_definitions': { + 'category_key1': + { + 'label_options': List[str], + 'exclusive': bool} + }, + 'manual_labels': [ + { + 'unit_id': str or int, + 'category_key1': List[str], + } + ], + 'removed_units': List[int | str] # Can not be in the merged_units + 'merge_unit_groups': List[List[int | str]], # one cell goes into at most one list +} -"""example = { - 'unit_ids': List[str, int], +v2 = { + 'format_version': '2', + 'unit_ids': List[int | int], 'label_definitions': { 'category_key1': { @@ -27,18 +43,38 @@ 'exclusive': bool} }, 'manual_labels': [ - {'unit_id': str or int, - category_key1': List[str], + { + 'unit_id': str | int, + 'category_key1': List[str], } ], - 'merge_unit_groups': List[List[unit_ids]], # one cell goes into at most one list - 'removed_units': List[unit_ids] # Can not be in the merged_units -} -""" + 'removed': List[unit_ids], # Can not be in the merged_units + 'merges': [ + { + 'merge_unit_group': List[unit_ids], + 'merge_new_unit_id': int | str (optional) + } + ], + 'splits': [ + { + 'unit_id': int | str + 'split_mode': 'indices' or 'labels', + 'split_indices': List[List[int]], + 'split_labels': List[int]], + 'split_new_unit_ids': List[int | str] + } + ] + +sortingview_curation = { + 'mergeGroups': List[List[int | str]], + 'labelsByUnit': { + 'unit_id': List[str] + } +""" curation_ids_int = { - "format_version": "1", + "format_version": "2", "unit_ids": [1, 2, 3, 6, 10, 14, 20, 31, 42], "label_definitions": { "quality": {"label_options": ["good", "noise", "MUA", "artifact"], "exclusive": True}, @@ -51,19 +87,21 @@ {"unit_id": 1, "quality": ["good"]}, { "unit_id": 2, - "quality": [ - "noise", - ], + "quality": ["noise"], "putative_type": ["excitatory", "pyramidal"], }, {"unit_id": 3, "putative_type": ["inhibitory"]}, ], - "merge_unit_groups": [[3, 6], [10, 14, 20]], # one cell goes into at most one list - "removed_units": [31, 42], # Can not be in the merged_units + "merges": [{"merge_unit_group": [3, 6]}, {"merge_unit_group": [10, 14, 20]}], + "splits": [], + "removed": [31, 42], } +# Test dictionary format for merges +curation_ids_int_dict = {**curation_ids_int, "merges": {50: [3, 6], 51: [10, 14, 20]}} + curation_ids_str = { - "format_version": "1", + "format_version": "2", "unit_ids": ["u1", "u2", "u3", "u6", "u10", "u14", "u20", "u31", "u42"], "label_definitions": { "quality": {"label_options": ["good", "noise", "MUA", "artifact"], "exclusive": True}, @@ -76,40 +114,65 @@ {"unit_id": "u1", "quality": ["good"]}, { "unit_id": "u2", - "quality": [ - "noise", - ], + "quality": ["noise"], "putative_type": ["excitatory", "pyramidal"], }, {"unit_id": "u3", "putative_type": ["inhibitory"]}, ], - "merge_unit_groups": [["u3", "u6"], ["u10", "u14", "u20"]], # one cell goes into at most one list - "removed_units": ["u31", "u42"], # Can not be in the merged_units + "merges": [{"merge_unit_group": ["u3", "u6"]}, {"merge_unit_group": ["u10", "u14", "u20"]}], + "splits": [], + "removed": ["u31", "u42"], } -# This is a failure example with duplicated merge -duplicate_merge = curation_ids_int.copy() -duplicate_merge["merge_unit_groups"] = [[3, 6, 10], [10, 14, 20]] +# Test dictionary format for merges with string IDs +curation_ids_str_dict = {**curation_ids_str, "merges": {"u50": ["u3", "u6"], "u51": ["u10", "u14", "u20"]}} +# Test with splits +curation_with_splits = { + **curation_ids_int, + "splits": [ + {"unit_id": 2, "split_mode": "indices", "split_indices": [[0, 1, 2], [3, 4, 5]], "split_new_unit_ids": [50, 51]} + ], +} + +# Test dictionary format for splits +curation_with_splits_dict = {**curation_ids_int, "splits": {2: [[0, 1, 2], [3, 4, 5]]}} + +# This is a failure example with duplicated merge +duplicate_merge = {**curation_ids_int, "merges": [{"merge_unit_group": [3, 6, 10]}, {"merge_unit_group": [10, 14, 20]}]} # This is a failure example with unit 3 both in removed and merged -merged_and_removed = curation_ids_int.copy() -merged_and_removed["merge_unit_groups"] = [[3, 6], [10, 14, 20]] -merged_and_removed["removed_units"] = [3, 31, 42] +merged_and_removed = { + **curation_ids_int, + "merges": [{"merge_unit_group": [3, 6]}, {"merge_unit_group": [10, 14, 20]}], + "removed": [3, 31, 42], +} -# this is a failure because unit 99 is not in the initial list -unknown_merged_unit = curation_ids_int.copy() -unknown_merged_unit["merge_unit_groups"] = [[3, 6, 99], [10, 14, 20]] +# This is a failure because unit 99 is not in the initial list +unknown_merged_unit = { + **curation_ids_int, + "merges": [{"merge_unit_group": [3, 6, 99]}, {"merge_unit_group": [10, 14, 20]}], +} -# this is a failure because unit 99 is not in the initial list -unknown_removed_unit = curation_ids_int.copy() -unknown_removed_unit["removed_units"] = [31, 42, 99] +# This is a failure because unit 99 is not in the initial list +unknown_removed_unit = {**curation_ids_int, "removed": [31, 42, 99]} def test_curation_format_validation(): + # Test basic formats + print(curation_ids_int) validate_curation_dict(curation_ids_int) + print(curation_ids_int) validate_curation_dict(curation_ids_str) + # Test dictionary formats + validate_curation_dict(curation_ids_int_dict) + validate_curation_dict(curation_ids_str_dict) + + # Test splits + validate_curation_dict(curation_with_splits) + validate_curation_dict(curation_with_splits_dict) + with pytest.raises(ValueError): # Raised because duplicated merged units validate_curation_dict(duplicate_merge) @@ -125,13 +188,13 @@ def test_curation_format_validation(): def test_to_from_json(): - json.loads(json.dumps(curation_ids_int, indent=4)) json.loads(json.dumps(curation_ids_str, indent=4)) + json.loads(json.dumps(curation_ids_int_dict, indent=4)) + json.loads(json.dumps(curation_with_splits, indent=4)) def test_convert_from_sortingview_curation_format_v0(): - parent_folder = Path(__file__).parent for filename in ( "sv-sorting-curation.json", @@ -139,18 +202,13 @@ def test_convert_from_sortingview_curation_format_v0(): "sv-sorting-curation-str.json", "sv-sorting-curation-false-positive.json", ): - json_file = parent_folder / filename with open(json_file, "r") as f: curation_v0 = json.load(f) - # print(curation_v0) - curation_v1 = convert_from_sortingview_curation_format_v0(curation_v0) - # print(curation_v1) - validate_curation_dict(curation_v1) + validate_curation_dict(curation_v0) def test_curation_label_to_vectors(): - labels = curation_label_to_vectors(curation_ids_int) assert "quality" in labels assert "excitatory" in labels @@ -161,36 +219,45 @@ def test_curation_label_to_vectors(): def test_curation_label_to_dataframe(): - df = curation_label_to_dataframe(curation_ids_int) assert "quality" in df.columns assert "excitatory" in df.columns print(df) df = curation_label_to_dataframe(curation_ids_str) - # print(df) + print(df) def test_apply_curation(): recording, sorting = generate_ground_truth_recording(durations=[10.0], num_units=9, seed=2205) - sorting._main_ids = np.array([1, 2, 3, 6, 10, 14, 20, 31, 42]) + sorting = sorting.rename_units([1, 2, 3, 6, 10, 14, 20, 31, 42]) analyzer = create_sorting_analyzer(sorting, recording, sparse=False) + # Test with list format sorting_curated = apply_curation(sorting, curation_ids_int) assert sorting_curated.get_property("quality", ids=[1])[0] == "good" assert sorting_curated.get_property("quality", ids=[2])[0] == "noise" assert sorting_curated.get_property("excitatory", ids=[2])[0] + # Test with dictionary format + sorting_curated = apply_curation(sorting, curation_ids_int_dict) + assert sorting_curated.get_property("quality", ids=[1])[0] == "good" + assert sorting_curated.get_property("quality", ids=[2])[0] == "noise" + assert sorting_curated.get_property("excitatory", ids=[2])[0] + + # Test with splits + sorting_curated = apply_curation(sorting, curation_with_splits) + assert sorting_curated.get_property("quality", ids=[1])[0] == "good" + + # Test analyzer analyzer_curated = apply_curation(analyzer, curation_ids_int) assert "quality" in analyzer_curated.sorting.get_property_keys() if __name__ == "__main__": - test_curation_format_validation() test_curation_format_validation() test_to_from_json() test_convert_from_sortingview_curation_format_v0() test_curation_label_to_vectors() test_curation_label_to_dataframe() - test_apply_curation() diff --git a/src/spikeinterface/curation/tests/test_curation_model.py b/src/spikeinterface/curation/tests/test_curation_model.py index 1ebb36296d..3db904a480 100644 --- a/src/spikeinterface/curation/tests/test_curation_model.py +++ b/src/spikeinterface/curation/tests/test_curation_model.py @@ -13,15 +13,15 @@ def test_format_version(): # Invalid format version with pytest.raises(ValidationError): - CurationModel(format_version="2", unit_ids=[1, 2, 3]) + CurationModel(format_version="3", unit_ids=[1, 2, 3]) with pytest.raises(ValidationError): - CurationModel(format_version="0", unit_ids=[1, 2, 3]) + CurationModel(format_version="0.1", unit_ids=[1, 2, 3]) # Test data for label definitions def test_label_definitions(): valid_label_def = { - "format_version": "1", + "format_version": "2", "unit_ids": [1, 2, 3], "label_definitions": { "quality": LabelDefinition(name="quality", label_options=["good", "noise"], exclusive=True), @@ -42,7 +42,7 @@ def test_label_definitions(): # Test manual labels def test_manual_labels(): valid_labels = { - "format_version": "1", + "format_version": "2", "unit_ids": [1, 2, 3], "label_definitions": { "quality": LabelDefinition(name="quality", label_options=["good", "noise"], exclusive=True), @@ -59,7 +59,7 @@ def test_manual_labels(): # Test invalid unit ID invalid_unit = { - "format_version": "1", + "format_version": "2", "unit_ids": [1, 2, 3], "label_definitions": { "quality": LabelDefinition(name="quality", label_options=["good", "noise"], exclusive=True) @@ -71,7 +71,7 @@ def test_manual_labels(): # Test violation of exclusive label invalid_exclusive = { - "format_version": "1", + "format_version": "2", "unit_ids": [1, 2, 3], "label_definitions": { "quality": LabelDefinition(name="quality", label_options=["good", "noise"], exclusive=True) @@ -88,7 +88,7 @@ def test_manual_labels(): def test_merge_units(): # Test list format valid_merge = { - "format_version": "1", + "format_version": "2", "unit_ids": [1, 2, 3, 4], "merges": [ {"merge_unit_group": [1, 2], "merge_new_unit_ids": 5}, @@ -102,7 +102,7 @@ def test_merge_units(): assert model.merges[1].merge_new_unit_ids == 6 # Test dictionary format - valid_merge_dict = {"format_version": "1", "unit_ids": [1, 2, 3, 4], "merges": {5: [1, 2], 6: [3, 4]}} + valid_merge_dict = {"format_version": "2", "unit_ids": [1, 2, 3, 4], "merges": {5: [1, 2], 6: [3, 4]}} model = CurationModel(**valid_merge_dict) assert len(model.merges) == 2 @@ -111,7 +111,7 @@ def test_merge_units(): # Test list format valid_merge_list = { - "format_version": "1", + "format_version": "2", "unit_ids": [1, 2, 3, 4], "merges": [[1, 2], [3, 4]], # Merge each pair into a new unit } @@ -120,7 +120,7 @@ def test_merge_units(): # Test invalid merge group (single unit) invalid_merge_group = { - "format_version": "1", + "format_version": "2", "unit_ids": [1, 2, 3], "merges": [{"merge_unit_group": [1], "merge_new_unit_ids": 4}], } @@ -129,7 +129,7 @@ def test_merge_units(): # Test overlapping merge groups invalid_overlap = { - "format_version": "1", + "format_version": "2", "unit_ids": [1, 2, 3], "merges": [ {"merge_unit_group": [1, 2], "merge_new_unit_ids": 4}, @@ -144,7 +144,7 @@ def test_merge_units(): def test_split_units(): # Test indices mode with list format valid_split_indices = { - "format_version": "1", + "format_version": "2", "unit_ids": [1, 2, 3], "splits": [ { @@ -163,7 +163,7 @@ def test_split_units(): # Test labels mode with list format valid_split_labels = { - "format_version": "1", + "format_version": "2", "unit_ids": [1, 2, 3], "splits": [ {"unit_id": 1, "split_mode": "labels", "split_labels": [0, 0, 1, 1, 0, 2], "split_new_unit_ids": [4, 5, 6]} @@ -177,7 +177,7 @@ def test_split_units(): # Test dictionary format with indices valid_split_dict = { - "format_version": "1", + "format_version": "2", "unit_ids": [1, 2, 3], "splits": { 1: [[0, 1, 2], [3, 4, 5]], # Split unit 1 into two parts @@ -191,7 +191,7 @@ def test_split_units(): # Test invalid unit ID invalid_unit_id = { - "format_version": "1", + "format_version": "2", "unit_ids": [1, 2, 3], "splits": [{"unit_id": 4, "split_mode": "indices", "split_indices": [[0, 1], [2, 3]]}], # Non-existent unit } @@ -200,7 +200,7 @@ def test_split_units(): # Test invalid new unit IDs count for indices mode invalid_new_ids = { - "format_version": "1", + "format_version": "2", "unit_ids": [1, 2, 3], "splits": [ { @@ -217,19 +217,19 @@ def test_split_units(): # Test removed units def test_removed_units(): - valid_remove = {"format_version": "1", "unit_ids": [1, 2, 3], "removed": [2]} + valid_remove = {"format_version": "2", "unit_ids": [1, 2, 3], "removed": [2]} model = CurationModel(**valid_remove) assert len(model.removed) == 1 # Test removing non-existent unit - invalid_remove = {"format_version": "1", "unit_ids": [1, 2, 3], "removed": [4]} # Non-existent unit + invalid_remove = {"format_version": "2", "unit_ids": [1, 2, 3], "removed": [4]} # Non-existent unit with pytest.raises(ValidationError): CurationModel(**invalid_remove) # Test conflict between merge and remove invalid_merge_remove = { - "format_version": "1", + "format_version": "2", "unit_ids": [1, 2, 3], "merges": [{"merge_unit_group": [1, 2], "merge_new_unit_ids": 4}], "removed": [1], # Unit is both merged and removed @@ -241,7 +241,7 @@ def test_removed_units(): # Test complete model with multiple operations def test_complete_model(): complete_model = { - "format_version": "1", + "format_version": "2", "unit_ids": [1, 2, 3, 4, 5], "label_definitions": { "quality": LabelDefinition(name="quality", label_options=["good", "noise"], exclusive=True), @@ -256,7 +256,7 @@ def test_complete_model(): } model = CurationModel(**complete_model) - assert model.format_version == "1" + assert model.format_version == "2" assert len(model.unit_ids) == 5 assert len(model.label_definitions) == 2 assert len(model.manual_labels) == 1 @@ -266,7 +266,7 @@ def test_complete_model(): # Test dictionary format for complete model complete_model_dict = { - "format_version": "1", + "format_version": "2", "unit_ids": [1, 2, 3, 4, 5], "label_definitions": { "quality": LabelDefinition(name="quality", label_options=["good", "noise"], exclusive=True), @@ -279,7 +279,7 @@ def test_complete_model(): } model = CurationModel(**complete_model_dict) - assert model.format_version == "1" + assert model.format_version == "2" assert len(model.unit_ids) == 5 assert len(model.label_definitions) == 2 assert len(model.manual_labels) == 1 From f122db720b8ec39ce1231747290d9b3fe05b723f Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 27 Mar 2025 15:49:39 -0400 Subject: [PATCH 07/11] Fix sortingview tests --- .../curation/curation_format.py | 3 ++- .../curation/sortingview_curation.py | 27 +++++++++++-------- 2 files changed, 18 insertions(+), 12 deletions(-) diff --git a/src/spikeinterface/curation/curation_format.py b/src/spikeinterface/curation/curation_format.py index d634735344..ea1c19e719 100644 --- a/src/spikeinterface/curation/curation_format.py +++ b/src/spikeinterface/curation/curation_format.py @@ -182,11 +182,12 @@ def apply_curation( censor_ms : float | None, default: None When applying the merges, any consecutive spikes within the `censor_ms` are removed. This can be thought of as the desired refractory period. If `censor_ms=None`, no spikes are discarded. - new_id_strategy : "append" | "take_first", default: "append" + new_id_strategy : "append" | "take_first" | "join", default: "append" The strategy that should be used, if `new_unit_ids` is None, to create new unit_ids. * "append" : new_units_ids will be added at the end of max(sorting.unit_ids) * "take_first" : new_unit_ids will be the first unit_id of every list of merges + * "join" : new_unit_ids will be the concatenation of all unit_ids of every list of merges merging_mode : "soft" | "hard", default: "soft" How merges are performed for SortingAnalyzer. If the `merge_mode` is "soft" , merges will be approximated, with no reloading of the waveforms. This will lead to approximations. If `merge_mode` is "hard", recomputations are accurately diff --git a/src/spikeinterface/curation/sortingview_curation.py b/src/spikeinterface/curation/sortingview_curation.py index 8970463831..fe21b72263 100644 --- a/src/spikeinterface/curation/sortingview_curation.py +++ b/src/spikeinterface/curation/sortingview_curation.py @@ -11,7 +11,7 @@ apply_curation, curation_label_to_vectors, ) -from .curation_model import CurationModel +from .curation_model import CurationModel, Merge def get_kachery(): @@ -32,7 +32,6 @@ def get_kachery(): ) -# TODO: fix sortingview curation with new format def apply_sortingview_curation( sorting_or_analyzer, uri_or_json, exclude_labels=None, include_labels=None, skip_merge=False, verbose=None ): @@ -85,14 +84,8 @@ def apply_sortingview_curation( curation_dict["unit_ids"] = unit_ids curation_model = CurationModel(**curation_dict) - if not skip_merge: - sorting_curated = apply_curation(sorting_or_analyzer, curation_model) - else: - sorting_curated = sorting_or_analyzer - - # now remove units based on labels - curation_model.merges = [] - curation_model.unit_ids = sorting_curated.unit_ids + if skip_merge: + curation_model.merges = [] # this is a hack because it was not in the old format if exclude_labels is not None: @@ -114,7 +107,19 @@ def apply_sortingview_curation( removed_units = np.unique(removed_units) curation_model.removed = removed_units + # make merges and removed units if len(curation_model.removed) > 0: - sorting_curated = apply_curation(sorting_curated, curation_model) + clean_merges = [] + for merge in curation_model.merges: + clean_merge = [] + for unit_id in merge.merge_unit_group: + if unit_id not in curation_model.removed: + clean_merge.append(unit_id) + if len(clean_merge) > 1: + clean_merges.append(Merge(merge_unit_group=clean_merge)) + curation_model.merges = clean_merges + + # apply curation + sorting_curated = apply_curation(sorting_or_analyzer, curation_model, new_id_strategy="join") return sorting_curated From 4f14e90295dc2dab228bf26355fb98449b46f983 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 27 Mar 2025 15:57:42 -0400 Subject: [PATCH 08/11] Fix sortingview conversion --- src/spikeinterface/curation/curation_model.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/spikeinterface/curation/curation_model.py b/src/spikeinterface/curation/curation_model.py index 787af88111..94a277b3b7 100644 --- a/src/spikeinterface/curation/curation_model.py +++ b/src/spikeinterface/curation/curation_model.py @@ -273,10 +273,13 @@ def convert_old_format(cls, values): labels_def = { "all_labels": {"name": "all_labels", "label_options": list(set(all_labels)), "exclusive": False} } + for merge_group in merge_groups: + all_units.extend(merge_group) + all_units = list(set(all_units)) values = { "format_version": "2", - "unit_ids": values["unit_ids"], + "unit_ids": values.get("unit_ids", all_units), "label_definitions": labels_def, "manual_labels": list(manual_labels), "merges": [{"merge_unit_group": list(group)} for group in merge_groups], From 317f87c291a9b712664e5f951be8a4b2b711724a Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Fri, 28 Mar 2025 16:45:34 -0400 Subject: [PATCH 09/11] merge_new_unit_ids -> merge_new_unit_id --- src/spikeinterface/curation/curation_model.py | 20 +++++++++---------- .../curation/tests/test_curation_model.py | 20 +++++++++---------- 2 files changed, 19 insertions(+), 21 deletions(-) diff --git a/src/spikeinterface/curation/curation_model.py b/src/spikeinterface/curation/curation_model.py index 94a277b3b7..afcbef17d5 100644 --- a/src/spikeinterface/curation/curation_model.py +++ b/src/spikeinterface/curation/curation_model.py @@ -17,7 +17,7 @@ class ManualLabel(BaseModel): class Merge(BaseModel): merge_unit_group: List[Union[int, str]] = Field(..., description="List of groups of units to be merged") - merge_new_unit_ids: Optional[Union[int, str]] = Field(default=None, description="New unit IDs for the merge group") + merge_new_unit_id: Optional[Union[int, str]] = Field(default=None, description="New unit IDs for the merge group") class Split(BaseModel): @@ -67,7 +67,7 @@ def add_label_definition_name(cls, label_definitions): @classmethod def check_manual_labels(cls, values): - values = dict(values) + unit_ids = list(values["unit_ids"]) manual_labels = values.get("manual_labels") if manual_labels is None: @@ -99,7 +99,7 @@ def check_manual_labels(cls, values): @classmethod def check_merges(cls, values): - values = dict(values) + unit_ids = list(values["unit_ids"]) merges = values.get("merges") if merges is None: @@ -110,7 +110,7 @@ def check_merges(cls, values): # Convert dict format to list of Merge objects merge_list = [] for merge_new_id, merge_group in merges.items(): - merge_list.append({"merge_unit_group": list(merge_group), "merge_new_unit_ids": merge_new_id}) + merge_list.append({"merge_unit_group": list(merge_group), "merge_new_unit_id": merge_new_id}) merges = merge_list # Make a copy of the list @@ -138,16 +138,16 @@ def check_merges(cls, values): raise ValueError("Merge unit groups must have at least 2 elements") # Check new unit id not already used - if merge.merge_new_unit_ids is not None: - if merge.merge_new_unit_ids in unit_ids: - raise ValueError(f"New unit ID {merge.merge_new_unit_ids} is already in the unit list") + if merge.merge_new_unit_id is not None: + if merge.merge_new_unit_id in unit_ids: + raise ValueError(f"New unit ID {merge.merge_new_unit_id} is already in the unit list") values["merges"] = merges return values @classmethod def check_splits(cls, values): - values = dict(values) + unit_ids = list(values["unit_ids"]) splits = values.get("splits") if splits is None: @@ -230,7 +230,6 @@ def check_splits(cls, values): @classmethod def check_removed(cls, values): - values = dict(values) unit_ids = list(values["unit_ids"]) removed = values.get("removed") if removed is None: @@ -246,8 +245,6 @@ def check_removed(cls, values): @classmethod def convert_old_format(cls, values): format_version = values.get("format_version", "0") - if format_version != "2": - values = dict(values) if format_version == "0": print("Conversion from format version v0 (sortingview) to v2") if "mergeGroups" not in values.keys(): @@ -298,6 +295,7 @@ def convert_old_format(cls, values): @model_validator(mode="before") def validate_fields(cls, values): values = dict(values) + values["label_definitions"] = values.get("label_definitions", {}) values = cls.convert_old_format(values) values = cls.check_manual_labels(values) values = cls.check_merges(values) diff --git a/src/spikeinterface/curation/tests/test_curation_model.py b/src/spikeinterface/curation/tests/test_curation_model.py index 3db904a480..7354ac1892 100644 --- a/src/spikeinterface/curation/tests/test_curation_model.py +++ b/src/spikeinterface/curation/tests/test_curation_model.py @@ -91,22 +91,22 @@ def test_merge_units(): "format_version": "2", "unit_ids": [1, 2, 3, 4], "merges": [ - {"merge_unit_group": [1, 2], "merge_new_unit_ids": 5}, - {"merge_unit_group": [3, 4], "merge_new_unit_ids": 6}, + {"merge_unit_group": [1, 2], "merge_new_unit_id": 5}, + {"merge_unit_group": [3, 4], "merge_new_unit_id": 6}, ], } model = CurationModel(**valid_merge) assert len(model.merges) == 2 - assert model.merges[0].merge_new_unit_ids == 5 - assert model.merges[1].merge_new_unit_ids == 6 + assert model.merges[0].merge_new_unit_id == 5 + assert model.merges[1].merge_new_unit_id == 6 # Test dictionary format valid_merge_dict = {"format_version": "2", "unit_ids": [1, 2, 3, 4], "merges": {5: [1, 2], 6: [3, 4]}} model = CurationModel(**valid_merge_dict) assert len(model.merges) == 2 - merge_new_ids = {merge.merge_new_unit_ids for merge in model.merges} + merge_new_ids = {merge.merge_new_unit_id for merge in model.merges} assert merge_new_ids == {5, 6} # Test list format @@ -122,7 +122,7 @@ def test_merge_units(): invalid_merge_group = { "format_version": "2", "unit_ids": [1, 2, 3], - "merges": [{"merge_unit_group": [1], "merge_new_unit_ids": 4}], + "merges": [{"merge_unit_group": [1], "merge_new_unit_id": 4}], } with pytest.raises(ValidationError): CurationModel(**invalid_merge_group) @@ -132,8 +132,8 @@ def test_merge_units(): "format_version": "2", "unit_ids": [1, 2, 3], "merges": [ - {"merge_unit_group": [1, 2], "merge_new_unit_ids": 4}, - {"merge_unit_group": [2, 3], "merge_new_unit_ids": 5}, + {"merge_unit_group": [1, 2], "merge_new_unit_id": 4}, + {"merge_unit_group": [2, 3], "merge_new_unit_id": 5}, ], } with pytest.raises(ValidationError): @@ -231,7 +231,7 @@ def test_removed_units(): invalid_merge_remove = { "format_version": "2", "unit_ids": [1, 2, 3], - "merges": [{"merge_unit_group": [1, 2], "merge_new_unit_ids": 4}], + "merges": [{"merge_unit_group": [1, 2], "merge_new_unit_id": 4}], "removed": [1], # Unit is both merged and removed } with pytest.raises(ValidationError): @@ -248,7 +248,7 @@ def test_complete_model(): "tags": LabelDefinition(name="tags", label_options=["burst", "slow"], exclusive=False), }, "manual_labels": [{"unit_id": 1, "labels": {"quality": ["good"], "tags": ["burst"]}}], - "merges": [{"merge_unit_group": [2, 3], "merge_new_unit_ids": 6}], + "merges": [{"merge_unit_group": [2, 3], "merge_new_unit_id": 6}], "splits": [ {"unit_id": 4, "split_mode": "indices", "split_indices": [[0, 1], [2, 3]], "split_new_unit_ids": [7, 8]} ], From c64b8406998947fc6a7e69f5669fb6f6ae3f759e Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 12 Jun 2025 15:36:04 +0200 Subject: [PATCH 10/11] Implement feedback --- .../curation/curation_format.py | 6 +- src/spikeinterface/curation/curation_model.py | 102 ++++++++++-------- .../curation/sortingview_curation.py | 2 +- .../curation/tests/test_curation_model.py | 16 +-- 4 files changed, 72 insertions(+), 54 deletions(-) diff --git a/src/spikeinterface/curation/curation_format.py b/src/spikeinterface/curation/curation_format.py index ea1c19e719..bcd6bd9a4b 100644 --- a/src/spikeinterface/curation/curation_format.py +++ b/src/spikeinterface/curation/curation_format.py @@ -128,7 +128,7 @@ def apply_curation_labels( sorting.set_property(key, all_values) for new_unit_id, merge in zip(new_unit_ids, curation_model.merges): - old_group_ids = merge.merge_unit_group + old_group_ids = merge.unit_ids for label_key, label_def in curation_model.label_definitions.items(): if label_def.exclusive: group_values = [] @@ -221,7 +221,7 @@ def apply_curation( if len(curation_model.merges) > 0: sorting, _, new_unit_ids = apply_merges_to_sorting( sorting, - merge_unit_groups=[m.merge_unit_group for m in curation_model.merges], + merge_unit_groups=[m.unit_ids for m in curation_model.merges], censor_ms=censor_ms, return_extra=True, new_id_strategy=new_id_strategy, @@ -237,7 +237,7 @@ def apply_curation( analyzer = analyzer.remove_units(curation_model.removed) if len(curation_model.removed) > 0: analyzer, new_unit_ids = analyzer.merge_units( - merge_unit_groups=[m.merge_unit_group for m in curation_model.merges], + merge_unit_groups=[m.unit_ids for m in curation_model.merges], censor_ms=censor_ms, merging_mode=merging_mode, sparsity_overlap=sparsity_overlap, diff --git a/src/spikeinterface/curation/curation_model.py b/src/spikeinterface/curation/curation_model.py index afcbef17d5..6ece504ee3 100644 --- a/src/spikeinterface/curation/curation_model.py +++ b/src/spikeinterface/curation/curation_model.py @@ -16,24 +16,31 @@ class ManualLabel(BaseModel): class Merge(BaseModel): - merge_unit_group: List[Union[int, str]] = Field(..., description="List of groups of units to be merged") - merge_new_unit_id: Optional[Union[int, str]] = Field(default=None, description="New unit IDs for the merge group") + unit_ids: List[Union[int, str]] = Field(..., description="List of unit ids to be merged") + new_unit_id: Optional[Union[int, str]] = Field(default=None, description="New unit IDs for the merge group") class Split(BaseModel): unit_id: Union[int, str] = Field(..., description="ID of the unit") - split_mode: Literal["indices", "labels"] = Field( + mode: Literal["indices", "labels"] = Field( default="indices", description=( "Mode of the split. The split can be defined by indices or labels. " "If indices, the split is defined by the a list of lists of indices of spikes within spikes " - "belonging to the unit (`split_indices`). " - "If labels, the split is defined by a list of labels for each spike (`split_labels`). " + "belonging to the unit (`indices`). " + "If labels, the split is defined by a list of labels for each spike (`labels`). " ), ) - split_indices: Optional[Union[List[List[int]]]] = Field(default=None, description="List of indices for the split") - split_labels: Optional[List[int]] = Field(default=None, description="List of labels for the split") - split_new_unit_ids: Optional[List[Union[int, str]]] = Field( + indices: Optional[Union[List[int], List[List[int]]]] = Field( + default=None, + description=( + "List of indices for the split. If a list of indices, the unit is splt in 2 (provided indices/others). " + "If a list of lists, the unit is split in multiple groups (one for each list of indices), plus an optional " + "extra if the spike train has more spikes than the sum of the indices in the lists." + ), + ) + labels: Optional[List[int]] = Field(default=None, description="List of labels for the split") + new_unit_ids: Optional[List[Union[int, str]]] = Field( default=None, description="List of new unit IDs for each split" ) @@ -129,25 +136,36 @@ def check_merges(cls, values): # Validate merges for merge in merges: # Check unit ids exist - for unit_id in merge.merge_unit_group: + for unit_id in merge.unit_ids: if unit_id not in unit_ids: raise ValueError(f"Merge unit group unit_id {unit_id} is not in the unit list") # Check minimum group size - if len(merge.merge_unit_group) < 2: + if len(merge.unit_ids) < 2: raise ValueError("Merge unit groups must have at least 2 elements") # Check new unit id not already used - if merge.merge_new_unit_id is not None: - if merge.merge_new_unit_id in unit_ids: - raise ValueError(f"New unit ID {merge.merge_new_unit_id} is already in the unit list") + if merge.new_unit_id is not None: + if merge.new_unit_id in unit_ids: + raise ValueError(f"New unit ID {merge.new_unit_id} is already in the unit list") values["merges"] = merges return values @classmethod def check_splits(cls, values): - + """ + Checks and validates the splits in the curation model. + If `splits` is a dictionary with unit_id as key and split data as values, + it converts it to a list of Split objects. + Each Split object is then validated: + - Checks if the unit_id exists in the unit_ids list. + - Validates the mode (indices or labels). + - If mode is indices, checks that indices are defined and not empty, and that there are no duplicate indices. + - If mode is labels, checks that labels are defined and not empty. + - Validates new unit IDs if provided, ensuring they are not already in the unit_ids list and match the + number of splits. + """ unit_ids = list(values["unit_ids"]) splits = values.get("splits") if splits is None: @@ -162,12 +180,12 @@ def check_splits(cls, values): split_list.append( { "unit_id": unit_id, - "split_mode": "indices", - "split_indices": [list(indices) for indices in split_data], + "mode": "indices", + "indices": [list(indices) for indices in split_data], } ) else: - split_list.append({"unit_id": unit_id, "split_mode": "labels", "split_labels": list(split_data)}) + split_list.append({"unit_id": unit_id, "mode": "labels", "labels": list(split_data)}) splits = split_list # Make a copy of the list @@ -177,12 +195,12 @@ def check_splits(cls, values): for i, split in enumerate(splits): if isinstance(split, dict): split = dict(split) - if "split_indices" in split: - split["split_indices"] = [list(indices) for indices in split["split_indices"]] - if "split_labels" in split: - split["split_labels"] = list(split["split_labels"]) - if "split_new_unit_ids" in split: - split["split_new_unit_ids"] = list(split["split_new_unit_ids"]) + if "indices" in split: + split["indices"] = [list(indices) for indices in split["indices"]] + if "labels" in split: + split["labels"] = list(split["labels"]) + if "new_unit_ids" in split: + split["new_unit_ids"] = list(split["new_unit_ids"]) splits[i] = Split(**split) # Validate splits @@ -192,36 +210,36 @@ def check_splits(cls, values): raise ValueError(f"Split unit_id {split.unit_id} is not in the unit list") # Validate based on mode - if split.split_mode == "indices": - if split.split_indices is None: - raise ValueError(f"Split unit {split.unit_id} has no split_indices defined") - if len(split.split_indices) < 1: - raise ValueError(f"Split unit {split.unit_id} has empty split_indices") + if split.mode == "indices": + if split.indices is None: + raise ValueError(f"Split unit {split.unit_id} has no indices defined") + if len(split.indices) < 1: + raise ValueError(f"Split unit {split.unit_id} has empty indices") # Check no duplicate indices - all_indices = list(chain.from_iterable(split.split_indices)) + all_indices = list(chain.from_iterable(split.indices)) if len(all_indices) != len(set(all_indices)): raise ValueError(f"Split unit {split.unit_id} has duplicate indices") - elif split.split_mode == "labels": - if split.split_labels is None: - raise ValueError(f"Split unit {split.unit_id} has no split_labels defined") - if len(split.split_labels) == 0: - raise ValueError(f"Split unit {split.unit_id} has empty split_labels") + elif split.mode == "labels": + if split.labels is None: + raise ValueError(f"Split unit {split.unit_id} has no labels defined") + if len(split.labels) == 0: + raise ValueError(f"Split unit {split.unit_id} has empty labels") # Validate new unit IDs - if split.split_new_unit_ids is not None: - if split.split_mode == "indices": - if len(split.split_new_unit_ids) != len(split.split_indices): + if split.new_unit_ids is not None: + if split.mode == "indices": + if len(split.new_unit_ids) != len(split.indices): raise ValueError( f"Number of new unit IDs does not match number of splits for unit {split.unit_id}" ) - elif split.split_mode == "labels": - if len(split.split_new_unit_ids) != len(set(split.split_labels)): + elif split.mode == "labels": + if len(split.new_unit_ids) != len(set(split.labels)): raise ValueError( f"Number of new unit IDs does not match number of unique labels for unit {split.unit_id}" ) - for new_id in split.split_new_unit_ids: + for new_id in split.new_unit_ids: if new_id in unit_ids: raise ValueError(f"New unit ID {new_id} is already in the unit list") @@ -312,7 +330,7 @@ def validate_curation_dict(cls, values): labeled_unit_set = set([lbl.unit_id for lbl in values.manual_labels]) if values.manual_labels else set() merged_units_set = ( - set(chain.from_iterable(merge.merge_unit_group for merge in values.merges)) if values.merges else set() + set(chain.from_iterable(merge.unit_ids for merge in values.merges)) if values.merges else set() ) split_units_set = set(split.unit_id for split in values.splits) if values.splits else set() removed_set = set(values.removed) if values.removed else set() @@ -329,7 +347,7 @@ def validate_curation_dict(cls, values): raise ValueError("Curation format: some removed units are not in the unit list") # Check for units being merged multiple times - all_merging_groups = [set(merge.merge_unit_group) for merge in values.merges] if values.merges else [] + all_merging_groups = [set(merge.unit_ids) for merge in values.merges] if values.merges else [] for gp_1, gp_2 in combinations(all_merging_groups, 2): if len(gp_1.intersection(gp_2)) != 0: raise ValueError("Curation format: some units belong to multiple merge groups") diff --git a/src/spikeinterface/curation/sortingview_curation.py b/src/spikeinterface/curation/sortingview_curation.py index fe21b72263..b9d891c69f 100644 --- a/src/spikeinterface/curation/sortingview_curation.py +++ b/src/spikeinterface/curation/sortingview_curation.py @@ -112,7 +112,7 @@ def apply_sortingview_curation( clean_merges = [] for merge in curation_model.merges: clean_merge = [] - for unit_id in merge.merge_unit_group: + for unit_id in merge.unit_ids: if unit_id not in curation_model.removed: clean_merge.append(unit_id) if len(clean_merge) > 1: diff --git a/src/spikeinterface/curation/tests/test_curation_model.py b/src/spikeinterface/curation/tests/test_curation_model.py index 7354ac1892..edfa29162a 100644 --- a/src/spikeinterface/curation/tests/test_curation_model.py +++ b/src/spikeinterface/curation/tests/test_curation_model.py @@ -98,15 +98,15 @@ def test_merge_units(): model = CurationModel(**valid_merge) assert len(model.merges) == 2 - assert model.merges[0].merge_new_unit_id == 5 - assert model.merges[1].merge_new_unit_id == 6 + assert model.merges[0].new_unit_id == 5 + assert model.merges[1].new_unit_id == 6 # Test dictionary format valid_merge_dict = {"format_version": "2", "unit_ids": [1, 2, 3, 4], "merges": {5: [1, 2], 6: [3, 4]}} model = CurationModel(**valid_merge_dict) assert len(model.merges) == 2 - merge_new_ids = {merge.merge_new_unit_id for merge in model.merges} + merge_new_ids = {merge.new_unit_id for merge in model.merges} assert merge_new_ids == {5, 6} # Test list format @@ -158,8 +158,8 @@ def test_split_units(): model = CurationModel(**valid_split_indices) assert len(model.splits) == 1 - assert model.splits[0].split_mode == "indices" - assert len(model.splits[0].split_indices) == 2 + assert model.splits[0].mode == "indices" + assert len(model.splits[0].indices) == 2 # Test labels mode with list format valid_split_labels = { @@ -172,8 +172,8 @@ def test_split_units(): model = CurationModel(**valid_split_labels) assert len(model.splits) == 1 - assert model.splits[0].split_mode == "labels" - assert len(set(model.splits[0].split_labels)) == 3 + assert model.splits[0].mode == "labels" + assert len(set(model.splits[0].labels)) == 3 # Test dictionary format with indices valid_split_dict = { @@ -187,7 +187,7 @@ def test_split_units(): model = CurationModel(**valid_split_dict) assert len(model.splits) == 2 - assert all(split.split_mode == "indices" for split in model.splits) + assert all(split.mode == "indices" for split in model.splits) # Test invalid unit ID invalid_unit_id = { From dbdac133604be6850a9b9b08b5344a275bcc9da1 Mon Sep 17 00:00:00 2001 From: Alessio Buccino Date: Thu, 12 Jun 2025 15:40:04 +0200 Subject: [PATCH 11/11] Clearer explanation on split data --- src/spikeinterface/curation/curation_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/spikeinterface/curation/curation_model.py b/src/spikeinterface/curation/curation_model.py index 6ece504ee3..175164f81d 100644 --- a/src/spikeinterface/curation/curation_model.py +++ b/src/spikeinterface/curation/curation_model.py @@ -156,7 +156,7 @@ def check_merges(cls, values): def check_splits(cls, values): """ Checks and validates the splits in the curation model. - If `splits` is a dictionary with unit_id as key and split data as values, + If `splits` is a dictionary with unit_id as key and split indices as values, it converts it to a list of Split objects. Each Split object is then validated: - Checks if the unit_id exists in the unit_ids list.