Webpage designed using Bootstrap 5 and Fontawesome 5.
diff --git a/docs/root/robots.txt b/docs/root/robots.txt
index 43249ef2..bbbcdfe9 100644
--- a/docs/root/robots.txt
+++ b/docs/root/robots.txt
@@ -1,3 +1,2 @@
User-agent: *
Disallow: /staging/
-Disallow: /docs/
diff --git a/docs/source/_static/css/custom.js b/docs/source/_static/css/custom.js
new file mode 100644
index 00000000..f9afa170
--- /dev/null
+++ b/docs/source/_static/css/custom.js
@@ -0,0 +1,6 @@
+requirejs.config({
+ paths: {
+ base: '/static/base',
+ plotly: 'https://cdn.plot.ly/plotly-2.12.1.min.js?noext',
+ },
+});
diff --git a/docs/source/_templates/layout.html b/docs/source/_templates/layout.html
index 2994db97..0140a5cf 100644
--- a/docs/source/_templates/layout.html
+++ b/docs/source/_templates/layout.html
@@ -1,11 +1,15 @@
{% extends "pydata_sphinx_theme/layout.html" %}
-{% block fonts %}
+{% block extrahead %}
+
+
+
+{% endblock %}
+{% block fonts %}
-
{% endblock %}
{% block docs_sidebar %}
diff --git a/docs/source/api.rst b/docs/source/api.rst
index 8989337f..846602f1 100644
--- a/docs/source/api.rst
+++ b/docs/source/api.rst
@@ -38,6 +38,9 @@ these components in other contexts and research code bases.
api/pytorch/distributions
api/pytorch/models
api/pytorch/helpers
+ api/pytorch/multiobjective
+ api/pytorch/regularized
+ api/pytorch/attribution
.. toctree::
:hidden:
diff --git a/docs/source/api/pytorch/attribution.rst b/docs/source/api/pytorch/attribution.rst
new file mode 100644
index 00000000..6efb043f
--- /dev/null
+++ b/docs/source/api/pytorch/attribution.rst
@@ -0,0 +1,21 @@
+===================
+Attribution Methods
+===================
+
+.. automodule:: cebra.attribution
+ :members:
+ :show-inheritance:
+
+Different attribution methods
+-----------------------------
+
+.. automodule:: cebra.attribution.attribution_models
+ :members:
+ :show-inheritance:
+
+Jacobian-based attribution
+--------------------------
+
+.. automodule:: cebra.attribution.jacobian_attribution
+ :members:
+ :show-inheritance:
diff --git a/docs/source/api/pytorch/models.rst b/docs/source/api/pytorch/models.rst
index ee3455bc..3fe2219b 100644
--- a/docs/source/api/pytorch/models.rst
+++ b/docs/source/api/pytorch/models.rst
@@ -43,12 +43,8 @@ Layers and model building blocks
:show-inheritance:
Multi-objective models
-~~~~~~~~~~~~~~~~~~~~~~~~
+~~~~~~~~~~~~~~~~~~~~~~
-.. automodule:: cebra.models.multiobjective
- :members:
- :private-members:
- :show-inheritance:
-
-..
- - projector
+The multi-objective interface was moved to a separate section beginning with CEBRA 0.6.0.
+Please see the :doc:`Multi-objective models ` section
+for all details, both on the old and new API interface.
diff --git a/docs/source/api/pytorch/multiobjective.rst b/docs/source/api/pytorch/multiobjective.rst
new file mode 100644
index 00000000..c959cfa1
--- /dev/null
+++ b/docs/source/api/pytorch/multiobjective.rst
@@ -0,0 +1,15 @@
+======================
+Multi-objective models
+======================
+
+.. automodule:: cebra.solver.multiobjective
+ :members:
+ :show-inheritance:
+
+.. automodule:: cebra.models.multicriterions
+ :members:
+ :show-inheritance:
+
+.. automodule:: cebra.models.multiobjective
+ :members:
+ :show-inheritance:
diff --git a/docs/source/api/pytorch/regularized.rst b/docs/source/api/pytorch/regularized.rst
new file mode 100644
index 00000000..7da94603
--- /dev/null
+++ b/docs/source/api/pytorch/regularized.rst
@@ -0,0 +1,24 @@
+================================
+Regularized Contrastive Learning
+================================
+
+Regularized solvers
+--------------------
+
+.. automodule:: cebra.solver.regularized
+ :members:
+ :show-inheritance:
+
+Schedulers
+----------
+
+.. automodule:: cebra.solver.schedulers
+ :members:
+ :show-inheritance:
+
+Jacobian Regularization
+-----------------------
+
+.. automodule:: cebra.models.jacobian_regularizer
+ :members:
+ :show-inheritance:
diff --git a/docs/source/conf.py b/docs/source/conf.py
index 025a988b..4147e7c9 100644
--- a/docs/source/conf.py
+++ b/docs/source/conf.py
@@ -26,21 +26,13 @@
# list see the documentation:
# https://www.sphinx-doc.org/en/master/usage/configuration.html
-# -- Path setup --------------------------------------------------------------
-
-# If extensions (or modules to document with autodoc) are in another directory,
-# add these directories to sys.path here. If the directory is relative to the
-# documentation root, use os.path.abspath to make it absolute, like shown here.
-#
+import datetime
import os
+import pathlib
import sys
sys.path.insert(0, os.path.abspath("."))
-import datetime
-
-import cebra
-
def get_years(start_year=2021):
year = datetime.datetime.now().year
@@ -52,16 +44,31 @@ def get_years(start_year=2021):
# -- Project information -----------------------------------------------------
project = "cebra"
-copyright = f"""{get_years(2021)}, Steffen Schneider, Jin H Lee, Mackenzie Mathis"""
-author = "Steffen Schneider, Jin H Lee, Mackenzie Mathis"
-# The full version, including alpha/beta/rc tags
-release = cebra.__version__
+copyright = f"""{get_years(2021)}"""
+author = "See AUTHORS.md"
+version_file = pathlib.Path(
+ __file__).parent.parent.parent / "cebra" / "__init__.py"
+assert version_file.exists(), f"Could not find version file: {version_file}"
+with version_file.open("r") as f:
+ for line in f:
+ if line.startswith("__version__"):
+ version = line.split("=")[1].strip().strip('"').strip("'")
+ print("Building docs for version:", version)
+ break
+ else:
+ raise ValueError("Could not find version in __init__.py")
# -- General configuration ---------------------------------------------------
# Add any Sphinx extension module names here, as strings. They can be
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones.
+
+#https://github.com/spatialaudio/nbsphinx/issues/128#issuecomment-1158712159
+html_js_files = [
+ "https://cdn.plot.ly/plotly-latest.min.js", # Add Plotly.js
+]
+
extensions = [
"sphinx.ext.autodoc",
"sphinx.ext.napoleon",
@@ -73,7 +80,6 @@ def get_years(start_year=2021):
"sphinx_tabs.tabs",
"sphinx.ext.mathjax",
"IPython.sphinxext.ipython_console_highlighting",
- # "sphinx_panels", # Note: package to avoid: no longer maintained.
"sphinx_design",
"sphinx_togglebutton",
"sphinx.ext.doctest",
@@ -121,7 +127,8 @@ def get_years(start_year=2021):
autodoc_member_order = "bysource"
autodoc_mock_imports = [
- "torch", "nlb_tools", "tqdm", "h5py", "pandas", "matplotlib", "plotly"
+ "torch", "nlb_tools", "tqdm", "h5py", "pandas", "matplotlib", "plotly",
+ "cvxpy", "captum", "joblib", "scikit-learn", "scipy", "requests", "sklearn"
]
# autodoc_typehints = "none"
@@ -132,8 +139,18 @@ def get_years(start_year=2021):
# directories to ignore when looking for source files.
# This pattern also affects html_static_path and html_extra_path.
exclude_patterns = [
- "**/todo", "**/src", "cebra-figures/figures.rst", "cebra-figures/*.rst",
- "*/cebra-figures/*.rst", "demo_notebooks/README.rst"
+ "**/todo",
+ "**/src",
+ "cebra-figures/figures.rst",
+ "cebra-figures/*.rst",
+ "*/cebra-figures/*.rst",
+ "*/demo_notebooks/README.rst",
+ "demo_notebooks/README.rst",
+ # TODO(stes): Remove this from the assets repo, then remove here
+ "_static/figures_usage.ipynb",
+ "*/_static/figures_usage.ipynb",
+ "assets/**/*.ipynb",
+ "*/assets/**/*.ipynb"
]
# -- Options for HTML output -------------------------------------------------
@@ -142,6 +159,23 @@ def get_years(start_year=2021):
# a list of builtin themes.
html_theme = "pydata_sphinx_theme"
+html_context = {
+ "default_mode": "light",
+ "switcher": {
+ "version_match":
+ "latest", # Adjust this dynamically per version
+ "versions": [
+ ("latest", "/latest/"),
+ ("v0.2.0", "/v0.2.0/"),
+ ("v0.3.0", "/v0.3.0/"),
+ ("v0.4.0", "/v0.4.0/"),
+ ("v0.5.0rc1", "/v0.5.0rc1/"),
+ ],
+ },
+ "navbar_start": ["version-switcher",
+ "navbar-logo"], # Place the dropdown above the logo
+}
+
# More info on theme options:
# https://pydata-sphinx-theme.readthedocs.io/en/latest/user_guide/configuring.html
html_theme_options = {
@@ -156,11 +190,6 @@ def get_years(start_year=2021):
"url": "https://twitter.com/cebraAI",
"icon": "fab fa-twitter",
},
- # {
- # "name": "DockerHub",
- # "url": "https://hub.docker.com/r/stffsc/cebra",
- # "icon": "fab fa-docker",
- # },
{
"name": "PyPI",
"url": "https://pypi.org/project/cebra/",
@@ -172,23 +201,26 @@ def get_years(start_year=2021):
"icon": "fas fa-graduation-cap",
},
],
- "external_links": [
- # {"name": "Mathis Lab", "url": "http://www.mackenziemathislab.org/"},
- ],
"collapse_navigation": False,
- "navigation_depth": 4,
- "show_nav_level": 2,
+ "navigation_depth": 1,
+ "show_nav_level": 1,
"navbar_align": "content",
"show_prev_next": False,
+ "navbar_end": ["theme-switcher", "navbar-icon-links.html"],
+ "navbar_persistent": [],
+ "header_links_before_dropdown": 7
}
-html_context = {"default_mode": "dark"}
+html_context = {"default_mode": "light"}
html_favicon = "_static/img/logo_small.png"
html_logo = "_static/img/logo_large.png"
-# Remove the search field for now
+# Replace with this configuration to enable "on this page" navigation
html_sidebars = {
- "**": ["search-field.html", "sidebar-nav-bs.html"],
+ "**": ["search-field.html", "sidebar-nav-bs", "page-toc.html"],
+ "demos": ["search-field.html", "sidebar-nav-bs"],
+ "api": ["search-field.html", "sidebar-nav-bs"],
+ "figures": ["search-field.html", "sidebar-nav-bs"],
}
# Disable links for embedded images
@@ -207,6 +239,8 @@ def get_years(start_year=2021):
]
nbsphinx_thumbnails = {
+ "demo_notebooks/CEBRA_best_practices":
+ "_static/thumbnails/cebra-best.png",
"demo_notebooks/Demo_primate_reaching":
"_static/thumbnails/ForelimbS1.png",
"demo_notebooks/Demo_hippocampus":
@@ -235,6 +269,8 @@ def get_years(start_year=2021):
"_static/thumbnails/openScope_demo.png",
"demo_notebooks/Demo_dandi_NeuroDataReHack_2023":
"_static/thumbnails/dandi_demo_monkey.png",
+ "demo_notebooks/Demo_xCEBRA_RatInABox":
+ "_static/thumbnails/xCEBRA.png"
}
rst_prolog = r"""
@@ -247,6 +283,9 @@ def get_years(start_year=2021):
# Download link for the notebook, see
# https://nbsphinx.readthedocs.io/en/0.3.0/prolog-and-epilog.html
+
+# fmt: off
+# flake8: noqa: E501
nbsphinx_prolog = r"""
.. only:: html
@@ -269,3 +308,14 @@ def get_years(start_year=2021):
----
"""
+# fmt: on
+# flake8: enable=E501
+
+# Configure nbsphinx to properly render Plotly plots
+nbsphinx_execute = 'auto'
+nbsphinx_allow_errors = True
+nbsphinx_requirejs_path = 'https://cdnjs.cloudflare.com/ajax/libs/require.js/2.3.7/require.js'
+nbsphinx_execute_arguments = [
+ "--InlineBackend.figure_formats={'png', 'svg', 'pdf'}",
+ "--InlineBackend.rc=figure.dpi=96",
+]
diff --git a/docs/source/contributing.rst b/docs/source/contributing.rst
index cc7ae0a8..7fcd16a1 100644
--- a/docs/source/contributing.rst
+++ b/docs/source/contributing.rst
@@ -155,13 +155,13 @@ Enter the build environment and build the package:
host $ make interact
docker $ make build
# ... outputs ...
- Successfully built cebra-X.X.XaX-py2.py3-none-any.whl
+ Successfully built cebra-X.X.XaX-py3-none-any.whl
The built package can be found in ``dist/`` and can be installed locally with
.. code:: bash
- pip install dist/cebra-X.X.XaX-py2.py3-none-any.whl
+ pip install dist/cebra-X.X.XaX-py3-none-any.whl
**Please do not distribute this package prior to the public release of the CEBRA repository, because it also
contains parts of the source code.**
diff --git a/docs/source/demos.rst b/docs/source/demos.rst
deleted file mode 100644
index f0822386..00000000
--- a/docs/source/demos.rst
+++ /dev/null
@@ -1 +0,0 @@
-.. include:: demo_notebooks/README.rst
diff --git a/docs/source/demos.rst b/docs/source/demos.rst
new file mode 120000
index 00000000..edd57b74
--- /dev/null
+++ b/docs/source/demos.rst
@@ -0,0 +1 @@
+demo_notebooks/README.rst
\ No newline at end of file
diff --git a/docs/source/figures.rst b/docs/source/figures.rst
index 24b1987e..a4101f4a 100644
--- a/docs/source/figures.rst
+++ b/docs/source/figures.rst
@@ -1,7 +1,7 @@
Figures
=======
-CEBRA was introduced in `Schneider, Lee and Mathis (2022)`_ and applied to various datasets across
+CEBRA was introduced in `Schneider, Lee and Mathis (2023)`_ and applied to various datasets across
animals and recording modalities.
In this section, we provide reference code for reproducing the figures and experiments. Since especially
@@ -56,4 +56,4 @@ differ in minor typographic details.
-.. _Schneider, Lee and Mathis (2022): https://arxiv.org/abs/2204.00673
+.. _Schneider, Lee and Mathis (2023): https://www.nature.com/articles/s41586-023-06031-6
diff --git a/docs/source/index.rst b/docs/source/index.rst
index c8231746..1a6ce4d2 100644
--- a/docs/source/index.rst
+++ b/docs/source/index.rst
@@ -34,27 +34,18 @@ Please support the development of CEBRA by starring and/or watching the project
Installation and Setup
----------------------
-Please see the dedicated :doc:`Installation Guide ` for information on installation options using ``conda``, ``pip`` and ``docker``.
-
-Have fun! 😁
+Please see the dedicated :doc:`Installation Guide ` for information on installation options using ``conda``, ``pip`` and ``docker``. Have fun! 😁
Usage
-----
Please head over to the :doc:`Usage ` tab to find step-by-step instructions to use CEBRA on your data. For example use cases, see the :doc:`Demos ` tab.
-Integrations
-------------
-
-CEBRA can be directly integrated with existing libraries commonly used in data analysis. The ``cebra.integrations`` module
-is getting actively extended. Right now, we offer integrations for ``scikit-learn``-like usage of CEBRA, a package making use of ``matplotlib`` to plot the CEBRA model results, as well as the
-possibility to compute CEBRA embeddings on DeepLabCut_ outputs directly.
-
Licensing
---------
-
-Since version 0.4.0, CEBRA is open source software under an Apache 2.0 license.
+The ideas presented in our package are currently patent pending (Patent No. WO2023143843).
+Since version 0.4.0, CEBRA's source is licenced under an Apache 2.0 license.
Prior versions 0.1.0 to 0.3.1 were released for academic use only.
Please see the full license file on Github_ for further information.
@@ -65,13 +56,19 @@ Contributing
Please refer to the :doc:`Contributing ` tab to find our guidelines on contributions.
-Code contributors
+Code Contributors
-----------------
-The CEBRA code was originally developed by Steffen Schneider, Jin H. Lee, and Mackenzie Mathis (up to internal version 0.0.2). As of March 2023, it is being actively extended and maintained by `Steffen Schneider`_, `Célia Benquet`_, and `Mackenzie Mathis`_.
+The CEBRA code was originally developed by Steffen Schneider, Jin H. Lee, and Mackenzie Mathis (up to internal version 0.0.2). Please see our AUTHORS file for more information.
-References
-----------
+Integrations
+------------
+
+CEBRA can be directly integrated with existing libraries commonly used in data analysis. Namely, we provide a ``scikit-learn`` style interface to use CEBRA. Additionally, we offer integrations with our ``scikit-learn``-style of using CEBRA, a package making use of ``matplotlib`` and ``plotly`` to plot the CEBRA model results, as well as the possibility to compute CEBRA embeddings on DeepLabCut_ outputs directly. If you have another suggestion, please head over to Discussions_ on GitHub_!
+
+
+Key References
+--------------
.. code::
@article{schneider2023cebra,
@@ -82,14 +79,22 @@ References
year = {2023},
}
+ @article{xCEBRA2025,
+ author={Steffen Schneider and Rodrigo Gonz{\'a}lez Laiz and Anastasiia Filippova and Markus Frey and Mackenzie W Mathis},
+ title = {Time-series attribution maps with regularized contrastive learning},
+ journal = {AISTATS},
+ url = {https://openreview.net/forum?id=aGrCXoTB4P},
+ year = {2025},
+ }
+
This documentation is based on the `PyData Theme`_.
.. _`Twitter`: https://twitter.com/cebraAI
.. _`PyData Theme`: https://github.com/pydata/pydata-sphinx-theme
.. _`DeepLabCut`: https://deeplabcut.org
+.. _`Discussions`: https://github.com/AdaptiveMotorControlLab/CEBRA/discussions
.. _`Github`: https://github.com/AdaptiveMotorControlLab/cebra
.. _`email`: mailto:mackenzie.mathis@epfl.ch
.. _`Steffen Schneider`: https://github.com/stes
-.. _`Célia Benquet`: https://github.com/CeliaBenquet
.. _`Mackenzie Mathis`: https://github.com/MMathisLab
diff --git a/docs/source/installation.rst b/docs/source/installation.rst
index a9650452..1630cfe8 100644
--- a/docs/source/installation.rst
+++ b/docs/source/installation.rst
@@ -4,7 +4,7 @@ Installation Guide
System Requirements
-------------------
-CEBRA is written in Python (3.8+) and PyTorch. CEBRA is most effective when used with a GPU, but CPU-only support is provided. We provide instructions to run CEBRA on your system directly. The instructions below were tested on different compute setups with Ubuntu 18.04 or 20.04, using Nvidia GTX 2080, A4000, and V100 cards. Other setups are possible (including Windows), as long as CUDA 10.2+ support is guaranteed.
+CEBRA is written in Python (3.9+) and PyTorch. CEBRA is most effective when used with a GPU, but CPU-only support is provided. We provide instructions to run CEBRA on your system directly. The instructions below were tested on different compute setups with Ubuntu 18.04 or 20.04, using Nvidia GTX 2080, A4000, and V100 cards. Other setups are possible (including Windows), as long as CUDA 10.2+ support is guaranteed.
- Software dependencies and operating systems:
- Linux or MacOS
@@ -93,11 +93,11 @@ we outline different options below.
* 🚀 For more advanced users, CEBRA has different extra install options that you can select based on your usecase:
- * ``[integrations]``: This will install (experimental) support for our streamlit and jupyter integrations.
+ * ``[integrations]``: This will install (experimental) support for integrations, such as plotly.
* ``[docs]``: This will install additional dependencies for building the package documentation.
* ``[dev]``: This will install additional dependencies for development, unit and integration testing,
code formatting, etc. Install this extension if you want to work on a pull request.
- * ``[demos]``: This will install additional dependencies for running our demo notebooks.
+ * ``[demos]``: This will install additional dependencies for running our demo notebooks in Jupyter.
* ``[datasets]``: This extension will install additional dependencies to use the pre-installed datasets
in ``cebra.datasets``.
@@ -149,6 +149,13 @@ we outline different options below.
Note that, similarly to that last command, you can select the specific install options of interest based on their description above and on your usecase.
+ .. tab:: Docker
+
+ .. code:: bash
+
+ $ docker pull mmathislab/cebra-cuda12.4-cudnn9
+
+ You can pull our container from DockerHub: https://hub.docker.com/u/mmathislab
..
diff --git a/docs/source/usage.rst b/docs/source/usage.rst
index 334f1bbc..82e45a0b 100644
--- a/docs/source/usage.rst
+++ b/docs/source/usage.rst
@@ -1,7 +1,7 @@
Using CEBRA
===========
-This page covers a standard CEBRA usage. We recommend checking out the :py:doc:`demos` for in-depth CEBRA usage examples as well. Here we present a quick overview on how to use CEBRA on various datasets. Note that we provide two ways to interact with the code:
+This page covers a standard CEBRA usage. We recommend checking out the :py:doc:`demos` for CEBRA usage examples as well. Here we present a quick overview on how to use CEBRA on various datasets. Note that we provide two ways to interact with the code:
* For regular usage, we recommend leveraging the **high-level interface**, adhering to ``scikit-learn`` formatting.
* Upon specific needs, advanced users might consider diving into the **low-level interface** that adheres to ``PyTorch`` formatting.
@@ -12,7 +12,7 @@ Firstly, why use CEBRA?
CEBRA is primarily designed for producing robust, consistent extractions of latent factors from time-series data. It supports three modes, and is a self-supervised representation learning algorithm that uses our modified contrastive learning approach designed for multi-modal time-series data. In short, it is a type of non-linear dimensionality reduction, like `tSNE
`_ and `UMAP `_. We show in our original paper that it outperforms tSNE and UMAP at producing closer-to-ground-truth latents and is more consistent.
-That being said, CEBRA can be used on non-time-series data and it does not strictly require multi-modal data. In general, we recommend considering using CEBRA for measuring changes in consistency across conditions (brain areas, cells, animals), for hypothesis-guided decoding, and for topological exploration of the resulting embedding spaces. It can also be used for visualization and considering dynamics within the embedding space. For examples of how CEBRA can be used to map space, decode natural movies, and make hypotheses for neural coding of sensorimotor systems, see our paper (Schneider, Lee, Mathis, 2023).
+That being said, CEBRA can be used on non-time-series data and it does not strictly require multi-modal data. In general, we recommend considering using CEBRA for measuring changes in consistency across conditions (brain areas, cells, animals), for hypothesis-guided decoding, and for topological exploration of the resulting embedding spaces. It can also be used for visualization and considering dynamics within the embedding space. For examples of how CEBRA can be used to map space, decode natural movies, and make hypotheses for neural coding of sensorimotor systems, see `Schneider, Lee, Mathis. Nature 2023 `_.
The CEBRA workflow
------------------
@@ -22,7 +22,7 @@ We recommend to start with running CEBRA-Time (unsupervised) and look both at th
(1) Use CEBRA-Time for unsupervised data exploration.
(2) Consider running a hyperparameter sweep on the inputs to the model, such as :py:attr:`cebra.CEBRA.model_architecture`, :py:attr:`cebra.CEBRA.time_offsets`, :py:attr:`cebra.CEBRA.output_dimension`, and set :py:attr:`cebra.CEBRA.batch_size` to be as high as your GPU allows. You want to see clear structure in the 3D plot (the first 3 latents are shown by default).
-(3) Use CEBRA-Behavior with many different labels and combinations, then look at the InfoNCE loss - the lower the loss value, the better the fit (see :py:doc:`cebra-figures/figures/ExtendedDataFigure5`), and visualize the embeddings. The goal is to understand which labels are contributing to the structure you see in CEBRA-Time, and improve this structure. Again, you should consider a hyperparameter sweep.
+(3) Use CEBRA-Behavior with many different labels and combinations, then look at the InfoNCE loss - the lower the loss value, the better the fit (see :py:doc:`cebra-figures/figures/ExtendedDataFigure5`), and visualize the embeddings. The goal is to understand which labels are contributing to the structure you see in CEBRA-Time, and improve this structure. Again, you should consider a hyperparameter sweep (and avoid overfitting by performing the proper train/validation split (see Step 3 in our quick start guide below).
(4) Interpretability: now you can use these latents in downstream tasks, such as measuring consistency, decoding, and determining the dimensionality of your data with topological data analysis.
All the steps to do this are described below. Enjoy using CEBRA! 🔥🦓
@@ -179,7 +179,7 @@ We provide a set of pre-defined models. You can access (and search) a list of av
Then, you can choose the one that fits best with your needs and provide it to the CEBRA model as the :py:attr:`~.CEBRA.model_architecture` parameter.
-As an indication the table below presents the model architecture we used to train CEBRA on the datasets presented in our paper (Schneider, Lee, Mathis, 2022).
+As an indication the table below presents the model architecture we used to train CEBRA on the datasets presented in our paper (Schneider, Lee, Mathis. Nature 2023).
.. list-table::
:widths: 25 25 20 30
@@ -265,9 +265,8 @@ For standard usage we recommend the default values (i.e., ``InfoNCE`` and ``cosi
.. rubric:: Temperature :py:attr:`~.CEBRA.temperature`
-:py:attr:`~.CEBRA.temperature` has the largest effect on visualization of the embedding (see :py:doc:`cebra-figures/figures/ExtendedDataFigure2`). Hence, it is important that it is fitted to your specific data.
+:py:attr:`~.CEBRA.temperature` has the largest effect on *visualization* of the embedding (see :py:doc:`cebra-figures/figures/ExtendedDataFigure2`). Hence, it is important that it is fitted to your specific data. Lower temperatures (e.g. around 0.1) will result in a more dispersed embedding, higher temperatures (larger than 1) will concentrate the embedding.
-The simplest way to handle it is to use a *learnable temperature*. For that, set :py:attr:`~.CEBRA.temperature_mode` to ``auto``. :py:attr:`~.CEBRA.temperature` will be trained alongside the model.
🚀 For advance usage, you might need to find the optimal :py:attr:`~.CEBRA.temperature`. For that we recommend to perform a grid-search.
@@ -307,7 +306,6 @@ Here is an example of a CEBRA model initialization:
cebra_model = CEBRA(
model_architecture = "offset10-model",
batch_size = 1024,
- temperature_mode="auto",
learning_rate = 0.001,
max_iterations = 10,
time_offsets = 10,
@@ -321,8 +319,7 @@ Here is an example of a CEBRA model initialization:
.. testoutput::
CEBRA(batch_size=1024, learning_rate=0.001, max_iterations=10,
- model_architecture='offset10-model', temperature_mode='auto',
- time_offsets=10)
+ model_architecture='offset10-model', time_offsets=10)
.. admonition:: See API docs
:class: dropdown
@@ -568,7 +565,8 @@ We provide a simple hyperparameters sweep to compare CEBRA models with different
learning_rate = [0.001],
time_offsets = 5,
max_iterations = 5,
- temperature_mode = "auto",
+ temperature_mode='constant',
+ temperature = 0.1,
verbose = False)
# 2. Define the datasets to iterate over
@@ -820,7 +818,7 @@ It takes a CEBRA model and returns a 2D plot of the loss against the number of i
Displaying the temperature
""""""""""""""""""""""""""
-:py:attr:`~.CEBRA.temperature` has the largest effect on the visualization of the embedding. Hence it might be interesting to check its evolution when ``temperature_mode=auto``.
+:py:attr:`~.CEBRA.temperature` has the largest effect on the visualization of the embedding. Hence it might be interesting to check its evolution when ``temperature_mode=auto``. We recommend only using `auto` if you have first explored the `constant` setting. If you use the ``auto`` mode, please always check the time evolution of the temperature over time alongside the loss curve.
To that extend, you can use the function :py:func:`~.plot_temperature`.
@@ -1186,9 +1184,10 @@ Improve model performance
🧐 Below is a (non-exhaustive) list of actions you can try if your embedding looks different from what you were expecting.
#. Assess that your model `converged `_. For that, observe if the training loss stabilizes itself around the end of the training or still seems to be decreasing. Refer to `Visualize the training loss`_ for more details on how to display the training loss.
-#. Increase the number of iterations. It should be at least 10,000.
+#. Increase the number of iterations. It typically should be at least 10,000. On small datasets, it can make sense to stop training earlier to avoid overfitting effects.
#. Make sure the batch size is big enough. It should be at least 512.
#. Fine-tune the model's hyperparameters, namely ``learning_rate``, ``output_dimension``, ``num_hidden_units`` and eventually ``temperature`` (by setting ``temperature_mode`` back to ``constant``). Refer to `Grid search`_ for more details on performing hyperparameters tuning.
+#. To note, you should still be mindful of performing train/validation splits and shuffle controls to avoid `overfitting `_.
@@ -1202,17 +1201,22 @@ Putting all previous snippet examples together, we obtain the following pipeline
import cebra
from numpy.random import uniform, randint
from sklearn.model_selection import train_test_split
+ import os
+ import tempfile
+ from pathlib import Path
# 1. Define a CEBRA model
cebra_model = cebra.CEBRA(
- model_architecture = "offset10-model",
- batch_size = 512,
- learning_rate = 1e-4,
- max_iterations = 10, # TODO(user): to change to at least 10'000
- max_adapt_iterations = 10, # TODO(user): to change to ~100-500
- time_offsets = 10,
- output_dimension = 8,
- verbose = False
+ model_architecture = "offset10-model",
+ batch_size = 512,
+ learning_rate = 1e-4,
+ temperature_mode='constant',
+ temperature = 0.1,
+ max_iterations = 10, # TODO(user): to change to ~500-10000 depending on dataset size
+ #max_adapt_iterations = 10, # TODO(user): use and to change to ~100-500 if adapting
+ time_offsets = 10,
+ output_dimension = 8,
+ verbose = False
)
# 2. Load example data
@@ -1221,34 +1225,40 @@ Putting all previous snippet examples together, we obtain the following pipeline
continuous_label = cebra.load_data(file="auxiliary_behavior_data.h5", key="auxiliary_variables", columns=["continuous1", "continuous2", "continuous3"])
discrete_label = cebra.load_data(file="auxiliary_behavior_data.h5", key="auxiliary_variables", columns=["discrete"]).flatten()
+
assert neural_data.shape == (100, 3)
assert new_neural_data.shape == (100, 4)
assert discrete_label.shape == (100, )
assert continuous_label.shape == (100, 3)
- # 3. Split data and labels
- (
- train_data,
- valid_data,
- train_discrete_label,
- valid_discrete_label,
- train_continuous_label,
- valid_continuous_label,
- ) = train_test_split(neural_data,
- discrete_label,
- continuous_label,
- test_size=0.3)
+ # 3. Split data and labels into train/validation
+ from sklearn.model_selection import train_test_split
+
+ split_idx = int(0.8 * len(neural_data))
+ # suggestion: 5%-20% depending on your dataset size; note that this splits the
+ # into an early and late part, which might not be ideal for your data/experiment!
+ # As a more involved alternative, consider e.g. a nested time-series split.
+
+ train_data = neural_data[:split_idx]
+ valid_data = neural_data[split_idx:]
+
+ train_continuous_label = continuous_label[:split_idx]
+ valid_continuous_label = continuous_label[split_idx:]
+
+ train_discrete_label = discrete_label[:split_idx]
+ valid_discrete_label = discrete_label[split_idx:]
# 4. Fit the model
# time contrastive learning
cebra_model.fit(train_data)
# discrete behavior contrastive learning
- cebra_model.fit(train_data, train_discrete_label,)
+ cebra_model.fit(train_data, train_discrete_label)
# continuous behavior contrastive learning
cebra_model.fit(train_data, train_continuous_label)
# mixed behavior contrastive learning
cebra_model.fit(train_data, train_discrete_label, train_continuous_label)
+
# 5. Save the model
tmp_file = Path(tempfile.gettempdir(), 'cebra.pt')
cebra_model.save(tmp_file)
@@ -1257,15 +1267,15 @@ Putting all previous snippet examples together, we obtain the following pipeline
cebra_model = cebra.CEBRA.load(tmp_file)
train_embedding = cebra_model.transform(train_data)
valid_embedding = cebra_model.transform(valid_data)
- assert train_embedding.shape == (70, 8)
- assert valid_embedding.shape == (30, 8)
- # 7. Evaluate the model performances
- goodness_of_fit = cebra.sklearn.metrics.infonce_loss(cebra_model,
+ assert train_embedding.shape == (80, 8) # TODO(user): change to split ratio & output dim
+ assert valid_embedding.shape == (20, 8) # TODO(user): change to split ratio & output dim
+
+ # 7. Evaluate the model performance (you can also check the train_data)
+ goodness_of_fit = cebra.sklearn.metrics.goodness_of_fit_score(cebra_model,
valid_data,
valid_discrete_label,
- valid_continuous_label,
- num_batches=5)
+ valid_continuous_label)
# 8. Adapt the model to a new session
cebra_model.fit(new_neural_data, adapt = True)
@@ -1274,7 +1284,9 @@ Putting all previous snippet examples together, we obtain the following pipeline
decoder = cebra.KNNDecoder()
decoder.fit(train_embedding, train_discrete_label)
prediction = decoder.predict(valid_embedding)
- assert prediction.shape == (30,)
+ assert prediction.shape == (20,)
+
+
👉 For further guidance on different/customized applications of CEBRA on your own data, refer to the ``examples/`` folder or to the full documentation folder ``docs/``.
@@ -1424,17 +1436,14 @@ gets initialized which also allows the `prior` to be directly parametrized.
solver.fit(loader=loader)
# 7. Transform Embedding
- train_batches = np.lib.stride_tricks.sliding_window_view(
- neural_data, neural_model.get_offset().__len__(), axis=0
- )
-
x_train_emb = solver.transform(
- torch.from_numpy(train_batches[:]).type(torch.FloatTensor).to(device)
- ).to(device)
+ torch.from_numpy(neural_data).type(torch.FloatTensor).to(device),
+ pad_before_transform=True,
+ batch_size=512).to(device)
# 8. Plot Embedding
cebra.plot_embedding(
x_train_emb.cpu(),
- discrete_label[neural_model.get_offset().__len__() - 1 :, 0],
+ discrete_label[:,0],
markersize=10,
)
diff --git a/pyproject.toml b/pyproject.toml
index 4a927c6c..b64475e7 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -1,7 +1,8 @@
[build-system]
requires = [
"setuptools>=43",
- "wheel"
+ "wheel",
+ "packaging>=24.2"
]
build-backend = "setuptools.build_meta"
diff --git a/reinstall.sh b/reinstall.sh
index 778f98eb..ea8981b9 100755
--- a/reinstall.sh
+++ b/reinstall.sh
@@ -15,7 +15,7 @@ pip uninstall -y cebra
# Get version info after uninstalling --- this will automatically get the
# most recent version based on the source code in the current directory.
# $(tools/get_cebra_version.sh)
-VERSION=0.4.0
+VERSION=0.6.0a1
echo "Upgrading to CEBRA v${VERSION}"
# Upgrade the build system (PEP517/518 compatible)
@@ -24,4 +24,4 @@ python3 -m pip install --upgrade build
python3 -m build --sdist --wheel .
# Reinstall the package with most recent version
-pip install --upgrade --no-cache-dir "dist/cebra-${VERSION}-py2.py3-none-any.whl[datasets,integrations]"
+pip install --upgrade --no-cache-dir "dist/cebra-${VERSION}-py3-none-any.whl[datasets,integrations]"
diff --git a/setup.cfg b/setup.cfg
index 68263d73..7faff998 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -1,8 +1,8 @@
[metadata]
name = cebra
version = attr: cebra.__version__
-author = Steffen Schneider, Jin H Lee, Mackenzie W Mathis
-author_email = stes@hey.com
+author = file: AUTHORS.md
+author_email = stes@hey.com, mackenzie@post.harvard.edu
description = Consistent Embeddings of high-dimensional Recordings using Auxiliary variables
long_description = file: README.md
long_description_content_type = text/markdown
@@ -31,13 +31,17 @@ where =
python_requires = >=3.9
install_requires =
joblib
- numpy<2.0.0
+ numpy<2.0;platform_system=="Windows"
+ numpy<2.0;platform_system!="Windows" and python_version<"3.10"
+ numpy;platform_system!="Windows" and python_version>="3.10"
literate-dataclasses
scikit-learn
scipy
- torch
+ torch>=2.4.0
tqdm
- matplotlib
+ # NOTE(stes): Remove pin once https://github.com/AdaptiveMotorControlLab/CEBRA/issues/240
+ # is resolved.
+ matplotlib<3.11
requests
[options.extras_require]
@@ -56,15 +60,18 @@ datasets =
hdf5storage # for creating .mat files in new format
openpyxl # for excel file format loading
integrations =
- jupyter
pandas
plotly
+ seaborn
+ captum
+ cvxpy
+ scikit-image
docs =
- sphinx==5.3
- sphinx-gallery==0.10.1
+ sphinx
+ sphinx-gallery
docutils
- pydata-sphinx-theme==0.9.0
- sphinx_autodoc_typehints==1.19
+ pydata-sphinx-theme
+ sphinx_autodoc_typehints
sphinx_copybutton
sphinx_tabs
sphinx_design
@@ -72,16 +79,14 @@ docs =
nbsphinx
nbconvert
ipykernel
- matplotlib<=3.5.2
+ matplotlib
pandas
seaborn
scikit-learn
- numpy<2.0.0
demos =
ipykernel
jupyter
nbconvert
- seaborn
# TODO(stes): Additional dependency for running
# co-homology analysis
# is ripser, which can be tricky to
@@ -104,12 +109,10 @@ dev =
pytest-sphinx
tables
licenseheaders
+ interrogate
# TODO(stes) Add back once upstream issue
# https://github.com/PyCQA/docformatter/issues/119
# is resolved.
# docformatter[tomli]
codespell
cffconvert
-
-[bdist_wheel]
-universal=1
diff --git a/tests/_build_legacy_model/.gitignore b/tests/_build_legacy_model/.gitignore
new file mode 100644
index 00000000..4b6ebe5f
--- /dev/null
+++ b/tests/_build_legacy_model/.gitignore
@@ -0,0 +1 @@
+*.pt
diff --git a/tests/_build_legacy_model/Dockerfile b/tests/_build_legacy_model/Dockerfile
new file mode 100644
index 00000000..ddbb0e61
--- /dev/null
+++ b/tests/_build_legacy_model/Dockerfile
@@ -0,0 +1,39 @@
+FROM python:3.12-slim AS base
+RUN pip install torch --index-url https://download.pytorch.org/whl/cpu
+RUN apt-get update && \
+ apt-get install -y --no-install-recommends git && \
+ rm -rf /var/lib/apt/lists/*
+
+FROM base AS cebra-0.4.0-scikit-learn-1.4
+RUN pip install cebra==0.4.0 "scikit-learn<1.5"
+WORKDIR /app
+COPY create_model.py .
+RUN python create_model.py
+
+FROM base AS cebra-0.4.0-scikit-learn-1.6
+RUN pip install cebra==0.4.0 "scikit-learn>=1.6"
+WORKDIR /app
+COPY create_model.py .
+RUN python create_model.py
+
+FROM base AS cebra-rc-scikit-learn-1.4
+# NOTE(stes): Commit where new scikit-learn tag logic was added to the CEBRA class.
+# https://github.com/AdaptiveMotorControlLab/CEBRA/commit/5f46c3257952a08dfa9f9e1b149a85f7f12c1053
+RUN pip install git+https://github.com/AdaptiveMotorControlLab/CEBRA.git@5f46c3257952a08dfa9f9e1b149a85f7f12c1053 "scikit-learn<1.5"
+WORKDIR /app
+COPY create_model.py .
+RUN python create_model.py
+
+FROM base AS cebra-rc-scikit-learn-1.6
+# NOTE(stes): Commit where new scikit-learn tag logic was added to the CEBRA class.
+# https://github.com/AdaptiveMotorControlLab/CEBRA/commit/5f46c3257952a08dfa9f9e1b149a85f7f12c1053
+RUN pip install git+https://github.com/AdaptiveMotorControlLab/CEBRA.git@5f46c3257952a08dfa9f9e1b149a85f7f12c1053 "scikit-learn>=1.6"
+WORKDIR /app
+COPY create_model.py .
+RUN python create_model.py
+
+FROM scratch
+COPY --from=cebra-0.4.0-scikit-learn-1.4 /app/cebra_model.pt /cebra_model_cebra-0.4.0-scikit-learn-1.4.pt
+COPY --from=cebra-0.4.0-scikit-learn-1.6 /app/cebra_model.pt /cebra_model_cebra-0.4.0-scikit-learn-1.6.pt
+COPY --from=cebra-rc-scikit-learn-1.4 /app/cebra_model.pt /cebra_model_cebra-rc-scikit-learn-1.4.pt
+COPY --from=cebra-rc-scikit-learn-1.6 /app/cebra_model.pt /cebra_model_cebra-rc-scikit-learn-1.6.pt
diff --git a/tests/_build_legacy_model/README.md b/tests/_build_legacy_model/README.md
new file mode 100644
index 00000000..4bcffa2b
--- /dev/null
+++ b/tests/_build_legacy_model/README.md
@@ -0,0 +1,13 @@
+# Helper script to build CEBRA checkpoints
+
+This script builds CEBRA checkpoints for different versions of scikit-learn and CEBRA.
+To build all models, run:
+
+```bash
+./generate.sh
+```
+
+The models are currently also stored in git directly due to their small size.
+
+Related issue: https://github.com/AdaptiveMotorControlLab/CEBRA/issues/207
+Related test: tests/test_sklearn_legacy.py
diff --git a/tests/_build_legacy_model/create_model.py b/tests/_build_legacy_model/create_model.py
new file mode 100644
index 00000000..f308d296
--- /dev/null
+++ b/tests/_build_legacy_model/create_model.py
@@ -0,0 +1,15 @@
+import numpy as np
+
+import cebra
+
+neural_data = np.random.normal(0, 1, (1000, 30)) # 1000 samples, 30 features
+cebra_model = cebra.CEBRA(model_architecture="offset10-model",
+ batch_size=512,
+ learning_rate=1e-4,
+ max_iterations=10,
+ time_offsets=10,
+ num_hidden_units=16,
+ output_dimension=8,
+ verbose=True)
+cebra_model.fit(neural_data)
+cebra_model.save("cebra_model.pt")
diff --git a/tests/_build_legacy_model/generate.sh b/tests/_build_legacy_model/generate.sh
new file mode 100755
index 00000000..749a0d32
--- /dev/null
+++ b/tests/_build_legacy_model/generate.sh
@@ -0,0 +1,3 @@
+#!/bin/bash
+
+DOCKER_BUILDKIT=1 docker build --output type=local,dest=. .
diff --git a/tests/_util.py b/tests/_util.py
index b4a0e07d..42dd54cb 100644
--- a/tests/_util.py
+++ b/tests/_util.py
@@ -74,3 +74,8 @@ def parametrize_with_checks_slow(fast_arguments, slow_arguments):
slow_arg, generate_only=True))[0] for slow_arg in slow_arguments
]
return parametrize_slow("estimator,check", fast_params, slow_params)
+
+
+def parametrize_device(func):
+ _devices = ("cpu", "cuda") if torch.cuda.is_available() else ("cpu",)
+ return pytest.mark.parametrize("device", _devices)(func)
diff --git a/tests/_utils_deprecated.py b/tests/_utils_deprecated.py
new file mode 100644
index 00000000..bf412058
--- /dev/null
+++ b/tests/_utils_deprecated.py
@@ -0,0 +1,126 @@
+import warnings
+from typing import Optional, Union
+
+import numpy as np
+import numpy.typing as npt
+import sklearn.utils.validation as sklearn_utils_validation
+import torch
+
+import cebra
+import cebra.integrations.sklearn.utils as sklearn_utils
+import cebra.models
+
+
+#NOTE: Deprecated: transform is now handled in the solver but the original
+# method is kept here for testing.
+def cebra_transform_deprecated(cebra_model,
+ X: Union[npt.NDArray, torch.Tensor],
+ session_id: Optional[int] = None) -> npt.NDArray:
+ """Transform an input sequence and return the embedding.
+
+ Args:
+ cebra_model: The CEBRA model to use for the transform.
+ X: A numpy array or torch tensor of size ``time x dimension``.
+ session_id: The session ID, an :py:class:`int` between 0 and :py:attr:`num_sessions` for
+ multisession, set to ``None`` for single session.
+
+ Returns:
+ A :py:func:`numpy.array` of size ``time x output_dimension``.
+
+ Example:
+
+ >>> import cebra
+ >>> import numpy as np
+ >>> dataset = np.random.uniform(0, 1, (1000, 30))
+ >>> cebra_model = cebra.CEBRA(max_iterations=10)
+ >>> cebra_model.fit(dataset)
+ CEBRA(max_iterations=10)
+ >>> embedding = cebra_model.transform(dataset)
+
+ """
+ warnings.warn(
+ "The method is deprecated "
+ "but kept for testing puroposes."
+ "We recommend using `transform` instead.",
+ DeprecationWarning,
+ stacklevel=2)
+
+ sklearn_utils_validation.check_is_fitted(cebra_model, "n_features_")
+ model, offset = cebra_model._select_model(X, session_id)
+
+ # Input validation
+ X = sklearn_utils.check_input_array(X, min_samples=len(cebra_model.offset_))
+ input_dtype = X.dtype
+
+ with torch.no_grad():
+ model.eval()
+
+ if cebra_model.pad_before_transform:
+ X = np.pad(X, ((offset.left, offset.right - 1), (0, 0)),
+ mode="edge")
+ X = torch.from_numpy(X).float().to(cebra_model.device_)
+
+ if isinstance(model, cebra.models.ConvolutionalModelMixin):
+ # Fully convolutional evaluation, switch (T, C) -> (1, C, T)
+ X = X.transpose(1, 0).unsqueeze(0)
+ output = model(X).cpu().numpy().squeeze(0).transpose(1, 0)
+ else:
+ # Standard evaluation, (T, C, dt)
+ output = model(X).cpu().numpy()
+
+ if input_dtype == "float64":
+ return output.astype(input_dtype)
+
+ return output
+
+
+# NOTE: Deprecated: batched transform can now be performed (more memory efficient)
+# using the transform method of the model, and handling padding is implemented
+# directly in the base Solver. This method is kept for testing purposes.
+@torch.no_grad()
+def multiobjective_transform_deprecated(solver: "cebra.solvers.Solver",
+ inputs: torch.Tensor) -> torch.Tensor:
+ """Transform the input data using the model.
+
+ Args:
+ solver: The solver containing the model and device.
+ inputs: The input data to transform.
+
+ Returns:
+ The transformed data.
+ """
+
+ warnings.warn(
+ "The method is deprecated "
+ "but kept for testing puroposes."
+ "We recommend using `transform` instead.",
+ DeprecationWarning,
+ stacklevel=2)
+
+ offset = solver.model.get_offset()
+ solver.model.eval()
+ X = inputs.cpu().numpy()
+ X = np.pad(X, ((offset.left, offset.right - 1), (0, 0)), mode="edge")
+ X = torch.from_numpy(X).float().to(solver.device)
+
+ if isinstance(solver.model.module, cebra.models.ConvolutionalModelMixin):
+ # Fully convolutional evaluation, switch (T, C) -> (1, C, T)
+ X = X.transpose(1, 0).unsqueeze(0)
+ outputs = solver.model(X)
+
+ # switch back from (1, C, T) -> (T, C)
+ if isinstance(outputs, torch.Tensor):
+ assert outputs.dim() == 3 and outputs.shape[0] == 1
+ outputs = outputs.squeeze(0).transpose(1, 0)
+ elif isinstance(outputs, tuple):
+ assert all(tensor.dim() == 3 and tensor.shape[0] == 1
+ for tensor in outputs)
+ outputs = (output.squeeze(0).transpose(1, 0) for output in outputs)
+ outputs = tuple(outputs)
+ else:
+ raise ValueError("Invalid condition in solver.transform")
+ else:
+ # Standard evaluation, (T, C, dt)
+ outputs = solver.model(X)
+
+ return outputs
diff --git a/tests/test_api.py b/tests/test_api.py
index bc279cbd..4e514429 100644
--- a/tests/test_api.py
+++ b/tests/test_api.py
@@ -21,6 +21,5 @@
#
def test_api():
import cebra.distributions
- from cebra.distributions import TimedeltaDistribution
cebra.distributions.TimedeltaDistribution
diff --git a/tests/test_attribution.py b/tests/test_attribution.py
new file mode 100644
index 00000000..cfb8ad7a
--- /dev/null
+++ b/tests/test_attribution.py
@@ -0,0 +1,214 @@
+import numpy as np
+import pytest
+import torch
+
+import cebra.attribution._jacobian
+import cebra.attribution.jacobian_attribution as jacobian_attribution
+from cebra.attribution import attribution_models
+from cebra.models import Model
+
+
+class DummyModel(Model):
+
+ def __init__(self):
+ super().__init__(num_input=10, num_output=5)
+ self.linear = torch.nn.Linear(10, 5)
+
+ def forward(self, x):
+ return self.linear(x)
+
+ def get_offset(self):
+ return None
+
+
+@pytest.fixture
+def model():
+ return DummyModel()
+
+
+@pytest.fixture
+def input_data():
+ return torch.randn(100, 10)
+
+
+def test_neuron_gradient_method(model, input_data):
+ attribution = attribution_models.NeuronGradientMethod(model=model,
+ input_data=input_data,
+ output_dimension=5)
+
+ result = attribution.compute_attribution_map()
+
+ assert 'neuron-gradient' in result
+ assert 'neuron-gradient-convabs' in result
+ assert result['neuron-gradient'].shape == (100, 5, 10)
+
+
+def test_neuron_gradient_shap_method(model, input_data):
+ attribution = attribution_models.NeuronGradientShapMethod(
+ model=model, input_data=input_data, output_dimension=5)
+
+ result = attribution.compute_attribution_map(baselines="zeros")
+
+ assert 'neuron-gradient-shap' in result
+ assert 'neuron-gradient-shap-convabs' in result
+ assert result['neuron-gradient-shap'].shape == (100, 5, 10)
+
+ with pytest.raises(NotImplementedError):
+ attribution.compute_attribution_map(baselines="invalid")
+
+
+def test_feature_ablation_method(model, input_data):
+ attribution = attribution_models.FeatureAblationMethod(
+ model=model, input_data=input_data, output_dimension=5)
+
+ result = attribution.compute_attribution_map()
+
+ assert 'feature-ablation' in result
+ assert 'feature-ablation-convabs' in result
+ assert result['feature-ablation'].shape == (100, 5, 10)
+
+
+def test_integrated_gradients_method(model, input_data):
+ attribution = attribution_models.IntegratedGradientsMethod(
+ model=model, input_data=input_data, output_dimension=5)
+
+ result = attribution.compute_attribution_map()
+
+ assert 'integrated-gradients' in result
+ assert 'integrated-gradients-convabs' in result
+ assert result['integrated-gradients'].shape == (100, 5, 10)
+
+
+def test_batched_methods(model, input_data):
+ # Test batched version of NeuronGradientMethod
+ attribution = attribution_models.NeuronGradientMethodBatched(
+ model=model, input_data=input_data, output_dimension=5)
+
+ result = attribution.compute_attribution_map(batch_size=32)
+ assert 'neuron-gradient' in result
+ assert result['neuron-gradient'].shape == (100, 5, 10)
+
+ # Test batched version of IntegratedGradientsMethod
+ attribution = attribution_models.IntegratedGradientsMethodBatched(
+ model=model, input_data=input_data, output_dimension=5)
+
+ result = attribution.compute_attribution_map(batch_size=32)
+ assert 'integrated-gradients' in result
+ assert result['integrated-gradients'].shape == (100, 5, 10)
+
+
+def test_compute_metrics():
+ attribution = attribution_models.AttributionMap(model=None, input_data=None)
+
+ attribution_map = np.array([0.1, 0.8, 0.3, 0.9, 0.2])
+ ground_truth = np.array([False, True, False, True, False])
+
+ metrics = attribution.compute_metrics(attribution_map, ground_truth)
+
+ assert 'max_connected' in metrics
+ assert 'mean_connected' in metrics
+ assert 'min_connected' in metrics
+ assert 'max_nonconnected' in metrics
+ assert 'mean_nonconnected' in metrics
+ assert 'min_nonconnected' in metrics
+ assert 'gap_max' in metrics
+ assert 'gap_mean' in metrics
+ assert 'gap_min' in metrics
+ assert 'gap_minmax' in metrics
+ assert 'max_jacobian' in metrics
+ assert 'min_jacobian' in metrics
+
+
+def test_compute_attribution_score():
+ attribution = attribution_models.AttributionMap(model=None, input_data=None)
+
+ attribution_map = np.array([0.1, 0.8, 0.3, 0.9, 0.2])
+ ground_truth = np.array([False, True, False, True, False])
+
+ score = attribution.compute_attribution_score(attribution_map, ground_truth)
+ assert isinstance(score, float)
+ assert 0 <= score <= 1
+
+
+def test_jacobian_computation():
+ # Create a simple model and input for testing
+ model = torch.nn.Sequential(torch.nn.Linear(10, 5), torch.nn.ReLU(),
+ torch.nn.Linear(5, 3))
+ input_data = torch.randn(100, 10, requires_grad=True)
+
+ # Test basic Jacobian computation
+ jf, jhatg = jacobian_attribution.get_attribution_map(model=model,
+ input_data=input_data,
+ double_precision=True,
+ convert_to_numpy=True)
+
+ # Check shapes
+ assert jf.shape == (100, 3, 10) # (batch_size, output_dim, input_dim)
+ assert jhatg.shape == (100, 10, 3) # (batch_size, input_dim, output_dim)
+
+
+def test_tensor_conversion():
+ # Test CPU and double precision conversion
+ test_tensors = [torch.randn(10, 5), torch.randn(5, 3)]
+
+ converted = cebra.attribution._jacobian.tensors_to_cpu_and_double(
+ test_tensors)
+
+ for tensor in converted:
+ assert tensor.device.type == "cpu"
+ assert tensor.dtype == torch.float64
+
+ # Only test CUDA conversion if CUDA is available
+ if torch.cuda.is_available():
+ cuda_tensors = cebra.attribution._jacobian.tensors_to_cuda(
+ test_tensors, cuda_device="cuda")
+ for tensor in cuda_tensors:
+ assert tensor.is_cuda
+ else:
+ # Skip CUDA test with a message
+ pytest.skip("CUDA not available - skipping CUDA conversion test")
+
+
+def test_jacobian_with_hybrid_solver():
+ # Test Jacobian computation with hybrid solver
+ class HybridModel(torch.nn.Module):
+
+ def __init__(self):
+ super().__init__()
+ self.fc1 = torch.nn.Linear(10, 5)
+ self.fc2 = torch.nn.Linear(10, 3)
+
+ def forward(self, x):
+ return self.fc1(x), self.fc2(x)
+
+ model = HybridModel()
+ # Move model to CPU to ensure test works everywhere
+ model = model.cpu()
+ input_data = torch.randn(50, 10, requires_grad=True)
+
+ # Ensure input is on CPU
+ input_data = input_data.cpu()
+
+ jacobian = cebra.attribution._jacobian.compute_jacobian(
+ model=model,
+ input_vars=[input_data],
+ hybrid_solver=True,
+ convert_to_numpy=True,
+ cuda_device=None # Explicitly set to None to use CPU
+ )
+
+ # Check shape (batch_size, output_dim, input_dim)
+ assert jacobian.shape == (50, 8, 10) # 8 = 5 + 3 concatenated outputs
+
+
+def test_attribution_map_transforms():
+ model = torch.nn.Sequential(torch.nn.Linear(10, 5), torch.nn.ReLU(),
+ torch.nn.Linear(5, 3))
+ input_data = torch.randn(100, 10)
+
+ # Test different aggregation methods
+ for aggregate in ["mean", "sum", "max"]:
+ jf, jhatg = jacobian_attribution.get_attribution_map(
+ model=model, input_data=input_data, aggregate=aggregate)
+ assert isinstance(jf, np.ndarray)
+ assert isinstance(jhatg, np.ndarray)
diff --git a/tests/test_cli.py b/tests/test_cli.py
index 41e67f42..8e49cc35 100644
--- a/tests/test_cli.py
+++ b/tests/test_cli.py
@@ -19,6 +19,3 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-import argparse
-
-import pytest
diff --git a/tests/test_criterions.py b/tests/test_criterions.py
index 93a3b846..0d6f8ff2 100644
--- a/tests/test_criterions.py
+++ b/tests/test_criterions.py
@@ -19,7 +19,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
-import numpy as np
import pytest
import torch
from torch import nn
@@ -294,7 +293,7 @@ def _sample_dist_matrices(seed):
@pytest.mark.parametrize("seed", [42, 4242, 424242])
-def test_infonce(seed):
+def test_infonce_check_output_parts(seed):
pos_dist, neg_dist = _sample_dist_matrices(seed)
ref_loss, ref_align, ref_uniform = _reference_infonce(pos_dist, neg_dist)
diff --git a/tests/test_data_masking.py b/tests/test_data_masking.py
new file mode 100644
index 00000000..1b4976af
--- /dev/null
+++ b/tests/test_data_masking.py
@@ -0,0 +1,206 @@
+import copy
+
+import pytest
+import torch
+
+import cebra.data.mask
+from cebra.data.masking import MaskedMixin
+
+#### Tests for Mask class ####
+
+
+@pytest.mark.parametrize("mask", [
+ cebra.data.mask.RandomNeuronMask,
+ cebra.data.mask.RandomTimestepMask,
+ cebra.data.mask.NeuronBlockMask,
+])
+def test_random_mask(mask: cebra.data.mask.Mask):
+ data = torch.ones(
+ (10, 20,
+ 30)) # Example tensor with shape (batch_size, n_neurons, offset)
+ mask = mask(masking_value=0.5)
+ masked_data = mask.apply_mask(copy.deepcopy(data))
+
+ assert masked_data.shape == data.shape, "Masked data shape should match input data shape"
+ assert (masked_data <= 1).all() and (
+ masked_data >= 0).all(), "Masked data should only contain values 0 or 1"
+ assert torch.sum(masked_data) < torch.sum(
+ data), "Masked data should have fewer active neurons than original data"
+
+
+def test_timeblock_mask():
+ data = torch.ones(
+ (10, 20,
+ 30)) # Example tensor with shape (batch_size, n_neurons, offset)
+ mask = cebra.data.mask.TimeBlockMask(masking_value=(0.035, 10))
+ masked_data = mask.apply_mask(copy.deepcopy(data))
+
+ assert masked_data.shape == data.shape, "Masked data shape should match input data shape"
+ assert (masked_data <= 1).all() and (
+ masked_data >= 0).all(), "Masked data should only contain values 0 or 1"
+ assert torch.sum(masked_data) < torch.sum(
+ data), "Masked data should have fewer active neurons than original data"
+
+
+#### Tests for MaskedMixin class ####
+
+
+def test_masked_mixin_no_masks():
+ mixin = MaskedMixin()
+ data = torch.ones(
+ (10, 20,
+ 30)) # Example tensor with shape (batch_size, n_neurons, offset)
+ masked_data = mixin.apply_mask(copy.deepcopy(data))
+
+ assert torch.equal(
+ data,
+ masked_data), "Data should remain unchanged when no masks are applied"
+
+
+@pytest.mark.parametrize(
+ "mask", ["RandomNeuronMask", "RandomTimestepMask", "NeuronBlockMask"])
+def test_masked_mixin_random_mask(mask):
+ data = torch.ones(
+ (10, 20,
+ 30)) # Example tensor with shape (batch_size, n_neurons, offset)
+
+ mixin = MaskedMixin()
+ assert mixin.masks == [], "Masks should be empty initially"
+
+ mixin.set_masks({mask: 0.5})
+ assert len(mixin.masks) == 1, "One mask should be set"
+ assert isinstance(mixin.masks[0],
+ getattr(cebra.data.mask,
+ mask)), f"Mask should be of type {mask}"
+ if isinstance(mixin.masks[0], cebra.data.mask.NeuronBlockMask):
+ assert mixin.masks[
+ 0].mask_prop == 0.5, "Masking value should be set correctly"
+ else:
+ assert mixin.masks[
+ 0].mask_ratio == 0.5, "Masking value should be set correctly"
+
+ masked_data = mixin.apply_mask(copy.deepcopy(data))
+ assert masked_data.shape == data.shape, "Masked data shape should match input data shape"
+ assert not torch.equal(
+ data, masked_data), "Data should be modified when a mask is applied"
+
+ mixin.set_masks({mask: [0.5, 0.1]})
+ assert len(mixin.masks) == 1, "One mask should be set"
+ assert isinstance(mixin.masks[0],
+ getattr(cebra.data.mask,
+ mask)), f"Mask should be of type {mask}"
+ masked_data = mixin.apply_mask(copy.deepcopy(data))
+ assert masked_data.shape == data.shape, "Masked data shape should match input data shape"
+ assert not torch.equal(
+ data, masked_data), "Data should be modified when a mask is applied"
+
+ mixin.set_masks({mask: (0.3, 0.9, 0.05)})
+ assert len(mixin.masks) == 1, "One mask should be set"
+ assert isinstance(mixin.masks[0],
+ getattr(cebra.data.mask,
+ mask)), f"Mask should be of type {mask}"
+ masked_data = mixin.apply_mask(copy.deepcopy(data))
+ assert masked_data.shape == data.shape, "Masked data shape should match input data shape"
+ assert not torch.equal(
+ data, masked_data), "Data should be modified when a mask is applied"
+
+
+def test_apply_mask_with_time_block_mask():
+ mixin = MaskedMixin()
+
+ with pytest.raises(AssertionError, match="sampled_rate.*masked_seq_len"):
+ mixin.set_masks({"TimeBlockMask": 0.2})
+
+ with pytest.raises(AssertionError, match="(sampled_rate.*masked_seq_len)"):
+ mixin.set_masks({"TimeBlockMask": [0.2, 10]})
+
+ with pytest.raises(AssertionError, match="between.*0.0.*1.0"):
+ mixin.set_masks({"TimeBlockMask": (-2, 10)})
+
+ with pytest.raises(AssertionError, match="between.*0.0.*1.0"):
+ mixin.set_masks({"TimeBlockMask": (2, 10)})
+
+ with pytest.raises(AssertionError, match="integer.*greater"):
+ mixin.set_masks({"TimeBlockMask": (0.2, -10)})
+
+ with pytest.raises(AssertionError, match="integer.*greater"):
+ mixin.set_masks({"TimeBlockMask": (0.2, 5.5)})
+
+ mixin.set_masks({"TimeBlockMask": (0.035, 10)}) # Correct usage
+ data = torch.ones(
+ (10, 20,
+ 30)) # Example tensor with shape (batch_size, n_neurons, offset)
+ masked_data = mixin.apply_mask(copy.deepcopy(data))
+
+ assert masked_data.shape == data.shape, "Masked data shape should match input data shape"
+ assert not torch.equal(
+ data, masked_data), "Data should be modified when a mask is applied"
+
+
+def test_multiple_masks_mixin():
+ mixin = MaskedMixin()
+ mixin.set_masks({"RandomNeuronMask": 0.5, "RandomTimestepMask": 0.3})
+ data = torch.ones(
+ (10, 20,
+ 30)) # Example tensor with shape (batch_size, n_neurons, offset)
+
+ masked_data = mixin.apply_mask(copy.deepcopy(data))
+ assert masked_data.shape == data.shape, "Masked data shape should match input data shape"
+ assert not torch.equal(
+ data,
+ masked_data), "Data should be modified when multiple masks are applied"
+
+ masked_data2 = mixin.apply_mask(copy.deepcopy(masked_data))
+ assert masked_data2.shape == data.shape, "Masked data shape should match input data shape"
+ assert not torch.equal(
+ data,
+ masked_data2), "Data should be modified when multiple masks are applied"
+ assert not torch.equal(
+ masked_data, masked_data2
+ ), "Masked data should be different for different iterations"
+
+
+def test_single_dim_input():
+ mixin = MaskedMixin()
+ mixin.set_masks({"RandomNeuronMask": 0.5})
+ data = torch.ones((10, 1, 30)) # Single neuron
+ masked_data = mixin.apply_mask(copy.deepcopy(data))
+
+ assert masked_data.shape == data.shape, "Masked data shape should match input data shape"
+ assert not torch.equal(
+ data, masked_data), "Data should be modified even with a single neuron"
+
+ mixin = MaskedMixin()
+ mixin.set_masks({"RandomTimestepMask": 0.5})
+ data = torch.ones((10, 20, 1)) # Single timestep
+ masked_data = mixin.apply_mask(copy.deepcopy(data))
+
+ assert masked_data.shape == data.shape, "Masked data shape should match input data shape"
+ assert not torch.equal(
+ data,
+ masked_data), "Data should be modified even with a single timestep"
+
+
+def test_apply_mask_with_invalid_input():
+ mixin = MaskedMixin()
+ mixin.set_masks({"RandomNeuronMask": 0.5})
+
+ with pytest.raises(ValueError, match="Data must be a 3D tensor"):
+ data = torch.ones(
+ (10, 20)) # Invalid tensor shape (missing offset dimension)
+ mixin.apply_mask(data)
+
+ with pytest.raises(ValueError, match="Data must be a float32 tensor"):
+ data = torch.ones((10, 20, 30), dtype=torch.int32)
+ mixin.apply_mask(data)
+
+
+def test_apply_mask_with_chunk_size():
+ mixin = MaskedMixin()
+ mixin.set_masks({"RandomNeuronMask": 0.5})
+ data = torch.ones((10000, 20, 30)) # Large tensor to test chunking
+ masked_data = mixin.apply_mask(copy.deepcopy(data), chunk_size=1000)
+
+ assert masked_data.shape == data.shape, "Masked data shape should match input data shape"
+ assert not torch.equal(
+ data, masked_data), "Data should be modified when a mask is applied"
diff --git a/tests/test_datasets.py b/tests/test_datasets.py
index 6a7f9319..e8e03ff0 100644
--- a/tests/test_datasets.py
+++ b/tests/test_datasets.py
@@ -68,9 +68,7 @@ def test_demo():
@pytest.mark.requires_dataset
def test_hippocampus():
-
pytest.skip("Outdated")
-
dataset = cebra.datasets.init("rat-hippocampus-single")
loader = cebra.data.ContinuousDataLoader(
dataset=dataset,
@@ -99,7 +97,6 @@ def test_hippocampus():
@pytest.mark.requires_dataset
def test_monkey():
-
dataset = cebra.datasets.init(
"area2-bump-pos-active-passive",
path=pathlib.Path(_DEFAULT_DATADIR) / "monkey_reaching_preload_smth_40",
@@ -110,7 +107,6 @@ def test_monkey():
@pytest.mark.requires_dataset
def test_allen():
-
pytest.skip("Test takes too long")
ca_dataset = cebra.datasets.init("allen-movie-one-ca-VISp-100-train-10-111")
diff --git a/tests/test_demo.py b/tests/test_demo.py
index 4f0f146c..ce555db3 100644
--- a/tests/test_demo.py
+++ b/tests/test_demo.py
@@ -21,7 +21,6 @@
#
import glob
import re
-import sys
import pytest
diff --git a/tests/test_distributions.py b/tests/test_distributions.py
index d7151fd1..656559bb 100644
--- a/tests/test_distributions.py
+++ b/tests/test_distributions.py
@@ -43,7 +43,7 @@ def prepare(N=1000, n=128, d=5, probs=[0.3, 0.1, 0.6], device="cpu"):
continuous = torch.randn(N, d).to(device)
rand = torch.from_numpy(np.random.randint(0, N, (n,))).to(device)
- qidx = discrete[rand].to(device)
+ _ = discrete[rand].to(device)
query = continuous[rand] + 0.1 * torch.randn(n, d).to(device)
query = query.to(device)
@@ -173,7 +173,7 @@ def test_mixed():
discrete, continuous)
reference_idx = distribution.sample_prior(10)
- positive_idx = distribution.sample_conditional(reference_idx)
+ _ = distribution.sample_conditional(reference_idx)
# The conditional distribution p(· | disc, cont) should yield
# samples where the label exactly matches the reference sample.
@@ -193,7 +193,7 @@ def test_continuous(benchmark):
def _test_distribution(dist):
distribution = dist(continuous)
reference_idx = distribution.sample_prior(10)
- positive_idx = distribution.sample_conditional(reference_idx)
+ _ = distribution.sample_conditional(reference_idx)
return distribution
distribution = _test_distribution(
@@ -411,3 +411,16 @@ def test_new_delta_normal_with_multidimensional_index(delta, numerical_check):
pytest.skip(
"multivariate delta distribution can not accurately sample with the "
"given parameters. TODO: Add a warning message for these cases.")
+
+
+@pytest.mark.parametrize("time_offset", [1, 5, 10])
+def test_unified_distribution(time_offset):
+ dataset = cebra_datasets.init("demo-continuous-unified")
+ sampler = cebra_distr.UnifiedSampler(dataset, time_offset=time_offset)
+
+ num_samples = 5
+ sample = sampler.sample_prior(num_samples)
+ assert sample.shape == (dataset.num_sessions, num_samples)
+
+ positive = sampler.sample_conditional(sample)
+ assert positive.shape == (dataset.num_sessions, num_samples)
diff --git a/tests/test_dlc.py b/tests/test_dlc.py
index a19fe593..8ab29abd 100644
--- a/tests/test_dlc.py
+++ b/tests/test_dlc.py
@@ -29,6 +29,7 @@
import cebra.integrations.deeplabcut as cebra_dlc
from cebra import CEBRA
from cebra import load_data
+from cebra.data.load import read_hdf
# NOTE(stes): The original data URL is
# https://github.com/DeepLabCut/DeepLabCut/blob/main/examples
@@ -54,11 +55,7 @@ def test_imports():
def _load_dlc_dataframe(filename):
- try:
- df = pd.read_hdf(filename, "df_with_missing")
- except KeyError:
- df = pd.read_hdf(filename)
- return df
+ return read_hdf(filename)
def _get_annotated_data(url, keypoints):
diff --git a/tests/test_grid_search.py b/tests/test_grid_search.py
index 3f88ba12..c774ea02 100644
--- a/tests/test_grid_search.py
+++ b/tests/test_grid_search.py
@@ -20,7 +20,6 @@
# limitations under the License.
#
import numpy as np
-import pytest
import cebra
import cebra.grid_search
diff --git a/tests/test_integration_train.py b/tests/test_integration_train.py
index 06e6da40..238bbea7 100644
--- a/tests/test_integration_train.py
+++ b/tests/test_integration_train.py
@@ -20,7 +20,6 @@
# limitations under the License.
#
import itertools
-from typing import List
import pytest
import torch
diff --git a/tests/test_integration_xcebra.py b/tests/test_integration_xcebra.py
new file mode 100644
index 00000000..760e26ef
--- /dev/null
+++ b/tests/test_integration_xcebra.py
@@ -0,0 +1,190 @@
+import pickle
+
+import _utils_deprecated
+import numpy as np
+import pytest
+import torch
+
+import cebra
+import cebra.attribution
+import cebra.data
+import cebra.models
+import cebra.solver
+from cebra.data import ContrastiveMultiObjectiveLoader
+from cebra.data import DatasetxCEBRA
+from cebra.solver import MultiObjectiveConfig
+from cebra.solver.schedulers import LinearRampUp
+
+
+@pytest.fixture
+def synthetic_data():
+ import tempfile
+ import urllib.request
+ from pathlib import Path
+
+ url = "https://cebra.fra1.digitaloceanspaces.com/xcebra_synthetic_data.pkl"
+
+ # Create a persistent temp directory specific to this test
+ temp_dir = Path(tempfile.gettempdir()) / "cebra_test_data"
+ temp_dir.mkdir(exist_ok=True)
+ filepath = temp_dir / "synthetic_data.pkl"
+
+ if not filepath.exists():
+ urllib.request.urlretrieve(url, filepath)
+
+ with filepath.open('rb') as file:
+ return pickle.load(file)
+
+
+@pytest.fixture
+def device():
+ return "cuda" if torch.cuda.is_available() else "cpu"
+
+
+def test_synthetic_data_training(synthetic_data, device):
+ # Setup data
+ neurons = synthetic_data['neurons']
+ latents = synthetic_data['latents']
+ n_latents = latents.shape[1]
+ Z1 = synthetic_data['Z1']
+ Z2 = synthetic_data['Z2']
+ gt_attribution_map = synthetic_data['gt_attribution_map']
+ data = DatasetxCEBRA(neurons, Z1=Z1, Z2=Z2)
+
+ # Configure training with reduced steps
+ TOTAL_STEPS = 50 # Reduced from 2000 for faster testing
+ loader = ContrastiveMultiObjectiveLoader(dataset=data,
+ num_steps=TOTAL_STEPS,
+ batch_size=512).to(device)
+
+ config = MultiObjectiveConfig(loader)
+ config.set_slice(0, 6)
+ config.set_loss("FixedEuclideanInfoNCE", temperature=1.)
+ config.set_distribution("time", time_offset=1)
+ config.push()
+
+ config.set_slice(3, 6)
+ config.set_loss("FixedEuclideanInfoNCE", temperature=1.)
+ config.set_distribution("time_delta", time_delta=1, label_name="Z2")
+ config.push()
+
+ config.finalize()
+
+ # Initialize model and solver
+ neural_model = cebra.models.init(
+ name="offset1-model-mse-clip-5-5",
+ num_neurons=data.neural.shape[1],
+ num_units=256,
+ num_output=n_latents,
+ ).to(device)
+
+ data.configure_for(neural_model)
+
+ opt = torch.optim.Adam(
+ list(neural_model.parameters()) + list(config.criterion.parameters()),
+ lr=3e-4,
+ weight_decay=0,
+ )
+
+ regularizer = cebra.models.jacobian_regularizer.JacobianReg()
+
+ solver = cebra.solver.init(
+ name="multiobjective-solver",
+ model=neural_model,
+ feature_ranges=config.feature_ranges,
+ regularizer=regularizer,
+ renormalize=False,
+ use_sam=False,
+ criterion=config.criterion,
+ optimizer=opt,
+ tqdm_on=False,
+ ).to(device)
+
+ # Train model with reduced steps for regularizer
+ weight_scheduler = LinearRampUp(
+ n_splits=2,
+ step_to_switch_on_reg=25, # Reduced from 2500
+ step_to_switch_off_reg=40, # Reduced from 15000
+ start_weight=0.,
+ end_weight=0.01,
+ stay_constant_after_switch_off=True)
+
+ solver.fit(
+ loader=loader,
+ valid_loader=None,
+ log_frequency=None,
+ scheduler_regularizer=weight_scheduler,
+ scheduler_loss=None,
+ )
+
+ # Basic test that model runs and produces output
+ solver.model.split_outputs = False
+ embedding = solver.model(data.neural.to(device)).detach().cpu()
+
+ # Verify output dimensions
+ assert embedding.shape[1] == n_latents, "Incorrect embedding dimension"
+ assert not torch.isnan(embedding).any(), "NaN values in embedding"
+
+ # Test attribution map functionality
+ data.neural.requires_grad_(True)
+ method = cebra.attribution.init(name="jacobian-based",
+ model=solver.model,
+ input_data=data.neural,
+ output_dimension=solver.model.num_output)
+
+ result = method.compute_attribution_map()
+ jfinv = abs(result['jf-inv-lsq']).mean(0)
+
+ # Verify attribution map output
+ assert not torch.isnan(
+ torch.tensor(jfinv)).any(), "NaN values in attribution map"
+ assert jfinv.shape == gt_attribution_map.shape, "Incorrect attribution map shape"
+
+ # Test split outputs functionality
+ solver.model.split_outputs = True
+ embedding_split = solver.model(data.neural.to(device))
+ Z1_hat = embedding_split[0].detach().cpu()
+ Z2_hat = embedding_split[1].detach().cpu()
+
+ # TODO(stes): Right now, this results 6D output vs. 3D as expected. Need to double check
+ # the API docs on the desired behavior here, both could be fine...
+ # assert Z1_hat.shape == Z1.shape, f"Incorrect Z1 embedding dimension: {Z1_hat.shape}"
+ assert Z2_hat.shape == Z2.shape, f"Incorrect Z2 embedding dimension: {Z2_hat.shape}"
+ assert not torch.isnan(Z1_hat).any(), "NaN values in Z1 embedding"
+ assert not torch.isnan(Z2_hat).any(), "NaN values in Z2 embedding"
+
+ # Test the transform
+ solver.model.split_outputs = False
+ transform_embedding = solver.transform(data.neural.to(device))
+ assert transform_embedding.shape[
+ 1] == n_latents, "Incorrect embedding dimension"
+ assert not torch.isnan(transform_embedding).any(), "NaN values in embedding"
+ assert np.allclose(embedding, transform_embedding, rtol=1e-4, atol=1e-4)
+
+ # Test the transform with batching
+ batched_embedding = solver.transform(data.neural.to(device), batch_size=512)
+ assert batched_embedding.shape[
+ 1] == n_latents, "Incorrect embedding dimension"
+ assert not torch.isnan(batched_embedding).any(), "NaN values in embedding"
+ assert np.allclose(embedding, batched_embedding, rtol=1e-4, atol=1e-4)
+
+ assert np.allclose(transform_embedding,
+ batched_embedding,
+ rtol=1e-4,
+ atol=1e-4)
+
+ # Test and compare the previous transform (transform_deprecated)
+ deprecated_transform_embedding = _utils_deprecated.multiobjective_transform_deprecated(
+ solver, data.neural.to(device))
+ assert np.allclose(embedding,
+ deprecated_transform_embedding,
+ rtol=1e-4,
+ atol=1e-4)
+ assert np.allclose(transform_embedding,
+ deprecated_transform_embedding,
+ rtol=1e-4,
+ atol=1e-4)
+ assert np.allclose(batched_embedding,
+ deprecated_transform_embedding,
+ rtol=1e-4,
+ atol=1e-4)
diff --git a/tests/test_load.py b/tests/test_load.py
index 6f62dc92..4524b29c 100644
--- a/tests/test_load.py
+++ b/tests/test_load.py
@@ -22,10 +22,7 @@
import itertools
import pathlib
import pickle
-import platform
import tempfile
-import unittest
-from unittest.mock import patch
import h5py
import hdf5storage
@@ -125,7 +122,7 @@ def generate_numpy_confounder(filename, dtype):
@register("npz")
-def generate_numpy_path(filename, dtype):
+def generate_numpy_path_2(filename, dtype):
A = np.arange(1000, dtype=dtype).reshape(10, 100)
np.savez(filename, array=A, other_data="test")
loaded_A = cebra_load.load(pathlib.Path(filename))
@@ -251,7 +248,7 @@ def generate_h5_no_array(filename, dtype):
def generate_h5_dataframe(filename, dtype):
A = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
df_A = pd.DataFrame(np.array(A), columns=["a", "b", "c"])
- df_A.to_hdf(filename, "df_A")
+ df_A.to_hdf(filename, key="df_A")
loaded_A = cebra_load.load(filename, key="df_A")
return A, loaded_A
@@ -261,7 +258,7 @@ def generate_h5_dataframe_columns(filename, dtype):
A = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
A_col = A[:, :2]
df_A = pd.DataFrame(np.array(A), columns=["a", "b", "c"])
- df_A.to_hdf(filename, "df_A")
+ df_A.to_hdf(filename, key="df_A")
loaded_A = cebra_load.load(filename, key="df_A", columns=["a", "b"])
return A_col, loaded_A
@@ -272,8 +269,8 @@ def generate_h5_multi_dataframe(filename, dtype):
B = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
df_A = pd.DataFrame(np.array(A), columns=["a", "b", "c"])
df_B = pd.DataFrame(np.array(B), columns=["c", "d", "e"])
- df_A.to_hdf(filename, "df_A")
- df_B.to_hdf(filename, "df_B")
+ df_A.to_hdf(filename, key="df_A")
+ df_B.to_hdf(filename, key="df_B")
loaded_A = cebra_load.load(filename, key="df_A")
return A, loaded_A
@@ -282,7 +279,7 @@ def generate_h5_multi_dataframe(filename, dtype):
def generate_h5_single_dataframe_no_key(filename, dtype):
A = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]).astype(dtype)
df_A = pd.DataFrame(np.array(A), columns=["a", "b", "c"])
- df_A.to_hdf(filename, "df_A")
+ df_A.to_hdf(filename, key="df_A")
loaded_A = cebra_load.load(filename)
return A, loaded_A
@@ -293,8 +290,8 @@ def generate_h5_multi_dataframe_no_key(filename, dtype):
B = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]).astype(dtype)
df_A = pd.DataFrame(np.array(A), columns=["a", "b", "c"])
df_B = pd.DataFrame(np.array(B), columns=["c", "d", "e"])
- df_A.to_hdf(filename, "df_A")
- df_B.to_hdf(filename, "df_B")
+ df_A.to_hdf(filename, key="df_A")
+ df_B.to_hdf(filename, key="df_B")
_ = cebra_load.load(filename)
@@ -307,7 +304,7 @@ def generate_h5_multicol_dataframe(filename, dtype):
df_A = pd.DataFrame(A,
columns=pd.MultiIndex.from_product([animals,
keypoints]))
- df_A.to_hdf(filename, "df_A")
+ df_A.to_hdf(filename, key="df_A")
loaded_A = cebra_load.load(filename, key="df_A")
return A, loaded_A
@@ -316,7 +313,7 @@ def generate_h5_multicol_dataframe(filename, dtype):
def generate_h5_dataframe_invalid_key(filename, dtype):
A = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]).astype(dtype)
df_A = pd.DataFrame(np.array(A), columns=["a", "b", "c"])
- df_A.to_hdf(filename, "df_A")
+ df_A.to_hdf(filename, key="df_A")
_ = cebra_load.load(filename, key="df_B")
@@ -324,7 +321,7 @@ def generate_h5_dataframe_invalid_key(filename, dtype):
def generate_h5_dataframe_invalid_column(filename, dtype):
A = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]).astype(dtype)
df_A = pd.DataFrame(np.array(A), columns=["a", "b", "c"])
- df_A.to_hdf(filename, "df_A")
+ df_A.to_hdf(filename, key="df_A")
_ = cebra_load.load(filename, key="df_A", columns=["d", "b"])
@@ -337,7 +334,7 @@ def generate_h5_multicol_dataframe_columns(filename, dtype):
df_A = pd.DataFrame(A,
columns=pd.MultiIndex.from_product([animals,
keypoints]))
- df_A.to_hdf(filename, "df_A")
+ df_A.to_hdf(filename, key="df_A")
_ = cebra_load.load(filename, key="df_A", columns=["a", "b"])
@@ -418,7 +415,7 @@ def generate_csv_path(filename, dtype):
@register_error("csv")
def generate_csv_empty_file(filename, dtype):
- with open(filename, "w") as creating_new_csv_file:
+ with open(filename, "w") as _:
pass
_ = cebra_load.load(filename)
@@ -619,7 +616,6 @@ def generate_pickle_invalid_key(filename, dtype):
@register_error("pkl", "p")
def generate_pickle_no_array(filename, dtype):
- A = np.arange(1000, dtype=dtype).reshape(10, 100)
with open(filename, "wb") as f:
pickle.dump({"A": "test_1", "B": "test_2"}, f)
_ = cebra_load.load(filename)
diff --git a/tests/test_loader.py b/tests/test_loader.py
index 562f64a7..cb6be9a7 100644
--- a/tests/test_loader.py
+++ b/tests/test_loader.py
@@ -19,16 +19,16 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
+import _util
+import numpy as np
import pytest
import torch
import cebra.data
import cebra.io
-
-def parametrize_device(func):
- _devices = ("cpu", "cuda") if torch.cuda.is_available() else ("cpu",)
- return pytest.mark.parametrize("device", _devices)(func)
+BATCH_SIZE = 32
+NUMS_NEURAL = [3, 4, 5]
class LoadSpeed:
@@ -107,7 +107,11 @@ def _assert_dataset_on_correct_device(loader, device):
assert hasattr(loader, "dataset")
assert hasattr(loader, "device")
assert isinstance(loader.dataset, cebra.io.HasDevice)
- assert loader.dataset.neural.device.type == device
+ if isinstance(loader, cebra.data.SingleSessionDataset):
+ assert loader.dataset.neural.device.type == device
+ elif isinstance(loader, cebra.data.MultiSessionDataset):
+ for session in loader.dataset.iter_sessions():
+ assert session.neural.device.type == device
def test_demo_data():
@@ -130,13 +134,15 @@ def _to_str(val):
assert _to_str(first) == _to_str(second)
-@parametrize_device
+@_util.parametrize_device
@pytest.mark.parametrize(
"data_name, loader_initfunc",
[
("demo-discrete", cebra.data.DiscreteDataLoader),
("demo-continuous", cebra.data.ContinuousDataLoader),
("demo-mixed", cebra.data.MixedDataLoader),
+ ("demo-continuous-multisession", cebra.data.MultiSessionLoader),
+ ("demo-continuous-unified", cebra.data.UnifiedLoader),
],
)
def test_device(data_name, loader_initfunc, device):
@@ -147,7 +153,7 @@ def test_device(data_name, loader_initfunc, device):
other_device = swap.get(device)
dataset = RandomDataset(N=100, device=other_device)
- loader = loader_initfunc(dataset, num_steps=10, batch_size=32)
+ loader = loader_initfunc(dataset, num_steps=10, batch_size=BATCH_SIZE)
loader.to(device)
assert loader.dataset == dataset
_assert_device(loader.device, device)
@@ -156,7 +162,7 @@ def test_device(data_name, loader_initfunc, device):
_assert_device(loader.get_indices(10).reference.device, device)
-@parametrize_device
+@_util.parametrize_device
@pytest.mark.parametrize("prior", ("uniform", "empirical"))
def test_discrete(prior, device, benchmark):
dataset = RandomDataset(N=100, device=device)
@@ -171,7 +177,7 @@ def test_discrete(prior, device, benchmark):
benchmark(load_speed)
-@parametrize_device
+@_util.parametrize_device
@pytest.mark.parametrize("conditional", ("time", "time_delta"))
def test_continuous(conditional, device, benchmark):
dataset = RandomDataset(N=100, d=5, device=device)
@@ -199,7 +205,7 @@ def _check_attributes(obj, is_list=False):
raise TypeError()
-@parametrize_device
+@_util.parametrize_device
@pytest.mark.parametrize(
"data_name, loader_initfunc",
[
@@ -211,7 +217,7 @@ def _check_attributes(obj, is_list=False):
def test_singlesession_loader(data_name, loader_initfunc, device):
data = cebra.datasets.init(data_name)
data.to(device)
- loader = loader_initfunc(data, num_steps=10, batch_size=32)
+ loader = loader_initfunc(data, num_steps=10, batch_size=BATCH_SIZE)
_assert_dataset_on_correct_device(loader, device)
index = loader.get_indices(100)
@@ -219,25 +225,33 @@ def test_singlesession_loader(data_name, loader_initfunc, device):
for batch in loader:
_check_attributes(batch)
- assert len(batch.positive) == 32
+ assert len(batch.positive) == BATCH_SIZE
-def test_multisession_cont_loader():
- data = cebra.datasets.MultiContinuous(nums_neural=[3, 4, 5],
- num_behavior=5,
- num_timepoints=100)
- loader = cebra.data.ContinuousMultiSessionDataLoader(
- data,
- num_steps=10,
- batch_size=32,
- )
+@_util.parametrize_device
+@pytest.mark.parametrize(
+ "data_name, loader_initfunc",
+ [
+ ("demo-continuous-multisession",
+ cebra.data.ContinuousMultiSessionDataLoader),
+ ("demo-discrete-multisession",
+ cebra.data.DiscreteMultiSessionDataLoader),
+ ],
+)
+def test_multisession_loader(data_name, loader_initfunc, device):
+ data = cebra.datasets.init(data_name)
+ data.to(device)
+ loader = loader_initfunc(data, num_steps=10, batch_size=BATCH_SIZE)
+
+ _assert_dataset_on_correct_device(loader, device)
# Check the sampler
assert hasattr(loader, "sampler")
ref_idx = loader.sampler.sample_prior(1000)
- assert len(ref_idx) == 3 # num_sessions
- for session in range(3):
- assert ref_idx[session].max() < 100
+ assert len(ref_idx) == len(NUMS_NEURAL)
+ for session in range(len(NUMS_NEURAL)):
+ assert ref_idx[session].max(
+ ) < cebra.datasets.demo._DEFAULT_NUM_TIMEPOINTS
pos_idx, idx, idx_rev = loader.sampler.sample_conditional(ref_idx)
assert pos_idx is not None
@@ -245,6 +259,8 @@ def test_multisession_cont_loader():
assert idx_rev is not None
batch = next(iter(loader))
+ for i, n_neurons in enumerate(NUMS_NEURAL):
+ assert batch[i].reference.shape == (BATCH_SIZE, n_neurons, 10)
def _mix(array, idx):
shape = array.shape
@@ -259,82 +275,70 @@ def _process(batch, feature_dim=1):
[b.reference.flatten(1).mean(dim=1, keepdims=True) for b in batch],
dim=0).repeat(1, 1, feature_dim)
- assert batch[0].reference.shape == (32, 3, 10)
- assert batch[1].reference.shape == (32, 4, 10)
- assert batch[2].reference.shape == (32, 5, 10)
-
dummy_prediction = _process(batch, feature_dim=6)
- assert dummy_prediction.shape == (3, 32, 6)
+ assert dummy_prediction.shape == (3, BATCH_SIZE, 6)
_mix(dummy_prediction, batch[0].index)
+ index = loader.get_indices(100)
+ #print(index[0])
+ #print(type(index))
+ _check_attributes(index, is_list=False)
-def test_multisession_disc_loader():
- data = cebra.datasets.MultiDiscrete(nums_neural=[3, 4, 5],
- num_timepoints=100)
- loader = cebra.data.DiscreteMultiSessionDataLoader(
- data,
- num_steps=10,
- batch_size=32,
- )
-
- # Check the sampler
- assert hasattr(loader, "sampler")
- ref_idx = loader.sampler.sample_prior(1000)
- assert len(ref_idx) == 3 # num_sessions
-
- # Check sample points are in session length range
- for session in range(3):
- assert ref_idx[session].max() < loader.sampler.session_lengths[session]
- pos_idx, idx, idx_rev = loader.sampler.sample_conditional(ref_idx)
-
- assert pos_idx is not None
- assert idx is not None
- assert idx_rev is not None
-
- batch = next(iter(loader))
-
- def _mix(array, idx):
- shape = array.shape
- n, m = shape[:2]
- mixed = array.reshape(n * m, -1)[idx]
- print(mixed.shape, array.shape, idx.shape)
- return mixed.reshape(shape)
-
- def _process(batch, feature_dim=1):
- """Given list_i[(N,d_i)] batch, return (#session, N, feature_dim) tensor"""
- return torch.stack(
- [b.reference.flatten(1).mean(dim=1, keepdims=True) for b in batch],
- dim=0).repeat(1, 1, feature_dim)
-
- assert batch[0].reference.shape == (32, 3, 10)
- assert batch[1].reference.shape == (32, 4, 10)
- assert batch[2].reference.shape == (32, 5, 10)
-
- dummy_prediction = _process(batch, feature_dim=6)
- assert dummy_prediction.shape == (3, 32, 6)
- _mix(dummy_prediction, batch[0].index)
+ for batch in loader:
+ _check_attributes(batch, is_list=True)
+ for session_batch in batch:
+ assert len(session_batch.positive) == BATCH_SIZE
-@parametrize_device
+@_util.parametrize_device
@pytest.mark.parametrize(
"data_name, loader_initfunc",
- [('demo-discrete-multisession', cebra.data.DiscreteMultiSessionDataLoader),
- ("demo-continuous-multisession",
- cebra.data.ContinuousMultiSessionDataLoader)],
+ [
+ ("demo-continuous-unified", cebra.data.UnifiedLoader),
+ ],
)
-def test_multisession_loader(data_name, loader_initfunc, device):
- # TODO change number of timepoints across the sessions
-
+def test_unified_loader(data_name, loader_initfunc, device):
data = cebra.datasets.init(data_name)
- kwargs = dict(num_steps=10, batch_size=32)
- loader = loader_initfunc(data, **kwargs)
+ data.to(device)
+ loader = loader_initfunc(data, num_steps=10, batch_size=BATCH_SIZE)
+
+ _assert_dataset_on_correct_device(loader, device)
+
+ # Check the sampler
+ num_samples = 100
+ assert hasattr(loader, "sampler")
+ ref_idx = loader.sampler.sample_all_uniform_prior(num_samples)
+ assert ref_idx.shape == (len(NUMS_NEURAL), num_samples)
+ assert isinstance(ref_idx, np.ndarray)
+
+ for session in range(len(NUMS_NEURAL)):
+ assert ref_idx[session].max(
+ ) < cebra.datasets.demo._DEFAULT_NUM_TIMEPOINTS
+ pos_idx = loader.sampler.sample_conditional(ref_idx)
+ assert pos_idx.shape == (len(NUMS_NEURAL), num_samples)
+
+ for session in range(len(NUMS_NEURAL)):
+ ref_idx = torch.from_numpy(
+ loader.sampler.sample_all_uniform_prior(
+ num_samples=num_samples)[session])
+ assert ref_idx.shape == (num_samples,)
+ all_ref_idx = loader.sampler.sample_all_sessions(ref_idx=ref_idx,
+ session_id=session)
+ assert all_ref_idx.shape == (len(NUMS_NEURAL), num_samples)
+ assert isinstance(all_ref_idx, torch.Tensor)
+ for i in range(len(all_ref_idx)):
+ assert all_ref_idx[i].max(
+ ) < cebra.datasets.demo._DEFAULT_NUM_TIMEPOINTS
+
+ for i in range(len(all_ref_idx)):
+ pos_idx = loader.sampler.sample_conditional(all_ref_idx)
+ assert pos_idx.shape == (len(NUMS_NEURAL), num_samples)
+
+ # Check the batch
+ batch = next(iter(loader))
+ assert batch.reference.shape == (BATCH_SIZE, sum(NUMS_NEURAL), 10)
+ assert batch.positive.shape == (BATCH_SIZE, sum(NUMS_NEURAL), 10)
+ assert batch.negative.shape == (BATCH_SIZE, sum(NUMS_NEURAL), 10)
index = loader.get_indices(100)
- print(index[0])
- print(type(index))
_check_attributes(index, is_list=False)
-
- for batch in loader:
- _check_attributes(batch, is_list=True)
- for session_batch in batch:
- assert len(session_batch.positive) == 32
diff --git a/tests/test_models.py b/tests/test_models.py
index 2a6e4812..658cc467 100644
--- a/tests/test_models.py
+++ b/tests/test_models.py
@@ -90,6 +90,10 @@ def test_offset_models(model_name, batch_size, input_length):
def test_multiobjective():
+ # NOTE(stes): This test is deprecated and will be removed in a future version.
+ # As of CEBRA 0.6.0, the multi objective models are tested separately in
+ # test_multiobjective.py.
+
class TestModel(cebra.models.Model):
def __init__(self):
@@ -155,8 +159,8 @@ def test_version_check(version, raises):
cebra.models.model._check_torch_version(raise_error=True)
-def test_version_check():
- raises = not cebra.models.model._check_torch_version(raise_error=False)
+def test_version_check_dropout_available():
+ raises = cebra.models.model._check_torch_version(raise_error=False)
if raises:
assert len(cebra.models.get_options("*dropout*")) == 0
else:
diff --git a/tests/test_multiobjective.py b/tests/test_multiobjective.py
new file mode 100644
index 00000000..a4c601ac
--- /dev/null
+++ b/tests/test_multiobjective.py
@@ -0,0 +1,145 @@
+#
+# CEBRA: Consistent EmBeddings of high-dimensional Recordings using Auxiliary variables
+# © Mackenzie W. Mathis & Steffen Schneider (v0.4.0+)
+# Source code:
+# https://github.com/AdaptiveMotorControlLab/CEBRA
+#
+# Please see LICENSE.md for the full license document:
+# https://github.com/AdaptiveMotorControlLab/CEBRA/blob/main/LICENSE.md
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import warnings
+
+import pytest
+import torch
+
+import cebra
+from cebra.data import ContrastiveMultiObjectiveLoader
+from cebra.data import DatasetxCEBRA
+from cebra.solver import MultiObjectiveConfig
+
+
+@pytest.fixture
+def config():
+ neurons = torch.randn(100, 5)
+ behavior1 = torch.randn(100, 2)
+ behavior2 = torch.randn(100, 1)
+ data = DatasetxCEBRA(neurons, behavior1=behavior1, behavior2=behavior2)
+ loader = ContrastiveMultiObjectiveLoader(dataset=data,
+ num_steps=1,
+ batch_size=24)
+ return MultiObjectiveConfig(loader)
+
+
+def test_imports():
+ pass
+
+
+def test_add_data(config):
+ config.set_slice(0, 10)
+ config.set_loss('loss_name', param1='value1')
+ config.set_distribution('distribution_name', param2='value2')
+ config.push()
+
+ assert len(config.total_info) == 1
+ assert config.total_info[0]['slice'] == (0, 10)
+ assert config.total_info[0]['losses'] == {
+ "name": 'loss_name',
+ "kwargs": {
+ 'param1': 'value1'
+ }
+ }
+ assert config.total_info[0]['distributions'] == {
+ "name": 'distribution_name',
+ "kwargs": {
+ 'param2': 'value2'
+ }
+ }
+
+
+def test_overwriting_key_warning(config):
+ with warnings.catch_warnings(record=True) as w:
+ config.set_slice(0, 10)
+ config.set_slice(10, 20)
+ assert len(w) == 1
+ assert issubclass(w[-1].category, UserWarning)
+ assert "Configuration key already exists" in str(w[-1].message)
+
+
+def test_missing_slice_error(config):
+ with pytest.raises(RuntimeError, match="Slice configuration is missing"):
+ config.set_loss('loss_name', param1='value1')
+ config.set_distribution('distribution_name', param2='value2')
+ config.push()
+
+
+def test_missing_distributions_error(config):
+ with pytest.raises(RuntimeError,
+ match="Distributions configuration is missing"):
+ config.set_slice(0, 10)
+ config.set_loss('loss_name', param1='value1')
+ config.push()
+
+
+def test_missing_losses_error(config):
+ with pytest.raises(RuntimeError, match="Losses configuration is missing"):
+ config.set_slice(0, 10)
+ config.set_distribution('distribution_name', param2='value2')
+ config.push()
+
+
+def test_finalize(config):
+ config.set_slice(0, 6)
+ config.set_loss("FixedEuclideanInfoNCE", temperature=1.)
+ config.set_distribution("time", time_offset=1)
+ config.push()
+
+ config.set_slice(3, 6)
+ config.set_loss("FixedEuclideanInfoNCE", temperature=1.)
+ config.set_distribution("time_delta", time_delta=3, label_name="behavior2")
+ config.push()
+
+ config.finalize()
+
+ assert len(config.losses) == 2
+ assert config.losses[0]['indices'] == (0, 6)
+ assert config.losses[1]['indices'] == (3, 6)
+
+ assert len(config.feature_ranges) == 2
+ assert config.feature_ranges[0] == slice(0, 6)
+ assert config.feature_ranges[1] == slice(3, 6)
+
+ assert len(config.loader.distributions) == 2
+ assert isinstance(config.loader.distributions[0],
+ cebra.distributions.continuous.TimeContrastive)
+ assert config.loader.distributions[0].time_offset == 1
+
+ assert isinstance(config.loader.distributions[1],
+ cebra.distributions.continuous.TimedeltaDistribution)
+ assert config.loader.distributions[1].time_delta == 3
+
+
+def test_non_unique_feature_ranges_error(config):
+ config.set_slice(0, 10)
+ config.set_loss("FixedEuclideanInfoNCE", temperature=1.)
+ config.set_distribution("time", time_offset=1)
+ config.push()
+
+ config.set_slice(0, 10)
+ config.set_loss("FixedEuclideanInfoNCE", temperature=1.)
+ config.set_distribution("time_delta", time_delta=3, label_name="behavior2")
+ config.push()
+
+ with pytest.raises(RuntimeError, match="Feature ranges are not unique"):
+ config.finalize()
diff --git a/tests/test_plot.py b/tests/test_plot.py
index 3f44d887..1d94d310 100644
--- a/tests/test_plot.py
+++ b/tests/test_plot.py
@@ -72,8 +72,6 @@ def test_plot_imports():
def test_colormaps():
import matplotlib
- import cebra
-
cmap = matplotlib.colormaps["cebra"]
assert cmap is not None
plt.scatter([1], [2], c=[2], cmap="cebra")
@@ -241,7 +239,7 @@ def test_compare_models():
_ = cebra_plot.compare_models(models, labels=long_labels, ax=ax)
with pytest.raises(ValueError, match="Invalid.*labels"):
invalid_labels = copy.deepcopy(labels)
- ele = invalid_labels.pop()
+ _ = invalid_labels.pop()
invalid_labels.append(["a"])
_ = cebra_plot.compare_models(models, labels=invalid_labels, ax=ax)
diff --git a/tests/test_registry.py b/tests/test_registry.py
index 69e04f38..cd27344c 100644
--- a/tests/test_registry.py
+++ b/tests/test_registry.py
@@ -117,7 +117,7 @@ def test_override():
_Foo1 = test_module.register("foo")(Foo)
assert _Foo1 == Foo
assert _Foo1 != Bar
- assert f"foo" in test_module.get_options()
+ assert "foo" in test_module.get_options()
# Check that the class was actually added to the module
assert (
@@ -137,7 +137,7 @@ def test_override():
_Foo2 = test_module.register("foo", override=True)(Bar)
assert _Foo2 != Foo
assert _Foo2 == Bar
- assert f"foo" in test_module.get_options()
+ assert "foo" in test_module.get_options()
def test_depreciation():
@@ -145,7 +145,7 @@ def test_depreciation():
Foo = _make_class()
_Foo1 = test_module.register("foo")(Foo)
assert _Foo1 == Foo
- assert f"foo" in test_module.get_options()
+ assert "foo" in test_module.get_options()
# Registering the same class under different names
# also raises and error
diff --git a/tests/test_sklearn.py b/tests/test_sklearn.py
index 33df3caf..8c7cd0a1 100644
--- a/tests/test_sklearn.py
+++ b/tests/test_sklearn.py
@@ -24,6 +24,7 @@
import warnings
import _util
+import _utils_deprecated
import numpy as np
import pkg_resources
import pytest
@@ -276,7 +277,6 @@ def test_api(estimator, check):
pytest.skip(f"Model architecture {estimator.model_architecture} "
f"requires longer input sizes than 20 samples.")
- success = True
exception = None
num_successful = 0
total_runs = 0
@@ -334,7 +334,6 @@ def test_sklearn(model_architecture, device):
y_c1 = np.random.uniform(0, 1, (1000, 5))
y_c1_s2 = np.random.uniform(0, 1, (800, 5))
y_c2 = np.random.uniform(0, 1, (1000, 2))
- y_c2_s2 = np.random.uniform(0, 1, (800, 2))
y_d = np.random.randint(0, 10, (1000,))
y_d_s2 = np.random.randint(0, 10, (800,))
@@ -863,7 +862,6 @@ def test_sklearn_full(model_architecture, device, pad_before_transform):
X = np.random.uniform(0, 1, (1000, 50))
y_c1 = np.random.uniform(0, 1, (1000, 5))
y_c2 = np.random.uniform(0, 1, (1000, 2))
- y_d = np.random.randint(0, 10, (1000,))
# time contrastive
cebra_model.fit(X)
@@ -931,7 +929,7 @@ def test_sklearn_resampling_model_not_yet_supported(model_architecture, device):
with pytest.raises(ValueError):
cebra_model.fit(X, y_c1)
- output = cebra_model.transform(X)
+ _ = cebra_model.transform(X)
def _iterate_actions():
@@ -1378,18 +1376,16 @@ def test_new_transform(model_architecture, device):
# example dataset
X = np.random.uniform(0, 1, (1000, 50))
X_s2 = np.random.uniform(0, 1, (800, 30))
- X_s3 = np.random.uniform(0, 1, (1000, 30))
y_c1 = np.random.uniform(0, 1, (1000, 5))
y_c1_s2 = np.random.uniform(0, 1, (800, 5))
y_c2 = np.random.uniform(0, 1, (1000, 2))
- y_c2_s2 = np.random.uniform(0, 1, (800, 2))
y_d = np.random.randint(0, 10, (1000,))
y_d_s2 = np.random.randint(0, 10, (800,))
# time contrastive
cebra_model.fit(X)
embedding1 = cebra_model.transform(X)
- embedding2 = cebra_model.transform_deprecated(X)
+ embedding2 = _utils_deprecated.cebra_transform_deprecated(cebra_model, X)
assert np.allclose(embedding1, embedding2, rtol=1e-5,
atol=1e-8), "Arrays are not close enough"
@@ -1398,17 +1394,20 @@ def test_new_transform(model_architecture, device):
assert cebra_model.num_sessions is None
embedding1 = cebra_model.transform(X)
- embedding2 = cebra_model.transform_deprecated(X)
+ embedding2 = _utils_deprecated.cebra_transform_deprecated(cebra_model, X)
assert np.allclose(embedding1, embedding2, rtol=1e-5,
atol=1e-8), "Arrays are not close enough"
embedding1 = cebra_model.transform(torch.Tensor(X))
- embedding2 = cebra_model.transform_deprecated(torch.Tensor(X))
+ embedding2 = _utils_deprecated.cebra_transform_deprecated(
+ cebra_model, torch.Tensor(X))
assert np.allclose(embedding1, embedding2, rtol=1e-5,
atol=1e-8), "Arrays are not close enough"
embedding1 = cebra_model.transform(torch.Tensor(X), session_id=0)
- embedding2 = cebra_model.transform_deprecated(torch.Tensor(X), session_id=0)
+ embedding2 = _utils_deprecated.cebra_transform_deprecated(cebra_model,
+ torch.Tensor(X),
+ session_id=0)
assert np.allclose(embedding1, embedding2, rtol=1e-5,
atol=1e-8), "Arrays are not close enough"
@@ -1418,14 +1417,14 @@ def test_new_transform(model_architecture, device):
# discrete behavior contrastive
cebra_model.fit(X, y_d)
embedding1 = cebra_model.transform(X)
- embedding2 = cebra_model.transform_deprecated(X)
+ embedding2 = _utils_deprecated.cebra_transform_deprecated(cebra_model, X)
assert np.allclose(embedding1, embedding2, rtol=1e-5,
atol=1e-8), "Arrays are not close enough"
# mixed
cebra_model.fit(X, y_c1, y_c2, y_d)
embedding1 = cebra_model.transform(X)
- embedding2 = cebra_model.transform_deprecated(X)
+ embedding2 = _utils_deprecated.cebra_transform_deprecated(cebra_model, X)
assert np.allclose(embedding1, embedding2, rtol=1e-5,
atol=1e-8), "Arrays are not close enough"
@@ -1433,17 +1432,23 @@ def test_new_transform(model_architecture, device):
cebra_model.fit([X, X_s2], [y_d, y_d_s2])
embedding1 = cebra_model.transform(X, session_id=0)
- embedding2 = cebra_model.transform_deprecated(X, session_id=0)
+ embedding2 = _utils_deprecated.cebra_transform_deprecated(cebra_model,
+ X,
+ session_id=0)
assert np.allclose(embedding1, embedding2, rtol=1e-5,
atol=1e-8), "Arrays are not close enough"
embedding1 = cebra_model.transform(torch.Tensor(X), session_id=0)
- embedding2 = cebra_model.transform_deprecated(torch.Tensor(X), session_id=0)
+ embedding2 = _utils_deprecated.cebra_transform_deprecated(cebra_model,
+ torch.Tensor(X),
+ session_id=0)
assert np.allclose(embedding1, embedding2, rtol=1e-5,
atol=1e-8), "Arrays are not close enough"
embedding1 = cebra_model.transform(X_s2, session_id=1)
- embedding2 = cebra_model.transform_deprecated(X_s2, session_id=1)
+ embedding2 = _utils_deprecated.cebra_transform_deprecated(cebra_model,
+ X_s2,
+ session_id=1)
assert np.allclose(embedding1, embedding2, rtol=1e-5,
atol=1e-8), "Arrays are not close enough"
@@ -1451,12 +1456,16 @@ def test_new_transform(model_architecture, device):
cebra_model.fit([X, X_s2], [y_c1, y_c1_s2])
embedding1 = cebra_model.transform(X, session_id=0)
- embedding2 = cebra_model.transform_deprecated(X, session_id=0)
+ embedding2 = _utils_deprecated.cebra_transform_deprecated(cebra_model,
+ X,
+ session_id=0)
assert np.allclose(embedding1, embedding2, rtol=1e-5,
atol=1e-8), "Arrays are not close enough"
embedding1 = cebra_model.transform(torch.Tensor(X), session_id=0)
- embedding2 = cebra_model.transform_deprecated(torch.Tensor(X), session_id=0)
+ embedding2 = _utils_deprecated.cebra_transform_deprecated(cebra_model,
+ torch.Tensor(X),
+ session_id=0)
assert np.allclose(embedding1, embedding2, rtol=1e-5,
atol=1e-8), "Arrays are not close enough"
@@ -1475,17 +1484,23 @@ def test_new_transform(model_architecture, device):
cebra_model.fit([X, X_s2, X], [y_d, y_d_s2, y_d])
embedding1 = cebra_model.transform(X, session_id=0)
- embedding2 = cebra_model.transform_deprecated(X, session_id=0)
+ embedding2 = _utils_deprecated.cebra_transform_deprecated(cebra_model,
+ X,
+ session_id=0)
assert np.allclose(embedding1, embedding2, rtol=1e-5,
atol=1e-8), "Arrays are not close enough"
embedding1 = cebra_model.transform(X_s2, session_id=1)
- embedding2 = cebra_model.transform_deprecated(X_s2, session_id=1)
+ embedding2 = _utils_deprecated.cebra_transform_deprecated(cebra_model,
+ X_s2,
+ session_id=1)
assert np.allclose(embedding1, embedding2, rtol=1e-5,
atol=1e-8), "Arrays are not close enough"
embedding1 = cebra_model.transform(X, session_id=2)
- embedding2 = cebra_model.transform_deprecated(X, session_id=2)
+ embedding2 = _utils_deprecated.cebra_transform_deprecated(cebra_model,
+ X,
+ session_id=2)
assert np.allclose(embedding1, embedding2, rtol=1e-5,
atol=1e-8), "Arrays are not close enough"
@@ -1493,25 +1508,31 @@ def test_new_transform(model_architecture, device):
cebra_model.fit([X, X_s2, X], [y_c1, y_c1_s2, y_c1])
embedding1 = cebra_model.transform(X, session_id=0)
- embedding2 = cebra_model.transform_deprecated(X, session_id=0)
+ embedding2 = _utils_deprecated.cebra_transform_deprecated(cebra_model,
+ X,
+ session_id=0)
assert np.allclose(embedding1, embedding2, rtol=1e-5,
atol=1e-8), "Arrays are not close enough"
embedding1 = cebra_model.transform(X_s2, session_id=1)
- embedding2 = cebra_model.transform_deprecated(X_s2, session_id=1)
+ embedding2 = _utils_deprecated.cebra_transform_deprecated(cebra_model,
+ X_s2,
+ session_id=1)
assert np.allclose(embedding1, embedding2, rtol=1e-5,
atol=1e-8), "Arrays are not close enough"
embedding1 = cebra_model.transform(X, session_id=2)
- embedding2 = cebra_model.transform_deprecated(X, session_id=2)
+ embedding2 = _utils_deprecated.cebra_transform_deprecated(cebra_model,
+ X,
+ session_id=2)
assert np.allclose(embedding1, embedding2, rtol=1e-5,
atol=1e-8), "Arrays are not close enough"
def test_last_incomplete_batch_smaller_than_offset():
"""
- When offset of the model is larger than the remaining samples in the
- last batch, an error could happen. We merge the penultimate
+ When offset of the model is larger than the remaining samples in the
+ last batch, an error could happen. We merge the penultimate
and last batches together to avoid this.
"""
train = cebra.data.TensorDataset(neural=np.random.rand(20111, 100),
@@ -1522,4 +1543,4 @@ def test_last_incomplete_batch_smaller_than_offset():
device="cpu")
model.fit(train.neural, train.continuous)
- _ = model.transform(train.neural, batch_size=300)
\ No newline at end of file
+ _ = model.transform(train.neural, batch_size=300)
diff --git a/tests/test_sklearn_legacy.py b/tests/test_sklearn_legacy.py
new file mode 100644
index 00000000..4d74515f
--- /dev/null
+++ b/tests/test_sklearn_legacy.py
@@ -0,0 +1,41 @@
+import pathlib
+import urllib.request
+
+import numpy as np
+import pytest
+
+from cebra.integrations.sklearn.cebra import CEBRA
+
+MODEL_VARIANTS = [
+ "cebra-0.4.0-scikit-learn-1.4", "cebra-0.4.0-scikit-learn-1.6",
+ "cebra-rc-scikit-learn-1.4", "cebra-rc-scikit-learn-1.6"
+]
+
+
+@pytest.mark.parametrize("model_variant", MODEL_VARIANTS)
+def test_load_legacy_model(model_variant):
+ """Test loading a legacy CEBRA model."""
+
+ X = np.random.normal(0, 1, (1000, 30))
+
+ model_path = pathlib.Path(
+ __file__
+ ).parent / "_build_legacy_model" / f"cebra_model_{model_variant}.pt"
+
+ if not model_path.exists():
+ url = f"https://cebra.fra1.digitaloceanspaces.com/cebra_model_{model_variant}.pt"
+ urllib.request.urlretrieve(url, model_path)
+
+ loaded_model = CEBRA.load(model_path)
+
+ assert loaded_model.model_architecture == "offset10-model"
+ assert loaded_model.output_dimension == 8
+ assert loaded_model.num_hidden_units == 16
+ assert loaded_model.time_offsets == 10
+
+ output = loaded_model.transform(X)
+ assert isinstance(output, np.ndarray)
+ assert output.shape[1] == loaded_model.output_dimension
+
+ assert hasattr(loaded_model, "state_dict_")
+ assert hasattr(loaded_model, "n_features_")
diff --git a/tests/test_sklearn_metrics.py b/tests/test_sklearn_metrics.py
index 58e12010..4e765ba7 100644
--- a/tests/test_sklearn_metrics.py
+++ b/tests/test_sklearn_metrics.py
@@ -383,3 +383,132 @@ def test_sklearn_runs_consistency():
with pytest.raises(ValueError, match="Invalid.*embeddings"):
_, _, _ = cebra_sklearn_metrics.consistency_score(
invalid_embeddings_runs, between="runs")
+
+
+@pytest.mark.parametrize("seed", [42, 24, 10])
+def test_goodness_of_fit_score(seed):
+ """
+ Ensure that the GoF score is close to 0 for a model fit on random data.
+ """
+ cebra_model = cebra_sklearn_cebra.CEBRA(
+ model_architecture="offset1-model",
+ max_iterations=5,
+ batch_size=512,
+ )
+ generator = torch.Generator().manual_seed(seed)
+ X = torch.rand(5000, 50, dtype=torch.float32, generator=generator)
+ y = torch.rand(5000, 5, dtype=torch.float32, generator=generator)
+ cebra_model.fit(X, y)
+ score = cebra_sklearn_metrics.goodness_of_fit_score(cebra_model,
+ X,
+ y,
+ session_id=0,
+ num_batches=500)
+ assert isinstance(score, float)
+ assert np.isclose(score, 0, atol=0.01)
+
+
+@pytest.mark.parametrize("seed", [42, 24, 10])
+def test_goodness_of_fit_history(seed):
+ """
+ Ensure that the GoF score is higher for a model fit on data with underlying
+ structure than for a model fit on random data.
+ """
+
+ # Generate data
+ generator = torch.Generator().manual_seed(seed)
+ X = torch.rand(1000, 50, dtype=torch.float32, generator=generator)
+ y_random = torch.rand(len(X), 5, dtype=torch.float32, generator=generator)
+ linear_map = torch.randn(50, 5, dtype=torch.float32, generator=generator)
+ y_linear = X @ linear_map
+
+ def _fit_and_get_history(X, y):
+ cebra_model = cebra_sklearn_cebra.CEBRA(
+ model_architecture="offset1-model",
+ max_iterations=150,
+ batch_size=512,
+ device="cpu")
+ cebra_model.fit(X, y)
+ history = cebra_sklearn_metrics.goodness_of_fit_history(cebra_model)
+ # NOTE(stes): Ignore the first 5 iterations, they can have nonsensical values
+ # due to numerical issues.
+ return history[5:]
+
+ history_random = _fit_and_get_history(X, y_random)
+ history_linear = _fit_and_get_history(X, y_linear)
+
+ assert isinstance(history_random, np.ndarray)
+ assert history_random.shape[0] > 0
+ # NOTE(stes): Ignore the first 5 iterations, they can have nonsensical values
+ # due to numerical issues.
+ history_random_non_negative = history_random[history_random >= 0]
+ np.testing.assert_allclose(history_random_non_negative, 0, atol=0.075)
+
+ assert isinstance(history_linear, np.ndarray)
+ assert history_linear.shape[0] > 0
+
+ assert np.all(history_linear[-20:] > history_random[-20:])
+
+
+@pytest.mark.parametrize("seed", [42, 24, 10])
+def test_infonce_to_goodness_of_fit(seed):
+ """Test the conversion from InfoNCE loss to goodness of fit metric."""
+ # Test with model
+ cebra_model = cebra_sklearn_cebra.CEBRA(
+ model_architecture="offset10-model",
+ max_iterations=5,
+ batch_size=128,
+ )
+ generator = torch.Generator().manual_seed(seed)
+ X = torch.rand(1000, 50, dtype=torch.float32, generator=generator)
+ cebra_model.fit(X)
+
+ # Test single value
+ gof = cebra_sklearn_metrics.infonce_to_goodness_of_fit(1.0,
+ model=cebra_model)
+ assert isinstance(gof, float)
+
+ # Test array of values
+ infonce_values = np.array([1.0, 2.0, 3.0])
+ gof_array = cebra_sklearn_metrics.infonce_to_goodness_of_fit(
+ infonce_values, model=cebra_model)
+ assert isinstance(gof_array, np.ndarray)
+ assert gof_array.shape == infonce_values.shape
+
+ # Test with explicit batch_size and num_sessions
+ gof = cebra_sklearn_metrics.infonce_to_goodness_of_fit(1.0,
+ batch_size=128,
+ num_sessions=1)
+ assert isinstance(gof, float)
+
+ # Test error cases
+ with pytest.raises(ValueError, match="batch_size.*should not be provided"):
+ cebra_sklearn_metrics.infonce_to_goodness_of_fit(1.0,
+ model=cebra_model,
+ batch_size=128)
+
+ with pytest.raises(ValueError, match="batch_size.*should not be provided"):
+ cebra_sklearn_metrics.infonce_to_goodness_of_fit(1.0,
+ model=cebra_model,
+ num_sessions=1)
+
+ # Test with unfitted model
+ unfitted_model = cebra_sklearn_cebra.CEBRA(max_iterations=5)
+ with pytest.raises(RuntimeError, match="Fit the CEBRA model first"):
+ cebra_sklearn_metrics.infonce_to_goodness_of_fit(1.0,
+ model=unfitted_model)
+
+ # Test with model having batch_size=None
+ none_batch_model = cebra_sklearn_cebra.CEBRA(batch_size=None,
+ max_iterations=5)
+ none_batch_model.fit(X)
+ with pytest.raises(ValueError, match="Computing the goodness of fit"):
+ cebra_sklearn_metrics.infonce_to_goodness_of_fit(1.0,
+ model=none_batch_model)
+
+ # Test missing batch_size or num_sessions when model is None
+ with pytest.raises(ValueError, match="batch_size.*and num_sessions"):
+ cebra_sklearn_metrics.infonce_to_goodness_of_fit(1.0, batch_size=128)
+
+ with pytest.raises(ValueError, match="batch_size.*and num_sessions"):
+ cebra_sklearn_metrics.infonce_to_goodness_of_fit(1.0, num_sessions=1)
diff --git a/tests/test_solver.py b/tests/test_solver.py
index 68e2a43e..5cbbbfb3 100644
--- a/tests/test_solver.py
+++ b/tests/test_solver.py
@@ -34,41 +34,12 @@
device = "cpu"
-single_session_tests = []
-for args in [
- ("demo-discrete", cebra.data.DiscreteDataLoader, "offset10-model"),
- ("demo-discrete", cebra.data.DiscreteDataLoader, "offset1-model"),
- ("demo-discrete", cebra.data.DiscreteDataLoader, "offset1-model"),
- ("demo-discrete", cebra.data.DiscreteDataLoader, "offset10-model"),
- ("demo-continuous", cebra.data.ContinuousDataLoader, "offset10-model"),
- ("demo-continuous", cebra.data.ContinuousDataLoader, "offset1-model"),
- ("demo-mixed", cebra.data.MixedDataLoader, "offset10-model"),
- ("demo-mixed", cebra.data.MixedDataLoader, "offset1-model"),
-]:
- single_session_tests.append((*args, cebra.solver.SingleSessionSolver))
-
-single_session_hybrid_tests = []
-for args in [("demo-continuous", cebra.data.HybridDataLoader, "offset10-model"),
- ("demo-continuous", cebra.data.HybridDataLoader, "offset1-model")]:
- single_session_hybrid_tests.append(
- (*args, cebra.solver.SingleSessionHybridSolver))
-
-multi_session_tests = []
-for args in [
- ("demo-continuous-multisession",
- cebra.data.ContinuousMultiSessionDataLoader, "offset1-model"),
- ("demo-continuous-multisession",
- cebra.data.ContinuousMultiSessionDataLoader, "offset10-model"),
-]:
- multi_session_tests.append((*args, cebra.solver.MultiSessionSolver))
-
-# multi_session_tests.append((*args, cebra.solver.MultiSessionAuxVariableSolver))
-
-
-def _get_loader(data, loader_initfunc):
- kwargs = dict(num_steps=5, batch_size=32)
+
+def _get_loader(data_name, loader_initfunc):
+ data = cebra.datasets.init(data_name)
+ kwargs = dict(num_steps=2, batch_size=32)
loader = loader_initfunc(data, **kwargs)
- return loader
+ return loader, data
OUTPUT_DIMENSION = 3
@@ -84,12 +55,12 @@ def _make_model(dataset, model_architecture="offset10-model"):
OUTPUT_DIMENSION)
-# def _make_behavior_model(dataset):
-# # TODO flexible input dimension
-# return nn.Sequential(
-# nn.Conv1d(dataset.input_dimension, 5, kernel_size=10),
-# nn.Flatten(start_dim=1, end_dim=-1),
-# )
+def _make_behavior_model(dataset):
+ # TODO flexible input dimension
+ return nn.Sequential(
+ nn.Conv1d(dataset.input_dimension, 5, kernel_size=10),
+ nn.Flatten(start_dim=1, end_dim=-1),
+ )
def _assert_same_state_dict(first, second):
@@ -135,12 +106,16 @@ def _assert_equal(original_solver, loaded_solver):
@pytest.mark.parametrize(
- "data_name, loader_initfunc, model_architecture, solver_initfunc",
- single_session_tests)
+ "data_name, model_architecture, loader_initfunc, solver_initfunc",
+ [(dataset, model, loader, cebra.solver.SingleSessionSolver)
+ for dataset, loader in [("demo-discrete", cebra.data.DiscreteDataLoader),
+ ("demo-continuous", cebra.data.ContinuousDataLoader
+ ), ("demo-mixed", cebra.data.MixedDataLoader)]
+ for model in
+ ["offset1-model", "offset10-model", "offset40-model-4x-subsample"]])
def test_single_session(data_name, loader_initfunc, model_architecture,
solver_initfunc):
- data = cebra.datasets.init(data_name)
- loader = _get_loader(data, loader_initfunc)
+ loader, data = _get_loader(data_name, loader_initfunc)
model = _make_model(data, model_architecture)
data.configure_for(model)
offset = model.get_offset()
@@ -163,21 +138,84 @@ def test_single_session(data_name, loader_initfunc, model_architecture,
solver.fit(loader)
- assert solver.num_sessions == None
+ assert solver.num_sessions is None
assert solver.n_features == X.shape[1]
embedding = solver.transform(X)
assert isinstance(embedding, torch.Tensor)
- assert embedding.shape == (X.shape[0], OUTPUT_DIMENSION)
+ if isinstance(solver.model, cebra.models.ResampleModelMixin):
+ assert embedding.shape == (X.shape[0] // solver.model.resample_factor,
+ OUTPUT_DIMENSION)
+ else:
+ assert embedding.shape == (X.shape[0], OUTPUT_DIMENSION)
embedding = solver.transform(torch.Tensor(X))
assert isinstance(embedding, torch.Tensor)
- assert embedding.shape == (X.shape[0], OUTPUT_DIMENSION)
+ if isinstance(solver.model, cebra.models.ResampleModelMixin):
+ assert embedding.shape == (X.shape[0] // solver.model.resample_factor,
+ OUTPUT_DIMENSION)
+ else:
+ assert embedding.shape == (X.shape[0], OUTPUT_DIMENSION)
embedding = solver.transform(X, session_id=0)
assert isinstance(embedding, torch.Tensor)
- assert embedding.shape == (X.shape[0], OUTPUT_DIMENSION)
+ if isinstance(solver.model, cebra.models.ResampleModelMixin):
+ assert embedding.shape == (X.shape[0] // solver.model.resample_factor,
+ OUTPUT_DIMENSION)
+ else:
+ assert embedding.shape == (X.shape[0], OUTPUT_DIMENSION)
embedding = solver.transform(X, pad_before_transform=False)
assert isinstance(embedding, torch.Tensor)
- assert embedding.shape == (X.shape[0] - len(offset) + 1, OUTPUT_DIMENSION)
+ if isinstance(solver.model, cebra.models.ResampleModelMixin):
+ assert embedding.shape == (
+ (X.shape[0] - len(offset)) // solver.model.resample_factor + 1,
+ OUTPUT_DIMENSION)
+ else:
+ assert embedding.shape == (X.shape[0] - len(offset) + 1,
+ OUTPUT_DIMENSION)
+
+ with pytest.raises(ValueError, match="torch.Tensor"):
+ solver.transform(X.numpy())
+ with pytest.raises(RuntimeError, match="Invalid.*session_id"):
+ embedding = solver.transform(X, session_id=2)
+
+ for param in solver.parameters():
+ assert isinstance(param, torch.Tensor)
+
+ fitted_solver = copy.deepcopy(solver)
+ with tempfile.TemporaryDirectory() as temp_dir:
+ solver.save(temp_dir)
+ solver.load(temp_dir)
+ _assert_equal(fitted_solver, solver)
+
+ embedding = solver.transform(X)
+ assert isinstance(embedding, torch.Tensor)
+ if isinstance(solver.model, cebra.models.ResampleModelMixin):
+ assert embedding.shape == (X.shape[0] // solver.model.resample_factor,
+ OUTPUT_DIMENSION)
+ else:
+ assert embedding.shape == (X.shape[0], OUTPUT_DIMENSION)
+ embedding = solver.transform(torch.Tensor(X))
+ assert isinstance(embedding, torch.Tensor)
+ if isinstance(solver.model, cebra.models.ResampleModelMixin):
+ assert embedding.shape == (X.shape[0] // solver.model.resample_factor,
+ OUTPUT_DIMENSION)
+ else:
+ assert embedding.shape == (X.shape[0], OUTPUT_DIMENSION)
+ embedding = solver.transform(X, session_id=0)
+ assert isinstance(embedding, torch.Tensor)
+ if isinstance(solver.model, cebra.models.ResampleModelMixin):
+ assert embedding.shape == (X.shape[0] // solver.model.resample_factor,
+ OUTPUT_DIMENSION)
+ else:
+ assert embedding.shape == (X.shape[0], OUTPUT_DIMENSION)
+ embedding = solver.transform(X, pad_before_transform=False)
+ assert isinstance(embedding, torch.Tensor)
+ if isinstance(solver.model, cebra.models.ResampleModelMixin):
+ assert embedding.shape == (
+ (X.shape[0] - len(offset)) // solver.model.resample_factor + 1,
+ OUTPUT_DIMENSION)
+ else:
+ assert embedding.shape == (X.shape[0] - len(offset) + 1,
+ OUTPUT_DIMENSION)
with pytest.raises(ValueError, match="torch.Tensor"):
solver.transform(X.numpy())
@@ -195,15 +233,21 @@ def test_single_session(data_name, loader_initfunc, model_architecture,
@pytest.mark.parametrize(
- "data_name, loader_initfunc, model_architecture, solver_initfunc",
- single_session_tests)
+ "data_name, model_architecture, loader_initfunc, solver_initfunc",
+ [(dataset, model, loader, cebra.solver.SingleSessionSolver)
+ for dataset, loader in [("demo-discrete", cebra.data.DiscreteDataLoader),
+ ("demo-continuous", cebra.data.ContinuousDataLoader
+ ), ("demo-mixed", cebra.data.MixedDataLoader)]
+ for model in
+ ["offset1-model", "offset10-model", "offset40-model-4x-subsample"]])
def test_single_session_auxvar(data_name, loader_initfunc, model_architecture,
solver_initfunc):
- return # TODO
+
+ pytest.skip("Not yet supported")
loader = _get_loader(data_name, loader_initfunc)
model = _make_model(loader.dataset)
- behavior_model = _make_behavior_model(loader.dataset)
+ behavior_model = _make_behavior_model(loader.dataset) # noqa: F841
criterion = cebra.models.InfoNCE()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
@@ -223,12 +267,13 @@ def test_single_session_auxvar(data_name, loader_initfunc, model_architecture,
@pytest.mark.parametrize(
- "data_name, loader_initfunc, model_architecture, solver_initfunc",
- single_session_hybrid_tests)
+ "data_name, model_architecture, loader_initfunc, solver_initfunc",
+ [("demo-continuous", model, cebra.data.HybridDataLoader,
+ cebra.solver.SingleSessionHybridSolver)
+ for model in ["offset1-model", "offset10-model"]])
def test_single_session_hybrid(data_name, loader_initfunc, model_architecture,
solver_initfunc):
- data = cebra.datasets.init(data_name)
- loader = _get_loader(data, loader_initfunc)
+ loader, data = _get_loader(data_name, loader_initfunc)
model = _make_model(data, model_architecture)
data.configure_for(model)
offset = model.get_offset()
@@ -250,7 +295,7 @@ def test_single_session_hybrid(data_name, loader_initfunc, model_architecture,
solver.fit(loader)
- assert solver.num_sessions == None
+ assert solver.num_sessions is None
assert solver.n_features == X.shape[1]
embedding = solver.transform(X)
@@ -282,17 +327,25 @@ def test_single_session_hybrid(data_name, loader_initfunc, model_architecture,
@pytest.mark.parametrize(
- "data_name, loader_initfunc, model_architecture, solver_initfunc",
- multi_session_tests)
+ "data_name, model_architecture, loader_initfunc, solver_initfunc",
+ [(dataset, model, loader, cebra.solver.MultiSessionSolver)
+ for dataset, loader in [
+ ("demo-discrete-multisession",
+ cebra.data.DiscreteMultiSessionDataLoader),
+ ("demo-continuous-multisession",
+ cebra.data.ContinuousMultiSessionDataLoader),
+ ]
+ for model in ["offset1-model", "offset10-model"]])
def test_multi_session(data_name, loader_initfunc, model_architecture,
solver_initfunc):
- data = cebra.datasets.init(data_name)
- loader = _get_loader(data, loader_initfunc)
+ loader, data = _get_loader(data_name, loader_initfunc)
model = nn.ModuleList([
_make_model(dataset, model_architecture)
for dataset in data.iter_sessions()
])
data.configure_for(model)
+ offset_length = len(model[0].get_offset())
+
criterion = cebra.models.InfoNCE()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
@@ -302,8 +355,9 @@ def test_multi_session(data_name, loader_initfunc, model_architecture,
batch = next(iter(loader))
for session_id, dataset in enumerate(loader.dataset.iter_sessions()):
- assert batch[session_id].reference.shape[:2] == (
- 32, dataset.input_dimension)
+ assert batch[session_id].reference.shape == (32,
+ dataset.input_dimension,
+ offset_length)
assert batch[session_id].index is not None
log = solver.step(batch)
@@ -360,267 +414,28 @@ def test_multi_session(data_name, loader_initfunc, model_architecture,
_assert_equal(fitted_solver, solver)
-@pytest.mark.parametrize(
- "inputs, add_padding, offset, start_batch_idx, end_batch_idx, expected_output",
- [
- # Test case 1: No padding
- (torch.tensor([[1, 2], [3, 4], [5, 6]]), False, cebra.data.Offset(
- 0, 1), 0, 2, torch.tensor([[1, 2], [3, 4]])), # first batch
- (torch.tensor([[1, 2], [3, 4], [5, 6]]), False, cebra.data.Offset(
- 0, 1), 1, 3, torch.tensor([[3, 4], [5, 6]])), # last batch
- (torch.tensor(
- [[1, 2], [3, 4], [5, 6], [7, 8]]), False, cebra.data.Offset(
- 0, 1), 1, 3, torch.tensor([[3, 4], [5, 6]])), # middle batch
-
- # Test case 2: First batch with padding
- (
- torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]),
- True,
- cebra.data.Offset(0, 1),
- 0,
- 2,
- torch.tensor([[1, 2, 3], [4, 5, 6]]),
- ),
- (
- torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]),
- True,
- cebra.data.Offset(1, 1),
- 0,
- 3,
- torch.tensor([[1, 2, 3], [1, 2, 3], [4, 5, 6], [7, 8, 9]]),
- ),
-
- # Test case 3: Last batch with padding
- (
- torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]),
- True,
- cebra.data.Offset(0, 1),
- 1,
- 3,
- torch.tensor([[4, 5, 6], [7, 8, 9]]),
- ),
- (
- torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12],
- [13, 14, 15]]),
- True,
- cebra.data.Offset(1, 2),
- 1,
- 3,
- torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]),
- ),
-
- # Test case 4: Middle batch with padding
- (
- torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]),
- True,
- cebra.data.Offset(0, 1),
- 1,
- 3,
- torch.tensor([[4, 5, 6], [7, 8, 9]]),
- ),
- (
- torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]),
- True,
- cebra.data.Offset(1, 1),
- 1,
- 3,
- torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]),
- ),
- (
- torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12],
- [13, 14, 15]]),
- True,
- cebra.data.Offset(0, 1),
- 2,
- 4,
- torch.tensor([[7, 8, 9], [10, 11, 12]]),
- ),
- (
- torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]),
- True,
- cebra.data.Offset(0, 1),
- 0,
- 3,
- torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]),
- ),
-
- # Examples that throw an error:
-
- # Padding without offset (should raise an error)
- (torch.tensor([[1, 2]]), True, None, 0, 2, ValueError),
- # Negative start_batch_idx or end_batch_idx (should raise an error)
- (torch.tensor([[1, 2]]), False, cebra.data.Offset(
- 0, 1), -1, 2, ValueError),
- # out of bound indices because offset is too large
- (torch.tensor([[1, 2], [3, 4]]), True, cebra.data.Offset(
- 5, 5), 1, 2, ValueError),
- # Batch length is smaller than offset.
- (torch.tensor([[1, 2], [3, 4]]), False, cebra.data.Offset(
- 0, 1), 0, 1, ValueError), # first batch
- ],
-)
-def test_get_batch(inputs, add_padding, offset, start_batch_idx, end_batch_idx,
- expected_output):
- if expected_output == ValueError:
- with pytest.raises(ValueError):
- cebra.solver.base._get_batch(inputs, offset, start_batch_idx,
- end_batch_idx, add_padding)
- else:
- result = cebra.solver.base._get_batch(inputs, offset, start_batch_idx,
- end_batch_idx, add_padding)
- assert torch.equal(result, expected_output)
-
-
-def create_model(model_name, input_dimension):
- return cebra.models.init(model_name,
- num_neurons=input_dimension,
- num_units=128,
- num_output=OUTPUT_DIMENSION)
-
-
-single_session_tests_select_model = []
-single_session_hybrid_tests_select_model = []
-for model_name in ["offset1-model", "offset10-model"]:
- for session_id in [None, 0, 5]:
- for args in [
- ("demo-discrete", model_name, session_id,
- cebra.data.DiscreteDataLoader),
- ("demo-continuous", model_name, session_id,
- cebra.data.ContinuousDataLoader),
- ("demo-mixed", model_name, session_id, cebra.data.MixedDataLoader),
- ]:
- single_session_tests_select_model.append(
- (*args, cebra.solver.SingleSessionSolver))
- single_session_hybrid_tests_select_model.append(
- (*args, cebra.solver.SingleSessionHybridSolver))
-
-multi_session_tests_select_model = []
-for model_name in ["offset10-model"]:
- for session_id in [None, 0, 1, 5, 2, 6, 4]:
- for args in [("demo-continuous-multisession", model_name, session_id,
- cebra.data.ContinuousMultiSessionDataLoader)]:
- multi_session_tests_select_model.append(
- (*args, cebra.solver.MultiSessionSolver))
+def _make_val_data(dataset):
+ if isinstance(dataset, cebra.datasets.demo.DemoDataset):
+ return dataset.neural
+ elif isinstance(dataset, cebra.datasets.demo.DemoDatasetUnified):
+ return [session.neural for session in dataset.iter_sessions()], [
+ session.continuous_index for session in dataset.iter_sessions()
+ ]
@pytest.mark.parametrize(
- "data_name, model_name ,session_id, loader_initfunc, solver_initfunc",
- single_session_tests_select_model + single_session_hybrid_tests_select_model
-)
-def test_select_model_single_session(data_name, model_name, session_id,
- loader_initfunc, solver_initfunc):
- dataset = cebra.datasets.init(data_name)
- model = create_model(model_name, dataset.input_dimension)
- dataset.configure_for(model)
- loader = _get_loader(dataset, loader_initfunc=loader_initfunc)
+ "data_name, model_architecture, loader_initfunc, solver_initfunc",
+ [(dataset, model, loader, cebra.solver.UnifiedSolver)
+ for dataset, loader in [
+ ("demo-continuous-unified", cebra.data.UnifiedLoader),
+ ]
+ for model in ["offset1-model", "offset10-model"]])
+def test_unified_session(data_name, model_architecture, loader_initfunc,
+ solver_initfunc):
+ loader, data = _get_loader(data_name, loader_initfunc)
+ model = _make_model(data, model_architecture)
+ data.configure_for(model)
offset = model.get_offset()
- solver = solver_initfunc(model=model, criterion=None, optimizer=None)
-
- with pytest.raises(ValueError):
- solver.n_features = 1000
- solver._select_model(inputs=dataset.neural, session_id=0)
-
- solver.n_features = dataset.neural.shape[1]
- if session_id is not None and session_id > 0:
- with pytest.raises(RuntimeError):
- solver._select_model(inputs=dataset.neural, session_id=session_id)
- else:
- model_, offset_ = solver._select_model(inputs=dataset.neural,
- session_id=session_id)
- assert offset.left == offset_.left and offset.right == offset_.right
- assert model == model_
-
-
-@pytest.mark.parametrize(
- "data_name, model_name, session_id, loader_initfunc, solver_initfunc",
- multi_session_tests_select_model)
-def test_select_model_multi_session(data_name, model_name, session_id,
- loader_initfunc, solver_initfunc):
- dataset = cebra.datasets.init(data_name)
- model = nn.ModuleList([
- create_model(model_name, dataset.input_dimension)
- for dataset in dataset.iter_sessions()
- ])
- dataset.configure_for(model)
- loader = _get_loader(dataset, loader_initfunc=loader_initfunc)
-
- offset = model[0].get_offset()
- solver = solver_initfunc(model=model,
- criterion=cebra.models.InfoNCE(),
- optimizer=torch.optim.Adam(model.parameters(),
- lr=1e-3))
-
- loader_kwargs = dict(num_steps=10, batch_size=32)
- loader = cebra.data.ContinuousMultiSessionDataLoader(
- dataset, **loader_kwargs)
- solver.fit(loader)
-
- for i, (model, dataset_) in enumerate(zip(model, dataset.iter_sessions())):
- inputs = dataset_.neural
-
- if session_id is None or session_id >= dataset.num_sessions:
- with pytest.raises(RuntimeError):
- solver._select_model(inputs, session_id=session_id)
- elif i != session_id:
- with pytest.raises(ValueError):
- solver._select_model(inputs, session_id=session_id)
- else:
- model_, offset_ = solver._select_model(inputs,
- session_id=session_id)
- assert offset.left == offset_.left and offset.right == offset_.right
- assert model == model_
-
-
-models = [
- "offset1-model",
- "offset10-model",
- "offset40-model-4x-subsample",
- "offset1-model",
- "offset10-model",
-]
-batch_size_inference = [40_000, 99_990, 99_999]
-
-single_session_tests_transform = []
-for padding in [True, False]:
- for model_name in models:
- for batch_size in batch_size_inference:
- for args in [
- ("demo-discrete", model_name, padding, batch_size,
- cebra.data.DiscreteDataLoader),
- ("demo-continuous", model_name, padding, batch_size,
- cebra.data.ContinuousDataLoader),
- ("demo-mixed", model_name, padding, batch_size,
- cebra.data.MixedDataLoader),
- ]:
- single_session_tests_transform.append(
- (*args, cebra.solver.SingleSessionSolver))
-
-single_session_hybrid_tests_transform = []
-for padding in [True, False]:
- for model_name in models:
- for batch_size in batch_size_inference:
- for args in [("demo-continuous", model_name, padding, batch_size,
- cebra.data.HybridDataLoader)]:
- single_session_hybrid_tests_transform.append(
- (*args, cebra.solver.SingleSessionHybridSolver))
-
-
-@pytest.mark.parametrize(
- "data_name, model_name, padding, batch_size_inference, loader_initfunc, solver_initfunc",
- single_session_tests_transform + single_session_hybrid_tests_transform)
-def test_batched_transform_single_session(
- data_name,
- model_name,
- padding,
- batch_size_inference,
- loader_initfunc,
- solver_initfunc,
-):
- dataset = cebra.datasets.init(data_name)
- model = create_model(model_name, dataset.input_dimension)
- dataset.offset = model.get_offset()
- loader_kwargs = dict(num_steps=10, batch_size=32)
- loader = loader_initfunc(dataset, **loader_kwargs)
criterion = cebra.models.InfoNCE()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
@@ -628,95 +443,24 @@ def test_batched_transform_single_session(
solver = solver_initfunc(model=model,
criterion=criterion,
optimizer=optimizer)
- solver.fit(loader)
-
- smallest_batch_length = loader.dataset.neural.shape[0] - batch_size
- offset_ = model.get_offset()
- padding_left = offset_.left if padding else 0
-
- if smallest_batch_length <= len(offset_):
- with pytest.raises(ValueError):
- solver.transform(inputs=loader.dataset.neural,
- batch_size=batch_size,
- pad_before_transform=padding)
- else:
- embedding_batched = solver.transform(inputs=loader.dataset.neural,
- batch_size=batch_size,
- pad_before_transform=padding)
-
- embedding = solver.transform(inputs=loader.dataset.neural,
- pad_before_transform=padding)
-
- assert embedding_batched.shape == embedding.shape
- assert np.allclose(embedding_batched, embedding, rtol=1e-02)
-
-
-multi_session_tests_transform = []
-for padding in [True, False]:
- for model_name in models:
- for batch_size in batch_size_inference:
- for args in [
- ("demo-continuous-multisession", model_name, padding,
- batch_size, cebra.data.ContinuousMultiSessionDataLoader)
- ]:
- multi_session_tests_transform.append(
- (*args, cebra.solver.MultiSessionSolver))
-
-
-@pytest.mark.parametrize(
- "data_name, model_name,padding,batch_size_inference,loader_initfunc, solver_initfunc",
- multi_session_tests_transform)
-def test_batched_transform_multi_session(data_name, model_name, padding,
- batch_size_inference, loader_initfunc,
- solver_initfunc):
- dataset = cebra.datasets.init(data_name)
- model = nn.ModuleList([
- create_model(model_name, dataset.input_dimension)
- for dataset in dataset.iter_sessions()
- ])
- dataset.offset = model[0].get_offset()
-
- n_samples = dataset._datasets[0].neural.shape[0]
- assert all(
- d.neural.shape[0] == n_samples for d in dataset._datasets
- ), "for this set all of the sessions need to have same number of samples."
-
- smallest_batch_length = n_samples - batch_size
- offset_ = model[0].get_offset()
- padding_left = offset_.left if padding else 0
- for d in dataset._datasets:
- d.offset = offset_
- loader_kwargs = dict(num_steps=10, batch_size=32)
- loader = loader_initfunc(dataset, **loader_kwargs)
+ batch = next(iter(loader))
+ assert batch.reference.shape == (32, loader.dataset.input_dimension,
+ len(offset))
- criterion = cebra.models.InfoNCE()
- optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
+ log = solver.step(batch)
+ assert isinstance(log, dict)
- solver = solver_initfunc(model=model,
- criterion=criterion,
- optimizer=optimizer)
solver.fit(loader)
+ data, labels = _make_val_data(loader.dataset)
- # Transform each session with the right model, by providing the corresponding session ID
- for i, inputs in enumerate(dataset.iter_sessions()):
+ assert solver.num_sessions == 3
+ assert solver.n_features == sum(
+ [data[i].shape[1] for i in range(len(data))])
- if smallest_batch_length <= len(offset_):
- with pytest.raises(ValueError):
- solver.transform(inputs=inputs.neural,
- batch_size=batch_size,
- session_id=i,
- pad_before_transform=padding)
+ for i in range(loader.dataset.num_sessions):
+ emb = solver.transform(data, labels, session_id=i)
+ assert emb.shape == (loader.dataset.num_timepoints, 3)
- else:
- model_ = model[i]
- embedding = solver.transform(inputs=inputs.neural,
- session_id=i,
- pad_before_transform=padding)
- embedding_batched = solver.transform(inputs=inputs.neural,
- session_id=i,
- pad_before_transform=padding,
- batch_size=batch_size)
-
- assert embedding_batched.shape == embedding.shape
- assert np.allclose(embedding_batched, embedding, rtol=1e-02)
+ emb = solver.transform(data, labels, session_id=i, batch_size=300)
+ assert emb.shape == (loader.dataset.num_timepoints, 3)
diff --git a/tests/test_solver_batched.py b/tests/test_solver_batched.py
new file mode 100644
index 00000000..8592aea2
--- /dev/null
+++ b/tests/test_solver_batched.py
@@ -0,0 +1,343 @@
+#
+# CEBRA: Consistent EmBeddings of high-dimensional Recordings using Auxiliary variables
+# © Mackenzie W. Mathis & Steffen Schneider (v0.4.0+)
+# Source code:
+# https://github.com/AdaptiveMotorControlLab/CEBRA
+#
+# Please see LICENSE.md for the full license document:
+# https://github.com/AdaptiveMotorControlLab/CEBRA/blob/main/LICENSE.md
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+import numpy as np
+import pytest
+import torch
+from torch import nn
+
+import cebra.data
+import cebra.datasets
+import cebra.models
+import cebra.solver
+
+device = "cpu"
+
+NUM_STEPS = 10
+BATCHES = [25_000, 50_000, 75_000]
+MODELS = ["offset1-model", "offset10-model", "offset40-model-4x-subsample"]
+
+
+@pytest.mark.parametrize(
+ "inputs, add_padding, offset, start_batch_idx, end_batch_idx, expected_output",
+ [
+ # Test case 1: No padding
+ (torch.tensor([[1, 2], [3, 4], [5, 6]]), False, cebra.data.Offset(
+ 0, 1), 0, 2, torch.tensor([[1, 2], [3, 4]])), # first batch
+ (torch.tensor([[1, 2], [3, 4], [5, 6]]), False, cebra.data.Offset(
+ 0, 1), 1, 3, torch.tensor([[3, 4], [5, 6]])), # last batch
+ (torch.tensor(
+ [[1, 2], [3, 4], [5, 6], [7, 8]]), False, cebra.data.Offset(
+ 0, 1), 1, 3, torch.tensor([[3, 4], [5, 6]])), # middle batch
+
+ # Test case 2: First batch with padding
+ (
+ torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]),
+ True,
+ cebra.data.Offset(0, 1),
+ 0,
+ 2,
+ torch.tensor([[1, 2, 3], [4, 5, 6]]),
+ ),
+ (
+ torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]),
+ True,
+ cebra.data.Offset(1, 1),
+ 0,
+ 3,
+ torch.tensor([[1, 2, 3], [1, 2, 3], [4, 5, 6], [7, 8, 9]]),
+ ),
+
+ # Test case 3: Last batch with padding
+ (
+ torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]),
+ True,
+ cebra.data.Offset(0, 1),
+ 1,
+ 3,
+ torch.tensor([[4, 5, 6], [7, 8, 9]]),
+ ),
+ (
+ torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12],
+ [13, 14, 15]]),
+ True,
+ cebra.data.Offset(1, 2),
+ 1,
+ 3,
+ torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]),
+ ),
+
+ # Test case 4: Middle batch with padding
+ (
+ torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]),
+ True,
+ cebra.data.Offset(0, 1),
+ 1,
+ 3,
+ torch.tensor([[4, 5, 6], [7, 8, 9]]),
+ ),
+ (
+ torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]),
+ True,
+ cebra.data.Offset(1, 1),
+ 1,
+ 3,
+ torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]),
+ ),
+ (
+ torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12],
+ [13, 14, 15]]),
+ True,
+ cebra.data.Offset(0, 1),
+ 2,
+ 4,
+ torch.tensor([[7, 8, 9], [10, 11, 12]]),
+ ),
+ (
+ torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12]]),
+ True,
+ cebra.data.Offset(0, 1),
+ 0,
+ 3,
+ torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]]),
+ ),
+ # Padding without offset (should raise an error)
+ (torch.tensor([[1, 2]]), True, None, 0, 2, ValueError),
+ # Negative start_batch_idx or end_batch_idx (should raise an error)
+ (torch.tensor([[1, 2]]), False, cebra.data.Offset(
+ 0, 1), -1, 2, ValueError),
+ # out of bound indices because offset is too large
+ (torch.tensor([[1, 2], [3, 4]]), True, cebra.data.Offset(
+ 5, 5), 1, 2, ValueError),
+ # Batch length is smaller than offset.
+ (torch.tensor([[1, 2], [3, 4]]), False, cebra.data.Offset(
+ 0, 1), 0, 1, ValueError), # first batch
+ ],
+)
+def test_get_batch(inputs, add_padding, offset, start_batch_idx, end_batch_idx,
+ expected_output):
+ if expected_output == ValueError:
+ with pytest.raises(ValueError):
+ cebra.solver.base._get_batch(inputs, offset, start_batch_idx,
+ end_batch_idx, add_padding)
+ else:
+ result = cebra.solver.base._get_batch(inputs, offset, start_batch_idx,
+ end_batch_idx, add_padding)
+ assert torch.equal(result, expected_output)
+
+
+def create_model(model_name, input_dimension):
+ return cebra.models.init(model_name,
+ num_neurons=input_dimension,
+ num_units=128,
+ num_output=3)
+
+
+@pytest.mark.parametrize(
+ "data_name, model_name, session_id, loader_initfunc, solver_initfunc",
+ [(dataset, model, session_id, loader, cebra.solver.SingleSessionSolver)
+ for dataset, loader in [("demo-discrete", cebra.data.DiscreteDataLoader),
+ ("demo-continuous", cebra.data.ContinuousDataLoader
+ ), ("demo-mixed", cebra.data.MixedDataLoader)]
+ for model in ["offset1-model", "offset10-model"]
+ for session_id in [None, 0, 5]] +
+ [(dataset, model, session_id, loader,
+ cebra.solver.SingleSessionHybridSolver)
+ for dataset, loader in [
+ ("demo-continuous", cebra.data.HybridDataLoader),
+ ]
+ for model in ["offset1-model", "offset10-model"]
+ for session_id in [None, 0, 5]])
+def test_select_model_single_session(data_name, model_name, session_id,
+ loader_initfunc, solver_initfunc):
+ dataset = cebra.datasets.init(data_name)
+ model = create_model(model_name, dataset.input_dimension)
+ dataset.configure_for(model)
+ offset = model.get_offset()
+ solver = solver_initfunc(model=model, criterion=None, optimizer=None)
+
+ with pytest.raises(ValueError):
+ solver.n_features = 1000
+ solver._select_model(inputs=dataset.neural, session_id=0)
+
+ solver.n_features = dataset.neural.shape[1]
+ if session_id is not None and session_id > 0:
+ with pytest.raises(RuntimeError):
+ solver._select_model(inputs=dataset.neural, session_id=session_id)
+ else:
+ model_, offset_ = solver._select_model(inputs=dataset.neural,
+ session_id=session_id)
+ assert offset.left == offset_.left and offset.right == offset_.right
+ assert model == model_
+
+
+@pytest.mark.parametrize(
+ "data_name, model_name, session_id, loader_initfunc, solver_initfunc",
+ [(dataset, model, session_id, loader, cebra.solver.MultiSessionSolver)
+ for dataset, loader in [
+ ("demo-continuous-multisession",
+ cebra.data.ContinuousMultiSessionDataLoader),
+ ]
+ for model in ["offset1-model", "offset10-model"]
+ for session_id in [None, 0, 1, 5, 2, 6, 4]])
+def test_select_model_multi_session(data_name, model_name, session_id,
+ loader_initfunc, solver_initfunc):
+
+ dataset = cebra.datasets.init(data_name)
+ kwargs = dict(num_steps=NUM_STEPS, batch_size=32)
+ loader = loader_initfunc(dataset, **kwargs)
+
+ model = nn.ModuleList([
+ create_model(model_name, dataset.input_dimension)
+ for dataset in dataset.iter_sessions()
+ ])
+ dataset.configure_for(model)
+
+ offset = model[0].get_offset()
+ solver = solver_initfunc(model=model,
+ criterion=cebra.models.InfoNCE(),
+ optimizer=torch.optim.Adam(model.parameters(),
+ lr=1e-3))
+
+ loader_kwargs = dict(num_steps=NUM_STEPS, batch_size=32)
+ loader = cebra.data.ContinuousMultiSessionDataLoader(
+ dataset, **loader_kwargs)
+ solver.fit(loader)
+
+ for i, (model, dataset_) in enumerate(zip(model, dataset.iter_sessions())):
+ inputs = dataset_.neural
+
+ if session_id is None or session_id >= dataset.num_sessions:
+ with pytest.raises(RuntimeError):
+ solver._select_model(inputs, session_id=session_id)
+ elif i != session_id:
+ with pytest.raises(ValueError):
+ solver._select_model(inputs, session_id=session_id)
+ else:
+ model_, offset_ = solver._select_model(inputs,
+ session_id=session_id)
+ assert offset.left == offset_.left and offset.right == offset_.right
+ assert model == model_
+
+
+@pytest.mark.parametrize(
+ "data_name, model_name, padding, batch_size_inference, loader_initfunc, solver_initfunc",
+ [(dataset, model, padding, batch_size, loader,
+ cebra.solver.SingleSessionSolver)
+ for dataset, loader in [("demo-discrete", cebra.data.DiscreteDataLoader),
+ ("demo-continuous", cebra.data.ContinuousDataLoader
+ ), ("demo-mixed", cebra.data.MixedDataLoader)]
+ for model in
+ ["offset1-model", "offset10-model", "offset40-model-4x-subsample"]
+ for padding in [True, False]
+ for batch_size in BATCHES] +
+ [(dataset, model, padding, batch_size, loader,
+ cebra.solver.SingleSessionHybridSolver)
+ for dataset, loader in [
+ ("demo-continuous", cebra.data.HybridDataLoader),
+ ]
+ for model in MODELS
+ for padding in [True, False]
+ for batch_size in BATCHES])
+def test_batched_transform_single_session(
+ data_name,
+ model_name,
+ padding,
+ batch_size_inference,
+ loader_initfunc,
+ solver_initfunc,
+):
+ dataset = cebra.datasets.init(data_name)
+ model = create_model(model_name, dataset.input_dimension)
+ dataset.configure_for(model)
+ loader_kwargs = dict(num_steps=NUM_STEPS, batch_size=32)
+ loader = loader_initfunc(dataset, **loader_kwargs)
+
+ criterion = cebra.models.InfoNCE()
+ optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
+
+ solver = solver_initfunc(model=model,
+ criterion=criterion,
+ optimizer=optimizer)
+ solver.fit(loader)
+
+ embedding_batched = solver.transform(inputs=loader.dataset.neural,
+ batch_size=batch_size_inference,
+ pad_before_transform=padding)
+
+ embedding = solver.transform(inputs=loader.dataset.neural,
+ pad_before_transform=padding)
+
+ assert embedding_batched.shape == embedding.shape
+ assert np.allclose(embedding_batched, embedding, rtol=1e-4, atol=1e-4)
+
+
+@pytest.mark.parametrize(
+ "data_name, model_name,padding,batch_size_inference,loader_initfunc, solver_initfunc",
+ [(dataset, model, padding, batch_size, loader,
+ cebra.solver.MultiSessionSolver)
+ for dataset, loader in [
+ ("demo-continuous-multisession",
+ cebra.data.ContinuousMultiSessionDataLoader),
+ ]
+ for model in
+ ["offset1-model", "offset10-model", "offset40-model-4x-subsample"]
+ for padding in [True, False]
+ for batch_size in BATCHES])
+def test_batched_transform_multi_session(data_name, model_name, padding,
+ batch_size_inference, loader_initfunc,
+ solver_initfunc):
+ dataset = cebra.datasets.init(data_name)
+ model = nn.ModuleList([
+ create_model(model_name, dataset.input_dimension)
+ for dataset in dataset.iter_sessions()
+ ])
+ dataset.configure_for(model)
+
+ n_samples = dataset._datasets[0].neural.shape[0]
+ assert all(
+ d.neural.shape[0] == n_samples for d in dataset._datasets
+ ), "for this set all of the sessions need to have same number of samples."
+
+ loader_kwargs = dict(num_steps=NUM_STEPS, batch_size=32)
+ loader = loader_initfunc(dataset, **loader_kwargs)
+
+ criterion = cebra.models.InfoNCE()
+ optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
+
+ solver = solver_initfunc(model=model,
+ criterion=criterion,
+ optimizer=optimizer)
+ solver.fit(loader)
+
+ # Transform each session with the right model, by providing
+ # the corresponding session ID
+ for i, inputs in enumerate(dataset.iter_sessions()):
+ embedding = solver.transform(inputs=inputs.neural,
+ session_id=i,
+ pad_before_transform=padding)
+ embedding_batched = solver.transform(inputs=inputs.neural,
+ session_id=i,
+ pad_before_transform=padding,
+ batch_size=batch_size_inference)
+
+ assert embedding_batched.shape == embedding.shape
+ assert np.allclose(embedding_batched, embedding, rtol=1e-4, atol=1e-4)
diff --git a/tests/test_usecases.py b/tests/test_usecases.py
index 22195bd8..f0cc308a 100644
--- a/tests/test_usecases.py
+++ b/tests/test_usecases.py
@@ -29,7 +29,6 @@
"""
import itertools
-import pickle
import numpy as np
import pytest
diff --git a/tools/build_docker.sh b/tools/build_docker.sh
index 76aa8228..cec031a0 100755
--- a/tools/build_docker.sh
+++ b/tools/build_docker.sh
@@ -3,6 +3,21 @@
set -e
+# Parse command line arguments
+RUN_FULL_TESTS=false
+while [[ $# -gt 0 ]]; do
+ case $1 in
+ --full-tests)
+ RUN_FULL_TESTS=true
+ shift
+ ;;
+ *)
+ echo "Unknown option: $1"
+ exit 1
+ ;;
+ esac
+done
+
if [[ -z $(git status --porcelain) ]]; then
TAG=$(git rev-parse --short HEAD)
else
@@ -23,13 +38,20 @@ docker build \
-t $DOCKERNAME .
docker tag $DOCKERNAME $LATEST
+# Determine whether to run full tests or not
+if [[ "$RUN_FULL_TESTS" == "true" ]]; then
+ echo "Running full test suite including tests that require datasets"
+else
+ echo "Running tests that don't require datasets"
+fi
+
docker run \
--gpus 2 \
${extra_kwargs[@]} \
-v ${CEBRA_DATADIR:-./data}:/data \
--env CEBRA_DATADIR=/data \
--network host \
- -it $DOCKERNAME python -m pytest --ff -x -m "not requires_dataset" --doctest-modules ./docs/source/usage.rst tests cebra
+ -it $DOCKERNAME python -m pytest --ff -x $([ "$RUN_FULL_TESTS" != "true" ] && echo '-m "not requires_dataset"') --doctest-modules ./docs/source/usage.rst tests cebra
#docker push $DOCKERNAME
#docker push $LATEST
diff --git a/tools/build_docs.sh b/tools/build_docs.sh
index 3f5f36cd..119272ed 100755
--- a/tools/build_docs.sh
+++ b/tools/build_docs.sh
@@ -1,79 +1,17 @@
#!/bin/bash
-# Locally build the documentation and display it in a webserver.
-set -xe
-
-git_checkout_or_pull() {
- local repo=$1
- local target_dir=$2
- # TODO(stes): theoretically we could also auto-update the repo,
- # I commented this out for now to avoid interference with local
- # dev/changes
- #if [ -d "$target_dir" ]; then
- # cd "$target_dir"
- # git pull --ff-only origin main
- # cd -
- #else
- if [ ! -d "$target_dir" ]; then
- git clone "$repo" "$target_dir"
- fi
-}
-
-checkout_cebra_figures() {
- git_checkout_or_pull git@github.com:AdaptiveMotorControlLab/cebra-figures.git docs/source/cebra-figures
-}
-
-checkout_assets() {
- git_checkout_or_pull git@github.com:AdaptiveMotorControlLab/cebra-assets.git assets
-}
-
-checkout_cebra_demos() {
- git_checkout_or_pull git@github.com:AdaptiveMotorControlLab/cebra-demos.git docs/source/demo_notebooks
-}
-
-setup_python() {
- python -m pip install --upgrade pip setuptools wheel
- sudo apt-get install -y pandoc
- pip install torch --extra-index-url=https://download.pytorch.org/whl/cpu
- pip install '.[docs]'
-}
-
-build_docs() {
- cp -r assets/* .
- export SPHINXOPTS="-W --keep-going -n"
- (cd docs && PYTHONPATH=.. make page)
-}
-
-serve() {
- python -m http.server 8080 --b 0.0.0.0 -d docs/build/html
-}
-
-main() {
- build_docs
- serve
-}
-
-if [[ "$1" == "--build" ]]; then
- main
-fi
-
-docker build -t cebra-docs -f - . << "EOF"
-FROM python:3.9
-RUN python -m pip install --upgrade pip setuptools wheel \
- && apt-get update -y && apt-get install -y pandoc git
-RUN pip install torch --extra-index-url=https://download.pytorch.org/whl/cpu
-COPY dist/cebra-0.4.0-py2.py3-none-any.whl .
-RUN pip install 'cebra-0.4.0-py2.py3-none-any.whl[docs]'
-EOF
-
-checkout_cebra_figures
-checkout_assets
-checkout_cebra_demos
-
-docker run \
- -p 127.0.0.1:8080:8080 \
- -u $(id -u):$(id -g) \
- -v .:/app -w /app \
- --tmpfs /.config --tmpfs /.cache \
- -it cebra-docs \
- ./tools/build_docs.sh --build
+docker build -t cebra-docs -f docs/Dockerfile .
+
+docker run -u $(id -u):$(id -g) \
+ -p 127.0.0.1:8000:8000 \
+ -v $(pwd):/app \
+ -v /tmp/.cache/pip:/.cache/pip \
+ -v /tmp/.cache/sphinx:/.cache/sphinx \
+ -v /tmp/.cache/matplotlib:/.cache/matplotlib \
+ -v /tmp/.cache/fontconfig:/.cache/fontconfig \
+ -e MPLCONFIGDIR=/tmp/.cache/matplotlib \
+ -w /app \
+ --env SPHINXBUILD="sphinx-autobuild" \
+ --env SPHINXOPTS="-W --keep-going -n --port 8000 --host 0.0.0.0" \
+ -it cebra-docs \
+ make docs
diff --git a/tools/bump_version.sh b/tools/bump_version.sh
index fbc161b1..17142f7e 100755
--- a/tools/bump_version.sh
+++ b/tools/bump_version.sh
@@ -1,7 +1,7 @@
#!/bin/bash
# Bump the CEBRA version to the specified value.
# Edits all relevant files at once.
-#
+#
# Usage:
# tools/bump_version.sh 0.3.1rc1
@@ -10,24 +10,40 @@ if [ -z ${version} ]; then
>&1 echo "Specify a version number."
>&1 echo "Usage:"
>&1 echo "tools/bump_version.sh "
+ exit 1
+fi
+
+# Determine the correct sed command based on the OS
+# On macOS, the `sed` command requires an empty string argument after `-i` for in-place editing.
+# On Linux and other Unix-like systems, the `sed` command only requires `-i` for in-place editing.
+if [[ "$OSTYPE" == "darwin"* ]]; then
+ # macOS
+ SED_CMD="sed -i .bkp -e"
+else
+ # Linux and other Unix-like systems
+ SED_CMD="sed -i -e"
fi
# python cebra version
-sed -i "s/__version__ = .*/__version__ = \"${version}\"/" \
- cebra/__init__.py
+$SED_CMD "s/__version__ = .*/__version__ = \"${version}\"/" cebra/__init__.py
# reinstall script in root
-sed -i "s/VERSION=.*/VERSION=${version}/" \
- reinstall.sh
+$SED_CMD "s/VERSION=.*/VERSION=${version}/" reinstall.sh
# Makefile
-sed -i "s/CEBRA_VERSION := .*/CEBRA_VERSION := ${version}/" \
- Makefile
+$SED_CMD "s/CEBRA_VERSION := .*/CEBRA_VERSION := ${version}/" Makefile
-# Arch linux PKGBUILD
-sed -i "s/pkgver=.*/pkgver=${version}/" \
- PKGBUILD
+# Arch linux PKGBUILD
+$SED_CMD "s/pkgver=.*/pkgver=${version}/" PKGBUILD
# Dockerfile
-sed -i "s/ENV WHEEL=cebra-.*\.whl/ENV WHEEL=cebra-${version}-py2.py3-none-any.whl/" \
- Dockerfile
+$SED_CMD "s/ENV WHEEL=cebra-.*\.whl/ENV WHEEL=cebra-${version}-py3-none-any.whl/" Dockerfile
+
+# build_docs.sh
+$SED_CMD "s/COPY dist\/cebra-.*-py3-none-any\.whl/COPY dist\/cebra-${version}-py3-none-any.whl/" tools/build_docs.sh
+$SED_CMD "s/RUN pip install 'cebra-.*-py3-none-any\.whl/RUN pip install 'cebra-${version}-py3-none-any.whl/" tools/build_docs.sh
+
+# Remove backup files
+if [[ "$OSTYPE" == "darwin"* ]]; then
+ rm cebra/__init__.py.bkp reinstall.sh.bkp Makefile.bkp PKGBUILD.bkp Dockerfile.bkp tools/build_docs.sh.bkp
+fi