diff --git a/.dockerignore b/.dockerignore index 3937fd07..945ab50a 100644 --- a/.dockerignore +++ b/.dockerignore @@ -5,3 +5,5 @@ tests/ third_party/ tools/ PKGBUILD + +!docs/requirements.txt diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index a231258f..f109294c 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -10,6 +10,7 @@ on: jobs: build: + timeout-minutes: 30 strategy: fail-fast: true matrix: @@ -18,27 +19,33 @@ jobs: # We aim to support the versions on pytorch.org # as well as selected previous versions on # https://pytorch.org/get-started/previous-versions/ - torch-version: ["2.2.2", "2.4.0"] + torch-version: ["2.4.0", "2.6.0"] + sklearn-version: ["latest"] include: - os: windows-latest torch-version: 2.4.0 python-version: "3.10" + sklearn-version: "latest" + - os: ubuntu-latest + torch-version: 2.4.0 + python-version: "3.10" + sklearn-version: "legacy" runs-on: ${{ matrix.os }} steps: - name: Cache dependencies id: pip-cache - uses: actions/cache@v3 + uses: actions/cache@v4 with: path: ~/.cache/pip - key: pip-os_${{ runner.os }}-python_${{ matrix.python-version }}-torch_${{ matrix.torch-version }} + key: pip-os_${{ runner.os }}-python_${{ matrix.python-version }}-torch_${{ matrix.torch-version }}-sklearn_${{ matrix.sklearn-version }} - name: Checkout code uses: actions/checkout@v2 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} @@ -48,6 +55,11 @@ jobs: python -m pip install torch==${{ matrix.torch-version }} --extra-index-url https://download.pytorch.org/whl/cpu pip install '.[dev,datasets,integrations]' + - name: Check sklearn legacy version + if: matrix.sklearn-version == 'legacy' + run: | + pip install scikit-learn==1.4.2 '.[dev,datasets,integrations]' + - name: Run the formatter run: | make format @@ -56,6 +68,10 @@ jobs: run: | make codespell + - name: Check the documentation coverage + run: | + make interrogate + - name: Check CITATION.cff validity run: | cffconvert --validate diff --git a/.github/workflows/doc-coverage.yml b/.github/workflows/doc-coverage.yml deleted file mode 100644 index 268cbee0..00000000 --- a/.github/workflows/doc-coverage.yml +++ /dev/null @@ -1,82 +0,0 @@ -name: PR Status -# Adapted from https://github.com/shift-happens-benchmark/icml-2022/blob/main/.github/workflows/pr-status.yml -# Apache 2.0 licensed - - -# NOTE(stes): Use pull_request_target instead of pull_request to allow -# to post comments on the current PR, even when an external contributor -# opens a PR. -# IMPORTANT: DO NOT EXPOSE REPOSITORY SECRETS WITHIN THIS PR! -on: - pull_request: - branches: - - main - - public - - dev - -permissions: - pull-requests: write - -jobs: - documentation-status: - runs-on: ubuntu-latest - strategy: - matrix: - python-version: ['3.8'] - - steps: - # NOTE(stes) currently not used, we check - # the entire codebase now by default. - #- uses: actions/checkout@v3 - # with: - # ref: main - - uses: actions/checkout@v3 - - uses: actions/cache@v1 - with: - path: ~/.cache/pip - key: ${{ runner.os }}-pip - restore-keys: | - ${{ runner.os }}-pip - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v1 - with: - python-version: ${{ matrix.python-version }} - - name: Install package - run: | - python -m pip install --upgrade pip setuptools wheel - pip install interrogate==1.5.0 - - name: documentation - id: documentation - run: | - RESULT=$(make --silent interrogate) - RESULT=$(tail -n +3 <<< $RESULT) - STATUS=$(tail -n1 <<< $RESULT) - STATUS=$(sed 's/-//g' <<< $STATUS) - # trim - STATUS=$(echo $STATUS | xargs echo -n) - RESULT=$(head -n -1 <<< $RESULT) - # remove second to last line - RESULTA=$(head -n -2 <<< $RESULT) - RESULTB=$(tail -n1 <<< $RESULT) - NL=$'\n' - RESULT="$RESULTA${NL}||||||${NL}$RESULTB" - RESULT="$RESULT${NL}${NL}$STATUS" - RESULT="${RESULT//'%'/'%25'}" - RESULT="${RESULT//$'\n'/'%0A'}" - RESULT="${RESULT//$'\r'/'%0D'}" - echo "::set-output name=result::$RESULT" - continue-on-error: true - #- name: comment documentation result on PR - # uses: thollander/actions-comment-pull-request@v1 - # with: - # message: | - # ## Docstring Coverage Report - # ${{ steps.documentation.outputs.result }} - # comment_includes: '## Docstring Coverage Report' - # GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - # following snippet borrowed from - # https://stackoverflow.com/a/58003436 - # CC BY-SA 4.0, Peter Evans - - name: Fail on insufficient coverage - if: steps.documentation.outcome != 'success' - run: exit 1 diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index 83c9d829..39f882b9 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -7,8 +7,6 @@ on: pull_request: branches: - main - - public - - dev jobs: build: @@ -17,7 +15,7 @@ jobs: steps: - name: Cache dependencies id: pip-cache - uses: actions/cache@v3 + uses: actions/cache@v4 with: path: ~/.cache/pip key: ${{ runner.os }}-pip @@ -51,28 +49,33 @@ jobs: path: docs/source/demo_notebooks ref: main - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 + - name: Set up Python 3.10 + uses: actions/setup-python@v5 with: - python-version: ${{ matrix.python-version }} + python-version: "3.10" - name: Install package run: | python -m pip install --upgrade pip setuptools wheel # NOTE(stes) Pandoc version must be at least (2.14.2) but less than (4.0.0). - # as of 29/10/23. Ubuntu 22.04 which is used for ubuntu-latest only has an + # as of 29/10/23. Ubuntu 22.04 which is used for ubuntu-latest only has an # old pandoc version (2.9.). We will hence install the latest version manually. # previou: sudo apt-get install -y pandoc - wget https://github.com/jgm/pandoc/releases/download/3.1.9/pandoc-3.1.9-1-amd64.deb - sudo dpkg -i pandoc-3.1.9-1-amd64.deb - rm pandoc-3.1.9-1-amd64.deb - pip install torch --extra-index-url https://download.pytorch.org/whl/cpu - pip install '.[docs]' + # NOTE(stes): Updated to latest version as of 17/04/2025, v3.6.4. + wget -q https://github.com/jgm/pandoc/releases/download/3.6.4/pandoc-3.6.4-1-amd64.deb + sudo dpkg -i pandoc-3.6.4-1-amd64.deb + rm pandoc-3.6.4-1-amd64.deb + pip install -r docs/requirements.txt + + - name: Check software versions + run: | + sphinx-build --version + pandoc --version - name: Build docs run: | ls docs/source/cebra-figures - # later also add the -n option to check for broken links + export SPHINXBUILD="sphinx-build" export SPHINXOPTS="-W --keep-going -n" make docs diff --git a/.github/workflows/release-pypi.yml b/.github/workflows/release-pypi.yml index d6950119..ac078fd9 100644 --- a/.github/workflows/release-pypi.yml +++ b/.github/workflows/release-pypi.yml @@ -23,11 +23,18 @@ jobs: steps: - name: Cache dependencies id: pip-cache - uses: actions/cache@v3 + uses: actions/cache@v4 with: path: ~/.cache/pip key: ${{ runner.os }}-pip + - name: Install dependencies + run: | + pip install --upgrade pip + pip install wheel + # NOTE(stes) see https://github.com/pypa/twine/issues/1216#issuecomment-2629069669 + pip install "packaging>=24.2" + - name: Checkout code uses: actions/checkout@v3 diff --git a/.gitignore b/.gitignore index 30b65ee3..e30f5f43 100644 --- a/.gitignore +++ b/.gitignore @@ -7,6 +7,25 @@ experiments/sweeps exports/ demo_notebooks/ assets/ +.remove + +# demo run +.vscode/ +auxiliary_behavior_data.h5 +cebra_model.pt +data.npz +grid_search_models/ +neural_data.npz +saved_models/ + +# demo run +.vscode/ +auxiliary_behavior_data.h5 +cebra_model.pt +data.npz +grid_search_models/ +neural_data.npz +saved_models/ # Binary files *.png diff --git a/AUTHORS.md b/AUTHORS.md new file mode 100644 index 00000000..17db8887 --- /dev/null +++ b/AUTHORS.md @@ -0,0 +1,26 @@ + + + + +CEBRA was initially developed by **Mackenzie Mathis** and **Steffen Schneider** (2021+), who are co-inventors on the patent application [WO2023143843](https://infoscience.epfl.ch/entities/patent/0d9debed-4d22-47b7-bad1-f211e7010323). +**Jin Hwa Lee** contributed significantly to our first paper: + +> **Schneider, S., Lee, J.H., & Mathis, M.W.** +> [*Learnable latent embeddings for joint behavioural and neural analysis.*](https://doi.org/10.1038/s41586-023-06031-6) +> Nature 617, 360–368 (2023) + +CEBRA is actively developed by [**Mackenzie Mathis**](https://www.mackenziemathislab.org/) and [**Steffen Schneider**](https://dynamical-inference.ai/) and their labs. + +It is a publicly available tool that has benefited from contributions and suggestions from many individuals: [CEBRA/graphs/contributors](https://github.com/AdaptiveMotorControlLab/CEBRA/graphs/contributors). + +## CEBRA Extensions + +### 2023 +- **Steffen Schneider, Rodrigo González Laiz, Markus Frey, Mackenzie W. Mathis** + [*Identifiable attribution maps using regularized contrastive learning.*](https://sslneurips23.github.io/paper_pdfs/paper_80.pdf) + NeurIPS 4th Workshop on Self-Supervised Learning: Theory and Practice (2023) + +### 2025 +- **Steffen Schneider, Rodrigo González Laiz, Anastasiia Filippova, Markus Frey, Mackenzie W. Mathis** + [*Time-series attribution maps with regularized contrastive learning.*](https://openreview.net/forum?id=aGrCXoTB4P) + AISTATS (2025) diff --git a/Dockerfile b/Dockerfile index d734ee6f..1a280a30 100644 --- a/Dockerfile +++ b/Dockerfile @@ -40,7 +40,7 @@ RUN make dist FROM cebra-base # install the cebra wheel -ENV WHEEL=cebra-0.4.0-py2.py3-none-any.whl +ENV WHEEL=cebra-0.6.0a1-py3-none-any.whl WORKDIR /build COPY --from=wheel /build/dist/${WHEEL} . RUN pip install --no-cache-dir ${WHEEL}'[dev,integrations,datasets]' diff --git a/Makefile b/Makefile index ca8c5480..a863a921 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,4 @@ -CEBRA_VERSION := 0.4.0 +CEBRA_VERSION := 0.6.0a1 dist: python3 -m pip install virtualenv @@ -55,7 +55,7 @@ interrogate: --ignore-private \ --ignore-magic \ --omit-covered-files \ - -f 90 \ + -f 80 \ cebra # Build documentation using sphinx diff --git a/NOTICE.yml b/NOTICE.yml index 3588b5e6..bf498e0f 100644 --- a/NOTICE.yml +++ b/NOTICE.yml @@ -35,3 +35,83 @@ - 'tests/**/*.py' - 'docs/**/*.py' - 'conda/**/*.yml' + +- header: | + CEBRA: Consistent EmBeddings of high-dimensional Recordings using Auxiliary variables + © Mackenzie W. Mathis & Steffen Schneider (v0.4.0+) + Source code: + https://github.com/AdaptiveMotorControlLab/CEBRA + + Please see LICENSE.md for the full license document: + https://github.com/AdaptiveMotorControlLab/CEBRA/blob/main/LICENSE.md + + Adapted from https://github.com/rpatrik96/nl-causal-representations/blob/master/care_nl_ica/dep_mat.py, + licensed under the following MIT License: + + MIT License + + Copyright (c) 2022 Patrik Reizinger + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in all + copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + SOFTWARE. + + include: + - 'cebra/attribution/jacobian.py' + + +- header: | + CEBRA: Consistent EmBeddings of high-dimensional Recordings using Auxiliary variables + © Mackenzie W. Mathis & Steffen Schneider (v0.4.0+) + Source code: + https://github.com/AdaptiveMotorControlLab/CEBRA + + Please see LICENSE.md for the full license document: + https://github.com/AdaptiveMotorControlLab/CEBRA/blob/main/LICENSE.md + + This file contains the PyTorch implementation of Jacobian regularization described in [1]. + Judy Hoffman, Daniel A. Roberts, and Sho Yaida, + "Robust Learning with Jacobian Regularization," 2019. + [arxiv:1908.02729](https://arxiv.org/abs/1908.02729) + + Adapted from https://github.com/facebookresearch/jacobian_regularizer/blob/main/jacobian/jacobian.py + licensed under the following MIT License: + + MIT License + + Copyright (c) Facebook, Inc. and its affiliates. + + Permission is hereby granted, free of charge, to any person obtaining a copy + of this software and associated documentation files (the "Software"), to deal + in the Software without restriction, including without limitation the rights + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + copies of the Software, and to permit persons to whom the Software is + furnished to do so, subject to the following conditions: + + The above copyright notice and this permission notice shall be included in all + copies or substantial portions of the Software. + + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + SOFTWARE. + + include: + - 'cebra/models/jacobian_regularizer.py' diff --git a/PKGBUILD b/PKGBUILD index 07fa3a1d..48088dcb 100644 --- a/PKGBUILD +++ b/PKGBUILD @@ -1,7 +1,7 @@ # Maintainer: Steffen Schneider pkgname=python-cebra _pkgname=cebra -pkgver=0.4.0 +pkgver=0.6.0a1 pkgrel=1 pkgdesc="Consistent Embeddings of high-dimensional Recordings using Auxiliary variables" url="https://cebra.ai" @@ -40,7 +40,7 @@ build() { package() { cd $srcdir/${_pkgname}-${pkgver} - pip install --ignore-installed --no-deps --root="${pkgdir}" dist/${_pkgname}-${pkgver}-py2.py3-none-any.whl + pip install --ignore-installed --no-deps --root="${pkgdir}" dist/${_pkgname}-${pkgver}-py3-none-any.whl find ${pkgdir} -iname __pycache__ -exec rm -r {} \; 2>/dev/null || echo install -Dm 644 LICENSE.md $pkgdir/usr/share/licenses/${pkgname}/LICENSE } diff --git a/cebra/__init__.py b/cebra/__init__.py index 204cd2a2..cb2cbd06 100644 --- a/cebra/__init__.py +++ b/cebra/__init__.py @@ -66,7 +66,7 @@ import cebra.integrations.sklearn as sklearn -__version__ = "0.4.0" +__version__ = "0.6.0a1" __all__ = ["CEBRA"] __allow_lazy_imports = False __lazy_imports = {} diff --git a/cebra/attribution/__init__.py b/cebra/attribution/__init__.py new file mode 100644 index 00000000..e1d8306a --- /dev/null +++ b/cebra/attribution/__init__.py @@ -0,0 +1,38 @@ +# +# CEBRA: Consistent EmBeddings of high-dimensional Recordings using Auxiliary variables +# © Mackenzie W. Mathis & Steffen Schneider (v0.4.0+) +# Source code: +# https://github.com/AdaptiveMotorControlLab/CEBRA +# +# Please see LICENSE.md for the full license document: +# https://github.com/AdaptiveMotorControlLab/CEBRA/blob/main/LICENSE.md +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""Attribution methods for CEBRA. + +This module was added in v0.6.0 and contains attribution methods described and benchmarked +in [Schneider2025]_. + + +.. [Schneider2025] Schneider, S., González Laiz, R., Filippova, A., Frey, M., & Mathis, M. W. (2025). + Time-series attribution maps with regularized contrastive learning. + The 28th International Conference on Artificial Intelligence and Statistics. + https://openreview.net/forum?id=aGrCXoTB4P +""" +import cebra.registry + +cebra.registry.add_helper_functions(__name__) + +from cebra.attribution.attribution_models import * +from cebra.attribution.jacobian_attribution import * diff --git a/cebra/attribution/_jacobian.py b/cebra/attribution/_jacobian.py new file mode 100644 index 00000000..00102aeb --- /dev/null +++ b/cebra/attribution/_jacobian.py @@ -0,0 +1,142 @@ +# +# CEBRA: Consistent EmBeddings of high-dimensional Recordings using Auxiliary variables +# © Mackenzie W. Mathis & Steffen Schneider (v0.4.0+) +# Source code: +# https://github.com/AdaptiveMotorControlLab/CEBRA +# +# Please see LICENSE.md for the full license document: +# https://github.com/AdaptiveMotorControlLab/CEBRA/blob/main/LICENSE.md +# +# Adapted from https://github.com/rpatrik96/nl-causal-representations/blob/master/care_nl_ica/dep_mat.py, +# licensed under the following MIT License: +# +# MIT License +# +# Copyright (c) 2022 Patrik Reizinger +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# + +from typing import Union + +import numpy as np +import torch + + +def tensors_to_cpu_and_double(vars_: list[torch.Tensor]) -> list[torch.Tensor]: + """Convert a list of tensors to CPU and double precision. + + Args: + vars_: List of PyTorch tensors to convert + + Returns: + List of tensors converted to CPU and double precision + """ + cpu_vars = [] + for v in vars_: + if v.is_cuda: + v = v.to("cpu") + cpu_vars.append(v.double()) + return cpu_vars + + +def tensors_to_cuda(vars_: list[torch.Tensor], + cuda_device: str) -> list[torch.Tensor]: + """Convert a list of tensors to CUDA device. + + Args: + vars_: List of PyTorch tensors to convert + cuda_device: CUDA device to move tensors to + + Returns: + List of tensors moved to specified CUDA device + """ + cpu_vars = [] + for v in vars_: + if not v.is_cuda: + v = v.to(cuda_device) + cpu_vars.append(v) + return cpu_vars + + +def compute_jacobian( + model: torch.nn.Module, + input_vars: list[torch.Tensor], + mode: str = "autograd", + cuda_device: str = "cuda", + double_precision: bool = False, + convert_to_numpy: bool = True, + hybrid_solver: bool = False, +) -> Union[torch.Tensor, np.ndarray]: + """Compute the Jacobian matrix for a given model and input. + + This function computes the Jacobian matrix using PyTorch's autograd functionality. + It supports both CPU and CUDA computation, as well as single and double precision. + + Args: + model: PyTorch model to compute Jacobian for + input_vars: List of input tensors + mode: Computation mode, currently only "autograd" is supported + cuda_device: Device to use for CUDA computation + double_precision: If True, use double precision + convert_to_numpy: If True, convert output to numpy array + hybrid_solver: If True, concatenate multiple outputs along dimension 1 + + Returns: + Jacobian matrix as either PyTorch tensor or numpy array + """ + if double_precision: + model = model.to("cpu").double() + input_vars = tensors_to_cpu_and_double(input_vars) + if hybrid_solver: + output = model(*input_vars) + output_vars = torch.cat(output, dim=1).to("cpu").double() + else: + output_vars = model(*input_vars).to("cpu").double() + else: + model = model.to(cuda_device).float() + input_vars = tensors_to_cuda(input_vars, cuda_device=cuda_device) + + if hybrid_solver: + output = model(*input_vars) + output_vars = torch.cat(output, dim=1) + else: + output_vars = model(*input_vars) + + if mode == "autograd": + jacob = [] + for i in range(output_vars.shape[1]): + grads = torch.autograd.grad( + output_vars[:, i:i + 1], + input_vars, + retain_graph=True, + create_graph=False, + grad_outputs=torch.ones(output_vars[:, i:i + 1].shape).to( + output_vars.device), + ) + jacob.append(torch.cat(grads, dim=1)) + + jacobian = torch.stack(jacob, dim=1) + + jacobian = jacobian.detach().cpu() + + if convert_to_numpy: + jacobian = jacobian.numpy() + + return jacobian diff --git a/cebra/attribution/attribution_models.py b/cebra/attribution/attribution_models.py new file mode 100644 index 00000000..ddbc7a37 --- /dev/null +++ b/cebra/attribution/attribution_models.py @@ -0,0 +1,720 @@ +# +# CEBRA: Consistent EmBeddings of high-dimensional Recordings using Auxiliary variables +# © Mackenzie W. Mathis & Steffen Schneider (v0.4.0+) +# Source code: +# https://github.com/AdaptiveMotorControlLab/CEBRA +# +# Please see LICENSE.md for the full license document: +# https://github.com/AdaptiveMotorControlLab/CEBRA/blob/main/LICENSE.md +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import dataclasses +import time + +import cvxpy as cp +import numpy as np +import scipy.linalg +import sklearn.metrics +import torch +import torch.nn as nn +import tqdm +from captum.attr import NeuronFeatureAblation +from captum.attr import NeuronGradient +from captum.attr import NeuronGradientShap +from captum.attr import NeuronIntegratedGradients + +import cebra +import cebra.attribution._jacobian +from cebra.attribution import register + + +@dataclasses.dataclass +class AttributionMap: + """Base class for computing attribution maps for CEBRA models. + + Args: + model: The trained CEBRA model to analyze + input_data: Input data tensor to compute attributions for + output_dimension: Output dimension to analyze. If ``None``, uses model's output dimension + num_samples: Number of samples to use for attribution. If ``None``, uses full dataset + seed: Random seed which is used to subsample the data. Only relevant if ``num_samples`` is not ``None``. + """ + + model: nn.Module + input_data: torch.Tensor + output_dimension: int = None + num_samples: int = None + seed: int = 9712341 + + def __post_init__(self): + if isinstance(self.model, cebra.models.ConvolutionalModelMixin): + data = cebra.data.TensorDataset(self.input_data, + continuous=torch.zeros( + len(self.input_data))) + data.configure_for(self.model) + offset = self.model.get_offset() + + #NOTE: explain, why do we do this again? + input_data = data[torch.arange(offset.left, + len(data) - offset.right + 1)].to( + self.input_data.device) + + # subsample the data + if self.num_samples is not None: + if self.num_samples > input_data.shape[0]: + raise ValueError( + f"You are using a bigger number of samples to " + f"subsample ({self.num_samples}) than the number " + f"of samples in the dataset ({input_data.shape[0]}).") + + random_generator = torch.Generator() + random_generator.manual_seed(self.seed) + num_elements = input_data.size(0) + random_indices = torch.randperm( + num_elements, generator=random_generator)[:self.num_samples] + input_data = input_data[random_indices] + + self.input_data = input_data + + def compute_attribution_map(self): + """Compute the attribution map for the model. + + Returns: + dict: Attribution maps and their variants + + Raises: + NotImplementedError: Must be implemented by subclasses + """ + raise NotImplementedError + + def compute_metrics(self, attribution_map, ground_truth_map): + """Compute metrics comparing attribution map to ground truth. + + This function computes various statistical metrics to compare the attribution values + between connected and non-connected neurons based on a ground truth connectivity map. + It separates the attribution values into two groups based on the binary ground truth, + and calculates summary statistics and differences between these groups. + + Args: + attribution_map: Computed attribution values representing the strength of connections + between neurons + ground_truth_map: Binary ground truth connectivity map where True indicates a + connected neuron and False indicates a non-connected neuron + + Returns: + dict: Dictionary containing the following metrics: + - max/mean/min_nonconnected: Statistics for non-connected neurons + - max/mean/min_connected: Statistics for connected neurons + - gap_max: Difference between max connected and max non-connected values + - gap_mean: Difference between mean connected and mean non-connected values + - gap_min: Difference between min connected and min non-connected values + - gap_minmax: Difference between min connected and max non-connected values + - max/min_jacobian: Global max/min values across all neurons + """ + assert np.issubdtype(ground_truth_map.dtype, bool) + connected_neurons = attribution_map[np.where(ground_truth_map)] + non_connected_neurons = attribution_map[np.where(~ground_truth_map)] + assert connected_neurons.size == ground_truth_map.sum() + assert non_connected_neurons.size == ground_truth_map.size - ground_truth_map.sum( + ) + assert connected_neurons.size + non_connected_neurons.size == attribution_map.size == ground_truth_map.size + + max_connected = np.max(connected_neurons) + mean_connected = np.mean(connected_neurons) + min_connected = np.min(connected_neurons) + + max_nonconnected = np.max(non_connected_neurons) + mean_nonconnected = np.mean(non_connected_neurons) + min_nonconnected = np.min(non_connected_neurons) + + metrics = { + 'max_nonconnected': max_nonconnected, + 'mean_nonconnected': mean_nonconnected, + 'min_nonconnected': min_nonconnected, + 'max_connected': max_connected, + 'mean_connected': mean_connected, + 'min_connected': min_connected, + 'gap_max': max_connected - max_nonconnected, + 'gap_mean': mean_connected - mean_nonconnected, + 'gap_min': min_connected - min_nonconnected, + 'gap_minmax': min_connected - max_nonconnected, + 'max_jacobian': np.max(attribution_map), + 'min_jacobian': np.min(attribution_map), + } + return metrics + + def compute_attribution_score(self, attribution_map, ground_truth_map): + """Compute ROC AUC score between attribution map and ground truth. + + Args: + attribution_map: Computed attribution values + ground_truth_map: Binary ground truth connectivity map + + Returns: + float: ROC AUC score + """ + assert attribution_map.shape == ground_truth_map.shape + assert np.issubdtype(ground_truth_map.dtype, bool) + fpr, tpr, _ = sklearn.metrics.roc_curve( # noqa: codespell:ignore fpr, tpr + ground_truth_map.flatten(), attribution_map.flatten()) + auc = sklearn.metrics.auc(fpr, tpr) # noqa: codespell:ignore fpr, tpr + return auc + + @staticmethod + def _check_moores_penrose_conditions( + matrix: np.ndarray, matrix_inverse: np.ndarray) -> np.ndarray: + """Check Moore-Penrose conditions for a single matrix pair. + + Args: + matrix: Input matrix + matrix_inverse: Putative pseudoinverse matrix + + Returns: + np.ndarray: Boolean array indicating which conditions are satisfied + """ + matrix_inverse = matrix_inverse.T + condition_1 = np.allclose(matrix @ matrix_inverse @ matrix, matrix) + condition_2 = np.allclose(matrix_inverse @ matrix @ matrix_inverse, + matrix_inverse) + condition_3 = np.allclose((matrix @ matrix_inverse).T, + matrix @ matrix_inverse) + condition_4 = np.allclose((matrix_inverse @ matrix).T, + matrix_inverse @ matrix) + + return np.array([condition_1, condition_2, condition_3, condition_4]) + + def check_moores_penrose_conditions( + self, jacobian: np.ndarray, + jacobian_pseudoinverse: np.ndarray) -> np.ndarray: + """Check Moore-Penrose conditions for Jacobian matrices. + + Args: + jacobian: Jacobian matrices of shape (num samples, output_dim, num_neurons) + jacobian_pseudoinverse: Pseudoinverse matrices of shape (num samples, num_neurons, output_dim) + + Returns: + Boolean array of shape (num samples, 4) indicating satisfied conditions + """ + # check the four conditions + conditions = np.zeros((jacobian.shape[0], 4)) + for i, (matrix, inverse_matrix) in enumerate( + zip(jacobian, jacobian_pseudoinverse)): + conditions[i] = self._check_moores_penrose_conditions( + matrix, inverse_matrix) + return conditions + + def _inverse(self, jacobian, method="lsq"): + """Compute inverse/pseudoinverse of Jacobian matrices. + + Args: + jacobian: Input Jacobian matrices + method: Inversion method ('lsq_cvxpy', 'lsq', or 'svd') + + Returns: + (Inverse matrices, computation time) + """ + # NOTE(stes): Before we used "np.linalg.pinv" here, which + # is numerically not stable for the Jacobian matrices we + # need to compute. + start_time = time.time() + Jfinv = np.zeros_like(jacobian) + if method == "lsq_cvxpy": + for i in tqdm(range(len(jacobian))): + Jfinv[i] = self._inverse_lsq_cvxpy(jacobian[i]).T + elif method == "lsq": + for i in range(len(jacobian)): + Jfinv[i] = self._inverse_lsq_scipy(jacobian[i]).T + elif method == "svd": + for i in range(len(jacobian)): + Jfinv[i] = self._inverse_svd(jacobian[i]).T + else: + raise NotImplementedError(f"Method {method} not implemented.") + end_time = time.time() + return Jfinv, end_time - start_time + + @staticmethod + def _inverse_lsq_cvxpy(matrix: np.ndarray, + solver: str = 'SCS') -> np.ndarray: + """Compute least squares inverse using CVXPY. + + Args: + matrix: Input matrix + solver: CVXPY solver to use + + Returns: + np.ndarray: Least squares inverse matrix + """ + + matrix_param = cp.Parameter((matrix.shape[0], matrix.shape[1])) + matrix_param.value = matrix + + identity = np.eye(matrix.shape[0]) + matrix_inverse = cp.Variable((matrix.shape[1], matrix.shape[0])) + # noqa: codespell + objective = cp.Minimize( + cp.norm(matrix @ matrix_inverse - identity, + "fro")) # noqa: codespell:ignore fro + prob = cp.Problem(objective) + prob.solve(verbose=False, solver=solver) + + return matrix_inverse.value + + @staticmethod + def _inverse_lsq_scipy(jacobian): + """Compute least squares inverse using scipy.linalg.lstsq. + + Args: + jacobian: Input Jacobian matrix + + Returns: + np.ndarray: Least squares inverse matrix + """ + return scipy.linalg.lstsq(jacobian, np.eye(jacobian.shape[0]))[0] + + @staticmethod + def _inverse_svd(jacobian): + """Compute pseudoinverse using SVD. + + Args: + jacobian: Input Jacobian matrix + + Returns: + np.ndarray: Pseudoinverse matrix + """ + return scipy.linalg.pinv(jacobian) + + def _reduce_attribution_map(self, attribution_maps): + """Reduce attribution maps by averaging across dimensions. + + Args: + attribution_maps: Dictionary of attribution maps to reduce + + Returns: + dict: Reduced attribution maps + """ + + def _reduce(full_jacobian): + if full_jacobian.ndim == 4: + jf_convabs = abs(full_jacobian).mean(-1) + jf = full_jacobian.mean(-1) + else: + jf_convabs = full_jacobian + jf = full_jacobian + return jf, jf_convabs + + result = {} + for key, value in attribution_maps.items(): + result[key], result[f'{key}-convabs'] = _reduce(value) + return result + + +@dataclasses.dataclass +@register("jacobian-based") +class JFMethodBased(AttributionMap): + """Compute the attribution map using the Jacobian of the model encoder.""" + + def _compute_jacobian(self, input_data): + return cebra.attribution._jacobian.compute_jacobian( + self.model, + input_vars=[input_data], + mode="autograd", + cuda_device=self.input_data.device, + double_precision=False, + convert_to_numpy=True, + hybrid_solver=False, + ) + + def compute_attribution_map(self): + + full_jacobian = self._compute_jacobian(self.input_data) + + result = {} + for key, value in self._reduce_attribution_map({ + 'jf': full_jacobian + }).items(): + result[key] = value + for method in ['lsq', 'svd']: + print(f"Computing inverse for {key} with method {method}") + result[f"{key}-inv-{method}"], result[ + f'time_inversion_{method}'] = self._inverse(value, + method=method) + # result[f"{key}-inv-{method}-conditions"] = self.check_moores_penrose_conditions(value, result[f"{key}-inv-{method}"]) + + return result + + +@dataclasses.dataclass +@register("jacobian-based-batched") +class JFMethodBasedBatched(JFMethodBased): + """Compute an attribution map based on the Jacobian using mini-batches. + + See also: + :py:class:`JFMethodBased` + """ + + def compute_attribution_map(self, batch_size=1024): + if batch_size > self.input_data.shape[0]: + raise ValueError( + f"Batch size ({batch_size}) is bigger than data ({self.input_data.shape[0]})" + ) + + input_data_batches = torch.split(self.input_data, batch_size) + full_jacobian = [] + for input_data_batch in input_data_batches: + jacobian_batch = self._compute_jacobian(input_data_batch) + full_jacobian.append(jacobian_batch) + full_jacobian = np.vstack(full_jacobian) + + result = {} + for key, value in self._reduce_attribution_map({ + 'jf': full_jacobian + }).items(): + result[key] = value + for method in ['lsq', 'svd']: + + result[f"{key}-inv-{method}"], result[ + f'time_inversion_{method}'] = self._inverse(value, + method=method) + + return result + + +@dataclasses.dataclass +@register("neuron-gradient") +class NeuronGradientMethod(AttributionMap): + """Compute the attribution map using the neuron gradient from Captum. + + Note: + This method is equivalent to Jacobian-based attributions, but + uses a different backend implementation. + """ + + def __post_init__(self): + super().__post_init__() + self.captum_model = NeuronGradient(forward_func=self.model, + layer=self.model) + + def compute_attribution_map(self, attribute_to_neuron_input=False): + attribution_map = [] + for s in range(self.output_dimension): + att = self.captum_model.attribute( + inputs=self.input_data, + attribute_to_neuron_input=attribute_to_neuron_input, + neuron_selector=s) + + attribution_map.append(att.detach().cpu().numpy()) + + attribution_map = np.array(attribution_map) + attribution_map = np.swapaxes(attribution_map, 1, 0) + + result = {} + for key, value in self._reduce_attribution_map({ + 'neuron-gradient': attribution_map + }).items(): + result[key] = value + + for method in ['lsq', 'svd']: + result[f"{key}-inv-{method}"], result[ + f'time_inversion_{method}'] = self._inverse(value, + method=method) + # result[f"{key}-inv-{method}-conditions"] = self.check_moores_penrose_conditions(value, result[f"{key}-inv-{method}"]) + + return result + + +@dataclasses.dataclass +@register("neuron-gradient-batched") +class NeuronGradientMethodBatched(NeuronGradientMethod): + """As :py:class:`NeuronGradientMethod`, but using mini-batches. + + See also: + :py:class:`NeuronGradientMethod` + """ + + def compute_attribution_map(self, + attribute_to_neuron_input=False, + batch_size=1024): + input_data_batches = torch.split(self.input_data, batch_size) + + attribution_map = [] + for input_data_batch in input_data_batches: + attribution_map_batch = [] + for s in range(self.output_dimension): + att = self.captum_model.attribute( + inputs=input_data_batch, + attribute_to_neuron_input=attribute_to_neuron_input, + neuron_selector=s) + + attribution_map_batch.append(att.detach().cpu().numpy()) + + attribution_map_batch = np.array(attribution_map_batch) + attribution_map_batch = np.swapaxes(attribution_map_batch, 1, 0) + attribution_map.append(attribution_map_batch) + + attribution_map = np.vstack(attribution_map) + return self._reduce_attribution_map({ + 'neuron-gradient': attribution_map, + #'neuron-gradient-invsvd': self._inverse_svd(attribution_map) + }) + + +@dataclasses.dataclass +@register("feature-ablation") +class FeatureAblationMethod(AttributionMap): + """Compute the attribution map using the feature ablation method from Captum.""" + + def __post_init__(self): + super().__post_init__() + self.captum_model = NeuronFeatureAblation(forward_func=self.model, + layer=self.model) + + def compute_attribution_map(self, + baselines=None, + feature_mask=None, + perturbations_per_eval=1, + attribute_to_neuron_input=False): + attribution_map = [] + for s in range(self.output_dimension): + att = self.captum_model.attribute( + inputs=self.input_data, + neuron_selector=s, + baselines=baselines, + perturbations_per_eval=perturbations_per_eval, + feature_mask=feature_mask, + attribute_to_neuron_input=attribute_to_neuron_input) + + attribution_map.append(att.detach().cpu().numpy()) + + attribution_map = np.array(attribution_map) + attribution_map = np.swapaxes(attribution_map, 1, 0) + return self._reduce_attribution_map( + {'feature-ablation': attribution_map}) + + +@dataclasses.dataclass +@register("feature-ablation-batched") +class FeatureAblationMethodBAtched(FeatureAblationMethod): + """As :py:class:`FeatureAblationMethod`, but using mini-batches. + + See also: + :py:class:`FeatureAblationMethod` + """ + + def compute_attribution_map(self, + baselines=None, + feature_mask=None, + perturbations_per_eval=1, + attribute_to_neuron_input=False, + batch_size=1024): + + input_data_batches = torch.split(self.input_data, batch_size) + attribution_map = [] + for input_data_batch in input_data_batches: + attribution_map_batch = [] + for s in range(self.output_dimension): + att = self.captum_model.attribute( + inputs=input_data_batch, + neuron_selector=s, + baselines=baselines, + perturbations_per_eval=perturbations_per_eval, + feature_mask=feature_mask, + attribute_to_neuron_input=attribute_to_neuron_input) + + attribution_map_batch.append(att.detach().cpu().numpy()) + + attribution_map_batch = np.array(attribution_map_batch) + attribution_map_batch = np.swapaxes(attribution_map_batch, 1, 0) + attribution_map.append(attribution_map_batch) + + attribution_map = np.vstack(attribution_map) + return self._reduce_attribution_map( + {'feature-ablation': attribution_map}) + + +@dataclasses.dataclass +@register("integrated-gradients") +class IntegratedGradientsMethod(AttributionMap): + """Compute the attribution map using the integrated gradients method from Captum.""" + + def __post_init__(self): + super().__post_init__() + self.captum_model = NeuronIntegratedGradients(forward_func=self.model, + layer=self.model) + + def compute_attribution_map(self, + n_steps=50, + method='gausslegendre', + internal_batch_size=None, + attribute_to_neuron_input=False, + baselines=None): + if internal_batch_size == "dataset": + internal_batch_size = len(self.input_data) + + attribution_map = [] + for s in range(self.output_dimension): + att = self.captum_model.attribute( + inputs=self.input_data, + neuron_selector=s, + n_steps=n_steps, + method=method, + internal_batch_size=internal_batch_size, + attribute_to_neuron_input=attribute_to_neuron_input, + baselines=baselines, + ) + attribution_map.append(att.detach().cpu().numpy()) + + attribution_map = np.array(attribution_map) + attribution_map = np.swapaxes(attribution_map, 1, 0) + return self._reduce_attribution_map( + {'integrated-gradients': attribution_map}) + + +@dataclasses.dataclass +@register("integrated-gradients-batched") +class IntegratedGradientsMethodBatched(IntegratedGradientsMethod): + """As :py:class:`IntegratedGradientsMethod`, but using mini-batches. + + See also: + :py:class:`IntegratedGradientsMethod` + """ + + def compute_attribution_map(self, + n_steps=50, + method='gausslegendre', + internal_batch_size=None, + attribute_to_neuron_input=False, + baselines=None, + batch_size=1024): + + input_data_batches = torch.split(self.input_data, batch_size) + attribution_map = [] + for input_data_batch in input_data_batches: + attribution_map_batch = [] + if internal_batch_size == "dataset": + internal_batch_size = len(input_data_batch) + for s in range(self.output_dimension): + att = self.captum_model.attribute( + inputs=input_data_batch, + neuron_selector=s, + n_steps=n_steps, + method=method, + internal_batch_size=internal_batch_size, + attribute_to_neuron_input=attribute_to_neuron_input, + baselines=baselines, + ) + attribution_map_batch.append(att.detach().cpu().numpy()) + + attribution_map_batch = np.array(attribution_map_batch) + attribution_map_batch = np.swapaxes(attribution_map_batch, 1, 0) + attribution_map.append(attribution_map_batch) + + attribution_map = np.vstack(attribution_map) + return self._reduce_attribution_map( + {'integrated-gradients': attribution_map}) + + +@dataclasses.dataclass +@register("neuron-gradient-shap") +class NeuronGradientShapMethod(AttributionMap): + """Compute the attribution map using the neuron gradient SHAP method from Captum.""" + + def __post_init__(self): + super().__post_init__() + self.captum_model = NeuronGradientShap(forward_func=self.model, + layer=self.model) + + def compute_attribution_map(self, + baselines: str, + n_samples=5, + stdevs=0.0, + attribute_to_neuron_input=False): + + if baselines == "zeros": + baselines = torch.zeros(size=(self.input_data.shape), + device=self.input_data.device) + elif baselines == "shuffle": + data = self.input_data.flatten() + data = data[torch.randperm(len(data))] + baselines = data.reshape(self.input_data.shape) + else: + raise NotImplementedError(f"Baseline {baselines} not implemented.") + + attribution_map = [] + for s in range(self.output_dimension): + att = self.captum_model.attribute( + inputs=self.input_data, + neuron_selector=s, + baselines=baselines, + n_samples=n_samples, + stdevs=stdevs, + attribute_to_neuron_input=attribute_to_neuron_input, + ) + + attribution_map.append(att.detach().cpu().numpy()) + + attribution_map = np.array(attribution_map) + attribution_map = np.swapaxes(attribution_map, 1, 0) + return self._reduce_attribution_map( + {'neuron-gradient-shap': attribution_map}) + + +@dataclasses.dataclass +@register("neuron-gradient-shap-batched") +class NeuronGradientShapMethodBatched(NeuronGradientShapMethod): + """As :py:class:`NeuronGradientShapMethod`, but using mini-batches. + + See also: + :py:class:`NeuronGradientShapMethod` + """ + + def compute_attribution_map(self, + baselines: str, + n_samples=5, + stdevs=0.0, + attribute_to_neuron_input=False, + batch_size=1024): + + if baselines == "zeros": + baselines = torch.zeros(size=(self.input_data.shape), + device=self.input_data.device) + elif baselines == "shuffle": + data = self.input_data.flatten() + data = data[torch.randperm(len(data))] + baselines = data.reshape(self.input_data.shape) + else: + raise NotImplementedError(f"Baseline {baselines} not implemented.") + + input_data_batches = torch.split(self.input_data, batch_size) + attribution_map = [] + for input_data_batch in input_data_batches: + attribution_map_batch = [] + for s in range(self.output_dimension): + att = self.captum_model.attribute( + inputs=input_data_batch, + neuron_selector=s, + baselines=baselines, + n_samples=n_samples, + stdevs=stdevs, + attribute_to_neuron_input=attribute_to_neuron_input, + ) + + attribution_map_batch.append(att.detach().cpu().numpy()) + + attribution_map_batch = np.array(attribution_map_batch) + attribution_map_batch = np.swapaxes(attribution_map_batch, 1, 0) + attribution_map.append(attribution_map_batch) + + attribution_map = np.vstack(attribution_map) + return self._reduce_attribution_map( + {'neuron-gradient-shap': attribution_map}) diff --git a/cebra/attribution/jacobian_attribution.py b/cebra/attribution/jacobian_attribution.py new file mode 100644 index 00000000..f8db8344 --- /dev/null +++ b/cebra/attribution/jacobian_attribution.py @@ -0,0 +1,95 @@ +# +# CEBRA: Consistent EmBeddings of high-dimensional Recordings using Auxiliary variables +# © Mackenzie W. Mathis & Steffen Schneider (v0.4.0+) +# Source code: +# https://github.com/AdaptiveMotorControlLab/CEBRA +# +# Please see LICENSE.md for the full license document: +# https://github.com/AdaptiveMotorControlLab/CEBRA/blob/main/LICENSE.md +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""Tools for computing attribution maps.""" + +from typing import Literal + +import numpy as np +import torch +from torch import nn + +import cebra.attribution._jacobian + +__all__ = ["get_attribution_map"] + + +def _prepare_inputs(inputs): + if not isinstance(inputs, torch.Tensor): + inputs = torch.from_numpy(inputs) + inputs.requires_grad_(True) + return inputs + + +def _prepare_model(model): + for p in model.parameters(): + p.requires_grad_(False) + return model + + +def get_attribution_map( + model: nn.Module, + input_data: torch.Tensor, + double_precision: bool = True, + convert_to_numpy: bool = True, + aggregate: Literal["mean", "sum", "max"] = "mean", + transform: Literal["none", "abs"] = "none", + hybrid_solver: bool = False, +): + """Estimate attribution maps using the Jacobian pseudo-inverse. + + The function estimates Jacobian matrices for each point in the model, + computes the pseudo-inverse (for every sample) and then aggregates + the resulting matrices to compute an attribution map. + + Args: + model: The neural network model for which to compute attributions. + input_data: Input tensor or numpy array to compute attributions for. + double_precision: If ``True``, use double precision for computation. + convert_to_numpy: If ``True``, convert the output to numpy arrays. + aggregate: Method to aggregate attribution values across samples. + Options are ``"mean"``, ``"sum"``, or ``"max"``. + transform: Transformation to apply to attribution values. + Options are ``"none"`` or ``"abs"``. + hybrid_solver: If ``True``, handle multi-objective models differently. + + Returns: + A tuple containing the Jacobian matrix of shape (num_samples, output_dim, input_dim) + and the pseudo-inverse of the Jacobian matrix. + + """ + assert aggregate in ["mean", "sum", "max"] + + input_data = _prepare_inputs(input_data) + model = _prepare_model(model) + + # compute jacobian CEBRA model + jf = cebra.attribution._jacobian.compute_jacobian( + model, + input_vars=[input_data], + mode="autograd", + double_precision=double_precision, + convert_to_numpy=convert_to_numpy, + hybrid_solver=hybrid_solver, + ) + + jhatg = np.linalg.pinv(jf) + return jf, jhatg diff --git a/cebra/data/__init__.py b/cebra/data/__init__.py index ec753f18..145ff835 100644 --- a/cebra/data/__init__.py +++ b/cebra/data/__init__.py @@ -46,10 +46,9 @@ # these imports will not be reordered by isort (see .isort.cfg) from cebra.data.base import * from cebra.data.datatypes import * - from cebra.data.single_session import * from cebra.data.multi_session import * - +from cebra.data.multiobjective import * from cebra.data.datasets import * - from cebra.data.helper import * +from cebra.data.masking import * diff --git a/cebra/data/assets.py b/cebra/data/assets.py index 86695482..adea8413 100644 --- a/cebra/data/assets.py +++ b/cebra/data/assets.py @@ -93,7 +93,7 @@ def download_file_with_progress_bar(url: str, ) # Create the directory and any necessary parent directories - location_path.mkdir(exist_ok=True) + location_path.mkdir(parents=True, exist_ok=True) filename = filename_match.group(1) file_path = location_path / filename diff --git a/cebra/data/base.py b/cebra/data/base.py index e35e20c5..51199cec 100644 --- a/cebra/data/base.py +++ b/cebra/data/base.py @@ -27,6 +27,7 @@ import torch import cebra.data.assets as cebra_data_assets +import cebra.data.masking as cebra_data_masking import cebra.distributions import cebra.io from cebra.data.datatypes import Batch @@ -36,7 +37,7 @@ __all__ = ["Dataset", "Loader"] -class Dataset(abc.ABC, cebra.io.HasDevice): +class Dataset(abc.ABC, cebra.io.HasDevice, cebra_data_masking.MaskedMixin): """Abstract base class for implementing a dataset. The class attributes provide information about the shape of the data when @@ -193,7 +194,6 @@ def load_batch(self, index: BatchIndex) -> Batch: """ raise NotImplementedError() - @abc.abstractmethod def configure_for(self, model: "cebra.models.Model"): """Configure the dataset offset for the provided model. @@ -203,7 +203,7 @@ def configure_for(self, model: "cebra.models.Model"): Args: model: The model to configure the dataset for. """ - raise NotImplementedError + self.offset = model.get_offset() @dataclasses.dataclass diff --git a/cebra/data/datasets.py b/cebra/data/datasets.py index dbb2f1f5..59af8900 100644 --- a/cebra/data/datasets.py +++ b/cebra/data/datasets.py @@ -22,16 +22,23 @@ """Pre-defined datasets.""" import types -from typing import List, Literal, Optional, Tuple, Union +from typing import List, Literal, Optional, Tuple, TYPE_CHECKING, Union import numpy as np import numpy.typing as npt import torch import cebra.data as cebra_data +import cebra.data.masking as cebra_data_masking import cebra.helper as cebra_helper +import cebra.io as cebra_io +from cebra.data.datatypes import Batch +from cebra.data.datatypes import BatchIndex from cebra.data.datatypes import Offset +if TYPE_CHECKING: + from cebra.models import Model + class TensorDataset(cebra_data.SingleSessionDataset): """Discrete and/or continuously indexed dataset based on torch/numpy arrays. @@ -295,3 +302,205 @@ def _apply(self, func): def _iter_property(self, attr): return (getattr(data, attr) for data in self.iter_sessions()) + + +# TODO(stes): This should be a single session dataset? +class DatasetxCEBRA(cebra_io.HasDevice, cebra_data_masking.MaskedMixin): + """Dataset class for xCEBRA models. + + This class handles neural data and associated labels for xCEBRA models, providing + functionality for data loading and batch preparation. + + Attributes: + neural: Neural data as a torch.Tensor or numpy array + labels: Labels associated with the data + offset: Offset for the dataset + + Args: + neural: Neural data as a torch.Tensor or numpy array + device: Device to store the data on (default: "cpu") + **labels: Additional keyword arguments for labels associated with the data + """ + + def __init__( + self, + neural: Union[torch.Tensor, npt.NDArray], + device="cpu", + **labels, + ): + super().__init__(device) + self.neural = neural + self.labels = labels + self.offset = Offset(0, 1) + + @property + def input_dimension(self) -> int: + """Get the input dimension of the neural data. + + Returns: + The number of features in the neural data + """ + return self.neural.shape[1] + + def __len__(self): + """Get the length of the dataset. + + Returns: + Number of samples in the dataset + """ + return len(self.neural) + + def configure_for(self, model: "Model"): + """Configure the dataset offset for the provided model. + + Call this function before indexing the dataset. This sets the + :py:attr:`offset` attribute of the dataset. + + Args: + model: The model to configure the dataset for. + """ + self.offset = model.get_offset() + + def expand_index(self, index: torch.Tensor) -> torch.Tensor: + """Expand indices based on the configured offset. + + Args: + index: A one-dimensional tensor of type long containing indices + to select from the dataset. + + Returns: + An expanded index of shape ``(len(index), len(self.offset))`` where + the elements will be + ``expanded_index[i,j] = index[i] + j - self.offset.left`` for all ``j`` + in ``range(0, len(self.offset))``. + + Note: + Requires the :py:attr:`offset` to be set. + """ + offset = torch.arange(-self.offset.left, + self.offset.right, + device=index.device) + + index = torch.clamp(index, self.offset.left, + len(self) - self.offset.right) + + return index[:, None] + offset[None, :] + + def __getitem__(self, index): + """Get item(s) from the dataset at the specified index. + + Args: + index: Index or indices to retrieve + + Returns: + The neural data at the specified indices, with dimensions transposed + """ + index = self.expand_index(index) + return self.neural[index].transpose(2, 1) + + def load_batch_supervised(self, index: Batch, + labels_supervised) -> torch.Tensor: + """Load a batch for supervised learning. + + Args: + index: Batch indices for reference data + labels_supervised: Labels to load for supervised learning + + Returns: + Batch containing reference data and corresponding labels + """ + assert index.negative is None + assert index.positive is None + labels = [ + self.labels[label].to(self.device) for label in labels_supervised + ] + + return Batch( + reference=self[index.reference], + positive=[label[index.reference] for label in labels], + negative=None, + ) + + def load_batch_contrastive(self, index: BatchIndex) -> Batch: + """Load a batch for contrastive learning. + + Args: + index: BatchIndex containing reference, positive and negative indices + + Returns: + Batch containing reference, positive and negative samples + """ + assert isinstance(index.positive, list) + return Batch( + reference=self[index.reference], + positive=[self[idx] for idx in index.positive], + negative=self[index.negative], + ) + + +class UnifiedDataset(DatasetCollection): + """Multi session dataset made up of a list of datasets, considered as a unique session. + + Considering the sessions as a unique session, or pseudo-session, is used to later train a single + model for all the sessions, even if they originally contain a variable number of neurons. + To do that, we sample ref/pos/neg for each session and concatenate them along the neurons axis. + + For instance, for a batch size ``batch_size``, we sample ``(batch_size, num_neurons(session), offset)`` for + each type of samples (ref/pos/neg) and then concatenate so that the final :py:class:`cebra.data.datatypes.Batch` + is of shape ``(batch_size, total_num_neurons, offset)``, with ``total_num_neurons`` is the sum of all the + ``num_neurons(session)``. + """ + + def __init__(self, *datasets: cebra_data.SingleSessionDataset): + super().__init__(*datasets) + + @property + def input_dimension(self) -> int: + """Returns the sum of the input dimension for each session.""" + return np.sum([ + self.get_input_dimension(session_id) + for session_id in range(self.num_sessions) + ]) + + def _get_batches(self, index): + """Return the data at the specified index location.""" + return [ + cebra_data.Batch( + reference=self.get_session(session_id)[ + index.reference[session_id]], + positive=self.get_session(session_id)[ + index.positive[session_id]], + negative=self.get_session(session_id)[ + index.negative[session_id]], + ) for session_id in range(self.num_sessions) + ] + + def load_batch(self, index: BatchIndex) -> Batch: + """Return the data at the specified index location. + + Concatenate batches for each sessions on the number of neurons axis. + + Args: + batches: List of :py:class:`cebra.data.datatypes.Batch` sampled for each session. An instance + :py:class:`cebra.data.datatypes.Batch` of the list is of shape ``(batch_size, num_neurons(session), offset)``. + + Returns: + A :py:class:`cebra.data.datatypes.Batch`, of shape ``(batch_size, total_num_neurons, offset)``, where + ``total_num_neurons`` is the sum of all the ``num_neurons(session)`` + """ + batches = self._get_batches(index) + + return cebra_data.Batch( + reference=self.apply_mask( + torch.cat([batch.reference for batch in batches], dim=1)), + positive=self.apply_mask( + torch.cat([batch.positive for batch in batches], dim=1)), + negative=self.apply_mask( + torch.cat([batch.negative for batch in batches], dim=1)), + ) + + def __getitem__(self, args) -> List[Batch]: + """Return a set of samples from all sessions.""" + + session_id, index = args + return self.get_session(session_id).__getitem__(index) diff --git a/cebra/data/load.py b/cebra/data/load.py index 6f1b86e5..02714ad0 100644 --- a/cebra/data/load.py +++ b/cebra/data/load.py @@ -275,11 +275,11 @@ def _is_dlc_df(h5_file: IO[bytes], df_keys: List[str]) -> bool: """ try: if ["_i_table", "table"] in df_keys: - df = pd.read_hdf(h5_file, key="table") + df = read_hdf(h5_file, key="table") else: - df = pd.read_hdf(h5_file, key=df_keys[0]) + df = read_hdf(h5_file, key=df_keys[0]) except KeyError: - df = pd.read_hdf(h5_file) + df = read_hdf(h5_file) return all(value in df.columns.names for value in ["scorer", "bodyparts", "coords"]) @@ -348,7 +348,7 @@ def load_from_h5(file: Union[pathlib.Path, str], key: str, Returns: A :py:func:`numpy.array` containing the data of interest extracted from the :py:class:`pandas.DataFrame`. """ - df = pd.read_hdf(file, key=key) + df = read_hdf(file, key=key) if columns is None: loaded_array = df.values elif isinstance(columns, list) and df.columns.nlevels == 1: @@ -716,3 +716,21 @@ def _get_loader(file_ending: str) -> _BaseLoader: if file_ending not in __loaders.keys() or file_ending == "": raise OSError(f"File ending {file_ending} not supported.") return __loaders[file_ending] + + +def read_hdf(filename, key=None): + """Read HDF5 file using pandas, with fallback to h5py if pandas fails. + + Args: + filename: Path to HDF5 file + key: Optional key to read from HDF5 file. If None, tries "df_with_missing" + then falls back to first available key. + + Returns: + pandas.DataFrame: The loaded data + + Raises: + RuntimeError: If both pandas and h5py fail to load the file + """ + + return pd.read_hdf(filename, key=key) diff --git a/cebra/data/mask.py b/cebra/data/mask.py new file mode 100644 index 00000000..946d97a4 --- /dev/null +++ b/cebra/data/mask.py @@ -0,0 +1,327 @@ +import abc +import random +from typing import List, Tuple, Union + +import numpy as np +import torch + +__all__ = [ + "Mask", "RandomNeuronMask", "RandomTimestepMask", "NeuronBlockMask", + "TimeBlockMask" +] + + +class Mask: + + def __init__(self, masking_value: Union[float, List[float], Tuple[float]]): + self._check_masking_parameters(masking_value) + + @abc.abstractmethod + def apply_mask(self, data: torch.Tensor) -> torch.Tensor: + raise NotImplementedError + + @abc.abstractmethod + def _select_masking_params(): + raise NotImplementedError + + def _check_masking_parameters(self, masking_value: Union[float, List[float], + Tuple[float]]): + """ + The masking values are the masking ratio to apply. + It can be a single ratio, a list of ratio that will be picked randomly or + a tuple of (min, max, step_size) that will be used to create a list of ratios + from which to sample randomly. + """ + if isinstance(masking_value, float): + assert 0.0 < masking_value < 1.0, ( + f"Masking ratio {masking_value} for {self.__name__()} " + "should be between 0.0 and 1.0.") + + elif isinstance(masking_value, list): + assert all(isinstance(ratio, float) for ratio in masking_value), ( + f"Masking ratios {masking_value} for {self.__name__()} " + "should be between 0.0 and 1.0.") + assert all(0.0 < ratio < 1.0 for ratio in masking_value), ( + f"Masking ratios {masking_value} for {self.__name__()} " + "should be between 0.0 and 1.0.") + + elif isinstance(masking_value, tuple): + assert len(masking_value) == 3, ( + f"Masking ratios {masking_value} for {self.__name__()} " + "should be a tuple of (min, max, step).") + assert 0.0 <= masking_value[0] < masking_value[1] <= 1.0, ( + f"Masking ratios {masking_value} for {self.__name__()} " + "should be between 0.0 and 1.0.") + assert masking_value[2] < masking_value[1] - masking_value[0], ( + f"Masking step {masking_value[2]} for {self.__name__()} " + "should be between smaller than the diff between min " + f"({masking_value[0]}) and max ({masking_value[1]}).") + + else: + raise ValueError( + f"Masking ratio {masking_value} for {self.__name__()} " + "should be a float, list of floats or a tuple of (min, max, step)." + ) + + +class RandomNeuronMask(Mask): + + def __init__(self, masking_value: Union[float, List[float], Tuple[float]]): + super().__init__(masking_value) + self.mask_ratio = masking_value + + def __name__(self): + return "RandomNeuronMask" + + def apply_mask(self, data: torch.Tensor) -> torch.Tensor: + """ Apply random masking on the neuron dimension. + + Args: + data: batch of size (batch_size, n_neurons, offset). + mask_ratio: Proportion of neurons to mask. Default value 0.3 comes + from the MtM paper: https://arxiv.org/pdf/2407.14668v2 + + Returns: + torch.Tensor: The mask, a tensor of the same size as the input data with the + masked neurons set to 1. + """ + batch_size, n_neurons, offset_length = data.shape + mask_ratio = self._select_masking_params() + + # Random mask: shape [batch_size, n_neurons], different per batch and neurons + masked = torch.rand(batch_size, n_neurons, + device=data.device) < mask_ratio + return (~masked).int().unsqueeze(2).expand( + -1, -1, offset_length) # Expand to all timesteps + + def _select_masking_params(self) -> float: + """ + The masking values are the masking ratio to apply. + It can be a single ratio, a list of ratio that will be picked randomly or + a tuple of (min, max, step_size) that will be used to create a list of ratios + from which to sample randomly. + """ + if isinstance(self.mask_ratio, float): + selected_value = self.mask_ratio + + elif isinstance(self.mask_ratio, list): + selected_value = random.choice(self.mask_ratio) + + elif isinstance(self.mask_ratio, tuple): + min_val, max_val, step_size = self.mask_ratio + selected_value = random.choice( + np.arange(min_val, max_val + step_size, step_size).tolist()) + + else: + raise ValueError( + f"Masking ratio {self.mask_ratio} for {self.__name__()} " + "should be a float, list of floats or a tuple of (min, max, step)." + ) + + return selected_value + + +class RandomTimestepMask(Mask): + + def __init__(self, masking_value: Union[float, List[float], Tuple[float]]): + super().__init__(masking_value) + self.mask_ratio = masking_value + + def __name__(self): + return "RandomTimestepMask" + + def apply_mask(self, data: torch.Tensor) -> torch.Tensor: + """ Apply random masking on the time dimension. + + Args: + data: batch of size (batch_idx, feature_dim, seq_len). With seq_len + corresponding to the offset. + mask_ratio: Proportion of timesteps masked. Not necessarly consecutive. + Default value 0.3 comes from the MtM paper: https://arxiv.org/pdf/2407.14668v2 + + Returns: + torch.Tensor: The mask, a tensor of the same size as the input data with the + masked neurons set to 1. + + """ + batch_idx, n_neurons, offset_length = data.shape + mask_ratio = self._select_masking_params() + + # Random mask: shape [batbatch_idxch_size, offset_length], different per batch and timestamp + masked = torch.rand(batch_idx, offset_length, + device=data.device) < mask_ratio + return (~masked).int().unsqueeze(1).expand(-1, n_neurons, + -1) # Expand to all neurons + + def _select_masking_params(self) -> float: + """ + The masking values are the masking ratio to apply. + It can be a single ratio, a list of ratio that will be picked randomly or + a tuple of (min, max, step_size) that will be used to create a list of ratios + from which to sample randomly. + """ + if isinstance(self.mask_ratio, float): + selected_value = self.mask_ratio + + elif isinstance(self.mask_ratio, list): + selected_value = random.choice(self.mask_ratio) + + elif isinstance(self.mask_ratio, tuple): + min_val, max_val, step_size = self.mask_ratio + selected_value = random.choice( + np.arange(min_val, max_val + step_size, step_size).tolist()) + + else: + raise ValueError( + f"Masking ratio {self.mask_ratio} for {self.__name__()} " + "should be a float, list of floats or a tuple of (min, max, step)." + ) + + return selected_value + + +class NeuronBlockMask(Mask): + + def __init__(self, masking_value: Union[float, List[float], Tuple[float]]): + super().__init__(masking_value) + self.mask_prop = masking_value + + def __name__(self): + return "NeuronBlockMask" + + def apply_mask(self, data: torch.Tensor) -> torch.Tensor: + """ Apply masking to a contiguous block of neurons. + + Args: + data: batch of size (batch_size, n_neurons, offset). + self.mask_prop: Proportion of neurons to mask. The neurons are masked in a + contiguous block. + + Returns: + torch.Tensor: The mask, a tensor of the same size as the input data with the + masked neurons set to 1. + """ + batch_size, n_neurons, offset_length = data.shape + + mask_prop = self._select_masking_params() + num_mask = int(n_neurons * mask_prop) + mask = torch.ones((batch_size, n_neurons), + dtype=torch.int, + device=data.device) + + if num_mask == 0: + return mask.unsqueeze(2) + + for batch_idx in range(batch_size): # Create a mask for each batch + # Select random the start index for the block of neurons to mask + start_idx = torch.randint(0, n_neurons - num_mask + 1, (1,)).item() + end_idx = min(start_idx + num_mask, n_neurons) + mask[batch_idx, start_idx:end_idx] = 0 # set masked neurons to 0 + + return mask.unsqueeze(2).expand( + -1, -1, offset_length) # Expand to all timesteps + + def _select_masking_params(self) -> float: + """ + The masking values are the masking ratio to apply. + It can be a single ratio, a list of ratio that will be picked randomly or + a tuple of (min, max, step_size) that will be used to create a list of ratios + from which to sample randomly. + """ + if isinstance(self.mask_prop, float): + selected_value = self.mask_prop + + elif isinstance(self.mask_prop, list): + selected_value = random.choice(self.mask_prop) + + elif isinstance(self.mask_prop, tuple): + min_val, max_val, step_size = self.mask_prop + selected_value = random.choice( + np.arange(min_val, max_val + step_size, step_size).tolist()) + + else: + raise ValueError( + f"Masking ratio {self.mask_prop} for {self.__name__()} " + "should be a float, list of floats or a tuple of (min, max, step)." + ) + + return selected_value + + +class TimeBlockMask(Mask): + + def __init__(self, masking_value: Union[float, List[float], Tuple[float]]): + super().__init__(masking_value) + self.sampled_rate, self.masked_seq_len = masking_value + + def __name__(self): + return "TimeBlockMask" + + def apply_mask(self, data: torch.Tensor) -> torch.Tensor: + """ Apply continguous block masking on the time dimension. + + When choosing which block of timesteps to mask, each timestep is considered + a candidate starting time-step with probability ``self.sampled_rate`` where + ``self.masked_seq_len`` is the length of each masked span starting from the respective + time step. Sampled starting time steps are expanded to length ``self.masked_seq_len`` + and spans can overlap. Inspirede by the wav2vec 2.0 masking strategy. + + Default values from the wav2vec paper: https://arxiv.org/abs/2006.11477. + + Args: + data (torch.Tensor): The input tensor of shape (batch_size, seq_len, feature_dim). + self.sampled_rate (float): The probability of each time-step being a candidate for masking. + self.masked_seq_len (int): The length of each masked span starting from the sampled time-step. + + Returns: + torch.Tensor: A boolean mask of shape (batch_size, seq_len) where True + indicates masked positions. + """ + batch_size, n_neurons, offset_length = data.shape + + sampled_rate, masked_seq_len = self._select_masking_params() + + num_masked_starting_points = int(offset_length * sampled_rate) + mask = torch.ones((batch_size, offset_length), + dtype=int, + device=data.device) + for batch_idx in range(batch_size): + # Sample starting points for masking in the current batch + start_indices = torch.randperm( + offset_length, device=data.device)[:num_masked_starting_points] + + # Apply masking spans + for start in start_indices: + end = min(start + masked_seq_len, offset_length) + mask[batch_idx, start:end] = 0 # set masked timesteps to 0 + + return mask.unsqueeze(1).expand(-1, n_neurons, + -1) # Expand to all neurons + + def _check_masking_parameters(self, masking_value: Union[float, List[float], + Tuple[float]]): + """ + The masking values are the parameters for the timeblock masking. + It needs to be a tuple of (sampled_rate, masked_seq_len) + sampled_rate: The probability of each time-step being a candidate for masking. + masked_seq_len: The length of each masked span starting from the sampled time-step. + """ + assert isinstance(masking_value, tuple) and len(masking_value) == 2, ( + f"Masking parameters {masking_value} for {self.__name__()} " + "should be a tuple of (sampled_rate, masked_seq_len).") + assert 0.0 < masking_value[0] < 1.0 and isinstance( + masking_value[0], float), ( + f"Masking parameters {masking_value} for {self.__name__()} " + "should be between 0.0 and 1.0.") + assert masking_value[1] > 0 and isinstance(masking_value[1], int), ( + f"Masking parameters {masking_value} for {self.__name__()} " + "should be an integer greater than 0.") + + def _select_masking_params(self) -> float: + """ + The masking values are the masking ratio to apply. + It can be a single ratio, a list of ratio that will be picked randomly or + a tuple of (min, max, step_size) that will be used to create a list of ratios + from which to sample randomly. + """ + return self.sampled_rate, self.masked_seq_len diff --git a/cebra/data/masking.py b/cebra/data/masking.py new file mode 100644 index 00000000..2a9a5977 --- /dev/null +++ b/cebra/data/masking.py @@ -0,0 +1,86 @@ +import random +from typing import Dict, Optional + +import torch + +import cebra.data.mask as mask + + +class MaskedMixin: + """A mixin class for applying masking to data. + + Note: + This class is designed to be used as a mixin for other classes. + It provides functionality to apply masking to data. + The `set_masks` method should be called to set the masking types + and their corresponding probabilities. + """ + masks = [] # a list of Mask instances + + def set_masks(self, masking: Optional[Dict[str, float]] = None) -> None: + """Set the mask type and probability for the dataset. + + Args: + masking (Dict[str, float]): A dictionary of masking types and their + corresponding required masking values. The keys are the names + of the Mask instances. + + Note: + By default, no masks are applied. + """ + if masking is not None: + for mask_key in masking: + if mask_key in mask.__all__: + cls = getattr(mask, mask_key) + self.masks = [ + m for m in self.masks if not isinstance(m, cls) + ] + self.masks.append(cls(masking[mask_key])) + else: + raise ValueError( + f"Mask type {mask_key} not supported. Supported types are {masking.keys()}" + ) + + def apply_mask(self, + data: torch.Tensor, + chunk_size: int = 1000) -> torch.Tensor: + """Apply masking to the input data. + + Note: + - By default, no masking. Else apply masking on the input data. + - Only one masking type can be applied at a time, but multiple + masking types can be set so that it alternates between them + across iterations. + - Masking is applied to the data in chunks to avoid memory issues. + + Args: + data (torch.Tensor): batch of size (batch_size, num_neurons, offset). + chunk_size (int): Number of rows to process at a time. + + Returns: + torch.Tensor: The masked data. + """ + if data.dim() != 3: + raise ValueError( + f"Data must be a 3D tensor, but got {data.dim()}D tensor.") + if data.dtype != torch.float32: + raise ValueError( + f"Data must be a float32 tensor, but got {data.dtype}.") + + # If masks is empty, return the data as is + if not self.masks: + return data + + sampled_mask = random.choice(self.masks) + mask = sampled_mask.apply_mask(data) + + num_chunks = (data.shape[0] + chunk_size - + 1) // chunk_size # Compute number of chunks + + for i in range(num_chunks): + start, end = i * chunk_size, min((i + 1) * chunk_size, + data.shape[0]) + data[start:end].mul_( + mask[start:end]) # apply mask in-place to save memory + + return data diff --git a/cebra/data/multi_session.py b/cebra/data/multi_session.py index ebae8b6f..49e9f894 100644 --- a/cebra/data/multi_session.py +++ b/cebra/data/multi_session.py @@ -26,6 +26,7 @@ import literate_dataclasses as dataclasses import torch +import torch.nn as nn import cebra.data as cebra_data import cebra.distributions @@ -38,6 +39,7 @@ "ContinuousMultiSessionDataLoader", "DiscreteMultiSessionDataLoader", "MixedMultiSessionDataLoader", + "UnifiedLoader", ] @@ -108,13 +110,20 @@ def configure_for(self, model: "cebra.models.Model"): """Configure the dataset offset for the provided model. Call this function before indexing the dataset. This sets the - `offset` attribute of the dataset. + :py:attr:`~.Dataset.offset` attribute of the dataset. Args: model: The model to configure the dataset for. """ for i, session in enumerate(self.iter_sessions()): - session.configure_for(model[i]) + if isinstance(model, nn.ModuleList): + if len(model) != self.num_sessions: + raise ValueError( + f"The model must have {self.num_sessions} sessions, but got {len(model)}." + ) + session.configure_for(model[i]) + else: + session.configure_for(model) @dataclasses.dataclass @@ -178,3 +187,64 @@ def index(self): @dataclasses.dataclass class MixedMultiSessionDataLoader(MultiSessionLoader): pass + + +@dataclasses.dataclass +class UnifiedLoader(ContinuousMultiSessionDataLoader): + """Dataloader for multi-session datasets, considered as a single session. + + This class is used in pair with :py:class:`cebra.data.datasets.UnifiedDataset` + to sample from each session and train a single model on them, even if sessions have a + different number of neurons. + + To sample the reference and negative samples, a target session is randomly selected. Indexes + are unformly sampled in that first session. Then, indexes in the other sessions are samples + conditionally to the first session indexes, so that their corresponding auxiliary variables + are close. For the positive samples, they are sampled conditionally to the reference samples, + in their corresponding session only. + + Then, the ref/pos/neg samples are concatenated respectively, along the neurons axis (takes place + in the :py:class:`cebra.data.datasets.UnifiedDataset`). + + """ + + def __post_init__(self): + super().__post_init__() + self.sampler = cebra.distributions.UnifiedSampler( + self.dataset, self.time_offset) + + def get_indices(self, num_samples: int) -> BatchIndex: + """Sample and return the specified number of indices. + + The elements of the returned ``BatchIndex`` will be used to index the + ``dataset`` of this data loader. + + To sample the reference and negative samples, a target session is + randomly selected. Indexes are unformly sampled in that first + session. Then, indexes in the other sessions are samples conditionally + to the first session indexes, so that their corresponding auxiliary + variables are close. For the positive samples, they are sampled + conditionally to the reference samples, in their corresponding session + only. + + Args: + num_samples: The size of each of the reference, positive and + negative samples to sample. + + Returns: + Batch indices for the reference, positive and negative samples. + """ + ref_idx = self.sampler.sample_prior(self.batch_size) + neg_idx = self.sampler.sample_prior(self.batch_size) + + pos_idx = self.sampler.sample_conditional(ref_idx) + + ref_idx = torch.from_numpy(ref_idx).to(self.device) + neg_idx = torch.from_numpy(neg_idx).to(self.device) + pos_idx = torch.from_numpy(pos_idx).to(self.device) + + return BatchIndex( + reference=ref_idx, + positive=pos_idx, + negative=neg_idx, + ) diff --git a/cebra/data/multiobjective.py b/cebra/data/multiobjective.py new file mode 100644 index 00000000..f700d1c4 --- /dev/null +++ b/cebra/data/multiobjective.py @@ -0,0 +1,173 @@ +# +# CEBRA: Consistent EmBeddings of high-dimensional Recordings using Auxiliary variables +# © Mackenzie W. Mathis & Steffen Schneider (v0.4.0+) +# Source code: +# https://github.com/AdaptiveMotorControlLab/CEBRA +# +# Please see LICENSE.md for the full license document: +# https://github.com/AdaptiveMotorControlLab/CEBRA/blob/main/LICENSE.md +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import literate_dataclasses as dataclasses + +import cebra.data as cebra_data +import cebra.distributions +from cebra.data.datatypes import BatchIndex +from cebra.distributions.continuous import Prior + + +@dataclasses.dataclass +class MultiObjectiveLoader(cebra_data.Loader): + """Baseclass of Multiobjective Data Loader. Yields batches of the specified size from the given dataset object. + """ + dataset: int = dataclasses.field( + default=None, + doc="""A dataset instance specifying a ``__getitem__`` function.""", + ) + num_steps: int = dataclasses.field(default=None) + batch_size: int = dataclasses.field(default=None) + + def __post_init__(self): + super().__post_init__() + if self.batch_size > len(self.dataset.neural): + raise ValueError("Batch size can't be larger than data.") + self.prior = Prior(self.dataset.neural, device=self.device) + + def get_indices(self): + return NotImplementedError + + def __iter__(self): + return NotImplementedError + + def add_config(self, config): + raise NotImplementedError + + +@dataclasses.dataclass +class SupervisedMultiObjectiveLoader(MultiObjectiveLoader): + """Supervised Multiobjective data Loader. Yields batches of the specified size from the given dataset object. + """ + sampling_mode_supervised: str = dataclasses.field( + default="ref_shared", + doc="""Type of sampling performed, re whether reference are shared or not. + are shared. Options will be ref_shared, independent.""") + + def __post_init__(self): + super().__post_init__() + self.labels = [] + + def add_config(self, config): + self.labels.append(config['label']) + + def get_indices(self, num_samples: int): + if self.sampling_mode_supervised == "ref_shared": + reference_idx = self.prior.sample_prior(num_samples) + else: + raise ValueError( + f"Sampling mode {self.sampling_mode_supervised} is not implemented." + ) + + batch_index = BatchIndex( + reference=reference_idx, + positive=None, + negative=None, + ) + + return batch_index + + def __iter__(self): + for _ in range(len(self)): + index = self.get_indices(num_samples=self.batch_size) + yield self.dataset.load_batch_supervised(index, self.labels) + + +@dataclasses.dataclass +class ContrastiveMultiObjectiveLoader(MultiObjectiveLoader): + """Contrastive Multiobjective data Loader. Yields batches of the specified size from the given dataset object. + """ + + sampling_mode_contrastive: str = dataclasses.field( + default="refneg_shared", + doc= + """Type of sampling performed, re whether reference and negative samples + are shared. Options will be ref_shared, neg_shared and refneg_shared""" + ) + + def __post_init__(self): + super().__post_init__() + self.distributions = [] + + def add_config(self, config): + kwargs_distribution = config['kwargs'] + if config['distribution'] == "time": + distribution = cebra.distributions.TimeContrastive( + time_offset=kwargs_distribution['time_offset'], + num_samples=len(self.dataset.neural), + device=self.device, + ) + elif config['distribution'] == "time_delta": + distribution = cebra.distributions.TimedeltaDistribution( + continuous=self.dataset.labels[ + kwargs_distribution['label_name']], + time_delta=kwargs_distribution['time_delta'], + device=self.device) + elif config['distribution'] == "delta_normal": + distribution = cebra.distributions.DeltaNormalDistribution( + continuous=self.dataset.labels[ + kwargs_distribution['label_name']], + delta=kwargs_distribution['delta'], + device=self.device) + elif config['distribution'] == "delta_vmf": + distribution = cebra.distributions.DeltaVMFDistribution( + continuous=self.dataset.labels[ + kwargs_distribution['label_name']], + delta=kwargs_distribution['delta'], + device=self.device) + else: + raise NotImplementedError( + f"Distribution {config['distribution']} is not implemented yet." + ) + + self.distributions.append(distribution) + + def get_indices(self, num_samples: int): + """Sample and return the specified number of indices.""" + + if self.sampling_mode_contrastive == "refneg_shared": + ref_and_neg = self.prior.sample_prior(num_samples * 2) + reference_idx = ref_and_neg[:num_samples] + negative_idx = ref_and_neg[num_samples:] + + positives_idx = [] + for distribution in self.distributions: + idx = distribution.sample_conditional(reference_idx) + positives_idx.append(idx) + + batch_index = BatchIndex( + reference=reference_idx, + positive=positives_idx, + negative=negative_idx, + ) + else: + raise ValueError( + f"Sampling mode {self.sampling_mode_contrastive} is not implemented yet." + ) + + return batch_index + + def __iter__(self): + for _ in range(len(self)): + index = self.get_indices(num_samples=self.batch_size) + yield self.dataset.load_batch_contrastive(index) diff --git a/cebra/data/single_session.py b/cebra/data/single_session.py index ab6c9729..6aaed3d2 100644 --- a/cebra/data/single_session.py +++ b/cebra/data/single_session.py @@ -64,22 +64,11 @@ def __len__(self): def load_batch(self, index: BatchIndex) -> Batch: """Return the data at the specified index location.""" return Batch( - positive=self[index.positive], - negative=self[index.negative], - reference=self[index.reference], + positive=self.apply_mask(self[index.positive]), + negative=self.apply_mask(self[index.negative]), + reference=self.apply_mask(self[index.reference]), ) - def configure_for(self, model: "cebra.models.Model"): - """Configure the dataset offset for the provided model. - - Call this function before indexing the dataset. This sets the - `offset` attribute of the dataset. - - Args: - model: The model to configure the dataset for. - """ - self.offset = model.get_offset() - @dataclasses.dataclass class DiscreteDataLoader(cebra_data.Loader): @@ -237,6 +226,13 @@ def _init_distribution(self): self.dataset.continuous_index, self.delta, device=self.device) + # TODO(stes): Add this distribution from internal xCEBRA codebase at a later point + # in time, currently not in use. + #elif self.conditional == "delta_vmf": + # self.distribution = cebra.distributions.DeltaVMFDistribution( + # self.dataset.continuous_index, + # self.delta, + # device=self.device) else: raise ValueError(self.conditional) @@ -343,6 +339,8 @@ class HybridDataLoader(cebra_data.Loader): """ conditional: str = dataclasses.field(default="time_delta") + time_distribution: str = dataclasses.field(default="time") + time_offset: int = dataclasses.field(default=10) delta: float = dataclasses.field(default=0.1) @property @@ -359,17 +357,59 @@ def __post_init__(self): # e.g. integrating the FAISS dataloader back in. super().__post_init__() - if self.conditional != "time_delta": - raise NotImplementedError( - "Hybrid training is currently only implemented using the ``time_delta`` " - "continual distribution.") - - self.time_distribution = cebra.distributions.TimeContrastive( - time_offset=self.time_offset, - num_samples=len(self.dataset.neural), - device=self.device) - self.behavior_distribution = cebra.distributions.TimedeltaDistribution( - self.dataset.continuous_index, self.time_offset, device=self.device) + self._init_behavior_distribution() + self._init_time_distribution() + + def _init_behavior_distribution(self): + if self.conditional == "time": + self.behavior_distribution = cebra.distributions.TimeContrastive( + time_offset=self.time_offset, + num_samples=len(self.dataset.neural), + device=self.device, + ) + + if self.conditional == "time_delta": + self.behavior_distribution = cebra.distributions.TimedeltaDistribution( + self.dataset.continuous_index, + self.time_offset, + device=self.device) + + elif self.conditional == "delta_normal": + self.behavior_distribution = cebra.distributions.DeltaNormalDistribution( + self.dataset.continuous_index, self.delta, device=self.device) + + elif self.conditional == "time": + self.behavior_distribution = cebra.distributions.TimeContrastive( + time_offset=self.time_offset, + num_samples=len(self.dataset.neural), + device=self.device, + ) + + def _init_time_distribution(self): + + if self.time_distribution == "time": + self.time_distribution = cebra.distributions.TimeContrastive( + time_offset=self.time_offset, + num_samples=len(self.dataset.neural), + device=self.device, + ) + + elif self.time_distribution == "time_delta": + self.time_distribution = cebra.distributions.TimedeltaDistribution( + self.dataset.continuous_index, + self.time_offset, + device=self.device) + + elif self.time_distribution == "delta_normal": + self.time_distribution = cebra.distributions.DeltaNormalDistribution( + self.dataset.continuous_index, self.delta, device=self.device) + + # TODO(stes): Add this distribution from internal xCEBRA codebase at a later point + #elif self.time_distribution == "delta_vmf": + # self.time_distribution = cebra.distributions.DeltaVMFDistribution( + # self.dataset.continuous_index, self.delta, device=self.device) + else: + raise ValueError def get_indices(self, num_samples: int) -> BatchIndex: """Samples indices for reference, positive and negative examples. diff --git a/cebra/datasets/__init__.py b/cebra/datasets/__init__.py index 5716e399..7a187489 100644 --- a/cebra/datasets/__init__.py +++ b/cebra/datasets/__init__.py @@ -97,6 +97,8 @@ def get_datapath(path: str = None) -> str: from cebra.datasets.hippocampus import * from cebra.datasets.monkey_reaching import * from cebra.datasets.synthetic_data import * + from cebra.datasets.perich import * + from cebra.datasets.nlb import * except ModuleNotFoundError as e: warnings.warn(f"Could not initialize one or more datasets: {e}. " f"For using the datasets, consider installing the " diff --git a/cebra/datasets/demo.py b/cebra/datasets/demo.py index 90ba5367..380a6526 100644 --- a/cebra/datasets/demo.py +++ b/cebra/datasets/demo.py @@ -32,7 +32,8 @@ import cebra.io from cebra.datasets import register -_DEFAULT_NUM_TIMEPOINTS = 100000 +_DEFAULT_NUM_TIMEPOINTS = 10_000 +NUMS_NEURAL = [3, 4, 5] class DemoDataset(cebra.data.SingleSessionDataset): @@ -117,7 +118,7 @@ class MultiDiscrete(cebra.data.DatasetCollection): def __init__( self, - nums_neural=[3, 4, 5], + nums_neural=NUMS_NEURAL, num_timepoints=_DEFAULT_NUM_TIMEPOINTS, ): super().__init__(*[ @@ -131,7 +132,7 @@ class MultiContinuous(cebra.data.DatasetCollection): def __init__( self, - nums_neural=[3, 4, 5], + nums_neural=NUMS_NEURAL, num_behavior=5, num_timepoints=_DEFAULT_NUM_TIMEPOINTS, ): @@ -146,8 +147,26 @@ def __init__( # @register("demo-mixed-multisession") class MultiMixed(cebra.data.DatasetCollection): - def __init__(self, nums_neural=[3, 4, 5], num_behavior=5): + def __init__(self, nums_neural=NUMS_NEURAL, num_behavior=5): super().__init__(*[ DemoDatasetMixed(_DEFAULT_NUM_TIMEPOINTS, num_neural, num_behavior) for num_neural in nums_neural ]) + + +@register("demo-continuous-unified") +class DemoDatasetUnified(cebra.data.UnifiedDataset): + + def __init__( + self, + nums_neural=NUMS_NEURAL, + num_behavior=5, + num_timepoints=_DEFAULT_NUM_TIMEPOINTS, + ): + super().__init__(*[ + DemoDatasetContinuous(num_timepoints, num_neural, num_behavior) + for num_neural in nums_neural + ]) + + self.num_timepoints = num_timepoints + self.nums_neural = nums_neural diff --git a/cebra/distributions/multisession.py b/cebra/distributions/multisession.py index 647044f2..1e0c48d4 100644 --- a/cebra/distributions/multisession.py +++ b/cebra/distributions/multisession.py @@ -21,7 +21,11 @@ # """Continuous variable multi-session sampling.""" +import random +from typing import Optional + import numpy as np +import numpy.typing as npt import torch import cebra.distributions as cebra_distr @@ -383,3 +387,202 @@ def __getitem__(self, pos_idx): for i in range(self.num_sessions): pos_samples[i] = self.data[i][pos_idx[i]] return pos_samples + + +class UnifiedSampler(MultisessionSampler): + """Multi-session sampling, considering them as a single session. + + Align embeddings across multiple sessions, using a set of + auxiliary variables, so that the samples in the different sessions + are sampled together based on how the corresponding auxiliary + variables are close from each other. + + Then, the reference, positive and negative can be concatenated on their + neurons axis to train a single model for all sessions. + + Example: + >>> import cebra.distributions.multisession as cebra_distributions_multisession + >>> import cebra.integrations.sklearn.dataset as cebra_sklearn_dataset + >>> import cebra.data + >>> import torch + >>> from torch import nn + >>> # Multisession training: one model per dataset (different input dimensions) + >>> session1 = torch.rand(100, 30) + >>> session2 = torch.rand(100, 50) + >>> index1 = torch.rand(100) + >>> index2 = torch.rand(100) + >>> num_features = 8 + >>> dataset = cebra.data.UnifiedDataset( + ... cebra_sklearn_dataset.SklearnDataset(session1, (index1, )), + ... cebra_sklearn_dataset.SklearnDataset(session2, (index2, ))) + >>> model = cebra.models.init( + ... name="offset1-model", + ... num_neurons=dataset.input_dimension, + ... num_units=32, + ... num_output=num_features, + ... ).to("cpu") + >>> sampler = cebra_distributions_multisession.UnifiedSampler(dataset, time_offset=10) + + >>> # ref and pos samples from all datasets + >>> ref = sampler.sample_prior(100) + >>> pos = sampler.sample_conditional(ref) + >>> ref = torch.LongTensor(ref) + >>> pos = torch.LongTensor(pos) + >>> loss = (ref - pos)**2 + + Note: + This function does currently not support explicitly selected + discrete indices. They should be added as dimensions to the + continuous index. More weight can be added to the discrete + dimensions by using larger values in one-hot coding. + + """ + + def sample_all_uniform_prior(self, + num_samples: int) -> npt.NDArray[np.int64]: + """Returns uniformly sampled index for all sessions of the dataset. + + Args: + num_samples: Number of samples to sample in each session. + + Returns: + ``(N, num_samples)`` with ``N`` the number of sessions. Array of + samples, uniformly picked for each session. + """ + return super().sample_prior(num_samples=num_samples) + + def sample_prior(self, + num_samples: int, + session_id: Optional[int] = None) -> npt.NDArray[np.int64]: + """Return uniformly sampled indices for all sessions. + + First, the reference indexes in a reference session are uniformly sampled. + Then the reference indexes for the other sessions are sampled so that their + corresponding auxiliary variables are close to the reference indexes of the + reference session. + + Args: + num_samples: Number of samples to pick. + session_id: ID of the session to use as the reference session. If ``None``, + the session is randomly selected. + + Returns: + A :py:func:`numpy.array` containing the idx of the reference samples for all + sessions. + """ + + # Randomly pick the reference session + if session_id is None: + session_id = random.choice(list(range(self.num_sessions))) + + # Sample prior for all sessions + idx = self.sample_all_uniform_prior(num_samples=num_samples) + # Keep the idx for the reference session only + idx = torch.from_numpy(idx[session_id]) + + # Sample the references indexes in other sessions, based on their distance to the + # reference idx in the reference session. + return self.sample_all_sessions(idx, session_id).cpu().numpy() + + def _get_query(self, + reference_idx: torch.Tensor, + session_id: int, + aligned: bool = False) -> torch.Tensor: + """ + + Args: + aligned: If True, no time difference is added to the query. + """ + cum_idx = reference_idx + self.lengths[session_id] + if aligned: + query = self.all_data[cum_idx] + else: + diff_idx = torch.randint(len(self.time_difference), + (len(reference_idx),)) + query = self.all_data[cum_idx] + self.time_difference[diff_idx] + return torch.from_numpy(query).to(_device) + + def sample_all_sessions(self, ref_idx: torch.Tensor, + session_id: int) -> torch.Tensor: + """Sample sessions based on a reference session. + + Reference samples for the ``(session_id)``th session were first sampled uniformly, as in + the py:class:`~.MultisessionSampler`. Then, reference samples for the other sessions + are sampled so that they are as close as the corresponding auxiliary variables in + the reference session. + + Note: similar to ``sample_condiditonal`` but at the level of the sessions, sampling ref idx in each + session so that they are close to the ref idx in the reference session (``session_id``th session). + + Args: + ref_idx: Uniformly sampled ``idx`` for the reference session, ``(num_samples, )``, values + can be in ``[0, len(get_session[session_id])]``. + session_id: Session ID of the reference session, whose ``idx`` are present in ``ref_idx``. + + Returns: + The prior for all sessions, creating a "pseudo-animal", where ``idx`` sampled in different + sessions correspond to points in the recordings where the auxiliary variables are similar. + + """ + # Get the continuous data corresponding to the idx + # all_data: (sum(self.session_lengths), ) + # ref_idx: (num_samples, ), values in [O, len(get_session[session_id])] + # self.lengths: (num_sessions, ), cumsum of the length of each session, providing the first + # element of a session in self.all_data. + # cum_ref_idx: (num_samples, ), values of ref_idx, switched to correspond to the indexes in + # of session_id, in the flatten array self.all_data. + all_idx = torch.zeros(self.num_sessions, len(ref_idx), + device=_device).long() + query = self._get_query( + reference_idx=ref_idx, session_id=session_id, + aligned=True) # same query for all + no time diff added + + for i in range(self.num_sessions): + # except for the session_id provided + if i == session_id: + continue + # different query for each. more robust to variance. + #query = self._get_query(reference_idx=ref_idx, + # session_id=session_id, + # aligned=False) + + # get the idx of the datapoint that is the closest to the query + all_idx[i] = self.index[i].search( + query) # search in the whole dataset + + # all_idx[i] = self.index[i].search_or_mask( + # query, threshold=self.distance_threshold[i]) + + all_idx[session_id] = ref_idx + return all_idx + + def sample_conditional( + self, reference_idx: npt.NDArray[np.int64]) -> torch.Tensor: + """Sample from the conditional distribution. + + Contrary to the :py:class:`MultisessionSampler`, conditional distribution + is sampled so that the samples match the reference samples. They are sampled + from the same session as each reference idx only, rather than across all + sessions. + + Args: + reference_idx: Reference indices, with dimension ``(session, batch)``. + + Returns: + Positive indices, which will be grouped by + session and match the reference indices. + Returned shape is ``(session, batch)``. + + """ + + cond_idx = torch.zeros((reference_idx.shape[0], reference_idx.shape[1]), + dtype=torch.int, + device=_device).long() + + for session_id in range(self.num_sessions): + query = self._get_query(reference_idx=reference_idx[session_id], + session_id=session_id) + + cond_idx[session_id] = self.index[session_id].search(query) + + return cond_idx.cpu().numpy() diff --git a/cebra/integrations/matplotlib.py b/cebra/integrations/matplotlib.py index 30af7fd4..c2696d4a 100644 --- a/cebra/integrations/matplotlib.py +++ b/cebra/integrations/matplotlib.py @@ -684,7 +684,7 @@ def _to_heatmap_format( else: heatmap_values[i, j] = score_dict[label_i, label_j] - return np.minimum(heatmap_values * 100, 99) + return heatmap_values * 100 def _create_text(self): """Create the text to add in the confusion matrix grid and the title.""" diff --git a/cebra/integrations/plotly.py b/cebra/integrations/plotly.py index bbaa1de6..2cfc5ec9 100644 --- a/cebra/integrations/plotly.py +++ b/cebra/integrations/plotly.py @@ -27,6 +27,7 @@ import numpy as np import numpy.typing as npt import plotly.graph_objects +import plotly.graph_objects as go import torch from cebra.integrations.matplotlib import _EmbeddingPlot @@ -152,18 +153,19 @@ def _plot_3d(self, **kwargs) -> plotly.graph_objects.Figure: def plot_embedding_interactive( - embedding: Union[npt.NDArray, torch.Tensor], - embedding_labels: Optional[Union[npt.NDArray, torch.Tensor, str]] = "grey", - axis: Optional[plotly.graph_objects.Figure] = None, - markersize: float = 1, - idx_order: Optional[Tuple[int]] = None, - alpha: float = 0.4, - cmap: str = "cool", - title: str = "Embedding", - figsize: Tuple[int] = (5, 5), - dpi: int = 100, - **kwargs, -) -> plotly.graph_objects.Figure: + embedding: Union[npt.NDArray, torch.Tensor], + embedding_labels: Optional[Union[npt.NDArray, torch.Tensor, + str]] = "grey", + axis: Optional["go.Figure"] = None, + markersize: float = 1, + idx_order: Optional[Tuple[int]] = None, + alpha: float = 0.4, + cmap: str = "cool", + title: str = "Embedding", + figsize: Tuple[int] = (5, 5), + dpi: int = 100, + **kwargs, +) -> "go.Figure": """Plot embedding in a 3D dimensional space. This is supposing that the dimensions provided to ``idx_order`` are in the range of the number of diff --git a/cebra/integrations/sklearn/cebra.py b/cebra/integrations/sklearn/cebra.py index a340a392..0add3f06 100644 --- a/cebra/integrations/sklearn/cebra.py +++ b/cebra/integrations/sklearn/cebra.py @@ -27,11 +27,14 @@ import numpy as np import numpy.typing as npt +import packaging.version import pkg_resources +import sklearn import sklearn.utils.validation as sklearn_utils_validation import torch from sklearn.base import BaseEstimator from sklearn.base import TransformerMixin +from sklearn.utils.metaestimators import available_if from torch import nn import cebra.data @@ -41,6 +44,38 @@ import cebra.models import cebra.solver +# NOTE(stes): From torch 2.6 onwards, we need to specify the following list +# when loading CEBRA models to allow weights_only = True. +CEBRA_LOAD_SAFE_GLOBALS = [ + cebra.data.Offset, torch.torch_version.TorchVersion, np.dtype, + np.dtypes.Float64DType, np.dtypes.Int64DType +] + + +def check_version(estimator): + # NOTE(stes): required as a check for the old way of specifying tags + # https://github.com/scikit-learn/scikit-learn/pull/29677#issuecomment-2334229165 + return packaging.version.parse( + sklearn.__version__) < packaging.version.parse("1.6.dev") + + +def _safe_torch_load(filename, weights_only, **kwargs): + if weights_only is None: + if packaging.version.parse( + torch.__version__) >= packaging.version.parse("2.6.0"): + weights_only = True + else: + weights_only = False + + if not weights_only: + checkpoint = torch.load(filename, weights_only=False, **kwargs) + else: + # NOTE(stes): This is only supported for torch 2.6+ + with torch.serialization.safe_globals(CEBRA_LOAD_SAFE_GLOBALS): + checkpoint = torch.load(filename, weights_only=True, **kwargs) + + return checkpoint + def _init_loader( is_cont: bool, @@ -94,7 +129,7 @@ def _init_loader( (not is_cont, not is_disc, is_multi), ] if any(all(combination) for combination in incompatible_combinations): - raise ValueError(f"Invalid index combination.\n" + raise ValueError("Invalid index combination.\n" f"Continuous: {is_cont},\n" f"Discrete: {is_disc},\n" f"Hybrid training: {is_hybrid},\n" @@ -258,7 +293,7 @@ def _require_arg(key): "single-session", ) - error_message = (f"Invalid index combination.\n" + error_message = ("Invalid index combination.\n" f"Continuous: {is_cont},\n" f"Discrete: {is_disc},\n" f"Hybrid training: {is_hybrid},\n" @@ -305,7 +340,7 @@ def _load_cebra_with_sklearn_backend(cebra_info: Dict) -> "CEBRA": if missing_keys: raise ValueError( f"Missing keys in data dictionary: {', '.join(missing_keys)}. " - f"You can try loading the CEBRA model with the torch backend.") + "You can try loading the CEBRA model with the torch backend.") args, state, state_dict = cebra_info['args'], cebra_info[ 'state'], cebra_info['state_dict'] @@ -364,7 +399,7 @@ def _load_cebra_with_sklearn_backend(cebra_info: Dict) -> "CEBRA": return cebra_ -class CEBRA(BaseEstimator, TransformerMixin): +class CEBRA(TransformerMixin, BaseEstimator): """CEBRA model defined as part of a ``scikit-learn``-like API. Attributes: @@ -462,6 +497,8 @@ class CEBRA(BaseEstimator, TransformerMixin): optimizer documentation in :py:mod:`torch.optim` for further information on how to format the arguments. |Default:| ``(('betas', (0.9, 0.999)), ('eps', 1e-08), ('weight_decay', 0), ('amsgrad', False))`` + masking_kwargs (dict): + TODO(celia) Example: @@ -535,6 +572,8 @@ def __init__( ("weight_decay", 0), ("amsgrad", False), ), + masking_kwargs: Dict[str, Union[float, List[float], Tuple[float, + ...]]] = None, ): self.__dict__.update(locals()) @@ -621,12 +660,12 @@ def _get_dataset_multi(X: List[Iterable], y: List[Iterable]): # TODO(celia): to make it work for multiple set of index. For now, y should be a tuple of one list only if isinstance(y, tuple) and len(y) > 1: raise NotImplementedError( - f"Support for multiple set of index is not implemented in multissesion training, " + "Support for multiple set of index is not implemented in multissesion training, " f"got {len(y)} sets of indexes.") if not _are_sessions_equal(X, y): raise ValueError( - f"Invalid number of sessions: number of sessions in X and y need to match, " + "Invalid number of sessions: number of sessions in X and y need to match, " f"got X:{len(X)} and y:{[len(y_i) for y_i in y]}.") for session in range(len(X)): @@ -650,8 +689,8 @@ def _get_dataset_multi(X: List[Iterable], y: List[Iterable]): else: if not _are_sessions_equal(X, y): raise ValueError( - f"Invalid number of samples or labels sessions: provide one session for single-session training, " - f"and make sure the number of samples in X and y need match, " + "Invalid number of samples or labels sessions: provide one session for single-session training, " + "and make sure the number of samples in X and y match, " f"got {len(X)} and {[len(y_i) for y_i in y]}.") is_multisession = False dataset = _get_dataset(X, y) @@ -813,7 +852,7 @@ def _check_labels_types(self, y: tuple, session_id: Optional[int] = None): # Check that same number of index if len(self.label_types_) != n_idx: raise ValueError( - f"Number of index invalid: labels must have the same number of index as for fitting," + "Number of index invalid: labels must have the same number of index as for fitting," f"expects {len(self.label_types_)}, got {n_idx} idx.") for i in range(len(self.label_types_)): # for each index @@ -826,12 +865,12 @@ def _check_labels_types(self, y: tuple, session_id: Optional[int] = None): > 1): # is there more than one feature in the index if label_types_idx[1][1] != y[i].shape[1]: raise ValueError( - f"Labels invalid: must have the same number of features as the ones used for fitting," + "Labels invalid: must have the same number of features as the ones used for fitting," f"expects {label_types_idx[1]}, got {y[i].shape}.") if label_types_idx[0] != y[i].dtype: raise ValueError( - f"Labels invalid: must have the same type of features as the ones used for fitting," + "Labels invalid: must have the same type of features as the ones used for fitting," f"expects {label_types_idx[0]}, got {y[i].dtype}.") def _prepare_fit( @@ -859,6 +898,8 @@ def _prepare_fit( self.offset_ = self._compute_offset() dataset, is_multisession = self._prepare_data(X, y) + dataset.set_masks(self.masking_kwargs) + loader, solver_name = self._prepare_loader( dataset, max_iterations=self.max_iterations, @@ -1018,14 +1059,12 @@ def _partial_fit( # Save variables of interest as semi-private attributes self.model_ = model - self.n_features_ = ([ - loader.dataset.get_input_dimension(session_id) - for session_id in range(loader.dataset.num_sessions) - ] if is_multisession else loader.dataset.input_dimension) + + self.n_features_ = solver.n_features + self.num_sessions_ = solver.num_sessions self.solver_ = solver self.n_features_in_ = ([model[n].num_input for n in range(len(model))] if is_multisession else model.num_input) - self.num_sessions_ = loader.dataset.num_sessions if is_multisession else None return self @@ -1194,7 +1233,7 @@ def transform(self, >>> cebra_model = cebra.CEBRA(max_iterations=10) >>> cebra_model.fit(dataset) CEBRA(max_iterations=10) - >>> embedding = cebra_model.transform(dataset) + >>> embedding = cebra_model.transform(dataset, batch_size=200) """ sklearn_utils_validation.check_is_fitted(self, "n_features_") @@ -1221,60 +1260,6 @@ def transform(self, return output.detach().cpu().numpy() - # Deprecated, kept for testing. - def transform_deprecated(self, - X: Union[npt.NDArray, torch.Tensor], - session_id: Optional[int] = None) -> npt.NDArray: - """Transform an input sequence and return the embedding. - - Args: - X: A numpy array or torch tensor of size ``time x dimension``. - session_id: The session ID, an :py:class:`int` between 0 and :py:attr:`num_sessions` for - multisession, set to ``None`` for single session. - - Returns: - A :py:func:`numpy.array` of size ``time x output_dimension``. - - Example: - - >>> import cebra - >>> import numpy as np - >>> dataset = np.random.uniform(0, 1, (1000, 30)) - >>> cebra_model = cebra.CEBRA(max_iterations=10) - >>> cebra_model.fit(dataset) - CEBRA(max_iterations=10) - >>> embedding = cebra_model.transform(dataset) - - """ - - sklearn_utils_validation.check_is_fitted(self, "n_features_") - model, offset = self._select_model(X, session_id) - - # Input validation - X = sklearn_utils.check_input_array(X, min_samples=len(self.offset_)) - input_dtype = X.dtype - - with torch.no_grad(): - model.eval() - - if self.pad_before_transform: - X = np.pad(X, ((offset.left, offset.right - 1), (0, 0)), - mode="edge") - X = torch.from_numpy(X).float().to(self.device_) - - if isinstance(model, cebra.models.ConvolutionalModelMixin): - # Fully convolutional evaluation, switch (T, C) -> (1, C, T) - X = X.transpose(1, 0).unsqueeze(0) - output = model(X).cpu().numpy().squeeze(0).transpose(1, 0) - else: - # Standard evaluation, (T, C, dt) - output = model(X).cpu().numpy() - - if input_dtype == "float64": - return output.astype(input_dtype) - - return output - def fit_transform( self, X: Union[npt.NDArray, torch.Tensor], @@ -1317,6 +1302,15 @@ def fit_transform( callback_frequency=callback_frequency) return self.transform(X) + def __sklearn_tags__(self): + # NOTE(stes): from 1.6.dev, this is the new way to specify tags + # https://scikit-learn.org/dev/developers/develop.html + # https://github.com/scikit-learn/scikit-learn/pull/29677#issuecomment-2334229165 + tags = super().__sklearn_tags__() + tags.non_deterministic = True + return tags + + @available_if(check_version) def _more_tags(self): # NOTE(stes): This tag is needed as seeding is not fully implemented in the # current version of CEBRA. @@ -1416,15 +1410,22 @@ def save(self, def load(cls, filename: str, backend: Literal["auto", "sklearn", "torch"] = "auto", + weights_only: bool = None, **kwargs) -> "CEBRA": """Load a model from disk. Args: filename: The path to the file in which to save the trained model. backend: A string identifying the used backend. + weights_only: Indicates whether unpickler should be restricted to loading only tensors, primitive types, + dictionaries and any types added via :py:func:`torch.serialization.add_safe_globals`. + See :py:func:`torch.load` with ``weights_only=True`` for more details. It it recommended to leave this + at the default value of ``None``, which sets the argument to ``False`` for torch<2.6, and ``True`` for + higher versions of torch. If you experience issues with loading custom models (specified outside + of the CEBRA package), you can try to set this to ``False`` if you trust the source of the model. kwargs: Optional keyword arguments passed directly to the loader. - Return: + Returns: The model to load. Note: @@ -1434,7 +1435,6 @@ def load(cls, For information about the file format please refer to :py:meth:`cebra.CEBRA.save`. Example: - >>> import cebra >>> import numpy as np >>> import tempfile @@ -1448,16 +1448,14 @@ def load(cls, >>> loaded_model = cebra.CEBRA.load(tmp_file) >>> embedding = loaded_model.transform(dataset) >>> tmp_file.unlink() - """ - supported_backends = ["auto", "sklearn", "torch"] if backend not in supported_backends: raise NotImplementedError( f"Unsupported backend: '{backend}'. Supported backends are: {', '.join(supported_backends)}" ) - checkpoint = torch.load(filename, **kwargs) + checkpoint = _safe_torch_load(filename, weights_only, **kwargs) if backend == "auto": backend = "sklearn" if isinstance(checkpoint, dict) else "torch" diff --git a/cebra/integrations/sklearn/metrics.py b/cebra/integrations/sklearn/metrics.py index d07f9359..d8fd791d 100644 --- a/cebra/integrations/sklearn/metrics.py +++ b/cebra/integrations/sklearn/metrics.py @@ -109,6 +109,149 @@ def infonce_loss( return avg_loss +def goodness_of_fit_score(cebra_model: cebra_sklearn_cebra.CEBRA, + X: Union[npt.NDArray, torch.Tensor], + *y, + session_id: Optional[int] = None, + num_batches: int = 500) -> float: + """Compute the goodness of fit score on a *single session* dataset on the model. + + This function uses the :func:`infonce_loss` function to compute the InfoNCE loss + for a given `cebra_model` and the :func:`infonce_to_goodness_of_fit` function + to derive the goodness of fit from the InfoNCE loss. + + Args: + cebra_model: The model to use to compute the InfoNCE loss on the samples. + X: A 2D data matrix, corresponding to a *single session* recording. + y: An arbitrary amount of continuous indices passed as 2D matrices, and up to one + discrete index passed as a 1D array. Each index has to match the length of ``X``. + session_id: The session ID, an :py:class:`int` between 0 and :py:attr:`cebra.CEBRA.num_sessions` + for multisession, set to ``None`` for single session. + num_batches: The number of iterations to consider to evaluate the model on the new data. + Higher values will give a more accurate estimate. Set it to at least 500 iterations. + + Returns: + The average GoF score estimated over ``num_batches`` batches from the data distribution. + + Related: + :func:`infonce_to_goodness_of_fit` + + Example: + + >>> import cebra + >>> import numpy as np + >>> neural_data = np.random.uniform(0, 1, (1000, 20)) + >>> cebra_model = cebra.CEBRA(max_iterations=10, batch_size = 512) + >>> cebra_model.fit(neural_data) + CEBRA(batch_size=512, max_iterations=10) + >>> gof = cebra.sklearn.metrics.goodness_of_fit_score(cebra_model, neural_data) + """ + loss = infonce_loss(cebra_model, + X, + *y, + session_id=session_id, + num_batches=num_batches, + correct_by_batchsize=False) + return infonce_to_goodness_of_fit(loss, cebra_model) + + +def goodness_of_fit_history(model: cebra_sklearn_cebra.CEBRA) -> np.ndarray: + """Return the history of the goodness of fit score. + + Args: + model: A trained CEBRA model. + + Returns: + A numpy array containing the goodness of fit values, measured in bits. + + Related: + :func:`infonce_to_goodness_of_fit` + + Example: + + >>> import cebra + >>> import numpy as np + >>> neural_data = np.random.uniform(0, 1, (1000, 20)) + >>> cebra_model = cebra.CEBRA(max_iterations=10, batch_size = 512) + >>> cebra_model.fit(neural_data) + CEBRA(batch_size=512, max_iterations=10) + >>> gof_history = cebra.sklearn.metrics.goodness_of_fit_history(cebra_model) + """ + infonce = np.array(model.state_dict_["log"]["total"]) + return infonce_to_goodness_of_fit(infonce, model) + + +def infonce_to_goodness_of_fit( + infonce: Union[float, np.ndarray], + model: Optional[cebra_sklearn_cebra.CEBRA] = None, + batch_size: Optional[int] = None, + num_sessions: Optional[int] = None) -> Union[float, np.ndarray]: + """Given a trained CEBRA model, return goodness of fit metric. + + The goodness of fit ranges from 0 (lowest meaningful value) + to a positive number with the unit "bits", the higher the + better. + + Values lower than 0 bits are possible, but these only occur + due to numerical effects. A perfectly collapsed embedding + (e.g., because the data cannot be fit with the provided + auxiliary variables) will have a goodness of fit of 0. + + The conversion between the generalized InfoNCE metric that + CEBRA is trained with and the goodness of fit computed with this + function is + + .. math:: + + S = \\log N - \\text{InfoNCE} + + To use this function, either provide a trained CEBRA model or the + batch size and number of sessions. + + Args: + infonce: The InfoNCE loss, either a single value or an iterable of values. + model: The trained CEBRA model. + batch_size: The batch size used to train the model. + num_sessions: The number of sessions used to train the model. + + Returns: + Numpy array containing the goodness of fit values, measured in bits + + Raises: + RuntimeError: If the provided model is not fit to data. + ValueError: If both ``model`` and ``(batch_size, num_sessions)`` are provided. + """ + if model is not None: + if batch_size is not None or num_sessions is not None: + raise ValueError( + "batch_size and num_sessions should not be provided if model is provided." + ) + if not hasattr(model, "state_dict_"): + raise RuntimeError("Fit the CEBRA model first.") + if model.batch_size is None: + raise ValueError( + "Computing the goodness of fit is not yet supported for " + "models trained on the full dataset (batchsize = None). ") + batch_size = model.batch_size + num_sessions = model.num_sessions_ + if num_sessions is None: + num_sessions = 1 + + if model.batch_size is None: + raise ValueError( + "Computing the goodness of fit is not yet supported for " + "models trained on the full dataset (batchsize = None). ") + else: + if batch_size is None or num_sessions is None: + raise ValueError( + f"batch_size ({batch_size}) and num_sessions ({num_sessions})" + f"should be provided if model is not provided.") + + nats_to_bits = np.log2(np.e) + chance_level = np.log(batch_size * num_sessions) + return (chance_level - infonce) * nats_to_bits + + def _consistency_scores( embeddings: List[Union[npt.NDArray, torch.Tensor]], datasets: List[Union[int, str]], diff --git a/cebra/integrations/sklearn/utils.py b/cebra/integrations/sklearn/utils.py index 0ec01aa1..be6f54ce 100644 --- a/cebra/integrations/sklearn/utils.py +++ b/cebra/integrations/sklearn/utils.py @@ -22,12 +22,26 @@ import warnings import numpy.typing as npt +import packaging +import sklearn import sklearn.utils.validation as sklearn_utils_validation import torch import cebra.helper +def _sklearn_check_array(array, **kwargs): + # NOTE(stes): See discussion in https://github.com/AdaptiveMotorControlLab/CEBRA/pull/206 + # https://scikit-learn.org/1.6/modules/generated/sklearn.utils.check_array.html + # force_all_finite was renamed to ensure_all_finite and will be removed in 1.8. + if packaging.version.parse( + sklearn.__version__) < packaging.version.parse("1.6"): + if "ensure_all_finite" in kwargs: + kwargs["force_all_finite"] = kwargs["ensure_all_finite"] + del kwargs["ensure_all_finite"] + return sklearn_utils_validation.check_array(array, **kwargs) + + def update_old_param(old: dict, new: dict, kwargs: dict, default) -> tuple: """Handle deprecated arguments of a function until they are replaced. @@ -74,16 +88,16 @@ def check_input_array(X: npt.NDArray, *, min_samples: int) -> npt.NDArray: Returns: The converted and validated array. """ - return sklearn_utils_validation.check_array( + return _sklearn_check_array( X, accept_sparse=False, accept_large_sparse=False, - # NOTE: remove float16 because F.pad does not allow float16. + # NOTE(celia): remove float16 because F.pad does not allow float16. dtype=("float32", "float64"), order=None, copy=False, - force_all_finite=True, ensure_2d=True, + ensure_all_finite=True, allow_nd=False, ensure_min_samples=min_samples, ensure_min_features=1, @@ -106,15 +120,15 @@ def check_label_array(y: npt.NDArray, *, min_samples: int): Returns: The converted and validated labels. """ - return sklearn_utils_validation.check_array( + return _sklearn_check_array( y, accept_sparse=False, accept_large_sparse=False, dtype="numeric", order=None, copy=False, - force_all_finite=True, ensure_2d=False, + ensure_all_finite=True, allow_nd=False, ensure_min_samples=min_samples, ) diff --git a/cebra/models/__init__.py b/cebra/models/__init__.py index 4dfad333..2d170e24 100644 --- a/cebra/models/__init__.py +++ b/cebra/models/__init__.py @@ -36,5 +36,7 @@ from cebra.models.multiobjective import * from cebra.models.layers import * from cebra.models.criterions import * +from cebra.models.multicriterions import * +from cebra.models.jacobian_regularizer import * cebra.registry.add_docstring(__name__) diff --git a/cebra/models/criterions.py b/cebra/models/criterions.py index 47c2a87f..f78e298b 100644 --- a/cebra/models/criterions.py +++ b/cebra/models/criterions.py @@ -95,7 +95,7 @@ def infonce( Note: - The behavior of this function changed beginning in CEBRA 0.3.0. - The InfoNCE implementation is numerically stabilized. + The InfoNCE implementation is numerically stabilized. """ with torch.no_grad(): c, _ = neg_dist.max(dim=1, keepdim=True) diff --git a/cebra/models/decoders.py b/cebra/models/decoders.py new file mode 100644 index 00000000..ec7c3fca --- /dev/null +++ b/cebra/models/decoders.py @@ -0,0 +1,38 @@ +import torch.nn as nn + +from cebra.models import register + + +@register("one-layer-mlp-decoder") +class SingleLayerDecoder(nn.Module): + """Supervised module to predict behaviors. + + Note: + By default, the output dimension is 2, to predict x/y velocity + (Perich et al., 2018). + """ + + def __init__(self, input_dim, output_dim=2): + super(SingleLayerDecoder, self).__init__() + self.fc = nn.Linear(input_dim, output_dim) + + def forward(self, x): + return self.fc(x) + + +@register("two-layers-mlp-decoder") +class TwoLayersDecoder(nn.Module): + """Supervised module to predict behaviors. + + Note: + By default, the output dimension is 2, to predict x/y velocity + (Perich et al., 2018). + """ + + def __init__(self, input_dim, output_dim=2): + super(TwoLayersDecoder, self).__init__() + self.fc = nn.Sequential(nn.Linear(input_dim, 32), nn.GELU(), + nn.Linear(32, output_dim)) + + def forward(self, x): + return self.fc(x) diff --git a/cebra/models/jacobian_regularizer.py b/cebra/models/jacobian_regularizer.py new file mode 100644 index 00000000..a909a31b --- /dev/null +++ b/cebra/models/jacobian_regularizer.py @@ -0,0 +1,148 @@ +# +# CEBRA: Consistent EmBeddings of high-dimensional Recordings using Auxiliary variables +# © Mackenzie W. Mathis & Steffen Schneider (v0.4.0+) +# Source code: +# https://github.com/AdaptiveMotorControlLab/CEBRA +# +# Please see LICENSE.md for the full license document: +# https://github.com/AdaptiveMotorControlLab/CEBRA/blob/main/LICENSE.md +# +# This file contains the PyTorch implementation of Jacobian regularization described in [1]. +# Judy Hoffman, Daniel A. Roberts, and Sho Yaida, +# "Robust Learning with Jacobian Regularization," 2019. +# [arxiv:1908.02729](https://arxiv.org/abs/1908.02729) +# +# Adapted from https://github.com/facebookresearch/jacobian_regularizer/blob/main/jacobian/jacobian.py +# licensed under the following MIT License: +# +# MIT License +# +# Copyright (c) Facebook, Inc. and its affiliates. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# +"""Jacobian Regularization for CEBRA. + +This implementation is adapted from the Jacobian regularization described in [1]_. + +.. [1] Judy Hoffman, Daniel A. Roberts, and Sho Yaida, + "Robust Learning with Jacobian Regularization," 2019. + `arxiv:1908.02729 `_ +""" + +from __future__ import division + +import torch +import torch.nn as nn + + +class JacobianReg(nn.Module): + """Loss criterion that computes the trace of the square of the Jacobian. + + Args: + n: Determines the number of random projections. If n=-1, then it is set to the dimension + of the output space and projection is non-random and orthonormal, yielding the exact + result. For any reasonable batch size, the default (n=1) should be sufficient. + |Default:| ``1`` + """ + + def __init__(self, n: int = 1): + assert n == -1 or n > 0 + self.n = n + super(JacobianReg, self).__init__() + + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + """Computes (1/2) tr \\|dy/dx\\|^2. + + Args: + x: Input tensor + y: Output tensor + + Returns: + The computed regularization term + """ + B, C = y.shape + if self.n == -1: + num_proj = C + else: + num_proj = self.n + J2 = 0 + for ii in range(num_proj): + if self.n == -1: + # orthonormal vector, sequentially spanned + v = torch.zeros(B, C) + v[:, ii] = 1 + else: + # random properly-normalized vector for each sample + v = self._random_vector(C=C, B=B) + if x.is_cuda: + v = v.cuda() + Jv = self._jacobian_vector_product(y, x, v, create_graph=True) + J2 += C * torch.norm(Jv)**2 / (num_proj * B) + R = (1 / 2) * J2 + return R + + def _random_vector(self, C: int, B: int) -> torch.Tensor: + """Creates a random vector of dimension C with a norm of C^(1/2). + + This is needed for the projection formula to work. + + Args: + C: Output dimension + B: Batch size + + Returns: + A random normalized vector + """ + if C == 1: + return torch.ones(B) + v = torch.randn(B, C) + arxilirary_zero = torch.zeros(B, C) + vnorm = torch.norm(v, 2, 1, True) + v = torch.addcdiv(arxilirary_zero, 1.0, v, vnorm) + return v + + def _jacobian_vector_product(self, + y: torch.Tensor, + x: torch.Tensor, + v: torch.Tensor, + create_graph: bool = False) -> torch.Tensor: + """Produce jacobian-vector product dy/dx dot v. + + Args: + y: Output tensor + x: Input tensor + v: Vector to compute product with + create_graph: If True, graph of the derivative will be constructed, allowing + to compute higher order derivative products. |Default:| ``False`` + + Returns: + The Jacobian-vector product + + Note: + If you want to differentiate the result, you need to make create_graph=True + """ + flat_y = y.reshape(-1) + flat_v = v.reshape(-1) + grad_x, = torch.autograd.grad(flat_y, + x, + flat_v, + retain_graph=True, + create_graph=create_graph) + return grad_x diff --git a/cebra/models/layers.py b/cebra/models/layers.py index 7c1c36e8..e8b8175e 100644 --- a/cebra/models/layers.py +++ b/cebra/models/layers.py @@ -97,3 +97,25 @@ def forward(self, inp: torch.Tensor) -> torch.Tensor: connect = self.layer(inp) downsampled = F.interpolate(inp, scale_factor=1 / self.downsample) return torch.cat([connect, downsampled[..., :connect.size(-1)]], dim=1) + + +class _SkipLinear(nn.Module): + """Add a skip connection to a linear module + Args: + module (torch.nn.Module): Module to add to the bottleneck + """ + + def __init__(self, module): + super().__init__() + self.module = module + assert isinstance(self.module, nn.Linear) + padding_size = self.module.out_features - self.module.in_features + self.padding_size = padding_size + + def forward(self, inp: torch.Tensor) -> torch.Tensor: + """Compute forward pass through the skip connection. + """ + inp_padded = F.pad(inp, (0, self.padding_size), + mode='constant', + value=0) + return inp_padded + self.module(inp) diff --git a/cebra/models/model.py b/cebra/models/model.py index 7631ba86..33cd2782 100644 --- a/cebra/models/model.py +++ b/cebra/models/model.py @@ -29,8 +29,11 @@ import cebra.data import cebra.data.datatypes import cebra.models.layers as cebra_layers +from cebra.models import parametrize from cebra.models import register +DROPOUT = 0.1 + def _check_torch_version(raise_error=False): current_version = tuple( @@ -223,6 +226,12 @@ def __init__(self, # the self.net self.normalize = normalize + def _make_layers(self, num_units, num_layers, kernel_size=3): + return [ + cebra_layers._Skip(nn.Conv1d(num_units, num_units, kernel_size), + nn.GELU()) for _ in range(num_layers) + ] + def forward(self, inp): """Compute the embedding given the input signal. @@ -265,9 +274,7 @@ def __init__(self, num_neurons, num_units, num_output, normalize=True): super().__init__( nn.Conv1d(num_neurons, num_units, 2), nn.GELU(), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), + *self._make_layers(num_units, num_layers=3), nn.Conv1d(num_units, num_output, 3), num_input=num_neurons, num_output=num_output, @@ -528,9 +535,7 @@ def __init__(self, num_neurons, num_units, num_output, normalize=True): cebra_layers._MeanAndConv(num_neurons, num_units, 4, stride=2), nn.Conv1d(num_neurons + num_units, num_units, 3, stride=2), nn.GELU(), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), + *self._make_layers(num_units, num_layers=3), nn.Conv1d(num_units, num_output, 3), num_input=num_neurons, num_output=num_output, @@ -675,22 +680,7 @@ def __init__(self, num_neurons, num_units, num_output, normalize=True): super().__init__( nn.Conv1d(num_neurons, num_units, 2), nn.GELU(), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), + *self._make_layers(num_units, num_layers=16), nn.Conv1d(num_units, num_output, 3), num_input=num_neurons, num_output=num_output, @@ -717,24 +707,9 @@ def __init__(self, num_neurons, num_units, num_output, normalize=True): ) super().__init__( nn.Conv1d(num_neurons, num_units, 2), - torch.nn.Dropout1d(p=0.1), + torch.nn.Dropout1d(p=DROPOUT), nn.GELU(), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), - cebra_layers._Skip(nn.Conv1d(num_units, num_units, 3), nn.GELU()), + *self._make_layers(num_units, num_layers=16), nn.Conv1d(num_units, num_output, 3), num_input=num_neurons, num_output=num_output, @@ -768,9 +743,9 @@ def __init__(self, num_neurons, num_units, num_output, normalize=True): ) super().__init__( nn.Conv1d(num_neurons, num_units, 2), - torch.nn.Dropout1d(p=0.1), + torch.nn.Dropout1d(p=DROPOUT), nn.GELU(), - *self._make_layers(num_units, 0.1, 16), + *self._make_layers(num_units, p=DROPOUT, n=16), nn.Conv1d(num_units, num_output, 3), num_input=num_neurons, num_output=num_output, @@ -780,3 +755,295 @@ def __init__(self, num_neurons, num_units, num_output, normalize=True): def get_offset(self) -> cebra.data.datatypes.Offset: """See `:py:meth:Model.get_offset`""" return cebra.data.Offset(18, 18) + + +@register("offset40-model") +class Offset40(_OffsetModel, ConvolutionalModelMixin): + """CEBRA model with a 40 samples receptive field.""" + + def __init__(self, num_neurons, num_units, num_output, normalize=True): + if num_units < 1: + raise ValueError( + f"Hidden dimension needs to be at least 1, but got {num_units}." + ) + super().__init__( + nn.Conv1d(num_neurons, num_units, 2), + nn.GELU(), + *self._make_layers(num_units, 18), + nn.Conv1d(num_units, num_output, 3), + num_input=num_neurons, + num_output=num_output, + normalize=normalize, + ) + + def get_offset(self) -> cebra.data.datatypes.Offset: + """See `:py:meth:Model.get_offset`""" + return cebra.data.Offset(20, 20) + + +@register("offset50-model") +class Offset50(_OffsetModel, ConvolutionalModelMixin): + """CEBRA model with a sample receptive field.""" + + def __init__(self, num_neurons, num_units, num_output, normalize=True): + if num_units < 1: + raise ValueError( + f"Hidden dimension needs to be at least 1, but got {num_units}." + ) + super().__init__( + nn.Conv1d(num_neurons, num_units, 2), + nn.GELU(), + *self._make_layers(num_units, 23), + nn.Conv1d(num_units, num_output, 3), + num_input=num_neurons, + num_output=num_output, + normalize=normalize, + ) + + def get_offset(self) -> cebra.data.datatypes.Offset: + """See `:py:meth:Model.get_offset`""" + return cebra.data.Offset(25, 25) + + +@register("offset15-model") +class Offset15Model(_OffsetModel, ConvolutionalModelMixin): + """CEBRA model with a 15 sample receptive field.""" + + def __init__(self, num_neurons, num_units, num_output, normalize=True): + if num_units < 1: + raise ValueError( + f"Hidden dimension needs to be at least 1, but got {num_units}." + ) + super().__init__( + nn.Conv1d(num_neurons, num_units, 2), + nn.GELU(), + *self._make_layers(num_units, num_layers=6), + nn.Conv1d(num_units, num_output, 2), + num_input=num_neurons, + num_output=num_output, + normalize=normalize, + ) + + def get_offset(self) -> cebra.data.datatypes.Offset: + """See `:py:meth:Model.get_offset`""" + return cebra.data.Offset(7, 8) + + +@register("offset20-model") +class Offset20Model(_OffsetModel, ConvolutionalModelMixin): + """CEBRA model with a 15 sample receptive field.""" + + def __init__(self, num_neurons, num_units, num_output, normalize=True): + if num_units < 1: + raise ValueError( + f"Hidden dimension needs to be at least 1, but got {num_units}." + ) + super().__init__( + nn.Conv1d(num_neurons, num_units, 2), + nn.GELU(), + *self._make_layers(num_units, num_layers=8), + nn.Conv1d(num_units, num_output, 3), + num_input=num_neurons, + num_output=num_output, + normalize=normalize, + ) + + def get_offset(self) -> cebra.data.datatypes.Offset: + """See `:py:meth:Model.get_offset`""" + return cebra.data.Offset(10, 10) + + +@register("offset10-model-mse-tanh") +class Offset10Model(_OffsetModel, ConvolutionalModelMixin): + """CEBRA model with a 10 sample receptive field.""" + + def __init__(self, num_neurons, num_units, num_output, normalize=False): + if num_units < 1: + raise ValueError( + f"Hidden dimension needs to be at least 1, but got {num_units}." + ) + super().__init__( + nn.Conv1d(num_neurons, num_units, 2), + nn.GELU(), + *self._make_layers(num_units, num_layers=3), + nn.Conv1d(num_units, num_output, 3), + nn.Tanh(), # Added tanh activation function + num_input=num_neurons, + num_output=num_output, + normalize=normalize, + ) + + def get_offset(self) -> cebra.data.datatypes.Offset: + """See :py:meth:`~.Model.get_offset`""" + return cebra.data.Offset(5, 5) + + +@register("offset1-model-mse-tanh") +class Offset0ModelMSETanH(_OffsetModel): + """CEBRA model with a single sample receptive field, without output normalization.""" + + def __init__(self, num_neurons, num_units, num_output, normalize=False): + super().__init__( + nn.Flatten(start_dim=1, end_dim=-1), + nn.Linear( + num_neurons, + num_output * 30, + ), + nn.GELU(), + nn.Linear(num_output * 30, num_output * 30), + nn.GELU(), + nn.Linear(num_output * 30, num_output * 10), + nn.GELU(), + nn.Linear(int(num_output * 10), num_output), + nn.Tanh(), # Added tanh activation function + num_input=num_neurons, + num_output=num_output, + normalize=normalize, + ) + + def get_offset(self) -> cebra.data.datatypes.Offset: + """See :py:meth:`~.Model.get_offset`""" + return cebra.data.Offset(0, 1) + + +@parametrize("offset1-model-mse-clip-{clip_min}-{clip_max}", + clip_min=(1000, 100, 50, 25, 20, 15, 10, 5, 1), + clip_max=(1000, 100, 50, 25, 20, 15, 10, 5, 1)) +class Offset0ModelMSEClip(_OffsetModel): + """CEBRA model with a single sample receptive field, without output normalization.""" + + def __init__(self, + num_neurons, + num_units, + num_output, + clip_min=-1, + clip_max=1, + normalize=False): + super().__init__( + nn.Flatten(start_dim=1, end_dim=-1), + nn.Linear( + num_neurons, + num_output * 30, + ), + nn.GELU(), + nn.Linear(num_output * 30, num_output * 30), + nn.GELU(), + nn.Linear(num_output * 30, num_output * 10), + nn.GELU(), + nn.Linear(int(num_output * 10), num_output), + num_input=num_neurons, + num_output=num_output, + normalize=normalize, + ) + self.clamp = nn.Hardtanh(-clip_min, clip_max) + + def forward(self, inputs): + outputs = super().forward(inputs) + outputs = self.clamp(outputs) + return outputs + + def get_offset(self) -> cebra.data.datatypes.Offset: + """See :py:meth:`~.Model.get_offset`""" + return cebra.data.Offset(0, 1) + + +@parametrize("offset1-model-mse-v2-{n_intermediate_layers}layers{tanh}", + n_intermediate_layers=(1, 2, 3, 4, 5, 6, 7, 8, 9, 10), + tanh=("-tanh", "")) +class Offset0ModelMSETanHv2(_OffsetModel): + """CEBRA model with a single sample receptive field, without output normalization.""" + + def __init__(self, + num_neurons, + num_units, + num_output, + tanh="", + n_intermediate_layers=1, + normalize=False): + if num_units < 2: + raise ValueError( + f"Number of hidden units needs to be at least 2, but got {num_units}." + ) + + intermediate_layers = [ + nn.Linear(num_units, num_units), + nn.GELU(), + ] * n_intermediate_layers + + layers = [ + nn.Flatten(start_dim=1, end_dim=-1), + nn.Linear( + num_neurons, + num_units, + ), + nn.GELU(), + *intermediate_layers, + nn.Linear(num_units, int(num_units // 2)), + nn.GELU(), + nn.Linear(int(num_units // 2), num_output), + ] + + if tanh == "-tanh": + layers += [nn.Tanh()] + + super().__init__( + *layers, + num_input=num_neurons, + num_output=num_output, + normalize=normalize, + ) + + def get_offset(self) -> cebra.data.datatypes.Offset: + """See :py:meth:`~.Model.get_offset`""" + return cebra.data.Offset(0, 1) + + +@parametrize("offset1-model-mse-resnet-{n_intermediate_layers}layers{tanh}", + n_intermediate_layers=(1, 2, 3, 4, 5, 6, 7, 8, 9, 10), + tanh=("-tanh", "")) +class Offset0ModelResNetTanH(_OffsetModel): + """CEBRA model with a single sample receptive field, without output normalization.""" + + def __init__(self, + num_neurons, + num_units, + num_output, + tanh="", + n_intermediate_layers=1, + normalize=False): + if num_units < 2: + raise ValueError( + f"Number of hidden units needs to be at least 2, but got {num_units}." + ) + + intermediate_layers = [ + cebra_layers._SkipLinear(nn.Linear(num_units, num_units)), + nn.GELU(), + ] * n_intermediate_layers + + layers = [ + nn.Flatten(start_dim=1, end_dim=-1), + cebra_layers._SkipLinear(nn.Linear( + num_neurons, + num_units, + )), + nn.GELU(), + *intermediate_layers, + cebra_layers._SkipLinear(nn.Linear(num_units, int(num_units // 2))), + nn.GELU(), + nn.Linear(int(num_units // 2), num_output), + ] + + if tanh == "-tanh": + layers += [nn.Tanh()] + + super().__init__( + *layers, + num_input=num_neurons, + num_output=num_output, + normalize=normalize, + ) + + def get_offset(self) -> cebra.data.datatypes.Offset: + """See :py:meth:`~.Model.get_offset`""" + return cebra.data.Offset(0, 1) diff --git a/cebra/models/multicriterions.py b/cebra/models/multicriterions.py new file mode 100644 index 00000000..2b02fc37 --- /dev/null +++ b/cebra/models/multicriterions.py @@ -0,0 +1,154 @@ +# +# CEBRA: Consistent EmBeddings of high-dimensional Recordings using Auxiliary variables +# © Mackenzie W. Mathis & Steffen Schneider (v0.4.0+) +# Source code: +# https://github.com/AdaptiveMotorControlLab/CEBRA +# +# Please see LICENSE.md for the full license document: +# https://github.com/AdaptiveMotorControlLab/CEBRA/blob/main/LICENSE.md +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""Support for training CEBRA with multiple criteria. + +.. note:: + This module was introduced in CEBRA 0.6.0. + +""" +from typing import Tuple + +import torch +from torch import nn + +from cebra.data.datatypes import Batch + + +class MultiCriterions(nn.Module): + """A module for handling multiple loss functions with different criteria. + + This module allows combining multiple loss functions, each operating on specific + slices of the input data. It supports both supervised and contrastive learning modes. + + Args: + losses: A list of dictionaries containing loss configurations. Each dictionary should have: + - 'indices': Tuple of (start, end) indices for the data slice + - 'supervised_loss': Dict with loss config for supervised mode + - 'contrastive_loss': Dict with loss config for contrastive mode + Loss configs should contain: + - 'name': Name of the loss function + - 'kwargs': Optional parameters for the loss function + mode: Either "supervised" or "contrastive" to specify the training mode + + The loss functions can be from torch.nn or custom implementations from cebra.models.criterions. + Each criterion is applied to its corresponding slice of the input data during forward pass. + + Example: + >>> import torch + >>> from cebra.data.datatypes import Batch + >>> # Define loss configurations for a hybrid model with both contrastive and supervised losses + >>> losses = [ + ... { + ... 'indices': (0, 10), # First 10 dimensions + ... 'contrastive_loss': { + ... 'name': 'InfoNCE', # Using CEBRA's InfoNCE loss + ... 'kwargs': {'temperature': 1.0} + ... }, + ... 'supervised_loss': { + ... 'name': 'nn.MSELoss', # Using PyTorch's MSE loss + ... 'kwargs': {} + ... } + ... }, + ... { + ... 'indices': (10, 20), # Next 10 dimensions + ... 'contrastive_loss': { + ... 'name': 'InfoNCE', # Using CEBRA's InfoNCE loss + ... 'kwargs': {'temperature': 0.5} + ... }, + ... 'supervised_loss': { + ... 'name': 'nn.L1Loss', # Using PyTorch's L1 loss + ... 'kwargs': {} + ... } + ... } + ... ] + >>> # Create sample predictions (2 batches of 32 samples each with 10 features) + >>> ref1 = torch.randn(32, 10) + >>> pos1 = torch.randn(32, 10) + >>> neg1 = torch.randn(32, 10) + >>> ref2 = torch.randn(32, 10) + >>> pos2 = torch.randn(32, 10) + >>> neg2 = torch.randn(32, 10) + >>> predictions = ( + ... Batch(reference=ref1, positive=pos1, negative=neg1), + ... Batch(reference=ref2, positive=pos2, negative=neg2) + ... ) + >>> # Create multi-criterion module in contrastive mode + >>> multi_loss = MultiCriterions(losses, mode="contrastive") + >>> # Forward pass with multiple predictions + >>> losses = multi_loss(predictions) # Returns list of loss values + >>> assert len(losses) == 2 # One loss per criterion + """ + + def __init__(self, losses, mode): + super(MultiCriterions, self).__init__() + self.criterions = nn.ModuleList() + self.slices = [] + + for loss_info in losses: + slice_indices = loss_info['indices'] + + if mode == "supervised": + loss = loss_info['supervised_loss'] + elif mode == "contrastive": + loss = loss_info['contrastive_loss'] + else: + raise NotImplementedError + + loss_name = loss['name'] + loss_kwargs = loss.get('kwargs', {}) + + if loss_name.startswith("nn"): + name = loss_name.split(".")[-1] + criterion = getattr(torch.nn, name, None) + else: + import cebra.models + criterion = getattr(cebra.models.criterions, loss_name, None) + + if criterion is None: + raise ValueError(f"Loss {loss_name} not found.") + else: + criterion = criterion(**loss_kwargs) + + self.criterions.append(criterion) + self.slices.append(slice(*slice_indices)) + assert len(self.criterions) == len(self.slices) + + def forward(self, predictions: Tuple[Batch]): + + losses = [] + + for criterion, prediction in zip(self.criterions, predictions): + + if prediction.negative is None: + # supervised + #reference: data, positive: label + loss = criterion(prediction.reference, prediction.positive) + else: + #contrastive + loss, pos, neg = criterion(prediction.reference, + prediction.positive, + prediction.negative) + + losses.append(loss) + + assert len(self.criterions) == len(predictions) == len(losses) + return losses diff --git a/cebra/models/multiobjective.py b/cebra/models/multiobjective.py index d9393fdc..5dc4d247 100644 --- a/cebra/models/multiobjective.py +++ b/cebra/models/multiobjective.py @@ -19,19 +19,80 @@ # See the License for the specific language governing permissions and # limitations under the License. # -"""Wrappers for using models with multiobjective solvers. - -.. note:: - - Experimental as of Nov 06, 2022. -""" - -from typing import Tuple +import itertools +from typing import List, Tuple import torch from torch import nn import cebra.models +import cebra.models.model as cebra_models_base + + +def create_multiobjective_model(module, + **kwargs) -> "SubspaceMultiobjectiveModel": + assert isinstance(module, cebra_models_base.Model) + if isinstance(module, cebra.models.ConvolutionalModelMixin): + return SubspaceMultiobjectiveConvolutionalModel(module=module, **kwargs) + else: + return SubspaceMultiobjectiveModel(module=module, **kwargs) + + +def check_slices_for_gaps(slice_list): + slice_list = sorted(slice_list, key=lambda s: s.start) + for i in range(1, len(slice_list)): + if slice_list[i - 1].stop < slice_list[i].start: + raise ValueError( + f"There is a gap in the slices {slice_list[i-1]} and {slice_list[i]}" + ) + + +def check_overlapping_feature_ranges(slice_list): + for slice1, slice2 in itertools.combinations(slice_list, 2): + if slice1.start < slice2.stop and slice1.stop > slice2.start: + return True + return False + + +def compute_renormalize_ranges(feature_ranges, sort=True): + + max_slice_dim = max(s.stop for s in feature_ranges) + min_slice_dim = min(s.start for s in feature_ranges) + full_emb_slice = slice(min_slice_dim, max_slice_dim) + + n_full_emb_slices = sum(1 for s in feature_ranges if s == full_emb_slice) + + if n_full_emb_slices > 1: + raise ValueError( + "There are more than one slice that cover the full embedding.") + + if n_full_emb_slices == 0: + raise ValueError( + "There are overlapping slices but none of them cover the full embedding." + ) + + rest_of_slices = [s for s in feature_ranges if s != full_emb_slice] + max_slice_dim_rest = max(s.stop for s in rest_of_slices) + min_slice_dim_rest = min(s.start for s in rest_of_slices) + + remaining_slices = [] + if full_emb_slice.start < min_slice_dim_rest: + remaining_slices.append(slice(full_emb_slice.start, min_slice_dim_rest)) + + if full_emb_slice.stop > max_slice_dim_rest: + remaining_slices.append(slice(max_slice_dim_rest, full_emb_slice.stop)) + + if len(remaining_slices) == 0: + raise ValueError( + "The behavior slices and the time slices coincide completely.") + + final_slices = remaining_slices + rest_of_slices + + if sort: + final_slices = sorted(final_slices, key=lambda s: s.start) + + check_slices_for_gaps(final_slices) + return final_slices class _Norm(nn.Module): @@ -68,6 +129,13 @@ class MultiobjectiveModel(nn.Module): TODO: - Update nn.Module type annotation for ``module`` to cebra.models.Model + + Note: + This model will be deprecated in a future version. Please use the functionality in + :py:mod:`cebra.models.multiobjective` instead, which provides more versatile + multi-objective training capabilities. Instantiation of this model will raise a + deprecation warning. The new model is :py:class:`cebra.models.multiobjective.SubspaceMultiobjectiveModel` + which allows for unlimited subspaces and better configuration of the feature ranges. """ class Mode: @@ -178,3 +246,122 @@ def forward(self, inputs): if self.renormalize: outputs = (self._norm(output) for output in outputs) return tuple(outputs) + + +class SubspaceMultiobjectiveModel(nn.Module): + """Wrapper around contrastive learning models to all training with multiple objectives + + Multi-objective training splits the last layer's feature representation into multiple + chunks, which are then used for individual training objectives. + + Args: + module: The module to wrap + dimensions: A tuple of dimension values to extract from the model's feature embedding. + renormalize: If True, the individual feature slices will be re-normalized before + getting returned---this option only makes sense in conjunction with a loss based + on the cosine distance or dot product. + TODO: + - Update nn.Module type annotation for ``module`` to cebra.models.Model + """ + + def __init__(self, + module: nn.Module, + feature_ranges: List[slice], + renormalize: bool, + split_outputs: bool = True): + super().__init__() + + if not isinstance(module, cebra.models.Model): + raise ValueError("Can only wrap models that are subclassing the " + "cebra.models.Model abstract base class. " + f"Got a model of type {type(module)}.") + + self.module = module + self.renormalize = renormalize + self._norm = _Norm() + self.feature_ranges = feature_ranges + self.split_outputs = split_outputs + + max_slice_dim = max(s.stop for s in self.feature_ranges) + min_slice_dim = min(s.start for s in self.feature_ranges) + if min_slice_dim != 0: + raise ValueError( + f"The first slice should start at 0, but it starts at {min_slice_dim}." + ) + + if max_slice_dim != self.num_output: + raise ValueError( + f"The dimension of output {self.num_output} is different than the highest dimension of the slices ({max_slice_dim})." + f"The output dimension and slice dimension need to have the same dimension." + ) + + check_slices_for_gaps(self.feature_ranges) + + if check_overlapping_feature_ranges(self.feature_ranges): + print("Computing renormalized ranges...") + self.renormalize_ranges = compute_renormalize_ranges( + self.feature_ranges, sort=True) + print("New ranges:", self.renormalize_ranges) + + def set_split_outputs(self, val): + assert isinstance(val, bool) + self.split_outputs = val + + @property + def get_offset(self): + """See :py:meth:`cebra.models.model.Model.get_offset`.""" + return self.module.get_offset + + @property + def num_output(self): + """See :py:attr:`cebra.models.model.Model.num_output`.""" + return self.module.num_output + + def forward(self, inputs): + """Compute multiple embeddings for a single signal input. + + Args: + inputs: The input tensor + + Returns: + A tuple of tensors which are sliced according to `self.feature_ranges` + if `renormalize` is set to true, each of the tensors will be normalized + across the first (feature) dimension. + """ + + output = self.module(inputs) + + if (not self.renormalize) and (not self.split_outputs): + return output + + if self.renormalize: + if hasattr(self, "renormalize_ranges"): + if not all(self.renormalize_ranges[i].start <= + self.renormalize_ranges[i + 1].start + for i in range(len(self.renormalize_ranges) - 1)): + raise ValueError( + "The renormalize_ranges must be sorted by start index.") + + output = [ + self._norm(output[:, slice_features]) + for slice_features in self.renormalize_ranges + ] + else: + output = [ + self._norm(output[:, slice_features]) + for slice_features in self.feature_ranges + ] + + output = torch.cat(output, dim=1) + + if self.split_outputs: + return tuple(output[:, slice_features] + for slice_features in self.feature_ranges) + else: + assert isinstance(output, torch.Tensor) + return output + + +class SubspaceMultiobjectiveConvolutionalModel( + SubspaceMultiobjectiveModel, cebra_models_base.ConvolutionalModelMixin): + pass diff --git a/cebra/solver/__init__.py b/cebra/solver/__init__.py index 12ad2f06..8bc63a42 100644 --- a/cebra/solver/__init__.py +++ b/cebra/solver/__init__.py @@ -37,7 +37,11 @@ # pylint: disable=wrong-import-position from cebra.solver.base import * from cebra.solver.multi_session import * +from cebra.solver.multiobjective import * +from cebra.solver.regularized import * +from cebra.solver.schedulers import * from cebra.solver.single_session import * from cebra.solver.supervised import * +from cebra.solver.unified_session import * cebra.registry.add_docstring(__name__) diff --git a/cebra/solver/base.py b/cebra/solver/base.py index f1eab6ed..9e55e03d 100644 --- a/cebra/solver/base.py +++ b/cebra/solver/base.py @@ -32,6 +32,7 @@ import abc import os +import warnings from typing import Callable, Dict, List, Literal, Optional, Tuple, Union import literate_dataclasses as dataclasses @@ -51,11 +52,12 @@ def _check_indices(batch_start_idx: int, batch_end_idx: int, offset: cebra.data.Offset, num_samples: int): - """Check that indexes in a batch are in a correct range. + """Check that indices in a batch are in a correct range. - First and last index must be positive integers, smaller than the total length of inputs - in the dataset, the first index must be smaller than the last and the batch size cannot - be smaller than the offset of the model. + First and last index must be positive integers, smaller than + the total length of inputs in the dataset, the first index + must be smaller than the last and the batch size cannot be + smaller than the offset of the model. Args: batch_start_idx: Index of the first sample in the batch. @@ -82,7 +84,7 @@ def _check_indices(batch_start_idx: int, batch_end_idx: int, raise ValueError( f"The batch has length {batch_size_length} which " f"is smaller or equal than the required offset length {len(offset)}." - f"Either choose a model with smaller offset or the batch should contain more samples." + f"Either choose a model with smaller offset or the batch should contain 3 times more samples." ) @@ -125,7 +127,7 @@ def _get_batch(inputs: torch.Tensor, offset: Optional[cebra.data.Offset], inputs: Input data. offset: Model offset. batch_start_idx: Index of the first sample in the batch. - batch_end_idx: Index of the first sample in the batch. + batch_end_idx: Index of the last sample in the batch. pad_before_transform: If True zero-pad the batched data. Returns: @@ -192,7 +194,10 @@ def _transform( offset: Model offset. Returns: - The embedding. + torch.Tensor: The (potentially) padded data. + + Raises: + ValueError: If add_padding is True and offset is not provided. """ if pad_before_transform: inputs = F.pad(inputs.T, (offset.left, offset.right - 1), 'replicate').T @@ -230,6 +235,11 @@ def __getitem__(self, idx): index_dataset = IndexDataset(inputs) index_dataloader = DataLoader(index_dataset, batch_size=batch_size) + if len(index_dataloader) < 2: + raise ValueError( + f"Number of batches must be greater than 1, you can use transform " + f"without batching instead, got {len(index_dataloader)}.") + output = [] for batch_idx, index_batch in enumerate(index_dataloader): # NOTE(celia): This is to prevent that adding the offset to the @@ -243,7 +253,11 @@ def __getitem__(self, idx): if batch_idx == (len(index_dataloader) - 1): # last batch, incomplete index_batch = torch.cat((last_batch, index_batch), dim=0) + assert index_batch[-1] + 1 == len(inputs), ( + f"Last batch index {index_batch[-1]} + 1 should be equal to the length of inputs {len(inputs)}." + ) + # Batch start and end so that `batch_size` size with the last batch including 2 batches batch_start_idx, batch_end_idx = index_batch[0], index_batch[-1] + 1 batched_data = _get_batch(inputs=inputs, offset=offset, @@ -254,7 +268,7 @@ def __getitem__(self, idx): output_batch = _inference_transform(model, batched_data) output.append(output_batch) - output = torch.cat(output) + output = torch.cat(output, dim=0) return output @@ -371,8 +385,43 @@ def num_parameters(self) -> int: @abc.abstractmethod def parameters(self, session_id: Optional[int] = None): + """Iterate over all parameters of the model. + + Args: + session_id: The session ID, an :py:class:`int` between 0 and + the number of sessions -1 for multisession, and set to + ``None`` for single session. + + Yields: + The parameters of the model. + """ raise NotImplementedError + def _compute_features( + self, + batch: cebra.data.Batch, + model: Optional[torch.nn.Module] = None + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + """Compute the features of the reference, positive and negative samples. + Args: + batch: The input data, not necessarily aligned across the batch + dimension. This means that ``batch.index`` specifies the map + between reference/positive samples, if not equal ``None``. + model: The model to use for inference. + If not provided, the model of the solver is used. + + Returns: + Tuple of reference, positive and negative features. + """ + if model is None: + model = self.model + + batch.to(self.device) + ref = model(batch.reference) + pos = model(batch.positive) + neg = model(batch.negative) + return ref, pos, neg + def _get_loader(self, loader): return ProgressBar( loader, @@ -442,7 +491,11 @@ def fit( self.decoding(loader, valid_loader)) if save_hook is not None: save_hook(num_steps, self) - self.save(logdir, f"checkpoint_{num_steps:#07d}.pth") + if logdir is not None: + self.save(logdir, f"checkpoint_{num_steps:#07d}.pth") + + assert hasattr(self, "n_features") + assert hasattr(self, "num_sessions") def step(self, batch: cebra.data.Batch) -> dict: """Perform a single gradient update. @@ -559,14 +612,23 @@ def _select_model( """ raise NotImplementedError - @property - def is_fitted(self): - return hasattr(self, "n_features") + def _check_is_fitted(self): + """Check if the model is fitted. + + If the model is fitted, the solver should have a `n_features` attribute. + + Raises: + ValueError: If the model is not fitted. + """ + if not hasattr(self, "n_features"): + raise ValueError( + f"This {type(self).__name__} instance is not fitted yet. Call 'fit' with " + "appropriate arguments before using this estimator.") @torch.no_grad() def transform(self, inputs: Union[torch.Tensor, List[torch.Tensor], npt.NDArray], - pad_before_transform: bool = True, + pad_before_transform: Optional[bool] = True, session_id: Optional[int] = None, batch_size: Optional[int] = None) -> torch.Tensor: """Compute the embedding. @@ -575,45 +637,41 @@ def transform(self, of the given model, after switching it into eval mode. Args: - inputs: The input signal - pad_before_transform: If ``False``, no padding is applied to the input sequence. - and the output sequence will be smaller than the input sequence due to the - receptive field of the model. If the input sequence is ``n`` steps long, - and a model with receptive field ``m`` is used, the output sequence would - only be ``n-m+1`` steps long. + inputs: The input signal (T, N). + pad_before_transform: If ``False``, no padding is applied to the input + sequence and the output sequence will be smaller than the input + sequence due to the receptive field of the model. If the + input sequence is ``n`` steps long, and a model with receptive + field ``m`` is used, the output sequence would only be + ``n-m+1`` steps long. session_id: The session ID, an :py:class:`int` between 0 and the number of sessions -1 for multisession, and set to ``None`` for single session. - batch_size: If not None, batched inference will be applied. + batch_size: If not None, batched inference will not be applied. Returns: The output embedding. """ - if not self.is_fitted: - raise ValueError( - f"This {type(self).__name__} instance is not fitted yet. Call 'fit' with " - "appropriate arguments before using this estimator.") - - if batch_size is not None and batch_size < 1: - raise ValueError( - f"Batch size should be at least 1, got {batch_size}") - if isinstance(inputs, list): raise ValueError( "Inputs to transform() should be the data for a single session, but received a list." ) - elif not isinstance(inputs, torch.Tensor): raise ValueError( f"Inputs should be a torch.Tensor, not {type(inputs)}.") + self._check_is_fitted() + model, offset = self._select_model(inputs, session_id) if len(offset) < 2 and pad_before_transform: pad_before_transform = False model.eval() - if batch_size is not None: + if batch_size is not None and inputs.shape[0] > int( + batch_size * 2) and not isinstance( + self.model, cebra.models.ResampleModelMixin): + # NOTE: resampling models are not supported for batched inference. output = _batched_transform( model=model, inputs=inputs, @@ -633,8 +691,6 @@ def transform(self, def _inference(self, batch: cebra.data.Batch) -> cebra.data.Batch: """Given a batch of input examples, return the model outputs. - TODO: make this a public function? - Args: batch: The input data, not necessarily aligned across the batch dimension. This means that ``batch.index`` specifies the map @@ -647,12 +703,12 @@ def _inference(self, batch: cebra.data.Batch) -> cebra.data.Batch: """ raise NotImplementedError - def load(self, logdir, filename="checkpoint.pth"): + def load(self, logdir: str, filename: str = "checkpoint.pth"): """Load the experiment from its checkpoint file. Args: - logdir: Log directory. - filename (str): Checkpoint name for loading the experiment. + logdir: Logging directory. + filename: Checkpoint name for loading the experiment. """ savepath = os.path.join(logdir, filename) @@ -662,7 +718,12 @@ def load(self, logdir, filename="checkpoint.pth"): checkpoint = torch.load(savepath, map_location=self.device) self.load_state_dict(checkpoint, strict=True) - def save(self, logdir, filename="checkpoint_last.pth"): + n_features = self.n_features + self.n_features = ([ + session_n_features for session_n_features in n_features + ] if isinstance(n_features, list) else n_features) + + def save(self, logdir: str, filename: str = "checkpoint_last.pth"): """Save the model and optimizer params. Args: @@ -693,11 +754,19 @@ class MultiobjectiveSolver(Solver): for time contrastive learning. renormalize_features: If ``True``, normalize the behavior and time contrastive features individually before computing similarity scores. + ignore_deprecation_warning: If ``True``, suppress the deprecation warning. + + Note: + This solver will be deprecated in a future version. Please use the functionality in + :py:mod:`cebra.solver.multiobjective` instead, which provides more versatile + multi-objective training capabilities. Instantiation of this solver will raise a + deprecation warning. """ num_behavior_features: int = 3 renormalize_features: bool = False output_mode: Literal["overlapping", "separate"] = "overlapping" + ignore_deprecation_warning: bool = False @property def num_time_features(self): @@ -709,6 +778,13 @@ def num_total_features(self): def __post_init__(self): super().__post_init__() + if not self.ignore_deprecation_warning: + warnings.warn( + "MultiobjectiveSolver is deprecated since CEBRA 0.6.0 and will be removed in a future version. " + "Use the new functionality in cebra.solver.multiobjective instead, which is more versatile. " + "If you see this warning when using the scikit-learn interface, no action is required.", + DeprecationWarning, + stacklevel=2) self._check_dimensions() self.model = cebra.models.MultiobjectiveModel( self.model, diff --git a/cebra/solver/multi_session.py b/cebra/solver/multi_session.py index 2c2153c2..e14b4b8d 100644 --- a/cebra/solver/multi_session.py +++ b/cebra/solver/multi_session.py @@ -21,8 +21,10 @@ # """Solver implementations for multi-session datasetes.""" -from typing import List, Optional +import copy +from typing import List, Optional, Tuple, Union +import numpy.typing as npt import torch import cebra @@ -40,11 +42,21 @@ class MultiSessionSolver(abc_.Solver): _variant_name = "multi-session" def parameters(self, session_id: Optional[int] = None): - """Iterate over all parameters.""" + """Iterate over all parameters. + + Args: + session_id: The session ID, an :py:class:`int` between 0 and + the number of sessions -1 for multisession, and set to + ``None`` for single session. + + Yields: + The parameters of the model. + """ if session_id is not None: for parameter in self.model[session_id].parameters(): yield parameter + # If session_id is None, it can still iterate over the criterion for parameter in self.criterion.parameters(): yield parameter @@ -69,10 +81,11 @@ def _single_model_inference(self, batch: cebra.data.Batch, across the sample dimensions, the output data should be aligned and ``batch.index`` should be set to ``None``. """ - batch.to(self.device) - ref = torch.stack([model(batch.reference)], dim=0) - pos = torch.stack([model(batch.positive)], dim=0) - neg = torch.stack([model(batch.negative)], dim=0) + ref, pos, neg = self._compute_features(batch, model) + + ref = ref.unsqueeze(0) + pos = pos.unsqueeze(0) + neg = neg.unsqueeze(0) pos = self._mix(pos, batch.index_reversed) @@ -161,12 +174,12 @@ def _check_is_inputs_valid(self, inputs: torch.Tensor, def _check_is_session_id_valid(self, session_id: Optional[int]): """Check that the session ID provided is valid for the solver instance. - The session ID must be non-null and between 0 and the number session in the dataset. + The session ID must be non-null and between 0 and the number session + in the dataset. Args: session_id: The session ID to check. """ - if session_id is None: raise RuntimeError( "No session_id provided: multisession model requires a session_id to choose the model corresponding to your data shape." @@ -177,7 +190,7 @@ def _check_is_session_id_valid(self, session_id: Optional[int]): ) def _select_model(self, inputs: torch.Tensor, session_id: Optional[int]): - """ Select the model based on the input dimension and session ID. + """ Select the (trained) model based on the input dimension and session ID. Args: inputs: Data to infer using the selected model. @@ -189,6 +202,7 @@ def _select_model(self, inputs: torch.Tensor, session_id: Optional[int]): The model (first returns) and the offset of the model (second returns). """ self._check_is_session_id_valid(session_id=session_id) + self._check_is_fitted() self._check_is_inputs_valid(inputs, session_id=session_id) model = self.model[session_id] @@ -230,9 +244,31 @@ class MultiSessionAuxVariableSolver(MultiSessionSolver): """Multi session training, contrasting neural data against behavior.""" _variant_name = "multi-session-aux" - reference_model: torch.nn.Module + reference_model: torch.nn.Module = None + + def __post_init__(self): + super().__post_init__() + if self.reference_model is None: + # NOTE(stes): This should work, according to this thread + # https://discuss.pytorch.org/t/can-i-deepcopy-a-model/52192/19 + # and create a true copy of the model. + self.reference_model = copy.deepcopy(self.model) + self.reference_model.to(self.device) + + def _inference(self, batches: List[cebra.data.Batch]) -> cebra.data.Batch: + """Given batches of input examples, computes the feature representations/embeddings. - def _inference(self, batches): + Args: + batches: A list of input data, not necessarily aligned across the batch + dimension. This means that ``batch.index`` specifies the map + between reference/positive samples, if not equal ``None``. + + Returns: + Processed batch of data. While the input data might not be aligned + across the sample dimensions, the output data should be aligned and + ``batch.index`` should be set to ``None``. + + """ refs = [] poss = [] negs = [] @@ -252,3 +288,83 @@ def _inference(self, batches): positive=pos.view(-1, num_features), negative=neg.view(-1, num_features), ) + + def _select_model( + self, + inputs: Union[torch.Tensor, List[torch.Tensor]], + session_id: Optional[int] = None, + use_reference_model: bool = False, + ) -> Tuple[Union[List[torch.nn.Module], torch.nn.Module], + cebra.data.datatypes.Offset]: + """ Select the model based on the input dimension and session ID. + + Args: + inputs: Data to infer using the selected model. + session_id: The session ID, an :py:class:`int` between 0 and + the number of sessions -1 for multisession, and set to + ``None`` for single session. + + Returns: + The model (first returns) and the offset of the model (second returns). + """ + self._check_is_inputs_valid(inputs, session_id=session_id) + self._check_is_session_id_valid(session_id=session_id) + + if use_reference_model: + model = self.reference_model[session_id] + else: + model = self.model[session_id] + offset = model.get_offset() + return model, offset + + @torch.no_grad() + def transform(self, + inputs: Union[torch.Tensor, List[torch.Tensor], npt.NDArray], + pad_before_transform: bool = True, + session_id: Optional[int] = None, + batch_size: Optional[int] = None, + use_reference_model: bool = False) -> torch.Tensor: + """Compute the embedding. + This function by default use ``model`` that was trained to encode the positive + and negative samples. To use ``reference_model`` instead of ``model`` + ``use_reference_model`` should be equal ``True``. + Args: + inputs: The input signal + use_reference_model: Flag for using ``reference_model`` + Returns: + The output embedding. + """ + if isinstance(inputs, list): + raise NotImplementedError( + "Inputs to transform() should be the data for a single session." + ) + elif not isinstance(inputs, torch.Tensor): + raise ValueError( + f"Inputs should be a torch.Tensor, not {type(inputs)}.") + + if not hasattr(self, "history") and len(self.history) > 0: + raise ValueError( + f"This {type(self).__name__} instance is not fitted yet. Call 'fit' with " + "appropriate arguments before using this estimator.") + model, offset = self._select_model( + inputs, session_id, use_reference_model=use_reference_model) + + if len(offset) < 2 and pad_before_transform: + pad_before_transform = False + + model.eval() + if batch_size is not None: + output = abc_._batched_transform( + model=model, + inputs=inputs, + offset=offset, + batch_size=batch_size, + pad_before_transform=pad_before_transform, + ) + else: + output = abc_._transform(model=model, + inputs=inputs, + offset=offset, + pad_before_transform=pad_before_transform) + + return output diff --git a/cebra/solver/multiobjective.py b/cebra/solver/multiobjective.py new file mode 100644 index 00000000..98587bd7 --- /dev/null +++ b/cebra/solver/multiobjective.py @@ -0,0 +1,508 @@ +# +# CEBRA: Consistent EmBeddings of high-dimensional Recordings using Auxiliary variables +# © Mackenzie W. Mathis & Steffen Schneider (v0.4.0+) +# Source code: +# https://github.com/AdaptiveMotorControlLab/CEBRA +# +# Please see LICENSE.md for the full license document: +# https://github.com/AdaptiveMotorControlLab/CEBRA/blob/main/LICENSE.md +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""Multiobjective contrastive learning. + +Starting in CEBRA 0.6.0, we have added support for subspace contrastive learning. +This is a method for training models that are able to learn multiple subspaces of the +feature space simultaneously. + +Subspace contrastive learning requires to use specialized models and criterions. +This module specifies a test of classes required for training CEBRA models with multiple objectives. +The objectives are defined by the wrapper class :py:class:`cebra.models.multicriterions.MultiCriterions`. + +Two solvers are currently implemented: + +- :py:class:`cebra.solver.multiobjective.ContrastiveMultiobjectiveSolverxCEBRA` +- :py:class:`cebra.solver.multiobjective.SupervisedMultiobjectiveSolverxCEBRA` + +See Also: + :py:class:`cebra.solver.multiobjective.SupervisedMultiobjectiveSolverxCEBRA` + :py:class:`cebra.solver.multiobjective.MultiObjectiveConfig` + :py:class:`cebra.models.multicriterions.MultiCriterions` +""" + +import logging +import time +import warnings +from typing import Callable, Dict, List, Optional, Tuple + +import literate_dataclasses as dataclasses +import numpy as np +import torch + +import cebra +import cebra.data +import cebra.io +import cebra.models +import cebra.solver.single_session as cebra_solver_single +from cebra.solver import register +from cebra.solver.base import Solver +from cebra.solver.schedulers import Scheduler +from cebra.solver.util import Meter + + +class MultiObjectiveConfig: + """Configuration class for setting up multi-objective learning with Cebra. + + + + Args: + loader: Data loader used for configurations. + """ + + def __init__(self, loader): + self.loader = loader + self.total_info = [] + self.current_info = {} + + def _check_overwriting_key(self, key): + if key in self.current_info: + warnings.warn( + "Configuration key already exists. Overwriting existing value. " + "If you don't want to overwrite you should call push() before.") + + def _check_pushed_status(self): + if "slice" not in self.current_info: + raise RuntimeError( + "Slice configuration is missing. Add it before pushing it.") + if "distributions" not in self.current_info: + raise RuntimeError( + "Distributions configuration is missing. Add it before pushing it." + ) + if "losses" not in self.current_info: + raise RuntimeError( + "Losses configuration is missing. Add it before pushing it.") + + def set_slice(self, start, end): + """Select the index range of the embedding. + + The configured loss will be applied to the ``start:end`` slice of the + embedding space. Make sure the selected dimensionality is appropriate + for the chosen loss function and distribution. + """ + self._check_overwriting_key("slice") + self.current_info['slice'] = (start, end) + + def set_loss(self, loss_name, **kwargs): + """Select the loss function to apply. + + Select a valid loss function from :py:mod:`cebra.models.criterions`. + Common choices are: + + - `FixedEuclideanInfoNCE` + - `FixedCosineInfoNCE` + + which can be passed as string values to ``loss_name``. The loss + will be applied to the range specified with ``set_slice``. + """ + self._check_overwriting_key("losses") + self.current_info["losses"] = {"name": loss_name, "kwargs": kwargs} + + def set_distribution(self, distribution_name, **kwargs): + """Select the distribution to sample from. + + The loss function specified in ``set_loss`` is applied to positive + and negative pairs sampled from the specified distribution. + """ + self._check_overwriting_key("distributions") + self.current_info["distributions"] = { + "name": distribution_name, + "kwargs": kwargs + } + + def push(self): + """Add a slice/loss/distribution setting to the config. + + After calling all of ``set_slice``, ``set_loss``, ``set_distribution``, + add this group to the config by calling this function. + + Once all configuration parts are pushed, call ``finalize`` to finish + the configuration. + """ + self._check_pushed_status() + print(f"Adding configuration for slice: {self.current_info['slice']}") + self.total_info.append(self.current_info) + self.current_info = {} + + def finalize(self): + """Finalize the multiobjective configuration.""" + self.losses = [] + self.feature_ranges = [] + self.feature_ranges_tuple = [] + + for info in self.total_info: + self._process_info(info) + + if len(set(self.feature_ranges_tuple)) != len( + self.feature_ranges_tuple): + raise RuntimeError( + "Feature ranges are not unique. Please check again and remove the duplicates. " + f"Feature ranges: {self.feature_ranges_tuple}") + + print("Creating MultiCriterion") + self.criterion = cebra.models.MultiCriterions(losses=self.losses, + mode="contrastive") + + def _process_info(self, info): + """ + Processes individual configuration info and updates the losses and feature ranges. + + Args: + info (dict): The configuration info to process. + """ + slice_info = info["slice"] + losses_info = info["losses"] + distributions_info = info["distributions"] + + self.losses.append( + dict(indices=(slice_info[0], slice_info[1]), + contrastive_loss=dict(name=losses_info['name'], + kwargs=losses_info['kwargs']))) + + self.feature_ranges.append(slice(slice_info[0], slice_info[1])) + self.feature_ranges_tuple.append((slice_info[0], slice_info[1])) + + print(f"Adding distribution of slice: {slice_info}") + self.loader.add_config( + dict(distribution=distributions_info["name"], + kwargs=distributions_info["kwargs"])) + + +@dataclasses.dataclass +class MultiobjectiveSolverBase(cebra_solver_single.SingleSessionSolver): + + feature_ranges: List[slice] = None + renormalize: bool = None + log: Dict[Tuple, + List[float]] = dataclasses.field(default_factory=lambda: ({})) + use_sam: bool = False + regularizer: torch.nn.Module = None + metadata: Dict = dataclasses.field(default_factory=lambda: ({ + "timestamp": None, + "batches_seen": None, + })) + + def __post_init__(self): + super().__post_init__() + + self.model = cebra.models.create_multiobjective_model( + module=self.model, + feature_ranges=self.feature_ranges, + renormalize=self.renormalize, + ) + + def parameters(self, session_id: Optional[int] = None): + """Iterate over all parameters.""" + super().parameters(session_id=session_id) + + for parameter in self.regularizer.parameters(): + yield parameter + + def fit(self, + loader: cebra.data.Loader, + valid_loader: cebra.data.Loader = None, + *, + valid_frequency: int = None, + log_frequency: int = None, + save_hook: Callable[[int, "Solver"], None] = None, + scheduler_regularizer: "Scheduler" = None, + scheduler_loss: "Scheduler" = None, + logger: logging.Logger = None): + """Train model for the specified number of steps. + + Args: + loader: Data loader, which is an iterator over `cebra.data.Batch` instances. + Each batch contains reference, positive and negative input samples. + valid_loader: Data loader used for validation of the model. + valid_frequency: The frequency for running validation on the ``valid_loader`` instance. + logdir: The logging directory for writing model checkpoints. The checkpoints + can be read again using the `solver.load` function, or manually via loading the + state dict. + save_hook: callback. It will be called when we run validation. + log_frequency: how frequent we log things. + logger: logger to log progress. None by default. + + """ + + def _run_validation(): + stats_val = self.validation(valid_loader, logger=logger) + if save_hook is not None: + save_hook(solver=self, step=num_steps) + return stats_val + + self._set_fitted_params(loader) + self.to(loader.device) + + iterator = self._get_loader(loader, + logger=logger, + log_frequency=log_frequency) + self.model.train() + for num_steps, batch in iterator: + weights_regularizer = None + if scheduler_regularizer is not None: + weights_regularizer = scheduler_regularizer.get_weights( + step=num_steps) + # NOTE(stes): Both SAM and Jacobian regularization is not yet supported. + # For this, we need to re-implement the closure logic below (right now, + # the closure function applies the non-regularized loss in the second + # step, it is unclear if that is the correct behavior. + assert not self.use_sam + + weights_loss = None + if scheduler_loss is not None: + weights_loss = scheduler_loss.get_weights() + + stats = self.step(batch, + weights_regularizer=weights_regularizer, + weights_loss=weights_loss) + + self._update_metadata(num_steps) + iterator.set_description(stats) + run_validation = (valid_loader + is not None) and (num_steps % valid_frequency + == 0) + if run_validation: + _run_validation() + + #TODO + #_run_validation() + + def _get_loader(self, loader, **kwargs): + return super()._get_loader(loader) + + def _update_metadata(self, num_steps): + self.metadata["timestamp"] = time.time() + self.metadata["batches_seen"] = num_steps + + def compute_regularizer(self, predictions, inputs): + regularizer = [] + for prediction in predictions: + R = self.regularizer(inputs, prediction.reference) + regularizer.append(R) + + return regularizer + + def create_closure(self, batch, weights_loss): + + def inner_closure(): + predictions = self._inference(batch) + losses = self.criterion(predictions) + + if weights_loss is not None: + assert len(weights_loss) == len( + losses + ), "Number of weights should match the number of losses" + losses = [ + weight * loss for weight, loss in zip(weights_loss, losses) + ] + + loss = sum(losses) + loss.backward() + return loss + + return inner_closure + + def step(self, + batch: cebra.data.Batch, + weights_loss: Optional[List[float]] = None, + weights_regularizer: Optional[List[float]] = None) -> dict: + """Perform a single gradient update with multiple objectives.""" + + closure = None + if self.use_sam: + closure = self.create_closure(batch, weights_loss) + + if weights_regularizer is not None: + assert isinstance(batch.reference, torch.Tensor) + batch.reference.requires_grad_(True) + + predictions = self._inference(batch) + losses = self.criterion(predictions) + + for i, loss_value in enumerate(losses): + key = "loss_train", i + self.log.setdefault(key, []).append(loss_value.item()) + + if weights_loss is not None: + losses = [ + weight * loss for weight, loss in zip(weights_loss, losses) + ] + + loss = sum(losses) + + if weights_regularizer is not None: + regularizer = self.compute_regularizer(predictions=predictions, + inputs=batch.reference) + assert len(weights_regularizer) == len(regularizer) == len(losses) + loss = loss + sum( + weight * reg + for weight, reg in zip(weights_regularizer, regularizer)) + + loss.backward() + self.optimizer.step(closure) + self.optimizer.zero_grad() + + if weights_regularizer is not None: + for i, (weight, + reg) in enumerate(zip(weights_regularizer, regularizer)): + assert isinstance(weight, float) + self.log.setdefault(("regularizer", i), []).append(reg.item()) + self.log.setdefault(("regularizer_weight", i), + []).append(weight) + + if weights_loss is not None: + for i, weight in enumerate(weights_loss): + assert isinstance(weight, float) or isinstance(weight, int) + self.log.setdefault(("loss_weight", i), []).append(weight) + + # add sum_loss_train + self.log.setdefault(("sum_loss_train",), []).append(loss.item()) + return {"sum_loss_train": loss.item()} + + @torch.no_grad() + def _compute_metrics(self): + # NOTE: We set split_outputs = False when we compute + # validation metrics, otherwise it returns a tuple + # which led to a bug before. + embeddings = {} + self.model.set_split_outputs(False) + for split in self.metrics.splits: + embedding_tensor = self.transform( + self.metrics.datasets[split].neural) + embedding_np = embedding_tensor.cpu().numpy() + assert embedding_np.shape[1] == self.model.num_output + embeddings[split] = embedding_np + + self.model.set_split_outputs(True) + return self.metrics.compute_metrics(embeddings) + + @torch.no_grad() + def validation( + self, + loader: cebra.data.Loader, + logger=None, + weights_loss: Optional[List[float]] = None, + ): + loader.dataset.configure_for(self.model) + iterator = self._get_loader(loader) + + self.model.eval() + total_loss = Meter() + + losses_dict = {} + for _, batch in iterator: + predictions = self._inference(batch) + losses = self.criterion(predictions) + + if weights_loss is not None: + assert len(weights_loss) == len( + losses + ), "Number of weights should match the number of losses" + losses = [ + weight * loss for weight, loss in zip(weights_loss, losses) + ] + + total_loss.add(sum(losses).item()) + + for i, loss_value in enumerate(losses): + key = "loss_val", i + losses_dict.setdefault(key, []).append(loss_value.item()) + + losses_dict_mean = {k: np.mean(v) for k, v in losses_dict.items()} + stats_val = {**losses_dict_mean} + + if self.metrics is not None: + metrics = self._compute_metrics() + stats_val.update(metrics) + + for key, value in stats_val.items(): + self.log.setdefault(key, []).append(value) + + if logger is not None: + formatted_loss = ', '.join([ + f"{'_'.join(map(str, key))}:{value:.3f}" + for key, value in stats_val.items() + if key[0].startswith("loss") + ]) + formatted_r2 = ', '.join([ + f"{'_'.join(map(str, key))}:{value:.3f}" + for key, value in stats_val.items() + if key[0].startswith("r2") + ]) + logger.info(f"Val: {formatted_loss}") + logger.info(f"Val: {formatted_r2}") + + # add sum_loss_valid + sum_loss_valid = total_loss.average + self.log.setdefault(("sum_loss_val",), []).append(sum_loss_valid) + return stats_val + + +@register("supervised-solver-xcebra") +@dataclasses.dataclass +class SupervisedMultiobjectiveSolverxCEBRA(MultiobjectiveSolverBase): + """Supervised neural network training using the MSE loss. + + This solver can be used as a baseline variant instead of the contrastive solver, + :py:class:`cebra.solver.multiobjective.ContrastiveMultiobjectiveSolverxCEBRA`. + """ + + _variant_name = "supervised-solver-xcebra" + + def _inference(self, batch): + """Compute predictions (discrete/continuous) for the batch.""" + pred_refs = self.model(batch.reference) + prediction_batches = [] + for i, label_data in enumerate(batch.positive): + prediction_batches.append( + cebra.data.Batch(reference=pred_refs[i], + positive=label_data, + negative=None)) + return prediction_batches + + +@register("multiobjective-solver") +@dataclasses.dataclass +class ContrastiveMultiobjectiveSolverxCEBRA(MultiobjectiveSolverBase): + """Multi-objective solver for CEBRA. + + This solver is used for training CEBRA models with multiple objectives. + + See Also: + :py:class:`cebra.solver.multiobjective.SupervisedMultiobjectiveSolverxCEBRA` + :py:class:`cebra.solver.multiobjective.MultiObjectiveConfig` + :py:class:`cebra.models.multicriterions.MultiCriterions` + """ + + _variant_name = "contrastive-solver-xcebra" + + def _inference(self, batch: cebra.data.Batch) -> cebra.data.Batch: + pred_refs = self.model(batch.reference) + pred_negs = self.model(batch.negative) + + prediction_batches = [] + for i, positive in enumerate(batch.positive): + pred_pos = self.model(positive) + prediction_batches.append( + cebra.data.Batch(pred_refs[i], pred_pos[i], pred_negs[i])) + + return prediction_batches diff --git a/cebra/solver/regularized.py b/cebra/solver/regularized.py new file mode 100644 index 00000000..41284529 --- /dev/null +++ b/cebra/solver/regularized.py @@ -0,0 +1,105 @@ +# +# CEBRA: Consistent EmBeddings of high-dimensional Recordings using Auxiliary variables +# © Mackenzie W. Mathis & Steffen Schneider (v0.4.0+) +# Source code: +# https://github.com/AdaptiveMotorControlLab/CEBRA +# +# Please see LICENSE.md for the full license document: +# https://github.com/AdaptiveMotorControlLab/CEBRA/blob/main/LICENSE.md +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""Regularized contrastive learning.""" + +from typing import Dict, Optional + +import literate_dataclasses as dataclasses +import torch + +import cebra +import cebra.data +import cebra.models +from cebra.solver import register +from cebra.solver.single_session import SingleSessionSolver + + +@register("regularized-solver") +@dataclasses.dataclass +class RegularizedSolver(SingleSessionSolver): + """Optimize a model using Jacobian Regularizer.""" + + _variant_name = "regularized-solver" + log: Dict = dataclasses.field(default_factory=lambda: ({ + "pos": [], + "neg": [], + "loss": [], + "loss_reg": [], + "temperature": [], + "reg": [], + "reg_lambda": [], + })) + + lambda_JR: Optional[float] = None + + def __post_init__(self): + super().__post_init__() + #TODO: rn we are using the full jacobian. Can be optimized later if needed. + self.jac_regularizer = cebra.models.JacobianReg(n=-1) + + def step(self, batch: cebra.data.Batch) -> dict: + """Perform a single gradient update using the jacobian regularizaiton!. + + Args: + batch: The input samples + + Returns: + Dictionary containing training metrics. + """ + + self.optimizer.zero_grad() + batch.reference.requires_grad = True + prediction = self._inference(batch) + R = self.jac_regularizer(batch.reference, prediction.reference) + + loss, align, uniform = self.criterion(prediction.reference, + prediction.positive, + prediction.negative) + loss_reg = loss + self.lambda_JR * R + + loss_reg.backward() + self.optimizer.step() + self.history.append(loss.item()) + stats = dict(pos=align.item(), + neg=uniform.item(), + loss=loss.item(), + loss_reg=loss_reg.item(), + reg=R.item(), + temperature=self.criterion.temperature, + reg_lambda=(self.lambda_JR * R).item()) + + for key, value in stats.items(): + self.log[key].append(value) + return stats + + +def _prepare_inputs(inputs): + if not isinstance(inputs, torch.Tensor): + inputs = torch.from_numpy(inputs) + inputs.requires_grad_(True) + return inputs + + +def _prepare_model(model): + for p in model.parameters(): + p.requires_grad_(False) + return model diff --git a/cebra/solver/schedulers.py b/cebra/solver/schedulers.py new file mode 100644 index 00000000..1da637af --- /dev/null +++ b/cebra/solver/schedulers.py @@ -0,0 +1,97 @@ +# +# CEBRA: Consistent EmBeddings of high-dimensional Recordings using Auxiliary variables +# © Mackenzie W. Mathis & Steffen Schneider (v0.4.0+) +# Source code: +# https://github.com/AdaptiveMotorControlLab/CEBRA +# +# Please see LICENSE.md for the full license document: +# https://github.com/AdaptiveMotorControlLab/CEBRA/blob/main/LICENSE.md +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import abc +import dataclasses +from typing import List + +import cebra.registry + +cebra.registry.add_helper_functions(__name__) + +__all__ = ["Scheduler", "ConstantScheduler", "LinearScheduler", "LinearRampUp"] + + +@dataclasses.dataclass +class Scheduler(abc.ABC): + + def __post_init__(self): + pass + + @abc.abstractmethod + def get_weights(self): + pass + + +@register("constant-weight") +@dataclasses.dataclass +class ConstantScheduler(Scheduler): + initial_weights: List[float] + + def __post_init__(self): + super().__post_init__() + + def get_weights(self): + weights = self.initial_weights + if len(weights) == 0: + weights = None + return weights + + +@register("linear-scheduler") +@dataclasses.dataclass +class LinearScheduler(Scheduler): + n_splits: int + step_to_switch_on_reg: int + step_to_switch_off_reg: int + start_weight: float + end_weight: float + stay_constant_after_switch_off: bool = False + + def __post_init__(self): + super().__post_init__() + assert self.step_to_switch_off_reg > self.step_to_switch_on_reg + + def get_weights(self, step): + if self.step_to_switch_on_reg is not None: + if step >= self.step_to_switch_on_reg and step <= self.step_to_switch_off_reg: + interpolation_factor = min( + 1.0, (step - self.step_to_switch_on_reg) / + (self.step_to_switch_off_reg - self.step_to_switch_on_reg)) + weight = self.start_weight + ( + self.end_weight - self.start_weight) * interpolation_factor + weights = [weight] * self.n_splits + elif self.stay_constant_after_switch_off and step > self.step_to_switch_off_reg: + weight = self.end_weight + weights = [weight] * self.n_splits + else: + weights = None + + return weights + + +@register("linear-ramp-up") +@dataclasses.dataclass +class LinearRampUp(LinearScheduler): + + def __post_init__(self): + super().__post_init__() + self.stay_constant_after_switch_off = True diff --git a/cebra/solver/single_session.py b/cebra/solver/single_session.py index 62570a57..e452269d 100644 --- a/cebra/solver/single_session.py +++ b/cebra/solver/single_session.py @@ -25,6 +25,7 @@ from typing import List, Optional, Tuple, Union import literate_dataclasses as dataclasses +import numpy.typing as npt import torch import cebra @@ -46,8 +47,18 @@ class SingleSessionSolver(abc_.Solver): _variant_name = "single-session" def parameters(self, session_id: Optional[int] = None): - """Iterate over all parameters.""" - self._check_is_session_id_valid(session_id=session_id) + """Iterate over all parameters. + + Args: + session_id: The session ID, an :py:class:`int` between 0 and + the number of sessions -1 for multisession, and set to + ``None`` for single session. + + Yields: + The parameters of the model. + """ + # If session_id is invalid, it doesn't matter, since we are + # using a single session solver. for parameter in self.model.parameters(): yield parameter @@ -103,7 +114,7 @@ def _select_model( List[torch.Tensor]], session_id: Optional[int] ) -> Tuple[Union[List[torch.nn.Module], torch.nn.Module], cebra.data.datatypes.Offset]: - """ Select the model based on the input dimension and session ID. + """ Select the (trained) model based on the input dimension and session ID. Args: inputs: Data to infer using the selected model. @@ -114,8 +125,9 @@ def _select_model( Returns: The model (first returns) and the offset of the model (second returns). """ - self._check_is_inputs_valid(inputs, session_id=session_id) self._check_is_session_id_valid(session_id=session_id) + self._check_is_fitted() + self._check_is_inputs_valid(inputs, session_id=session_id) model = self.model offset = model.get_offset() @@ -134,10 +146,7 @@ def _inference(self, batch: cebra.data.Batch) -> cebra.data.Batch: across the sample dimensions, the output data should be aligned and ``batch.index`` should be set to ``None``. """ - batch.to(self.device) - ref = self.model(batch.reference) - pos = self.model(batch.positive) - neg = self.model(batch.negative) + ref, pos, neg = self._compute_features(batch) return cebra.data.Batch(ref, pos, neg) def get_embedding(self, data: torch.Tensor) -> torch.Tensor: @@ -195,7 +204,118 @@ def __post_init__(self): self.reference_model = copy.deepcopy(self.model) self.reference_model.to(self.model.device) - def _inference(self, batch): + def _select_model( + self, + inputs: Union[torch.Tensor, List[torch.Tensor]], + session_id: Optional[int] = None, + use_reference_model: bool = False, + ) -> Tuple[Union[List[torch.nn.Module], torch.nn.Module], + cebra.data.datatypes.Offset]: + """ Select the model based on the input dimension and session ID. + + Args: + inputs: Data to infer using the selected model. + session_id: The session ID, an :py:class:`int` between 0 and + the number of sessions -1 for multisession, and set to + ``None`` for single session. + use_reference_model: Flag for using ``reference_model``. + + Returns: + The model (first returns) and the offset of the model (second returns). + """ + self._check_is_inputs_valid(inputs, session_id=session_id) + self._check_is_session_id_valid(session_id=session_id) + + if use_reference_model: + model = self.reference_model + else: + model = self.model + + if hasattr(model, 'get_offset'): + offset = model.get_offset() + else: + offset = None + return model, offset + + @torch.no_grad() + def transform(self, + inputs: Union[torch.Tensor, List[torch.Tensor], npt.NDArray], + pad_before_transform: bool = True, + session_id: Optional[int] = None, + batch_size: Optional[int] = None, + use_reference_model: bool = False) -> torch.Tensor: + """Compute the embedding. + This function by default use ``model`` that was trained to encode the positive + and negative samples. To use ``reference_model`` instead of ``model`` + ``use_reference_model`` should be equal ``True``. + Args: + inputs: The input signal + use_reference_model: Flag for using ``reference_model`` + Returns: + The output embedding. + """ + if isinstance(inputs, list): + raise NotImplementedError( + "Inputs to transform() should be the data for a single session." + ) + elif not isinstance(inputs, torch.Tensor): + raise ValueError( + f"Inputs should be a torch.Tensor, not {type(inputs)}.") + + if not hasattr(self, "history") and len(self.history) > 0: + raise ValueError( + f"This {type(self).__name__} instance is not fitted yet. Call 'fit' with " + "appropriate arguments before using this estimator.") + model, offset = self._select_model( + inputs, session_id, use_reference_model=use_reference_model) + + if len(offset) < 2 and pad_before_transform: + pad_before_transform = False + + model.eval() + if batch_size is not None: + output = abc_._batched_transform( + model=model, + inputs=inputs, + offset=offset, + batch_size=batch_size, + pad_before_transform=pad_before_transform, + ) + else: + output = abc_._transform(model=model, + inputs=inputs, + offset=offset, + pad_before_transform=pad_before_transform) + + return output + + def _compute_features( + self, + batch: cebra.data.Batch, + model: Optional[torch.nn.Module] = None + ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: + batch.to(self.device) + ref = self.reference_model(batch.reference) + pos = self.model(batch.positive) + neg = self.model(batch.negative) + return ref, pos, neg + + def _inference(self, batch: cebra.data.Batch) -> cebra.data.Batch: + """Given a batch of input examples, computes the feature representation/embedding. + + The reference samples are processed with a different model than the + positive and negative samples. + + Args: + batch: The input data, not necessarily aligned across the batch + dimension. This means that ``batch.index`` specifies the map + between reference/positive samples, if not equal ``None``. + + Returns: + Processed batch of data. While the input data might not be aligned + across the sample dimensions, the output data should be aligned and + ``batch.index`` should be set to ``None``. + """ batch.to(self.device) ref = self.reference_model(batch.reference) pos = self.model(batch.positive) @@ -211,6 +331,21 @@ class SingleSessionHybridSolver(abc_.MultiobjectiveSolver, SingleSessionSolver): _variant_name = "single-session-hybrid" def _inference(self, batch: cebra.data.Batch) -> cebra.data.Batch: + """Given a batch of input examples, computes the feature representation/embedding. + + The samples are processed with both a time-contrastive module and a + behavior-contrastive module, that are part of the same model. + + Args: + batch: The input data, not necessarily aligned across the batch + dimension. This means that ``batch.index`` specifies the map + between reference/positive samples, if not equal ``None``. + + Returns: + Processed batch of data. While the input data might not be aligned + across the sample dimensions, the output data should be aligned and + ``batch.index`` should be set to ``None``. + """ batch.to(self.device) behavior_ref = self.model(batch.reference)[0] behavior_pos = self.model(batch.positive[:int(len(batch.positive) // @@ -228,7 +363,7 @@ def _select_model( List[torch.Tensor]], session_id: Optional[int] ) -> Tuple[Union[List[torch.nn.Module], torch.nn.Module], cebra.data.datatypes.Offset]: - """ Select the model based on the input dimension and session ID. + """ Select the (trained) model based on the input dimension and session ID. Args: inputs: Data to infer using the selected model. @@ -239,14 +374,12 @@ def _select_model( Returns: The model (first returns) and the offset of the model (second returns). """ - self._check_is_inputs_valid(inputs, session_id=session_id) self._check_is_session_id_valid(session_id=session_id) + self._check_is_fitted() + self._check_is_inputs_valid(inputs, session_id=session_id) model = self.model.module - if hasattr(model, 'get_offset'): - offset = model.get_offset() - else: - offset = None + offset = model.get_offset() return model, offset @@ -303,6 +436,18 @@ def get_embedding(self, data): return self.model(data[0].T) def _inference(self, batch: cebra.data.Batch) -> cebra.data.Batch: + """Given a batch of input examples, computes the feature representation/embedding. + + Args: + batch: The input data, not necessarily aligned across the batch + dimension. This means that ``batch.index`` specifies the map + between reference/positive samples, if not equal ``None``. + + Returns: + Processed batch of data. While the input data might not be aligned + across the sample dimensions, the output data should be aligned and + ``batch.index`` should be set to ``None``. + """ outputs = self.get_embedding(self.neural) idc = batch.positive - self.offset.left >= len(outputs) batch.positive[idc] = batch.reference[idc] diff --git a/cebra/solver/unified_session.py b/cebra/solver/unified_session.py new file mode 100644 index 00000000..59945927 --- /dev/null +++ b/cebra/solver/unified_session.py @@ -0,0 +1,437 @@ +# +# CEBRA: Consistent EmBeddings of high-dimensional Recordings using Auxiliary variables +# © Mackenzie W. Mathis & Steffen Schneider (v0.4.0+) +# Source code: +# https://github.com/AdaptiveMotorControlLab/CEBRA +# +# Please see LICENSE.md for the full license document: +# https://github.com/AdaptiveMotorControlLab/CEBRA/blob/main/LICENSE.md +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +"""Solver implementations for unified-session datasets.""" + +from typing import List, Optional, Tuple, Union + +import literate_dataclasses as dataclasses +import numpy as np +import torch + +import cebra +import cebra.data +import cebra.distributions +import cebra.models +import cebra.solver.base as abc_ +from cebra.solver import register + + +@register("unified-session") +@dataclasses.dataclass +class UnifiedSolver(abc_.Solver): + """Multi session training, considering a single model for all sessions.""" + + _variant_name = "unified-session" + + def parameters(self, session_id: Optional[int] = None): # same as single + """Iterate over all parameters.""" + for parameter in self.model.parameters(): + yield parameter + + for parameter in self.criterion.parameters(): + yield parameter + + def _set_fitted_params(self, loader: cebra.data.Loader): # mix + """Set parameters once the solver is fitted. + + In single session solver, the number of session is set to None and the number of + features is set to the number of neurons in the dataset. + + Args: + loader: Loader used to fit the solver. + """ + self.num_sessions = loader.dataset.num_sessions + self.n_features = loader.dataset.input_dimension + + def _check_is_inputs_valid(self, inputs: Union[torch.Tensor, + List[torch.Tensor]], + session_id: int): + """Check that the inputs can be infered using the selected model. + + Note: This method checks that the number of neurons in the input is + similar to the input dimension to the selected model. + + Args: + inputs: Data to infer using the selected model. + session_id: The session ID, an :py:class:`int` between 0 and + the number of sessions -1 for multisession, and set to + ``None`` for single session. + """ + + if isinstance(inputs, list): + inputs_shape = 0 + for i in range(len(inputs)): + inputs_shape += inputs[i].shape[1] + elif isinstance(inputs, + torch.Tensor): #NOTE(celia): flexible input at training + raise NotImplementedError + else: + raise NotImplementedError + + if self.n_features != inputs_shape: + raise ValueError( + f"Invalid input shape: model requires an input of shape" + f"(n_samples, {self.n_features}), got (n_samples, {inputs.shape[1]})." + ) + + def _check_is_session_id_valid(self, + session_id: Optional[int] = None + ): # same as multi + """Check that the session ID provided is valid for the solver instance. + + The session ID must be non-null and between 0 and the number session in the dataset. + + Args: + session_id: The session ID to check. + """ + + if session_id is None: + raise RuntimeError( + "No session_id provided: multisession model requires a session_id to choose the model corresponding to your data shape." + ) + if session_id >= self.num_sessions or session_id < 0: + raise RuntimeError( + f"Invalid session_id {session_id}: session_id for the current multisession model must be between 0 and {self.num_sessions-1}." + ) + + def _select_model( + self, + inputs: Union[torch.Tensor, List[torch.Tensor]], + session_id: Optional[int] = None + ) -> Tuple[Union[List[torch.nn.Module], torch.nn.Module], + cebra.data.datatypes.Offset]: + """ Select the model based on the input dimension and session ID. + + Args: + inputs: Data to infer using the selected model. + session_id: The session ID, an :py:class:`int` between 0 and + the number of sessions -1 for multisession, and set to + ``None`` for single session. + + Returns: + The model (first returns) and the offset of the model (second returns). + """ + model = self.model + offset = model.get_offset() + return model, offset + + def _single_model_inference(self, batch: cebra.data.Batch, + model: torch.nn.Module) -> cebra.data.Batch: + """Given a single batch of input examples, computes the feature representation/embedding. + + Args: + batch: The input data, not necessarily aligned across the batch + dimension. This means that ``batch.index`` specifies the map + between reference/positive samples, if not equal ``None``. + model: The model to use for inference. + + Returns: + Processed batch of data. While the input data might not be aligned + across the sample dimensions, the output data should be aligned and + ``batch.index`` should be set to ``None``. + """ + ref, pos, neg = self._compute_features(batch, model) + ref = ref.unsqueeze(0) + pos = pos.unsqueeze(0) + neg = neg.unsqueeze(0) + + num_features = neg.shape[2] + + return cebra.data.Batch( + reference=ref.view(-1, num_features), + positive=pos.view(-1, num_features), + negative=neg.view(-1, num_features), + ) + + def _inference(self, batch: cebra.data.Batch) -> cebra.data.Batch: + """Given batches of input examples, computes the feature representations/embeddings. + + Args: + batches: A list of input data, not necessarily aligned across the batch + dimension. This means that ``batch.index`` specifies the map + between reference/positive samples, if not equal ``None``. + + Returns: + Processed batch of data. While the input data might not be aligned + across the sample dimensions, the output data should be aligned and + ``batch.index`` should be set to ``None``. + """ + return self._single_model_inference(batch, self.model) + + @torch.no_grad() + def transform(self, + inputs: List[torch.Tensor], + labels: List[torch.Tensor], + pad_before_transform: bool = True, + session_id: Optional[int] = None, + batch_size: Optional[int] = 512) -> torch.Tensor: + """Compute the embedding for the `session_id`th session of the dataset. + + Note: + Compared to the other :py:class:`cebra.solver.base.Solver`, we need all the sessions of + the dataset to transform the data, as the sampling is across all the sessions. + + Args: + inputs: The input signal for all sessions. + labels: The auxiliary variables to use for sampling. + session_id: The session ID, an :py:class:`int` between 0 and + the number of sessions -1. + batch_size: If not None, batched inference will be applied. + + Note: + The ``session_id`` is needed in order to sample the corresponding number of samples and + return an embedding of the expected shape. + + Note: + The batched inference will be required in most cases. Default is set to ``100`` for that reason. + + Returns: + The output embedding for the session corresponding to the provided ID `session_id`. The shape + is (num_samples(session_id), output_dimension)``. + + """ + if not isinstance(inputs, list): + raise ValueError( + f"Inputs to transform() should be a list, not {type(inputs)}.") + + self._check_is_fitted() + + if session_id is None: + raise ValueError("Session ID is required for multi-session models.") + + # Sampling according to session_id required + dataset = cebra.data.UnifiedDataset( + cebra.data.TensorDataset( + inputs[i], continuous=labels[i], offset=cebra.data.Offset(0, 1)) + for i in range(len(inputs))).to(self.device) + + # Only used to sample the reference samples + loader = cebra.data.UnifiedLoader(dataset, num_steps=1) + + # Sampling in batch + refs_data_batch_embeddings = [] + batch_range = range(0, len(dataset.get_session(session_id)), batch_size) + if len(batch_range) < 2: + raise ValueError( + "Not enough data to perform the batched transform. Please provide a larger dataset or reduce the batch size." + ) + for batch_start in batch_range: + batch_end = min(batch_start + batch_size, + len(dataset.get_session(session_id))) + + if batch_start == batch_range[-2]: # one before last batch + last_start = batch_start + continue + if batch_start == batch_range[-1]: # last batch, likely uncomplete + batch_start = last_start + batch_end = len(dataset.get_session(session_id)) + + refs_idx_batch = loader.sampler.sample_all_sessions( + ref_idx=torch.arange(batch_start, batch_end), + session_id=session_id).to(self.device) + + refs_data_batch = [ + session[refs_idx_batch[session_id]] + for session_id, session in enumerate(dataset.iter_sessions()) + ] + refs_data_batch_embeddings.append(super().transform( + torch.cat(refs_data_batch, dim=1).squeeze(), + pad_before_transform=pad_before_transform)) + return torch.cat(refs_data_batch_embeddings, dim=0) + + @torch.no_grad() + def single_session_transform( + self, + inputs: Union[torch.Tensor, List[torch.Tensor]], + session_id: Optional[int] = None, + pad_before_transform: bool = True, + padding_mode: str = "zero", + batch_size: Optional[int] = 100) -> torch.Tensor: + """Compute the embedding for the `session_id`th session of the dataset without labels alignement. + + By padding the channels that don't correspond to the {session_id}th session, we can + use a single session solver without behavioral alignment. + + Note: The embedding will not benefit from the behavioral alignment, and consequently + from the information contained in the other sessions. We expect single session encoder + behavioral decoding performances. + + Args: + inputs: The input signal for all sessions. + session_id: The session ID, an :py:class:`int` between 0 and + the number of sessions. + pad_before_transform: If True, pads the input before applying the transform. + padding_mode: The mode to use for padding. Padding is done in the following + ways, either by padding all the other sessions to the length of the + {session_id}th session, or by resampling all sessions in a random way: + - `time`: pads the inputs that are not infered to the maximum length of + the session and then zeros so that the lenght is the same as the + {session_id}th session length. + - `zero`: pads the inputs that are not infered with zeros so that the + lenght is the same as the {session_id}th session length. + - `poisson`: pads the inputs that are not infered with a poisson distribution + so that the lenght is the same as the {session_id}th session length. + - `random`: pads all sessions with random values sampled from a normal + distribution. + - `random_poisson`: pads all sessions with random values sampled from a + poisson distribution. + + batch_size: If not None, batched inference will be applied. + + Returns: + The output embedding for the session corresponding to the provided ID `session_id`. The shape + is (num_samples(session_id), output_dimension)``. + """ + inputs = [session.to(self.device) for session in inputs] + + zero_shape = inputs[session_id].shape[0] + + if padding_mode == "time" or padding_mode == "zero" or padding_mode == "poisson": + for i in range(len(inputs)): + if i != session_id: + if padding_mode == "time": + if inputs[i].shape[0] >= zero_shape: + inputs[i] = inputs[i][:zero_shape] + else: + inputs[i] = torch.cat( + (inputs[i], + torch.zeros( + (zero_shape - inputs[i].shape[0], + inputs[i].shape[1])).to(self.device))) + if padding_mode == "poisson": + inputs[i] = torch.poisson( + torch.ones((zero_shape, inputs[i].shape[1]))) + if padding_mode == "zero": + inputs[i] = torch.zeros( + (zero_shape, inputs[i].shape[1])) + padded_inputs = torch.cat( + [session.to(self.device) for session in inputs], dim=1) + + elif padding_mode == "random_poisson": + padded_inputs = torch.poisson( + torch.ones((zero_shape, self.n_features))) + elif padding_mode == "random": + padded_inputs = torch.normal( + torch.zeros((zero_shape, self.n_features)), + torch.ones((zero_shape, self.n_features))) + + else: + raise ValueError( + f"Invalid padding mode: {padding_mode}. " + "Choose from 'time', 'zero', 'poisson', 'random', or 'random_poisson'." + ) + + # Single session solver transform call + return super().transform(inputs=padded_inputs, + pad_before_transform=pad_before_transform, + batch_size=batch_size) + + @torch.no_grad() + def decoding(self, + train_loader: cebra.data.Loader, + valid_loader: Optional[cebra.data.Loader] = None, + decode: str = "ridge", + max_sessions: int = 5, + max_timesteps: int = 512) -> float: + # Sample a fixed number of sessions to compute the decoding score + # Sample a fixed number of timesteps to compute the decoding score (always the first ones) + if train_loader.dataset.num_sessions > max_sessions: + sessions = np.random.choice(np.arange( + train_loader.dataset.num_sessions), + size=max_sessions, + replace=False) + else: + sessions = np.arange(train_loader.dataset.num_sessions) + + train_scores, valid_scores = [], [] + for i in sessions: + if train_loader.dataset.get_session( + i).neural.shape[0] > max_timesteps: + train_end = max_timesteps + else: + train_end = -1 + train_x = self.transform([ + train_loader.dataset.get_session(j).neural[:train_end] + for j in range(train_loader.dataset.num_sessions) + ], [ + train_loader.dataset.get_session(j).continuous_index[:train_end] + if train_loader.dataset.get_session(j).continuous_index + is not None else + train_loader.dataset.get_session(j).discrete_index[:train_end] + for j in range(train_loader.dataset.num_sessions) + ], + session_id=i, + batch_size=128) + train_y = train_loader.dataset.get_session( + i + ).continuous_index[:train_end] if train_loader.dataset.get_session( + i + ).continuous_index is not None else train_loader.dataset.get_session( + i).discrete_index[:train_end] + + if valid_loader is not None: + if valid_loader.dataset.get_session( + i).neural.shape[0] > max_timesteps: + valid_end = max_timesteps + else: + valid_end = -1 + valid_x = self.transform([ + valid_loader.dataset.get_session(j).neural[:valid_end] + for j in range(valid_loader.dataset.num_sessions) + ], [ + valid_loader.dataset.get_session( + j).continuous_index[:valid_end] + if valid_loader.dataset.get_session(j).continuous_index + is not None else valid_loader.dataset.get_session( + j).discrete_index[:valid_end] + for j in range(valid_loader.dataset.num_sessions) + ], + session_id=i, + batch_size=128) + valid_y = valid_loader.dataset.get_session( + i + ).continuous_index[:valid_end] if valid_loader.dataset.get_session( + i + ).continuous_index is not None else valid_loader.dataset.get_session( + i).discrete_index[:valid_end] + + if decode == "knn": + decoder = cebra.KNNDecoder() + elif decode == "ridge": + decoder = cebra.RidgeRegressor() + else: + raise NotImplementedError(f"Decoder {decode} not implemented.") + + decoder.fit(train_x.cpu().numpy(), train_y.cpu().numpy()) + train_scores.append( + decoder.score(train_x.cpu().numpy(), + train_y.cpu().numpy())) + + if valid_loader is not None: + valid_scores.append( + decoder.score(valid_x.cpu().numpy(), + valid_y.cpu().numpy())) + + if valid_loader is None: + return np.array(train_scores).mean() + else: + return np.array(train_scores).mean(), np.array(valid_scores).mean() diff --git a/conda/cebra_paper.yml b/conda/cebra_paper.yml index e7537756..4b9e2b6e 100644 --- a/conda/cebra_paper.yml +++ b/conda/cebra_paper.yml @@ -39,7 +39,7 @@ dependencies: - "cebra[dev,integrations,datasets,demos]" - joblib - literate-dataclasses - - sklearn + - scikit-learn - scipy - torch - keras==2.3.1 diff --git a/conda/cebra_paper_m1.yml b/conda/cebra_paper_m1.yml index 32256758..3d8cd7b9 100644 --- a/conda/cebra_paper_m1.yml +++ b/conda/cebra_paper_m1.yml @@ -48,7 +48,7 @@ dependencies: - tensorflow-metal - joblib - literate-dataclasses - - sklearn + - scikit-learn - scipy - torch - umap-learn diff --git a/docs/.gitignore b/docs/.gitignore index a48ebfca..f7176a04 100644 --- a/docs/.gitignore +++ b/docs/.gitignore @@ -1,2 +1,3 @@ build/ page/ +root/static diff --git a/docs/Dockerfile b/docs/Dockerfile new file mode 100644 index 00000000..d96c24d2 --- /dev/null +++ b/docs/Dockerfile @@ -0,0 +1,14 @@ +FROM python:3.10 + +RUN apt-get update && apt-get install -y \ + git \ + make \ + pandoc \ + && rm -rf /var/lib/apt/lists/* + +COPY docs/requirements.txt . +RUN pip install -r requirements.txt + +#COPY setup.cfg . +#COPY pyproject.toml . +#COPY cebra/ . diff --git a/docs/Makefile b/docs/Makefile index 741d165e..9252ed72 100644 --- a/docs/Makefile +++ b/docs/Makefile @@ -3,7 +3,7 @@ # You can set these variables from the command line, and also # from the environment for the first two. -SPHINXOPTS ?= +SPHINXOPTS ?= -W --keep-going -n SPHINXBUILD ?= sphinx-build SOURCEDIR = source BUILDDIR = build @@ -18,6 +18,11 @@ help: html: PYTHONPATH=.. $(SPHINXBUILD) -M html "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) +# Build multiple versions +html_versions: + for v in latest v0.2.0 v0.3.0 v0.4.0; do \ + PYTHONPATH=.. $(SPHINXBUILD) -b html "$(SOURCEDIR)" "$(BUILDDIR)/$$v"; \ + done # Remove the current temp folder and page build clean: rm -rf build @@ -26,14 +31,31 @@ clean: # Checkout the source repository for CEBRA figures. Note that this requires SSH access # and might prompt you for an SSH key. source/cebra-figures: - git clone --depth 1 git@github.com:AdaptiveMotorControlLab/cebra-figures.git source/cebra-figures + cd $(dir $(realpath $(firstword $(MAKEFILE_LIST)))) && git clone --depth 1 git@github.com:AdaptiveMotorControlLab/cebra-figures.git source/cebra-figures + +source/demo_notebooks: + cd $(dir $(realpath $(firstword $(MAKEFILE_LIST)))) && git clone --depth 1 git@github.com:AdaptiveMotorControlLab/cebra-demos.git source/demo_notebooks + +source/demo_notebooks: + git clone --depth 1 git@github.com:AdaptiveMotorControlLab/cebra-demos.git source/demo_notebooks # Update the figures. Note that this might prompt you for an SSH key figures: source/cebra-figures cd source/cebra-figures && git pull --ff-only origin main +demos: source/demo_notebooks + cd source/demo_notebooks && git pull --ff-only origin main + +source/assets: + cd $(dir $(realpath $(firstword $(MAKEFILE_LIST)))) && git clone --depth 1 git@github.com:AdaptiveMotorControlLab/cebra-assets.git source/assets + +assets: source/assets + cd source/assets && git pull --ff-only origin main + cp -r source/assets/docs/* . + #rm -rf source/assets + # Build the page with pre-built figures -page: source/cebra-figures html +page: source/cebra-figures source/demo_notebooks html mkdir -p page/ mkdir -p page/docs mkdir -p page/staging/docs diff --git a/docs/README.md b/docs/README.md new file mode 100644 index 00000000..495e156c --- /dev/null +++ b/docs/README.md @@ -0,0 +1,14 @@ +# CEBRA documentation + +This directory contains the documentation for CEBRA. + +To build the docs, head to *the root folder of the repository* and run: + +```bash +./tools/build_docs.sh +``` + +This will build the docker container in [Dockerfile](Dockerfile) and run the `make docs` command from the root repo. +The exact requirements for building the docs are now listed in [requirements.txt](requirements.txt). + +For easier local development, docs are not using `sphinx-autobuild` and will by default be served at [http://127.0.0.1:8000](http://127.0.0.1:8000). diff --git a/docs/requirements.txt b/docs/requirements.txt new file mode 100644 index 00000000..880611c8 --- /dev/null +++ b/docs/requirements.txt @@ -0,0 +1,23 @@ +sphinx==7.4.7 +nbsphinx==0.9.6 +pydata-sphinx-theme==0.16.1 +pytest-sphinx==0.6.3 +sphinx-autobuild==2024.10.3 +sphinx-autodoc-typehints==1.19.0 +sphinx-copybutton==0.5.2 +sphinx-gallery==0.19.0 +sphinx-tabs==3.4.7 +sphinx-togglebutton==0.3.2 +sphinx_design==0.6.0 +sphinx_pydata_theme==0.1.0 +sphinxcontrib-applehelp==2.0.0 +sphinxcontrib-devhelp==2.0.0 +sphinxcontrib-htmlhelp==2.1.0 +sphinxcontrib-jsmath==1.0.1 +sphinxcontrib-qthelp==2.0.0 +sphinxcontrib-serializinghtml==2.0.0 + +literate_dataclasses +# For IPython.sphinxext.ipython_console_highlighting extension +ipython +numpy diff --git a/docs/root/index.html b/docs/root/index.html index 86015297..aa740039 100644 --- a/docs/root/index.html +++ b/docs/root/index.html @@ -7,21 +7,21 @@ - Learnable latent embeddings for joint behavioural and neural analysis - + CEBRA + - + - + @@ -36,7 +36,6 @@ CEBRA @@ -93,58 +116,26 @@
-

Learnable latent embeddings for joint behavioural and neural analysis

-
- - -
-
-
- Steffen Schneider*
- EPFL & IMPRS-IS - - -
-
- Jin Hwa Lee*
- EPFL - -
-
- Mackenzie Mathis
- EPFL - - -
+

CEBRA: a self-supervised learning algorithm for obtaining interpretable, Consistent EmBeddings of high-dimensional Recordings using Auxiliary variables

-
-
-
+ -
+ - - -
- CEBRA is a machine-learning - method that can be used to - compress time series in a way - that reveals otherwise hidden - structures in the variability of - the data. It excels on behavioural - and neural data recorded - simultaneously, and it can - decode activity from the visual - cortex of the mouse brain to - reconstruct a viewed video. +
- +

Demo Applications

-

Application of CEBRA-Behavior to rat hippocampus data (Grosmark and Buzsáki, 2016), showing position/neural activity (left), overlayed with decoding obtained by CEBRA. The current point in embedding space is highlighted (right). CEBRA obtains a median absolute error of 5cm (total track length: 160cm; see pre-print for details). Video is played at 2x real-time speed.

+

Application of CEBRA-Behavior to rat hippocampus data (Grosmark and Buzsáki, 2016), showing position/neural activity (left), overlayed with decoding obtained by CEBRA. The current point in embedding space is highlighted (right). CEBRA obtains a median absolute error of 5cm (total track length: 160cm; see Schneider et al. 2023 for details). Video is played at 2x real-time speed.

+
+ +
+ +
+ +
+ +

Interactive visualization of the CEBRA embedding for the rat hippocampus data. This 3D plot shows how neural activity is mapped to a lower-dimensional space that correlates with the animal's position and movement direction. Open In Colaboratory

+
+
+ +
+
+ + +

CEBRA applied to mouse primary visual cortex, collected at the Allen Institute (de Vries et al. 2020, Siegle et al. 2021). 2-photon and Neuropixels recordings are embedded with CEBRA using DINO frame features as labels. + The embedding is used to decode the video frames using a kNN decoder on the CEBRA-Behavior embedding from the test set.

- + + -

CEBRA applied to mouse primary visual cortex, collected at the Allen Institute (de Vries et al. 2020, Siegle et al. 2021). 2-photon and Neuropixels recordings are embedded with CEBRA using DINO frame features as labels. - The embedding is used to decode the video frames using a kNN decoder on the CEBRA-Behavior embedding from the test set.

+

CEBRA applied to M1 and S1 neural data, demonstrating how neural activity from primary motor and somatosensory cortices can be effectively embedded and analyzed. See DeWolf et al. 2024 for details.

+
+
+
+

Publications

+ +
+
+
Learnable latent embeddings for joint behavioural and neural analysis
+

Steffen Schneider*, Jin Hwa Lee*, Mackenzie Weygandt Mathis. Nature 2023

+

A comprehensive introduction to CEBRA, demonstrating its capabilities in joint behavioral and neural analysis across various datasets and species.

+ Read Paper + Preprint +
+
+ +
+
+
Time-series attribution maps with regularized contrastive learning
+

Steffen Schneider, Rodrigo González Laiz, Anastasiia Filipova, Markus Frey, Mackenzie Weygandt Mathis. AISTATS 2025

+

An extension of CEBRA that provides attribution maps for time-series data using regularized contrastive learning.

+ Read Paper + Preprint + NeurIPS-W 2023 Version +
+
+ +
+

Patent Information

+
+
+
Patent Pending
+

Please note EPFL has filed a patent titled "Dimensionality reduction of time-series data, and systems and devices that use the resultant embeddings" so if this does not work for your non-academic use case, please contact the Tech Transfer Office at EPFL.

+
+

- Abstract + Overview

@@ -209,31 +257,6 @@

-
-

- - Pre-Print -

-
- -
-

- The pre-print is available on arxiv at arxiv.org/abs/2204.00673. -

- -
-

@@ -244,8 +267,7 @@

You can find our official implementation of the CEBRA algorithm on GitHub: Watch and Star the repository to be notified of future updates and releases. - You can also follow us on Twitter or subscribe to our - mailing list for updates on the project. + You can also follow us on Twitter for updates on the project.

If you are interested in collaborations, please contact us via @@ -258,13 +280,13 @@

BibTeX

-

Please cite our paper as follows:

+

Please cite our papers as follows:

@article{schneider2023cebra,
-   author={Schneider, Steffen and Lee, Jin Hwa and Mathis, Mackenzie Weygandt},
+   author={Steffen Schneider and Jin Hwa Lee and Mackenzie Weygandt Mathis},
  title={Learnable latent embeddings for joint behavioural and neural analysis},
  journal={Nature},
  year={2023},
@@ -278,6 +300,58 @@

+
+
+ + @inproceedings{schneider2025timeseries,
+   title={Time-series attribution maps with regularized contrastive learning},
+   author={Steffen Schneider and Rodrigo Gonz{\'a}lez Laiz and Anastasiia Filippova and Markus Frey and Mackenzie Weygandt Mathis},
+   booktitle={The 28th International Conference on Artificial Intelligence and Statistics},
+   year={2025},
+   url={https://openreview.net/forum?id=aGrCXoTB4P}
+ } +
+
+
+ +
+

+ + Impact & Citations +

+
+ +
+

+ CEBRA has been cited in numerous high-impact publications across neuroscience, machine learning, and related fields. Our work has influenced research in neural decoding, brain-computer interfaces, computational neuroscience, and machine learning methods for time-series analysis. +

+ + + +
+
+

Our research has been cited in proceedings and journals including Nature Science ICML Nature Neuroscience ICML Neuron NeurIPS ICLR and others.

+
+
+
+ +
+
+ + MLAI Logo + +
+ © 2021 - present | EPFL Mathis Laboratory +
+
+
+
Webpage designed using Bootstrap 5 and Fontawesome 5. diff --git a/docs/root/robots.txt b/docs/root/robots.txt index 43249ef2..bbbcdfe9 100644 --- a/docs/root/robots.txt +++ b/docs/root/robots.txt @@ -1,3 +1,2 @@ User-agent: * Disallow: /staging/ -Disallow: /docs/ diff --git a/docs/source/_static/css/custom.js b/docs/source/_static/css/custom.js new file mode 100644 index 00000000..f9afa170 --- /dev/null +++ b/docs/source/_static/css/custom.js @@ -0,0 +1,6 @@ +requirejs.config({ + paths: { + base: '/static/base', + plotly: 'https://cdn.plot.ly/plotly-2.12.1.min.js?noext', + }, +}); diff --git a/docs/source/_templates/layout.html b/docs/source/_templates/layout.html index 2994db97..0140a5cf 100644 --- a/docs/source/_templates/layout.html +++ b/docs/source/_templates/layout.html @@ -1,11 +1,15 @@ {% extends "pydata_sphinx_theme/layout.html" %} -{% block fonts %} +{% block extrahead %} + + + +{% endblock %} +{% block fonts %} - {% endblock %} {% block docs_sidebar %} diff --git a/docs/source/api.rst b/docs/source/api.rst index 8989337f..846602f1 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -38,6 +38,9 @@ these components in other contexts and research code bases. api/pytorch/distributions api/pytorch/models api/pytorch/helpers + api/pytorch/multiobjective + api/pytorch/regularized + api/pytorch/attribution .. toctree:: :hidden: diff --git a/docs/source/api/pytorch/attribution.rst b/docs/source/api/pytorch/attribution.rst new file mode 100644 index 00000000..6efb043f --- /dev/null +++ b/docs/source/api/pytorch/attribution.rst @@ -0,0 +1,21 @@ +=================== +Attribution Methods +=================== + +.. automodule:: cebra.attribution + :members: + :show-inheritance: + +Different attribution methods +----------------------------- + +.. automodule:: cebra.attribution.attribution_models + :members: + :show-inheritance: + +Jacobian-based attribution +-------------------------- + +.. automodule:: cebra.attribution.jacobian_attribution + :members: + :show-inheritance: diff --git a/docs/source/api/pytorch/models.rst b/docs/source/api/pytorch/models.rst index ee3455bc..3fe2219b 100644 --- a/docs/source/api/pytorch/models.rst +++ b/docs/source/api/pytorch/models.rst @@ -43,12 +43,8 @@ Layers and model building blocks :show-inheritance: Multi-objective models -~~~~~~~~~~~~~~~~~~~~~~~~ +~~~~~~~~~~~~~~~~~~~~~~ -.. automodule:: cebra.models.multiobjective - :members: - :private-members: - :show-inheritance: - -.. - - projector +The multi-objective interface was moved to a separate section beginning with CEBRA 0.6.0. +Please see the :doc:`Multi-objective models ` section +for all details, both on the old and new API interface. diff --git a/docs/source/api/pytorch/multiobjective.rst b/docs/source/api/pytorch/multiobjective.rst new file mode 100644 index 00000000..c959cfa1 --- /dev/null +++ b/docs/source/api/pytorch/multiobjective.rst @@ -0,0 +1,15 @@ +====================== +Multi-objective models +====================== + +.. automodule:: cebra.solver.multiobjective + :members: + :show-inheritance: + +.. automodule:: cebra.models.multicriterions + :members: + :show-inheritance: + +.. automodule:: cebra.models.multiobjective + :members: + :show-inheritance: diff --git a/docs/source/api/pytorch/regularized.rst b/docs/source/api/pytorch/regularized.rst new file mode 100644 index 00000000..7da94603 --- /dev/null +++ b/docs/source/api/pytorch/regularized.rst @@ -0,0 +1,24 @@ +================================ +Regularized Contrastive Learning +================================ + +Regularized solvers +-------------------- + +.. automodule:: cebra.solver.regularized + :members: + :show-inheritance: + +Schedulers +---------- + +.. automodule:: cebra.solver.schedulers + :members: + :show-inheritance: + +Jacobian Regularization +----------------------- + +.. automodule:: cebra.models.jacobian_regularizer + :members: + :show-inheritance: diff --git a/docs/source/conf.py b/docs/source/conf.py index 025a988b..4147e7c9 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -26,21 +26,13 @@ # list see the documentation: # https://www.sphinx-doc.org/en/master/usage/configuration.html -# -- Path setup -------------------------------------------------------------- - -# If extensions (or modules to document with autodoc) are in another directory, -# add these directories to sys.path here. If the directory is relative to the -# documentation root, use os.path.abspath to make it absolute, like shown here. -# +import datetime import os +import pathlib import sys sys.path.insert(0, os.path.abspath(".")) -import datetime - -import cebra - def get_years(start_year=2021): year = datetime.datetime.now().year @@ -52,16 +44,31 @@ def get_years(start_year=2021): # -- Project information ----------------------------------------------------- project = "cebra" -copyright = f"""{get_years(2021)}, Steffen Schneider, Jin H Lee, Mackenzie Mathis""" -author = "Steffen Schneider, Jin H Lee, Mackenzie Mathis" -# The full version, including alpha/beta/rc tags -release = cebra.__version__ +copyright = f"""{get_years(2021)}""" +author = "See AUTHORS.md" +version_file = pathlib.Path( + __file__).parent.parent.parent / "cebra" / "__init__.py" +assert version_file.exists(), f"Could not find version file: {version_file}" +with version_file.open("r") as f: + for line in f: + if line.startswith("__version__"): + version = line.split("=")[1].strip().strip('"').strip("'") + print("Building docs for version:", version) + break + else: + raise ValueError("Could not find version in __init__.py") # -- General configuration --------------------------------------------------- # Add any Sphinx extension module names here, as strings. They can be # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. + +#https://github.com/spatialaudio/nbsphinx/issues/128#issuecomment-1158712159 +html_js_files = [ + "https://cdn.plot.ly/plotly-latest.min.js", # Add Plotly.js +] + extensions = [ "sphinx.ext.autodoc", "sphinx.ext.napoleon", @@ -73,7 +80,6 @@ def get_years(start_year=2021): "sphinx_tabs.tabs", "sphinx.ext.mathjax", "IPython.sphinxext.ipython_console_highlighting", - # "sphinx_panels", # Note: package to avoid: no longer maintained. "sphinx_design", "sphinx_togglebutton", "sphinx.ext.doctest", @@ -121,7 +127,8 @@ def get_years(start_year=2021): autodoc_member_order = "bysource" autodoc_mock_imports = [ - "torch", "nlb_tools", "tqdm", "h5py", "pandas", "matplotlib", "plotly" + "torch", "nlb_tools", "tqdm", "h5py", "pandas", "matplotlib", "plotly", + "cvxpy", "captum", "joblib", "scikit-learn", "scipy", "requests", "sklearn" ] # autodoc_typehints = "none" @@ -132,8 +139,18 @@ def get_years(start_year=2021): # directories to ignore when looking for source files. # This pattern also affects html_static_path and html_extra_path. exclude_patterns = [ - "**/todo", "**/src", "cebra-figures/figures.rst", "cebra-figures/*.rst", - "*/cebra-figures/*.rst", "demo_notebooks/README.rst" + "**/todo", + "**/src", + "cebra-figures/figures.rst", + "cebra-figures/*.rst", + "*/cebra-figures/*.rst", + "*/demo_notebooks/README.rst", + "demo_notebooks/README.rst", + # TODO(stes): Remove this from the assets repo, then remove here + "_static/figures_usage.ipynb", + "*/_static/figures_usage.ipynb", + "assets/**/*.ipynb", + "*/assets/**/*.ipynb" ] # -- Options for HTML output ------------------------------------------------- @@ -142,6 +159,23 @@ def get_years(start_year=2021): # a list of builtin themes. html_theme = "pydata_sphinx_theme" +html_context = { + "default_mode": "light", + "switcher": { + "version_match": + "latest", # Adjust this dynamically per version + "versions": [ + ("latest", "/latest/"), + ("v0.2.0", "/v0.2.0/"), + ("v0.3.0", "/v0.3.0/"), + ("v0.4.0", "/v0.4.0/"), + ("v0.5.0rc1", "/v0.5.0rc1/"), + ], + }, + "navbar_start": ["version-switcher", + "navbar-logo"], # Place the dropdown above the logo +} + # More info on theme options: # https://pydata-sphinx-theme.readthedocs.io/en/latest/user_guide/configuring.html html_theme_options = { @@ -156,11 +190,6 @@ def get_years(start_year=2021): "url": "https://twitter.com/cebraAI", "icon": "fab fa-twitter", }, - # { - # "name": "DockerHub", - # "url": "https://hub.docker.com/r/stffsc/cebra", - # "icon": "fab fa-docker", - # }, { "name": "PyPI", "url": "https://pypi.org/project/cebra/", @@ -172,23 +201,26 @@ def get_years(start_year=2021): "icon": "fas fa-graduation-cap", }, ], - "external_links": [ - # {"name": "Mathis Lab", "url": "http://www.mackenziemathislab.org/"}, - ], "collapse_navigation": False, - "navigation_depth": 4, - "show_nav_level": 2, + "navigation_depth": 1, + "show_nav_level": 1, "navbar_align": "content", "show_prev_next": False, + "navbar_end": ["theme-switcher", "navbar-icon-links.html"], + "navbar_persistent": [], + "header_links_before_dropdown": 7 } -html_context = {"default_mode": "dark"} +html_context = {"default_mode": "light"} html_favicon = "_static/img/logo_small.png" html_logo = "_static/img/logo_large.png" -# Remove the search field for now +# Replace with this configuration to enable "on this page" navigation html_sidebars = { - "**": ["search-field.html", "sidebar-nav-bs.html"], + "**": ["search-field.html", "sidebar-nav-bs", "page-toc.html"], + "demos": ["search-field.html", "sidebar-nav-bs"], + "api": ["search-field.html", "sidebar-nav-bs"], + "figures": ["search-field.html", "sidebar-nav-bs"], } # Disable links for embedded images @@ -207,6 +239,8 @@ def get_years(start_year=2021): ] nbsphinx_thumbnails = { + "demo_notebooks/CEBRA_best_practices": + "_static/thumbnails/cebra-best.png", "demo_notebooks/Demo_primate_reaching": "_static/thumbnails/ForelimbS1.png", "demo_notebooks/Demo_hippocampus": @@ -235,6 +269,8 @@ def get_years(start_year=2021): "_static/thumbnails/openScope_demo.png", "demo_notebooks/Demo_dandi_NeuroDataReHack_2023": "_static/thumbnails/dandi_demo_monkey.png", + "demo_notebooks/Demo_xCEBRA_RatInABox": + "_static/thumbnails/xCEBRA.png" } rst_prolog = r""" @@ -247,6 +283,9 @@ def get_years(start_year=2021): # Download link for the notebook, see # https://nbsphinx.readthedocs.io/en/0.3.0/prolog-and-epilog.html + +# fmt: off +# flake8: noqa: E501 nbsphinx_prolog = r""" .. only:: html @@ -269,3 +308,14 @@ def get_years(start_year=2021): ---- """ +# fmt: on +# flake8: enable=E501 + +# Configure nbsphinx to properly render Plotly plots +nbsphinx_execute = 'auto' +nbsphinx_allow_errors = True +nbsphinx_requirejs_path = 'https://cdnjs.cloudflare.com/ajax/libs/require.js/2.3.7/require.js' +nbsphinx_execute_arguments = [ + "--InlineBackend.figure_formats={'png', 'svg', 'pdf'}", + "--InlineBackend.rc=figure.dpi=96", +] diff --git a/docs/source/contributing.rst b/docs/source/contributing.rst index cc7ae0a8..7fcd16a1 100644 --- a/docs/source/contributing.rst +++ b/docs/source/contributing.rst @@ -155,13 +155,13 @@ Enter the build environment and build the package: host $ make interact docker $ make build # ... outputs ... - Successfully built cebra-X.X.XaX-py2.py3-none-any.whl + Successfully built cebra-X.X.XaX-py3-none-any.whl The built package can be found in ``dist/`` and can be installed locally with .. code:: bash - pip install dist/cebra-X.X.XaX-py2.py3-none-any.whl + pip install dist/cebra-X.X.XaX-py3-none-any.whl **Please do not distribute this package prior to the public release of the CEBRA repository, because it also contains parts of the source code.** diff --git a/docs/source/demos.rst b/docs/source/demos.rst deleted file mode 100644 index f0822386..00000000 --- a/docs/source/demos.rst +++ /dev/null @@ -1 +0,0 @@ -.. include:: demo_notebooks/README.rst diff --git a/docs/source/demos.rst b/docs/source/demos.rst new file mode 120000 index 00000000..edd57b74 --- /dev/null +++ b/docs/source/demos.rst @@ -0,0 +1 @@ +demo_notebooks/README.rst \ No newline at end of file diff --git a/docs/source/figures.rst b/docs/source/figures.rst index 24b1987e..a4101f4a 100644 --- a/docs/source/figures.rst +++ b/docs/source/figures.rst @@ -1,7 +1,7 @@ Figures ======= -CEBRA was introduced in `Schneider, Lee and Mathis (2022)`_ and applied to various datasets across +CEBRA was introduced in `Schneider, Lee and Mathis (2023)`_ and applied to various datasets across animals and recording modalities. In this section, we provide reference code for reproducing the figures and experiments. Since especially @@ -56,4 +56,4 @@ differ in minor typographic details. -.. _Schneider, Lee and Mathis (2022): https://arxiv.org/abs/2204.00673 +.. _Schneider, Lee and Mathis (2023): https://www.nature.com/articles/s41586-023-06031-6 diff --git a/docs/source/index.rst b/docs/source/index.rst index c8231746..1a6ce4d2 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -34,27 +34,18 @@ Please support the development of CEBRA by starring and/or watching the project Installation and Setup ---------------------- -Please see the dedicated :doc:`Installation Guide ` for information on installation options using ``conda``, ``pip`` and ``docker``. - -Have fun! 😁 +Please see the dedicated :doc:`Installation Guide ` for information on installation options using ``conda``, ``pip`` and ``docker``. Have fun! 😁 Usage ----- Please head over to the :doc:`Usage ` tab to find step-by-step instructions to use CEBRA on your data. For example use cases, see the :doc:`Demos ` tab. -Integrations ------------- - -CEBRA can be directly integrated with existing libraries commonly used in data analysis. The ``cebra.integrations`` module -is getting actively extended. Right now, we offer integrations for ``scikit-learn``-like usage of CEBRA, a package making use of ``matplotlib`` to plot the CEBRA model results, as well as the -possibility to compute CEBRA embeddings on DeepLabCut_ outputs directly. - Licensing --------- - -Since version 0.4.0, CEBRA is open source software under an Apache 2.0 license. +The ideas presented in our package are currently patent pending (Patent No. WO2023143843). +Since version 0.4.0, CEBRA's source is licenced under an Apache 2.0 license. Prior versions 0.1.0 to 0.3.1 were released for academic use only. Please see the full license file on Github_ for further information. @@ -65,13 +56,19 @@ Contributing Please refer to the :doc:`Contributing ` tab to find our guidelines on contributions. -Code contributors +Code Contributors ----------------- -The CEBRA code was originally developed by Steffen Schneider, Jin H. Lee, and Mackenzie Mathis (up to internal version 0.0.2). As of March 2023, it is being actively extended and maintained by `Steffen Schneider`_, `Célia Benquet`_, and `Mackenzie Mathis`_. +The CEBRA code was originally developed by Steffen Schneider, Jin H. Lee, and Mackenzie Mathis (up to internal version 0.0.2). Please see our AUTHORS file for more information. -References ----------- +Integrations +------------ + +CEBRA can be directly integrated with existing libraries commonly used in data analysis. Namely, we provide a ``scikit-learn`` style interface to use CEBRA. Additionally, we offer integrations with our ``scikit-learn``-style of using CEBRA, a package making use of ``matplotlib`` and ``plotly`` to plot the CEBRA model results, as well as the possibility to compute CEBRA embeddings on DeepLabCut_ outputs directly. If you have another suggestion, please head over to Discussions_ on GitHub_! + + +Key References +-------------- .. code:: @article{schneider2023cebra, @@ -82,14 +79,22 @@ References year = {2023}, } + @article{xCEBRA2025, + author={Steffen Schneider and Rodrigo Gonz{\'a}lez Laiz and Anastasiia Filippova and Markus Frey and Mackenzie W Mathis}, + title = {Time-series attribution maps with regularized contrastive learning}, + journal = {AISTATS}, + url = {https://openreview.net/forum?id=aGrCXoTB4P}, + year = {2025}, + } + This documentation is based on the `PyData Theme`_. .. _`Twitter`: https://twitter.com/cebraAI .. _`PyData Theme`: https://github.com/pydata/pydata-sphinx-theme .. _`DeepLabCut`: https://deeplabcut.org +.. _`Discussions`: https://github.com/AdaptiveMotorControlLab/CEBRA/discussions .. _`Github`: https://github.com/AdaptiveMotorControlLab/cebra .. _`email`: mailto:mackenzie.mathis@epfl.ch .. _`Steffen Schneider`: https://github.com/stes -.. _`Célia Benquet`: https://github.com/CeliaBenquet .. _`Mackenzie Mathis`: https://github.com/MMathisLab diff --git a/docs/source/installation.rst b/docs/source/installation.rst index a9650452..1630cfe8 100644 --- a/docs/source/installation.rst +++ b/docs/source/installation.rst @@ -4,7 +4,7 @@ Installation Guide System Requirements ------------------- -CEBRA is written in Python (3.8+) and PyTorch. CEBRA is most effective when used with a GPU, but CPU-only support is provided. We provide instructions to run CEBRA on your system directly. The instructions below were tested on different compute setups with Ubuntu 18.04 or 20.04, using Nvidia GTX 2080, A4000, and V100 cards. Other setups are possible (including Windows), as long as CUDA 10.2+ support is guaranteed. +CEBRA is written in Python (3.9+) and PyTorch. CEBRA is most effective when used with a GPU, but CPU-only support is provided. We provide instructions to run CEBRA on your system directly. The instructions below were tested on different compute setups with Ubuntu 18.04 or 20.04, using Nvidia GTX 2080, A4000, and V100 cards. Other setups are possible (including Windows), as long as CUDA 10.2+ support is guaranteed. - Software dependencies and operating systems: - Linux or MacOS @@ -93,11 +93,11 @@ we outline different options below. * 🚀 For more advanced users, CEBRA has different extra install options that you can select based on your usecase: - * ``[integrations]``: This will install (experimental) support for our streamlit and jupyter integrations. + * ``[integrations]``: This will install (experimental) support for integrations, such as plotly. * ``[docs]``: This will install additional dependencies for building the package documentation. * ``[dev]``: This will install additional dependencies for development, unit and integration testing, code formatting, etc. Install this extension if you want to work on a pull request. - * ``[demos]``: This will install additional dependencies for running our demo notebooks. + * ``[demos]``: This will install additional dependencies for running our demo notebooks in Jupyter. * ``[datasets]``: This extension will install additional dependencies to use the pre-installed datasets in ``cebra.datasets``. @@ -149,6 +149,13 @@ we outline different options below. Note that, similarly to that last command, you can select the specific install options of interest based on their description above and on your usecase. + .. tab:: Docker + + .. code:: bash + + $ docker pull mmathislab/cebra-cuda12.4-cudnn9 + + You can pull our container from DockerHub: https://hub.docker.com/u/mmathislab .. diff --git a/docs/source/usage.rst b/docs/source/usage.rst index 334f1bbc..82e45a0b 100644 --- a/docs/source/usage.rst +++ b/docs/source/usage.rst @@ -1,7 +1,7 @@ Using CEBRA =========== -This page covers a standard CEBRA usage. We recommend checking out the :py:doc:`demos` for in-depth CEBRA usage examples as well. Here we present a quick overview on how to use CEBRA on various datasets. Note that we provide two ways to interact with the code: +This page covers a standard CEBRA usage. We recommend checking out the :py:doc:`demos` for CEBRA usage examples as well. Here we present a quick overview on how to use CEBRA on various datasets. Note that we provide two ways to interact with the code: * For regular usage, we recommend leveraging the **high-level interface**, adhering to ``scikit-learn`` formatting. * Upon specific needs, advanced users might consider diving into the **low-level interface** that adheres to ``PyTorch`` formatting. @@ -12,7 +12,7 @@ Firstly, why use CEBRA? CEBRA is primarily designed for producing robust, consistent extractions of latent factors from time-series data. It supports three modes, and is a self-supervised representation learning algorithm that uses our modified contrastive learning approach designed for multi-modal time-series data. In short, it is a type of non-linear dimensionality reduction, like `tSNE `_ and `UMAP `_. We show in our original paper that it outperforms tSNE and UMAP at producing closer-to-ground-truth latents and is more consistent. -That being said, CEBRA can be used on non-time-series data and it does not strictly require multi-modal data. In general, we recommend considering using CEBRA for measuring changes in consistency across conditions (brain areas, cells, animals), for hypothesis-guided decoding, and for topological exploration of the resulting embedding spaces. It can also be used for visualization and considering dynamics within the embedding space. For examples of how CEBRA can be used to map space, decode natural movies, and make hypotheses for neural coding of sensorimotor systems, see our paper (Schneider, Lee, Mathis, 2023). +That being said, CEBRA can be used on non-time-series data and it does not strictly require multi-modal data. In general, we recommend considering using CEBRA for measuring changes in consistency across conditions (brain areas, cells, animals), for hypothesis-guided decoding, and for topological exploration of the resulting embedding spaces. It can also be used for visualization and considering dynamics within the embedding space. For examples of how CEBRA can be used to map space, decode natural movies, and make hypotheses for neural coding of sensorimotor systems, see `Schneider, Lee, Mathis. Nature 2023 `_. The CEBRA workflow ------------------ @@ -22,7 +22,7 @@ We recommend to start with running CEBRA-Time (unsupervised) and look both at th (1) Use CEBRA-Time for unsupervised data exploration. (2) Consider running a hyperparameter sweep on the inputs to the model, such as :py:attr:`cebra.CEBRA.model_architecture`, :py:attr:`cebra.CEBRA.time_offsets`, :py:attr:`cebra.CEBRA.output_dimension`, and set :py:attr:`cebra.CEBRA.batch_size` to be as high as your GPU allows. You want to see clear structure in the 3D plot (the first 3 latents are shown by default). -(3) Use CEBRA-Behavior with many different labels and combinations, then look at the InfoNCE loss - the lower the loss value, the better the fit (see :py:doc:`cebra-figures/figures/ExtendedDataFigure5`), and visualize the embeddings. The goal is to understand which labels are contributing to the structure you see in CEBRA-Time, and improve this structure. Again, you should consider a hyperparameter sweep. +(3) Use CEBRA-Behavior with many different labels and combinations, then look at the InfoNCE loss - the lower the loss value, the better the fit (see :py:doc:`cebra-figures/figures/ExtendedDataFigure5`), and visualize the embeddings. The goal is to understand which labels are contributing to the structure you see in CEBRA-Time, and improve this structure. Again, you should consider a hyperparameter sweep (and avoid overfitting by performing the proper train/validation split (see Step 3 in our quick start guide below). (4) Interpretability: now you can use these latents in downstream tasks, such as measuring consistency, decoding, and determining the dimensionality of your data with topological data analysis. All the steps to do this are described below. Enjoy using CEBRA! 🔥🦓 @@ -179,7 +179,7 @@ We provide a set of pre-defined models. You can access (and search) a list of av Then, you can choose the one that fits best with your needs and provide it to the CEBRA model as the :py:attr:`~.CEBRA.model_architecture` parameter. -As an indication the table below presents the model architecture we used to train CEBRA on the datasets presented in our paper (Schneider, Lee, Mathis, 2022). +As an indication the table below presents the model architecture we used to train CEBRA on the datasets presented in our paper (Schneider, Lee, Mathis. Nature 2023). .. list-table:: :widths: 25 25 20 30 @@ -265,9 +265,8 @@ For standard usage we recommend the default values (i.e., ``InfoNCE`` and ``cosi .. rubric:: Temperature :py:attr:`~.CEBRA.temperature` -:py:attr:`~.CEBRA.temperature` has the largest effect on visualization of the embedding (see :py:doc:`cebra-figures/figures/ExtendedDataFigure2`). Hence, it is important that it is fitted to your specific data. +:py:attr:`~.CEBRA.temperature` has the largest effect on *visualization* of the embedding (see :py:doc:`cebra-figures/figures/ExtendedDataFigure2`). Hence, it is important that it is fitted to your specific data. Lower temperatures (e.g. around 0.1) will result in a more dispersed embedding, higher temperatures (larger than 1) will concentrate the embedding. -The simplest way to handle it is to use a *learnable temperature*. For that, set :py:attr:`~.CEBRA.temperature_mode` to ``auto``. :py:attr:`~.CEBRA.temperature` will be trained alongside the model. 🚀 For advance usage, you might need to find the optimal :py:attr:`~.CEBRA.temperature`. For that we recommend to perform a grid-search. @@ -307,7 +306,6 @@ Here is an example of a CEBRA model initialization: cebra_model = CEBRA( model_architecture = "offset10-model", batch_size = 1024, - temperature_mode="auto", learning_rate = 0.001, max_iterations = 10, time_offsets = 10, @@ -321,8 +319,7 @@ Here is an example of a CEBRA model initialization: .. testoutput:: CEBRA(batch_size=1024, learning_rate=0.001, max_iterations=10, - model_architecture='offset10-model', temperature_mode='auto', - time_offsets=10) + model_architecture='offset10-model', time_offsets=10) .. admonition:: See API docs :class: dropdown @@ -568,7 +565,8 @@ We provide a simple hyperparameters sweep to compare CEBRA models with different learning_rate = [0.001], time_offsets = 5, max_iterations = 5, - temperature_mode = "auto", + temperature_mode='constant', + temperature = 0.1, verbose = False) # 2. Define the datasets to iterate over @@ -820,7 +818,7 @@ It takes a CEBRA model and returns a 2D plot of the loss against the number of i Displaying the temperature """""""""""""""""""""""""" -:py:attr:`~.CEBRA.temperature` has the largest effect on the visualization of the embedding. Hence it might be interesting to check its evolution when ``temperature_mode=auto``. +:py:attr:`~.CEBRA.temperature` has the largest effect on the visualization of the embedding. Hence it might be interesting to check its evolution when ``temperature_mode=auto``. We recommend only using `auto` if you have first explored the `constant` setting. If you use the ``auto`` mode, please always check the time evolution of the temperature over time alongside the loss curve. To that extend, you can use the function :py:func:`~.plot_temperature`. @@ -1186,9 +1184,10 @@ Improve model performance 🧐 Below is a (non-exhaustive) list of actions you can try if your embedding looks different from what you were expecting. #. Assess that your model `converged `_. For that, observe if the training loss stabilizes itself around the end of the training or still seems to be decreasing. Refer to `Visualize the training loss`_ for more details on how to display the training loss. -#. Increase the number of iterations. It should be at least 10,000. +#. Increase the number of iterations. It typically should be at least 10,000. On small datasets, it can make sense to stop training earlier to avoid overfitting effects. #. Make sure the batch size is big enough. It should be at least 512. #. Fine-tune the model's hyperparameters, namely ``learning_rate``, ``output_dimension``, ``num_hidden_units`` and eventually ``temperature`` (by setting ``temperature_mode`` back to ``constant``). Refer to `Grid search`_ for more details on performing hyperparameters tuning. +#. To note, you should still be mindful of performing train/validation splits and shuffle controls to avoid `overfitting `_. @@ -1202,17 +1201,22 @@ Putting all previous snippet examples together, we obtain the following pipeline import cebra from numpy.random import uniform, randint from sklearn.model_selection import train_test_split + import os + import tempfile + from pathlib import Path # 1. Define a CEBRA model cebra_model = cebra.CEBRA( - model_architecture = "offset10-model", - batch_size = 512, - learning_rate = 1e-4, - max_iterations = 10, # TODO(user): to change to at least 10'000 - max_adapt_iterations = 10, # TODO(user): to change to ~100-500 - time_offsets = 10, - output_dimension = 8, - verbose = False + model_architecture = "offset10-model", + batch_size = 512, + learning_rate = 1e-4, + temperature_mode='constant', + temperature = 0.1, + max_iterations = 10, # TODO(user): to change to ~500-10000 depending on dataset size + #max_adapt_iterations = 10, # TODO(user): use and to change to ~100-500 if adapting + time_offsets = 10, + output_dimension = 8, + verbose = False ) # 2. Load example data @@ -1221,34 +1225,40 @@ Putting all previous snippet examples together, we obtain the following pipeline continuous_label = cebra.load_data(file="auxiliary_behavior_data.h5", key="auxiliary_variables", columns=["continuous1", "continuous2", "continuous3"]) discrete_label = cebra.load_data(file="auxiliary_behavior_data.h5", key="auxiliary_variables", columns=["discrete"]).flatten() + assert neural_data.shape == (100, 3) assert new_neural_data.shape == (100, 4) assert discrete_label.shape == (100, ) assert continuous_label.shape == (100, 3) - # 3. Split data and labels - ( - train_data, - valid_data, - train_discrete_label, - valid_discrete_label, - train_continuous_label, - valid_continuous_label, - ) = train_test_split(neural_data, - discrete_label, - continuous_label, - test_size=0.3) + # 3. Split data and labels into train/validation + from sklearn.model_selection import train_test_split + + split_idx = int(0.8 * len(neural_data)) + # suggestion: 5%-20% depending on your dataset size; note that this splits the + # into an early and late part, which might not be ideal for your data/experiment! + # As a more involved alternative, consider e.g. a nested time-series split. + + train_data = neural_data[:split_idx] + valid_data = neural_data[split_idx:] + + train_continuous_label = continuous_label[:split_idx] + valid_continuous_label = continuous_label[split_idx:] + + train_discrete_label = discrete_label[:split_idx] + valid_discrete_label = discrete_label[split_idx:] # 4. Fit the model # time contrastive learning cebra_model.fit(train_data) # discrete behavior contrastive learning - cebra_model.fit(train_data, train_discrete_label,) + cebra_model.fit(train_data, train_discrete_label) # continuous behavior contrastive learning cebra_model.fit(train_data, train_continuous_label) # mixed behavior contrastive learning cebra_model.fit(train_data, train_discrete_label, train_continuous_label) + # 5. Save the model tmp_file = Path(tempfile.gettempdir(), 'cebra.pt') cebra_model.save(tmp_file) @@ -1257,15 +1267,15 @@ Putting all previous snippet examples together, we obtain the following pipeline cebra_model = cebra.CEBRA.load(tmp_file) train_embedding = cebra_model.transform(train_data) valid_embedding = cebra_model.transform(valid_data) - assert train_embedding.shape == (70, 8) - assert valid_embedding.shape == (30, 8) - # 7. Evaluate the model performances - goodness_of_fit = cebra.sklearn.metrics.infonce_loss(cebra_model, + assert train_embedding.shape == (80, 8) # TODO(user): change to split ratio & output dim + assert valid_embedding.shape == (20, 8) # TODO(user): change to split ratio & output dim + + # 7. Evaluate the model performance (you can also check the train_data) + goodness_of_fit = cebra.sklearn.metrics.goodness_of_fit_score(cebra_model, valid_data, valid_discrete_label, - valid_continuous_label, - num_batches=5) + valid_continuous_label) # 8. Adapt the model to a new session cebra_model.fit(new_neural_data, adapt = True) @@ -1274,7 +1284,9 @@ Putting all previous snippet examples together, we obtain the following pipeline decoder = cebra.KNNDecoder() decoder.fit(train_embedding, train_discrete_label) prediction = decoder.predict(valid_embedding) - assert prediction.shape == (30,) + assert prediction.shape == (20,) + + 👉 For further guidance on different/customized applications of CEBRA on your own data, refer to the ``examples/`` folder or to the full documentation folder ``docs/``. @@ -1424,17 +1436,14 @@ gets initialized which also allows the `prior` to be directly parametrized. solver.fit(loader=loader) # 7. Transform Embedding - train_batches = np.lib.stride_tricks.sliding_window_view( - neural_data, neural_model.get_offset().__len__(), axis=0 - ) - x_train_emb = solver.transform( - torch.from_numpy(train_batches[:]).type(torch.FloatTensor).to(device) - ).to(device) + torch.from_numpy(neural_data).type(torch.FloatTensor).to(device), + pad_before_transform=True, + batch_size=512).to(device) # 8. Plot Embedding cebra.plot_embedding( x_train_emb.cpu(), - discrete_label[neural_model.get_offset().__len__() - 1 :, 0], + discrete_label[:,0], markersize=10, ) diff --git a/pyproject.toml b/pyproject.toml index 4a927c6c..b64475e7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,8 @@ [build-system] requires = [ "setuptools>=43", - "wheel" + "wheel", + "packaging>=24.2" ] build-backend = "setuptools.build_meta" diff --git a/reinstall.sh b/reinstall.sh index 778f98eb..ea8981b9 100755 --- a/reinstall.sh +++ b/reinstall.sh @@ -15,7 +15,7 @@ pip uninstall -y cebra # Get version info after uninstalling --- this will automatically get the # most recent version based on the source code in the current directory. # $(tools/get_cebra_version.sh) -VERSION=0.4.0 +VERSION=0.6.0a1 echo "Upgrading to CEBRA v${VERSION}" # Upgrade the build system (PEP517/518 compatible) @@ -24,4 +24,4 @@ python3 -m pip install --upgrade build python3 -m build --sdist --wheel . # Reinstall the package with most recent version -pip install --upgrade --no-cache-dir "dist/cebra-${VERSION}-py2.py3-none-any.whl[datasets,integrations]" +pip install --upgrade --no-cache-dir "dist/cebra-${VERSION}-py3-none-any.whl[datasets,integrations]" diff --git a/setup.cfg b/setup.cfg index 68263d73..7faff998 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,8 +1,8 @@ [metadata] name = cebra version = attr: cebra.__version__ -author = Steffen Schneider, Jin H Lee, Mackenzie W Mathis -author_email = stes@hey.com +author = file: AUTHORS.md +author_email = stes@hey.com, mackenzie@post.harvard.edu description = Consistent Embeddings of high-dimensional Recordings using Auxiliary variables long_description = file: README.md long_description_content_type = text/markdown @@ -31,13 +31,17 @@ where = python_requires = >=3.9 install_requires = joblib - numpy<2.0.0 + numpy<2.0;platform_system=="Windows" + numpy<2.0;platform_system!="Windows" and python_version<"3.10" + numpy;platform_system!="Windows" and python_version>="3.10" literate-dataclasses scikit-learn scipy - torch + torch>=2.4.0 tqdm - matplotlib + # NOTE(stes): Remove pin once https://github.com/AdaptiveMotorControlLab/CEBRA/issues/240 + # is resolved. + matplotlib<3.11 requests [options.extras_require] @@ -56,15 +60,18 @@ datasets = hdf5storage # for creating .mat files in new format openpyxl # for excel file format loading integrations = - jupyter pandas plotly + seaborn + captum + cvxpy + scikit-image docs = - sphinx==5.3 - sphinx-gallery==0.10.1 + sphinx + sphinx-gallery docutils - pydata-sphinx-theme==0.9.0 - sphinx_autodoc_typehints==1.19 + pydata-sphinx-theme + sphinx_autodoc_typehints sphinx_copybutton sphinx_tabs sphinx_design @@ -72,16 +79,14 @@ docs = nbsphinx nbconvert ipykernel - matplotlib<=3.5.2 + matplotlib pandas seaborn scikit-learn - numpy<2.0.0 demos = ipykernel jupyter nbconvert - seaborn # TODO(stes): Additional dependency for running # co-homology analysis # is ripser, which can be tricky to @@ -104,12 +109,10 @@ dev = pytest-sphinx tables licenseheaders + interrogate # TODO(stes) Add back once upstream issue # https://github.com/PyCQA/docformatter/issues/119 # is resolved. # docformatter[tomli] codespell cffconvert - -[bdist_wheel] -universal=1 diff --git a/tests/_build_legacy_model/.gitignore b/tests/_build_legacy_model/.gitignore new file mode 100644 index 00000000..4b6ebe5f --- /dev/null +++ b/tests/_build_legacy_model/.gitignore @@ -0,0 +1 @@ +*.pt diff --git a/tests/_build_legacy_model/Dockerfile b/tests/_build_legacy_model/Dockerfile new file mode 100644 index 00000000..ddbb0e61 --- /dev/null +++ b/tests/_build_legacy_model/Dockerfile @@ -0,0 +1,39 @@ +FROM python:3.12-slim AS base +RUN pip install torch --index-url https://download.pytorch.org/whl/cpu +RUN apt-get update && \ + apt-get install -y --no-install-recommends git && \ + rm -rf /var/lib/apt/lists/* + +FROM base AS cebra-0.4.0-scikit-learn-1.4 +RUN pip install cebra==0.4.0 "scikit-learn<1.5" +WORKDIR /app +COPY create_model.py . +RUN python create_model.py + +FROM base AS cebra-0.4.0-scikit-learn-1.6 +RUN pip install cebra==0.4.0 "scikit-learn>=1.6" +WORKDIR /app +COPY create_model.py . +RUN python create_model.py + +FROM base AS cebra-rc-scikit-learn-1.4 +# NOTE(stes): Commit where new scikit-learn tag logic was added to the CEBRA class. +# https://github.com/AdaptiveMotorControlLab/CEBRA/commit/5f46c3257952a08dfa9f9e1b149a85f7f12c1053 +RUN pip install git+https://github.com/AdaptiveMotorControlLab/CEBRA.git@5f46c3257952a08dfa9f9e1b149a85f7f12c1053 "scikit-learn<1.5" +WORKDIR /app +COPY create_model.py . +RUN python create_model.py + +FROM base AS cebra-rc-scikit-learn-1.6 +# NOTE(stes): Commit where new scikit-learn tag logic was added to the CEBRA class. +# https://github.com/AdaptiveMotorControlLab/CEBRA/commit/5f46c3257952a08dfa9f9e1b149a85f7f12c1053 +RUN pip install git+https://github.com/AdaptiveMotorControlLab/CEBRA.git@5f46c3257952a08dfa9f9e1b149a85f7f12c1053 "scikit-learn>=1.6" +WORKDIR /app +COPY create_model.py . +RUN python create_model.py + +FROM scratch +COPY --from=cebra-0.4.0-scikit-learn-1.4 /app/cebra_model.pt /cebra_model_cebra-0.4.0-scikit-learn-1.4.pt +COPY --from=cebra-0.4.0-scikit-learn-1.6 /app/cebra_model.pt /cebra_model_cebra-0.4.0-scikit-learn-1.6.pt +COPY --from=cebra-rc-scikit-learn-1.4 /app/cebra_model.pt /cebra_model_cebra-rc-scikit-learn-1.4.pt +COPY --from=cebra-rc-scikit-learn-1.6 /app/cebra_model.pt /cebra_model_cebra-rc-scikit-learn-1.6.pt diff --git a/tests/_build_legacy_model/README.md b/tests/_build_legacy_model/README.md new file mode 100644 index 00000000..4bcffa2b --- /dev/null +++ b/tests/_build_legacy_model/README.md @@ -0,0 +1,13 @@ +# Helper script to build CEBRA checkpoints + +This script builds CEBRA checkpoints for different versions of scikit-learn and CEBRA. +To build all models, run: + +```bash +./generate.sh +``` + +The models are currently also stored in git directly due to their small size. + +Related issue: https://github.com/AdaptiveMotorControlLab/CEBRA/issues/207 +Related test: tests/test_sklearn_legacy.py diff --git a/tests/_build_legacy_model/create_model.py b/tests/_build_legacy_model/create_model.py new file mode 100644 index 00000000..f308d296 --- /dev/null +++ b/tests/_build_legacy_model/create_model.py @@ -0,0 +1,15 @@ +import numpy as np + +import cebra + +neural_data = np.random.normal(0, 1, (1000, 30)) # 1000 samples, 30 features +cebra_model = cebra.CEBRA(model_architecture="offset10-model", + batch_size=512, + learning_rate=1e-4, + max_iterations=10, + time_offsets=10, + num_hidden_units=16, + output_dimension=8, + verbose=True) +cebra_model.fit(neural_data) +cebra_model.save("cebra_model.pt") diff --git a/tests/_build_legacy_model/generate.sh b/tests/_build_legacy_model/generate.sh new file mode 100755 index 00000000..749a0d32 --- /dev/null +++ b/tests/_build_legacy_model/generate.sh @@ -0,0 +1,3 @@ +#!/bin/bash + +DOCKER_BUILDKIT=1 docker build --output type=local,dest=. . diff --git a/tests/_util.py b/tests/_util.py index b4a0e07d..42dd54cb 100644 --- a/tests/_util.py +++ b/tests/_util.py @@ -74,3 +74,8 @@ def parametrize_with_checks_slow(fast_arguments, slow_arguments): slow_arg, generate_only=True))[0] for slow_arg in slow_arguments ] return parametrize_slow("estimator,check", fast_params, slow_params) + + +def parametrize_device(func): + _devices = ("cpu", "cuda") if torch.cuda.is_available() else ("cpu",) + return pytest.mark.parametrize("device", _devices)(func) diff --git a/tests/_utils_deprecated.py b/tests/_utils_deprecated.py new file mode 100644 index 00000000..bf412058 --- /dev/null +++ b/tests/_utils_deprecated.py @@ -0,0 +1,126 @@ +import warnings +from typing import Optional, Union + +import numpy as np +import numpy.typing as npt +import sklearn.utils.validation as sklearn_utils_validation +import torch + +import cebra +import cebra.integrations.sklearn.utils as sklearn_utils +import cebra.models + + +#NOTE: Deprecated: transform is now handled in the solver but the original +# method is kept here for testing. +def cebra_transform_deprecated(cebra_model, + X: Union[npt.NDArray, torch.Tensor], + session_id: Optional[int] = None) -> npt.NDArray: + """Transform an input sequence and return the embedding. + + Args: + cebra_model: The CEBRA model to use for the transform. + X: A numpy array or torch tensor of size ``time x dimension``. + session_id: The session ID, an :py:class:`int` between 0 and :py:attr:`num_sessions` for + multisession, set to ``None`` for single session. + + Returns: + A :py:func:`numpy.array` of size ``time x output_dimension``. + + Example: + + >>> import cebra + >>> import numpy as np + >>> dataset = np.random.uniform(0, 1, (1000, 30)) + >>> cebra_model = cebra.CEBRA(max_iterations=10) + >>> cebra_model.fit(dataset) + CEBRA(max_iterations=10) + >>> embedding = cebra_model.transform(dataset) + + """ + warnings.warn( + "The method is deprecated " + "but kept for testing puroposes." + "We recommend using `transform` instead.", + DeprecationWarning, + stacklevel=2) + + sklearn_utils_validation.check_is_fitted(cebra_model, "n_features_") + model, offset = cebra_model._select_model(X, session_id) + + # Input validation + X = sklearn_utils.check_input_array(X, min_samples=len(cebra_model.offset_)) + input_dtype = X.dtype + + with torch.no_grad(): + model.eval() + + if cebra_model.pad_before_transform: + X = np.pad(X, ((offset.left, offset.right - 1), (0, 0)), + mode="edge") + X = torch.from_numpy(X).float().to(cebra_model.device_) + + if isinstance(model, cebra.models.ConvolutionalModelMixin): + # Fully convolutional evaluation, switch (T, C) -> (1, C, T) + X = X.transpose(1, 0).unsqueeze(0) + output = model(X).cpu().numpy().squeeze(0).transpose(1, 0) + else: + # Standard evaluation, (T, C, dt) + output = model(X).cpu().numpy() + + if input_dtype == "float64": + return output.astype(input_dtype) + + return output + + +# NOTE: Deprecated: batched transform can now be performed (more memory efficient) +# using the transform method of the model, and handling padding is implemented +# directly in the base Solver. This method is kept for testing purposes. +@torch.no_grad() +def multiobjective_transform_deprecated(solver: "cebra.solvers.Solver", + inputs: torch.Tensor) -> torch.Tensor: + """Transform the input data using the model. + + Args: + solver: The solver containing the model and device. + inputs: The input data to transform. + + Returns: + The transformed data. + """ + + warnings.warn( + "The method is deprecated " + "but kept for testing puroposes." + "We recommend using `transform` instead.", + DeprecationWarning, + stacklevel=2) + + offset = solver.model.get_offset() + solver.model.eval() + X = inputs.cpu().numpy() + X = np.pad(X, ((offset.left, offset.right - 1), (0, 0)), mode="edge") + X = torch.from_numpy(X).float().to(solver.device) + + if isinstance(solver.model.module, cebra.models.ConvolutionalModelMixin): + # Fully convolutional evaluation, switch (T, C) -> (1, C, T) + X = X.transpose(1, 0).unsqueeze(0) + outputs = solver.model(X) + + # switch back from (1, C, T) -> (T, C) + if isinstance(outputs, torch.Tensor): + assert outputs.dim() == 3 and outputs.shape[0] == 1 + outputs = outputs.squeeze(0).transpose(1, 0) + elif isinstance(outputs, tuple): + assert all(tensor.dim() == 3 and tensor.shape[0] == 1 + for tensor in outputs) + outputs = (output.squeeze(0).transpose(1, 0) for output in outputs) + outputs = tuple(outputs) + else: + raise ValueError("Invalid condition in solver.transform") + else: + # Standard evaluation, (T, C, dt) + outputs = solver.model(X) + + return outputs diff --git a/tests/test_api.py b/tests/test_api.py index bc279cbd..4e514429 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -21,6 +21,5 @@ # def test_api(): import cebra.distributions - from cebra.distributions import TimedeltaDistribution cebra.distributions.TimedeltaDistribution diff --git a/tests/test_attribution.py b/tests/test_attribution.py new file mode 100644 index 00000000..cfb8ad7a --- /dev/null +++ b/tests/test_attribution.py @@ -0,0 +1,214 @@ +import numpy as np +import pytest +import torch + +import cebra.attribution._jacobian +import cebra.attribution.jacobian_attribution as jacobian_attribution +from cebra.attribution import attribution_models +from cebra.models import Model + + +class DummyModel(Model): + + def __init__(self): + super().__init__(num_input=10, num_output=5) + self.linear = torch.nn.Linear(10, 5) + + def forward(self, x): + return self.linear(x) + + def get_offset(self): + return None + + +@pytest.fixture +def model(): + return DummyModel() + + +@pytest.fixture +def input_data(): + return torch.randn(100, 10) + + +def test_neuron_gradient_method(model, input_data): + attribution = attribution_models.NeuronGradientMethod(model=model, + input_data=input_data, + output_dimension=5) + + result = attribution.compute_attribution_map() + + assert 'neuron-gradient' in result + assert 'neuron-gradient-convabs' in result + assert result['neuron-gradient'].shape == (100, 5, 10) + + +def test_neuron_gradient_shap_method(model, input_data): + attribution = attribution_models.NeuronGradientShapMethod( + model=model, input_data=input_data, output_dimension=5) + + result = attribution.compute_attribution_map(baselines="zeros") + + assert 'neuron-gradient-shap' in result + assert 'neuron-gradient-shap-convabs' in result + assert result['neuron-gradient-shap'].shape == (100, 5, 10) + + with pytest.raises(NotImplementedError): + attribution.compute_attribution_map(baselines="invalid") + + +def test_feature_ablation_method(model, input_data): + attribution = attribution_models.FeatureAblationMethod( + model=model, input_data=input_data, output_dimension=5) + + result = attribution.compute_attribution_map() + + assert 'feature-ablation' in result + assert 'feature-ablation-convabs' in result + assert result['feature-ablation'].shape == (100, 5, 10) + + +def test_integrated_gradients_method(model, input_data): + attribution = attribution_models.IntegratedGradientsMethod( + model=model, input_data=input_data, output_dimension=5) + + result = attribution.compute_attribution_map() + + assert 'integrated-gradients' in result + assert 'integrated-gradients-convabs' in result + assert result['integrated-gradients'].shape == (100, 5, 10) + + +def test_batched_methods(model, input_data): + # Test batched version of NeuronGradientMethod + attribution = attribution_models.NeuronGradientMethodBatched( + model=model, input_data=input_data, output_dimension=5) + + result = attribution.compute_attribution_map(batch_size=32) + assert 'neuron-gradient' in result + assert result['neuron-gradient'].shape == (100, 5, 10) + + # Test batched version of IntegratedGradientsMethod + attribution = attribution_models.IntegratedGradientsMethodBatched( + model=model, input_data=input_data, output_dimension=5) + + result = attribution.compute_attribution_map(batch_size=32) + assert 'integrated-gradients' in result + assert result['integrated-gradients'].shape == (100, 5, 10) + + +def test_compute_metrics(): + attribution = attribution_models.AttributionMap(model=None, input_data=None) + + attribution_map = np.array([0.1, 0.8, 0.3, 0.9, 0.2]) + ground_truth = np.array([False, True, False, True, False]) + + metrics = attribution.compute_metrics(attribution_map, ground_truth) + + assert 'max_connected' in metrics + assert 'mean_connected' in metrics + assert 'min_connected' in metrics + assert 'max_nonconnected' in metrics + assert 'mean_nonconnected' in metrics + assert 'min_nonconnected' in metrics + assert 'gap_max' in metrics + assert 'gap_mean' in metrics + assert 'gap_min' in metrics + assert 'gap_minmax' in metrics + assert 'max_jacobian' in metrics + assert 'min_jacobian' in metrics + + +def test_compute_attribution_score(): + attribution = attribution_models.AttributionMap(model=None, input_data=None) + + attribution_map = np.array([0.1, 0.8, 0.3, 0.9, 0.2]) + ground_truth = np.array([False, True, False, True, False]) + + score = attribution.compute_attribution_score(attribution_map, ground_truth) + assert isinstance(score, float) + assert 0 <= score <= 1 + + +def test_jacobian_computation(): + # Create a simple model and input for testing + model = torch.nn.Sequential(torch.nn.Linear(10, 5), torch.nn.ReLU(), + torch.nn.Linear(5, 3)) + input_data = torch.randn(100, 10, requires_grad=True) + + # Test basic Jacobian computation + jf, jhatg = jacobian_attribution.get_attribution_map(model=model, + input_data=input_data, + double_precision=True, + convert_to_numpy=True) + + # Check shapes + assert jf.shape == (100, 3, 10) # (batch_size, output_dim, input_dim) + assert jhatg.shape == (100, 10, 3) # (batch_size, input_dim, output_dim) + + +def test_tensor_conversion(): + # Test CPU and double precision conversion + test_tensors = [torch.randn(10, 5), torch.randn(5, 3)] + + converted = cebra.attribution._jacobian.tensors_to_cpu_and_double( + test_tensors) + + for tensor in converted: + assert tensor.device.type == "cpu" + assert tensor.dtype == torch.float64 + + # Only test CUDA conversion if CUDA is available + if torch.cuda.is_available(): + cuda_tensors = cebra.attribution._jacobian.tensors_to_cuda( + test_tensors, cuda_device="cuda") + for tensor in cuda_tensors: + assert tensor.is_cuda + else: + # Skip CUDA test with a message + pytest.skip("CUDA not available - skipping CUDA conversion test") + + +def test_jacobian_with_hybrid_solver(): + # Test Jacobian computation with hybrid solver + class HybridModel(torch.nn.Module): + + def __init__(self): + super().__init__() + self.fc1 = torch.nn.Linear(10, 5) + self.fc2 = torch.nn.Linear(10, 3) + + def forward(self, x): + return self.fc1(x), self.fc2(x) + + model = HybridModel() + # Move model to CPU to ensure test works everywhere + model = model.cpu() + input_data = torch.randn(50, 10, requires_grad=True) + + # Ensure input is on CPU + input_data = input_data.cpu() + + jacobian = cebra.attribution._jacobian.compute_jacobian( + model=model, + input_vars=[input_data], + hybrid_solver=True, + convert_to_numpy=True, + cuda_device=None # Explicitly set to None to use CPU + ) + + # Check shape (batch_size, output_dim, input_dim) + assert jacobian.shape == (50, 8, 10) # 8 = 5 + 3 concatenated outputs + + +def test_attribution_map_transforms(): + model = torch.nn.Sequential(torch.nn.Linear(10, 5), torch.nn.ReLU(), + torch.nn.Linear(5, 3)) + input_data = torch.randn(100, 10) + + # Test different aggregation methods + for aggregate in ["mean", "sum", "max"]: + jf, jhatg = jacobian_attribution.get_attribution_map( + model=model, input_data=input_data, aggregate=aggregate) + assert isinstance(jf, np.ndarray) + assert isinstance(jhatg, np.ndarray) diff --git a/tests/test_cli.py b/tests/test_cli.py index 41e67f42..8e49cc35 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -19,6 +19,3 @@ # See the License for the specific language governing permissions and # limitations under the License. # -import argparse - -import pytest diff --git a/tests/test_criterions.py b/tests/test_criterions.py index 93a3b846..0d6f8ff2 100644 --- a/tests/test_criterions.py +++ b/tests/test_criterions.py @@ -19,7 +19,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # -import numpy as np import pytest import torch from torch import nn @@ -294,7 +293,7 @@ def _sample_dist_matrices(seed): @pytest.mark.parametrize("seed", [42, 4242, 424242]) -def test_infonce(seed): +def test_infonce_check_output_parts(seed): pos_dist, neg_dist = _sample_dist_matrices(seed) ref_loss, ref_align, ref_uniform = _reference_infonce(pos_dist, neg_dist) diff --git a/tests/test_data_masking.py b/tests/test_data_masking.py new file mode 100644 index 00000000..1b4976af --- /dev/null +++ b/tests/test_data_masking.py @@ -0,0 +1,206 @@ +import copy + +import pytest +import torch + +import cebra.data.mask +from cebra.data.masking import MaskedMixin + +#### Tests for Mask class #### + + +@pytest.mark.parametrize("mask", [ + cebra.data.mask.RandomNeuronMask, + cebra.data.mask.RandomTimestepMask, + cebra.data.mask.NeuronBlockMask, +]) +def test_random_mask(mask: cebra.data.mask.Mask): + data = torch.ones( + (10, 20, + 30)) # Example tensor with shape (batch_size, n_neurons, offset) + mask = mask(masking_value=0.5) + masked_data = mask.apply_mask(copy.deepcopy(data)) + + assert masked_data.shape == data.shape, "Masked data shape should match input data shape" + assert (masked_data <= 1).all() and ( + masked_data >= 0).all(), "Masked data should only contain values 0 or 1" + assert torch.sum(masked_data) < torch.sum( + data), "Masked data should have fewer active neurons than original data" + + +def test_timeblock_mask(): + data = torch.ones( + (10, 20, + 30)) # Example tensor with shape (batch_size, n_neurons, offset) + mask = cebra.data.mask.TimeBlockMask(masking_value=(0.035, 10)) + masked_data = mask.apply_mask(copy.deepcopy(data)) + + assert masked_data.shape == data.shape, "Masked data shape should match input data shape" + assert (masked_data <= 1).all() and ( + masked_data >= 0).all(), "Masked data should only contain values 0 or 1" + assert torch.sum(masked_data) < torch.sum( + data), "Masked data should have fewer active neurons than original data" + + +#### Tests for MaskedMixin class #### + + +def test_masked_mixin_no_masks(): + mixin = MaskedMixin() + data = torch.ones( + (10, 20, + 30)) # Example tensor with shape (batch_size, n_neurons, offset) + masked_data = mixin.apply_mask(copy.deepcopy(data)) + + assert torch.equal( + data, + masked_data), "Data should remain unchanged when no masks are applied" + + +@pytest.mark.parametrize( + "mask", ["RandomNeuronMask", "RandomTimestepMask", "NeuronBlockMask"]) +def test_masked_mixin_random_mask(mask): + data = torch.ones( + (10, 20, + 30)) # Example tensor with shape (batch_size, n_neurons, offset) + + mixin = MaskedMixin() + assert mixin.masks == [], "Masks should be empty initially" + + mixin.set_masks({mask: 0.5}) + assert len(mixin.masks) == 1, "One mask should be set" + assert isinstance(mixin.masks[0], + getattr(cebra.data.mask, + mask)), f"Mask should be of type {mask}" + if isinstance(mixin.masks[0], cebra.data.mask.NeuronBlockMask): + assert mixin.masks[ + 0].mask_prop == 0.5, "Masking value should be set correctly" + else: + assert mixin.masks[ + 0].mask_ratio == 0.5, "Masking value should be set correctly" + + masked_data = mixin.apply_mask(copy.deepcopy(data)) + assert masked_data.shape == data.shape, "Masked data shape should match input data shape" + assert not torch.equal( + data, masked_data), "Data should be modified when a mask is applied" + + mixin.set_masks({mask: [0.5, 0.1]}) + assert len(mixin.masks) == 1, "One mask should be set" + assert isinstance(mixin.masks[0], + getattr(cebra.data.mask, + mask)), f"Mask should be of type {mask}" + masked_data = mixin.apply_mask(copy.deepcopy(data)) + assert masked_data.shape == data.shape, "Masked data shape should match input data shape" + assert not torch.equal( + data, masked_data), "Data should be modified when a mask is applied" + + mixin.set_masks({mask: (0.3, 0.9, 0.05)}) + assert len(mixin.masks) == 1, "One mask should be set" + assert isinstance(mixin.masks[0], + getattr(cebra.data.mask, + mask)), f"Mask should be of type {mask}" + masked_data = mixin.apply_mask(copy.deepcopy(data)) + assert masked_data.shape == data.shape, "Masked data shape should match input data shape" + assert not torch.equal( + data, masked_data), "Data should be modified when a mask is applied" + + +def test_apply_mask_with_time_block_mask(): + mixin = MaskedMixin() + + with pytest.raises(AssertionError, match="sampled_rate.*masked_seq_len"): + mixin.set_masks({"TimeBlockMask": 0.2}) + + with pytest.raises(AssertionError, match="(sampled_rate.*masked_seq_len)"): + mixin.set_masks({"TimeBlockMask": [0.2, 10]}) + + with pytest.raises(AssertionError, match="between.*0.0.*1.0"): + mixin.set_masks({"TimeBlockMask": (-2, 10)}) + + with pytest.raises(AssertionError, match="between.*0.0.*1.0"): + mixin.set_masks({"TimeBlockMask": (2, 10)}) + + with pytest.raises(AssertionError, match="integer.*greater"): + mixin.set_masks({"TimeBlockMask": (0.2, -10)}) + + with pytest.raises(AssertionError, match="integer.*greater"): + mixin.set_masks({"TimeBlockMask": (0.2, 5.5)}) + + mixin.set_masks({"TimeBlockMask": (0.035, 10)}) # Correct usage + data = torch.ones( + (10, 20, + 30)) # Example tensor with shape (batch_size, n_neurons, offset) + masked_data = mixin.apply_mask(copy.deepcopy(data)) + + assert masked_data.shape == data.shape, "Masked data shape should match input data shape" + assert not torch.equal( + data, masked_data), "Data should be modified when a mask is applied" + + +def test_multiple_masks_mixin(): + mixin = MaskedMixin() + mixin.set_masks({"RandomNeuronMask": 0.5, "RandomTimestepMask": 0.3}) + data = torch.ones( + (10, 20, + 30)) # Example tensor with shape (batch_size, n_neurons, offset) + + masked_data = mixin.apply_mask(copy.deepcopy(data)) + assert masked_data.shape == data.shape, "Masked data shape should match input data shape" + assert not torch.equal( + data, + masked_data), "Data should be modified when multiple masks are applied" + + masked_data2 = mixin.apply_mask(copy.deepcopy(masked_data)) + assert masked_data2.shape == data.shape, "Masked data shape should match input data shape" + assert not torch.equal( + data, + masked_data2), "Data should be modified when multiple masks are applied" + assert not torch.equal( + masked_data, masked_data2 + ), "Masked data should be different for different iterations" + + +def test_single_dim_input(): + mixin = MaskedMixin() + mixin.set_masks({"RandomNeuronMask": 0.5}) + data = torch.ones((10, 1, 30)) # Single neuron + masked_data = mixin.apply_mask(copy.deepcopy(data)) + + assert masked_data.shape == data.shape, "Masked data shape should match input data shape" + assert not torch.equal( + data, masked_data), "Data should be modified even with a single neuron" + + mixin = MaskedMixin() + mixin.set_masks({"RandomTimestepMask": 0.5}) + data = torch.ones((10, 20, 1)) # Single timestep + masked_data = mixin.apply_mask(copy.deepcopy(data)) + + assert masked_data.shape == data.shape, "Masked data shape should match input data shape" + assert not torch.equal( + data, + masked_data), "Data should be modified even with a single timestep" + + +def test_apply_mask_with_invalid_input(): + mixin = MaskedMixin() + mixin.set_masks({"RandomNeuronMask": 0.5}) + + with pytest.raises(ValueError, match="Data must be a 3D tensor"): + data = torch.ones( + (10, 20)) # Invalid tensor shape (missing offset dimension) + mixin.apply_mask(data) + + with pytest.raises(ValueError, match="Data must be a float32 tensor"): + data = torch.ones((10, 20, 30), dtype=torch.int32) + mixin.apply_mask(data) + + +def test_apply_mask_with_chunk_size(): + mixin = MaskedMixin() + mixin.set_masks({"RandomNeuronMask": 0.5}) + data = torch.ones((10000, 20, 30)) # Large tensor to test chunking + masked_data = mixin.apply_mask(copy.deepcopy(data), chunk_size=1000) + + assert masked_data.shape == data.shape, "Masked data shape should match input data shape" + assert not torch.equal( + data, masked_data), "Data should be modified when a mask is applied" diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 6a7f9319..e8e03ff0 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -68,9 +68,7 @@ def test_demo(): @pytest.mark.requires_dataset def test_hippocampus(): - pytest.skip("Outdated") - dataset = cebra.datasets.init("rat-hippocampus-single") loader = cebra.data.ContinuousDataLoader( dataset=dataset, @@ -99,7 +97,6 @@ def test_hippocampus(): @pytest.mark.requires_dataset def test_monkey(): - dataset = cebra.datasets.init( "area2-bump-pos-active-passive", path=pathlib.Path(_DEFAULT_DATADIR) / "monkey_reaching_preload_smth_40", @@ -110,7 +107,6 @@ def test_monkey(): @pytest.mark.requires_dataset def test_allen(): - pytest.skip("Test takes too long") ca_dataset = cebra.datasets.init("allen-movie-one-ca-VISp-100-train-10-111") diff --git a/tests/test_demo.py b/tests/test_demo.py index 4f0f146c..ce555db3 100644 --- a/tests/test_demo.py +++ b/tests/test_demo.py @@ -21,7 +21,6 @@ # import glob import re -import sys import pytest diff --git a/tests/test_distributions.py b/tests/test_distributions.py index d7151fd1..656559bb 100644 --- a/tests/test_distributions.py +++ b/tests/test_distributions.py @@ -43,7 +43,7 @@ def prepare(N=1000, n=128, d=5, probs=[0.3, 0.1, 0.6], device="cpu"): continuous = torch.randn(N, d).to(device) rand = torch.from_numpy(np.random.randint(0, N, (n,))).to(device) - qidx = discrete[rand].to(device) + _ = discrete[rand].to(device) query = continuous[rand] + 0.1 * torch.randn(n, d).to(device) query = query.to(device) @@ -173,7 +173,7 @@ def test_mixed(): discrete, continuous) reference_idx = distribution.sample_prior(10) - positive_idx = distribution.sample_conditional(reference_idx) + _ = distribution.sample_conditional(reference_idx) # The conditional distribution p(· | disc, cont) should yield # samples where the label exactly matches the reference sample. @@ -193,7 +193,7 @@ def test_continuous(benchmark): def _test_distribution(dist): distribution = dist(continuous) reference_idx = distribution.sample_prior(10) - positive_idx = distribution.sample_conditional(reference_idx) + _ = distribution.sample_conditional(reference_idx) return distribution distribution = _test_distribution( @@ -411,3 +411,16 @@ def test_new_delta_normal_with_multidimensional_index(delta, numerical_check): pytest.skip( "multivariate delta distribution can not accurately sample with the " "given parameters. TODO: Add a warning message for these cases.") + + +@pytest.mark.parametrize("time_offset", [1, 5, 10]) +def test_unified_distribution(time_offset): + dataset = cebra_datasets.init("demo-continuous-unified") + sampler = cebra_distr.UnifiedSampler(dataset, time_offset=time_offset) + + num_samples = 5 + sample = sampler.sample_prior(num_samples) + assert sample.shape == (dataset.num_sessions, num_samples) + + positive = sampler.sample_conditional(sample) + assert positive.shape == (dataset.num_sessions, num_samples) diff --git a/tests/test_dlc.py b/tests/test_dlc.py index a19fe593..8ab29abd 100644 --- a/tests/test_dlc.py +++ b/tests/test_dlc.py @@ -29,6 +29,7 @@ import cebra.integrations.deeplabcut as cebra_dlc from cebra import CEBRA from cebra import load_data +from cebra.data.load import read_hdf # NOTE(stes): The original data URL is # https://github.com/DeepLabCut/DeepLabCut/blob/main/examples @@ -54,11 +55,7 @@ def test_imports(): def _load_dlc_dataframe(filename): - try: - df = pd.read_hdf(filename, "df_with_missing") - except KeyError: - df = pd.read_hdf(filename) - return df + return read_hdf(filename) def _get_annotated_data(url, keypoints): diff --git a/tests/test_grid_search.py b/tests/test_grid_search.py index 3f88ba12..c774ea02 100644 --- a/tests/test_grid_search.py +++ b/tests/test_grid_search.py @@ -20,7 +20,6 @@ # limitations under the License. # import numpy as np -import pytest import cebra import cebra.grid_search diff --git a/tests/test_integration_train.py b/tests/test_integration_train.py index 06e6da40..238bbea7 100644 --- a/tests/test_integration_train.py +++ b/tests/test_integration_train.py @@ -20,7 +20,6 @@ # limitations under the License. # import itertools -from typing import List import pytest import torch diff --git a/tests/test_integration_xcebra.py b/tests/test_integration_xcebra.py new file mode 100644 index 00000000..760e26ef --- /dev/null +++ b/tests/test_integration_xcebra.py @@ -0,0 +1,190 @@ +import pickle + +import _utils_deprecated +import numpy as np +import pytest +import torch + +import cebra +import cebra.attribution +import cebra.data +import cebra.models +import cebra.solver +from cebra.data import ContrastiveMultiObjectiveLoader +from cebra.data import DatasetxCEBRA +from cebra.solver import MultiObjectiveConfig +from cebra.solver.schedulers import LinearRampUp + + +@pytest.fixture +def synthetic_data(): + import tempfile + import urllib.request + from pathlib import Path + + url = "https://cebra.fra1.digitaloceanspaces.com/xcebra_synthetic_data.pkl" + + # Create a persistent temp directory specific to this test + temp_dir = Path(tempfile.gettempdir()) / "cebra_test_data" + temp_dir.mkdir(exist_ok=True) + filepath = temp_dir / "synthetic_data.pkl" + + if not filepath.exists(): + urllib.request.urlretrieve(url, filepath) + + with filepath.open('rb') as file: + return pickle.load(file) + + +@pytest.fixture +def device(): + return "cuda" if torch.cuda.is_available() else "cpu" + + +def test_synthetic_data_training(synthetic_data, device): + # Setup data + neurons = synthetic_data['neurons'] + latents = synthetic_data['latents'] + n_latents = latents.shape[1] + Z1 = synthetic_data['Z1'] + Z2 = synthetic_data['Z2'] + gt_attribution_map = synthetic_data['gt_attribution_map'] + data = DatasetxCEBRA(neurons, Z1=Z1, Z2=Z2) + + # Configure training with reduced steps + TOTAL_STEPS = 50 # Reduced from 2000 for faster testing + loader = ContrastiveMultiObjectiveLoader(dataset=data, + num_steps=TOTAL_STEPS, + batch_size=512).to(device) + + config = MultiObjectiveConfig(loader) + config.set_slice(0, 6) + config.set_loss("FixedEuclideanInfoNCE", temperature=1.) + config.set_distribution("time", time_offset=1) + config.push() + + config.set_slice(3, 6) + config.set_loss("FixedEuclideanInfoNCE", temperature=1.) + config.set_distribution("time_delta", time_delta=1, label_name="Z2") + config.push() + + config.finalize() + + # Initialize model and solver + neural_model = cebra.models.init( + name="offset1-model-mse-clip-5-5", + num_neurons=data.neural.shape[1], + num_units=256, + num_output=n_latents, + ).to(device) + + data.configure_for(neural_model) + + opt = torch.optim.Adam( + list(neural_model.parameters()) + list(config.criterion.parameters()), + lr=3e-4, + weight_decay=0, + ) + + regularizer = cebra.models.jacobian_regularizer.JacobianReg() + + solver = cebra.solver.init( + name="multiobjective-solver", + model=neural_model, + feature_ranges=config.feature_ranges, + regularizer=regularizer, + renormalize=False, + use_sam=False, + criterion=config.criterion, + optimizer=opt, + tqdm_on=False, + ).to(device) + + # Train model with reduced steps for regularizer + weight_scheduler = LinearRampUp( + n_splits=2, + step_to_switch_on_reg=25, # Reduced from 2500 + step_to_switch_off_reg=40, # Reduced from 15000 + start_weight=0., + end_weight=0.01, + stay_constant_after_switch_off=True) + + solver.fit( + loader=loader, + valid_loader=None, + log_frequency=None, + scheduler_regularizer=weight_scheduler, + scheduler_loss=None, + ) + + # Basic test that model runs and produces output + solver.model.split_outputs = False + embedding = solver.model(data.neural.to(device)).detach().cpu() + + # Verify output dimensions + assert embedding.shape[1] == n_latents, "Incorrect embedding dimension" + assert not torch.isnan(embedding).any(), "NaN values in embedding" + + # Test attribution map functionality + data.neural.requires_grad_(True) + method = cebra.attribution.init(name="jacobian-based", + model=solver.model, + input_data=data.neural, + output_dimension=solver.model.num_output) + + result = method.compute_attribution_map() + jfinv = abs(result['jf-inv-lsq']).mean(0) + + # Verify attribution map output + assert not torch.isnan( + torch.tensor(jfinv)).any(), "NaN values in attribution map" + assert jfinv.shape == gt_attribution_map.shape, "Incorrect attribution map shape" + + # Test split outputs functionality + solver.model.split_outputs = True + embedding_split = solver.model(data.neural.to(device)) + Z1_hat = embedding_split[0].detach().cpu() + Z2_hat = embedding_split[1].detach().cpu() + + # TODO(stes): Right now, this results 6D output vs. 3D as expected. Need to double check + # the API docs on the desired behavior here, both could be fine... + # assert Z1_hat.shape == Z1.shape, f"Incorrect Z1 embedding dimension: {Z1_hat.shape}" + assert Z2_hat.shape == Z2.shape, f"Incorrect Z2 embedding dimension: {Z2_hat.shape}" + assert not torch.isnan(Z1_hat).any(), "NaN values in Z1 embedding" + assert not torch.isnan(Z2_hat).any(), "NaN values in Z2 embedding" + + # Test the transform + solver.model.split_outputs = False + transform_embedding = solver.transform(data.neural.to(device)) + assert transform_embedding.shape[ + 1] == n_latents, "Incorrect embedding dimension" + assert not torch.isnan(transform_embedding).any(), "NaN values in embedding" + assert np.allclose(embedding, transform_embedding, rtol=1e-4, atol=1e-4) + + # Test the transform with batching + batched_embedding = solver.transform(data.neural.to(device), batch_size=512) + assert batched_embedding.shape[ + 1] == n_latents, "Incorrect embedding dimension" + assert not torch.isnan(batched_embedding).any(), "NaN values in embedding" + assert np.allclose(embedding, batched_embedding, rtol=1e-4, atol=1e-4) + + assert np.allclose(transform_embedding, + batched_embedding, + rtol=1e-4, + atol=1e-4) + + # Test and compare the previous transform (transform_deprecated) + deprecated_transform_embedding = _utils_deprecated.multiobjective_transform_deprecated( + solver, data.neural.to(device)) + assert np.allclose(embedding, + deprecated_transform_embedding, + rtol=1e-4, + atol=1e-4) + assert np.allclose(transform_embedding, + deprecated_transform_embedding, + rtol=1e-4, + atol=1e-4) + assert np.allclose(batched_embedding, + deprecated_transform_embedding, + rtol=1e-4, + atol=1e-4) diff --git a/tests/test_load.py b/tests/test_load.py index 6f62dc92..4524b29c 100644 --- a/tests/test_load.py +++ b/tests/test_load.py @@ -22,10 +22,7 @@ import itertools import pathlib import pickle -import platform import tempfile -import unittest -from unittest.mock import patch import h5py import hdf5storage @@ -125,7 +122,7 @@ def generate_numpy_confounder(filename, dtype): @register("npz") -def generate_numpy_path(filename, dtype): +def generate_numpy_path_2(filename, dtype): A = np.arange(1000, dtype=dtype).reshape(10, 100) np.savez(filename, array=A, other_data="test") loaded_A = cebra_load.load(pathlib.Path(filename)) @@ -251,7 +248,7 @@ def generate_h5_no_array(filename, dtype): def generate_h5_dataframe(filename, dtype): A = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) df_A = pd.DataFrame(np.array(A), columns=["a", "b", "c"]) - df_A.to_hdf(filename, "df_A") + df_A.to_hdf(filename, key="df_A") loaded_A = cebra_load.load(filename, key="df_A") return A, loaded_A @@ -261,7 +258,7 @@ def generate_h5_dataframe_columns(filename, dtype): A = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) A_col = A[:, :2] df_A = pd.DataFrame(np.array(A), columns=["a", "b", "c"]) - df_A.to_hdf(filename, "df_A") + df_A.to_hdf(filename, key="df_A") loaded_A = cebra_load.load(filename, key="df_A", columns=["a", "b"]) return A_col, loaded_A @@ -272,8 +269,8 @@ def generate_h5_multi_dataframe(filename, dtype): B = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) df_A = pd.DataFrame(np.array(A), columns=["a", "b", "c"]) df_B = pd.DataFrame(np.array(B), columns=["c", "d", "e"]) - df_A.to_hdf(filename, "df_A") - df_B.to_hdf(filename, "df_B") + df_A.to_hdf(filename, key="df_A") + df_B.to_hdf(filename, key="df_B") loaded_A = cebra_load.load(filename, key="df_A") return A, loaded_A @@ -282,7 +279,7 @@ def generate_h5_multi_dataframe(filename, dtype): def generate_h5_single_dataframe_no_key(filename, dtype): A = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]).astype(dtype) df_A = pd.DataFrame(np.array(A), columns=["a", "b", "c"]) - df_A.to_hdf(filename, "df_A") + df_A.to_hdf(filename, key="df_A") loaded_A = cebra_load.load(filename) return A, loaded_A @@ -293,8 +290,8 @@ def generate_h5_multi_dataframe_no_key(filename, dtype): B = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]).astype(dtype) df_A = pd.DataFrame(np.array(A), columns=["a", "b", "c"]) df_B = pd.DataFrame(np.array(B), columns=["c", "d", "e"]) - df_A.to_hdf(filename, "df_A") - df_B.to_hdf(filename, "df_B") + df_A.to_hdf(filename, key="df_A") + df_B.to_hdf(filename, key="df_B") _ = cebra_load.load(filename) @@ -307,7 +304,7 @@ def generate_h5_multicol_dataframe(filename, dtype): df_A = pd.DataFrame(A, columns=pd.MultiIndex.from_product([animals, keypoints])) - df_A.to_hdf(filename, "df_A") + df_A.to_hdf(filename, key="df_A") loaded_A = cebra_load.load(filename, key="df_A") return A, loaded_A @@ -316,7 +313,7 @@ def generate_h5_multicol_dataframe(filename, dtype): def generate_h5_dataframe_invalid_key(filename, dtype): A = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]).astype(dtype) df_A = pd.DataFrame(np.array(A), columns=["a", "b", "c"]) - df_A.to_hdf(filename, "df_A") + df_A.to_hdf(filename, key="df_A") _ = cebra_load.load(filename, key="df_B") @@ -324,7 +321,7 @@ def generate_h5_dataframe_invalid_key(filename, dtype): def generate_h5_dataframe_invalid_column(filename, dtype): A = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]).astype(dtype) df_A = pd.DataFrame(np.array(A), columns=["a", "b", "c"]) - df_A.to_hdf(filename, "df_A") + df_A.to_hdf(filename, key="df_A") _ = cebra_load.load(filename, key="df_A", columns=["d", "b"]) @@ -337,7 +334,7 @@ def generate_h5_multicol_dataframe_columns(filename, dtype): df_A = pd.DataFrame(A, columns=pd.MultiIndex.from_product([animals, keypoints])) - df_A.to_hdf(filename, "df_A") + df_A.to_hdf(filename, key="df_A") _ = cebra_load.load(filename, key="df_A", columns=["a", "b"]) @@ -418,7 +415,7 @@ def generate_csv_path(filename, dtype): @register_error("csv") def generate_csv_empty_file(filename, dtype): - with open(filename, "w") as creating_new_csv_file: + with open(filename, "w") as _: pass _ = cebra_load.load(filename) @@ -619,7 +616,6 @@ def generate_pickle_invalid_key(filename, dtype): @register_error("pkl", "p") def generate_pickle_no_array(filename, dtype): - A = np.arange(1000, dtype=dtype).reshape(10, 100) with open(filename, "wb") as f: pickle.dump({"A": "test_1", "B": "test_2"}, f) _ = cebra_load.load(filename) diff --git a/tests/test_loader.py b/tests/test_loader.py index 562f64a7..cb6be9a7 100644 --- a/tests/test_loader.py +++ b/tests/test_loader.py @@ -19,16 +19,16 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import _util +import numpy as np import pytest import torch import cebra.data import cebra.io - -def parametrize_device(func): - _devices = ("cpu", "cuda") if torch.cuda.is_available() else ("cpu",) - return pytest.mark.parametrize("device", _devices)(func) +BATCH_SIZE = 32 +NUMS_NEURAL = [3, 4, 5] class LoadSpeed: @@ -107,7 +107,11 @@ def _assert_dataset_on_correct_device(loader, device): assert hasattr(loader, "dataset") assert hasattr(loader, "device") assert isinstance(loader.dataset, cebra.io.HasDevice) - assert loader.dataset.neural.device.type == device + if isinstance(loader, cebra.data.SingleSessionDataset): + assert loader.dataset.neural.device.type == device + elif isinstance(loader, cebra.data.MultiSessionDataset): + for session in loader.dataset.iter_sessions(): + assert session.neural.device.type == device def test_demo_data(): @@ -130,13 +134,15 @@ def _to_str(val): assert _to_str(first) == _to_str(second) -@parametrize_device +@_util.parametrize_device @pytest.mark.parametrize( "data_name, loader_initfunc", [ ("demo-discrete", cebra.data.DiscreteDataLoader), ("demo-continuous", cebra.data.ContinuousDataLoader), ("demo-mixed", cebra.data.MixedDataLoader), + ("demo-continuous-multisession", cebra.data.MultiSessionLoader), + ("demo-continuous-unified", cebra.data.UnifiedLoader), ], ) def test_device(data_name, loader_initfunc, device): @@ -147,7 +153,7 @@ def test_device(data_name, loader_initfunc, device): other_device = swap.get(device) dataset = RandomDataset(N=100, device=other_device) - loader = loader_initfunc(dataset, num_steps=10, batch_size=32) + loader = loader_initfunc(dataset, num_steps=10, batch_size=BATCH_SIZE) loader.to(device) assert loader.dataset == dataset _assert_device(loader.device, device) @@ -156,7 +162,7 @@ def test_device(data_name, loader_initfunc, device): _assert_device(loader.get_indices(10).reference.device, device) -@parametrize_device +@_util.parametrize_device @pytest.mark.parametrize("prior", ("uniform", "empirical")) def test_discrete(prior, device, benchmark): dataset = RandomDataset(N=100, device=device) @@ -171,7 +177,7 @@ def test_discrete(prior, device, benchmark): benchmark(load_speed) -@parametrize_device +@_util.parametrize_device @pytest.mark.parametrize("conditional", ("time", "time_delta")) def test_continuous(conditional, device, benchmark): dataset = RandomDataset(N=100, d=5, device=device) @@ -199,7 +205,7 @@ def _check_attributes(obj, is_list=False): raise TypeError() -@parametrize_device +@_util.parametrize_device @pytest.mark.parametrize( "data_name, loader_initfunc", [ @@ -211,7 +217,7 @@ def _check_attributes(obj, is_list=False): def test_singlesession_loader(data_name, loader_initfunc, device): data = cebra.datasets.init(data_name) data.to(device) - loader = loader_initfunc(data, num_steps=10, batch_size=32) + loader = loader_initfunc(data, num_steps=10, batch_size=BATCH_SIZE) _assert_dataset_on_correct_device(loader, device) index = loader.get_indices(100) @@ -219,25 +225,33 @@ def test_singlesession_loader(data_name, loader_initfunc, device): for batch in loader: _check_attributes(batch) - assert len(batch.positive) == 32 + assert len(batch.positive) == BATCH_SIZE -def test_multisession_cont_loader(): - data = cebra.datasets.MultiContinuous(nums_neural=[3, 4, 5], - num_behavior=5, - num_timepoints=100) - loader = cebra.data.ContinuousMultiSessionDataLoader( - data, - num_steps=10, - batch_size=32, - ) +@_util.parametrize_device +@pytest.mark.parametrize( + "data_name, loader_initfunc", + [ + ("demo-continuous-multisession", + cebra.data.ContinuousMultiSessionDataLoader), + ("demo-discrete-multisession", + cebra.data.DiscreteMultiSessionDataLoader), + ], +) +def test_multisession_loader(data_name, loader_initfunc, device): + data = cebra.datasets.init(data_name) + data.to(device) + loader = loader_initfunc(data, num_steps=10, batch_size=BATCH_SIZE) + + _assert_dataset_on_correct_device(loader, device) # Check the sampler assert hasattr(loader, "sampler") ref_idx = loader.sampler.sample_prior(1000) - assert len(ref_idx) == 3 # num_sessions - for session in range(3): - assert ref_idx[session].max() < 100 + assert len(ref_idx) == len(NUMS_NEURAL) + for session in range(len(NUMS_NEURAL)): + assert ref_idx[session].max( + ) < cebra.datasets.demo._DEFAULT_NUM_TIMEPOINTS pos_idx, idx, idx_rev = loader.sampler.sample_conditional(ref_idx) assert pos_idx is not None @@ -245,6 +259,8 @@ def test_multisession_cont_loader(): assert idx_rev is not None batch = next(iter(loader)) + for i, n_neurons in enumerate(NUMS_NEURAL): + assert batch[i].reference.shape == (BATCH_SIZE, n_neurons, 10) def _mix(array, idx): shape = array.shape @@ -259,82 +275,70 @@ def _process(batch, feature_dim=1): [b.reference.flatten(1).mean(dim=1, keepdims=True) for b in batch], dim=0).repeat(1, 1, feature_dim) - assert batch[0].reference.shape == (32, 3, 10) - assert batch[1].reference.shape == (32, 4, 10) - assert batch[2].reference.shape == (32, 5, 10) - dummy_prediction = _process(batch, feature_dim=6) - assert dummy_prediction.shape == (3, 32, 6) + assert dummy_prediction.shape == (3, BATCH_SIZE, 6) _mix(dummy_prediction, batch[0].index) + index = loader.get_indices(100) + #print(index[0]) + #print(type(index)) + _check_attributes(index, is_list=False) -def test_multisession_disc_loader(): - data = cebra.datasets.MultiDiscrete(nums_neural=[3, 4, 5], - num_timepoints=100) - loader = cebra.data.DiscreteMultiSessionDataLoader( - data, - num_steps=10, - batch_size=32, - ) - - # Check the sampler - assert hasattr(loader, "sampler") - ref_idx = loader.sampler.sample_prior(1000) - assert len(ref_idx) == 3 # num_sessions - - # Check sample points are in session length range - for session in range(3): - assert ref_idx[session].max() < loader.sampler.session_lengths[session] - pos_idx, idx, idx_rev = loader.sampler.sample_conditional(ref_idx) - - assert pos_idx is not None - assert idx is not None - assert idx_rev is not None - - batch = next(iter(loader)) - - def _mix(array, idx): - shape = array.shape - n, m = shape[:2] - mixed = array.reshape(n * m, -1)[idx] - print(mixed.shape, array.shape, idx.shape) - return mixed.reshape(shape) - - def _process(batch, feature_dim=1): - """Given list_i[(N,d_i)] batch, return (#session, N, feature_dim) tensor""" - return torch.stack( - [b.reference.flatten(1).mean(dim=1, keepdims=True) for b in batch], - dim=0).repeat(1, 1, feature_dim) - - assert batch[0].reference.shape == (32, 3, 10) - assert batch[1].reference.shape == (32, 4, 10) - assert batch[2].reference.shape == (32, 5, 10) - - dummy_prediction = _process(batch, feature_dim=6) - assert dummy_prediction.shape == (3, 32, 6) - _mix(dummy_prediction, batch[0].index) + for batch in loader: + _check_attributes(batch, is_list=True) + for session_batch in batch: + assert len(session_batch.positive) == BATCH_SIZE -@parametrize_device +@_util.parametrize_device @pytest.mark.parametrize( "data_name, loader_initfunc", - [('demo-discrete-multisession', cebra.data.DiscreteMultiSessionDataLoader), - ("demo-continuous-multisession", - cebra.data.ContinuousMultiSessionDataLoader)], + [ + ("demo-continuous-unified", cebra.data.UnifiedLoader), + ], ) -def test_multisession_loader(data_name, loader_initfunc, device): - # TODO change number of timepoints across the sessions - +def test_unified_loader(data_name, loader_initfunc, device): data = cebra.datasets.init(data_name) - kwargs = dict(num_steps=10, batch_size=32) - loader = loader_initfunc(data, **kwargs) + data.to(device) + loader = loader_initfunc(data, num_steps=10, batch_size=BATCH_SIZE) + + _assert_dataset_on_correct_device(loader, device) + + # Check the sampler + num_samples = 100 + assert hasattr(loader, "sampler") + ref_idx = loader.sampler.sample_all_uniform_prior(num_samples) + assert ref_idx.shape == (len(NUMS_NEURAL), num_samples) + assert isinstance(ref_idx, np.ndarray) + + for session in range(len(NUMS_NEURAL)): + assert ref_idx[session].max( + ) < cebra.datasets.demo._DEFAULT_NUM_TIMEPOINTS + pos_idx = loader.sampler.sample_conditional(ref_idx) + assert pos_idx.shape == (len(NUMS_NEURAL), num_samples) + + for session in range(len(NUMS_NEURAL)): + ref_idx = torch.from_numpy( + loader.sampler.sample_all_uniform_prior( + num_samples=num_samples)[session]) + assert ref_idx.shape == (num_samples,) + all_ref_idx = loader.sampler.sample_all_sessions(ref_idx=ref_idx, + session_id=session) + assert all_ref_idx.shape == (len(NUMS_NEURAL), num_samples) + assert isinstance(all_ref_idx, torch.Tensor) + for i in range(len(all_ref_idx)): + assert all_ref_idx[i].max( + ) < cebra.datasets.demo._DEFAULT_NUM_TIMEPOINTS + + for i in range(len(all_ref_idx)): + pos_idx = loader.sampler.sample_conditional(all_ref_idx) + assert pos_idx.shape == (len(NUMS_NEURAL), num_samples) + + # Check the batch + batch = next(iter(loader)) + assert batch.reference.shape == (BATCH_SIZE, sum(NUMS_NEURAL), 10) + assert batch.positive.shape == (BATCH_SIZE, sum(NUMS_NEURAL), 10) + assert batch.negative.shape == (BATCH_SIZE, sum(NUMS_NEURAL), 10) index = loader.get_indices(100) - print(index[0]) - print(type(index)) _check_attributes(index, is_list=False) - - for batch in loader: - _check_attributes(batch, is_list=True) - for session_batch in batch: - assert len(session_batch.positive) == 32 diff --git a/tests/test_models.py b/tests/test_models.py index 2a6e4812..658cc467 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -90,6 +90,10 @@ def test_offset_models(model_name, batch_size, input_length): def test_multiobjective(): + # NOTE(stes): This test is deprecated and will be removed in a future version. + # As of CEBRA 0.6.0, the multi objective models are tested separately in + # test_multiobjective.py. + class TestModel(cebra.models.Model): def __init__(self): @@ -155,8 +159,8 @@ def test_version_check(version, raises): cebra.models.model._check_torch_version(raise_error=True) -def test_version_check(): - raises = not cebra.models.model._check_torch_version(raise_error=False) +def test_version_check_dropout_available(): + raises = cebra.models.model._check_torch_version(raise_error=False) if raises: assert len(cebra.models.get_options("*dropout*")) == 0 else: diff --git a/tests/test_multiobjective.py b/tests/test_multiobjective.py new file mode 100644 index 00000000..a4c601ac --- /dev/null +++ b/tests/test_multiobjective.py @@ -0,0 +1,145 @@ +# +# CEBRA: Consistent EmBeddings of high-dimensional Recordings using Auxiliary variables +# © Mackenzie W. Mathis & Steffen Schneider (v0.4.0+) +# Source code: +# https://github.com/AdaptiveMotorControlLab/CEBRA +# +# Please see LICENSE.md for the full license document: +# https://github.com/AdaptiveMotorControlLab/CEBRA/blob/main/LICENSE.md +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import warnings + +import pytest +import torch + +import cebra +from cebra.data import ContrastiveMultiObjectiveLoader +from cebra.data import DatasetxCEBRA +from cebra.solver import MultiObjectiveConfig + + +@pytest.fixture +def config(): + neurons = torch.randn(100, 5) + behavior1 = torch.randn(100, 2) + behavior2 = torch.randn(100, 1) + data = DatasetxCEBRA(neurons, behavior1=behavior1, behavior2=behavior2) + loader = ContrastiveMultiObjectiveLoader(dataset=data, + num_steps=1, + batch_size=24) + return MultiObjectiveConfig(loader) + + +def test_imports(): + pass + + +def test_add_data(config): + config.set_slice(0, 10) + config.set_loss('loss_name', param1='value1') + config.set_distribution('distribution_name', param2='value2') + config.push() + + assert len(config.total_info) == 1 + assert config.total_info[0]['slice'] == (0, 10) + assert config.total_info[0]['losses'] == { + "name": 'loss_name', + "kwargs": { + 'param1': 'value1' + } + } + assert config.total_info[0]['distributions'] == { + "name": 'distribution_name', + "kwargs": { + 'param2': 'value2' + } + } + + +def test_overwriting_key_warning(config): + with warnings.catch_warnings(record=True) as w: + config.set_slice(0, 10) + config.set_slice(10, 20) + assert len(w) == 1 + assert issubclass(w[-1].category, UserWarning) + assert "Configuration key already exists" in str(w[-1].message) + + +def test_missing_slice_error(config): + with pytest.raises(RuntimeError, match="Slice configuration is missing"): + config.set_loss('loss_name', param1='value1') + config.set_distribution('distribution_name', param2='value2') + config.push() + + +def test_missing_distributions_error(config): + with pytest.raises(RuntimeError, + match="Distributions configuration is missing"): + config.set_slice(0, 10) + config.set_loss('loss_name', param1='value1') + config.push() + + +def test_missing_losses_error(config): + with pytest.raises(RuntimeError, match="Losses configuration is missing"): + config.set_slice(0, 10) + config.set_distribution('distribution_name', param2='value2') + config.push() + + +def test_finalize(config): + config.set_slice(0, 6) + config.set_loss("FixedEuclideanInfoNCE", temperature=1.) + config.set_distribution("time", time_offset=1) + config.push() + + config.set_slice(3, 6) + config.set_loss("FixedEuclideanInfoNCE", temperature=1.) + config.set_distribution("time_delta", time_delta=3, label_name="behavior2") + config.push() + + config.finalize() + + assert len(config.losses) == 2 + assert config.losses[0]['indices'] == (0, 6) + assert config.losses[1]['indices'] == (3, 6) + + assert len(config.feature_ranges) == 2 + assert config.feature_ranges[0] == slice(0, 6) + assert config.feature_ranges[1] == slice(3, 6) + + assert len(config.loader.distributions) == 2 + assert isinstance(config.loader.distributions[0], + cebra.distributions.continuous.TimeContrastive) + assert config.loader.distributions[0].time_offset == 1 + + assert isinstance(config.loader.distributions[1], + cebra.distributions.continuous.TimedeltaDistribution) + assert config.loader.distributions[1].time_delta == 3 + + +def test_non_unique_feature_ranges_error(config): + config.set_slice(0, 10) + config.set_loss("FixedEuclideanInfoNCE", temperature=1.) + config.set_distribution("time", time_offset=1) + config.push() + + config.set_slice(0, 10) + config.set_loss("FixedEuclideanInfoNCE", temperature=1.) + config.set_distribution("time_delta", time_delta=3, label_name="behavior2") + config.push() + + with pytest.raises(RuntimeError, match="Feature ranges are not unique"): + config.finalize() diff --git a/tests/test_plot.py b/tests/test_plot.py index 3f44d887..1d94d310 100644 --- a/tests/test_plot.py +++ b/tests/test_plot.py @@ -72,8 +72,6 @@ def test_plot_imports(): def test_colormaps(): import matplotlib - import cebra - cmap = matplotlib.colormaps["cebra"] assert cmap is not None plt.scatter([1], [2], c=[2], cmap="cebra") @@ -241,7 +239,7 @@ def test_compare_models(): _ = cebra_plot.compare_models(models, labels=long_labels, ax=ax) with pytest.raises(ValueError, match="Invalid.*labels"): invalid_labels = copy.deepcopy(labels) - ele = invalid_labels.pop() + _ = invalid_labels.pop() invalid_labels.append(["a"]) _ = cebra_plot.compare_models(models, labels=invalid_labels, ax=ax) diff --git a/tests/test_registry.py b/tests/test_registry.py index 69e04f38..cd27344c 100644 --- a/tests/test_registry.py +++ b/tests/test_registry.py @@ -117,7 +117,7 @@ def test_override(): _Foo1 = test_module.register("foo")(Foo) assert _Foo1 == Foo assert _Foo1 != Bar - assert f"foo" in test_module.get_options() + assert "foo" in test_module.get_options() # Check that the class was actually added to the module assert ( @@ -137,7 +137,7 @@ def test_override(): _Foo2 = test_module.register("foo", override=True)(Bar) assert _Foo2 != Foo assert _Foo2 == Bar - assert f"foo" in test_module.get_options() + assert "foo" in test_module.get_options() def test_depreciation(): @@ -145,7 +145,7 @@ def test_depreciation(): Foo = _make_class() _Foo1 = test_module.register("foo")(Foo) assert _Foo1 == Foo - assert f"foo" in test_module.get_options() + assert "foo" in test_module.get_options() # Registering the same class under different names # also raises and error diff --git a/tests/test_sklearn.py b/tests/test_sklearn.py index 33df3caf..8c7cd0a1 100644 --- a/tests/test_sklearn.py +++ b/tests/test_sklearn.py @@ -24,6 +24,7 @@ import warnings import _util +import _utils_deprecated import numpy as np import pkg_resources import pytest @@ -276,7 +277,6 @@ def test_api(estimator, check): pytest.skip(f"Model architecture {estimator.model_architecture} " f"requires longer input sizes than 20 samples.") - success = True exception = None num_successful = 0 total_runs = 0 @@ -334,7 +334,6 @@ def test_sklearn(model_architecture, device): y_c1 = np.random.uniform(0, 1, (1000, 5)) y_c1_s2 = np.random.uniform(0, 1, (800, 5)) y_c2 = np.random.uniform(0, 1, (1000, 2)) - y_c2_s2 = np.random.uniform(0, 1, (800, 2)) y_d = np.random.randint(0, 10, (1000,)) y_d_s2 = np.random.randint(0, 10, (800,)) @@ -863,7 +862,6 @@ def test_sklearn_full(model_architecture, device, pad_before_transform): X = np.random.uniform(0, 1, (1000, 50)) y_c1 = np.random.uniform(0, 1, (1000, 5)) y_c2 = np.random.uniform(0, 1, (1000, 2)) - y_d = np.random.randint(0, 10, (1000,)) # time contrastive cebra_model.fit(X) @@ -931,7 +929,7 @@ def test_sklearn_resampling_model_not_yet_supported(model_architecture, device): with pytest.raises(ValueError): cebra_model.fit(X, y_c1) - output = cebra_model.transform(X) + _ = cebra_model.transform(X) def _iterate_actions(): @@ -1378,18 +1376,16 @@ def test_new_transform(model_architecture, device): # example dataset X = np.random.uniform(0, 1, (1000, 50)) X_s2 = np.random.uniform(0, 1, (800, 30)) - X_s3 = np.random.uniform(0, 1, (1000, 30)) y_c1 = np.random.uniform(0, 1, (1000, 5)) y_c1_s2 = np.random.uniform(0, 1, (800, 5)) y_c2 = np.random.uniform(0, 1, (1000, 2)) - y_c2_s2 = np.random.uniform(0, 1, (800, 2)) y_d = np.random.randint(0, 10, (1000,)) y_d_s2 = np.random.randint(0, 10, (800,)) # time contrastive cebra_model.fit(X) embedding1 = cebra_model.transform(X) - embedding2 = cebra_model.transform_deprecated(X) + embedding2 = _utils_deprecated.cebra_transform_deprecated(cebra_model, X) assert np.allclose(embedding1, embedding2, rtol=1e-5, atol=1e-8), "Arrays are not close enough" @@ -1398,17 +1394,20 @@ def test_new_transform(model_architecture, device): assert cebra_model.num_sessions is None embedding1 = cebra_model.transform(X) - embedding2 = cebra_model.transform_deprecated(X) + embedding2 = _utils_deprecated.cebra_transform_deprecated(cebra_model, X) assert np.allclose(embedding1, embedding2, rtol=1e-5, atol=1e-8), "Arrays are not close enough" embedding1 = cebra_model.transform(torch.Tensor(X)) - embedding2 = cebra_model.transform_deprecated(torch.Tensor(X)) + embedding2 = _utils_deprecated.cebra_transform_deprecated( + cebra_model, torch.Tensor(X)) assert np.allclose(embedding1, embedding2, rtol=1e-5, atol=1e-8), "Arrays are not close enough" embedding1 = cebra_model.transform(torch.Tensor(X), session_id=0) - embedding2 = cebra_model.transform_deprecated(torch.Tensor(X), session_id=0) + embedding2 = _utils_deprecated.cebra_transform_deprecated(cebra_model, + torch.Tensor(X), + session_id=0) assert np.allclose(embedding1, embedding2, rtol=1e-5, atol=1e-8), "Arrays are not close enough" @@ -1418,14 +1417,14 @@ def test_new_transform(model_architecture, device): # discrete behavior contrastive cebra_model.fit(X, y_d) embedding1 = cebra_model.transform(X) - embedding2 = cebra_model.transform_deprecated(X) + embedding2 = _utils_deprecated.cebra_transform_deprecated(cebra_model, X) assert np.allclose(embedding1, embedding2, rtol=1e-5, atol=1e-8), "Arrays are not close enough" # mixed cebra_model.fit(X, y_c1, y_c2, y_d) embedding1 = cebra_model.transform(X) - embedding2 = cebra_model.transform_deprecated(X) + embedding2 = _utils_deprecated.cebra_transform_deprecated(cebra_model, X) assert np.allclose(embedding1, embedding2, rtol=1e-5, atol=1e-8), "Arrays are not close enough" @@ -1433,17 +1432,23 @@ def test_new_transform(model_architecture, device): cebra_model.fit([X, X_s2], [y_d, y_d_s2]) embedding1 = cebra_model.transform(X, session_id=0) - embedding2 = cebra_model.transform_deprecated(X, session_id=0) + embedding2 = _utils_deprecated.cebra_transform_deprecated(cebra_model, + X, + session_id=0) assert np.allclose(embedding1, embedding2, rtol=1e-5, atol=1e-8), "Arrays are not close enough" embedding1 = cebra_model.transform(torch.Tensor(X), session_id=0) - embedding2 = cebra_model.transform_deprecated(torch.Tensor(X), session_id=0) + embedding2 = _utils_deprecated.cebra_transform_deprecated(cebra_model, + torch.Tensor(X), + session_id=0) assert np.allclose(embedding1, embedding2, rtol=1e-5, atol=1e-8), "Arrays are not close enough" embedding1 = cebra_model.transform(X_s2, session_id=1) - embedding2 = cebra_model.transform_deprecated(X_s2, session_id=1) + embedding2 = _utils_deprecated.cebra_transform_deprecated(cebra_model, + X_s2, + session_id=1) assert np.allclose(embedding1, embedding2, rtol=1e-5, atol=1e-8), "Arrays are not close enough" @@ -1451,12 +1456,16 @@ def test_new_transform(model_architecture, device): cebra_model.fit([X, X_s2], [y_c1, y_c1_s2]) embedding1 = cebra_model.transform(X, session_id=0) - embedding2 = cebra_model.transform_deprecated(X, session_id=0) + embedding2 = _utils_deprecated.cebra_transform_deprecated(cebra_model, + X, + session_id=0) assert np.allclose(embedding1, embedding2, rtol=1e-5, atol=1e-8), "Arrays are not close enough" embedding1 = cebra_model.transform(torch.Tensor(X), session_id=0) - embedding2 = cebra_model.transform_deprecated(torch.Tensor(X), session_id=0) + embedding2 = _utils_deprecated.cebra_transform_deprecated(cebra_model, + torch.Tensor(X), + session_id=0) assert np.allclose(embedding1, embedding2, rtol=1e-5, atol=1e-8), "Arrays are not close enough" @@ -1475,17 +1484,23 @@ def test_new_transform(model_architecture, device): cebra_model.fit([X, X_s2, X], [y_d, y_d_s2, y_d]) embedding1 = cebra_model.transform(X, session_id=0) - embedding2 = cebra_model.transform_deprecated(X, session_id=0) + embedding2 = _utils_deprecated.cebra_transform_deprecated(cebra_model, + X, + session_id=0) assert np.allclose(embedding1, embedding2, rtol=1e-5, atol=1e-8), "Arrays are not close enough" embedding1 = cebra_model.transform(X_s2, session_id=1) - embedding2 = cebra_model.transform_deprecated(X_s2, session_id=1) + embedding2 = _utils_deprecated.cebra_transform_deprecated(cebra_model, + X_s2, + session_id=1) assert np.allclose(embedding1, embedding2, rtol=1e-5, atol=1e-8), "Arrays are not close enough" embedding1 = cebra_model.transform(X, session_id=2) - embedding2 = cebra_model.transform_deprecated(X, session_id=2) + embedding2 = _utils_deprecated.cebra_transform_deprecated(cebra_model, + X, + session_id=2) assert np.allclose(embedding1, embedding2, rtol=1e-5, atol=1e-8), "Arrays are not close enough" @@ -1493,25 +1508,31 @@ def test_new_transform(model_architecture, device): cebra_model.fit([X, X_s2, X], [y_c1, y_c1_s2, y_c1]) embedding1 = cebra_model.transform(X, session_id=0) - embedding2 = cebra_model.transform_deprecated(X, session_id=0) + embedding2 = _utils_deprecated.cebra_transform_deprecated(cebra_model, + X, + session_id=0) assert np.allclose(embedding1, embedding2, rtol=1e-5, atol=1e-8), "Arrays are not close enough" embedding1 = cebra_model.transform(X_s2, session_id=1) - embedding2 = cebra_model.transform_deprecated(X_s2, session_id=1) + embedding2 = _utils_deprecated.cebra_transform_deprecated(cebra_model, + X_s2, + session_id=1) assert np.allclose(embedding1, embedding2, rtol=1e-5, atol=1e-8), "Arrays are not close enough" embedding1 = cebra_model.transform(X, session_id=2) - embedding2 = cebra_model.transform_deprecated(X, session_id=2) + embedding2 = _utils_deprecated.cebra_transform_deprecated(cebra_model, + X, + session_id=2) assert np.allclose(embedding1, embedding2, rtol=1e-5, atol=1e-8), "Arrays are not close enough" def test_last_incomplete_batch_smaller_than_offset(): """ - When offset of the model is larger than the remaining samples in the - last batch, an error could happen. We merge the penultimate + When offset of the model is larger than the remaining samples in the + last batch, an error could happen. We merge the penultimate and last batches together to avoid this. """ train = cebra.data.TensorDataset(neural=np.random.rand(20111, 100), @@ -1522,4 +1543,4 @@ def test_last_incomplete_batch_smaller_than_offset(): device="cpu") model.fit(train.neural, train.continuous) - _ = model.transform(train.neural, batch_size=300) \ No newline at end of file + _ = model.transform(train.neural, batch_size=300) diff --git a/tests/test_sklearn_legacy.py b/tests/test_sklearn_legacy.py new file mode 100644 index 00000000..4d74515f --- /dev/null +++ b/tests/test_sklearn_legacy.py @@ -0,0 +1,41 @@ +import pathlib +import urllib.request + +import numpy as np +import pytest + +from cebra.integrations.sklearn.cebra import CEBRA + +MODEL_VARIANTS = [ + "cebra-0.4.0-scikit-learn-1.4", "cebra-0.4.0-scikit-learn-1.6", + "cebra-rc-scikit-learn-1.4", "cebra-rc-scikit-learn-1.6" +] + + +@pytest.mark.parametrize("model_variant", MODEL_VARIANTS) +def test_load_legacy_model(model_variant): + """Test loading a legacy CEBRA model.""" + + X = np.random.normal(0, 1, (1000, 30)) + + model_path = pathlib.Path( + __file__ + ).parent / "_build_legacy_model" / f"cebra_model_{model_variant}.pt" + + if not model_path.exists(): + url = f"https://cebra.fra1.digitaloceanspaces.com/cebra_model_{model_variant}.pt" + urllib.request.urlretrieve(url, model_path) + + loaded_model = CEBRA.load(model_path) + + assert loaded_model.model_architecture == "offset10-model" + assert loaded_model.output_dimension == 8 + assert loaded_model.num_hidden_units == 16 + assert loaded_model.time_offsets == 10 + + output = loaded_model.transform(X) + assert isinstance(output, np.ndarray) + assert output.shape[1] == loaded_model.output_dimension + + assert hasattr(loaded_model, "state_dict_") + assert hasattr(loaded_model, "n_features_") diff --git a/tests/test_sklearn_metrics.py b/tests/test_sklearn_metrics.py index 58e12010..4e765ba7 100644 --- a/tests/test_sklearn_metrics.py +++ b/tests/test_sklearn_metrics.py @@ -383,3 +383,132 @@ def test_sklearn_runs_consistency(): with pytest.raises(ValueError, match="Invalid.*embeddings"): _, _, _ = cebra_sklearn_metrics.consistency_score( invalid_embeddings_runs, between="runs") + + +@pytest.mark.parametrize("seed", [42, 24, 10]) +def test_goodness_of_fit_score(seed): + """ + Ensure that the GoF score is close to 0 for a model fit on random data. + """ + cebra_model = cebra_sklearn_cebra.CEBRA( + model_architecture="offset1-model", + max_iterations=5, + batch_size=512, + ) + generator = torch.Generator().manual_seed(seed) + X = torch.rand(5000, 50, dtype=torch.float32, generator=generator) + y = torch.rand(5000, 5, dtype=torch.float32, generator=generator) + cebra_model.fit(X, y) + score = cebra_sklearn_metrics.goodness_of_fit_score(cebra_model, + X, + y, + session_id=0, + num_batches=500) + assert isinstance(score, float) + assert np.isclose(score, 0, atol=0.01) + + +@pytest.mark.parametrize("seed", [42, 24, 10]) +def test_goodness_of_fit_history(seed): + """ + Ensure that the GoF score is higher for a model fit on data with underlying + structure than for a model fit on random data. + """ + + # Generate data + generator = torch.Generator().manual_seed(seed) + X = torch.rand(1000, 50, dtype=torch.float32, generator=generator) + y_random = torch.rand(len(X), 5, dtype=torch.float32, generator=generator) + linear_map = torch.randn(50, 5, dtype=torch.float32, generator=generator) + y_linear = X @ linear_map + + def _fit_and_get_history(X, y): + cebra_model = cebra_sklearn_cebra.CEBRA( + model_architecture="offset1-model", + max_iterations=150, + batch_size=512, + device="cpu") + cebra_model.fit(X, y) + history = cebra_sklearn_metrics.goodness_of_fit_history(cebra_model) + # NOTE(stes): Ignore the first 5 iterations, they can have nonsensical values + # due to numerical issues. + return history[5:] + + history_random = _fit_and_get_history(X, y_random) + history_linear = _fit_and_get_history(X, y_linear) + + assert isinstance(history_random, np.ndarray) + assert history_random.shape[0] > 0 + # NOTE(stes): Ignore the first 5 iterations, they can have nonsensical values + # due to numerical issues. + history_random_non_negative = history_random[history_random >= 0] + np.testing.assert_allclose(history_random_non_negative, 0, atol=0.075) + + assert isinstance(history_linear, np.ndarray) + assert history_linear.shape[0] > 0 + + assert np.all(history_linear[-20:] > history_random[-20:]) + + +@pytest.mark.parametrize("seed", [42, 24, 10]) +def test_infonce_to_goodness_of_fit(seed): + """Test the conversion from InfoNCE loss to goodness of fit metric.""" + # Test with model + cebra_model = cebra_sklearn_cebra.CEBRA( + model_architecture="offset10-model", + max_iterations=5, + batch_size=128, + ) + generator = torch.Generator().manual_seed(seed) + X = torch.rand(1000, 50, dtype=torch.float32, generator=generator) + cebra_model.fit(X) + + # Test single value + gof = cebra_sklearn_metrics.infonce_to_goodness_of_fit(1.0, + model=cebra_model) + assert isinstance(gof, float) + + # Test array of values + infonce_values = np.array([1.0, 2.0, 3.0]) + gof_array = cebra_sklearn_metrics.infonce_to_goodness_of_fit( + infonce_values, model=cebra_model) + assert isinstance(gof_array, np.ndarray) + assert gof_array.shape == infonce_values.shape + + # Test with explicit batch_size and num_sessions + gof = cebra_sklearn_metrics.infonce_to_goodness_of_fit(1.0, + batch_size=128, + num_sessions=1) + assert isinstance(gof, float) + + # Test error cases + with pytest.raises(ValueError, match="batch_size.*should not be provided"): + cebra_sklearn_metrics.infonce_to_goodness_of_fit(1.0, + model=cebra_model, + batch_size=128) + + with pytest.raises(ValueError, match="batch_size.*should not be provided"): + cebra_sklearn_metrics.infonce_to_goodness_of_fit(1.0, + model=cebra_model, + num_sessions=1) + + # Test with unfitted model + unfitted_model = cebra_sklearn_cebra.CEBRA(max_iterations=5) + with pytest.raises(RuntimeError, match="Fit the CEBRA model first"): + cebra_sklearn_metrics.infonce_to_goodness_of_fit(1.0, + model=unfitted_model) + + # Test with model having batch_size=None + none_batch_model = cebra_sklearn_cebra.CEBRA(batch_size=None, + max_iterations=5) + none_batch_model.fit(X) + with pytest.raises(ValueError, match="Computing the goodness of fit"): + cebra_sklearn_metrics.infonce_to_goodness_of_fit(1.0, + model=none_batch_model) + + # Test missing batch_size or num_sessions when model is None + with pytest.raises(ValueError, match="batch_size.*and num_sessions"): + cebra_sklearn_metrics.infonce_to_goodness_of_fit(1.0, batch_size=128) + + with pytest.raises(ValueError, match="batch_size.*and num_sessions"): + cebra_sklearn_metrics.infonce_to_goodness_of_fit(1.0, num_sessions=1) diff --git a/tests/test_solver.py b/tests/test_solver.py index 68e2a43e..5cbbbfb3 100644 --- a/tests/test_solver.py +++ b/tests/test_solver.py @@ -34,41 +34,12 @@ device = "cpu" -single_session_tests = [] -for args in [ - ("demo-discrete", cebra.data.DiscreteDataLoader, "offset10-model"), - ("demo-discrete", cebra.data.DiscreteDataLoader, "offset1-model"), - ("demo-discrete", cebra.data.DiscreteDataLoader, "offset1-model"), - ("demo-discrete", cebra.data.DiscreteDataLoader, "offset10-model"), - ("demo-continuous", cebra.data.ContinuousDataLoader, "offset10-model"), - ("demo-continuous", cebra.data.ContinuousDataLoader, "offset1-model"), - ("demo-mixed", cebra.data.MixedDataLoader, "offset10-model"), - ("demo-mixed", cebra.data.MixedDataLoader, "offset1-model"), -]: - single_session_tests.append((*args, cebra.solver.SingleSessionSolver)) - -single_session_hybrid_tests = [] -for args in [("demo-continuous", cebra.data.HybridDataLoader, "offset10-model"), - ("demo-continuous", cebra.data.HybridDataLoader, "offset1-model")]: - single_session_hybrid_tests.append( - (*args, cebra.solver.SingleSessionHybridSolver)) - -multi_session_tests = [] -for args in [ - ("demo-continuous-multisession", - cebra.data.ContinuousMultiSessionDataLoader, "offset1-model"), - ("demo-continuous-multisession", - cebra.data.ContinuousMultiSessionDataLoader, "offset10-model"), -]: - multi_session_tests.append((*args, cebra.solver.MultiSessionSolver)) - -# multi_session_tests.append((*args, cebra.solver.MultiSessionAuxVariableSolver)) - - -def _get_loader(data, loader_initfunc): - kwargs = dict(num_steps=5, batch_size=32) + +def _get_loader(data_name, loader_initfunc): + data = cebra.datasets.init(data_name) + kwargs = dict(num_steps=2, batch_size=32) loader = loader_initfunc(data, **kwargs) - return loader + return loader, data OUTPUT_DIMENSION = 3 @@ -84,12 +55,12 @@ def _make_model(dataset, model_architecture="offset10-model"): OUTPUT_DIMENSION) -# def _make_behavior_model(dataset): -# # TODO flexible input dimension -# return nn.Sequential( -# nn.Conv1d(dataset.input_dimension, 5, kernel_size=10), -# nn.Flatten(start_dim=1, end_dim=-1), -# ) +def _make_behavior_model(dataset): + # TODO flexible input dimension + return nn.Sequential( + nn.Conv1d(dataset.input_dimension, 5, kernel_size=10), + nn.Flatten(start_dim=1, end_dim=-1), + ) def _assert_same_state_dict(first, second): @@ -135,12 +106,16 @@ def _assert_equal(original_solver, loaded_solver): @pytest.mark.parametrize( - "data_name, loader_initfunc, model_architecture, solver_initfunc", - single_session_tests) + "data_name, model_architecture, loader_initfunc, solver_initfunc", + [(dataset, model, loader, cebra.solver.SingleSessionSolver) + for dataset, loader in [("demo-discrete", cebra.data.DiscreteDataLoader), + ("demo-continuous", cebra.data.ContinuousDataLoader + ), ("demo-mixed", cebra.data.MixedDataLoader)] + for model in + ["offset1-model", "offset10-model", "offset40-model-4x-subsample"]]) def test_single_session(data_name, loader_initfunc, model_architecture, solver_initfunc): - data = cebra.datasets.init(data_name) - loader = _get_loader(data, loader_initfunc) + loader, data = _get_loader(data_name, loader_initfunc) model = _make_model(data, model_architecture) data.configure_for(model) offset = model.get_offset() @@ -163,21 +138,84 @@ def test_single_session(data_name, loader_initfunc, model_architecture, solver.fit(loader) - assert solver.num_sessions == None + assert solver.num_sessions is None assert solver.n_features == X.shape[1] embedding = solver.transform(X) assert isinstance(embedding, torch.Tensor) - assert embedding.shape == (X.shape[0], OUTPUT_DIMENSION) + if isinstance(solver.model, cebra.models.ResampleModelMixin): + assert embedding.shape == (X.shape[0] // solver.model.resample_factor, + OUTPUT_DIMENSION) + else: + assert embedding.shape == (X.shape[0], OUTPUT_DIMENSION) embedding = solver.transform(torch.Tensor(X)) assert isinstance(embedding, torch.Tensor) - assert embedding.shape == (X.shape[0], OUTPUT_DIMENSION) + if isinstance(solver.model, cebra.models.ResampleModelMixin): + assert embedding.shape == (X.shape[0] // solver.model.resample_factor, + OUTPUT_DIMENSION) + else: + assert embedding.shape == (X.shape[0], OUTPUT_DIMENSION) embedding = solver.transform(X, session_id=0) assert isinstance(embedding, torch.Tensor) - assert embedding.shape == (X.shape[0], OUTPUT_DIMENSION) + if isinstance(solver.model, cebra.models.ResampleModelMixin): + assert embedding.shape == (X.shape[0] // solver.model.resample_factor, + OUTPUT_DIMENSION) + else: + assert embedding.shape == (X.shape[0], OUTPUT_DIMENSION) embedding = solver.transform(X, pad_before_transform=False) assert isinstance(embedding, torch.Tensor) - assert embedding.shape == (X.shape[0] - len(offset) + 1, OUTPUT_DIMENSION) + if isinstance(solver.model, cebra.models.ResampleModelMixin): + assert embedding.shape == ( + (X.shape[0] - len(offset)) // solver.model.resample_factor + 1, + OUTPUT_DIMENSION) + else: + assert embedding.shape == (X.shape[0] - len(offset) + 1, + OUTPUT_DIMENSION) + + with pytest.raises(ValueError, match="torch.Tensor"): + solver.transform(X.numpy()) + with pytest.raises(RuntimeError, match="Invalid.*session_id"): + embedding = solver.transform(X, session_id=2) + + for param in solver.parameters(): + assert isinstance(param, torch.Tensor) + + fitted_solver = copy.deepcopy(solver) + with tempfile.TemporaryDirectory() as temp_dir: + solver.save(temp_dir) + solver.load(temp_dir) + _assert_equal(fitted_solver, solver) + + embedding = solver.transform(X) + assert isinstance(embedding, torch.Tensor) + if isinstance(solver.model, cebra.models.ResampleModelMixin): + assert embedding.shape == (X.shape[0] // solver.model.resample_factor, + OUTPUT_DIMENSION) + else: + assert embedding.shape == (X.shape[0], OUTPUT_DIMENSION) + embedding = solver.transform(torch.Tensor(X)) + assert isinstance(embedding, torch.Tensor) + if isinstance(solver.model, cebra.models.ResampleModelMixin): + assert embedding.shape == (X.shape[0] // solver.model.resample_factor, + OUTPUT_DIMENSION) + else: + assert embedding.shape == (X.shape[0], OUTPUT_DIMENSION) + embedding = solver.transform(X, session_id=0) + assert isinstance(embedding, torch.Tensor) + if isinstance(solver.model, cebra.models.ResampleModelMixin): + assert embedding.shape == (X.shape[0] // solver.model.resample_factor, + OUTPUT_DIMENSION) + else: + assert embedding.shape == (X.shape[0], OUTPUT_DIMENSION) + embedding = solver.transform(X, pad_before_transform=False) + assert isinstance(embedding, torch.Tensor) + if isinstance(solver.model, cebra.models.ResampleModelMixin): + assert embedding.shape == ( + (X.shape[0] - len(offset)) // solver.model.resample_factor + 1, + OUTPUT_DIMENSION) + else: + assert embedding.shape == (X.shape[0] - len(offset) + 1, + OUTPUT_DIMENSION) with pytest.raises(ValueError, match="torch.Tensor"): solver.transform(X.numpy()) @@ -195,15 +233,21 @@ def test_single_session(data_name, loader_initfunc, model_architecture, @pytest.mark.parametrize( - "data_name, loader_initfunc, model_architecture, solver_initfunc", - single_session_tests) + "data_name, model_architecture, loader_initfunc, solver_initfunc", + [(dataset, model, loader, cebra.solver.SingleSessionSolver) + for dataset, loader in [("demo-discrete", cebra.data.DiscreteDataLoader), + ("demo-continuous", cebra.data.ContinuousDataLoader + ), ("demo-mixed", cebra.data.MixedDataLoader)] + for model in + ["offset1-model", "offset10-model", "offset40-model-4x-subsample"]]) def test_single_session_auxvar(data_name, loader_initfunc, model_architecture, solver_initfunc): - return # TODO + + pytest.skip("Not yet supported") loader = _get_loader(data_name, loader_initfunc) model = _make_model(loader.dataset) - behavior_model = _make_behavior_model(loader.dataset) + behavior_model = _make_behavior_model(loader.dataset) # noqa: F841 criterion = cebra.models.InfoNCE() optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) @@ -223,12 +267,13 @@ def test_single_session_auxvar(data_name, loader_initfunc, model_architecture, @pytest.mark.parametrize( - "data_name, loader_initfunc, model_architecture, solver_initfunc", - single_session_hybrid_tests) + "data_name, model_architecture, loader_initfunc, solver_initfunc", + [("demo-continuous", model, cebra.data.HybridDataLoader, + cebra.solver.SingleSessionHybridSolver) + for model in ["offset1-model", "offset10-model"]]) def test_single_session_hybrid(data_name, loader_initfunc, model_architecture, solver_initfunc): - data = cebra.datasets.init(data_name) - loader = _get_loader(data, loader_initfunc) + loader, data = _get_loader(data_name, loader_initfunc) model = _make_model(data, model_architecture) data.configure_for(model) offset = model.get_offset() @@ -250,7 +295,7 @@ def test_single_session_hybrid(data_name, loader_initfunc, model_architecture, solver.fit(loader) - assert solver.num_sessions == None + assert solver.num_sessions is None assert solver.n_features == X.shape[1] embedding = solver.transform(X) @@ -282,17 +327,25 @@ def test_single_session_hybrid(data_name, loader_initfunc, model_architecture, @pytest.mark.parametrize( - "data_name, loader_initfunc, model_architecture, solver_initfunc", - multi_session_tests) + "data_name, model_architecture, loader_initfunc, solver_initfunc", + [(dataset, model, loader, cebra.solver.MultiSessionSolver) + for dataset, loader in [ + ("demo-discrete-multisession", + cebra.data.DiscreteMultiSessionDataLoader), + ("demo-continuous-multisession", + cebra.data.ContinuousMultiSessionDataLoader), + ] + for model in ["offset1-model", "offset10-model"]]) def test_multi_session(data_name, loader_initfunc, model_architecture, solver_initfunc): - data = cebra.datasets.init(data_name) - loader = _get_loader(data, loader_initfunc) + loader, data = _get_loader(data_name, loader_initfunc) model = nn.ModuleList([ _make_model(dataset, model_architecture) for dataset in data.iter_sessions() ]) data.configure_for(model) + offset_length = len(model[0].get_offset()) + criterion = cebra.models.InfoNCE() optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) @@ -302,8 +355,9 @@ def test_multi_session(data_name, loader_initfunc, model_architecture, batch = next(iter(loader)) for session_id, dataset in enumerate(loader.dataset.iter_sessions()): - assert batch[session_id].reference.shape[:2] == ( - 32, dataset.input_dimension) + assert batch[session_id].reference.shape == (32, + dataset.input_dimension, + offset_length) assert batch[session_id].index is not None log = solver.step(batch) @@ -360,267 +414,28 @@ def test_multi_session(data_name, loader_initfunc, model_architecture, _assert_equal(fitted_solver, solver) -@pytest.mark.parametrize( - "inputs, add_padding, offset, start_batch_idx, end_batch_idx, expected_output", - [ - # Test case 1: No padding - (torch.tensor([[1, 2], [3, 4], [5, 6]]), False, cebra.data.Offset( - 0, 1), 0, 2, torch.tensor([[1, 2], [3, 4]])), # first batch - (torch.tensor([[1, 2], [3, 4], [5, 6]]), False, cebra.data.Offset( - 0, 1), 1, 3, torch.tensor([[3, 4], [5, 6]])), # last batch - (torch.tensor( - [[1, 2], [3, 4], [5, 6], [7, 8]]), False, cebra.data.Offset( - 0, 1), 1, 3, torch.tensor([[3, 4], [5, 6]])), # middle batch - - # Test case 2: First batch with padding - ( - torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), - True, - cebra.data.Offset(0, 1), - 0, - 2, - torch.tensor([[1, 2, 3], [4, 5, 6]]), - ), - ( - torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), - True, - cebra.data.Offset(1, 1), - 0, - 3, - torch.tensor([[1, 2, 3], [1, 2, 3], [4, 5, 6], [7, 8, 9]]), - ), - - # Test case 3: Last batch with padding - ( - torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), - True, - cebra.data.Offset(0, 1), - 1, - 3, - torch.tensor([[4, 5, 6], [7, 8, 9]]), - ), - ( - torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12], - [13, 14, 15]]), - True, - cebra.data.Offset(1, 2), - 1, - 3, - torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]), - ), - - # Test case 4: Middle batch with padding - ( - torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]), - True, - cebra.data.Offset(0, 1), - 1, - 3, - torch.tensor([[4, 5, 6], [7, 8, 9]]), - ), - ( - torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]), - True, - cebra.data.Offset(1, 1), - 1, - 3, - torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), - ), - ( - torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12], - [13, 14, 15]]), - True, - cebra.data.Offset(0, 1), - 2, - 4, - torch.tensor([[7, 8, 9], [10, 11, 12]]), - ), - ( - torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]), - True, - cebra.data.Offset(0, 1), - 0, - 3, - torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), - ), - - # Examples that throw an error: - - # Padding without offset (should raise an error) - (torch.tensor([[1, 2]]), True, None, 0, 2, ValueError), - # Negative start_batch_idx or end_batch_idx (should raise an error) - (torch.tensor([[1, 2]]), False, cebra.data.Offset( - 0, 1), -1, 2, ValueError), - # out of bound indices because offset is too large - (torch.tensor([[1, 2], [3, 4]]), True, cebra.data.Offset( - 5, 5), 1, 2, ValueError), - # Batch length is smaller than offset. - (torch.tensor([[1, 2], [3, 4]]), False, cebra.data.Offset( - 0, 1), 0, 1, ValueError), # first batch - ], -) -def test_get_batch(inputs, add_padding, offset, start_batch_idx, end_batch_idx, - expected_output): - if expected_output == ValueError: - with pytest.raises(ValueError): - cebra.solver.base._get_batch(inputs, offset, start_batch_idx, - end_batch_idx, add_padding) - else: - result = cebra.solver.base._get_batch(inputs, offset, start_batch_idx, - end_batch_idx, add_padding) - assert torch.equal(result, expected_output) - - -def create_model(model_name, input_dimension): - return cebra.models.init(model_name, - num_neurons=input_dimension, - num_units=128, - num_output=OUTPUT_DIMENSION) - - -single_session_tests_select_model = [] -single_session_hybrid_tests_select_model = [] -for model_name in ["offset1-model", "offset10-model"]: - for session_id in [None, 0, 5]: - for args in [ - ("demo-discrete", model_name, session_id, - cebra.data.DiscreteDataLoader), - ("demo-continuous", model_name, session_id, - cebra.data.ContinuousDataLoader), - ("demo-mixed", model_name, session_id, cebra.data.MixedDataLoader), - ]: - single_session_tests_select_model.append( - (*args, cebra.solver.SingleSessionSolver)) - single_session_hybrid_tests_select_model.append( - (*args, cebra.solver.SingleSessionHybridSolver)) - -multi_session_tests_select_model = [] -for model_name in ["offset10-model"]: - for session_id in [None, 0, 1, 5, 2, 6, 4]: - for args in [("demo-continuous-multisession", model_name, session_id, - cebra.data.ContinuousMultiSessionDataLoader)]: - multi_session_tests_select_model.append( - (*args, cebra.solver.MultiSessionSolver)) +def _make_val_data(dataset): + if isinstance(dataset, cebra.datasets.demo.DemoDataset): + return dataset.neural + elif isinstance(dataset, cebra.datasets.demo.DemoDatasetUnified): + return [session.neural for session in dataset.iter_sessions()], [ + session.continuous_index for session in dataset.iter_sessions() + ] @pytest.mark.parametrize( - "data_name, model_name ,session_id, loader_initfunc, solver_initfunc", - single_session_tests_select_model + single_session_hybrid_tests_select_model -) -def test_select_model_single_session(data_name, model_name, session_id, - loader_initfunc, solver_initfunc): - dataset = cebra.datasets.init(data_name) - model = create_model(model_name, dataset.input_dimension) - dataset.configure_for(model) - loader = _get_loader(dataset, loader_initfunc=loader_initfunc) + "data_name, model_architecture, loader_initfunc, solver_initfunc", + [(dataset, model, loader, cebra.solver.UnifiedSolver) + for dataset, loader in [ + ("demo-continuous-unified", cebra.data.UnifiedLoader), + ] + for model in ["offset1-model", "offset10-model"]]) +def test_unified_session(data_name, model_architecture, loader_initfunc, + solver_initfunc): + loader, data = _get_loader(data_name, loader_initfunc) + model = _make_model(data, model_architecture) + data.configure_for(model) offset = model.get_offset() - solver = solver_initfunc(model=model, criterion=None, optimizer=None) - - with pytest.raises(ValueError): - solver.n_features = 1000 - solver._select_model(inputs=dataset.neural, session_id=0) - - solver.n_features = dataset.neural.shape[1] - if session_id is not None and session_id > 0: - with pytest.raises(RuntimeError): - solver._select_model(inputs=dataset.neural, session_id=session_id) - else: - model_, offset_ = solver._select_model(inputs=dataset.neural, - session_id=session_id) - assert offset.left == offset_.left and offset.right == offset_.right - assert model == model_ - - -@pytest.mark.parametrize( - "data_name, model_name, session_id, loader_initfunc, solver_initfunc", - multi_session_tests_select_model) -def test_select_model_multi_session(data_name, model_name, session_id, - loader_initfunc, solver_initfunc): - dataset = cebra.datasets.init(data_name) - model = nn.ModuleList([ - create_model(model_name, dataset.input_dimension) - for dataset in dataset.iter_sessions() - ]) - dataset.configure_for(model) - loader = _get_loader(dataset, loader_initfunc=loader_initfunc) - - offset = model[0].get_offset() - solver = solver_initfunc(model=model, - criterion=cebra.models.InfoNCE(), - optimizer=torch.optim.Adam(model.parameters(), - lr=1e-3)) - - loader_kwargs = dict(num_steps=10, batch_size=32) - loader = cebra.data.ContinuousMultiSessionDataLoader( - dataset, **loader_kwargs) - solver.fit(loader) - - for i, (model, dataset_) in enumerate(zip(model, dataset.iter_sessions())): - inputs = dataset_.neural - - if session_id is None or session_id >= dataset.num_sessions: - with pytest.raises(RuntimeError): - solver._select_model(inputs, session_id=session_id) - elif i != session_id: - with pytest.raises(ValueError): - solver._select_model(inputs, session_id=session_id) - else: - model_, offset_ = solver._select_model(inputs, - session_id=session_id) - assert offset.left == offset_.left and offset.right == offset_.right - assert model == model_ - - -models = [ - "offset1-model", - "offset10-model", - "offset40-model-4x-subsample", - "offset1-model", - "offset10-model", -] -batch_size_inference = [40_000, 99_990, 99_999] - -single_session_tests_transform = [] -for padding in [True, False]: - for model_name in models: - for batch_size in batch_size_inference: - for args in [ - ("demo-discrete", model_name, padding, batch_size, - cebra.data.DiscreteDataLoader), - ("demo-continuous", model_name, padding, batch_size, - cebra.data.ContinuousDataLoader), - ("demo-mixed", model_name, padding, batch_size, - cebra.data.MixedDataLoader), - ]: - single_session_tests_transform.append( - (*args, cebra.solver.SingleSessionSolver)) - -single_session_hybrid_tests_transform = [] -for padding in [True, False]: - for model_name in models: - for batch_size in batch_size_inference: - for args in [("demo-continuous", model_name, padding, batch_size, - cebra.data.HybridDataLoader)]: - single_session_hybrid_tests_transform.append( - (*args, cebra.solver.SingleSessionHybridSolver)) - - -@pytest.mark.parametrize( - "data_name, model_name, padding, batch_size_inference, loader_initfunc, solver_initfunc", - single_session_tests_transform + single_session_hybrid_tests_transform) -def test_batched_transform_single_session( - data_name, - model_name, - padding, - batch_size_inference, - loader_initfunc, - solver_initfunc, -): - dataset = cebra.datasets.init(data_name) - model = create_model(model_name, dataset.input_dimension) - dataset.offset = model.get_offset() - loader_kwargs = dict(num_steps=10, batch_size=32) - loader = loader_initfunc(dataset, **loader_kwargs) criterion = cebra.models.InfoNCE() optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) @@ -628,95 +443,24 @@ def test_batched_transform_single_session( solver = solver_initfunc(model=model, criterion=criterion, optimizer=optimizer) - solver.fit(loader) - - smallest_batch_length = loader.dataset.neural.shape[0] - batch_size - offset_ = model.get_offset() - padding_left = offset_.left if padding else 0 - - if smallest_batch_length <= len(offset_): - with pytest.raises(ValueError): - solver.transform(inputs=loader.dataset.neural, - batch_size=batch_size, - pad_before_transform=padding) - else: - embedding_batched = solver.transform(inputs=loader.dataset.neural, - batch_size=batch_size, - pad_before_transform=padding) - - embedding = solver.transform(inputs=loader.dataset.neural, - pad_before_transform=padding) - - assert embedding_batched.shape == embedding.shape - assert np.allclose(embedding_batched, embedding, rtol=1e-02) - - -multi_session_tests_transform = [] -for padding in [True, False]: - for model_name in models: - for batch_size in batch_size_inference: - for args in [ - ("demo-continuous-multisession", model_name, padding, - batch_size, cebra.data.ContinuousMultiSessionDataLoader) - ]: - multi_session_tests_transform.append( - (*args, cebra.solver.MultiSessionSolver)) - - -@pytest.mark.parametrize( - "data_name, model_name,padding,batch_size_inference,loader_initfunc, solver_initfunc", - multi_session_tests_transform) -def test_batched_transform_multi_session(data_name, model_name, padding, - batch_size_inference, loader_initfunc, - solver_initfunc): - dataset = cebra.datasets.init(data_name) - model = nn.ModuleList([ - create_model(model_name, dataset.input_dimension) - for dataset in dataset.iter_sessions() - ]) - dataset.offset = model[0].get_offset() - - n_samples = dataset._datasets[0].neural.shape[0] - assert all( - d.neural.shape[0] == n_samples for d in dataset._datasets - ), "for this set all of the sessions need to have same number of samples." - - smallest_batch_length = n_samples - batch_size - offset_ = model[0].get_offset() - padding_left = offset_.left if padding else 0 - for d in dataset._datasets: - d.offset = offset_ - loader_kwargs = dict(num_steps=10, batch_size=32) - loader = loader_initfunc(dataset, **loader_kwargs) + batch = next(iter(loader)) + assert batch.reference.shape == (32, loader.dataset.input_dimension, + len(offset)) - criterion = cebra.models.InfoNCE() - optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) + log = solver.step(batch) + assert isinstance(log, dict) - solver = solver_initfunc(model=model, - criterion=criterion, - optimizer=optimizer) solver.fit(loader) + data, labels = _make_val_data(loader.dataset) - # Transform each session with the right model, by providing the corresponding session ID - for i, inputs in enumerate(dataset.iter_sessions()): + assert solver.num_sessions == 3 + assert solver.n_features == sum( + [data[i].shape[1] for i in range(len(data))]) - if smallest_batch_length <= len(offset_): - with pytest.raises(ValueError): - solver.transform(inputs=inputs.neural, - batch_size=batch_size, - session_id=i, - pad_before_transform=padding) + for i in range(loader.dataset.num_sessions): + emb = solver.transform(data, labels, session_id=i) + assert emb.shape == (loader.dataset.num_timepoints, 3) - else: - model_ = model[i] - embedding = solver.transform(inputs=inputs.neural, - session_id=i, - pad_before_transform=padding) - embedding_batched = solver.transform(inputs=inputs.neural, - session_id=i, - pad_before_transform=padding, - batch_size=batch_size) - - assert embedding_batched.shape == embedding.shape - assert np.allclose(embedding_batched, embedding, rtol=1e-02) + emb = solver.transform(data, labels, session_id=i, batch_size=300) + assert emb.shape == (loader.dataset.num_timepoints, 3) diff --git a/tests/test_solver_batched.py b/tests/test_solver_batched.py new file mode 100644 index 00000000..8592aea2 --- /dev/null +++ b/tests/test_solver_batched.py @@ -0,0 +1,343 @@ +# +# CEBRA: Consistent EmBeddings of high-dimensional Recordings using Auxiliary variables +# © Mackenzie W. Mathis & Steffen Schneider (v0.4.0+) +# Source code: +# https://github.com/AdaptiveMotorControlLab/CEBRA +# +# Please see LICENSE.md for the full license document: +# https://github.com/AdaptiveMotorControlLab/CEBRA/blob/main/LICENSE.md +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import numpy as np +import pytest +import torch +from torch import nn + +import cebra.data +import cebra.datasets +import cebra.models +import cebra.solver + +device = "cpu" + +NUM_STEPS = 10 +BATCHES = [25_000, 50_000, 75_000] +MODELS = ["offset1-model", "offset10-model", "offset40-model-4x-subsample"] + + +@pytest.mark.parametrize( + "inputs, add_padding, offset, start_batch_idx, end_batch_idx, expected_output", + [ + # Test case 1: No padding + (torch.tensor([[1, 2], [3, 4], [5, 6]]), False, cebra.data.Offset( + 0, 1), 0, 2, torch.tensor([[1, 2], [3, 4]])), # first batch + (torch.tensor([[1, 2], [3, 4], [5, 6]]), False, cebra.data.Offset( + 0, 1), 1, 3, torch.tensor([[3, 4], [5, 6]])), # last batch + (torch.tensor( + [[1, 2], [3, 4], [5, 6], [7, 8]]), False, cebra.data.Offset( + 0, 1), 1, 3, torch.tensor([[3, 4], [5, 6]])), # middle batch + + # Test case 2: First batch with padding + ( + torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), + True, + cebra.data.Offset(0, 1), + 0, + 2, + torch.tensor([[1, 2, 3], [4, 5, 6]]), + ), + ( + torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), + True, + cebra.data.Offset(1, 1), + 0, + 3, + torch.tensor([[1, 2, 3], [1, 2, 3], [4, 5, 6], [7, 8, 9]]), + ), + + # Test case 3: Last batch with padding + ( + torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), + True, + cebra.data.Offset(0, 1), + 1, + 3, + torch.tensor([[4, 5, 6], [7, 8, 9]]), + ), + ( + torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12], + [13, 14, 15]]), + True, + cebra.data.Offset(1, 2), + 1, + 3, + torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]), + ), + + # Test case 4: Middle batch with padding + ( + torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]), + True, + cebra.data.Offset(0, 1), + 1, + 3, + torch.tensor([[4, 5, 6], [7, 8, 9]]), + ), + ( + torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]), + True, + cebra.data.Offset(1, 1), + 1, + 3, + torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), + ), + ( + torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12], + [13, 14, 15]]), + True, + cebra.data.Offset(0, 1), + 2, + 4, + torch.tensor([[7, 8, 9], [10, 11, 12]]), + ), + ( + torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]), + True, + cebra.data.Offset(0, 1), + 0, + 3, + torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]), + ), + # Padding without offset (should raise an error) + (torch.tensor([[1, 2]]), True, None, 0, 2, ValueError), + # Negative start_batch_idx or end_batch_idx (should raise an error) + (torch.tensor([[1, 2]]), False, cebra.data.Offset( + 0, 1), -1, 2, ValueError), + # out of bound indices because offset is too large + (torch.tensor([[1, 2], [3, 4]]), True, cebra.data.Offset( + 5, 5), 1, 2, ValueError), + # Batch length is smaller than offset. + (torch.tensor([[1, 2], [3, 4]]), False, cebra.data.Offset( + 0, 1), 0, 1, ValueError), # first batch + ], +) +def test_get_batch(inputs, add_padding, offset, start_batch_idx, end_batch_idx, + expected_output): + if expected_output == ValueError: + with pytest.raises(ValueError): + cebra.solver.base._get_batch(inputs, offset, start_batch_idx, + end_batch_idx, add_padding) + else: + result = cebra.solver.base._get_batch(inputs, offset, start_batch_idx, + end_batch_idx, add_padding) + assert torch.equal(result, expected_output) + + +def create_model(model_name, input_dimension): + return cebra.models.init(model_name, + num_neurons=input_dimension, + num_units=128, + num_output=3) + + +@pytest.mark.parametrize( + "data_name, model_name, session_id, loader_initfunc, solver_initfunc", + [(dataset, model, session_id, loader, cebra.solver.SingleSessionSolver) + for dataset, loader in [("demo-discrete", cebra.data.DiscreteDataLoader), + ("demo-continuous", cebra.data.ContinuousDataLoader + ), ("demo-mixed", cebra.data.MixedDataLoader)] + for model in ["offset1-model", "offset10-model"] + for session_id in [None, 0, 5]] + + [(dataset, model, session_id, loader, + cebra.solver.SingleSessionHybridSolver) + for dataset, loader in [ + ("demo-continuous", cebra.data.HybridDataLoader), + ] + for model in ["offset1-model", "offset10-model"] + for session_id in [None, 0, 5]]) +def test_select_model_single_session(data_name, model_name, session_id, + loader_initfunc, solver_initfunc): + dataset = cebra.datasets.init(data_name) + model = create_model(model_name, dataset.input_dimension) + dataset.configure_for(model) + offset = model.get_offset() + solver = solver_initfunc(model=model, criterion=None, optimizer=None) + + with pytest.raises(ValueError): + solver.n_features = 1000 + solver._select_model(inputs=dataset.neural, session_id=0) + + solver.n_features = dataset.neural.shape[1] + if session_id is not None and session_id > 0: + with pytest.raises(RuntimeError): + solver._select_model(inputs=dataset.neural, session_id=session_id) + else: + model_, offset_ = solver._select_model(inputs=dataset.neural, + session_id=session_id) + assert offset.left == offset_.left and offset.right == offset_.right + assert model == model_ + + +@pytest.mark.parametrize( + "data_name, model_name, session_id, loader_initfunc, solver_initfunc", + [(dataset, model, session_id, loader, cebra.solver.MultiSessionSolver) + for dataset, loader in [ + ("demo-continuous-multisession", + cebra.data.ContinuousMultiSessionDataLoader), + ] + for model in ["offset1-model", "offset10-model"] + for session_id in [None, 0, 1, 5, 2, 6, 4]]) +def test_select_model_multi_session(data_name, model_name, session_id, + loader_initfunc, solver_initfunc): + + dataset = cebra.datasets.init(data_name) + kwargs = dict(num_steps=NUM_STEPS, batch_size=32) + loader = loader_initfunc(dataset, **kwargs) + + model = nn.ModuleList([ + create_model(model_name, dataset.input_dimension) + for dataset in dataset.iter_sessions() + ]) + dataset.configure_for(model) + + offset = model[0].get_offset() + solver = solver_initfunc(model=model, + criterion=cebra.models.InfoNCE(), + optimizer=torch.optim.Adam(model.parameters(), + lr=1e-3)) + + loader_kwargs = dict(num_steps=NUM_STEPS, batch_size=32) + loader = cebra.data.ContinuousMultiSessionDataLoader( + dataset, **loader_kwargs) + solver.fit(loader) + + for i, (model, dataset_) in enumerate(zip(model, dataset.iter_sessions())): + inputs = dataset_.neural + + if session_id is None or session_id >= dataset.num_sessions: + with pytest.raises(RuntimeError): + solver._select_model(inputs, session_id=session_id) + elif i != session_id: + with pytest.raises(ValueError): + solver._select_model(inputs, session_id=session_id) + else: + model_, offset_ = solver._select_model(inputs, + session_id=session_id) + assert offset.left == offset_.left and offset.right == offset_.right + assert model == model_ + + +@pytest.mark.parametrize( + "data_name, model_name, padding, batch_size_inference, loader_initfunc, solver_initfunc", + [(dataset, model, padding, batch_size, loader, + cebra.solver.SingleSessionSolver) + for dataset, loader in [("demo-discrete", cebra.data.DiscreteDataLoader), + ("demo-continuous", cebra.data.ContinuousDataLoader + ), ("demo-mixed", cebra.data.MixedDataLoader)] + for model in + ["offset1-model", "offset10-model", "offset40-model-4x-subsample"] + for padding in [True, False] + for batch_size in BATCHES] + + [(dataset, model, padding, batch_size, loader, + cebra.solver.SingleSessionHybridSolver) + for dataset, loader in [ + ("demo-continuous", cebra.data.HybridDataLoader), + ] + for model in MODELS + for padding in [True, False] + for batch_size in BATCHES]) +def test_batched_transform_single_session( + data_name, + model_name, + padding, + batch_size_inference, + loader_initfunc, + solver_initfunc, +): + dataset = cebra.datasets.init(data_name) + model = create_model(model_name, dataset.input_dimension) + dataset.configure_for(model) + loader_kwargs = dict(num_steps=NUM_STEPS, batch_size=32) + loader = loader_initfunc(dataset, **loader_kwargs) + + criterion = cebra.models.InfoNCE() + optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) + + solver = solver_initfunc(model=model, + criterion=criterion, + optimizer=optimizer) + solver.fit(loader) + + embedding_batched = solver.transform(inputs=loader.dataset.neural, + batch_size=batch_size_inference, + pad_before_transform=padding) + + embedding = solver.transform(inputs=loader.dataset.neural, + pad_before_transform=padding) + + assert embedding_batched.shape == embedding.shape + assert np.allclose(embedding_batched, embedding, rtol=1e-4, atol=1e-4) + + +@pytest.mark.parametrize( + "data_name, model_name,padding,batch_size_inference,loader_initfunc, solver_initfunc", + [(dataset, model, padding, batch_size, loader, + cebra.solver.MultiSessionSolver) + for dataset, loader in [ + ("demo-continuous-multisession", + cebra.data.ContinuousMultiSessionDataLoader), + ] + for model in + ["offset1-model", "offset10-model", "offset40-model-4x-subsample"] + for padding in [True, False] + for batch_size in BATCHES]) +def test_batched_transform_multi_session(data_name, model_name, padding, + batch_size_inference, loader_initfunc, + solver_initfunc): + dataset = cebra.datasets.init(data_name) + model = nn.ModuleList([ + create_model(model_name, dataset.input_dimension) + for dataset in dataset.iter_sessions() + ]) + dataset.configure_for(model) + + n_samples = dataset._datasets[0].neural.shape[0] + assert all( + d.neural.shape[0] == n_samples for d in dataset._datasets + ), "for this set all of the sessions need to have same number of samples." + + loader_kwargs = dict(num_steps=NUM_STEPS, batch_size=32) + loader = loader_initfunc(dataset, **loader_kwargs) + + criterion = cebra.models.InfoNCE() + optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) + + solver = solver_initfunc(model=model, + criterion=criterion, + optimizer=optimizer) + solver.fit(loader) + + # Transform each session with the right model, by providing + # the corresponding session ID + for i, inputs in enumerate(dataset.iter_sessions()): + embedding = solver.transform(inputs=inputs.neural, + session_id=i, + pad_before_transform=padding) + embedding_batched = solver.transform(inputs=inputs.neural, + session_id=i, + pad_before_transform=padding, + batch_size=batch_size_inference) + + assert embedding_batched.shape == embedding.shape + assert np.allclose(embedding_batched, embedding, rtol=1e-4, atol=1e-4) diff --git a/tests/test_usecases.py b/tests/test_usecases.py index 22195bd8..f0cc308a 100644 --- a/tests/test_usecases.py +++ b/tests/test_usecases.py @@ -29,7 +29,6 @@ """ import itertools -import pickle import numpy as np import pytest diff --git a/tools/build_docker.sh b/tools/build_docker.sh index 76aa8228..cec031a0 100755 --- a/tools/build_docker.sh +++ b/tools/build_docker.sh @@ -3,6 +3,21 @@ set -e +# Parse command line arguments +RUN_FULL_TESTS=false +while [[ $# -gt 0 ]]; do + case $1 in + --full-tests) + RUN_FULL_TESTS=true + shift + ;; + *) + echo "Unknown option: $1" + exit 1 + ;; + esac +done + if [[ -z $(git status --porcelain) ]]; then TAG=$(git rev-parse --short HEAD) else @@ -23,13 +38,20 @@ docker build \ -t $DOCKERNAME . docker tag $DOCKERNAME $LATEST +# Determine whether to run full tests or not +if [[ "$RUN_FULL_TESTS" == "true" ]]; then + echo "Running full test suite including tests that require datasets" +else + echo "Running tests that don't require datasets" +fi + docker run \ --gpus 2 \ ${extra_kwargs[@]} \ -v ${CEBRA_DATADIR:-./data}:/data \ --env CEBRA_DATADIR=/data \ --network host \ - -it $DOCKERNAME python -m pytest --ff -x -m "not requires_dataset" --doctest-modules ./docs/source/usage.rst tests cebra + -it $DOCKERNAME python -m pytest --ff -x $([ "$RUN_FULL_TESTS" != "true" ] && echo '-m "not requires_dataset"') --doctest-modules ./docs/source/usage.rst tests cebra #docker push $DOCKERNAME #docker push $LATEST diff --git a/tools/build_docs.sh b/tools/build_docs.sh index 3f5f36cd..119272ed 100755 --- a/tools/build_docs.sh +++ b/tools/build_docs.sh @@ -1,79 +1,17 @@ #!/bin/bash -# Locally build the documentation and display it in a webserver. -set -xe - -git_checkout_or_pull() { - local repo=$1 - local target_dir=$2 - # TODO(stes): theoretically we could also auto-update the repo, - # I commented this out for now to avoid interference with local - # dev/changes - #if [ -d "$target_dir" ]; then - # cd "$target_dir" - # git pull --ff-only origin main - # cd - - #else - if [ ! -d "$target_dir" ]; then - git clone "$repo" "$target_dir" - fi -} - -checkout_cebra_figures() { - git_checkout_or_pull git@github.com:AdaptiveMotorControlLab/cebra-figures.git docs/source/cebra-figures -} - -checkout_assets() { - git_checkout_or_pull git@github.com:AdaptiveMotorControlLab/cebra-assets.git assets -} - -checkout_cebra_demos() { - git_checkout_or_pull git@github.com:AdaptiveMotorControlLab/cebra-demos.git docs/source/demo_notebooks -} - -setup_python() { - python -m pip install --upgrade pip setuptools wheel - sudo apt-get install -y pandoc - pip install torch --extra-index-url=https://download.pytorch.org/whl/cpu - pip install '.[docs]' -} - -build_docs() { - cp -r assets/* . - export SPHINXOPTS="-W --keep-going -n" - (cd docs && PYTHONPATH=.. make page) -} - -serve() { - python -m http.server 8080 --b 0.0.0.0 -d docs/build/html -} - -main() { - build_docs - serve -} - -if [[ "$1" == "--build" ]]; then - main -fi - -docker build -t cebra-docs -f - . << "EOF" -FROM python:3.9 -RUN python -m pip install --upgrade pip setuptools wheel \ - && apt-get update -y && apt-get install -y pandoc git -RUN pip install torch --extra-index-url=https://download.pytorch.org/whl/cpu -COPY dist/cebra-0.4.0-py2.py3-none-any.whl . -RUN pip install 'cebra-0.4.0-py2.py3-none-any.whl[docs]' -EOF - -checkout_cebra_figures -checkout_assets -checkout_cebra_demos - -docker run \ - -p 127.0.0.1:8080:8080 \ - -u $(id -u):$(id -g) \ - -v .:/app -w /app \ - --tmpfs /.config --tmpfs /.cache \ - -it cebra-docs \ - ./tools/build_docs.sh --build +docker build -t cebra-docs -f docs/Dockerfile . + +docker run -u $(id -u):$(id -g) \ + -p 127.0.0.1:8000:8000 \ + -v $(pwd):/app \ + -v /tmp/.cache/pip:/.cache/pip \ + -v /tmp/.cache/sphinx:/.cache/sphinx \ + -v /tmp/.cache/matplotlib:/.cache/matplotlib \ + -v /tmp/.cache/fontconfig:/.cache/fontconfig \ + -e MPLCONFIGDIR=/tmp/.cache/matplotlib \ + -w /app \ + --env SPHINXBUILD="sphinx-autobuild" \ + --env SPHINXOPTS="-W --keep-going -n --port 8000 --host 0.0.0.0" \ + -it cebra-docs \ + make docs diff --git a/tools/bump_version.sh b/tools/bump_version.sh index fbc161b1..17142f7e 100755 --- a/tools/bump_version.sh +++ b/tools/bump_version.sh @@ -1,7 +1,7 @@ #!/bin/bash # Bump the CEBRA version to the specified value. # Edits all relevant files at once. -# +# # Usage: # tools/bump_version.sh 0.3.1rc1 @@ -10,24 +10,40 @@ if [ -z ${version} ]; then >&1 echo "Specify a version number." >&1 echo "Usage:" >&1 echo "tools/bump_version.sh " + exit 1 +fi + +# Determine the correct sed command based on the OS +# On macOS, the `sed` command requires an empty string argument after `-i` for in-place editing. +# On Linux and other Unix-like systems, the `sed` command only requires `-i` for in-place editing. +if [[ "$OSTYPE" == "darwin"* ]]; then + # macOS + SED_CMD="sed -i .bkp -e" +else + # Linux and other Unix-like systems + SED_CMD="sed -i -e" fi # python cebra version -sed -i "s/__version__ = .*/__version__ = \"${version}\"/" \ - cebra/__init__.py +$SED_CMD "s/__version__ = .*/__version__ = \"${version}\"/" cebra/__init__.py # reinstall script in root -sed -i "s/VERSION=.*/VERSION=${version}/" \ - reinstall.sh +$SED_CMD "s/VERSION=.*/VERSION=${version}/" reinstall.sh # Makefile -sed -i "s/CEBRA_VERSION := .*/CEBRA_VERSION := ${version}/" \ - Makefile +$SED_CMD "s/CEBRA_VERSION := .*/CEBRA_VERSION := ${version}/" Makefile -# Arch linux PKGBUILD -sed -i "s/pkgver=.*/pkgver=${version}/" \ - PKGBUILD +# Arch linux PKGBUILD +$SED_CMD "s/pkgver=.*/pkgver=${version}/" PKGBUILD # Dockerfile -sed -i "s/ENV WHEEL=cebra-.*\.whl/ENV WHEEL=cebra-${version}-py2.py3-none-any.whl/" \ - Dockerfile +$SED_CMD "s/ENV WHEEL=cebra-.*\.whl/ENV WHEEL=cebra-${version}-py3-none-any.whl/" Dockerfile + +# build_docs.sh +$SED_CMD "s/COPY dist\/cebra-.*-py3-none-any\.whl/COPY dist\/cebra-${version}-py3-none-any.whl/" tools/build_docs.sh +$SED_CMD "s/RUN pip install 'cebra-.*-py3-none-any\.whl/RUN pip install 'cebra-${version}-py3-none-any.whl/" tools/build_docs.sh + +# Remove backup files +if [[ "$OSTYPE" == "darwin"* ]]; then + rm cebra/__init__.py.bkp reinstall.sh.bkp Makefile.bkp PKGBUILD.bkp Dockerfile.bkp tools/build_docs.sh.bkp +fi