diff --git a/cebra/data/__init__.py b/cebra/data/__init__.py index 697801ed..145ff835 100644 --- a/cebra/data/__init__.py +++ b/cebra/data/__init__.py @@ -51,3 +51,4 @@ from cebra.data.multiobjective import * from cebra.data.datasets import * from cebra.data.helper import * +from cebra.data.masking import * diff --git a/cebra/data/base.py b/cebra/data/base.py index e71b05c4..51199cec 100644 --- a/cebra/data/base.py +++ b/cebra/data/base.py @@ -27,6 +27,7 @@ import torch import cebra.data.assets as cebra_data_assets +import cebra.data.masking as cebra_data_masking import cebra.distributions import cebra.io from cebra.data.datatypes import Batch @@ -36,7 +37,7 @@ __all__ = ["Dataset", "Loader"] -class Dataset(abc.ABC, cebra.io.HasDevice): +class Dataset(abc.ABC, cebra.io.HasDevice, cebra_data_masking.MaskedMixin): """Abstract base class for implementing a dataset. The class attributes provide information about the shape of the data when diff --git a/cebra/data/datasets.py b/cebra/data/datasets.py index 24735f47..8549000f 100644 --- a/cebra/data/datasets.py +++ b/cebra/data/datasets.py @@ -28,7 +28,9 @@ import numpy.typing as npt import torch +import cebra import cebra.data as cebra_data +import cebra.data.masking as cebra_data_masking import cebra.helper as cebra_helper import cebra.io as cebra_io from cebra.data.datatypes import Batch @@ -304,7 +306,7 @@ def _iter_property(self, attr): # TODO(stes): This should be a single session dataset? -class DatasetxCEBRA(cebra_io.HasDevice): +class DatasetxCEBRA(cebra_io.HasDevice, cebra_data_masking.MaskedMixin): """Dataset class for xCEBRA models. This class handles neural data and associated labels for xCEBRA models, providing @@ -435,3 +437,95 @@ def load_batch_contrastive(self, index: BatchIndex) -> Batch: positive=[self[idx] for idx in index.positive], negative=self[index.negative], ) + + +class UnifiedDataset(DatasetCollection): + """Multi session dataset made up of a list of datasets, considered as a unique session. + + Considering the sessions as a unique session, or pseudo-session, is used to later train a single + model for all the sessions, even if they originally contain a variable number of neurons. + To do that, we sample ref/pos/neg for each session and concatenate them along the neurons axis. + + For instance, for a batch size ``batch_size``, we sample ``(batch_size, num_neurons(session), offset)`` for + each type of samples (ref/pos/neg) and then concatenate so that the final :py:class:`cebra.data.datatypes.Batch` + is of shape ``(batch_size, total_num_neurons, offset)``, with ``total_num_neurons`` is the sum of all the + ``num_neurons(session)``. + """ + + def __init__(self, *datasets: cebra_data.SingleSessionDataset): + super().__init__(*datasets) + + @property + def input_dimension(self) -> int: + """Returns the sum of the input dimension for each session.""" + return np.sum([ + self.get_input_dimension(session_id) + for session_id in range(self.num_sessions) + ]) + + def _get_batches(self, index): + """Return the data at the specified index location.""" + return [ + cebra_data.Batch( + reference=self.get_session(session_id)[ + index.reference[session_id]], + positive=self.get_session(session_id)[ + index.positive[session_id]], + negative=self.get_session(session_id)[ + index.negative[session_id]], + ) for session_id in range(self.num_sessions) + ] + + def configure_for(self, model: "cebra.models.Model"): + """Configure the dataset offset for the provided model. + + Call this function before indexing the dataset. This sets the + :py:attr:`~.Dataset.offset` attribute of the dataset. + + Args: + model: The model to configure the dataset for. + """ + for i, session in enumerate(self.iter_sessions()): + session.configure_for(model) + + def load_batch(self, index: BatchIndex) -> Batch: + """Return the data at the specified index location. + + Concatenate batches for each sessions on the number of neurons axis. + + Args: + batches: List of :py:class:`cebra.data.datatypes.Batch` sampled for each session. An instance + :py:class:`cebra.data.datatypes.Batch` of the list is of shape ``(batch_size, num_neurons(session), offset)``. + + Returns: + A :py:class:`cebra.data.datatypes.Batch`, of shape ``(batch_size, total_num_neurons, offset)``, where + ``total_num_neurons`` is the sum of all the ``num_neurons(session)`` + """ + batches = self._get_batches(index) + + if hasattr(self, "apply_mask"): + # If the dataset has a mask, apply it to the data. + batch = cebra_data.Batch( + reference=self.apply_mask( + torch.cat([batch.reference for batch in batches], dim=1)), + positive=self.apply_mask( + torch.cat([batch.positive for batch in batches], dim=1)), + negative=self.apply_mask( + torch.cat([batch.negative for batch in batches], dim=1)), + ) + else: + batch = cebra_data.Batch( + reference=torch.cat([batch.reference for batch in batches], + dim=1), + positive=torch.cat([batch.positive for batch in batches], + dim=1), + negative=torch.cat([batch.negative for batch in batches], + dim=1), + ) + return batch + + def __getitem__(self, args) -> List[Batch]: + """Return a set of samples from all sessions.""" + + session_id, index = args + return self.get_session(session_id).__getitem__(index) diff --git a/cebra/data/masking.py b/cebra/data/masking.py new file mode 100644 index 00000000..90faf51e --- /dev/null +++ b/cebra/data/masking.py @@ -0,0 +1,425 @@ +# Please see LICENSE.md for the full license document: +# https://github.com/AdaptiveMotorControlLab/CEBRA/blob/main/LICENSE.md +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import abc +import random +from typing import Dict, List, Optional, Tuple, Union + +import numpy as np +import torch + +__all__ = [ + "MaskedMixin", "Mask", "RandomNeuronMask", "RandomTimestepMask", + "NeuronBlockMask", "TimeBlockMask" +] + + +class MaskedMixin: + """A mixin class for applying masking to data. + + Note: + This class is designed to be used as a mixin for other classes. + It provides functionality to apply masking to data. + The `set_masks` method should be called to set the masking types + and their corresponding probabilities. + """ + _masks = [] # a list of Mask instances + + def set_masks(self, masking: Optional[Dict[str, float]] = None) -> None: + """Set the mask type and probability for the dataset. + + Args: + masking (Dict[str, float]): A dictionary of masking types and their + corresponding required masking values. The keys are the names + of the Mask instances. + + Note: + By default, no masks are applied. + """ + if masking is not None: + for mask_key in masking: + if mask_key in globals(): + cls = globals()[mask_key] + self._masks = [ + m for m in self._masks if not isinstance(m, cls) + ] + self._masks.append(cls(masking[mask_key])) + else: + raise ValueError( + f"Mask type {mask_key} not supported. Supported types are {masking.keys()}" + ) + + def apply_mask(self, + data: torch.Tensor, + chunk_size: int = 1000) -> torch.Tensor: + """Apply masking to the input data. + + Note: + - By default, no masking. Else apply masking on the input data. + - Only one masking type can be applied at a time, but multiple + masking types can be set so that it alternates between them + across iterations. + - Masking is applied to the data in chunks to avoid memory issues. + + Args: + data (torch.Tensor): batch of size (batch_size, num_neurons, offset). + chunk_size (int): Number of rows to process at a time. + + Returns: + torch.Tensor: The masked data. + """ + if data.dim() == 2: + data = data.unsqueeze(0) + + if data.dim() != 3: + raise ValueError( + f"Data must be a 3D tensor, but got {data.dim()}D tensor.") + if data.dtype != torch.float32: + raise ValueError( + f"Data must be a float32 tensor, but got {data.dtype}.") + + # If masks is empty, return the data as is + if not self._masks: + return data + + sampled_mask = random.choice(self._masks) + mask = sampled_mask.apply_mask(data) + + num_chunks = (data.shape[0] + chunk_size - + 1) // chunk_size # Compute number of chunks + + for i in range(num_chunks): + start, end = i * chunk_size, min((i + 1) * chunk_size, + data.shape[0]) + data[start:end].mul_( + mask[start:end]) # apply mask in-place to save memory + + return data + + +class Mask: + + def __init__(self, masking_value: Union[float, List[float], Tuple[float]]): + self._check_masking_parameters(masking_value) + + @abc.abstractmethod + def apply_mask(self, data: torch.Tensor) -> torch.Tensor: + raise NotImplementedError + + @abc.abstractmethod + def _select_masking_params(): + raise NotImplementedError + + def _check_masking_parameters(self, masking_value: Union[float, List[float], + Tuple[float]]): + """ + The masking values are the masking ratio to apply. + It can be a single ratio, a list of ratio that will be picked randomly or + a tuple of (min, max, step_size) that will be used to create a list of ratios + from which to sample randomly. + """ + if isinstance(masking_value, float): + assert 0.0 < masking_value < 1.0, ( + f"Masking ratio {masking_value} for {self.__name__()} " + "should be between 0.0 and 1.0.") + + elif isinstance(masking_value, list): + assert all(isinstance(ratio, float) for ratio in masking_value), ( + f"Masking ratios {masking_value} for {self.__name__()} " + "should be between 0.0 and 1.0.") + assert all(0.0 < ratio < 1.0 for ratio in masking_value), ( + f"Masking ratios {masking_value} for {self.__name__()} " + "should be between 0.0 and 1.0.") + + elif isinstance(masking_value, tuple): + assert len(masking_value) == 3, ( + f"Masking ratios {masking_value} for {self.__name__()} " + "should be a tuple of (min, max, step).") + assert 0.0 <= masking_value[0] < masking_value[1] <= 1.0, ( + f"Masking ratios {masking_value} for {self.__name__()} " + "should be between 0.0 and 1.0.") + assert masking_value[2] < masking_value[1] - masking_value[0], ( + f"Masking step {masking_value[2]} for {self.__name__()} " + "should be between smaller than the diff between min " + f"({masking_value[0]}) and max ({masking_value[1]}).") + + else: + raise ValueError( + f"Masking ratio {masking_value} for {self.__name__()} " + "should be a float, list of floats or a tuple of (min, max, step)." + ) + + +class RandomNeuronMask(Mask): + + def __init__(self, masking_value: Union[float, List[float], Tuple[float]]): + super().__init__(masking_value) + self.mask_ratio = masking_value + + def __name__(self): + return "RandomNeuronMask" + + def apply_mask(self, data: torch.Tensor) -> torch.Tensor: + """ Apply random masking on the neuron dimension. + + Args: + data: batch of size (batch_size, n_neurons, offset). + mask_ratio: Proportion of neurons to mask. Default value 0.3 comes + from the MtM paper: https://arxiv.org/pdf/2407.14668v2 + + Returns: + torch.Tensor: The mask, a tensor of the same size as the input data with the + masked neurons set to 1. + """ + batch_size, n_neurons, offset_length = data.shape + mask_ratio = self._select_masking_params() + + # Random mask: shape [batch_size, n_neurons], different per batch and neurons + masked = torch.rand(batch_size, n_neurons, + device=data.device) < mask_ratio + return (~masked).int().unsqueeze(2).expand( + -1, -1, offset_length) # Expand to all timesteps + + def _select_masking_params(self) -> float: + """ + The masking values are the masking ratio to apply. + It can be a single ratio, a list of ratio that will be picked randomly or + a tuple of (min, max, step_size) that will be used to create a list of ratios + from which to sample randomly. + """ + if isinstance(self.mask_ratio, float): + selected_value = self.mask_ratio + + elif isinstance(self.mask_ratio, list): + selected_value = random.choice(self.mask_ratio) + + elif isinstance(self.mask_ratio, tuple): + min_val, max_val, step_size = self.mask_ratio + selected_value = random.choice( + np.arange(min_val, max_val + step_size, step_size).tolist()) + + else: + raise ValueError( + f"Masking ratio {self.mask_ratio} for {self.__name__()} " + "should be a float, list of floats or a tuple of (min, max, step)." + ) + + return selected_value + + +class RandomTimestepMask(Mask): + + def __init__(self, masking_value: Union[float, List[float], Tuple[float]]): + super().__init__(masking_value) + self.mask_ratio = masking_value + + def __name__(self): + return "RandomTimestepMask" + + def apply_mask(self, data: torch.Tensor) -> torch.Tensor: + """ Apply random masking on the time dimension. + + Args: + data: batch of size (batch_idx, feature_dim, seq_len). With seq_len + corresponding to the offset. + mask_ratio: Proportion of timesteps masked. Not necessarliy consecutive. + Default value 0.3 comes from the MtM paper: https://arxiv.org/pdf/2407.14668v2 + + Returns: + torch.Tensor: The mask, a tensor of the same size as the input data with the + masked neurons set to 1. + + """ + batch_idx, n_neurons, offset_length = data.shape + mask_ratio = self._select_masking_params() + + # Random mask: shape [batbatch_idxch_size, offset_length], different per batch and timestamp + masked = torch.rand(batch_idx, offset_length, + device=data.device) < mask_ratio + return (~masked).int().unsqueeze(1).expand(-1, n_neurons, + -1) # Expand to all neurons + + def _select_masking_params(self) -> float: + """ + The masking values are the masking ratio to apply. + It can be a single ratio, a list of ratio that will be picked randomly or + a tuple of (min, max, step_size) that will be used to create a list of ratios + from which to sample randomly. + """ + if isinstance(self.mask_ratio, float): + selected_value = self.mask_ratio + + elif isinstance(self.mask_ratio, list): + selected_value = random.choice(self.mask_ratio) + + elif isinstance(self.mask_ratio, tuple): + min_val, max_val, step_size = self.mask_ratio + selected_value = random.choice( + np.arange(min_val, max_val + step_size, step_size).tolist()) + + else: + raise ValueError( + f"Masking ratio {self.mask_ratio} for {self.__name__()} " + "should be a float, list of floats or a tuple of (min, max, step)." + ) + + return selected_value + + +class NeuronBlockMask(Mask): + + def __init__(self, masking_value: Union[float, List[float], Tuple[float]]): + super().__init__(masking_value) + self.mask_prop = masking_value + + def __name__(self): + return "NeuronBlockMask" + + def apply_mask(self, data: torch.Tensor) -> torch.Tensor: + """ Apply masking to a contiguous block of neurons. + + Args: + data: batch of size (batch_size, n_neurons, offset). + self.mask_prop: Proportion of neurons to mask. The neurons are masked in a + contiguous block. + + Returns: + torch.Tensor: The mask, a tensor of the same size as the input data with the + masked neurons set to 1. + """ + batch_size, n_neurons, offset_length = data.shape + + mask_prop = self._select_masking_params() + num_mask = int(n_neurons * mask_prop) + mask = torch.ones((batch_size, n_neurons), + dtype=torch.int, + device=data.device) + + if num_mask == 0: + return mask.unsqueeze(2) + + for batch_idx in range(batch_size): # Create a mask for each batch + # Select random the start index for the block of neurons to mask + start_idx = torch.randint(0, n_neurons - num_mask + 1, (1,)).item() + end_idx = min(start_idx + num_mask, n_neurons) + mask[batch_idx, start_idx:end_idx] = 0 # set masked neurons to 0 + + return mask.unsqueeze(2).expand( + -1, -1, offset_length) # Expand to all timesteps + + def _select_masking_params(self) -> float: + """ + The masking values are the masking ratio to apply. + It can be a single ratio, a list of ratio that will be picked randomly or + a tuple of (min, max, step_size) that will be used to create a list of ratios + from which to sample randomly. + """ + if isinstance(self.mask_prop, float): + selected_value = self.mask_prop + + elif isinstance(self.mask_prop, list): + selected_value = random.choice(self.mask_prop) + + elif isinstance(self.mask_prop, tuple): + min_val, max_val, step_size = self.mask_prop + selected_value = random.choice( + np.arange(min_val, max_val + step_size, step_size).tolist()) + + else: + raise ValueError( + f"Masking ratio {self.mask_prop} for {self.__name__()} " + "should be a float, list of floats or a tuple of (min, max, step)." + ) + + return selected_value + + +class TimeBlockMask(Mask): + + def __init__(self, masking_value: Union[float, List[float], Tuple[float]]): + super().__init__(masking_value) + self.sampled_rate, self.masked_seq_len = masking_value + + def __name__(self): + return "TimeBlockMask" + + def apply_mask(self, data: torch.Tensor) -> torch.Tensor: + """ Apply contiguous block masking on the time dimension. + + When choosing which block of timesteps to mask, each timestep is considered + a candidate starting time-step with probability ``self.sampled_rate`` where + ``self.masked_seq_len`` is the length of each masked span starting from the respective + time step. Sampled starting time steps are expanded to length ``self.masked_seq_len`` + and spans can overlap. Inspirede by the wav2vec 2.0 masking strategy. + + Default values from the wav2vec paper: https://arxiv.org/abs/2006.11477. + + Args: + data (torch.Tensor): The input tensor of shape (batch_size, seq_len, feature_dim). + self.sampled_rate (float): The probability of each time-step being a candidate for masking. + self.masked_seq_len (int): The length of each masked span starting from the sampled time-step. + + Returns: + torch.Tensor: A boolean mask of shape (batch_size, seq_len) where True + indicates masked positions. + """ + batch_size, n_neurons, offset_length = data.shape + + sampled_rate, masked_seq_len = self._select_masking_params() + + num_masked_starting_points = int(offset_length * sampled_rate) + mask = torch.ones((batch_size, offset_length), + dtype=int, + device=data.device) + for batch_idx in range(batch_size): + # Sample starting points for masking in the current batch + start_indices = torch.randperm( + offset_length, device=data.device)[:num_masked_starting_points] + + # Apply masking spans + for start in start_indices: + end = min(start + masked_seq_len, offset_length) + mask[batch_idx, start:end] = 0 # set masked timesteps to 0 + + return mask.unsqueeze(1).expand(-1, n_neurons, + -1) # Expand to all neurons + + def _check_masking_parameters(self, masking_value: Union[float, List[float], + Tuple[float]]): + """ + The masking values are the parameters for the timeblock masking. + It needs to be a tuple of (sampled_rate, masked_seq_len) + sampled_rate: The probability of each time-step being a candidate for masking. + masked_seq_len: The length of each masked span starting from the sampled time-step. + """ + assert isinstance(masking_value, tuple) and len(masking_value) == 2, ( + f"Masking parameters {masking_value} for {self.__name__()} " + "should be a tuple of (sampled_rate, masked_seq_len).") + assert 0.0 < masking_value[0] < 1.0 and isinstance( + masking_value[0], float), ( + f"Masking parameters {masking_value} for {self.__name__()} " + "should be between 0.0 and 1.0.") + assert masking_value[1] > 0 and isinstance(masking_value[1], int), ( + f"Masking parameters {masking_value} for {self.__name__()} " + "should be an integer greater than 0.") + + def _select_masking_params(self) -> float: + """ + The masking values are the masking ratio to apply. + It can be a single ratio, a list of ratio that will be picked randomly or + a tuple of (min, max, step_size) that will be used to create a list of ratios + from which to sample randomly. + """ + return self.sampled_rate, self.masked_seq_len diff --git a/cebra/data/multi_session.py b/cebra/data/multi_session.py index ad987e63..f33ad6ec 100644 --- a/cebra/data/multi_session.py +++ b/cebra/data/multi_session.py @@ -39,6 +39,7 @@ "ContinuousMultiSessionDataLoader", "DiscreteMultiSessionDataLoader", "MixedMultiSessionDataLoader", + "UnifiedLoader", ] @@ -95,15 +96,31 @@ def __getitem__(self, args) -> List[Batch]: def load_batch(self, index: BatchIndex) -> List[Batch]: """Return the data at the specified index location.""" - return [ - cebra_data.Batch( - reference=session[index.reference[session_id]], - positive=session[index.positive[session_id]], - negative=session[index.negative[session_id]], - index=index.index, - index_reversed=index.index_reversed, - ) for session_id, session in enumerate(self.iter_sessions()) - ] + + if hasattr(self, "apply_mask"): + batch = [ + cebra_data.Batch( + reference=self.apply_mask( + session[index.reference[session_id]]), + positive=self.apply_mask( + session[index.positive[session_id]]), + negative=self.apply_mask( + session[index.negative[session_id]]), + index=index.index, + index_reversed=index.index_reversed, + ) for session_id, session in enumerate(self.iter_sessions()) + ] + else: + batch = [ + cebra_data.Batch( + reference=session[index.reference[session_id]], + positive=session[index.positive[session_id]], + negative=session[index.negative[session_id]], + index=index.index, + index_reversed=index.index_reversed, + ) for session_id, session in enumerate(self.iter_sessions()) + ] + return batch def configure_for(self, model: "cebra.models.Model"): """Configure the dataset offset for the provided model. @@ -114,16 +131,15 @@ def configure_for(self, model: "cebra.models.Model"): Args: model: The model to configure the dataset for. """ - if not isinstance(model, nn.ModuleList): - raise ValueError( - "The model must be a nn.ModuleList to configure the dataset.") - if len(model) != self.num_sessions: - raise ValueError( - f"The model must have {self.num_sessions} sessions, but got {len(model)}." - ) - for i, session in enumerate(self.iter_sessions()): - session.configure_for(model[i]) + if isinstance(model, nn.ModuleList): + if len(model) != self.num_sessions: + raise ValueError( + f"The model must have {self.num_sessions} sessions, but got {len(model)}." + ) + session.configure_for(model[i]) + else: + session.configure_for(model) @dataclasses.dataclass @@ -187,3 +203,64 @@ def index(self): @dataclasses.dataclass class MixedMultiSessionDataLoader(MultiSessionLoader): pass + + +@dataclasses.dataclass +class UnifiedLoader(ContinuousMultiSessionDataLoader): + """Dataloader for multi-session datasets, considered as a single session. + + This class is used in pair with :py:class:`cebra.data.datasets.UnifiedDataset` + to sample from each session and train a single model on them, even if sessions have a + different number of neurons. + + To sample the reference and negative samples, a target session is randomly selected. Indexes + are unformly sampled in that first session. Then, indexes in the other sessions are samples + conditionally to the first session indexes, so that their corresponding auxiliary variables + are close. For the positive samples, they are sampled conditionally to the reference samples, + in their corresponding session only. + + Then, the ref/pos/neg samples are concatenated respectively, along the neurons axis (takes place + in the :py:class:`cebra.data.datasets.UnifiedDataset`). + + """ + + def __post_init__(self): + super().__post_init__() + self.sampler = cebra.distributions.UnifiedSampler( + self.dataset, self.time_offset) + + def get_indices(self, num_samples: int) -> BatchIndex: + """Sample and return the specified number of indices. + + The elements of the returned ``BatchIndex`` will be used to index the + ``dataset`` of this data loader. + + To sample the reference and negative samples, a target session is + randomly selected. Indexes are unformly sampled in that first + session. Then, indexes in the other sessions are samples conditionally + to the first session indexes, so that their corresponding auxiliary + variables are close. For the positive samples, they are sampled + conditionally to the reference samples, in their corresponding session + only. + + Args: + num_samples: The size of each of the reference, positive and + negative samples to sample. + + Returns: + Batch indices for the reference, positive and negative samples. + """ + ref_idx = self.sampler.sample_prior(self.batch_size) + neg_idx = self.sampler.sample_prior(self.batch_size) + + pos_idx = self.sampler.sample_conditional(ref_idx) + + ref_idx = torch.from_numpy(ref_idx).to(self.device) + neg_idx = torch.from_numpy(neg_idx).to(self.device) + pos_idx = torch.from_numpy(pos_idx).to(self.device) + + return BatchIndex( + reference=ref_idx, + positive=pos_idx, + negative=neg_idx, + ) diff --git a/cebra/data/single_session.py b/cebra/data/single_session.py index e3946089..7e4ad2fd 100644 --- a/cebra/data/single_session.py +++ b/cebra/data/single_session.py @@ -63,11 +63,21 @@ def __len__(self): def load_batch(self, index: BatchIndex) -> Batch: """Return the data at the specified index location.""" - return Batch( - positive=self[index.positive], - negative=self[index.negative], - reference=self[index.reference], - ) + + if hasattr(self, "apply_mask"): + # If the dataset has a mask, apply it to the data. + batch = Batch( + positive=self.apply_mask(self[index.positive]), + negative=self.apply_mask(self[index.negative]), + reference=self.apply_mask(self[index.reference]), + ) + else: + batch = Batch( + positive=self[index.positive], + negative=self[index.negative], + reference=self[index.reference], + ) + return batch @dataclasses.dataclass diff --git a/cebra/datasets/demo.py b/cebra/datasets/demo.py index a28f9fa1..8924412e 100644 --- a/cebra/datasets/demo.py +++ b/cebra/datasets/demo.py @@ -33,6 +33,7 @@ from cebra.datasets import register _DEFAULT_NUM_TIMEPOINTS = 1_000 +_NUMS_NEURAL = [3, 4, 5] class DemoDataset(cebra.data.SingleSessionDataset): @@ -117,7 +118,7 @@ class MultiDiscrete(cebra.data.DatasetCollection): def __init__( self, - nums_neural=[3, 4, 5], + nums_neural=_NUMS_NEURAL, num_timepoints=_DEFAULT_NUM_TIMEPOINTS, ): super().__init__(*[ @@ -131,7 +132,7 @@ class MultiContinuous(cebra.data.DatasetCollection): def __init__( self, - nums_neural=[3, 4, 5], + nums_neural=_NUMS_NEURAL, num_behavior=5, num_timepoints=_DEFAULT_NUM_TIMEPOINTS, ): @@ -146,8 +147,26 @@ def __init__( # @register("demo-mixed-multisession") class MultiMixed(cebra.data.DatasetCollection): - def __init__(self, nums_neural=[3, 4, 5], num_behavior=5): + def __init__(self, nums_neural=_NUMS_NEURAL, num_behavior=5): super().__init__(*[ DemoDatasetMixed(_DEFAULT_NUM_TIMEPOINTS, num_neural, num_behavior) for num_neural in nums_neural ]) + + +@register("demo-continuous-unified") +class DemoDatasetUnified(cebra.data.UnifiedDataset): + + def __init__( + self, + nums_neural=_NUMS_NEURAL, + num_behavior=5, + num_timepoints=_DEFAULT_NUM_TIMEPOINTS, + ): + super().__init__(*[ + DemoDatasetContinuous(num_timepoints, num_neural, num_behavior) + for num_neural in nums_neural + ]) + + self.num_timepoints = num_timepoints + self.nums_neural = nums_neural diff --git a/cebra/distributions/multisession.py b/cebra/distributions/multisession.py index 647044f2..9fe303a6 100644 --- a/cebra/distributions/multisession.py +++ b/cebra/distributions/multisession.py @@ -21,7 +21,11 @@ # """Continuous variable multi-session sampling.""" +import random +from typing import Optional + import numpy as np +import numpy.typing as npt import torch import cebra.distributions as cebra_distr @@ -174,7 +178,7 @@ def num_sessions(self) -> int: """The number of sessions in the index.""" return len(self.lengths) - def mix(self, array: np.ndarray, idx: np.ndarray): + def mix(self, array: npt.NDArray, idx: npt.NDArray): """Re-order array elements according to the given index mapping. The given array should be of the shape ``(session, batch, ...)`` and the @@ -383,3 +387,204 @@ def __getitem__(self, pos_idx): for i in range(self.num_sessions): pos_samples[i] = self.data[i][pos_idx[i]] return pos_samples + + +class UnifiedSampler(MultisessionSampler): + """Multi-session sampling, considering them as a single session. + + Align embeddings across multiple sessions, using a set of + auxiliary variables, so that the samples in the different sessions + are sampled together based on how the corresponding auxiliary + variables are close from each other. + + Then, the reference, positive and negative can be concatenated on their + neurons axis to train a single model for all sessions. + + Example: + >>> import cebra.distributions.multisession as cebra_distributions_multisession + >>> import cebra.integrations.sklearn.dataset as cebra_sklearn_dataset + >>> import cebra.data + >>> import torch + >>> from torch import nn + >>> # Multisession training: one model per dataset (different input dimensions) + >>> session1 = torch.rand(100, 30) + >>> session2 = torch.rand(100, 50) + >>> index1 = torch.rand(100) + >>> index2 = torch.rand(100) + >>> num_features = 8 + >>> dataset = cebra.data.UnifiedDataset( + ... cebra_sklearn_dataset.SklearnDataset(session1, (index1, )), + ... cebra_sklearn_dataset.SklearnDataset(session2, (index2, ))) + >>> model = cebra.models.init( + ... name="offset1-model", + ... num_neurons=dataset.input_dimension, + ... num_units=32, + ... num_output=num_features, + ... ).to("cpu") + >>> sampler = cebra_distributions_multisession.UnifiedSampler(dataset, time_offset=10) + + >>> # ref and pos samples from all datasets + >>> ref = sampler.sample_prior(100) + >>> pos = sampler.sample_conditional(ref) + >>> ref = torch.LongTensor(ref) + >>> pos = torch.LongTensor(pos) + >>> loss = (ref - pos)**2 + + Note: + This function does currently not support explicitly selected + discrete indices. They should be added as dimensions to the + continuous index. More weight can be added to the discrete + dimensions by using larger values in one-hot coding. + + """ + + def sample_all_uniform_prior(self, num_samples: int) -> npt.NDArray: + """Returns uniformly sampled index for all sessions of the dataset. + + Args: + num_samples: Number of samples to sample in each session. + + Returns: + ``(N, num_samples)`` with ``N`` the number of sessions. Array of + samples, uniformly picked for each session. + """ + return super().sample_prior(num_samples=num_samples) + + def sample_prior(self, + num_samples: int, + session_id: Optional[int] = None) -> npt.NDArray: + """Return uniformly sampled indices for all sessions. + + First, the reference indexes in a reference session are uniformly sampled. + Then the reference indexes for the other sessions are sampled so that their + corresponding auxiliary variables are close to the reference indexes of the + reference session. + + Args: + num_samples: Number of samples to pick. + session_id: ID of the session to use as the reference session. If ``None``, + the session is randomly selected. + + Returns: + A :py:func:`numpy.array` containing the idx of the reference samples for all + sessions. + """ + + # Randomly pick the reference session + if session_id is None: + session_id = random.choice(list(range(self.num_sessions))) + + # Sample prior for all sessions + idx = self.sample_all_uniform_prior(num_samples=num_samples) + # Keep the idx for the reference session only + idx = torch.from_numpy(idx[session_id]) + + # Sample the references indexes in other sessions, based on their distance to the + # reference idx in the reference session. + return self.sample_all_sessions(idx, session_id).cpu().numpy() + + def _get_query(self, + reference_idx: torch.Tensor, + session_id: int, + aligned: bool = False) -> torch.Tensor: + """Get the query for the given reference indices. + + Args: + reference_idx: Reference indices, with dimension ``(num_samples, )``. + session_id: Session ID of the reference session, whose ``idx`` are present in + ``reference_idx``. + aligned: If True, no time difference is added to the query. Otherwise, a + random time difference is added to the query. + """ + cum_idx = reference_idx + self.lengths[session_id] + if aligned: + query = self.all_data[cum_idx] + else: + diff_idx = torch.randint(len(self.time_difference), + (len(reference_idx),)) + query = self.all_data[cum_idx] + self.time_difference[diff_idx] + return torch.from_numpy(query).to(_device) + + def sample_all_sessions(self, ref_idx: torch.Tensor, + session_id: int) -> torch.Tensor: + """Sample sessions based on a reference session. + + Reference samples for the ``session_id`` th session were first sampled uniformly, as in + the py:class:`~.MultisessionSampler`. Then, reference samples for the other sessions + are sampled so that they are as close as the corresponding auxiliary variables in + the reference session. + + Note: similar to ``sample_condiditonal`` but at the level of the sessions, sampling ref idx in each + session so that they are close to the ref idx in the reference session (``session_id`` th session). + + Args: + ref_idx: Uniformly sampled ``idx`` for the reference session, ``(num_samples, )``, values + can be in ``[0, len(session)]``. + session_id: Session ID of the reference session, whose ``idx`` are present in ``ref_idx``. + + Returns: + The prior for all sessions, creating a "pseudo-animal", where ``idx`` sampled in different + sessions correspond to points in the recordings where the auxiliary variables are similar. + + """ + # Get the continuous data corresponding to the idx + # all_data: (sum(self.session_lengths), ) + # ref_idx: (num_samples, ), values in [O, len(get_session[session_id])] + # self.lengths: (num_sessions, ), cumsum of the length of each session, providing the first + # element of a session in self.all_data. + # cum_ref_idx: (num_samples, ), values of ref_idx, switched to correspond to the indexes in + # of session_id, in the flatten array self.all_data. + all_idx = torch.zeros(self.num_sessions, len(ref_idx), + device=_device).long() + query = self._get_query( + reference_idx=ref_idx, session_id=session_id, + aligned=True) # same query for all + no time diff added + + for i in range(self.num_sessions): + # except for the session_id provided + if i == session_id: + continue + # NOTE(celia): this is code to get different query for each. Unused for now. + #query = self._get_query(reference_idx=ref_idx, + # session_id=session_id, + # aligned=False) + + # get the idx of the datapoint that is the closest to the query + all_idx[i] = self.index[i].search( + query) # search in the whole dataset + + # all_idx[i] = self.index[i].search_or_mask( + # query, threshold=self.distance_threshold[i]) + + all_idx[session_id] = ref_idx + return all_idx + + def sample_conditional(self, reference_idx: npt.NDArray) -> torch.Tensor: + """Sample from the conditional distribution. + + Contrary to the :py:class:`MultisessionSampler`, conditional distribution + is sampled so that the samples match the reference samples. They are sampled + from the same session as each reference idx only, rather than across all + sessions. + + Args: + reference_idx: Reference indices, with dimension ``(session, batch)``. + + Returns: + Positive indices, which will be grouped by + session and match the reference indices. + Returned shape is ``(session, batch)``. + + """ + + cond_idx = torch.zeros((reference_idx.shape[0], reference_idx.shape[1]), + dtype=torch.int, + device=_device).long() + + for session_id in range(self.num_sessions): + query = self._get_query(reference_idx=reference_idx[session_id], + session_id=session_id) + + cond_idx[session_id] = self.index[session_id].search(query) + + return cond_idx.cpu().numpy() diff --git a/cebra/integrations/decoders.py b/cebra/integrations/decoders.py new file mode 100644 index 00000000..1d5314e3 --- /dev/null +++ b/cebra/integrations/decoders.py @@ -0,0 +1,220 @@ +from typing import List, Optional, Tuple, Union + +import numpy as np +import sklearn.metrics +import torch +import torch.nn as nn +import torch.optim as optim +from sklearn.linear_model import Ridge +from sklearn.model_selection import GridSearchCV +from torch.utils.data import DataLoader +from torch.utils.data import TensorDataset + + +def ridge_decoding( + embedding_train: Union[torch.Tensor, dict], + embedding_valid: Union[torch.Tensor, dict], + label_train: Union[torch.Tensor, dict], + label_valid: Union[torch.Tensor, dict], + n_run: Optional[int] = None, +) -> Tuple[List[float], List[float], np.ndarray]: + """ + Perform ridge regression decoding on training and validation embeddings. + + Args: + embedding_train (Union[torch.Tensor, dict]): Training embeddings. + embedding_valid (Union[torch.Tensor, dict]): Validation embeddings. + label_train (Union[torch.Tensor, dict]): Training labels. + label_valid (Union[torch.Tensor, dict]): Validation labels. + n_run (Optional[int]): Optional run number for dataset definition. + + Returns: + Training R2 scores, validation R2 scores, and validation predictions. + """ + if isinstance(embedding_train, dict): # only on run 1 + if n_run is None: + raise ValueError(f"n_run must be specified, got {n_run}.") + + all_train_embeddings = np.concatenate( + [ + embedding_train[i][n_run].cpu().numpy() + for i in range(len(embedding_train)) + ], + axis=0, + ) + train = np.concatenate( + [ + label_train[i].continuous.cpu().numpy() + for i in range(len(label_train)) + ], + axis=0, + ) + all_val_embeddings = np.concatenate( + [ + embedding_valid[i][n_run].cpu().numpy() + for i in range(len(embedding_valid)) + ], + axis=0, + ) + valid = np.concatenate( + [ + label_valid[i].continuous.cpu().numpy() + for i in range(len(label_valid)) + ], + axis=0, + ) + else: + all_train_embeddings = embedding_train.cpu().numpy() + train = label_train.cpu().numpy() + all_val_embeddings = embedding_valid.cpu().numpy() + valid = label_valid.cpu().numpy() + + decoder = GridSearchCV(Ridge(), {"alpha": np.logspace(-4, 0, 9)}) + decoder.fit(all_train_embeddings, train) + + train_prediction = decoder.predict(all_train_embeddings) + train_scores = sklearn.metrics.r2_score(train, + train_prediction, + multioutput="raw_values").tolist() + valid_prediction = decoder.predict(all_val_embeddings) + valid_scores = sklearn.metrics.r2_score(valid, + valid_prediction, + multioutput="raw_values").tolist() + + return train_scores, valid_scores, valid_prediction + + +class SingleLayerDecoder(nn.Module): + """Supervised module to predict behaviors. + + Note: + By default, the output dimension is 2, to predict x/y velocity + (Perich et al., 2018). + """ + + def __init__(self, input_dim, output_dim=2): + super(SingleLayerDecoder, self).__init__() + self.fc = nn.Linear(input_dim, output_dim) + + def forward(self, x): + return self.fc(x) + + +class TwoLayersDecoder(nn.Module): + """Supervised module to predict behaviors. + + Note: + By default, the output dimension is 2, to predict x/y velocity + (Perich et al., 2018). + """ + + def __init__(self, input_dim, output_dim=2): + super(TwoLayersDecoder, self).__init__() + self.fc = nn.Sequential(nn.Linear(input_dim, 32), nn.GELU(), + nn.Linear(32, output_dim)) + + def forward(self, x): + return self.fc(x) + + +def mlp_decoding( + embedding_train: Union[dict, torch.Tensor], + embedding_valid: Union[dict, torch.Tensor], + label_train: Union[dict, torch.Tensor], + label_valid: Union[dict, torch.Tensor], + num_epochs: int = 20, + lr: float = 0.001, + batch_size: int = 500, + device: str = "cuda", + model_type: str = "SingleLayerMLP", + n_run: Optional[int] = None, +): + """ Perform MLP decoding on training and validation embeddings. + + Args: + embedding_train (Union[dict, torch.Tensor]): Training embeddings. + embedding_valid (Union[dict, torch.Tensor]): Validation embeddings. + label_train (Union[dict, torch.Tensor]): Training labels. + label_valid (Union[dict, torch.Tensor]): Validation labels. + num_epochs (int): Number of training epochs. + lr (float): Learning rate for the optimizer. + batch_size (int): Batch size for training. + device (str): Device to run the model on ('cuda' or 'cpu'). + model_type (str): Type of MLP model to use ('SingleLayerMLP' or 'TwoLayersMLP'). + n_run (Optional[int]): Optional run number for dataset definition. + + Returns: + Training R2 scores, validation R2 scores, and validation predictions. + """ + if len(label_train.shape) == 1: + label_train = label_train[:, None] + label_valid = label_valid[:, None] + + if isinstance(embedding_train, dict): # only on run 1 + if n_run is None: + raise ValueError(f"n_run must be specified, got {n_run}.") + + all_train_embeddings = torch.cat( + [embedding_train[i][n_run] for i in range(len(embedding_train))], + axis=0) + train = torch.cat( + [label_train[i].continuous for i in range(len(label_train))], + axis=0) + all_val_embeddings = torch.cat( + [embedding_valid[i][n_run] for i in range(len(embedding_valid))], + axis=0) + valid = torch.cat( + [label_valid[i].continuous for i in range(len(label_valid))], + axis=0) + else: + all_train_embeddings = embedding_train + train = label_train + all_val_embeddings = embedding_valid + valid = label_valid + + dataset = TensorDataset(all_train_embeddings.to(device), train.to(device)) + train_loader = DataLoader(dataset, batch_size=batch_size, shuffle=True) + + input_dim = all_train_embeddings.shape[1] + output_dim = train.shape[1] + if model_type == "SingleLayerMLP": + model = SingleLayerDecoder(input_dim=input_dim, output_dim=output_dim) + elif model_type == "TwoLayersMLP": + model = TwoLayersDecoder(input_dim=input_dim, output_dim=output_dim) + else: + raise NotImplementedError() + model.to(device) + + criterion = nn.MSELoss() + optimizer = optim.Adam(model.parameters(), lr=lr) + + for epoch in range(num_epochs): + model.train() + running_loss = 0.0 + + for inputs, labels in train_loader: + inputs, labels = inputs.to(device), labels.to(device) + + optimizer.zero_grad() + outputs = model(inputs) + loss = criterion(outputs, labels) + loss.backward() + optimizer.step() + running_loss += loss.item() + + model.eval() + train_pred = model(all_train_embeddings.to(device)) + train_r2 = sklearn.metrics.r2_score( + y_true=train.cpu().numpy(), + y_pred=train_pred.cpu().detach().numpy(), + multioutput="raw_values", + ).tolist() + + valid_pred = model(all_val_embeddings.to(device)) + valid_r2 = sklearn.metrics.r2_score( + y_true=valid.cpu().numpy(), + y_pred=valid_pred.cpu().detach().numpy(), + multioutput="raw_values", + ).tolist() + + return train_r2, valid_r2, valid_pred diff --git a/cebra/integrations/sklearn/cebra.py b/cebra/integrations/sklearn/cebra.py index 0dee4c3f..98e56747 100644 --- a/cebra/integrations/sklearn/cebra.py +++ b/cebra/integrations/sklearn/cebra.py @@ -491,12 +491,16 @@ class CEBRA(TransformerMixin, BaseEstimator): hybrid (bool): If ``True``, the model will be trained using both the time-contrastive and the selected behavior-constrastive loss functions. |Default:| ``False``. - optimizer_kwargs (dict): + optimizer_kwargs (tuple): Additional optimization parameters. These have the form ``((key, value), (key, value))`` and are passed to the PyTorch optimizer specified through the ``optimizer`` argument. Refer to the optimizer documentation in :py:mod:`torch.optim` for further information on how to format the arguments. |Default:| ``(('betas', (0.9, 0.999)), ('eps', 1e-08), ('weight_decay', 0), ('amsgrad', False))`` + masking_kwargs (tuple): + A Tuple of masking types and their corresponding required masking values. The keys are the + names of the Mask instances and formatting should be ``((key, value), (key, value))``. + |Default:| ``None``. Example: @@ -570,6 +574,8 @@ def __init__( ("weight_decay", 0), ("amsgrad", False), ), + masking_kwargs: Tuple[Tuple[str, Union[float, List[float], + Tuple[float, ...]]], ...] = None, ): self.__dict__.update(locals()) @@ -896,6 +902,9 @@ def _prepare_fit( self.offset_ = self._compute_offset() dataset, is_multisession = self._prepare_data(X, y) + if self.masking_kwargs: + dataset.set_masks(dict(self.masking_kwargs)) + loader, solver_name = self._prepare_loader( dataset, max_iterations=self.max_iterations, diff --git a/cebra/models/model.py b/cebra/models/model.py index a74b0229..77423532 100644 --- a/cebra/models/model.py +++ b/cebra/models/model.py @@ -224,6 +224,12 @@ def __init__(self, # the self.net self.normalize = normalize + def _make_layers(self, num_units, num_layers, kernel_size=3): + return [ + cebra_layers._Skip(nn.Conv1d(num_units, num_units, kernel_size), + nn.GELU()) for _ in range(num_layers) + ] + def forward(self, inp): """Compute the embedding given the input signal. @@ -266,9 +272,7 @@ def __init__(self, num_neurons, num_units, num_output, normalize=True): super().__init__( nn.Conv1d(num_neurons, num_units, 2), nn.GELU(), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), + *self._make_layers(num_units, num_layers=3), nn.Conv1d(num_units, num_output, 3), num_input=num_neurons, num_output=num_output, @@ -529,9 +533,7 @@ def __init__(self, num_neurons, num_units, num_output, normalize=True): cebra_layers._MeanAndConv(num_neurons, num_units, 4, stride=2), nn.Conv1d(num_neurons + num_units, num_units, 3, stride=2), nn.GELU(), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), + *self._make_layers(num_units, num_layers=3), nn.Conv1d(num_units, num_output, 3), num_input=num_neurons, num_output=num_output, @@ -676,22 +678,7 @@ def __init__(self, num_neurons, num_units, num_output, normalize=True): super().__init__( nn.Conv1d(num_neurons, num_units, 2), nn.GELU(), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), + *self._make_layers(num_units, num_layers=16), nn.Conv1d(num_units, num_output, 3), num_input=num_neurons, num_output=num_output, @@ -711,31 +698,22 @@ class Offset36Dropout(_OffsetModel, ConvolutionalModelMixin): Requires ``torch>=1.12``. """ - def __init__(self, num_neurons, num_units, num_output, normalize=True): + def __init__(self, + num_neurons, + num_units, + num_output, + normalize=True, + dropout=0.1): if num_units < 1: raise ValueError( f"Hidden dimension needs to be at least 1, but got {num_units}." ) + dropout = dropout super().__init__( nn.Conv1d(num_neurons, num_units, 2), - torch.nn.Dropout1d(p=0.1), + torch.nn.Dropout1d(p=dropout), nn.GELU(), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), + *self._make_layers(num_units, num_layers=16), nn.Conv1d(num_units, num_output, 3), num_input=num_neurons, num_output=num_output, @@ -762,16 +740,22 @@ def _make_layers(self, num_units, p, n): for _ in range(n) ] - def __init__(self, num_neurons, num_units, num_output, normalize=True): + def __init__(self, + num_neurons, + num_units, + num_output, + normalize=True, + dropout=0.1): if num_units < 1: raise ValueError( f"Hidden dimension needs to be at least 1, but got {num_units}." ) + dropout = dropout super().__init__( nn.Conv1d(num_neurons, num_units, 2), - torch.nn.Dropout1d(p=0.1), + torch.nn.Dropout1d(p=dropout), nn.GELU(), - *self._make_layers(num_units, 0.1, 16), + *self._make_layers(num_units, p=dropout, n=16), nn.Conv1d(num_units, num_output, 3), num_input=num_neurons, num_output=num_output, @@ -783,6 +767,54 @@ def get_offset(self) -> cebra.data.datatypes.Offset: return cebra.data.Offset(18, 18) +@register("offset40-model") +class Offset40(_OffsetModel, ConvolutionalModelMixin): + """CEBRA model with a 40 samples receptive field.""" + + def __init__(self, num_neurons, num_units, num_output, normalize=True): + if num_units < 1: + raise ValueError( + f"Hidden dimension needs to be at least 1, but got {num_units}." + ) + super().__init__( + nn.Conv1d(num_neurons, num_units, 2), + nn.GELU(), + *self._make_layers(num_units, 18), + nn.Conv1d(num_units, num_output, 3), + num_input=num_neurons, + num_output=num_output, + normalize=normalize, + ) + + def get_offset(self) -> cebra.data.datatypes.Offset: + """See `:py:meth:Model.get_offset`""" + return cebra.data.Offset(20, 20) + + +@register("offset50-model") +class Offset50(_OffsetModel, ConvolutionalModelMixin): + """CEBRA model with a sample receptive field.""" + + def __init__(self, num_neurons, num_units, num_output, normalize=True): + if num_units < 1: + raise ValueError( + f"Hidden dimension needs to be at least 1, but got {num_units}." + ) + super().__init__( + nn.Conv1d(num_neurons, num_units, 2), + nn.GELU(), + *self._make_layers(num_units, 23), + nn.Conv1d(num_units, num_output, 3), + num_input=num_neurons, + num_output=num_output, + normalize=normalize, + ) + + def get_offset(self) -> cebra.data.datatypes.Offset: + """See `:py:meth:Model.get_offset`""" + return cebra.data.Offset(25, 25) + + @register("offset15-model") class Offset15Model(_OffsetModel, ConvolutionalModelMixin): """CEBRA model with a 15 sample receptive field.""" @@ -795,12 +827,7 @@ def __init__(self, num_neurons, num_units, num_output, normalize=True): super().__init__( nn.Conv1d(num_neurons, num_units, 2), nn.GELU(), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), + *self._make_layers(num_units, num_layers=6), nn.Conv1d(num_units, num_output, 2), num_input=num_neurons, num_output=num_output, @@ -824,14 +851,7 @@ def __init__(self, num_neurons, num_units, num_output, normalize=True): super().__init__( nn.Conv1d(num_neurons, num_units, 2), nn.GELU(), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), + *self._make_layers(num_units, num_layers=8), nn.Conv1d(num_units, num_output, 3), num_input=num_neurons, num_output=num_output, @@ -855,9 +875,7 @@ def __init__(self, num_neurons, num_units, num_output, normalize=False): super().__init__( nn.Conv1d(num_neurons, num_units, 2), nn.GELU(), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), + *self._make_layers(num_units, num_layers=3), nn.Conv1d(num_units, num_output, 3), nn.Tanh(), # Added tanh activation function num_input=num_neurons, diff --git a/cebra/solver/__init__.py b/cebra/solver/__init__.py index 965c16c8..8bc63a42 100644 --- a/cebra/solver/__init__.py +++ b/cebra/solver/__init__.py @@ -42,5 +42,6 @@ from cebra.solver.schedulers import * from cebra.solver.single_session import * from cebra.solver.supervised import * +from cebra.solver.unified_session import * cebra.registry.add_docstring(__name__) diff --git a/cebra/solver/base.py b/cebra/solver/base.py index 66d90786..c04c3398 100644 --- a/cebra/solver/base.py +++ b/cebra/solver/base.py @@ -393,7 +393,6 @@ def num_parameters(self) -> int: """Total number of parameters in the encoder and criterion.""" return sum(p.numel() for p in self.parameters()) - @abc.abstractmethod def parameters(self, session_id: Optional[int] = None): """Iterate over all parameters of the model. @@ -405,7 +404,11 @@ def parameters(self, session_id: Optional[int] = None): Yields: The parameters of the model. """ - raise NotImplementedError + for parameter in self.model.parameters(): + yield parameter + + for parameter in self.criterion.parameters(): + yield parameter def _get_loader(self, loader): return ProgressBar( diff --git a/cebra/solver/multi_session.py b/cebra/solver/multi_session.py index dfd71921..b26718bf 100644 --- a/cebra/solver/multi_session.py +++ b/cebra/solver/multi_session.py @@ -81,9 +81,13 @@ def _single_model_inference(self, batch: cebra.data.Batch, ``batch.index`` should be set to ``None``. """ batch.to(self.device) - ref = torch.stack([model(batch.reference)], dim=0) - pos = torch.stack([model(batch.positive)], dim=0) - neg = torch.stack([model(batch.negative)], dim=0) + ref = model(batch.reference) + pos = model(batch.positive) + neg = model(batch.negative) + + ref = ref.unsqueeze(0) + pos = pos.unsqueeze(0) + neg = neg.unsqueeze(0) pos = self._mix(pos, batch.index_reversed) diff --git a/cebra/solver/single_session.py b/cebra/solver/single_session.py index c74eb704..61b880e4 100644 --- a/cebra/solver/single_session.py +++ b/cebra/solver/single_session.py @@ -45,25 +45,6 @@ class SingleSessionSolver(abc_.Solver): _variant_name = "single-session" - def parameters(self, session_id: Optional[int] = None): - """Iterate over all parameters. - - Args: - session_id: The session ID, an :py:class:`int` between 0 and - the number of sessions -1 for multisession, and set to - ``None`` for single session. - - Yields: - The parameters of the model. - """ - # If session_id is invalid, it doesn't matter, since we are - # using a single session solver. - for parameter in self.model.parameters(): - yield parameter - - for parameter in self.criterion.parameters(): - yield parameter - def _set_fitted_params(self, loader: cebra.data.Loader): """Set parameters once the solver is fitted. diff --git a/cebra/solver/unified_session.py b/cebra/solver/unified_session.py new file mode 100644 index 00000000..3524700f --- /dev/null +++ b/cebra/solver/unified_session.py @@ -0,0 +1,382 @@ +# +# CEBRA: Consistent EmBeddings of high-dimensional Recordings using Auxiliary variables +# © Mackenzie W. Mathis & Steffen Schneider (v0.4.0+) +# Source code: +# https://github.com/AdaptiveMotorControlLab/CEBRA +# +# Please see LICENSE.md for the full license document: +# https://github.com/AdaptiveMotorControlLab/CEBRA/blob/main/LICENSE.md +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""Unified session solver for multi-session contrastive learning. + +We added support for training contrastive models on unified-session datasets. +This allows users to align and embed multiple sessions into a common latent +space using a single shared model. + +This module implements the :py:class:`~cebra.solver.unified.UnifiedSolver`, which +is designed for training a single embedding model across multiple recording sessions. +Unlike the standard multi-session solvers, the unified session approach uses +a global model that requires session-specific information for sampling but maintains +a shared representation across all data. + +Features: +- Single model inference across all sessions. +- Batched transform. +- Compatibility with :py:class:`~cebra.data.UnifiedDataset` and :py:class:`~cebra.data.UnifiedLoader`. + +See Also: + :py:class:`~cebra.solver.base.Solver` + :py:class:`~cebra.data.UnifiedDataset` + :py:class:`~cebra.data.UnifiedLoader` +""" + +from typing import List, Optional, Union + +import literate_dataclasses as dataclasses +import numpy as np +import torch + +import cebra +import cebra.data +import cebra.distributions +import cebra.models +import cebra.solver.base as abc_ +from cebra.solver import register + + +@register("unified-session") +@dataclasses.dataclass +class UnifiedSolver(abc_.Solver): + """Multi session training, considering a single model for all sessions.""" + + _variant_name = "unified-session" + + def _set_fitted_params(self, loader: cebra.data.Loader): # mix + """Set parameters once the solver is fitted. + + In single session solver, the number of session is set to None and the number of + features is set to the number of neurons in the dataset. + + Args: + loader: Loader used to fit the solver. + """ + self.num_sessions = loader.dataset.num_sessions + self.n_features = loader.dataset.input_dimension + + def _check_is_inputs_valid(self, inputs: Union[torch.Tensor, + List[torch.Tensor]], + session_id: int): + """Check that the inputs can be inferred using the selected model. + + Note: This method checks that the number of neurons in the input is + similar to the input dimension to the selected model. + + Args: + inputs: Data to infer using the selected model. + session_id: The session ID, an :py:class:`int` between 0 and + the number of sessions -1 for multisession, and set to + ``None`` for single session. + """ + + if isinstance(inputs, list): + inputs_shape = 0 + for i in range(len(inputs)): + inputs_shape += inputs[i].shape[1] + elif isinstance(inputs, + torch.Tensor): #NOTE(celia): flexible input at training + raise NotImplementedError + else: + raise NotImplementedError + + if self.n_features != inputs_shape: + raise ValueError( + f"Invalid input shape: model requires an input of shape" + f"(n_samples, {self.n_features}), got (n_samples, {inputs.shape[1]})." + ) + + def _check_is_session_id_valid( + self, + session_id: Optional[int] = None, + ): # same as multi + """Check that the session ID provided is valid for the solver instance. + + The session ID must be non-null and between 0 and the number session in the dataset. + + Args: + session_id: The session ID to check. + """ + + if session_id is None: + raise RuntimeError( + "No session_id provided: unified model requires a session_id as the target session to use to align the sessions." + ) + if session_id >= self.num_sessions or session_id < 0: + raise RuntimeError( + f"Invalid session_id {session_id}: session_id for the current unified model must be between 0 and {self.num_sessions-1}." + ) + + def _get_model(self, session_id: Optional[int] = None): + """Get the model for the given session ID. + + Args: + session_id: The session ID, an :py:class:`int` between 0 and + the number of sessions -1 for multisession, and set to + ``None`` for single session. + + Returns: + The model for the given session ID. + """ + self._check_is_session_id_valid(session_id=session_id) + self._check_is_fitted() + return self.model + + def _single_model_inference(self, batch: cebra.data.Batch, + model: torch.nn.Module) -> cebra.data.Batch: + """Given a single batch of input examples, computes the feature representation/embedding. + + Args: + batch: The input data, not necessarily aligned across the batch + dimension. This means that ``batch.index`` specifies the map + between reference/positive samples, if not equal ``None``. + model: The model to use for inference. + + Returns: + Processed batch of data. While the input data might not be aligned + across the sample dimensions, the output data should be aligned and + ``batch.index`` should be set to ``None``. + """ + batch.to(self.device) + ref = model(batch.reference) + pos = model(batch.positive) + neg = model(batch.negative) + + ref = ref.unsqueeze(0) + pos = pos.unsqueeze(0) + neg = neg.unsqueeze(0) + + num_features = neg.shape[2] + + return cebra.data.Batch( + reference=ref.view(-1, num_features), + positive=pos.view(-1, num_features), + negative=neg.view(-1, num_features), + ) + + def _inference(self, batch: cebra.data.Batch) -> cebra.data.Batch: + """Given batches of input examples, computes the feature representations/embeddings. + + Args: + batches: A list of input data, not necessarily aligned across the batch + dimension. This means that ``batch.index`` specifies the map + between reference/positive samples, if not equal ``None``. + + Returns: + Processed batch of data. While the input data might not be aligned + across the sample dimensions, the output data should be aligned and + ``batch.index`` should be set to ``None``. + """ + return self._single_model_inference(batch, self.model) + + @torch.no_grad() + def transform(self, + inputs: List[torch.Tensor], + labels: List[torch.Tensor], + pad_before_transform: bool = True, + session_id: Optional[int] = None, + batch_size: Optional[int] = 512) -> torch.Tensor: + """Compute the embedding for the `session_id`th session of the dataset. + + Note: + Compared to the other :py:class:`cebra.solver.base.Solver`, we need all the sessions of + the dataset to transform the data, as the sampling is across all the sessions. + + Args: + inputs: The input signal for all sessions. + labels: The auxiliary variables to use for sampling. + session_id: The session ID, an :py:class:`int` between 0 and + the number of sessions -1. + batch_size: If not None, batched inference will be applied. + + Note: + The ``session_id`` is needed in order to sample the corresponding number of samples and + return an embedding of the expected shape. + + Note: + The batched inference will be required in most cases. Default is set to ``100`` for that reason. + + Returns: + The output embedding for the session corresponding to the provided ID `session_id`. The shape + is (num_samples(session_id), output_dimension)``. + + """ + if not isinstance(inputs, list): + raise ValueError( + f"Inputs to transform() should be a list, not {type(inputs)}.") + + self._check_is_fitted() + + if session_id is None: + raise ValueError("Session ID is required for multi-session models.") + + # Sampling according to session_id required + dataset = cebra.data.UnifiedDataset( + cebra.data.TensorDataset( + inputs[i], continuous=labels[i], offset=cebra.data.Offset(0, 1)) + for i in range(len(inputs))).to(self.device) + + # Only used to sample the reference samples + loader = cebra.data.UnifiedLoader(dataset, num_steps=1) + + # Sampling in batch + refs_data_batch_embeddings = [] + batch_range = range(0, len(dataset.get_session(session_id)), batch_size) + if len(batch_range) < 2: + raise ValueError( + "Not enough data to perform the batched transform. Please provide a larger dataset or reduce the batch size." + ) + for batch_start in batch_range: + batch_end = min(batch_start + batch_size, + len(dataset.get_session(session_id))) + + if batch_start == batch_range[-2]: # one before last batch + last_start = batch_start + continue + if batch_start == batch_range[-1]: # last batch, likely incomplete + batch_start = last_start + batch_end = len(dataset.get_session(session_id)) + + refs_idx_batch = loader.sampler.sample_all_sessions( + ref_idx=torch.arange(batch_start, batch_end), + session_id=session_id).to(self.device) + + refs_data_batch = torch.cat([ + session[refs_idx_batch[session_id]] + for session_id, session in enumerate(dataset.iter_sessions()) + ], + dim=1).squeeze() + + if len(self.model.get_offset()) < 2 and pad_before_transform: + pad_before_transform = False + + self.model.eval() + refs_data_batch_embeddings.append( + self._transform(model=self.model, + inputs=refs_data_batch, + pad_before_transform=pad_before_transform, + offset=self.model.get_offset(), + batch_size=batch_size)) + + return torch.cat(refs_data_batch_embeddings, dim=0) + + @torch.no_grad() + def decoding(self, + train_loader: cebra.data.Loader, + valid_loader: Optional[cebra.data.Loader] = None, + decode: str = "ridge", + max_sessions: int = 5, + max_timesteps: int = 512) -> float: + """Sample a fixed number of sessions to compute the decoding score. + + Args: + train_loader: The training data loader. + valid_loader: The validation data loader, if available. + decode: The type of decoder to use, either "knn" or "ridge". + max_sessions: The maximum number of sessions to sample for decoding. + max_timesteps: The maximum number of timesteps to consider for each session. + Returns: + The average decoding score across the sampled sessions. If a validation loader is provided, + returns a tuple of the average training score and the average validation score. + """ + if train_loader.dataset.num_sessions > max_sessions: + sessions = np.random.choice(np.arange( + train_loader.dataset.num_sessions), + size=max_sessions, + replace=False) + else: + sessions = np.arange(train_loader.dataset.num_sessions) + + train_scores, valid_scores = [], [] + for i in sessions: + if train_loader.dataset.get_session( + i).neural.shape[0] > max_timesteps: + train_end = max_timesteps + else: + train_end = -1 + train_x = self.transform([ + train_loader.dataset.get_session(j).neural[:train_end] + for j in range(train_loader.dataset.num_sessions) + ], [ + train_loader.dataset.get_session(j).continuous_index[:train_end] + if train_loader.dataset.get_session(j).continuous_index + is not None else + train_loader.dataset.get_session(j).discrete_index[:train_end] + for j in range(train_loader.dataset.num_sessions) + ], + session_id=i, + batch_size=128) + train_y = train_loader.dataset.get_session( + i + ).continuous_index[:train_end] if train_loader.dataset.get_session( + i + ).continuous_index is not None else train_loader.dataset.get_session( + i).discrete_index[:train_end] + + if valid_loader is not None: + if valid_loader.dataset.get_session( + i).neural.shape[0] > max_timesteps: + valid_end = max_timesteps + else: + valid_end = -1 + valid_x = self.transform([ + valid_loader.dataset.get_session(j).neural[:valid_end] + for j in range(valid_loader.dataset.num_sessions) + ], [ + valid_loader.dataset.get_session( + j).continuous_index[:valid_end] + if valid_loader.dataset.get_session(j).continuous_index + is not None else valid_loader.dataset.get_session( + j).discrete_index[:valid_end] + for j in range(valid_loader.dataset.num_sessions) + ], + session_id=i, + batch_size=128) + valid_y = valid_loader.dataset.get_session( + i + ).continuous_index[:valid_end] if valid_loader.dataset.get_session( + i + ).continuous_index is not None else valid_loader.dataset.get_session( + i).discrete_index[:valid_end] + + if decode == "knn": + decoder = cebra.KNNDecoder() + elif decode == "ridge": + decoder = cebra.RidgeRegressor() + else: + raise NotImplementedError(f"Decoder {decode} not implemented.") + + decoder.fit(train_x.cpu().numpy(), train_y.cpu().numpy()) + train_scores.append( + decoder.score(train_x.cpu().numpy(), + train_y.cpu().numpy())) + + if valid_loader is not None: + valid_scores.append( + decoder.score(valid_x.cpu().numpy(), + valid_y.cpu().numpy())) + + if valid_loader is None: + return np.array(train_scores).mean() + else: + return np.array(train_scores).mean(), np.array(valid_scores).mean() diff --git a/docs/source/api/pytorch/helpers.rst b/docs/source/api/pytorch/helpers.rst index 6615e6f9..6e206521 100644 --- a/docs/source/api/pytorch/helpers.rst +++ b/docs/source/api/pytorch/helpers.rst @@ -35,3 +35,11 @@ Data helpers .. automodule:: cebra.data.helper :members: :show-inheritance: + + +Masking helpers +---------------- + +.. automodule:: cebra.data.masking + :members: + :show-inheritance: diff --git a/tests/_util.py b/tests/_util.py index b4a0e07d..42dd54cb 100644 --- a/tests/_util.py +++ b/tests/_util.py @@ -74,3 +74,8 @@ def parametrize_with_checks_slow(fast_arguments, slow_arguments): slow_arg, generate_only=True))[0] for slow_arg in slow_arguments ] return parametrize_slow("estimator,check", fast_params, slow_params) + + +def parametrize_device(func): + _devices = ("cpu", "cuda") if torch.cuda.is_available() else ("cpu",) + return pytest.mark.parametrize("device", _devices)(func) diff --git a/tests/test_data_masking.py b/tests/test_data_masking.py new file mode 100644 index 00000000..78448300 --- /dev/null +++ b/tests/test_data_masking.py @@ -0,0 +1,215 @@ +# Please see LICENSE.md for the full license document: +# https://github.com/AdaptiveMotorControlLab/CEBRA/blob/main/LICENSE.md +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import copy + +import pytest +import torch + +import cebra.data.masking + + +@pytest.mark.parametrize("mask", [ + cebra.data.masking.RandomNeuronMask, + cebra.data.masking.RandomTimestepMask, + cebra.data.masking.NeuronBlockMask, +]) +def test_random_mask(mask: cebra.data.masking.Mask): + data = torch.ones( + (10, 20, + 30)) # Example tensor with shape (batch_size, n_neurons, offset) + mask = mask(masking_value=0.5) + masked_data = mask.apply_mask(copy.deepcopy(data)) + + assert masked_data.shape == data.shape, "Masked data shape should match input data shape" + assert (masked_data <= 1).all() and ( + masked_data >= 0).all(), "Masked data should only contain values 0 or 1" + assert torch.sum(masked_data) < torch.sum( + data), "Masked data should have fewer active neurons than original data" + + +def test_timeblock_mask(): + data = torch.ones( + (10, 20, + 30)) # Example tensor with shape (batch_size, n_neurons, offset) + mask = cebra.data.masking.TimeBlockMask(masking_value=(0.035, 10)) + masked_data = mask.apply_mask(copy.deepcopy(data)) + + assert masked_data.shape == data.shape, "Masked data shape should match input data shape" + assert (masked_data <= 1).all() and ( + masked_data >= 0).all(), "Masked data should only contain values 0 or 1" + assert torch.sum(masked_data) < torch.sum( + data), "Masked data should have fewer active neurons than original data" + + +def test_masked_mixin_no_masks(): + mixin = cebra.data.masking.MaskedMixin() + data = torch.ones( + (10, 20, + 30)) # Example tensor with shape (batch_size, n_neurons, offset) + masked_data = mixin.apply_mask(copy.deepcopy(data)) + + assert torch.equal( + data, + masked_data), "Data should remain unchanged when no masks are applied" + + +@pytest.mark.parametrize( + "mask", ["RandomNeuronMask", "RandomTimestepMask", "NeuronBlockMask"]) +def test_masked_mixin_random_mask(mask): + data = torch.ones( + (10, 20, + 30)) # Example tensor with shape (batch_size, n_neurons, offset) + + mixin = cebra.data.masking.MaskedMixin() + assert mixin._masks == [], "Masks should be empty initially" + + mixin.set_masks({mask: 0.5}) + assert len(mixin._masks) == 1, "One mask should be set" + assert isinstance(mixin._masks[0], + getattr(cebra.data.masking, + mask)), f"Mask should be of type {mask}" + if isinstance(mixin._masks[0], cebra.data.masking.NeuronBlockMask): + assert mixin._masks[ + 0].mask_prop == 0.5, "Masking value should be set correctly" + else: + assert mixin._masks[ + 0].mask_ratio == 0.5, "Masking value should be set correctly" + + masked_data = mixin.apply_mask(copy.deepcopy(data)) + assert masked_data.shape == data.shape, "Masked data shape should match input data shape" + assert not torch.equal( + data, masked_data), "Data should be modified when a mask is applied" + + mixin.set_masks({mask: [0.5, 0.1]}) + assert len(mixin._masks) == 1, "One mask should be set" + assert isinstance(mixin._masks[0], + getattr(cebra.data.masking, + mask)), f"Mask should be of type {mask}" + masked_data = mixin.apply_mask(copy.deepcopy(data)) + assert masked_data.shape == data.shape, "Masked data shape should match input data shape" + assert not torch.equal( + data, masked_data), "Data should be modified when a mask is applied" + + mixin.set_masks({mask: (0.3, 0.9, 0.05)}) + assert len(mixin._masks) == 1, "One mask should be set" + assert isinstance(mixin._masks[0], + getattr(cebra.data.masking, + mask)), f"Mask should be of type {mask}" + masked_data = mixin.apply_mask(copy.deepcopy(data)) + assert masked_data.shape == data.shape, "Masked data shape should match input data shape" + assert not torch.equal( + data, masked_data), "Data should be modified when a mask is applied" + + +def test_apply_mask_with_time_block_mask(): + mixin = cebra.data.masking.MaskedMixin() + + with pytest.raises(AssertionError, match="sampled_rate.*masked_seq_len"): + mixin.set_masks({"TimeBlockMask": 0.2}) + + with pytest.raises(AssertionError, match="(sampled_rate.*masked_seq_len)"): + mixin.set_masks({"TimeBlockMask": [0.2, 10]}) + + with pytest.raises(AssertionError, match="between.*0.0.*1.0"): + mixin.set_masks({"TimeBlockMask": (-2, 10)}) + + with pytest.raises(AssertionError, match="between.*0.0.*1.0"): + mixin.set_masks({"TimeBlockMask": (2, 10)}) + + with pytest.raises(AssertionError, match="integer.*greater"): + mixin.set_masks({"TimeBlockMask": (0.2, -10)}) + + with pytest.raises(AssertionError, match="integer.*greater"): + mixin.set_masks({"TimeBlockMask": (0.2, 5.5)}) + + mixin.set_masks({"TimeBlockMask": (0.035, 10)}) # Correct usage + data = torch.ones( + (10, 20, + 30)) # Example tensor with shape (batch_size, n_neurons, offset) + masked_data = mixin.apply_mask(copy.deepcopy(data)) + + assert masked_data.shape == data.shape, "Masked data shape should match input data shape" + assert not torch.equal( + data, masked_data), "Data should be modified when a mask is applied" + + +def test_multiple_masks_mixin(): + mixin = cebra.data.masking.MaskedMixin() + mixin.set_masks({"RandomNeuronMask": 0.5, "RandomTimestepMask": 0.3}) + data = torch.ones( + (10, 20, + 30)) # Example tensor with shape (batch_size, n_neurons, offset) + + masked_data = mixin.apply_mask(copy.deepcopy(data)) + assert masked_data.shape == data.shape, "Masked data shape should match input data shape" + assert not torch.equal( + data, + masked_data), "Data should be modified when multiple masks are applied" + + masked_data2 = mixin.apply_mask(copy.deepcopy(masked_data)) + assert masked_data2.shape == data.shape, "Masked data shape should match input data shape" + assert not torch.equal( + data, + masked_data2), "Data should be modified when multiple masks are applied" + assert not torch.equal( + masked_data, masked_data2 + ), "Masked data should be different for different iterations" + + +def test_single_dim_input(): + mixin = cebra.data.masking.MaskedMixin() + mixin.set_masks({"RandomNeuronMask": 0.5}) + data = torch.ones((10, 1, 30)) # Single neuron + masked_data = mixin.apply_mask(copy.deepcopy(data)) + + assert masked_data.shape == data.shape, "Masked data shape should match input data shape" + assert not torch.equal( + data, masked_data), "Data should be modified even with a single neuron" + + mixin = cebra.data.masking.MaskedMixin() + mixin.set_masks({"RandomTimestepMask": 0.5}) + data = torch.ones((10, 20, 1)) # Single timestep + masked_data = mixin.apply_mask(copy.deepcopy(data)) + + assert masked_data.shape == data.shape, "Masked data shape should match input data shape" + assert not torch.equal( + data, + masked_data), "Data should be modified even with a single timestep" + + +def test_apply_mask_with_invalid_input(): + mixin = cebra.data.masking.MaskedMixin() + mixin.set_masks({"RandomNeuronMask": 0.5}) + + with pytest.raises(ValueError, match="Data must be a 3D tensor"): + data = torch.ones( + (10, 20, 30, 40)) # Invalid tensor shape (extra dimension) + mixin.apply_mask(data) + + with pytest.raises(ValueError, match="Data must be a float32 tensor"): + data = torch.ones((10, 20, 30), dtype=torch.int32) + mixin.apply_mask(data) + + +def test_apply_mask_with_chunk_size(): + mixin = cebra.data.masking.MaskedMixin() + mixin.set_masks({"RandomNeuronMask": 0.5}) + data = torch.ones((10000, 20, 30)) # Large tensor to test chunking + masked_data = mixin.apply_mask(copy.deepcopy(data), chunk_size=1000) + + assert masked_data.shape == data.shape, "Masked data shape should match input data shape" + assert not torch.equal( + data, masked_data), "Data should be modified when a mask is applied" diff --git a/tests/test_distributions.py b/tests/test_distributions.py index 2b704391..656559bb 100644 --- a/tests/test_distributions.py +++ b/tests/test_distributions.py @@ -411,3 +411,16 @@ def test_new_delta_normal_with_multidimensional_index(delta, numerical_check): pytest.skip( "multivariate delta distribution can not accurately sample with the " "given parameters. TODO: Add a warning message for these cases.") + + +@pytest.mark.parametrize("time_offset", [1, 5, 10]) +def test_unified_distribution(time_offset): + dataset = cebra_datasets.init("demo-continuous-unified") + sampler = cebra_distr.UnifiedSampler(dataset, time_offset=time_offset) + + num_samples = 5 + sample = sampler.sample_prior(num_samples) + assert sample.shape == (dataset.num_sessions, num_samples) + + positive = sampler.sample_conditional(sample) + assert positive.shape == (dataset.num_sessions, num_samples) diff --git a/tests/test_loader.py b/tests/test_loader.py index 562f64a7..cb6be9a7 100644 --- a/tests/test_loader.py +++ b/tests/test_loader.py @@ -19,16 +19,16 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import _util +import numpy as np import pytest import torch import cebra.data import cebra.io - -def parametrize_device(func): - _devices = ("cpu", "cuda") if torch.cuda.is_available() else ("cpu",) - return pytest.mark.parametrize("device", _devices)(func) +BATCH_SIZE = 32 +NUMS_NEURAL = [3, 4, 5] class LoadSpeed: @@ -107,7 +107,11 @@ def _assert_dataset_on_correct_device(loader, device): assert hasattr(loader, "dataset") assert hasattr(loader, "device") assert isinstance(loader.dataset, cebra.io.HasDevice) - assert loader.dataset.neural.device.type == device + if isinstance(loader, cebra.data.SingleSessionDataset): + assert loader.dataset.neural.device.type == device + elif isinstance(loader, cebra.data.MultiSessionDataset): + for session in loader.dataset.iter_sessions(): + assert session.neural.device.type == device def test_demo_data(): @@ -130,13 +134,15 @@ def _to_str(val): assert _to_str(first) == _to_str(second) -@parametrize_device +@_util.parametrize_device @pytest.mark.parametrize( "data_name, loader_initfunc", [ ("demo-discrete", cebra.data.DiscreteDataLoader), ("demo-continuous", cebra.data.ContinuousDataLoader), ("demo-mixed", cebra.data.MixedDataLoader), + ("demo-continuous-multisession", cebra.data.MultiSessionLoader), + ("demo-continuous-unified", cebra.data.UnifiedLoader), ], ) def test_device(data_name, loader_initfunc, device): @@ -147,7 +153,7 @@ def test_device(data_name, loader_initfunc, device): other_device = swap.get(device) dataset = RandomDataset(N=100, device=other_device) - loader = loader_initfunc(dataset, num_steps=10, batch_size=32) + loader = loader_initfunc(dataset, num_steps=10, batch_size=BATCH_SIZE) loader.to(device) assert loader.dataset == dataset _assert_device(loader.device, device) @@ -156,7 +162,7 @@ def test_device(data_name, loader_initfunc, device): _assert_device(loader.get_indices(10).reference.device, device) -@parametrize_device +@_util.parametrize_device @pytest.mark.parametrize("prior", ("uniform", "empirical")) def test_discrete(prior, device, benchmark): dataset = RandomDataset(N=100, device=device) @@ -171,7 +177,7 @@ def test_discrete(prior, device, benchmark): benchmark(load_speed) -@parametrize_device +@_util.parametrize_device @pytest.mark.parametrize("conditional", ("time", "time_delta")) def test_continuous(conditional, device, benchmark): dataset = RandomDataset(N=100, d=5, device=device) @@ -199,7 +205,7 @@ def _check_attributes(obj, is_list=False): raise TypeError() -@parametrize_device +@_util.parametrize_device @pytest.mark.parametrize( "data_name, loader_initfunc", [ @@ -211,7 +217,7 @@ def _check_attributes(obj, is_list=False): def test_singlesession_loader(data_name, loader_initfunc, device): data = cebra.datasets.init(data_name) data.to(device) - loader = loader_initfunc(data, num_steps=10, batch_size=32) + loader = loader_initfunc(data, num_steps=10, batch_size=BATCH_SIZE) _assert_dataset_on_correct_device(loader, device) index = loader.get_indices(100) @@ -219,25 +225,33 @@ def test_singlesession_loader(data_name, loader_initfunc, device): for batch in loader: _check_attributes(batch) - assert len(batch.positive) == 32 + assert len(batch.positive) == BATCH_SIZE -def test_multisession_cont_loader(): - data = cebra.datasets.MultiContinuous(nums_neural=[3, 4, 5], - num_behavior=5, - num_timepoints=100) - loader = cebra.data.ContinuousMultiSessionDataLoader( - data, - num_steps=10, - batch_size=32, - ) +@_util.parametrize_device +@pytest.mark.parametrize( + "data_name, loader_initfunc", + [ + ("demo-continuous-multisession", + cebra.data.ContinuousMultiSessionDataLoader), + ("demo-discrete-multisession", + cebra.data.DiscreteMultiSessionDataLoader), + ], +) +def test_multisession_loader(data_name, loader_initfunc, device): + data = cebra.datasets.init(data_name) + data.to(device) + loader = loader_initfunc(data, num_steps=10, batch_size=BATCH_SIZE) + + _assert_dataset_on_correct_device(loader, device) # Check the sampler assert hasattr(loader, "sampler") ref_idx = loader.sampler.sample_prior(1000) - assert len(ref_idx) == 3 # num_sessions - for session in range(3): - assert ref_idx[session].max() < 100 + assert len(ref_idx) == len(NUMS_NEURAL) + for session in range(len(NUMS_NEURAL)): + assert ref_idx[session].max( + ) < cebra.datasets.demo._DEFAULT_NUM_TIMEPOINTS pos_idx, idx, idx_rev = loader.sampler.sample_conditional(ref_idx) assert pos_idx is not None @@ -245,6 +259,8 @@ def test_multisession_cont_loader(): assert idx_rev is not None batch = next(iter(loader)) + for i, n_neurons in enumerate(NUMS_NEURAL): + assert batch[i].reference.shape == (BATCH_SIZE, n_neurons, 10) def _mix(array, idx): shape = array.shape @@ -259,82 +275,70 @@ def _process(batch, feature_dim=1): [b.reference.flatten(1).mean(dim=1, keepdims=True) for b in batch], dim=0).repeat(1, 1, feature_dim) - assert batch[0].reference.shape == (32, 3, 10) - assert batch[1].reference.shape == (32, 4, 10) - assert batch[2].reference.shape == (32, 5, 10) - dummy_prediction = _process(batch, feature_dim=6) - assert dummy_prediction.shape == (3, 32, 6) + assert dummy_prediction.shape == (3, BATCH_SIZE, 6) _mix(dummy_prediction, batch[0].index) + index = loader.get_indices(100) + #print(index[0]) + #print(type(index)) + _check_attributes(index, is_list=False) -def test_multisession_disc_loader(): - data = cebra.datasets.MultiDiscrete(nums_neural=[3, 4, 5], - num_timepoints=100) - loader = cebra.data.DiscreteMultiSessionDataLoader( - data, - num_steps=10, - batch_size=32, - ) - - # Check the sampler - assert hasattr(loader, "sampler") - ref_idx = loader.sampler.sample_prior(1000) - assert len(ref_idx) == 3 # num_sessions - - # Check sample points are in session length range - for session in range(3): - assert ref_idx[session].max() < loader.sampler.session_lengths[session] - pos_idx, idx, idx_rev = loader.sampler.sample_conditional(ref_idx) - - assert pos_idx is not None - assert idx is not None - assert idx_rev is not None - - batch = next(iter(loader)) - - def _mix(array, idx): - shape = array.shape - n, m = shape[:2] - mixed = array.reshape(n * m, -1)[idx] - print(mixed.shape, array.shape, idx.shape) - return mixed.reshape(shape) - - def _process(batch, feature_dim=1): - """Given list_i[(N,d_i)] batch, return (#session, N, feature_dim) tensor""" - return torch.stack( - [b.reference.flatten(1).mean(dim=1, keepdims=True) for b in batch], - dim=0).repeat(1, 1, feature_dim) - - assert batch[0].reference.shape == (32, 3, 10) - assert batch[1].reference.shape == (32, 4, 10) - assert batch[2].reference.shape == (32, 5, 10) - - dummy_prediction = _process(batch, feature_dim=6) - assert dummy_prediction.shape == (3, 32, 6) - _mix(dummy_prediction, batch[0].index) + for batch in loader: + _check_attributes(batch, is_list=True) + for session_batch in batch: + assert len(session_batch.positive) == BATCH_SIZE -@parametrize_device +@_util.parametrize_device @pytest.mark.parametrize( "data_name, loader_initfunc", - [('demo-discrete-multisession', cebra.data.DiscreteMultiSessionDataLoader), - ("demo-continuous-multisession", - cebra.data.ContinuousMultiSessionDataLoader)], + [ + ("demo-continuous-unified", cebra.data.UnifiedLoader), + ], ) -def test_multisession_loader(data_name, loader_initfunc, device): - # TODO change number of timepoints across the sessions - +def test_unified_loader(data_name, loader_initfunc, device): data = cebra.datasets.init(data_name) - kwargs = dict(num_steps=10, batch_size=32) - loader = loader_initfunc(data, **kwargs) + data.to(device) + loader = loader_initfunc(data, num_steps=10, batch_size=BATCH_SIZE) + + _assert_dataset_on_correct_device(loader, device) + + # Check the sampler + num_samples = 100 + assert hasattr(loader, "sampler") + ref_idx = loader.sampler.sample_all_uniform_prior(num_samples) + assert ref_idx.shape == (len(NUMS_NEURAL), num_samples) + assert isinstance(ref_idx, np.ndarray) + + for session in range(len(NUMS_NEURAL)): + assert ref_idx[session].max( + ) < cebra.datasets.demo._DEFAULT_NUM_TIMEPOINTS + pos_idx = loader.sampler.sample_conditional(ref_idx) + assert pos_idx.shape == (len(NUMS_NEURAL), num_samples) + + for session in range(len(NUMS_NEURAL)): + ref_idx = torch.from_numpy( + loader.sampler.sample_all_uniform_prior( + num_samples=num_samples)[session]) + assert ref_idx.shape == (num_samples,) + all_ref_idx = loader.sampler.sample_all_sessions(ref_idx=ref_idx, + session_id=session) + assert all_ref_idx.shape == (len(NUMS_NEURAL), num_samples) + assert isinstance(all_ref_idx, torch.Tensor) + for i in range(len(all_ref_idx)): + assert all_ref_idx[i].max( + ) < cebra.datasets.demo._DEFAULT_NUM_TIMEPOINTS + + for i in range(len(all_ref_idx)): + pos_idx = loader.sampler.sample_conditional(all_ref_idx) + assert pos_idx.shape == (len(NUMS_NEURAL), num_samples) + + # Check the batch + batch = next(iter(loader)) + assert batch.reference.shape == (BATCH_SIZE, sum(NUMS_NEURAL), 10) + assert batch.positive.shape == (BATCH_SIZE, sum(NUMS_NEURAL), 10) + assert batch.negative.shape == (BATCH_SIZE, sum(NUMS_NEURAL), 10) index = loader.get_indices(100) - print(index[0]) - print(type(index)) _check_attributes(index, is_list=False) - - for batch in loader: - _check_attributes(batch, is_list=True) - for session_batch in batch: - assert len(session_batch.positive) == 32 diff --git a/tests/test_solver.py b/tests/test_solver.py index 3be54c2e..8a54aaa9 100644 --- a/tests/test_solver.py +++ b/tests/test_solver.py @@ -41,10 +41,8 @@ def _get_loader(data_name, loader_initfunc): loader = loader_initfunc(data, **kwargs) return loader, data - OUTPUT_DIMENSION = 3 - def _make_model(dataset, model_architecture="offset10-model"): # TODO flexible input dimension # return nn.Sequential( @@ -349,3 +347,55 @@ def test_multi_session(data_name, loader_initfunc, model_architecture, solver.save(temp_dir) solver.load(temp_dir) _assert_equal(fitted_solver, solver) + + +def _make_val_data(dataset): + if isinstance(dataset, cebra.datasets.demo.DemoDataset): + return dataset.neural + elif isinstance(dataset, cebra.datasets.demo.DemoDatasetUnified): + return [session.neural for session in dataset.iter_sessions()], [ + session.continuous_index for session in dataset.iter_sessions() + ] + + +@pytest.mark.parametrize( + "data_name, model_architecture, loader_initfunc, solver_initfunc", + [(dataset, model, loader, cebra.solver.UnifiedSolver) + for dataset, loader in [ + ("demo-continuous-unified", cebra.data.UnifiedLoader), + ] + for model in ["offset1-model", "offset10-model"]]) +def test_unified_session(data_name, model_architecture, loader_initfunc, + solver_initfunc): + loader, data = _get_loader(data_name, loader_initfunc) + model = _make_model(data, model_architecture) + data.configure_for(model) + offset = model.get_offset() + + criterion = cebra.models.InfoNCE() + optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) + + solver = solver_initfunc(model=model, + criterion=criterion, + optimizer=optimizer) + + batch = next(iter(loader)) + assert batch.reference.shape == (32, loader.dataset.input_dimension, + len(offset)) + + log = solver.step(batch) + assert isinstance(log, dict) + + solver.fit(loader) + data, labels = _make_val_data(loader.dataset) + + assert solver.num_sessions == 3 + assert solver.n_features == sum( + [data[i].shape[1] for i in range(len(data))]) + + for i in range(loader.dataset.num_sessions): + emb = solver.transform(data, labels, session_id=i) + assert emb.shape == (loader.dataset.num_timepoints, 3) + + emb = solver.transform(data, labels, session_id=i, batch_size=300) + assert emb.shape == (loader.dataset.num_timepoints, 3) \ No newline at end of file