diff --git a/pyproject.toml b/pyproject.toml index e2c7b58d65..5a882b3a00 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,6 +28,7 @@ dependencies = [ "neo>=0.14.0", "probeinterface>=0.2.23", "packaging", + "pydantic", ] [build-system] diff --git a/src/spikeinterface/core/__init__.py b/src/spikeinterface/core/__init__.py index fb2e173b3e..fdcfc73c27 100644 --- a/src/spikeinterface/core/__init__.py +++ b/src/spikeinterface/core/__init__.py @@ -109,7 +109,12 @@ get_chunk_with_margin, order_channels_by_depth, ) -from .sorting_tools import spike_vector_to_spike_trains, random_spikes_selection, apply_merges_to_sorting +from .sorting_tools import ( + spike_vector_to_spike_trains, + random_spikes_selection, + apply_merges_to_sorting, + apply_splits_to_sorting, +) from .waveform_tools import extract_waveforms_to_buffers, estimate_templates, estimate_templates_with_accumulator from .snippets_tools import snippets_from_sorting diff --git a/src/spikeinterface/core/analyzer_extension_core.py b/src/spikeinterface/core/analyzer_extension_core.py index 447bbe562e..4834e864d5 100644 --- a/src/spikeinterface/core/analyzer_extension_core.py +++ b/src/spikeinterface/core/analyzer_extension_core.py @@ -94,6 +94,11 @@ def _merge_extension_data( new_data["random_spikes_indices"] = np.flatnonzero(selected_mask[keep_mask]) return new_data + def _split_extension_data(self, split_units, new_unit_ids, new_sorting_analyzer, verbose=False, **job_kwargs): + new_data = dict() + new_data["random_spikes_indices"] = self.data["random_spikes_indices"].copy() + return new_data + def _get_data(self): return self.data["random_spikes_indices"] @@ -245,8 +250,6 @@ def _select_extension_data(self, unit_ids): def _merge_extension_data( self, merge_unit_groups, new_unit_ids, new_sorting_analyzer, keep_mask=None, verbose=False, **job_kwargs ): - new_data = dict() - waveforms = self.data["waveforms"] some_spikes = self.sorting_analyzer.get_extension("random_spikes").get_random_spikes() if keep_mask is not None: @@ -277,6 +280,11 @@ def _merge_extension_data( return dict(waveforms=waveforms) + def _split_extension_data(self, split_units, new_unit_ids, new_sorting_analyzer, verbose=False, **job_kwargs): + # splitting only affects random spikes, not waveforms + new_data = dict(waveforms=self.data["waveforms"].copy()) + return new_data + def get_waveforms_one_unit(self, unit_id, force_dense: bool = False): """ Returns the waveforms of a unit id. @@ -556,6 +564,42 @@ def _merge_extension_data( return new_data + def _split_extension_data(self, split_units, new_unit_ids, new_sorting_analyzer, verbose=False, **job_kwargs): + new_data = dict() + for operator, arr in self.data.items(): + # we first copy the unsplit units + new_array = np.zeros((len(new_sorting_analyzer.unit_ids), arr.shape[1], arr.shape[2]), dtype=arr.dtype) + new_analyzer_unit_ids = list(new_sorting_analyzer.unit_ids) + unsplit_unit_ids = [unit_id for unit_id in self.sorting_analyzer.unit_ids if unit_id not in split_units] + new_indices = np.array([new_analyzer_unit_ids.index(unit_id) for unit_id in unsplit_unit_ids]) + old_indices = self.sorting_analyzer.sorting.ids_to_indices(unsplit_unit_ids) + new_array[new_indices, ...] = arr[old_indices, ...] + + for split_unit_id, new_splits in zip(split_units, new_unit_ids): + if new_sorting_analyzer.has_extension("waveforms"): + for new_unit_id in new_splits: + split_unit_index = new_sorting_analyzer.sorting.id_to_index(new_unit_id) + wfs = new_sorting_analyzer.get_extension("waveforms").get_waveforms_one_unit( + new_unit_id, force_dense=True + ) + + if operator == "average": + arr = np.average(wfs, axis=0) + elif operator == "std": + arr = np.std(wfs, axis=0) + elif operator == "median": + arr = np.median(wfs, axis=0) + elif "percentile" in operator: + _, percentile = operator.splot("_") + arr = np.percentile(wfs, float(percentile), axis=0) + new_array[split_unit_index, ...] = arr + else: + old_template = arr[self.sorting_analyzer.sorting.ids_to_indices([split_unit_id])[0], ...] + new_indices = np.array([new_unit_ids.index(unit_id) for unit_id in new_splits]) + new_array[new_indices, ...] = np.tile(old_template, (len(new_splits), 1, 1)) + new_data[operator] = new_array + return new_data + def _get_data(self, operator="average", percentile=None, outputs="numpy"): if operator != "percentile": key = operator @@ -729,6 +773,10 @@ def _merge_extension_data( # this does not depend on units return self.data.copy() + def _split_extension_data(self, split_units, new_unit_ids, new_sorting_analyzer, verbose=False, **job_kwargs): + # this does not depend on units + return self.data.copy() + def _run(self, verbose=False): self.data["noise_levels"] = get_noise_levels( self.sorting_analyzer.recording, return_scaled=self.sorting_analyzer.return_scaled, **self.params diff --git a/src/spikeinterface/core/sorting_tools.py b/src/spikeinterface/core/sorting_tools.py index 47ce8cf848..24718d4e5d 100644 --- a/src/spikeinterface/core/sorting_tools.py +++ b/src/spikeinterface/core/sorting_tools.py @@ -231,9 +231,15 @@ def random_spikes_selection( return random_spikes_indices +### MERGING ZONE ### def apply_merges_to_sorting( - sorting, merge_unit_groups, new_unit_ids=None, censor_ms=None, return_extra=False, new_id_strategy="append" -): + sorting: BaseSorting, + merge_unit_groups: list[list[int | str]] | list[tuple[int | str]], + new_unit_ids: list[int | str] | None = None, + censor_ms: float | None = None, + return_extra: bool = False, + new_id_strategy: str = "append", +) -> NumpySorting | tuple[NumpySorting, np.ndarray, list[int | str]]: """ Apply a resolved representation of the merges to a sorting object. @@ -245,9 +251,9 @@ def apply_merges_to_sorting( Parameters ---------- - sorting : Sorting + sorting : BaseSorting The Sorting object to apply merges. - merge_unit_groups : list/tuple of lists/tuples + merge_unit_groups : list of lists/tuples A list of lists for every merge group. Each element needs to have at least two elements (two units to merge), but it can also have more (merge multiple units at once). new_unit_ids : list | None, default: None @@ -440,3 +446,176 @@ def generate_unit_ids_for_merge_group(old_unit_ids, merge_unit_groups, new_unit_ raise ValueError("wrong new_id_strategy") return new_unit_ids + + +### SPLITTING ZONE ### +def apply_splits_to_sorting(sorting, unit_splits, new_unit_ids=None, return_extra=False, new_id_strategy="append"): + spikes = sorting.to_spike_vector().copy() + + # here we assume that unit_splits split_indices are already full. + # this is true when running via apply_curation + + new_unit_ids = generate_unit_ids_for_split( + sorting.unit_ids, unit_splits, new_unit_ids=new_unit_ids, new_id_strategy=new_id_strategy + ) + all_unit_ids = _get_ids_after_splitting(sorting.unit_ids, unit_splits, new_unit_ids) + all_unit_ids = list(all_unit_ids) + + num_seg = sorting.get_num_segments() + seg_lims = np.searchsorted(spikes["segment_index"], np.arange(0, num_seg + 2)) + segment_slices = [(seg_lims[i], seg_lims[i + 1]) for i in range(num_seg)] + + # using this function vaoid to use the mask approach and simplify a lot the algo + spike_vector_list = [spikes[s0:s1] for s0, s1 in segment_slices] + spike_indices = spike_vector_to_indices(spike_vector_list, sorting.unit_ids, absolute_index=True) + + for unit_id in sorting.unit_ids: + if unit_id in unit_splits: + split_indices = unit_splits[unit_id] + new_split_ids = new_unit_ids[list(unit_splits.keys()).index(unit_id)] + + for split, new_unit_id in zip(split_indices, new_split_ids): + new_unit_index = all_unit_ids.index(new_unit_id) + # split_indices are a concatenation across segments with absolute indices + # so we need to concatenate the spike indices across segments + spike_indices_unit = np.concatenate( + [spike_indices[segment_index][unit_id] for segment_index in range(num_seg)] + ) + spikes["unit_index"][spike_indices_unit[split]] = new_unit_index + else: + new_unit_index = all_unit_ids.index(unit_id) + for segment_index in range(num_seg): + spike_inds = spike_indices[segment_index][unit_id] + spikes["unit_index"][spike_inds] = new_unit_index + sorting = NumpySorting(spikes, sorting.sampling_frequency, all_unit_ids) + + if return_extra: + return sorting, new_unit_ids + else: + return sorting + + +def generate_unit_ids_for_split(old_unit_ids, unit_splits, new_unit_ids=None, new_id_strategy="append"): + """ + Function to generate new units ids during a merging procedure. If new_units_ids + are provided, it will return these unit ids, checking that they have the the same + length as `merge_unit_groups`. + + Parameters + ---------- + old_unit_ids : np.array + The old unit_ids. + unit_splits : dict + + new_unit_ids : list | None, default: None + Optional new unit_ids for merged units. If given, it needs to have the same length as `merge_unit_groups`. + If None, new ids will be generated. + new_id_strategy : "append" | "take_first" | "join", default: "append" + The strategy that should be used, if `new_unit_ids` is None, to create new unit_ids. + + * "append" : new_units_ids will be added at the end of max(sorging.unit_ids) + * "split" : new_unit_ids will join unit_ids of groups with a "-". + Only works if unit_ids are str otherwise switch to "append" + + Returns + ------- + new_unit_ids : list of lists + The new units_ids associated with the merges. + """ + assert new_id_strategy in ["append", "split"], "new_id_strategy should be 'append' or 'split'" + old_unit_ids = np.asarray(old_unit_ids) + + if new_unit_ids is not None: + for split_unit, new_split_ids in zip(unit_splits.values(), new_unit_ids): + # then only doing a consistency check + assert len(split_unit) == len(new_split_ids), "new_unit_ids should have the same len as unit_splits.values" + # new_unit_ids can also be part of old_unit_ids only inside the same group: + assert all( + new_split_id not in old_unit_ids for new_split_id in new_split_ids + ), "new_unit_ids already exists but outside the split groups" + else: + dtype = old_unit_ids.dtype + new_unit_ids = [] + for unit_to_split, split_indices in unit_splits.items(): + num_splits = len(split_indices) + # select new_unit_ids greater that the max id, event greater than the numerical str ids + if new_id_strategy == "append": + if np.issubdtype(dtype, np.character): + # dtype str + if all(p.isdigit() for p in old_unit_ids): + # All str are digit : we can generate a max + m = max(int(p) for p in old_unit_ids) + 1 + new_unit_ids.append([str(m + i) for i in range(num_splits)]) + else: + # we cannot automatically find new names + new_unit_ids.append([f"split{i}" for i in range(num_splits)]) + else: + # dtype int + new_unit_ids.append(list(max(old_unit_ids) + 1 + np.arange(num_splits, dtype=dtype))) + old_unit_ids = np.concatenate([old_unit_ids, new_unit_ids[-1]]) + elif new_id_strategy == "split": + if np.issubdtype(dtype, np.character): + new_unit_ids.append([f"{unit_to_split}-{i}" for i in np.arange(len(split_indices))]) + else: + # dtype int + new_unit_ids.append(list(max(old_unit_ids) + 1 + np.arange(num_splits, dtype=dtype))) + old_unit_ids = np.concatenate([old_unit_ids, new_unit_ids[-1]]) + + return new_unit_ids + + +def _get_full_unit_splits(unit_splits, sorting): + # take care of single-list splits + full_unit_splits = {} + num_spikes = sorting.count_num_spikes_per_unit() + for unit_id, split_indices in unit_splits.items(): + if not isinstance(split_indices[0], (list, np.ndarray)): + split_2 = np.arange(num_spikes[unit_id]) + split_2 = split_2[~np.isin(split_2, split_indices)] + new_split_indices = [split_indices, split_2] + else: + new_split_indices = split_indices + full_unit_splits[unit_id] = new_split_indices + return full_unit_splits + + +def _get_ids_after_splitting(old_unit_ids, split_units, new_unit_ids): + """ + Function to get the list of unique unit_ids after some splits, with given new_units_ids would + be provided. + + Every new unit_id will be added at the end if not already present. + + Parameters + ---------- + old_unit_ids : np.array + The old unit_ids. + split_units : dict + A dict of split units. Each element needs to have at least two elements (two units to split). + new_unit_ids : list | None + A new unit_ids for split units. If given, it needs to have the same length as `split_units` values. + + Returns + ------- + + all_unit_ids : The unit ids in the split sorting + The units_ids that will be present after splits + + """ + old_unit_ids = np.asarray(old_unit_ids) + dtype = old_unit_ids.dtype + if dtype.kind == "U": + # the new dtype can be longer + dtype = "U" + + assert len(new_unit_ids) == len(split_units), "new_unit_ids should have the same len as merge_unit_groups" + for new_unit_in_split, unit_to_split in zip(new_unit_ids, split_units.keys()): + assert len(new_unit_in_split) == len( + split_units[unit_to_split] + ), "new_unit_ids should have the same len as split_units values" + + all_unit_ids = list(old_unit_ids.copy()) + for split_unit, split_new_units in zip(split_units, new_unit_ids): + all_unit_ids.remove(split_unit) + all_unit_ids.extend(split_new_units) + return np.array(all_unit_ids, dtype=dtype) diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index bb4ee4db1c..d7b96b32b3 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -31,7 +31,13 @@ is_path_remote, clean_zarr_folder_name, ) -from .sorting_tools import generate_unit_ids_for_merge_group, _get_ids_after_merging +from .sorting_tools import ( + generate_unit_ids_for_merge_group, + _get_ids_after_merging, + generate_unit_ids_for_split, + _get_ids_after_splitting, + _get_full_unit_splits, +) from .job_tools import split_job_kwargs from .numpyextractors import NumpySorting from .sparsity import ChannelSparsity, estimate_sparsity @@ -867,17 +873,19 @@ def are_units_mergeable( else: return mergeable - def _save_or_select_or_merge( + def _save_or_select_or_merge_or_split( self, format="binary_folder", folder=None, unit_ids=None, merge_unit_groups=None, + split_units=None, censor_ms=None, merging_mode="soft", sparsity_overlap=0.75, verbose=False, - new_unit_ids=None, + merge_new_unit_ids=None, + split_new_unit_ids=None, backend_options=None, **job_kwargs, ) -> "SortingAnalyzer": @@ -896,6 +904,8 @@ def _save_or_select_or_merge( merge_unit_groups : list/tuple of lists/tuples or None, default: None A list of lists for every merge group. Each element needs to have at least two elements (two units to merge). If `merge_unit_groups` is not None, `new_unit_ids` must be given. + split_units : dict or None, default: None + A dictionary with the keys being the unit ids to split and the values being the split indices. censor_ms : None or float, default: None When merging units, any spikes violating this refractory period will be discarded. merging_mode : "soft" | "hard", default: "soft" @@ -904,8 +914,10 @@ def _save_or_select_or_merge( sparsity_overlap : float, default 0.75 The percentage of overlap that units should share in order to accept merges. If this criteria is not achieved, soft merging will not be performed. - new_unit_ids : list or None, default: None + merge_new_unit_ids : list or None, default: None The new unit ids for merged units. Required if `merge_unit_groups` is not None. + split_new_unit_ids : list or None, default: None + The new unit ids for split units. Required if `split_units` is not None. verbose : bool, default: False If True, output is verbose. backend_options : dict | None, default: None @@ -928,36 +940,63 @@ def _save_or_select_or_merge( else: recording = None - if self.sparsity is not None and unit_ids is None and merge_unit_groups is None: - sparsity = self.sparsity - elif self.sparsity is not None and unit_ids is not None and merge_unit_groups is None: - sparsity_mask = self.sparsity.mask[np.isin(self.unit_ids, unit_ids), :] - sparsity = ChannelSparsity(sparsity_mask, unit_ids, self.channel_ids) - elif self.sparsity is not None and merge_unit_groups is not None: - all_unit_ids = unit_ids - sparsity_mask = np.zeros((len(all_unit_ids), self.sparsity.mask.shape[1]), dtype=bool) - mergeable, masks = self.are_units_mergeable( - merge_unit_groups, - sparsity_overlap=sparsity_overlap, - return_masks=True, - ) + has_removed = unit_ids is not None + has_merges = merge_unit_groups is not None + has_splits = split_units is not None + assert not has_merges if has_splits else True, "Cannot merge and split at the same time" + + if self.sparsity is not None: + if not has_removed and not has_merges and not has_splits: + # no changes in units + sparsity = self.sparsity + elif has_removed and not has_merges and not has_splits: + # remove units + sparsity_mask = self.sparsity.mask[np.isin(self.unit_ids, unit_ids), :] + sparsity = ChannelSparsity(sparsity_mask, unit_ids, self.channel_ids) + elif has_merges: + # merge units + all_unit_ids = unit_ids + sparsity_mask = np.zeros((len(all_unit_ids), self.sparsity.mask.shape[1]), dtype=bool) + mergeable, masks = self.are_units_mergeable( + merge_unit_groups, + sparsity_overlap=sparsity_overlap, + return_masks=True, + ) - for unit_index, unit_id in enumerate(all_unit_ids): - if unit_id in new_unit_ids: - merge_unit_group = tuple(merge_unit_groups[new_unit_ids.index(unit_id)]) - if not mergeable[merge_unit_group]: - raise Exception( - f"The sparsity of {merge_unit_group} do not overlap enough for a soft merge using " - f"a sparsity threshold of {sparsity_overlap}. You can either lower the threshold or use " - "a hard merge." - ) + for unit_index, unit_id in enumerate(all_unit_ids): + if unit_id in merge_new_unit_ids: + merge_unit_group = tuple(merge_unit_groups[merge_new_unit_ids.index(unit_id)]) + if not mergeable[merge_unit_group]: + raise Exception( + f"The sparsity of {merge_unit_group} do not overlap enough for a soft merge using " + f"a sparsity threshold of {sparsity_overlap}. You can either lower the threshold or use " + "a hard merge." + ) + else: + sparsity_mask[unit_index] = masks[merge_unit_group] else: - sparsity_mask[unit_index] = masks[merge_unit_group] - else: - # This means that the unit is already in the previous sorting - index = self.sorting.id_to_index(unit_id) - sparsity_mask[unit_index] = self.sparsity.mask[index] - sparsity = ChannelSparsity(sparsity_mask, list(all_unit_ids), self.channel_ids) + # This means that the unit is already in the previous sorting + index = self.sorting.id_to_index(unit_id) + sparsity_mask[unit_index] = self.sparsity.mask[index] + sparsity = ChannelSparsity(sparsity_mask, list(all_unit_ids), self.channel_ids) + elif has_splits: + # split units + all_unit_ids = unit_ids + original_unit_ids = self.unit_ids + sparsity_mask = np.zeros((len(all_unit_ids), self.sparsity.mask.shape[1]), dtype=bool) + for unit_index, unit_id in enumerate(all_unit_ids): + if unit_id not in original_unit_ids: + # then it is a new unit + # we assign the original sparsity + for split_unit, new_unit_ids in zip(split_units, split_new_unit_ids): + if unit_id in new_unit_ids: + original_unit_index = self.sorting.id_to_index(split_unit) + sparsity_mask[unit_index] = self.sparsity.mask[original_unit_index] + break + else: + original_unit_index = self.sorting.id_to_index(unit_id) + sparsity_mask[unit_index] = self.sparsity.mask[original_unit_index] + sparsity = ChannelSparsity(sparsity_mask, list(all_unit_ids), self.channel_ids) else: sparsity = None @@ -967,25 +1006,35 @@ def _save_or_select_or_merge( # if the original sorting object is not available anymore (kilosort folder deleted, ....), take the copy sorting_provenance = self.sorting - if merge_unit_groups is None: + if merge_unit_groups is None and split_units is None: # when only some unit_ids then the sorting must be sliced # TODO check that unit_ids are in same order otherwise many extension do handle it properly!!!! sorting_provenance = sorting_provenance.select_units(unit_ids) - else: + elif merge_unit_groups is not None: + assert split_units is None, "split_units must be None when merge_unit_groups is None" from spikeinterface.core.sorting_tools import apply_merges_to_sorting sorting_provenance, keep_mask, _ = apply_merges_to_sorting( sorting=sorting_provenance, merge_unit_groups=merge_unit_groups, - new_unit_ids=new_unit_ids, + new_unit_ids=merge_new_unit_ids, censor_ms=censor_ms, return_extra=True, ) if censor_ms is None: # in this case having keep_mask None is faster instead of having a vector of ones keep_mask = None - # TODO: sam/pierre would create a curation field / curation.json with the applied merges. - # What do you think? + elif split_units is not None: + assert merge_unit_groups is None, "merge_unit_groups must be None when split_units is not None" + from spikeinterface.core.sorting_tools import apply_splits_to_sorting + + sorting_provenance = apply_splits_to_sorting( + sorting=sorting_provenance, + unit_splits=split_units, + new_unit_ids=split_new_unit_ids, + ) + # TODO: sam/pierre would create a curation field / curation.json with the applied merges. + # What do you think? backend_options = {} if backend_options is None else backend_options @@ -1034,26 +1083,34 @@ def _save_or_select_or_merge( recompute_dict = {} for extension_name, extension in sorted_extensions.items(): - if merge_unit_groups is None: + if merge_unit_groups is None and split_units is None: # copy full or select new_sorting_analyzer.extensions[extension_name] = extension.copy( new_sorting_analyzer, unit_ids=unit_ids ) - else: + elif merge_unit_groups is not None: # merge if merging_mode == "soft": new_sorting_analyzer.extensions[extension_name] = extension.merge( new_sorting_analyzer, merge_unit_groups=merge_unit_groups, - new_unit_ids=new_unit_ids, + new_unit_ids=merge_new_unit_ids, keep_mask=keep_mask, verbose=verbose, **job_kwargs, ) elif merging_mode == "hard": recompute_dict[extension_name] = extension.params + else: + # split + try: + new_sorting_analyzer.extensions[extension_name] = extension.split( + new_sorting_analyzer, split_units=split_units, new_unit_ids=split_new_unit_ids, verbose=verbose + ) + except NotImplementedError: + recompute_dict[extension_name] = extension.params - if merge_unit_groups is not None and merging_mode == "hard" and len(recompute_dict) > 0: + if len(recompute_dict) > 0: new_sorting_analyzer.compute_several_extensions(recompute_dict, save=True, verbose=verbose, **job_kwargs) return new_sorting_analyzer @@ -1081,7 +1138,7 @@ def save_as(self, format="memory", folder=None, backend_options=None) -> "Sortin """ if format == "zarr": folder = clean_zarr_folder_name(folder) - return self._save_or_select_or_merge(format=format, folder=folder, backend_options=backend_options) + return self._save_or_select_or_merge_or_split(format=format, folder=folder, backend_options=backend_options) def select_units(self, unit_ids, format="memory", folder=None) -> "SortingAnalyzer": """ @@ -1108,7 +1165,7 @@ def select_units(self, unit_ids, format="memory", folder=None) -> "SortingAnalyz # TODO check that unit_ids are in same order otherwise many extension do handle it properly!!!! if format == "zarr": folder = clean_zarr_folder_name(folder) - return self._save_or_select_or_merge(format=format, folder=folder, unit_ids=unit_ids) + return self._save_or_select_or_merge_or_split(format=format, folder=folder, unit_ids=unit_ids) def remove_units(self, remove_unit_ids, format="memory", folder=None) -> "SortingAnalyzer": """ @@ -1136,22 +1193,22 @@ def remove_units(self, remove_unit_ids, format="memory", folder=None) -> "Sortin unit_ids = self.unit_ids[~np.isin(self.unit_ids, remove_unit_ids)] if format == "zarr": folder = clean_zarr_folder_name(folder) - return self._save_or_select_or_merge(format=format, folder=folder, unit_ids=unit_ids) + return self._save_or_select_or_merge_or_split(format=format, folder=folder, unit_ids=unit_ids) def merge_units( self, - merge_unit_groups, - new_unit_ids=None, - censor_ms=None, - merging_mode="soft", - sparsity_overlap=0.75, - new_id_strategy="append", - return_new_unit_ids=False, - format="memory", - folder=None, - verbose=False, + merge_unit_groups: list[list[str | int]] | list[tuple[str | int]], + new_unit_ids: list[int | str] | None = None, + censor_ms: float | None = None, + merging_mode: str = "soft", + sparsity_overlap: float = 0.75, + new_id_strategy: str = "append", + return_new_unit_ids: bool = False, + format: str = "memory", + folder: Path | str | None = None, + verbose: bool = False, **job_kwargs, - ) -> "SortingAnalyzer": + ) -> "SortingAnalyzer | tuple[SortingAnalyzer, list[int | str]]": """ This method is equivalent to `save_as()` but with a list of merges that have to be achieved. Merges units by creating a new SortingAnalyzer object with the appropriate merges @@ -1222,7 +1279,7 @@ def merge_units( ) all_unit_ids = _get_ids_after_merging(self.unit_ids, merge_unit_groups, new_unit_ids=new_unit_ids) - new_analyzer = self._save_or_select_or_merge( + new_analyzer = self._save_or_select_or_merge_or_split( format=format, folder=folder, merge_unit_groups=merge_unit_groups, @@ -1231,7 +1288,81 @@ def merge_units( merging_mode=merging_mode, sparsity_overlap=sparsity_overlap, verbose=verbose, - new_unit_ids=new_unit_ids, + merge_new_unit_ids=new_unit_ids, + **job_kwargs, + ) + if return_new_unit_ids: + return new_analyzer, new_unit_ids + else: + return new_analyzer + + def split_units( + self, + split_units: dict[list[str | int], list[int] | list[list[int]]], + new_unit_ids: list[list[int | str]] | None = None, + new_id_strategy: str = "append", + return_new_unit_ids: bool = False, + format: str = "memory", + folder: Path | str | None = None, + verbose: bool = False, + **job_kwargs, + ) -> "SortingAnalyzer | tuple[SortingAnalyzer, list[int | str]]": + """ + This method is equivalent to `save_as()` but with a list of splits that have to be achieved. + Split units by creating a new SortingAnalyzer object with the appropriate splits + + Extensions are also updated to display the split `unit_ids`. + + Parameters + ---------- + split_units : dict + A dictionary with the keys being the unit ids to split and the values being the split indices. + new_unit_ids : None | list, default: None + A new unit_ids for split units. If given, it needs to have the same length as `merge_unit_groups`. If None, + merged units will have the first unit_id of every lists of merges + new_id_strategy : "append" | "split", default: "append" + The strategy that should be used, if `new_unit_ids` is None, to create new unit_ids. + + * "append" : new_units_ids will be added at the end of max(sorting.unit_ids) + * "split" : new_unit_ids will be the original unit_id to split with -{subsplit} + return_new_unit_ids : bool, default False + Alse return new_unit_ids which are the ids of the new units. + folder : Path | None, default: None + The new folder where the analyzer with merged units is copied for `format` "binary_folder" or "zarr" + format : "memory" | "binary_folder" | "zarr", default: "memory" + The format of SortingAnalyzer + verbose : bool, default: False + Whether to display calculations (such as sparsity estimation) + + Returns + ------- + analyzer : SortingAnalyzer + The newly create `SortingAnalyzer` with the selected units + """ + + if format == "zarr": + folder = clean_zarr_folder_name(folder) + + if len(split_units) == 0: + # TODO I think we should raise an error or at least make a copy and not return itself + if return_new_unit_ids: + return self, [] + else: + return self + + # TODO: add some checks + split_units = _get_full_unit_splits(split_units, self.sorting) + + new_unit_ids = generate_unit_ids_for_split(self.unit_ids, split_units, new_unit_ids, new_id_strategy) + all_unit_ids = _get_ids_after_splitting(self.unit_ids, split_units, new_unit_ids=new_unit_ids) + + new_analyzer = self._save_or_select_or_merge_or_split( + format=format, + folder=folder, + split_units=split_units, + unit_ids=all_unit_ids, + verbose=verbose, + split_new_unit_ids=new_unit_ids, **job_kwargs, ) if return_new_unit_ids: @@ -1243,7 +1374,7 @@ def copy(self): """ Create a a copy of SortingAnalyzer with format "memory". """ - return self._save_or_select_or_merge(format="memory", folder=None) + return self._save_or_select_or_merge_or_split(format="memory", folder=None) def is_read_only(self) -> bool: if self.format == "memory": @@ -2048,6 +2179,10 @@ def _merge_extension_data( # must be implemented in subclass raise NotImplementedError + def _split_extension_data(self, split_units, new_unit_ids, new_sorting_analyzer, verbose=False, **job_kwargs): + # must be implemented in subclass + raise NotImplementedError + def _get_pipeline_nodes(self): # must be implemented in subclass only if use_nodepipeline=True raise NotImplementedError @@ -2283,6 +2418,23 @@ def merge( new_extension.save() return new_extension + def split( + self, + new_sorting_analyzer, + split_units, + new_unit_ids, + verbose=False, + **job_kwargs, + ): + new_extension = self.__class__(new_sorting_analyzer) + new_extension.params = self.params.copy() + new_extension.data = self._split_extension_data( + split_units, new_unit_ids, new_sorting_analyzer, verbose=verbose, **job_kwargs + ) + new_extension.run_info = copy(self.run_info) + new_extension.save() + return new_extension + def run(self, save=True, **kwargs): if save and not self.sorting_analyzer.is_read_only(): # NB: this call to _save_params() also resets the folder or zarr group diff --git a/src/spikeinterface/curation/curation_format.py b/src/spikeinterface/curation/curation_format.py index 80f251ca43..6bf4900077 100644 --- a/src/spikeinterface/curation/curation_format.py +++ b/src/spikeinterface/curation/curation_format.py @@ -1,14 +1,13 @@ -from itertools import combinations +from __future__ import annotations import numpy as np +from itertools import chain -from spikeinterface.core import BaseSorting, SortingAnalyzer, apply_merges_to_sorting -import copy +from spikeinterface.core import BaseSorting, SortingAnalyzer, apply_merges_to_sorting, apply_splits_to_sorting +from spikeinterface.curation.curation_model import CurationModel -supported_curation_format_versions = {"1"} - -def validate_curation_dict(curation_dict): +def validate_curation_dict(curation_dict: dict): """ Validate that the curation dictionary given as parameter complies with the format @@ -19,119 +18,11 @@ def validate_curation_dict(curation_dict): curation_dict : dict """ + # this will validate the format of the curation_dict + CurationModel(**curation_dict) - # format - if "format_version" not in curation_dict: - raise ValueError("No version_format") - - if curation_dict["format_version"] not in supported_curation_format_versions: - raise ValueError( - f"Format version ({curation_dict['format_version']}) not supported. " - f"Only {supported_curation_format_versions} are valid" - ) - - # unit_ids - labeled_unit_set = set([lbl["unit_id"] for lbl in curation_dict["manual_labels"]]) - merged_units_set = set(sum(curation_dict["merge_unit_groups"], [])) - removed_units_set = set(curation_dict["removed_units"]) - - if curation_dict["unit_ids"] is not None: - # old format v0 did not contain unit_ids so this can contains None - unit_set = set(curation_dict["unit_ids"]) - if not labeled_unit_set.issubset(unit_set): - raise ValueError("Curation format: some labeled units are not in the unit list") - if not merged_units_set.issubset(unit_set): - raise ValueError("Curation format: some merged units are not in the unit list") - if not removed_units_set.issubset(unit_set): - raise ValueError("Curation format: some removed units are not in the unit list") - - for group in curation_dict["merge_unit_groups"]: - if len(group) < 2: - raise ValueError("Curation format: 'merge_unit_groups' must be list of list with at least 2 elements") - - all_merging_groups = [set(group) for group in curation_dict["merge_unit_groups"]] - for gp_1, gp_2 in combinations(all_merging_groups, 2): - if len(gp_1.intersection(gp_2)) != 0: - raise ValueError("Curation format: some units belong to multiple merge groups") - if len(removed_units_set.intersection(merged_units_set)) != 0: - raise ValueError("Curation format: some units were merged and deleted") - - # Check the labels exclusivity - for lbl in curation_dict["manual_labels"]: - for label_key in curation_dict["label_definitions"].keys(): - if label_key in lbl: - unit_id = lbl["unit_id"] - label_value = lbl[label_key] - if not isinstance(label_value, list): - raise ValueError(f"Curation format: manual_labels {unit_id} is invalid shoudl be a list") - - is_exclusive = curation_dict["label_definitions"][label_key]["exclusive"] - - if is_exclusive and not len(label_value) <= 1: - raise ValueError( - f"Curation format: manual_labels {unit_id} {label_key} are exclusive labels. {label_value} is invalid" - ) - - -def convert_from_sortingview_curation_format_v0(sortingview_dict, destination_format="1"): - """ - Converts the old sortingview curation format (v0) into a curation dictionary new format (v1) - Couple of caveats: - * The list of units is not available in the original sortingview dictionary. We set it to None - * Labels can not be mutually exclusive. - * Labels have no category, so we regroup them under the "all_labels" category - - Parameters - ---------- - sortingview_dict : dict - Dictionary containing the curation information from sortingview - destination_format : str - Version of the format to use. - Default to "1" - Returns - ------- - curation_dict : dict - A curation dictionary - """ - - assert destination_format == "1" - if "mergeGroups" not in sortingview_dict.keys(): - sortingview_dict["mergeGroups"] = [] - merge_groups = sortingview_dict["mergeGroups"] - merged_units = sum(merge_groups, []) - - first_unit_id = next(iter(sortingview_dict["labelsByUnit"].keys())) - if str.isdigit(first_unit_id): - unit_id_type = int - else: - unit_id_type = str - - all_units = [] - all_labels = [] - manual_labels = [] - general_cat = "all_labels" - for unit_id_, l_labels in sortingview_dict["labelsByUnit"].items(): - all_labels.extend(l_labels) - # recorver the correct type for unit_id - unit_id = unit_id_type(unit_id_) - all_units.append(unit_id) - manual_labels.append({"unit_id": unit_id, general_cat: l_labels}) - labels_def = {"all_labels": {"name": "all_labels", "label_options": list(set(all_labels)), "exclusive": False}} - - curation_dict = { - "format_version": destination_format, - "unit_ids": None, - "label_definitions": labels_def, - "manual_labels": manual_labels, - "merge_unit_groups": merge_groups, - "removed_units": [], - } - - return curation_dict - - -def curation_label_to_vectors(curation_dict): +def curation_label_to_vectors(curation_dict_or_model: dict | CurationModel): """ Transform the curation dict into dict of vectors. For label category with exclusive=True : a column is created and values are the unique label. @@ -141,66 +32,46 @@ def curation_label_to_vectors(curation_dict): Parameters ---------- - curation_dict : dict - A curation dictionary + curation_dict : dict or CurationModel + A curation dictionary or model Returns ------- labels : dict of numpy vector """ - unit_ids = list(curation_dict["unit_ids"]) + if isinstance(curation_dict_or_model, dict): + curation_model = CurationModel(**curation_dict_or_model) + else: + curation_model = curation_dict_or_model + unit_ids = list(curation_model.unit_ids) n = len(unit_ids) labels = {} - for label_key, label_def in curation_dict["label_definitions"].items(): - if label_def["exclusive"]: + for label_key, label_def in curation_model.label_definitions.items(): + if label_def.exclusive: assert label_key not in labels, f"{label_key} is already a key" labels[label_key] = [""] * n - for lbl in curation_dict["manual_labels"]: - value = lbl.get(label_key, []) - if len(value) == 1: - unit_index = unit_ids.index(lbl["unit_id"]) - labels[label_key][unit_index] = value[0] + for manual_label in curation_model.manual_labels: + values = manual_label.labels.get(label_key, []) + if len(values) == 1: + unit_index = unit_ids.index(manual_label.unit_id) + labels[label_key][unit_index] = values[0] labels[label_key] = np.array(labels[label_key]) else: - for label_opt in label_def["label_options"]: + for label_opt in label_def.label_options: assert label_opt not in labels, f"{label_opt} is already a key" labels[label_opt] = np.zeros(n, dtype=bool) - for lbl in curation_dict["manual_labels"]: - values = lbl.get(label_key, []) + for manual_label in curation_model.manual_labels: + values = manual_label.labels.get(label_key, []) for value in values: - unit_index = unit_ids.index(lbl["unit_id"]) + unit_index = unit_ids.index(manual_label.unit_id) labels[value][unit_index] = True - return labels -def clean_curation_dict(curation_dict): - """ - In some cases the curation_dict can have inconsistencies (like in the sorting view format). - For instance, some unit_ids are both in 'merge_unit_groups' and 'removed_units'. - This is ambiguous! - - This cleaner helper function ensures units tagged as `removed_units` are removed from the `merge_unit_groups` - """ - curation_dict = copy.deepcopy(curation_dict) - - clean_merge_unit_groups = [] - for group in curation_dict["merge_unit_groups"]: - clean_group = [] - for unit_id in group: - if unit_id not in curation_dict["removed_units"]: - clean_group.append(unit_id) - if len(clean_group) > 1: - clean_merge_unit_groups.append(clean_group) - - curation_dict["merge_unit_groups"] = clean_merge_unit_groups - return curation_dict - - -def curation_label_to_dataframe(curation_dict): +def curation_label_to_dataframe(curation_dict_or_model: dict | CurationModel): """ Transform the curation dict into a pandas dataframe. For label category with exclusive=True : a column is created and values are the unique label. @@ -220,38 +91,63 @@ def curation_label_to_dataframe(curation_dict): """ import pandas as pd - labels = pd.DataFrame(curation_label_to_vectors(curation_dict), index=curation_dict["unit_ids"]) + if isinstance(curation_dict_or_model, dict): + curation_model = CurationModel(**curation_dict_or_model) + else: + curation_model = curation_dict_or_model + + labels = pd.DataFrame(curation_label_to_vectors(curation_model), index=curation_model.unit_ids) return labels -def apply_curation_labels(sorting, new_unit_ids, curation_dict): +def apply_curation_labels( + sorting_or_analyzer: BaseSorting | SortingAnalyzer, curation_dict_or_model: dict | CurationModel +): """ - Apply manual labels after merges. + Apply manual labels after merges/splits. Rules: - * label for non merge is applied first + * label for non merged units is applied first * for merged group, when exclusive=True, if all have the same label then this label is applied * for merged group, when exclusive=False, if one unit has the label then the new one have also it + * for split units, the original label is applied to all split units """ + if isinstance(curation_dict_or_model, dict): + curation_model = CurationModel(**curation_dict_or_model) + else: + curation_model = curation_dict_or_model + + if isinstance(sorting_or_analyzer, BaseSorting): + sorting = sorting_or_analyzer + else: + sorting = sorting_or_analyzer.sorting # Please note that manual_labels is done on the unit_ids before the merge!!! - manual_labels = curation_label_to_vectors(curation_dict) + manual_labels = curation_label_to_vectors(curation_model) - # apply on non merged + # apply on non merged / split + merge_new_unit_ids = [m.merge_new_unit_id for m in curation_model.merges] + split_new_unit_ids = [m.split_new_unit_ids for m in curation_model.splits] + split_new_unit_ids = list(chain(*split_new_unit_ids)) + + merged_split_units = merge_new_unit_ids + split_new_unit_ids for key, values in manual_labels.items(): all_values = np.zeros(sorting.unit_ids.size, dtype=values.dtype) for unit_ind, unit_id in enumerate(sorting.unit_ids): - if unit_id not in new_unit_ids: - ind = list(curation_dict["unit_ids"]).index(unit_id) + if unit_id not in merged_split_units: + ind = list(curation_model.unit_ids).index(unit_id) all_values[unit_ind] = values[ind] sorting.set_property(key, all_values) - for new_unit_id, old_group_ids in zip(new_unit_ids, curation_dict["merge_unit_groups"]): - for label_key, label_def in curation_dict["label_definitions"].items(): - if label_def["exclusive"]: + # merges + for merge in curation_model.merges: + new_unit_id = merge.merge_new_unit_id + old_group_ids = merge.merge_unit_group + for label_key, label_def in curation_model.label_definitions.items(): + if label_def.exclusive: group_values = [] for unit_id in old_group_ids: - ind = curation_dict["unit_ids"].index(unit_id) + ind = list(curation_model.unit_ids).index(unit_id) value = manual_labels[label_key][ind] if value != "": group_values.append(value) @@ -259,25 +155,41 @@ def apply_curation_labels(sorting, new_unit_ids, curation_dict): # all group has the same label or empty sorting.set_property(key, values=group_values[:1], ids=[new_unit_id]) else: - - for key in label_def["label_options"]: + for key in label_def.label_options: group_values = [] for unit_id in old_group_ids: - ind = curation_dict["unit_ids"].index(unit_id) + ind = list(curation_model.unit_ids).index(unit_id) value = manual_labels[key][ind] group_values.append(value) new_value = np.any(group_values) sorting.set_property(key, values=[new_value], ids=[new_unit_id]) + # splits + for split in curation_model.splits: + # propagate property of splut unit to new units + old_unit = split.unit_id + new_unit_ids = split.split_new_unit_ids + for label_key, label_def in curation_model.label_definitions.items(): + if label_def.exclusive: + ind = list(curation_model.unit_ids).index(old_unit) + value = manual_labels[label_key][ind] + if value != "": + sorting.set_property(label_key, values=[value] * len(new_unit_ids), ids=new_unit_ids) + else: + for key in label_def.label_options: + ind = list(curation_model.unit_ids).index(old_unit) + value = manual_labels[key][ind] + sorting.set_property(key, values=[value] * len(new_unit_ids), ids=new_unit_ids) + def apply_curation( - sorting_or_analyzer, - curation_dict, - censor_ms=None, - new_id_strategy="append", - merging_mode="soft", - sparsity_overlap=0.75, - verbose=False, + sorting_or_analyzer: BaseSorting | SortingAnalyzer, + curation_dict_or_model: dict | CurationModel, + censor_ms: float | None = None, + new_id_strategy: str = "append", + merging_mode: str = "soft", + sparsity_overlap: float = 0.75, + verbose: bool = False, **job_kwargs, ): """ @@ -286,7 +198,8 @@ def apply_curation( Steps are done in this order: 1. Apply removal using curation_dict["removed_units"] 2. Apply merges using curation_dict["merge_unit_groups"] - 3. Set labels using curation_dict["manual_labels"] + 3. Apply splits using curation_dict["split_units"] + 4. Set labels using curation_dict["manual_labels"] A new Sorting or SortingAnalyzer (in memory) is returned. The user (an adult) has the responsability to save it somewhere (or not). @@ -294,17 +207,18 @@ def apply_curation( Parameters ---------- sorting_or_analyzer : Sorting | SortingAnalyzer - The Sorting object to apply merges. - curation_dict : dict - The curation dict. + The Sorting or SortingAnalyzer object to apply merges. + curation_dict : dict or CurationModel + The curation dict or model. censor_ms : float | None, default: None When applying the merges, any consecutive spikes within the `censor_ms` are removed. This can be thought of as the desired refractory period. If `censor_ms=None`, no spikes are discarded. - new_id_strategy : "append" | "take_first", default: "append" + new_id_strategy : "append" | "take_first" | "join", default: "append" The strategy that should be used, if `new_unit_ids` is None, to create new unit_ids. * "append" : new_units_ids will be added at the end of max(sorting.unit_ids) * "take_first" : new_unit_ids will be the first unit_id of every list of merges + * "join" : new_unit_ids will be the concatenation of all unit_ids of every list of merges merging_mode : "soft" | "hard", default: "soft" How merges are performed for SortingAnalyzer. If the `merge_mode` is "soft" , merges will be approximated, with no reloading of the waveforms. This will lead to approximations. If `merge_mode` is "hard", recomputations are accurately @@ -319,35 +233,46 @@ def apply_curation( Returns ------- - sorting_or_analyzer : Sorting | SortingAnalyzer + curated_sorting_or_analyzer : Sorting | SortingAnalyzer The curated object. - - """ - validate_curation_dict(curation_dict) - if not np.array_equal(np.asarray(curation_dict["unit_ids"]), sorting_or_analyzer.unit_ids): + assert isinstance( + sorting_or_analyzer, (BaseSorting, SortingAnalyzer) + ), f"`sorting_or_analyzer` must be a Sorting or a SortingAnalyzer, not an object of type {type(sorting_or_analyzer)}" + assert isinstance( + curation_dict_or_model, (dict, CurationModel) + ), f"`curation_dict_or_model` must be a dict or a CurationModel, not an object of type {type(curation_dict_or_model)}" + if isinstance(curation_dict_or_model, dict): + curation_model = CurationModel(**curation_dict_or_model) + else: + curation_model = curation_dict_or_model + + if not np.array_equal(np.asarray(curation_model.unit_ids), sorting_or_analyzer.unit_ids): raise ValueError("unit_ids from the curation_dict do not match the one from Sorting or SortingAnalyzer") - if isinstance(sorting_or_analyzer, BaseSorting): - sorting = sorting_or_analyzer - sorting = sorting.remove_units(curation_dict["removed_units"]) - sorting, _, new_unit_ids = apply_merges_to_sorting( - sorting, - curation_dict["merge_unit_groups"], - censor_ms=censor_ms, - return_extra=True, - new_id_strategy=new_id_strategy, - ) - apply_curation_labels(sorting, new_unit_ids, curation_dict) - return sorting - - elif isinstance(sorting_or_analyzer, SortingAnalyzer): - analyzer = sorting_or_analyzer - if len(curation_dict["removed_units"]) > 0: - analyzer = analyzer.remove_units(curation_dict["removed_units"]) - if len(curation_dict["merge_unit_groups"]) > 0: - analyzer, new_unit_ids = analyzer.merge_units( - curation_dict["merge_unit_groups"], + # 1. Remove units + if len(curation_model.removed) > 0: + curated_sorting_or_analyzer = sorting_or_analyzer.remove_units(curation_model.removed) + else: + curated_sorting_or_analyzer = sorting_or_analyzer + + # 2. Merge units + if len(curation_model.merges) > 0: + merge_unit_groups = [m.merge_unit_group for m in curation_model.merges] + merge_new_unit_ids = [m.merge_new_unit_id for m in curation_model.merges if m.merge_new_unit_id is not None] + if len(merge_new_unit_ids) == 0: + merge_new_unit_ids = None + if isinstance(sorting_or_analyzer, BaseSorting): + curated_sorting_or_analyzer, _, new_unit_ids = apply_merges_to_sorting( + curated_sorting_or_analyzer, + merge_unit_groups=merge_unit_groups, + censor_ms=censor_ms, + new_id_strategy=new_id_strategy, + return_extra=True, + ) + else: + curated_sorting_or_analyzer, new_unit_ids = curated_sorting_or_analyzer.merge_units( + merge_unit_groups=merge_unit_groups, censor_ms=censor_ms, merging_mode=merging_mode, sparsity_overlap=sparsity_overlap, @@ -357,11 +282,43 @@ def apply_curation( verbose=verbose, **job_kwargs, ) + for i, merge_unit_id in enumerate(new_unit_ids): + curation_model.merges[i].merge_new_unit_id = merge_unit_id + + # 3. Split units + if len(curation_model.splits) > 0: + split_units = {} + for split in curation_model.splits: + sorting = ( + curated_sorting_or_analyzer + if isinstance(sorting_or_analyzer, BaseSorting) + else sorting_or_analyzer.sorting + ) + split_units[split.unit_id] = split.get_full_spike_indices(sorting) + split_new_unit_ids = [s.split_new_unit_ids for s in curation_model.splits if s.split_new_unit_ids is not None] + if len(split_new_unit_ids) == 0: + split_new_unit_ids = None + if isinstance(sorting_or_analyzer, BaseSorting): + curated_sorting_or_analyzer, new_unit_ids = apply_splits_to_sorting( + curated_sorting_or_analyzer, + split_units, + new_unit_ids=split_new_unit_ids, + new_id_strategy=new_id_strategy, + return_extra=True, + ) else: - new_unit_ids = [] - apply_curation_labels(analyzer.sorting, new_unit_ids, curation_dict) - return analyzer - else: - raise TypeError( - f"`sorting_or_analyzer` must be a Sorting or a SortingAnalyzer, not an object of type {type(sorting_or_analyzer)}" - ) + curated_sorting_or_analyzer, new_unit_ids = curated_sorting_or_analyzer.split_units( + split_units, + new_id_strategy=new_id_strategy, + return_new_unit_ids=True, + new_unit_ids=split_new_unit_ids, + format="memory", + verbose=verbose, + ) + for i, split_unit_ids in enumerate(new_unit_ids): + curation_model.splits[i].split_new_unit_ids = split_unit_ids + + # 4. Apply labels + apply_curation_labels(curated_sorting_or_analyzer, curation_model) + + return curated_sorting_or_analyzer diff --git a/src/spikeinterface/curation/curation_model.py b/src/spikeinterface/curation/curation_model.py new file mode 100644 index 0000000000..55b97bdb71 --- /dev/null +++ b/src/spikeinterface/curation/curation_model.py @@ -0,0 +1,388 @@ +from pydantic import BaseModel, Field, model_validator, field_validator +from typing import List, Dict, Union, Optional, Literal, Tuple +from itertools import chain, combinations +import numpy as np + +from spikeinterface import BaseSorting + + +class LabelDefinition(BaseModel): + name: str = Field(..., description="Name of the label") + label_options: List[str] = Field(..., description="List of possible label options", min_length=2) + exclusive: bool = Field(..., description="Whether the label is exclusive") + + +class ManualLabel(BaseModel): + unit_id: Union[int, str] = Field(..., description="ID of the unit") + labels: Dict[str, List[str]] = Field(..., description="Dictionary of labels for the unit") + + +class Merge(BaseModel): + merge_unit_group: List[Union[int, str]] = Field(..., description="List of groups of units to be merged") + merge_new_unit_id: Optional[Union[int, str]] = Field(default=None, description="New unit IDs for the merge group") + + +class Split(BaseModel): + unit_id: Union[int, str] = Field(..., description="ID of the unit") + split_mode: Literal["indices", "labels"] = Field( + default="indices", + description=( + "Mode of the split. The split can be defined by indices or labels. " + "If indices, the split is defined by the a list of lists of indices of spikes within spikes " + "belonging to the unit (`split_indices`). " + "If labels, the split is defined by a list of labels for each spike (`split_labels`). " + ), + ) + split_indices: Optional[Union[List[List[int]]]] = Field(default=None, description="List of indices for the split") + split_labels: Optional[List[int]] = Field(default=None, description="List of labels for the split") + split_new_unit_ids: Optional[List[Union[int, str]]] = Field( + default=None, description="List of new unit IDs for each split" + ) + + def get_full_spike_indices(self, sorting: BaseSorting): + """ + Get the full indices of the spikes in the split for different split modes. + """ + num_spikes = sorting.count_num_spikes_per_unit()[self.unit_id] + if self.split_mode == "indices": + # check the sum of split_indices is equal to num_spikes + num_spikes_in_split = sum(len(indices) for indices in self.split_indices) + if num_spikes_in_split != num_spikes: + # add remaining spike indices + full_spike_indices = list(self.split_indices) + existing_indices = np.concatenate(self.split_indices) + remaining_indices = np.setdiff1d(np.arange(num_spikes), existing_indices) + full_spike_indices.append(remaining_indices) + else: + full_spike_indices = self.split_indices + elif self.split_mode == "labels": + assert len(self.split_labels) == num_spikes, ( + f"In 'labels' mode, the number of split_labels ({len(self.split_labels)}) " + f"must match the number of spikes in the unit ({num_spikes})" + ) + # convert to spike indices + full_spike_indices = [] + for label in np.unique(self.split_labels): + label_indices = np.where(self.split_labels == label)[0] + full_spike_indices.append(label_indices) + + return full_spike_indices + + +class CurationModel(BaseModel): + supported_versions: Tuple[Literal["1"], Literal["2"]] = Field( + default=["1", "2"], description="Supported versions of the curation format" + ) + format_version: str = Field(..., description="Version of the curation format") + unit_ids: List[Union[int, str]] = Field(..., description="List of unit IDs") + label_definitions: Optional[Dict[str, LabelDefinition]] = Field( + default=None, description="Dictionary of label definitions" + ) + manual_labels: Optional[List[ManualLabel]] = Field(default=None, description="List of manual labels") + removed: Optional[List[Union[int, str]]] = Field(default=None, description="List of removed unit IDs") + merges: Optional[List[Merge]] = Field(default=None, description="List of merges") + splits: Optional[List[Split]] = Field(default=None, description="List of splits") + + @field_validator("label_definitions", mode="before") + def add_label_definition_name(cls, label_definitions): + if label_definitions is None: + return {} + if isinstance(label_definitions, dict): + label_definitions = dict(label_definitions) + for key in list(label_definitions.keys()): + if isinstance(label_definitions[key], dict): + label_definitions[key] = dict(label_definitions[key]) + label_definitions[key]["name"] = key + return label_definitions + return label_definitions + + @classmethod + def check_manual_labels(cls, values): + unit_ids = list(values["unit_ids"]) + manual_labels = values.get("manual_labels") + if manual_labels is None: + values["manual_labels"] = [] + else: + manual_labels = list(manual_labels) + for i, manual_label in enumerate(manual_labels): + manual_label = dict(manual_label) + unit_id = manual_label["unit_id"] + labels = manual_label.get("labels") + if labels is None: + labels = set(manual_label.keys()) - {"unit_id"} + manual_label["labels"] = {} + else: + manual_label["labels"] = {k: list(v) for k, v in labels.items()} + for label in labels: + if label not in values["label_definitions"]: + raise ValueError(f"Manual label {unit_id} has an unknown label {label}") + if label not in manual_label["labels"]: + if label in manual_label: + manual_label["labels"][label] = list(manual_label[label]) + else: + raise ValueError(f"Manual label {unit_id} has no value for label {label}") + if unit_id not in unit_ids: + raise ValueError(f"Manual label unit_id {unit_id} is not in the unit list") + manual_labels[i] = manual_label + values["manual_labels"] = manual_labels + return values + + @classmethod + def check_merges(cls, values): + unit_ids = list(values["unit_ids"]) + merges = values.get("merges") + if merges is None: + values["merges"] = [] + return values + + if isinstance(merges, dict): + # Convert dict format to list of Merge objects + merge_list = [] + for merge_new_id, merge_group in merges.items(): + merge_list.append({"merge_unit_group": list(merge_group), "merge_new_unit_id": merge_new_id}) + merges = merge_list + + # Make a copy of the list + merges = list(merges) + + # Convert items to Merge objects + for i, merge in enumerate(merges): + if isinstance(merge, list): + merge = {"merge_unit_group": list(merge)} + if isinstance(merge, dict): + merge = dict(merge) + if "merge_unit_group" in merge: + merge["merge_unit_group"] = list(merge["merge_unit_group"]) + merges[i] = Merge(**merge) + + # Validate merges + for merge in merges: + # Check unit ids exist + for unit_id in merge.merge_unit_group: + if unit_id not in unit_ids: + raise ValueError(f"Merge unit group unit_id {unit_id} is not in the unit list") + + # Check minimum group size + if len(merge.merge_unit_group) < 2: + raise ValueError("Merge unit groups must have at least 2 elements") + + # Check new unit id not already used + if merge.merge_new_unit_id is not None: + if merge.merge_new_unit_id in unit_ids: + raise ValueError(f"New unit ID {merge.merge_new_unit_id} is already in the unit list") + + values["merges"] = merges + return values + + @classmethod + def check_splits(cls, values): + unit_ids = list(values["unit_ids"]) + splits = values.get("splits") + if splits is None: + values["splits"] = [] + return values + + # Convert dict format to list format + if isinstance(splits, dict): + split_list = [] + for unit_id, split_data in splits.items(): + if isinstance(split_data[0], (list, np.ndarray)) if split_data else False: + split_list.append( + { + "unit_id": unit_id, + "split_mode": "indices", + "split_indices": [list(indices) for indices in split_data], + } + ) + else: + split_list.append({"unit_id": unit_id, "split_mode": "labels", "split_labels": list(split_data)}) + splits = split_list + + # Make a copy of the list + splits = list(splits) + + # Convert items to Split objects + for i, split in enumerate(splits): + if isinstance(split, dict): + split = dict(split) + if "split_indices" in split: + split["split_indices"] = [list(indices) for indices in split["split_indices"]] + if "split_labels" in split: + split["split_labels"] = list(split["split_labels"]) + if "split_new_unit_ids" in split: + split["split_new_unit_ids"] = list(split["split_new_unit_ids"]) + splits[i] = Split(**split) + + # Validate splits + for split in splits: + # Check unit exists + if split.unit_id not in unit_ids: + raise ValueError(f"Split unit_id {split.unit_id} is not in the unit list") + + # Validate based on mode + if split.split_mode == "indices": + if split.split_indices is None: + raise ValueError(f"Split unit {split.unit_id} has no split_indices defined") + if len(split.split_indices) < 1: + raise ValueError(f"Split unit {split.unit_id} has empty split_indices") + # Check no duplicate indices + all_indices = list(chain.from_iterable(split.split_indices)) + if len(all_indices) != len(set(all_indices)): + raise ValueError(f"Split unit {split.unit_id} has duplicate indices") + + elif split.split_mode == "labels": + if split.split_labels is None: + raise ValueError(f"Split unit {split.unit_id} has no split_labels defined") + if len(split.split_labels) == 0: + raise ValueError(f"Split unit {split.unit_id} has empty split_labels") + + # Validate new unit IDs + if split.split_new_unit_ids is not None: + if split.split_mode == "indices": + if len(split.split_new_unit_ids) != len(split.split_indices): + raise ValueError( + f"Number of new unit IDs does not match number of splits for unit {split.unit_id}" + ) + elif split.split_mode == "labels": + if len(split.split_new_unit_ids) != len(set(split.split_labels)): + raise ValueError( + f"Number of new unit IDs does not match number of unique labels for unit {split.unit_id}" + ) + + for new_id in split.split_new_unit_ids: + if new_id in unit_ids: + raise ValueError(f"New unit ID {new_id} is already in the unit list") + + values["splits"] = splits + return values + + @classmethod + def check_removed(cls, values): + unit_ids = list(values["unit_ids"]) + removed = values.get("removed") + if removed is None: + values["removed"] = [] + else: + removed = list(removed) + for unit_id in removed: + if unit_id not in unit_ids: + raise ValueError(f"Removed unit_id {unit_id} is not in the unit list") + values["removed"] = removed + return values + + @classmethod + def convert_old_format(cls, values): + format_version = values.get("format_version", "0") + if format_version == "0": + print("Conversion from format version v0 (sortingview) to v2") + if "mergeGroups" not in values.keys(): + values["mergeGroups"] = [] + merge_groups = values["mergeGroups"] + + first_unit_id = next(iter(values["labelsByUnit"].keys())) + if str.isdigit(first_unit_id): + unit_id_type = int + else: + unit_id_type = str + + all_units = [] + all_labels = [] + manual_labels = [] + general_cat = "all_labels" + for unit_id_, l_labels in values["labelsByUnit"].items(): + all_labels.extend(l_labels) + unit_id = unit_id_type(unit_id_) + if unit_id not in all_units: + all_units.append(unit_id) + manual_labels.append({"unit_id": unit_id, general_cat: list(l_labels)}) + labels_def = { + "all_labels": {"name": "all_labels", "label_options": list(set(all_labels)), "exclusive": False} + } + for merge_group in merge_groups: + all_units.extend(merge_group) + all_units = list(set(all_units)) + + values = { + "format_version": "2", + "unit_ids": values.get("unit_ids", all_units), + "label_definitions": labels_def, + "manual_labels": list(manual_labels), + "merges": [{"merge_unit_group": list(group)} for group in merge_groups], + "splits": [], + "removed": [], + } + elif values["format_version"] == "1": + merge_unit_groups = values.get("merge_unit_groups") + if merge_unit_groups is not None: + values["merges"] = [{"merge_unit_group": list(group)} for group in merge_unit_groups] + removed_units = values.get("removed_units") + if removed_units is not None: + values["removed"] = list(removed_units) + return values + + @model_validator(mode="before") + def validate_fields(cls, values): + values = dict(values) + values["label_definitions"] = values.get("label_definitions", {}) + values = cls.convert_old_format(values) + values = cls.check_manual_labels(values) + values = cls.check_merges(values) + values = cls.check_splits(values) + values = cls.check_removed(values) + return values + + @model_validator(mode="after") + def validate_curation_dict(cls, values): + if values.format_version not in values.supported_versions: + raise ValueError( + f"Format version {values.format_version} not supported. Only {values.supported_versions} are valid" + ) + + labeled_unit_set = set([lbl.unit_id for lbl in values.manual_labels]) if values.manual_labels else set() + merged_units_set = ( + set(chain.from_iterable(merge.merge_unit_group for merge in values.merges)) if values.merges else set() + ) + split_units_set = set(split.unit_id for split in values.splits) if values.splits else set() + removed_set = set(values.removed) if values.removed else set() + unit_ids = values.unit_ids + + unit_set = set(unit_ids) + if not labeled_unit_set.issubset(unit_set): + raise ValueError("Curation format: some labeled units are not in the unit list") + if not merged_units_set.issubset(unit_set): + raise ValueError("Curation format: some merged units are not in the unit list") + if not split_units_set.issubset(unit_set): + raise ValueError("Curation format: some split units are not in the unit list") + if not removed_set.issubset(unit_set): + raise ValueError("Curation format: some removed units are not in the unit list") + + # Check for units being merged multiple times + all_merging_groups = [set(merge.merge_unit_group) for merge in values.merges] if values.merges else [] + for gp_1, gp_2 in combinations(all_merging_groups, 2): + if len(gp_1.intersection(gp_2)) != 0: + raise ValueError("Curation format: some units belong to multiple merge groups") + + # Check no overlaps between operations + if len(removed_set.intersection(merged_units_set)) != 0: + raise ValueError("Curation format: some units were merged and deleted") + if len(removed_set.intersection(split_units_set)) != 0: + raise ValueError("Curation format: some units were split and deleted") + if len(merged_units_set.intersection(split_units_set)) != 0: + raise ValueError("Curation format: some units were both merged and split") + + for manual_label in values.manual_labels: + for label_key in values.label_definitions.keys(): + if label_key in manual_label.labels: + unit_id = manual_label.unit_id + label_value = manual_label.labels[label_key] + if not isinstance(label_value, list): + raise ValueError(f"Curation format: manual_labels {unit_id} is invalid should be a list") + + is_exclusive = values.label_definitions[label_key].exclusive + + if is_exclusive and not len(label_value) <= 1: + raise ValueError( + f"Curation format: manual_labels {unit_id} {label_key} are exclusive labels. {label_value} is invalid" + ) + + return values diff --git a/src/spikeinterface/curation/sortingview_curation.py b/src/spikeinterface/curation/sortingview_curation.py index f33051309c..fe21b72263 100644 --- a/src/spikeinterface/curation/sortingview_curation.py +++ b/src/spikeinterface/curation/sortingview_curation.py @@ -7,13 +7,11 @@ import numpy as np from pathlib import Path -from .curationsorting import CurationSorting from .curation_format import ( - convert_from_sortingview_curation_format_v0, apply_curation, curation_label_to_vectors, - clean_curation_dict, ) +from .curation_model import CurationModel, Merge def get_kachery(): @@ -82,15 +80,14 @@ def apply_sortingview_curation( except: raise Exception(f"Could not retrieve curation from SortingView uri: {uri_or_json}") - # convert to new format - if "format_version" not in curation_dict: - curation_dict = convert_from_sortingview_curation_format_v0(curation_dict) - unit_ids = sorting_or_analyzer.unit_ids + curation_dict["unit_ids"] = unit_ids + curation_model = CurationModel(**curation_dict) - # this is a hack because it was not in the old format - curation_dict["unit_ids"] = list(unit_ids) + if skip_merge: + curation_model.merges = [] + # this is a hack because it was not in the old format if exclude_labels is not None: assert include_labels is None, "Use either `include_labels` or `exclude_labels` to filter units." manual_labels = curation_label_to_vectors(curation_dict) @@ -99,7 +96,7 @@ def apply_sortingview_curation( remove_mask = manual_labels[k] removed_units.extend(unit_ids[remove_mask]) removed_units = np.unique(removed_units) - curation_dict["removed_units"] = removed_units + curation_model.removed = removed_units if include_labels is not None: manual_labels = curation_label_to_vectors(curation_dict) @@ -108,122 +105,21 @@ def apply_sortingview_curation( remove_mask = ~manual_labels[k] removed_units.extend(unit_ids[remove_mask]) removed_units = np.unique(removed_units) - curation_dict["removed_units"] = removed_units - - if skip_merge: - curation_dict["merge_unit_groups"] = [] - - # cleaner to ensure validity - curation_dict = clean_curation_dict(curation_dict) - - # apply - sorting_curated = apply_curation(sorting_or_analyzer, curation_dict, new_id_strategy="join") + curation_model.removed = removed_units + + # make merges and removed units + if len(curation_model.removed) > 0: + clean_merges = [] + for merge in curation_model.merges: + clean_merge = [] + for unit_id in merge.merge_unit_group: + if unit_id not in curation_model.removed: + clean_merge.append(unit_id) + if len(clean_merge) > 1: + clean_merges.append(Merge(merge_unit_group=clean_merge)) + curation_model.merges = clean_merges + + # apply curation + sorting_curated = apply_curation(sorting_or_analyzer, curation_model, new_id_strategy="join") return sorting_curated - - -# TODO @alessio you remove this after testing -def apply_sortingview_curation_legacy( - sorting, uri_or_json, exclude_labels=None, include_labels=None, skip_merge=False, verbose=False -): - """ - Apply curation from SortingView manual curation. - First, merges (if present) are applied. Then labels are loaded and units - are optionally filtered based on exclude_labels and include_labels. - - Parameters - ---------- - sorting : BaseSorting - The sorting object to be curated - uri_or_json : str or Path - The URI curation link from SortingView or the path to the curation json file - exclude_labels : list, default: None - Optional list of labels to exclude (e.g. ["reject", "noise"]). - Mutually exclusive with include_labels - include_labels : list, default: None - Optional list of labels to include (e.g. ["accept"]). - Mutually exclusive with exclude_labels, by default None - skip_merge : bool, default: False - If True, merges are not applied (only labels) - verbose : bool, default: False - If True, output is verbose - - Returns - ------- - sorting_curated : BaseSorting - The curated sorting - """ - ka = get_kachery() - curation_sorting = CurationSorting(sorting, make_graph=False, properties_policy="keep") - - # get sorting view curation - if Path(uri_or_json).suffix == ".json" and not str(uri_or_json).startswith("gh://"): - with open(uri_or_json, "r") as f: - sortingview_curation_dict = json.load(f) - else: - try: - sortingview_curation_dict = ka.load_json(uri=uri_or_json) - except: - raise Exception(f"Could not retrieve curation from SortingView uri: {uri_or_json}") - - unit_ids_dtype = sorting.unit_ids.dtype - - # STEP 1: merge groups - labels_dict = sortingview_curation_dict["labelsByUnit"] - if "mergeGroups" in sortingview_curation_dict and not skip_merge: - merge_groups = sortingview_curation_dict["mergeGroups"] - for merge_group in merge_groups: - # Store labels of units that are about to be merged - labels_to_inherit = [] - for unit in merge_group: - labels_to_inherit.extend(labels_dict.get(str(unit), [])) - labels_to_inherit = list(set(labels_to_inherit)) # Remove duplicates - - if verbose: - print(f"Merging {merge_group}") - if unit_ids_dtype.kind in ("U", "S"): - merge_group = [str(unit) for unit in merge_group] - # if unit dtype is str, set new id as "{unit1}-{unit2}" - new_unit_id = "-".join(merge_group) - curation_sorting.merge(merge_group, new_unit_id=new_unit_id) - else: - # in this case, the CurationSorting takes care of finding a new unused int - curation_sorting.merge(merge_group, new_unit_id=None) - new_unit_id = curation_sorting.max_used_id # merged unit id - labels_dict[str(new_unit_id)] = labels_to_inherit - - # STEP 2: gather and apply sortingview curation labels - # In sortingview, a unit is not required to have all labels. - # For example, the first 3 units could be labeled as "accept". - # In this case, the first 3 values of the property "accept" will be True, the rest False - - # Initialize the properties dictionary - properties = { - label: np.zeros(len(curation_sorting.current_sorting.unit_ids), dtype=bool) - for labels in labels_dict.values() - for label in labels - } - - # Populate the properties dictionary - for unit_index, unit_id in enumerate(curation_sorting.current_sorting.unit_ids): - unit_id_str = str(unit_id) - if unit_id_str in labels_dict: - for label in labels_dict[unit_id_str]: - properties[label][unit_index] = True - - for prop_name, prop_values in properties.items(): - curation_sorting.current_sorting.set_property(prop_name, prop_values) - - if include_labels is not None or exclude_labels is not None: - units_to_remove = [] - unit_ids = curation_sorting.current_sorting.unit_ids - assert include_labels or exclude_labels, "Use either `include_labels` or `exclude_labels` to filter units." - if include_labels: - for include_label in include_labels: - units_to_remove.extend(unit_ids[curation_sorting.current_sorting.get_property(include_label) == False]) - if exclude_labels: - for exclude_label in exclude_labels: - units_to_remove.extend(unit_ids[curation_sorting.current_sorting.get_property(exclude_label) == True]) - units_to_remove = np.unique(units_to_remove) - curation_sorting.remove_units(units_to_remove) - return curation_sorting.current_sorting diff --git a/src/spikeinterface/curation/tests/test_curation_format.py b/src/spikeinterface/curation/tests/test_curation_format.py index af9d8e1eac..d0126d5460 100644 --- a/src/spikeinterface/curation/tests/test_curation_format.py +++ b/src/spikeinterface/curation/tests/test_curation_format.py @@ -8,15 +8,34 @@ from spikeinterface.curation.curation_format import ( validate_curation_dict, - convert_from_sortingview_curation_format_v0, curation_label_to_vectors, curation_label_to_dataframe, apply_curation, ) +""" +v1 = { + 'format_version': '1', + 'unit_ids': List[int | str], + 'label_definitions': { + 'category_key1': + { + 'label_options': List[str], + 'exclusive': bool} + }, + 'manual_labels': [ + { + 'unit_id': str or int, + 'category_key1': List[str], + } + ], + 'removed_units': List[int | str] # Can not be in the merged_units + 'merge_unit_groups': List[List[int | str]], # one cell goes into at most one list +} -"""example = { - 'unit_ids': List[str, int], +v2 = { + 'format_version': '2', + 'unit_ids': List[int | int], 'label_definitions': { 'category_key1': { @@ -24,18 +43,38 @@ 'exclusive': bool} }, 'manual_labels': [ - {'unit_id': str or int, - category_key1': List[str], + { + 'unit_id': str | int, + 'category_key1': List[str], } ], - 'merge_unit_groups': List[List[unit_ids]], # one cell goes into at most one list - 'removed_units': List[unit_ids] # Can not be in the merged_units -} -""" + 'removed': List[unit_ids], # Can not be in the merged_units + 'merges': [ + { + 'merge_unit_group': List[unit_ids], + 'merge_new_unit_id': int | str (optional) + } + ], + 'splits': [ + { + 'unit_id': int | str + 'split_mode': 'indices' or 'labels', + 'split_indices': List[List[int]], + 'split_labels': List[int]], + 'split_new_unit_ids': List[int | str] + } + ] + +sortingview_curation = { + 'mergeGroups': List[List[int | str]], + 'labelsByUnit': { + 'unit_id': List[str] + } +""" curation_ids_int = { - "format_version": "1", + "format_version": "2", "unit_ids": [1, 2, 3, 6, 10, 14, 20, 31, 42], "label_definitions": { "quality": {"label_options": ["good", "noise", "MUA", "artifact"], "exclusive": True}, @@ -48,19 +87,21 @@ {"unit_id": 1, "quality": ["good"]}, { "unit_id": 2, - "quality": [ - "noise", - ], + "quality": ["noise"], "putative_type": ["excitatory", "pyramidal"], }, {"unit_id": 3, "putative_type": ["inhibitory"]}, ], - "merge_unit_groups": [[3, 6], [10, 14, 20]], # one cell goes into at most one list - "removed_units": [31, 42], # Can not be in the merged_units + "merges": [{"merge_unit_group": [3, 6]}, {"merge_unit_group": [10, 14, 20]}], + "splits": [], + "removed": [31, 42], } +# Test dictionary format for merges +curation_ids_int_dict = {**curation_ids_int, "merges": {50: [3, 6], 51: [10, 14, 20]}} + curation_ids_str = { - "format_version": "1", + "format_version": "2", "unit_ids": ["u1", "u2", "u3", "u6", "u10", "u14", "u20", "u31", "u42"], "label_definitions": { "quality": {"label_options": ["good", "noise", "MUA", "artifact"], "exclusive": True}, @@ -73,40 +114,78 @@ {"unit_id": "u1", "quality": ["good"]}, { "unit_id": "u2", - "quality": [ - "noise", - ], + "quality": ["noise"], "putative_type": ["excitatory", "pyramidal"], }, {"unit_id": "u3", "putative_type": ["inhibitory"]}, ], - "merge_unit_groups": [["u3", "u6"], ["u10", "u14", "u20"]], # one cell goes into at most one list - "removed_units": ["u31", "u42"], # Can not be in the merged_units + "merges": [{"merge_unit_group": ["u3", "u6"]}, {"merge_unit_group": ["u10", "u14", "u20"]}], + "splits": [], + "removed": ["u31", "u42"], } +# Test dictionary format for merges with string IDs +curation_ids_str_dict = {**curation_ids_str, "merges": {"u50": ["u3", "u6"], "u51": ["u10", "u14", "u20"]}} + # This is a failure example with duplicated merge duplicate_merge = curation_ids_int.copy() duplicate_merge["merge_unit_groups"] = [[3, 6, 10], [10, 14, 20]] +# Test with splits +curation_with_splits = { + "format_version": "2", + "unit_ids": [1, 2, 3, 6, 10, 14, 20, 31, 42], + "label_definitions": { + "quality": {"label_options": ["good", "noise", "MUA", "artifact"], "exclusive": True}, + "putative_type": { + "label_options": ["excitatory", "inhibitory", "pyramidal", "mitral"], + "exclusive": False, + }, + }, + "manual_labels": [ + {"unit_id": 2, "quality": ["good"], "putative_type": ["excitatory", "pyramidal"]}, + ], + "splits": [{"unit_id": 2, "split_mode": "indices", "split_indices": [[0, 1, 2], [3, 4, 5]]}], +} + +# Test dictionary format for splits +curation_with_splits_dict = {**curation_ids_int, "splits": {2: [[0, 1, 2], [3, 4, 5]]}} + +# This is a failure example with duplicated merge +duplicate_merge = {**curation_ids_int, "merges": [{"merge_unit_group": [3, 6, 10]}, {"merge_unit_group": [10, 14, 20]}]} # This is a failure example with unit 3 both in removed and merged -merged_and_removed = curation_ids_int.copy() -merged_and_removed["merge_unit_groups"] = [[3, 6], [10, 14, 20]] -merged_and_removed["removed_units"] = [3, 31, 42] +merged_and_removed = { + **curation_ids_int, + "merges": [{"merge_unit_group": [3, 6]}, {"merge_unit_group": [10, 14, 20]}], + "removed": [3, 31, 42], +} -# this is a failure because unit 99 is not in the initial list -unknown_merged_unit = curation_ids_int.copy() -unknown_merged_unit["merge_unit_groups"] = [[3, 6, 99], [10, 14, 20]] +# This is a failure because unit 99 is not in the initial list +unknown_merged_unit = { + **curation_ids_int, + "merges": [{"merge_unit_group": [3, 6, 99]}, {"merge_unit_group": [10, 14, 20]}], +} -# this is a failure because unit 99 is not in the initial list -unknown_removed_unit = curation_ids_int.copy() -unknown_removed_unit["removed_units"] = [31, 42, 99] +# This is a failure because unit 99 is not in the initial list +unknown_removed_unit = {**curation_ids_int, "removed": [31, 42, 99]} def test_curation_format_validation(): + # Test basic formats + print(curation_ids_int) validate_curation_dict(curation_ids_int) + print(curation_ids_int) validate_curation_dict(curation_ids_str) + # Test dictionary formats + validate_curation_dict(curation_ids_int_dict) + validate_curation_dict(curation_ids_str_dict) + + # Test splits + validate_curation_dict(curation_with_splits) + validate_curation_dict(curation_with_splits_dict) + with pytest.raises(ValueError): # Raised because duplicated merged units validate_curation_dict(duplicate_merge) @@ -122,13 +201,13 @@ def test_curation_format_validation(): def test_to_from_json(): - json.loads(json.dumps(curation_ids_int, indent=4)) json.loads(json.dumps(curation_ids_str, indent=4)) + json.loads(json.dumps(curation_ids_int_dict, indent=4)) + json.loads(json.dumps(curation_with_splits, indent=4)) def test_convert_from_sortingview_curation_format_v0(): - parent_folder = Path(__file__).parent for filename in ( "sv-sorting-curation.json", @@ -136,18 +215,13 @@ def test_convert_from_sortingview_curation_format_v0(): "sv-sorting-curation-str.json", "sv-sorting-curation-false-positive.json", ): - json_file = parent_folder / filename with open(json_file, "r") as f: curation_v0 = json.load(f) - # print(curation_v0) - curation_v1 = convert_from_sortingview_curation_format_v0(curation_v0) - # print(curation_v1) - validate_curation_dict(curation_v1) + validate_curation_dict(curation_v0) def test_curation_label_to_vectors(): - labels = curation_label_to_vectors(curation_ids_int) assert "quality" in labels assert "excitatory" in labels @@ -158,35 +232,176 @@ def test_curation_label_to_vectors(): def test_curation_label_to_dataframe(): - df = curation_label_to_dataframe(curation_ids_int) assert "quality" in df.columns assert "excitatory" in df.columns print(df) df = curation_label_to_dataframe(curation_ids_str) - # print(df) + print(df) def test_apply_curation(): recording, sorting = generate_ground_truth_recording(durations=[10.0], num_units=9, seed=2205) - sorting._main_ids = np.array([1, 2, 3, 6, 10, 14, 20, 31, 42]) + sorting = sorting.rename_units([1, 2, 3, 6, 10, 14, 20, 31, 42]) analyzer = create_sorting_analyzer(sorting, recording, sparse=False) + # Test with list format sorting_curated = apply_curation(sorting, curation_ids_int) assert sorting_curated.get_property("quality", ids=[1])[0] == "good" assert sorting_curated.get_property("quality", ids=[2])[0] == "noise" assert sorting_curated.get_property("excitatory", ids=[2])[0] + # Test with dictionary format + sorting_curated = apply_curation(sorting, curation_ids_int_dict) + assert sorting_curated.get_property("quality", ids=[1])[0] == "good" + assert sorting_curated.get_property("quality", ids=[2])[0] == "noise" + assert sorting_curated.get_property("excitatory", ids=[2])[0] + + # Test analyzer analyzer_curated = apply_curation(analyzer, curation_ids_int) assert "quality" in analyzer_curated.sorting.get_property_keys() -if __name__ == "__main__": - # test_curation_format_validation() - # test_to_from_json() - # test_convert_from_sortingview_curation_format_v0() - # test_curation_label_to_vectors() - # test_curation_label_to_dataframe() +def test_apply_curation_with_split(): + recording, sorting = generate_ground_truth_recording(durations=[10.0], num_units=9, seed=2205) + sorting = sorting.rename_units(np.array([1, 2, 3, 6, 10, 14, 20, 31, 42])) + analyzer = create_sorting_analyzer(sorting, recording, sparse=False) + + sorting_curated = apply_curation(sorting, curation_with_splits) + # the split indices are not complete, so an extra unit is added + assert len(sorting_curated.unit_ids) == len(sorting.unit_ids) + 2 + + assert 2 not in sorting_curated.unit_ids + split_unit_ids = [43, 44, 45] + for unit_id in split_unit_ids: + assert unit_id in sorting_curated.unit_ids + assert sorting_curated.get_property("quality", ids=[unit_id])[0] == "good" + assert sorting_curated.get_property("excitatory", ids=[unit_id])[0] + assert sorting_curated.get_property("pyramidal", ids=[unit_id])[0] + + analyzer_curated = apply_curation(analyzer, curation_with_splits) + assert len(analyzer_curated.sorting.unit_ids) == len(analyzer.sorting.unit_ids) + 2 + + assert 2 not in analyzer_curated.unit_ids + for unit_id in split_unit_ids: + assert unit_id in analyzer_curated.unit_ids + assert analyzer_curated.sorting.get_property("quality", ids=[unit_id])[0] == "good" + assert analyzer_curated.sorting.get_property("excitatory", ids=[unit_id])[0] + assert analyzer_curated.sorting.get_property("pyramidal", ids=[unit_id])[0] + +def test_apply_curation_with_split_multi_segment(): + recording, sorting = generate_ground_truth_recording(durations=[10.0, 10.0], num_units=9, seed=2205) + sorting = sorting.rename_units(np.array([1, 2, 3, 6, 10, 14, 20, 31, 42])) + analyzer = create_sorting_analyzer(sorting, recording, sparse=False) + num_segments = sorting.get_num_segments() + + curation_with_splits_multi_segment = curation_with_splits.copy() + + # we make a split so that each subsplit will have all spikes from different segments + split_unit_id = curation_with_splits_multi_segment["splits"][0]["unit_id"] + sv = sorting.to_spike_vector() + unit_index = sorting.id_to_index(split_unit_id) + spikes_from_split_unit = sv[sv["unit_index"] == unit_index] + + split_indices = [] + cum_spikes = 0 + for segment_index in range(num_segments): + spikes_in_segment = spikes_from_split_unit[spikes_from_split_unit["segment_index"] == segment_index] + split_indices.append(np.arange(0, len(spikes_in_segment)) + cum_spikes) + cum_spikes += len(spikes_in_segment) + + curation_with_splits_multi_segment["splits"][0]["split_indices"] = split_indices + + sorting_curated = apply_curation(sorting, curation_with_splits_multi_segment) + + assert len(sorting_curated.unit_ids) == len(sorting.unit_ids) + 1 + assert 2 not in sorting_curated.unit_ids + assert 43 in sorting_curated.unit_ids + assert 44 in sorting_curated.unit_ids + + # check that spike trains are correctly split across segments + for seg_index in range(num_segments): + st_43 = sorting_curated.get_unit_spike_train(43, segment_index=seg_index) + st_44 = sorting_curated.get_unit_spike_train(44, segment_index=seg_index) + if seg_index == 0: + assert len(st_43) > 0 + assert len(st_44) == 0 + else: + assert len(st_43) == 0 + assert len(st_44) > 0 + + +def test_apply_curation_splits_with_mask(): + recording, sorting = generate_ground_truth_recording(durations=[10.0], num_units=9, seed=2205) + sorting = sorting.rename_units(np.array([1, 2, 3, 6, 10, 14, 20, 31, 42])) + analyzer = create_sorting_analyzer(sorting, recording, sparse=False) + + # Get number of spikes for unit 2 + num_spikes = sorting.count_num_spikes_per_unit()[2] + + # Create split labels that assign spikes to 3 different clusters + split_labels = np.zeros(num_spikes, dtype=int) + split_labels[: num_spikes // 3] = 0 # First third to cluster 0 + split_labels[num_spikes // 3 : 2 * num_spikes // 3] = 1 # Second third to cluster 1 + split_labels[2 * num_spikes // 3 :] = 2 # Last third to cluster 2 + + curation_with_mask_split = { + "format_version": "2", + "unit_ids": [1, 2, 3, 6, 10, 14, 20, 31, 42], + "label_definitions": { + "quality": {"label_options": ["good", "noise", "MUA", "artifact"], "exclusive": True}, + "putative_type": { + "label_options": ["excitatory", "inhibitory", "pyramidal", "mitral"], + "exclusive": False, + }, + }, + "manual_labels": [ + {"unit_id": 2, "quality": ["good"], "putative_type": ["excitatory", "pyramidal"]}, + ], + "splits": [ + { + "unit_id": 2, + "split_mode": "labels", + "split_labels": split_labels.tolist(), + "split_new_unit_ids": [43, 44, 45], + } + ], + } + + sorting_curated = apply_curation(sorting, curation_with_mask_split) + + # Check results + assert len(sorting_curated.unit_ids) == len(sorting.unit_ids) + 2 # Original units - 1 (split) + 3 (new) + assert 2 not in sorting_curated.unit_ids # Original unit should be removed + + # Check new split units + split_unit_ids = [43, 44, 45] + for unit_id in split_unit_ids: + assert unit_id in sorting_curated.unit_ids + # Check properties are propagated + assert sorting_curated.get_property("quality", ids=[unit_id])[0] == "good" + assert sorting_curated.get_property("excitatory", ids=[unit_id])[0] + assert sorting_curated.get_property("pyramidal", ids=[unit_id])[0] + + # Check analyzer + analyzer_curated = apply_curation(analyzer, curation_with_mask_split) + assert len(analyzer_curated.sorting.unit_ids) == len(analyzer.sorting.unit_ids) + 2 + + # Verify split sizes + spike_counts = analyzer_curated.sorting.count_num_spikes_per_unit() + assert spike_counts[43] == num_spikes // 3 # First third + assert spike_counts[44] == num_spikes // 3 # Second third + assert spike_counts[45] == num_spikes - 2 * (num_spikes // 3) # Remainder + + +if __name__ == "__main__": + test_curation_format_validation() + test_to_from_json() + test_convert_from_sortingview_curation_format_v0() + test_curation_label_to_vectors() + test_curation_label_to_dataframe() test_apply_curation() + test_apply_curation_with_split_multi_segment() + test_apply_curation_splits_with_mask() diff --git a/src/spikeinterface/curation/tests/test_curation_model.py b/src/spikeinterface/curation/tests/test_curation_model.py new file mode 100644 index 0000000000..7354ac1892 --- /dev/null +++ b/src/spikeinterface/curation/tests/test_curation_model.py @@ -0,0 +1,288 @@ +import pytest + +from pydantic import ValidationError +import numpy as np + +from spikeinterface.curation.curation_model import CurationModel, LabelDefinition + + +# Test data for format version +def test_format_version(): + # Valid format version + CurationModel(format_version="1", unit_ids=[1, 2, 3]) + + # Invalid format version + with pytest.raises(ValidationError): + CurationModel(format_version="3", unit_ids=[1, 2, 3]) + with pytest.raises(ValidationError): + CurationModel(format_version="0.1", unit_ids=[1, 2, 3]) + + +# Test data for label definitions +def test_label_definitions(): + valid_label_def = { + "format_version": "2", + "unit_ids": [1, 2, 3], + "label_definitions": { + "quality": LabelDefinition(name="quality", label_options=["good", "noise"], exclusive=True), + "tags": LabelDefinition(name="tags", label_options=["burst", "slow", "fast"], exclusive=False), + }, + } + + model = CurationModel(**valid_label_def) + assert "quality" in model.label_definitions + assert model.label_definitions["quality"].name == "quality" + assert model.label_definitions["quality"].exclusive is True + + # Test invalid label definition + with pytest.raises(ValidationError): + LabelDefinition(name="quality", label_options=[], exclusive=True) # Empty options should be invalid + + +# Test manual labels +def test_manual_labels(): + valid_labels = { + "format_version": "2", + "unit_ids": [1, 2, 3], + "label_definitions": { + "quality": LabelDefinition(name="quality", label_options=["good", "noise"], exclusive=True), + "tags": LabelDefinition(name="tags", label_options=["burst", "slow", "fast"], exclusive=False), + }, + "manual_labels": [ + {"unit_id": 1, "labels": {"quality": ["good"], "tags": ["burst", "fast"]}}, + {"unit_id": 2, "labels": {"quality": ["noise"]}}, + ], + } + + model = CurationModel(**valid_labels) + assert len(model.manual_labels) == 2 + + # Test invalid unit ID + invalid_unit = { + "format_version": "2", + "unit_ids": [1, 2, 3], + "label_definitions": { + "quality": LabelDefinition(name="quality", label_options=["good", "noise"], exclusive=True) + }, + "manual_labels": [{"unit_id": 4, "labels": {"quality": ["good"]}}], # Non-existent unit + } + with pytest.raises(ValidationError): + CurationModel(**invalid_unit) + + # Test violation of exclusive label + invalid_exclusive = { + "format_version": "2", + "unit_ids": [1, 2, 3], + "label_definitions": { + "quality": LabelDefinition(name="quality", label_options=["good", "noise"], exclusive=True) + }, + "manual_labels": [ + {"unit_id": 1, "labels": {"quality": ["good", "noise"]}} # Multiple values for exclusive label + ], + } + with pytest.raises(ValidationError): + CurationModel(**invalid_exclusive) + + +# Test merge functionality +def test_merge_units(): + # Test list format + valid_merge = { + "format_version": "2", + "unit_ids": [1, 2, 3, 4], + "merges": [ + {"merge_unit_group": [1, 2], "merge_new_unit_id": 5}, + {"merge_unit_group": [3, 4], "merge_new_unit_id": 6}, + ], + } + + model = CurationModel(**valid_merge) + assert len(model.merges) == 2 + assert model.merges[0].merge_new_unit_id == 5 + assert model.merges[1].merge_new_unit_id == 6 + + # Test dictionary format + valid_merge_dict = {"format_version": "2", "unit_ids": [1, 2, 3, 4], "merges": {5: [1, 2], 6: [3, 4]}} + + model = CurationModel(**valid_merge_dict) + assert len(model.merges) == 2 + merge_new_ids = {merge.merge_new_unit_id for merge in model.merges} + assert merge_new_ids == {5, 6} + + # Test list format + valid_merge_list = { + "format_version": "2", + "unit_ids": [1, 2, 3, 4], + "merges": [[1, 2], [3, 4]], # Merge each pair into a new unit + } + model = CurationModel(**valid_merge_list) + assert len(model.merges) == 2 + + # Test invalid merge group (single unit) + invalid_merge_group = { + "format_version": "2", + "unit_ids": [1, 2, 3], + "merges": [{"merge_unit_group": [1], "merge_new_unit_id": 4}], + } + with pytest.raises(ValidationError): + CurationModel(**invalid_merge_group) + + # Test overlapping merge groups + invalid_overlap = { + "format_version": "2", + "unit_ids": [1, 2, 3], + "merges": [ + {"merge_unit_group": [1, 2], "merge_new_unit_id": 4}, + {"merge_unit_group": [2, 3], "merge_new_unit_id": 5}, + ], + } + with pytest.raises(ValidationError): + CurationModel(**invalid_overlap) + + +# Test split functionality +def test_split_units(): + # Test indices mode with list format + valid_split_indices = { + "format_version": "2", + "unit_ids": [1, 2, 3], + "splits": [ + { + "unit_id": 1, + "split_mode": "indices", + "split_indices": [[0, 1, 2], [3, 4, 5]], + "split_new_unit_ids": [4, 5], + } + ], + } + + model = CurationModel(**valid_split_indices) + assert len(model.splits) == 1 + assert model.splits[0].split_mode == "indices" + assert len(model.splits[0].split_indices) == 2 + + # Test labels mode with list format + valid_split_labels = { + "format_version": "2", + "unit_ids": [1, 2, 3], + "splits": [ + {"unit_id": 1, "split_mode": "labels", "split_labels": [0, 0, 1, 1, 0, 2], "split_new_unit_ids": [4, 5, 6]} + ], + } + + model = CurationModel(**valid_split_labels) + assert len(model.splits) == 1 + assert model.splits[0].split_mode == "labels" + assert len(set(model.splits[0].split_labels)) == 3 + + # Test dictionary format with indices + valid_split_dict = { + "format_version": "2", + "unit_ids": [1, 2, 3], + "splits": { + 1: [[0, 1, 2], [3, 4, 5]], # Split unit 1 into two parts + 2: [[0, 1], [2, 3], [4, 5]], # Split unit 2 into three parts + }, + } + + model = CurationModel(**valid_split_dict) + assert len(model.splits) == 2 + assert all(split.split_mode == "indices" for split in model.splits) + + # Test invalid unit ID + invalid_unit_id = { + "format_version": "2", + "unit_ids": [1, 2, 3], + "splits": [{"unit_id": 4, "split_mode": "indices", "split_indices": [[0, 1], [2, 3]]}], # Non-existent unit + } + with pytest.raises(ValidationError): + CurationModel(**invalid_unit_id) + + # Test invalid new unit IDs count for indices mode + invalid_new_ids = { + "format_version": "2", + "unit_ids": [1, 2, 3], + "splits": [ + { + "unit_id": 1, + "split_mode": "indices", + "split_indices": [[0, 1], [2, 3]], + "split_new_unit_ids": [4], # Should have 2 new IDs for 2 splits + } + ], + } + with pytest.raises(ValidationError): + CurationModel(**invalid_new_ids) + + +# Test removed units +def test_removed_units(): + valid_remove = {"format_version": "2", "unit_ids": [1, 2, 3], "removed": [2]} + + model = CurationModel(**valid_remove) + assert len(model.removed) == 1 + + # Test removing non-existent unit + invalid_remove = {"format_version": "2", "unit_ids": [1, 2, 3], "removed": [4]} # Non-existent unit + with pytest.raises(ValidationError): + CurationModel(**invalid_remove) + + # Test conflict between merge and remove + invalid_merge_remove = { + "format_version": "2", + "unit_ids": [1, 2, 3], + "merges": [{"merge_unit_group": [1, 2], "merge_new_unit_id": 4}], + "removed": [1], # Unit is both merged and removed + } + with pytest.raises(ValidationError): + CurationModel(**invalid_merge_remove) + + +# Test complete model with multiple operations +def test_complete_model(): + complete_model = { + "format_version": "2", + "unit_ids": [1, 2, 3, 4, 5], + "label_definitions": { + "quality": LabelDefinition(name="quality", label_options=["good", "noise"], exclusive=True), + "tags": LabelDefinition(name="tags", label_options=["burst", "slow"], exclusive=False), + }, + "manual_labels": [{"unit_id": 1, "labels": {"quality": ["good"], "tags": ["burst"]}}], + "merges": [{"merge_unit_group": [2, 3], "merge_new_unit_id": 6}], + "splits": [ + {"unit_id": 4, "split_mode": "indices", "split_indices": [[0, 1], [2, 3]], "split_new_unit_ids": [7, 8]} + ], + "removed": [5], + } + + model = CurationModel(**complete_model) + assert model.format_version == "2" + assert len(model.unit_ids) == 5 + assert len(model.label_definitions) == 2 + assert len(model.manual_labels) == 1 + assert len(model.merges) == 1 + assert len(model.splits) == 1 + assert len(model.removed) == 1 + + # Test dictionary format for complete model + complete_model_dict = { + "format_version": "2", + "unit_ids": [1, 2, 3, 4, 5], + "label_definitions": { + "quality": LabelDefinition(name="quality", label_options=["good", "noise"], exclusive=True), + "tags": LabelDefinition(name="tags", label_options=["burst", "slow"], exclusive=False), + }, + "manual_labels": [{"unit_id": 1, "labels": {"quality": ["good"], "tags": ["burst"]}}], + "merges": {6: [2, 3]}, + "splits": {4: [[0, 1], [2, 3]]}, + "removed": [5], + } + + model = CurationModel(**complete_model_dict) + assert model.format_version == "2" + assert len(model.unit_ids) == 5 + assert len(model.label_definitions) == 2 + assert len(model.manual_labels) == 1 + assert len(model.merges) == 1 + assert len(model.splits) == 1 + assert len(model.removed) == 1 diff --git a/src/spikeinterface/postprocessing/amplitude_scalings.py b/src/spikeinterface/postprocessing/amplitude_scalings.py index 278151a930..ff926a998d 100644 --- a/src/spikeinterface/postprocessing/amplitude_scalings.py +++ b/src/spikeinterface/postprocessing/amplitude_scalings.py @@ -128,6 +128,9 @@ def _merge_extension_data( return new_data + def _split_extension_data(self, split_units, new_unit_ids, new_sorting_analyzer, verbose=False, **job_kwargs): + return self.data.copy() + def _get_pipeline_nodes(self): recording = self.sorting_analyzer.recording diff --git a/src/spikeinterface/postprocessing/correlograms.py b/src/spikeinterface/postprocessing/correlograms.py index 5e30d7c68b..d41beb595f 100644 --- a/src/spikeinterface/postprocessing/correlograms.py +++ b/src/spikeinterface/postprocessing/correlograms.py @@ -154,9 +154,6 @@ def _merge_extension_data( if unit_involved_in_merge is False: old_to_new_unit_index_map[old_unit_index] = new_sorting_analyzer.sorting.id_to_index(old_unit) - need_to_append = False - delete_from = 1 - correlograms, new_bins = deepcopy(self.get_data()) for new_unit_id, merge_unit_group in zip(new_unit_ids, merge_unit_groups): @@ -188,6 +185,12 @@ def _merge_extension_data( new_data = dict(ccgs=new_correlograms, bins=new_bins) return new_data + def _split_extension_data(self, split_units, new_unit_ids, new_sorting_analyzer, verbose=False, **job_kwargs): + # TODO: for now we just copy + new_ccgs, new_bins = _compute_correlograms_on_sorting(new_sorting_analyzer.sorting, **self.params) + new_data = dict(ccgs=new_ccgs, bins=new_bins) + return new_data + def _run(self, verbose=False): ccgs, bins = _compute_correlograms_on_sorting(self.sorting_analyzer.sorting, **self.params) self.data["ccgs"] = ccgs diff --git a/src/spikeinterface/postprocessing/isi.py b/src/spikeinterface/postprocessing/isi.py index 542f829f21..03bd9d71a8 100644 --- a/src/spikeinterface/postprocessing/isi.py +++ b/src/spikeinterface/postprocessing/isi.py @@ -1,6 +1,7 @@ from __future__ import annotations import numpy as np +from itertools import chain from spikeinterface.core.sortinganalyzer import register_result_extension, AnalyzerExtension @@ -80,6 +81,29 @@ def _merge_extension_data( new_extension_data = dict(isi_histograms=new_isi_hists, bins=new_bins) return new_extension_data + def _split_extension_data(self, split_units, new_unit_ids, new_sorting_analyzer, verbose=False, **job_kwargs): + new_bins = self.data["bins"] + arr = self.data["isi_histograms"] + num_dims = arr.shape[1] + all_new_units = new_sorting_analyzer.unit_ids + new_isi_hists = np.zeros((len(all_new_units), num_dims), dtype=arr.dtype) + + # compute all new isi at once + new_unit_ids_f = list(chain(*new_unit_ids)) + new_sorting = new_sorting_analyzer.sorting.select_units(new_unit_ids_f) + only_new_hist, _ = _compute_isi_histograms(new_sorting, **self.params) + + for unit_ind, unit_id in enumerate(all_new_units): + if unit_id not in new_unit_ids_f: + keep_unit_index = self.sorting_analyzer.sorting.id_to_index(unit_id) + new_isi_hists[unit_ind, :] = arr[keep_unit_index, :] + else: + new_unit_index = new_sorting.id_to_index(unit_id) + new_isi_hists[unit_ind, :] = only_new_hist[new_unit_index, :] + + new_extension_data = dict(isi_histograms=new_isi_hists, bins=new_bins) + return new_extension_data + def _run(self, verbose=False): isi_histograms, bins = _compute_isi_histograms(self.sorting_analyzer.sorting, **self.params) self.data["isi_histograms"] = isi_histograms diff --git a/src/spikeinterface/postprocessing/principal_component.py b/src/spikeinterface/postprocessing/principal_component.py index dd3a8febd7..c340b7ff50 100644 --- a/src/spikeinterface/postprocessing/principal_component.py +++ b/src/spikeinterface/postprocessing/principal_component.py @@ -149,6 +149,10 @@ def _merge_extension_data( new_data[k] = v return new_data + def _split_extension_data(self, split_units, new_unit_ids, new_sorting_analyzer, verbose=False, **job_kwargs): + # splitting only changes random spikes assignments + return self.data.copy() + def get_pca_model(self): """ Returns the scikit-learn PCA model objects. diff --git a/src/spikeinterface/postprocessing/spike_amplitudes.py b/src/spikeinterface/postprocessing/spike_amplitudes.py index 577dc948c3..4b7a4e8eae 100644 --- a/src/spikeinterface/postprocessing/spike_amplitudes.py +++ b/src/spikeinterface/postprocessing/spike_amplitudes.py @@ -92,6 +92,10 @@ def _merge_extension_data( return new_data + def _split_extension_data(self, split_units, new_unit_ids, new_sorting_analyzer, verbose=False, **job_kwargs): + # splitting only changes random spikes assignments + return self.data.copy() + def _get_pipeline_nodes(self): recording = self.sorting_analyzer.recording diff --git a/src/spikeinterface/postprocessing/spike_locations.py b/src/spikeinterface/postprocessing/spike_locations.py index 6995fc04da..c33b9bb8aa 100644 --- a/src/spikeinterface/postprocessing/spike_locations.py +++ b/src/spikeinterface/postprocessing/spike_locations.py @@ -105,6 +105,10 @@ def _merge_extension_data( ### in a merged could be different. Should be discussed return dict(spike_locations=new_spike_locations) + def _split_extension_data(self, split_units, new_unit_ids, new_sorting_analyzer, verbose=False, **job_kwargs): + # splitting only changes random spikes assignments + return self.data.copy() + def _get_pipeline_nodes(self): from spikeinterface.sortingcomponents.peak_localization import get_localization_pipeline_nodes diff --git a/src/spikeinterface/postprocessing/template_metrics.py b/src/spikeinterface/postprocessing/template_metrics.py index d78b1e3809..e077dab482 100644 --- a/src/spikeinterface/postprocessing/template_metrics.py +++ b/src/spikeinterface/postprocessing/template_metrics.py @@ -8,6 +8,7 @@ import numpy as np import warnings +from itertools import chain from copy import deepcopy from spikeinterface.core.sortinganalyzer import register_result_extension, AnalyzerExtension @@ -195,6 +196,26 @@ def _merge_extension_data( new_data = dict(metrics=metrics) return new_data + def _split_extension_data(self, split_units, new_unit_ids, new_sorting_analyzer, verbose=False, **job_kwargs): + import pandas as pd + + metric_names = self.params["metric_names"] + old_metrics = self.data["metrics"] + + all_unit_ids = new_sorting_analyzer.unit_ids + new_unit_ids_f = list(chain(*new_unit_ids)) + not_new_ids = all_unit_ids[~np.isin(all_unit_ids, new_unit_ids_f)] + + metrics = pd.DataFrame(index=all_unit_ids, columns=old_metrics.columns) + + metrics.loc[not_new_ids, :] = old_metrics.loc[not_new_ids, :] + metrics.loc[new_unit_ids_f, :] = self._compute_metrics( + new_sorting_analyzer, new_unit_ids_f, verbose, metric_names, **job_kwargs + ) + + new_data = dict(metrics=metrics) + return new_data + def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, metric_names=None, **job_kwargs): """ Compute template metrics. diff --git a/src/spikeinterface/postprocessing/template_similarity.py b/src/spikeinterface/postprocessing/template_similarity.py index 1928e12edc..5469c7fe5a 100644 --- a/src/spikeinterface/postprocessing/template_similarity.py +++ b/src/spikeinterface/postprocessing/template_similarity.py @@ -2,6 +2,7 @@ import numpy as np import warnings +from itertools import chain from spikeinterface.core.sortinganalyzer import register_result_extension, AnalyzerExtension from spikeinterface.core.template_tools import get_dense_templates_array @@ -128,6 +129,58 @@ def _merge_extension_data( return dict(similarity=similarity) + def _split_extension_data(self, split_units, new_unit_ids, new_sorting_analyzer, verbose=False, **job_kwargs): + num_shifts = int(self.params["max_lag_ms"] * self.sorting_analyzer.sampling_frequency / 1000) + all_templates_array = get_dense_templates_array( + new_sorting_analyzer, return_scaled=self.sorting_analyzer.return_scaled + ) + + new_unit_ids_f = list(chain(*new_unit_ids)) + keep = np.isin(new_sorting_analyzer.unit_ids, new_unit_ids_f) + new_templates_array = all_templates_array[keep, :, :] + if new_sorting_analyzer.sparsity is None: + new_sparsity = None + else: + new_sparsity = ChannelSparsity( + new_sorting_analyzer.sparsity.mask[keep, :], new_unit_ids_f, new_sorting_analyzer.channel_ids + ) + + new_similarity = compute_similarity_with_templates_array( + new_templates_array, + all_templates_array, + method=self.params["method"], + num_shifts=num_shifts, + support=self.params["support"], + sparsity=new_sparsity, + other_sparsity=new_sorting_analyzer.sparsity, + ) + + old_similarity = self.data["similarity"] + + all_new_unit_ids = new_sorting_analyzer.unit_ids + n = all_new_unit_ids.size + similarity = np.zeros((n, n), dtype=old_similarity.dtype) + + # copy old similarity + for unit_ind1, unit_id1 in enumerate(all_new_unit_ids): + if unit_id1 not in new_unit_ids_f: + old_ind1 = self.sorting_analyzer.sorting.id_to_index(unit_id1) + for unit_ind2, unit_id2 in enumerate(all_new_unit_ids): + if unit_id2 not in new_unit_ids_f: + old_ind2 = self.sorting_analyzer.sorting.id_to_index(unit_id2) + s = self.data["similarity"][old_ind1, old_ind2] + similarity[unit_ind1, unit_ind2] = s + similarity[unit_ind1, unit_ind2] = s + + # insert new similarity both way + for unit_ind, unit_id in enumerate(all_new_unit_ids): + if unit_id in new_unit_ids_f: + new_index = list(new_unit_ids_f).index(unit_id) + similarity[unit_ind, :] = new_similarity[new_index, :] + similarity[:, unit_ind] = new_similarity[new_index, :] + + return dict(similarity=similarity) + def _run(self, verbose=False): num_shifts = int(self.params["max_lag_ms"] * self.sorting_analyzer.sampling_frequency / 1000) templates_array = get_dense_templates_array( diff --git a/src/spikeinterface/postprocessing/tests/test_multi_extensions.py b/src/spikeinterface/postprocessing/tests/test_multi_extensions.py index be0070d94a..8c512c2109 100644 --- a/src/spikeinterface/postprocessing/tests/test_multi_extensions.py +++ b/src/spikeinterface/postprocessing/tests/test_multi_extensions.py @@ -11,8 +11,59 @@ ) from spikeinterface.core.generate import inject_some_split_units +# even if this is in postprocessing, we make an extension for quality metrics +extension_dict = { + "noise_levels": dict(), + "random_spikes": dict(), + "waveforms": dict(), + "templates": dict(), + "principal_components": dict(), + "spike_amplitudes": dict(), + "template_similarity": dict(), + "correlograms": dict(), + "isi_histograms": dict(), + "amplitude_scalings": dict(handle_collisions=False), # otherwise hard mode could fail due to dropped spikes + "spike_locations": dict(method="center_of_mass"), # trick to avoid UserWarning + "unit_locations": dict(), + "template_metrics": dict(), + "quality_metrics": dict(metric_names=["firing_rate", "isi_violation", "snr"]), +} +extension_data_type = { + "noise_levels": None, + "templates": "unit", + "isi_histograms": "unit", + "unit_locations": "unit", + "spike_amplitudes": "spike", + "amplitude_scalings": "spike", + "spike_locations": "spike", + "quality_metrics": "pandas", + "template_metrics": "pandas", + "correlograms": "matrix", + "template_similarity": "matrix", + "principal_components": "random", + "waveforms": "random", + "random_spikes": "random_spikes", +} +data_with_miltiple_returns = ["isi_histograms", "correlograms"] +# due to incremental PCA, hard computation could result in different results for PCA +# the model is differents always +random_computation = ["principal_components"] +# for some extensions (templates, amplitude_scalings), since the templates slightly change for merges/splits +# we allow a relative tolerance +# (amplitud_scalings are the moste sensitive!) +extensions_with_rel_tolerance_merge = { + "amplitude_scalings": 1e-1, + "templates": 1e-3, + "template_similarity": 1e-3, + "unit_locations": 1e-3, + "template_metrics": 1e-3, + "quality_metrics": 1e-3, +} +extensions_with_rel_tolerance_splits = {"amplitude_scalings": 1e-1} -def get_dataset(): + +def get_dataset_to_merge(): + # generate a dataset with some split units to minimize merge errors recording, sorting = generate_ground_truth_recording( durations=[30.0], sampling_frequency=16000.0, @@ -20,6 +71,7 @@ def get_dataset(): num_units=10, generate_sorting_kwargs=dict(firing_rates=10.0, refractory_period_ms=4.0), noise_kwargs=dict(noise_levels=5.0, strategy="tile_pregenerated"), + generate_unit_locations_kwargs=dict(margin_um=10.0, minimum_z=2.0, maximum_z=15.0, minimum_distance=20), seed=2205, ) @@ -36,70 +88,65 @@ def get_dataset(): sort_by_amp = np.argsort(list(get_template_extremum_amplitude(analyzer_raw).values()))[::-1] split_ids = sorting.unit_ids[sort_by_amp][:3] - sorting_with_splits, other_ids = inject_some_split_units( + sorting_with_splits, split_unit_ids = inject_some_split_units( sorting, num_split=3, split_ids=split_ids, output_ids=True, seed=0 ) - return recording, sorting_with_splits, other_ids + return recording, sorting_with_splits, split_unit_ids + + +def get_dataset_to_split(): + # generate a dataset and return large unit to split to minimize split errors + recording, sorting = generate_ground_truth_recording( + durations=[30.0], + sampling_frequency=16000.0, + num_channels=10, + num_units=10, + generate_sorting_kwargs=dict(firing_rates=10.0, refractory_period_ms=4.0), + noise_kwargs=dict(noise_levels=5.0, strategy="tile_pregenerated"), + seed=2205, + ) + + channel_ids_as_integers = [id for id in range(recording.get_num_channels())] + unit_ids_as_integers = [id for id in range(sorting.get_num_units())] + recording = recording.rename_channels(new_channel_ids=channel_ids_as_integers) + sorting = sorting.rename_units(new_unit_ids=unit_ids_as_integers) + + # since templates are going to be averaged and this might be a problem for amplitude scaling + # we select the 3 units with the largest templates to split + analyzer_raw = create_sorting_analyzer(sorting, recording, format="memory", sparse=False) + analyzer_raw.compute(["random_spikes", "templates"]) + # select 3 largest templates to split + sort_by_amp = np.argsort(list(get_template_extremum_amplitude(analyzer_raw).values()))[::-1] + large_units = sorting.unit_ids[sort_by_amp][:2] + + return recording, sorting, large_units @pytest.fixture(scope="module") -def dataset(): - return get_dataset() +def dataset_to_merge(): + return get_dataset_to_merge() + + +@pytest.fixture(scope="module") +def dataset_to_split(): + return get_dataset_to_split() @pytest.mark.parametrize("sparse", [False, True]) -def test_SortingAnalyzer_merge_all_extensions(dataset, sparse): +def test_SortingAnalyzer_merge_all_extensions(dataset_to_merge, sparse): set_global_job_kwargs(n_jobs=1) - recording, sorting, other_ids = dataset + recording, sorting, other_ids = dataset_to_merge sorting_analyzer = create_sorting_analyzer(sorting, recording, format="memory", sparse=sparse) + extension_dict_merge = extension_dict.copy() # we apply the merges according to the artificial splits merges = [list(v) for v in other_ids.values()] split_unit_ids = np.ravel(merges) unmerged_unit_ids = sorting_analyzer.unit_ids[~np.isin(sorting_analyzer.unit_ids, split_unit_ids)] - # even if this is in postprocessing, we make an extension for quality metrics - extension_dict = { - "noise_levels": dict(), - "random_spikes": dict(), - "waveforms": dict(), - "templates": dict(), - "principal_components": dict(), - "spike_amplitudes": dict(), - "template_similarity": dict(), - "correlograms": dict(), - "isi_histograms": dict(), - "amplitude_scalings": dict(handle_collisions=False), # otherwise hard mode could fail due to dropped spikes - "spike_locations": dict(method="center_of_mass"), # trick to avoid UserWarning - "unit_locations": dict(), - "template_metrics": dict(), - "quality_metrics": dict(metric_names=["firing_rate", "isi_violation", "snr"]), - } - extension_data_type = { - "noise_levels": None, - "templates": "unit", - "isi_histograms": "unit", - "unit_locations": "unit", - "spike_amplitudes": "spike", - "amplitude_scalings": "spike", - "spike_locations": "spike", - "quality_metrics": "pandas", - "template_metrics": "pandas", - "correlograms": "matrix", - "template_similarity": "matrix", - "principal_components": "random", - "waveforms": "random", - "random_spikes": "random_spikes", - } - data_with_miltiple_returns = ["isi_histograms", "correlograms"] - - # due to incremental PCA, hard computation could result in different results for PCA - # the model is differents always - random_computation = ["principal_components"] - - sorting_analyzer.compute(extension_dict, n_jobs=1) + sorting_analyzer.compute(extension_dict_merge, n_jobs=1) # TODO: still some UserWarnings for n_jobs, where from? t0 = time.perf_counter() @@ -155,14 +202,95 @@ def test_SortingAnalyzer_merge_all_extensions(dataset, sparse): ) if ext not in random_computation: + if ext in extensions_with_rel_tolerance_merge: + rtol = extensions_with_rel_tolerance_merge[ext] + else: + rtol = 0 if extension_data_type[ext] == "pandas": data_hard_merged = data_hard_merged.dropna().to_numpy().astype("float") data_soft_merged = data_soft_merged.dropna().to_numpy().astype("float") if data_hard_merged.dtype.fields is None: - assert np.allclose(data_hard_merged, data_soft_merged, rtol=0.1) + if not np.allclose(data_hard_merged, data_soft_merged, rtol=rtol): + max_error = np.max(np.abs(data_hard_merged - data_soft_merged)) + raise Exception(f"Failed for {ext} - max error {max_error}") else: for f in data_hard_merged.dtype.fields: - assert np.allclose(data_hard_merged[f], data_soft_merged[f], rtol=0.1) + if not np.allclose(data_hard_merged[f], data_soft_merged[f], rtol=rtol): + max_error = np.max(np.abs(data_hard_merged[f] - data_soft_merged[f])) + raise Exception(f"Failed for {ext} - field {f} - max error {max_error}") + + +@pytest.mark.parametrize("sparse", [False, True]) +def test_SortingAnalyzer_split_all_extensions(dataset_to_split, sparse): + set_global_job_kwargs(n_jobs=1) + + recording, sorting, units_to_split = dataset_to_split + + sorting_analyzer = create_sorting_analyzer(sorting, recording, format="memory", sparse=sparse) + extension_dict_split = extension_dict.copy() + sorting_analyzer.compute(extension_dict, n_jobs=1) + + # we randomly apply splits (at half of spiketrain) + num_spikes = sorting.count_num_spikes_per_unit() + + unsplit_unit_ids = sorting_analyzer.unit_ids[~np.isin(sorting_analyzer.unit_ids, units_to_split)] + splits = {} + for unit in units_to_split: + splits[unit] = np.arange(num_spikes[unit] // 2) + + analyzer_split, split_unit_ids = sorting_analyzer.split_units(split_units=splits, return_new_unit_ids=True) + split_unit_ids = list(np.concatenate(split_unit_ids)) + + # also do a full recopute + analyzer_hard = create_sorting_analyzer(analyzer_split.sorting, recording, format="memory", sparse=sparse) + # we propagate random spikes to avoid random spikes to be recomputed + extension_dict_ = extension_dict_split.copy() + extension_dict_.pop("random_spikes") + analyzer_hard.extensions["random_spikes"] = analyzer_split.extensions["random_spikes"] + analyzer_hard.compute(extension_dict_, n_jobs=1) + + for ext in extension_dict: + # 1. check that data are exactly the same for unchanged units between original/split + data_original = sorting_analyzer.get_extension(ext).get_data() + data_split = analyzer_split.get_extension(ext).get_data() + data_recompute = analyzer_hard.get_extension(ext).get_data() + if ext in data_with_miltiple_returns: + data_original = data_original[0] + data_split = data_split[0] + data_recompute = data_recompute[0] + data_original_unsplit = get_extension_data_for_units( + sorting_analyzer, data_original, unsplit_unit_ids, extension_data_type[ext] + ) + data_split_unsplit = get_extension_data_for_units( + analyzer_split, data_split, unsplit_unit_ids, extension_data_type[ext] + ) + + np.testing.assert_array_equal(data_original_unsplit, data_split_unsplit) + + # 2. check that split data are the same for extension split and recompute + data_split_soft = get_extension_data_for_units( + analyzer_split, data_split, split_unit_ids, extension_data_type[ext] + ) + data_split_hard = get_extension_data_for_units( + analyzer_hard, data_recompute, split_unit_ids, extension_data_type[ext] + ) + if ext not in random_computation: + if ext in extensions_with_rel_tolerance_splits: + rtol = extensions_with_rel_tolerance_splits[ext] + else: + rtol = 0 + if extension_data_type[ext] == "pandas": + data_split_soft = data_split_soft.dropna().to_numpy().astype("float") + data_split_hard = data_split_hard.dropna().to_numpy().astype("float") + if data_split_hard.dtype.fields is None: + if not np.allclose(data_split_hard, data_split_soft, rtol=rtol): + max_error = np.max(np.abs(data_split_hard - data_split_soft)) + raise Exception(f"Failed for {ext} - max error {max_error}") + else: + for f in data_split_hard.dtype.fields: + if not np.allclose(data_split_hard[f], data_split_soft[f], rtol=rtol): + max_error = np.max(np.abs(data_split_hard[f] - data_split_soft[f])) + raise Exception(f"Failed for {ext} - field {f} - max error {max_error}") def get_extension_data_for_units(sorting_analyzer, data, unit_ids, ext_data_type): @@ -191,5 +319,5 @@ def get_extension_data_for_units(sorting_analyzer, data, unit_ids, ext_data_type if __name__ == "__main__": - dataset = get_dataset() + dataset = get_dataset_to_merge() test_SortingAnalyzer_merge_all_extensions(dataset, False) diff --git a/src/spikeinterface/postprocessing/unit_locations.py b/src/spikeinterface/postprocessing/unit_locations.py index 5618499770..ea297f7b6c 100644 --- a/src/spikeinterface/postprocessing/unit_locations.py +++ b/src/spikeinterface/postprocessing/unit_locations.py @@ -1,7 +1,7 @@ from __future__ import annotations import numpy as np -import warnings +from itertools import chain from spikeinterface.core.sortinganalyzer import register_result_extension, AnalyzerExtension from .localization_tools import _unit_location_methods @@ -88,6 +88,30 @@ def _merge_extension_data( return dict(unit_locations=unit_location) + def _split_extension_data(self, split_units, new_unit_ids, new_sorting_analyzer, verbose=False, **job_kwargs): + old_unit_locations = self.data["unit_locations"] + num_dims = old_unit_locations.shape[1] + + method = self.params.get("method") + method_kwargs = self.params.copy() + method_kwargs.pop("method") + func = _unit_location_methods[method] + new_unit_ids_f = list(chain(*new_unit_ids)) + new_unit_locations = func(new_sorting_analyzer, unit_ids=new_unit_ids_f, **method_kwargs) + assert new_unit_locations.shape[0] == len(new_unit_ids_f) + + all_new_unit_ids = new_sorting_analyzer.unit_ids + unit_location = np.zeros((len(all_new_unit_ids), num_dims), dtype=old_unit_locations.dtype) + for unit_index, unit_id in enumerate(all_new_unit_ids): + if unit_id not in new_unit_ids_f: + old_index = self.sorting_analyzer.sorting.id_to_index(unit_id) + unit_location[unit_index] = old_unit_locations[old_index] + else: + new_index = list(new_unit_ids_f).index(unit_id) + unit_location[unit_index] = new_unit_locations[new_index] + + return dict(unit_locations=unit_location) + def _run(self, verbose=False): method = self.params.get("method") method_kwargs = self.params.copy() diff --git a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py index 134849e70f..055fefc78c 100644 --- a/src/spikeinterface/qualitymetrics/quality_metric_calculator.py +++ b/src/spikeinterface/qualitymetrics/quality_metric_calculator.py @@ -3,6 +3,7 @@ from __future__ import annotations import warnings +from itertools import chain from copy import deepcopy import numpy as np @@ -158,6 +159,33 @@ def _merge_extension_data( new_data = dict(metrics=metrics) return new_data + def _split_extension_data(self, split_units, new_unit_ids, new_sorting_analyzer, verbose=False, **job_kwargs): + import pandas as pd + + metric_names = self.params["metric_names"] + old_metrics = self.data["metrics"] + + all_unit_ids = new_sorting_analyzer.unit_ids + new_unit_ids_f = list(chain(*new_unit_ids)) + not_new_ids = all_unit_ids[~np.isin(all_unit_ids, new_unit_ids_f)] + + # this creates a new metrics dictionary, but the dtype for everything will be + # object. So we will need to fix this later after computing metrics + metrics = pd.DataFrame(index=all_unit_ids, columns=old_metrics.columns) + metrics.loc[not_new_ids, :] = old_metrics.loc[not_new_ids, :] + metrics.loc[new_unit_ids_f, :] = self._compute_metrics( + new_sorting_analyzer, new_unit_ids_f, verbose, metric_names, **job_kwargs + ) + + # we need to fix the dtypes after we compute everything because we have nans + # we can iterate through the columns and convert them back to the dtype + # of the original quality dataframe. + for column in old_metrics.columns: + metrics[column] = metrics[column].astype(old_metrics[column].dtype) + + new_data = dict(metrics=metrics) + return new_data + def _compute_metrics(self, sorting_analyzer, unit_ids=None, verbose=False, metric_names=None, **job_kwargs): """ Compute quality metrics.