From 9898850a560357f176441b37e422eb6d95a0c475 Mon Sep 17 00:00:00 2001 From: Steffen Schneider Date: Sun, 27 Oct 2024 19:23:19 +0100 Subject: [PATCH 01/86] Fix linting errors in tests (#188) * apply auto-fixes * Fix linting errors in tests/ * Fix version check --- tests/test_api.py | 1 - tests/test_cli.py | 3 --- tests/test_criterions.py | 3 +-- tests/test_datasets.py | 6 ------ tests/test_demo.py | 1 - tests/test_distributions.py | 6 +++--- tests/test_grid_search.py | 1 - tests/test_integration_train.py | 1 - tests/test_load.py | 8 ++------ tests/test_models.py | 4 ++-- tests/test_plot.py | 4 +--- tests/test_registry.py | 6 +++--- tests/test_sklearn.py | 7 ++----- tests/test_solver.py | 8 ++++---- tests/test_usecases.py | 1 - 15 files changed, 18 insertions(+), 42 deletions(-) diff --git a/tests/test_api.py b/tests/test_api.py index bc279cbd..4e514429 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -21,6 +21,5 @@ # def test_api(): import cebra.distributions - from cebra.distributions import TimedeltaDistribution cebra.distributions.TimedeltaDistribution diff --git a/tests/test_cli.py b/tests/test_cli.py index 41e67f42..8e49cc35 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -19,6 +19,3 @@ # See the License for the specific language governing permissions and # limitations under the License. # -import argparse - -import pytest diff --git a/tests/test_criterions.py b/tests/test_criterions.py index 93a3b846..0d6f8ff2 100644 --- a/tests/test_criterions.py +++ b/tests/test_criterions.py @@ -19,7 +19,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # -import numpy as np import pytest import torch from torch import nn @@ -294,7 +293,7 @@ def _sample_dist_matrices(seed): @pytest.mark.parametrize("seed", [42, 4242, 424242]) -def test_infonce(seed): +def test_infonce_check_output_parts(seed): pos_dist, neg_dist = _sample_dist_matrices(seed) ref_loss, ref_align, ref_uniform = _reference_infonce(pos_dist, neg_dist) diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 4bea0cf0..e8e03ff0 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -69,8 +69,6 @@ def test_demo(): @pytest.mark.requires_dataset def test_hippocampus(): pytest.skip("Outdated") - - from cebra.datasets import hippocampus # noqa: F401 dataset = cebra.datasets.init("rat-hippocampus-single") loader = cebra.data.ContinuousDataLoader( dataset=dataset, @@ -99,8 +97,6 @@ def test_hippocampus(): @pytest.mark.requires_dataset def test_monkey(): - from cebra.datasets import monkey_reaching # noqa: F401 - dataset = cebra.datasets.init( "area2-bump-pos-active-passive", path=pathlib.Path(_DEFAULT_DATADIR) / "monkey_reaching_preload_smth_40", @@ -111,8 +107,6 @@ def test_monkey(): @pytest.mark.requires_dataset def test_allen(): - from cebra.datasets import allen # noqa: F401 - pytest.skip("Test takes too long") ca_dataset = cebra.datasets.init("allen-movie-one-ca-VISp-100-train-10-111") diff --git a/tests/test_demo.py b/tests/test_demo.py index 4f0f146c..ce555db3 100644 --- a/tests/test_demo.py +++ b/tests/test_demo.py @@ -21,7 +21,6 @@ # import glob import re -import sys import pytest diff --git a/tests/test_distributions.py b/tests/test_distributions.py index d7151fd1..2b704391 100644 --- a/tests/test_distributions.py +++ b/tests/test_distributions.py @@ -43,7 +43,7 @@ def prepare(N=1000, n=128, d=5, probs=[0.3, 0.1, 0.6], device="cpu"): continuous = torch.randn(N, d).to(device) rand = torch.from_numpy(np.random.randint(0, N, (n,))).to(device) - qidx = discrete[rand].to(device) + _ = discrete[rand].to(device) query = continuous[rand] + 0.1 * torch.randn(n, d).to(device) query = query.to(device) @@ -173,7 +173,7 @@ def test_mixed(): discrete, continuous) reference_idx = distribution.sample_prior(10) - positive_idx = distribution.sample_conditional(reference_idx) + _ = distribution.sample_conditional(reference_idx) # The conditional distribution p(· | disc, cont) should yield # samples where the label exactly matches the reference sample. @@ -193,7 +193,7 @@ def test_continuous(benchmark): def _test_distribution(dist): distribution = dist(continuous) reference_idx = distribution.sample_prior(10) - positive_idx = distribution.sample_conditional(reference_idx) + _ = distribution.sample_conditional(reference_idx) return distribution distribution = _test_distribution( diff --git a/tests/test_grid_search.py b/tests/test_grid_search.py index 3f88ba12..c774ea02 100644 --- a/tests/test_grid_search.py +++ b/tests/test_grid_search.py @@ -20,7 +20,6 @@ # limitations under the License. # import numpy as np -import pytest import cebra import cebra.grid_search diff --git a/tests/test_integration_train.py b/tests/test_integration_train.py index 06e6da40..238bbea7 100644 --- a/tests/test_integration_train.py +++ b/tests/test_integration_train.py @@ -20,7 +20,6 @@ # limitations under the License. # import itertools -from typing import List import pytest import torch diff --git a/tests/test_load.py b/tests/test_load.py index 6f62dc92..2a9ef3b5 100644 --- a/tests/test_load.py +++ b/tests/test_load.py @@ -22,10 +22,7 @@ import itertools import pathlib import pickle -import platform import tempfile -import unittest -from unittest.mock import patch import h5py import hdf5storage @@ -125,7 +122,7 @@ def generate_numpy_confounder(filename, dtype): @register("npz") -def generate_numpy_path(filename, dtype): +def generate_numpy_path_2(filename, dtype): A = np.arange(1000, dtype=dtype).reshape(10, 100) np.savez(filename, array=A, other_data="test") loaded_A = cebra_load.load(pathlib.Path(filename)) @@ -418,7 +415,7 @@ def generate_csv_path(filename, dtype): @register_error("csv") def generate_csv_empty_file(filename, dtype): - with open(filename, "w") as creating_new_csv_file: + with open(filename, "w") as _: pass _ = cebra_load.load(filename) @@ -619,7 +616,6 @@ def generate_pickle_invalid_key(filename, dtype): @register_error("pkl", "p") def generate_pickle_no_array(filename, dtype): - A = np.arange(1000, dtype=dtype).reshape(10, 100) with open(filename, "wb") as f: pickle.dump({"A": "test_1", "B": "test_2"}, f) _ = cebra_load.load(filename) diff --git a/tests/test_models.py b/tests/test_models.py index 2a6e4812..d41dc7ab 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -155,8 +155,8 @@ def test_version_check(version, raises): cebra.models.model._check_torch_version(raise_error=True) -def test_version_check(): - raises = not cebra.models.model._check_torch_version(raise_error=False) +def test_version_check_dropout_available(): + raises = cebra.models.model._check_torch_version(raise_error=False) if raises: assert len(cebra.models.get_options("*dropout*")) == 0 else: diff --git a/tests/test_plot.py b/tests/test_plot.py index 3f44d887..1d94d310 100644 --- a/tests/test_plot.py +++ b/tests/test_plot.py @@ -72,8 +72,6 @@ def test_plot_imports(): def test_colormaps(): import matplotlib - import cebra - cmap = matplotlib.colormaps["cebra"] assert cmap is not None plt.scatter([1], [2], c=[2], cmap="cebra") @@ -241,7 +239,7 @@ def test_compare_models(): _ = cebra_plot.compare_models(models, labels=long_labels, ax=ax) with pytest.raises(ValueError, match="Invalid.*labels"): invalid_labels = copy.deepcopy(labels) - ele = invalid_labels.pop() + _ = invalid_labels.pop() invalid_labels.append(["a"]) _ = cebra_plot.compare_models(models, labels=invalid_labels, ax=ax) diff --git a/tests/test_registry.py b/tests/test_registry.py index 69e04f38..cd27344c 100644 --- a/tests/test_registry.py +++ b/tests/test_registry.py @@ -117,7 +117,7 @@ def test_override(): _Foo1 = test_module.register("foo")(Foo) assert _Foo1 == Foo assert _Foo1 != Bar - assert f"foo" in test_module.get_options() + assert "foo" in test_module.get_options() # Check that the class was actually added to the module assert ( @@ -137,7 +137,7 @@ def test_override(): _Foo2 = test_module.register("foo", override=True)(Bar) assert _Foo2 != Foo assert _Foo2 == Bar - assert f"foo" in test_module.get_options() + assert "foo" in test_module.get_options() def test_depreciation(): @@ -145,7 +145,7 @@ def test_depreciation(): Foo = _make_class() _Foo1 = test_module.register("foo")(Foo) assert _Foo1 == Foo - assert f"foo" in test_module.get_options() + assert "foo" in test_module.get_options() # Registering the same class under different names # also raises and error diff --git a/tests/test_sklearn.py b/tests/test_sklearn.py index e409c0e3..3b9d309b 100644 --- a/tests/test_sklearn.py +++ b/tests/test_sklearn.py @@ -276,7 +276,6 @@ def test_api(estimator, check): pytest.skip(f"Model architecture {estimator.model_architecture} " f"requires longer input sizes than 20 samples.") - success = True exception = None num_successful = 0 total_runs = 0 @@ -334,7 +333,6 @@ def test_sklearn(model_architecture, device): y_c1 = np.random.uniform(0, 1, (1000, 5)) y_c1_s2 = np.random.uniform(0, 1, (800, 5)) y_c2 = np.random.uniform(0, 1, (1000, 2)) - y_c2_s2 = np.random.uniform(0, 1, (800, 2)) y_d = np.random.randint(0, 10, (1000,)) y_d_s2 = np.random.randint(0, 10, (800,)) @@ -817,7 +815,6 @@ def test_sklearn_full(model_architecture, device, pad_before_transform): X = np.random.uniform(0, 1, (1000, 50)) y_c1 = np.random.uniform(0, 1, (1000, 5)) y_c2 = np.random.uniform(0, 1, (1000, 2)) - y_d = np.random.randint(0, 10, (1000,)) # time contrastive cebra_model.fit(X) @@ -883,7 +880,7 @@ def test_sklearn_resampling_model_not_yet_supported(model_architecture, device): with pytest.raises(ValueError): cebra_model.fit(X, y_c1) - output = cebra_model.transform(X) + _ = cebra_model.transform(X) def _iterate_actions(): @@ -1097,7 +1094,7 @@ def test_move_cpu_to_cuda_device(device): def test_move_cpu_to_mps_device(device): if not cebra.helper._is_mps_availabe(torch): - pytest.skip(f"MPS device is not available") + pytest.skip("MPS device is not available") X = np.random.uniform(0, 1, (10, 5)) cebra_model = cebra_sklearn_cebra.CEBRA(model_architecture="offset1-model", diff --git a/tests/test_solver.py b/tests/test_solver.py index 3107be30..65f49f71 100644 --- a/tests/test_solver.py +++ b/tests/test_solver.py @@ -19,7 +19,6 @@ # See the License for the specific language governing permissions and # limitations under the License. # -import itertools import pytest import torch @@ -100,11 +99,12 @@ def test_single_session(data_name, loader_initfunc, solver_initfunc): @pytest.mark.parametrize("data_name, loader_initfunc, solver_initfunc", single_session_tests) def test_single_session_auxvar(data_name, loader_initfunc, solver_initfunc): - return # TODO + + pytest.skip("Not yet supported") loader = _get_loader(data_name, loader_initfunc) model = _make_model(loader.dataset) - behavior_model = _make_behavior_model(loader.dataset) + behavior_model = _make_behavior_model(loader.dataset) # noqa: F841 criterion = cebra.models.InfoNCE() optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) @@ -172,7 +172,7 @@ def test_multi_session(data_name, loader_initfunc, solver_initfunc): @pytest.mark.parametrize("data_name, loader_initfunc, solver_initfunc", multi_session_tests) -def test_multi_session(data_name, loader_initfunc, solver_initfunc): +def test_multi_session_2(data_name, loader_initfunc, solver_initfunc): loader = _get_loader(data_name, loader_initfunc) criterion = cebra.models.InfoNCE() model = nn.ModuleList( diff --git a/tests/test_usecases.py b/tests/test_usecases.py index 22195bd8..f0cc308a 100644 --- a/tests/test_usecases.py +++ b/tests/test_usecases.py @@ -29,7 +29,6 @@ """ import itertools -import pickle import numpy as np import pytest From 36a91c7afe95c7b9b24d6dfb3bc7734f6ed68303 Mon Sep 17 00:00:00 2001 From: Steffen Schneider Date: Fri, 8 Nov 2024 07:33:23 +0000 Subject: [PATCH 02/86] Fix `scikit-learn` reference in conda environment files (#195) --- conda/cebra_paper.yml | 2 +- conda/cebra_paper_m1.yml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/conda/cebra_paper.yml b/conda/cebra_paper.yml index e7537756..4b9e2b6e 100644 --- a/conda/cebra_paper.yml +++ b/conda/cebra_paper.yml @@ -39,7 +39,7 @@ dependencies: - "cebra[dev,integrations,datasets,demos]" - joblib - literate-dataclasses - - sklearn + - scikit-learn - scipy - torch - keras==2.3.1 diff --git a/conda/cebra_paper_m1.yml b/conda/cebra_paper_m1.yml index 32256758..3d8cd7b9 100644 --- a/conda/cebra_paper_m1.yml +++ b/conda/cebra_paper_m1.yml @@ -48,7 +48,7 @@ dependencies: - tensorflow-metal - joblib - literate-dataclasses - - sklearn + - scikit-learn - scipy - torch - umap-learn From 5f46c3257952a08dfa9f9e1b149a85f7f12c1053 Mon Sep 17 00:00:00 2001 From: Steffen Schneider Date: Mon, 16 Dec 2024 20:32:47 +0100 Subject: [PATCH 03/86] Add support for new __sklearn_tags__ (#205) * Add support for new __sklearn_tags__ * fix inheritance order * Add more tests * fix added test --- .github/workflows/build.yml | 13 ++++++++++++- cebra/integrations/sklearn/cebra.py | 18 +++++++++++++++++- 2 files changed, 29 insertions(+), 2 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index a231258f..ef9e1777 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -19,10 +19,16 @@ jobs: # as well as selected previous versions on # https://pytorch.org/get-started/previous-versions/ torch-version: ["2.2.2", "2.4.0"] + sklearn-version: ["latest"] include: - os: windows-latest torch-version: 2.4.0 python-version: "3.10" + sklearn-version: "latest" + - os: ubuntu-latest + torch-version: 2.4.0 + python-version: "3.10" + sklearn-version: "legacy" runs-on: ${{ matrix.os }} @@ -32,7 +38,7 @@ jobs: uses: actions/cache@v3 with: path: ~/.cache/pip - key: pip-os_${{ runner.os }}-python_${{ matrix.python-version }}-torch_${{ matrix.torch-version }} + key: pip-os_${{ runner.os }}-python_${{ matrix.python-version }}-torch_${{ matrix.torch-version }}-sklearn_${{ matrix.sklearn-version }} - name: Checkout code uses: actions/checkout@v2 @@ -48,6 +54,11 @@ jobs: python -m pip install torch==${{ matrix.torch-version }} --extra-index-url https://download.pytorch.org/whl/cpu pip install '.[dev,datasets,integrations]' + - name: Check sklearn legacy version + if: matrix.sklearn-version == 'legacy' + run: | + pip install scikit-learn==1.4.2 '.[dev,datasets,integrations]' + - name: Run the formatter run: | make format diff --git a/cebra/integrations/sklearn/cebra.py b/cebra/integrations/sklearn/cebra.py index 046d3344..9a74eeb6 100644 --- a/cebra/integrations/sklearn/cebra.py +++ b/cebra/integrations/sklearn/cebra.py @@ -30,8 +30,10 @@ import pkg_resources import sklearn.utils.validation as sklearn_utils_validation import torch +import sklearn from sklearn.base import BaseEstimator from sklearn.base import TransformerMixin +from sklearn.utils.metaestimators import available_if from torch import nn import cebra.data @@ -41,6 +43,11 @@ import cebra.models import cebra.solver +def check_version(estimator): + # NOTE(stes): required as a check for the old way of specifying tags + # https://github.com/scikit-learn/scikit-learn/pull/29677#issuecomment-2334229165 + from packaging import version + return version.parse(sklearn.__version__) < version.parse("1.6.dev") def _init_loader( is_cont: bool, @@ -364,7 +371,7 @@ def _load_cebra_with_sklearn_backend(cebra_info: Dict) -> "CEBRA": return cebra_ -class CEBRA(BaseEstimator, TransformerMixin): +class CEBRA(TransformerMixin, BaseEstimator): """CEBRA model defined as part of a ``scikit-learn``-like API. Attributes: @@ -1294,6 +1301,15 @@ def fit_transform( callback_frequency=callback_frequency) return self.transform(X) + def __sklearn_tags__(self): + # NOTE(stes): from 1.6.dev, this is the new way to specify tags + # https://scikit-learn.org/dev/developers/develop.html + # https://github.com/scikit-learn/scikit-learn/pull/29677#issuecomment-2334229165 + tags = super().__sklearn_tags__() + tags.non_deterministic = True + return tags + + @available_if(check_version) def _more_tags(self): # NOTE(stes): This tag is needed as seeding is not fully implemented in the # current version of CEBRA. From 7a4d3fcaefd4fde76717423a2a4a5e96a0f0a26d Mon Sep 17 00:00:00 2001 From: Steffen Schneider Date: Wed, 22 Jan 2025 00:11:39 +0100 Subject: [PATCH 04/86] Update workflows to actions/setup-python@v5, actions/cache@v4 (#212) --- .github/workflows/build.yml | 8 ++++---- .github/workflows/doc-coverage.yml | 6 +++--- .github/workflows/docs.yml | 12 ++++++------ .github/workflows/release-pypi.yml | 2 +- 4 files changed, 14 insertions(+), 14 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index ef9e1777..3c4f68dd 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -25,7 +25,7 @@ jobs: torch-version: 2.4.0 python-version: "3.10" sklearn-version: "latest" - - os: ubuntu-latest + - os: ubuntu-latest torch-version: 2.4.0 python-version: "3.10" sklearn-version: "legacy" @@ -35,7 +35,7 @@ jobs: steps: - name: Cache dependencies id: pip-cache - uses: actions/cache@v3 + uses: actions/cache@v4 with: path: ~/.cache/pip key: pip-os_${{ runner.os }}-python_${{ matrix.python-version }}-torch_${{ matrix.torch-version }}-sklearn_${{ matrix.sklearn-version }} @@ -44,7 +44,7 @@ jobs: uses: actions/checkout@v2 - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} @@ -54,7 +54,7 @@ jobs: python -m pip install torch==${{ matrix.torch-version }} --extra-index-url https://download.pytorch.org/whl/cpu pip install '.[dev,datasets,integrations]' - - name: Check sklearn legacy version + - name: Check sklearn legacy version if: matrix.sklearn-version == 'legacy' run: | pip install scikit-learn==1.4.2 '.[dev,datasets,integrations]' diff --git a/.github/workflows/doc-coverage.yml b/.github/workflows/doc-coverage.yml index 268cbee0..8d7f0522 100644 --- a/.github/workflows/doc-coverage.yml +++ b/.github/workflows/doc-coverage.yml @@ -22,7 +22,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ['3.8'] + python-version: ['3.9'] steps: # NOTE(stes) currently not used, we check @@ -31,14 +31,14 @@ jobs: # with: # ref: main - uses: actions/checkout@v3 - - uses: actions/cache@v1 + - uses: actions/cache@v4 with: path: ~/.cache/pip key: ${{ runner.os }}-pip restore-keys: | ${{ runner.os }}-pip - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v1 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - name: Install package diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index 83c9d829..47b5862d 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -17,7 +17,7 @@ jobs: steps: - name: Cache dependencies id: pip-cache - uses: actions/cache@v3 + uses: actions/cache@v4 with: path: ~/.cache/pip key: ${{ runner.os }}-pip @@ -52,7 +52,7 @@ jobs: ref: main - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} @@ -60,12 +60,12 @@ jobs: run: | python -m pip install --upgrade pip setuptools wheel # NOTE(stes) Pandoc version must be at least (2.14.2) but less than (4.0.0). - # as of 29/10/23. Ubuntu 22.04 which is used for ubuntu-latest only has an + # as of 29/10/23. Ubuntu 22.04 which is used for ubuntu-latest only has an # old pandoc version (2.9.). We will hence install the latest version manually. # previou: sudo apt-get install -y pandoc - wget https://github.com/jgm/pandoc/releases/download/3.1.9/pandoc-3.1.9-1-amd64.deb - sudo dpkg -i pandoc-3.1.9-1-amd64.deb - rm pandoc-3.1.9-1-amd64.deb + wget https://github.com/jgm/pandoc/releases/download/3.1.9/pandoc-3.1.9-1-amd64.deb + sudo dpkg -i pandoc-3.1.9-1-amd64.deb + rm pandoc-3.1.9-1-amd64.deb pip install torch --extra-index-url https://download.pytorch.org/whl/cpu pip install '.[docs]' diff --git a/.github/workflows/release-pypi.yml b/.github/workflows/release-pypi.yml index d6950119..fc6d5c8e 100644 --- a/.github/workflows/release-pypi.yml +++ b/.github/workflows/release-pypi.yml @@ -23,7 +23,7 @@ jobs: steps: - name: Cache dependencies id: pip-cache - uses: actions/cache@v3 + uses: actions/cache@v4 with: path: ~/.cache/pip key: ${{ runner.os }}-pip From a79c2dee465bb144077690345446bdfc41f4bf73 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=8Dcaro?= Date: Wed, 22 Jan 2025 07:52:19 +0100 Subject: [PATCH 05/86] Fix deprecation warning force_all_finite -> ensure_all_finite for sklearn>=1.6 (#206) --- cebra/integrations/sklearn/utils.py | 22 ++++++++++++++++++---- 1 file changed, 18 insertions(+), 4 deletions(-) diff --git a/cebra/integrations/sklearn/utils.py b/cebra/integrations/sklearn/utils.py index 455213a3..d9bb3083 100644 --- a/cebra/integrations/sklearn/utils.py +++ b/cebra/integrations/sklearn/utils.py @@ -22,12 +22,26 @@ import warnings import numpy.typing as npt +import packaging +import sklearn import sklearn.utils.validation as sklearn_utils_validation import torch import cebra.helper +def _sklearn_check_array(array, **kwargs): + # NOTE(stes): See discussion in https://github.com/AdaptiveMotorControlLab/CEBRA/pull/206 + # https://scikit-learn.org/1.6/modules/generated/sklearn.utils.check_array.html + # force_all_finite was renamed to ensure_all_finite and will be removed in 1.8. + if packaging.version.parse( + sklearn.__version__) < packaging.version.parse("1.6"): + if "ensure_all_finite" in kwargs: + kwargs["force_all_finite"] = kwargs["ensure_all_finite"] + del kwargs["ensure_all_finite"] + return sklearn_utils_validation.check_array(array, **kwargs) + + def update_old_param(old: dict, new: dict, kwargs: dict, default) -> tuple: """Handle deprecated arguments of a function until they are replaced. @@ -74,15 +88,15 @@ def check_input_array(X: npt.NDArray, *, min_samples: int) -> npt.NDArray: Returns: The converted and validated array. """ - return sklearn_utils_validation.check_array( + return _sklearn_check_array( X, accept_sparse=False, accept_large_sparse=False, dtype=("float16", "float32", "float64"), order=None, copy=False, - force_all_finite=True, ensure_2d=True, + ensure_all_finite=True, allow_nd=False, ensure_min_samples=min_samples, ensure_min_features=1, @@ -105,15 +119,15 @@ def check_label_array(y: npt.NDArray, *, min_samples: int): Returns: The converted and validated labels. """ - return sklearn_utils_validation.check_array( + return _sklearn_check_array( y, accept_sparse=False, accept_large_sparse=False, dtype="numeric", order=None, copy=False, - force_all_finite=True, ensure_2d=False, + ensure_all_finite=True, allow_nd=False, ensure_min_samples=min_samples, ) From 7e74edac88b7d54eb19ae218caa2bf2bfed38d36 Mon Sep 17 00:00:00 2001 From: Steffen Schneider Date: Wed, 29 Jan 2025 14:27:15 -0500 Subject: [PATCH 06/86] Add tests to check legacy model loading (#214) --- tests/_build_legacy_model/.gitignore | 1 + tests/_build_legacy_model/Dockerfile | 39 +++++++++++++++++++++ tests/_build_legacy_model/README.md | 13 +++++++ tests/_build_legacy_model/create_model.py | 15 +++++++++ tests/_build_legacy_model/generate.sh | 3 ++ tests/test_sklearn_legacy.py | 41 +++++++++++++++++++++++ 6 files changed, 112 insertions(+) create mode 100644 tests/_build_legacy_model/.gitignore create mode 100644 tests/_build_legacy_model/Dockerfile create mode 100644 tests/_build_legacy_model/README.md create mode 100644 tests/_build_legacy_model/create_model.py create mode 100755 tests/_build_legacy_model/generate.sh create mode 100644 tests/test_sklearn_legacy.py diff --git a/tests/_build_legacy_model/.gitignore b/tests/_build_legacy_model/.gitignore new file mode 100644 index 00000000..4b6ebe5f --- /dev/null +++ b/tests/_build_legacy_model/.gitignore @@ -0,0 +1 @@ +*.pt diff --git a/tests/_build_legacy_model/Dockerfile b/tests/_build_legacy_model/Dockerfile new file mode 100644 index 00000000..ddbb0e61 --- /dev/null +++ b/tests/_build_legacy_model/Dockerfile @@ -0,0 +1,39 @@ +FROM python:3.12-slim AS base +RUN pip install torch --index-url https://download.pytorch.org/whl/cpu +RUN apt-get update && \ + apt-get install -y --no-install-recommends git && \ + rm -rf /var/lib/apt/lists/* + +FROM base AS cebra-0.4.0-scikit-learn-1.4 +RUN pip install cebra==0.4.0 "scikit-learn<1.5" +WORKDIR /app +COPY create_model.py . +RUN python create_model.py + +FROM base AS cebra-0.4.0-scikit-learn-1.6 +RUN pip install cebra==0.4.0 "scikit-learn>=1.6" +WORKDIR /app +COPY create_model.py . +RUN python create_model.py + +FROM base AS cebra-rc-scikit-learn-1.4 +# NOTE(stes): Commit where new scikit-learn tag logic was added to the CEBRA class. +# https://github.com/AdaptiveMotorControlLab/CEBRA/commit/5f46c3257952a08dfa9f9e1b149a85f7f12c1053 +RUN pip install git+https://github.com/AdaptiveMotorControlLab/CEBRA.git@5f46c3257952a08dfa9f9e1b149a85f7f12c1053 "scikit-learn<1.5" +WORKDIR /app +COPY create_model.py . +RUN python create_model.py + +FROM base AS cebra-rc-scikit-learn-1.6 +# NOTE(stes): Commit where new scikit-learn tag logic was added to the CEBRA class. +# https://github.com/AdaptiveMotorControlLab/CEBRA/commit/5f46c3257952a08dfa9f9e1b149a85f7f12c1053 +RUN pip install git+https://github.com/AdaptiveMotorControlLab/CEBRA.git@5f46c3257952a08dfa9f9e1b149a85f7f12c1053 "scikit-learn>=1.6" +WORKDIR /app +COPY create_model.py . +RUN python create_model.py + +FROM scratch +COPY --from=cebra-0.4.0-scikit-learn-1.4 /app/cebra_model.pt /cebra_model_cebra-0.4.0-scikit-learn-1.4.pt +COPY --from=cebra-0.4.0-scikit-learn-1.6 /app/cebra_model.pt /cebra_model_cebra-0.4.0-scikit-learn-1.6.pt +COPY --from=cebra-rc-scikit-learn-1.4 /app/cebra_model.pt /cebra_model_cebra-rc-scikit-learn-1.4.pt +COPY --from=cebra-rc-scikit-learn-1.6 /app/cebra_model.pt /cebra_model_cebra-rc-scikit-learn-1.6.pt diff --git a/tests/_build_legacy_model/README.md b/tests/_build_legacy_model/README.md new file mode 100644 index 00000000..4bcffa2b --- /dev/null +++ b/tests/_build_legacy_model/README.md @@ -0,0 +1,13 @@ +# Helper script to build CEBRA checkpoints + +This script builds CEBRA checkpoints for different versions of scikit-learn and CEBRA. +To build all models, run: + +```bash +./generate.sh +``` + +The models are currently also stored in git directly due to their small size. + +Related issue: https://github.com/AdaptiveMotorControlLab/CEBRA/issues/207 +Related test: tests/test_sklearn_legacy.py diff --git a/tests/_build_legacy_model/create_model.py b/tests/_build_legacy_model/create_model.py new file mode 100644 index 00000000..f308d296 --- /dev/null +++ b/tests/_build_legacy_model/create_model.py @@ -0,0 +1,15 @@ +import numpy as np + +import cebra + +neural_data = np.random.normal(0, 1, (1000, 30)) # 1000 samples, 30 features +cebra_model = cebra.CEBRA(model_architecture="offset10-model", + batch_size=512, + learning_rate=1e-4, + max_iterations=10, + time_offsets=10, + num_hidden_units=16, + output_dimension=8, + verbose=True) +cebra_model.fit(neural_data) +cebra_model.save("cebra_model.pt") diff --git a/tests/_build_legacy_model/generate.sh b/tests/_build_legacy_model/generate.sh new file mode 100755 index 00000000..749a0d32 --- /dev/null +++ b/tests/_build_legacy_model/generate.sh @@ -0,0 +1,3 @@ +#!/bin/bash + +DOCKER_BUILDKIT=1 docker build --output type=local,dest=. . diff --git a/tests/test_sklearn_legacy.py b/tests/test_sklearn_legacy.py new file mode 100644 index 00000000..4d74515f --- /dev/null +++ b/tests/test_sklearn_legacy.py @@ -0,0 +1,41 @@ +import pathlib +import urllib.request + +import numpy as np +import pytest + +from cebra.integrations.sklearn.cebra import CEBRA + +MODEL_VARIANTS = [ + "cebra-0.4.0-scikit-learn-1.4", "cebra-0.4.0-scikit-learn-1.6", + "cebra-rc-scikit-learn-1.4", "cebra-rc-scikit-learn-1.6" +] + + +@pytest.mark.parametrize("model_variant", MODEL_VARIANTS) +def test_load_legacy_model(model_variant): + """Test loading a legacy CEBRA model.""" + + X = np.random.normal(0, 1, (1000, 30)) + + model_path = pathlib.Path( + __file__ + ).parent / "_build_legacy_model" / f"cebra_model_{model_variant}.pt" + + if not model_path.exists(): + url = f"https://cebra.fra1.digitaloceanspaces.com/cebra_model_{model_variant}.pt" + urllib.request.urlretrieve(url, model_path) + + loaded_model = CEBRA.load(model_path) + + assert loaded_model.model_architecture == "offset10-model" + assert loaded_model.output_dimension == 8 + assert loaded_model.num_hidden_units == 16 + assert loaded_model.time_offsets == 10 + + output = loaded_model.transform(X) + assert isinstance(output, np.ndarray) + assert output.shape[1] == loaded_model.output_dimension + + assert hasattr(loaded_model, "state_dict_") + assert hasattr(loaded_model, "n_features_") From 4e3266147bf6c9e08ec4cef210129c38ae94d6fe Mon Sep 17 00:00:00 2001 From: Steffen Schneider Date: Sun, 2 Feb 2025 11:59:12 -0500 Subject: [PATCH 07/86] Add improved goodness of fit implementation (#190) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Started implementing improved goodness of fit implementation * add tests and improve implementation * Fix examples * Fix docstring error * Handle batch size = None for goodness of fit computation * adapt GoF implementation * Fix docstring tests * Update docstring for goodness_of_fit_score Co-authored-by: Célia Benquet <32598028+CeliaBenquet@users.noreply.github.com> * add annotations to goodness_of_fit_history Co-authored-by: Célia Benquet <32598028+CeliaBenquet@users.noreply.github.com> * fix typo Co-authored-by: Célia Benquet <32598028+CeliaBenquet@users.noreply.github.com> * improve err message Co-authored-by: Célia Benquet <32598028+CeliaBenquet@users.noreply.github.com> * make numerical test less conversative * Add tests for exception handling * fix tests --------- Co-authored-by: Célia Benquet <32598028+CeliaBenquet@users.noreply.github.com> --- cebra/integrations/sklearn/metrics.py | 143 ++++++++++++++++++++++++++ tests/test_sklearn_metrics.py | 129 +++++++++++++++++++++++ 2 files changed, 272 insertions(+) diff --git a/cebra/integrations/sklearn/metrics.py b/cebra/integrations/sklearn/metrics.py index ccecaa11..0af44ecb 100644 --- a/cebra/integrations/sklearn/metrics.py +++ b/cebra/integrations/sklearn/metrics.py @@ -108,6 +108,149 @@ def infonce_loss( return avg_loss +def goodness_of_fit_score(cebra_model: cebra_sklearn_cebra.CEBRA, + X: Union[npt.NDArray, torch.Tensor], + *y, + session_id: Optional[int] = None, + num_batches: int = 500) -> float: + """Compute the goodness of fit score on a *single session* dataset on the model. + + This function uses the :func:`infonce_loss` function to compute the InfoNCE loss + for a given `cebra_model` and the :func:`infonce_to_goodness_of_fit` function + to derive the goodness of fit from the InfoNCE loss. + + Args: + cebra_model: The model to use to compute the InfoNCE loss on the samples. + X: A 2D data matrix, corresponding to a *single session* recording. + y: An arbitrary amount of continuous indices passed as 2D matrices, and up to one + discrete index passed as a 1D array. Each index has to match the length of ``X``. + session_id: The session ID, an :py:class:`int` between 0 and :py:attr:`cebra.CEBRA.num_sessions` + for multisession, set to ``None`` for single session. + num_batches: The number of iterations to consider to evaluate the model on the new data. + Higher values will give a more accurate estimate. Set it to at least 500 iterations. + + Returns: + The average GoF score estimated over ``num_batches`` batches from the data distribution. + + Related: + :func:`infonce_to_goodness_of_fit` + + Example: + + >>> import cebra + >>> import numpy as np + >>> neural_data = np.random.uniform(0, 1, (1000, 20)) + >>> cebra_model = cebra.CEBRA(max_iterations=10, batch_size = 512) + >>> cebra_model.fit(neural_data) + CEBRA(batch_size=512, max_iterations=10) + >>> gof = cebra.sklearn.metrics.goodness_of_fit_score(cebra_model, neural_data) + """ + loss = infonce_loss(cebra_model, + X, + *y, + session_id=session_id, + num_batches=num_batches, + correct_by_batchsize=False) + return infonce_to_goodness_of_fit(loss, cebra_model) + + +def goodness_of_fit_history(model: cebra_sklearn_cebra.CEBRA) -> np.ndarray: + """Return the history of the goodness of fit score. + + Args: + model: A trained CEBRA model. + + Returns: + A numpy array containing the goodness of fit values, measured in bits. + + Related: + :func:`infonce_to_goodness_of_fit` + + Example: + + >>> import cebra + >>> import numpy as np + >>> neural_data = np.random.uniform(0, 1, (1000, 20)) + >>> cebra_model = cebra.CEBRA(max_iterations=10, batch_size = 512) + >>> cebra_model.fit(neural_data) + CEBRA(batch_size=512, max_iterations=10) + >>> gof_history = cebra.sklearn.metrics.goodness_of_fit_history(cebra_model) + """ + infonce = np.array(model.state_dict_["log"]["total"]) + return infonce_to_goodness_of_fit(infonce, model) + + +def infonce_to_goodness_of_fit( + infonce: Union[float, np.ndarray], + model: Optional[cebra_sklearn_cebra.CEBRA] = None, + batch_size: Optional[int] = None, + num_sessions: Optional[int] = None) -> Union[float, np.ndarray]: + """Given a trained CEBRA model, return goodness of fit metric. + + The goodness of fit ranges from 0 (lowest meaningful value) + to a positive number with the unit "bits", the higher the + better. + + Values lower than 0 bits are possible, but these only occur + due to numerical effects. A perfectly collapsed embedding + (e.g., because the data cannot be fit with the provided + auxiliary variables) will have a goodness of fit of 0. + + The conversion between the generalized InfoNCE metric that + CEBRA is trained with and the goodness of fit computed with this + function is + + .. math:: + + S = \\log N - \\text{InfoNCE} + + To use this function, either provide a trained CEBRA model or the + batch size and number of sessions. + + Args: + infonce: The InfoNCE loss, either a single value or an iterable of values. + model: The trained CEBRA model. + batch_size: The batch size used to train the model. + num_sessions: The number of sessions used to train the model. + + Returns: + Numpy array containing the goodness of fit values, measured in bits + + Raises: + RuntimeError: If the provided model is not fit to data. + ValueError: If both ``model`` and ``(batch_size, num_sessions)`` are provided. + """ + if model is not None: + if batch_size is not None or num_sessions is not None: + raise ValueError( + "batch_size and num_sessions should not be provided if model is provided." + ) + if not hasattr(model, "state_dict_"): + raise RuntimeError("Fit the CEBRA model first.") + if model.batch_size is None: + raise ValueError( + "Computing the goodness of fit is not yet supported for " + "models trained on the full dataset (batchsize = None). ") + batch_size = model.batch_size + num_sessions = model.num_sessions_ + if num_sessions is None: + num_sessions = 1 + + if model.batch_size is None: + raise ValueError( + "Computing the goodness of fit is not yet supported for " + "models trained on the full dataset (batchsize = None). ") + else: + if batch_size is None or num_sessions is None: + raise ValueError( + f"batch_size ({batch_size}) and num_sessions ({num_sessions})" + f"should be provided if model is not provided.") + + nats_to_bits = np.log2(np.e) + chance_level = np.log(batch_size * num_sessions) + return (chance_level - infonce) * nats_to_bits + + def _consistency_scores( embeddings: List[Union[npt.NDArray, torch.Tensor]], datasets: List[Union[int, str]], diff --git a/tests/test_sklearn_metrics.py b/tests/test_sklearn_metrics.py index 58e12010..4e765ba7 100644 --- a/tests/test_sklearn_metrics.py +++ b/tests/test_sklearn_metrics.py @@ -383,3 +383,132 @@ def test_sklearn_runs_consistency(): with pytest.raises(ValueError, match="Invalid.*embeddings"): _, _, _ = cebra_sklearn_metrics.consistency_score( invalid_embeddings_runs, between="runs") + + +@pytest.mark.parametrize("seed", [42, 24, 10]) +def test_goodness_of_fit_score(seed): + """ + Ensure that the GoF score is close to 0 for a model fit on random data. + """ + cebra_model = cebra_sklearn_cebra.CEBRA( + model_architecture="offset1-model", + max_iterations=5, + batch_size=512, + ) + generator = torch.Generator().manual_seed(seed) + X = torch.rand(5000, 50, dtype=torch.float32, generator=generator) + y = torch.rand(5000, 5, dtype=torch.float32, generator=generator) + cebra_model.fit(X, y) + score = cebra_sklearn_metrics.goodness_of_fit_score(cebra_model, + X, + y, + session_id=0, + num_batches=500) + assert isinstance(score, float) + assert np.isclose(score, 0, atol=0.01) + + +@pytest.mark.parametrize("seed", [42, 24, 10]) +def test_goodness_of_fit_history(seed): + """ + Ensure that the GoF score is higher for a model fit on data with underlying + structure than for a model fit on random data. + """ + + # Generate data + generator = torch.Generator().manual_seed(seed) + X = torch.rand(1000, 50, dtype=torch.float32, generator=generator) + y_random = torch.rand(len(X), 5, dtype=torch.float32, generator=generator) + linear_map = torch.randn(50, 5, dtype=torch.float32, generator=generator) + y_linear = X @ linear_map + + def _fit_and_get_history(X, y): + cebra_model = cebra_sklearn_cebra.CEBRA( + model_architecture="offset1-model", + max_iterations=150, + batch_size=512, + device="cpu") + cebra_model.fit(X, y) + history = cebra_sklearn_metrics.goodness_of_fit_history(cebra_model) + # NOTE(stes): Ignore the first 5 iterations, they can have nonsensical values + # due to numerical issues. + return history[5:] + + history_random = _fit_and_get_history(X, y_random) + history_linear = _fit_and_get_history(X, y_linear) + + assert isinstance(history_random, np.ndarray) + assert history_random.shape[0] > 0 + # NOTE(stes): Ignore the first 5 iterations, they can have nonsensical values + # due to numerical issues. + history_random_non_negative = history_random[history_random >= 0] + np.testing.assert_allclose(history_random_non_negative, 0, atol=0.075) + + assert isinstance(history_linear, np.ndarray) + assert history_linear.shape[0] > 0 + + assert np.all(history_linear[-20:] > history_random[-20:]) + + +@pytest.mark.parametrize("seed", [42, 24, 10]) +def test_infonce_to_goodness_of_fit(seed): + """Test the conversion from InfoNCE loss to goodness of fit metric.""" + # Test with model + cebra_model = cebra_sklearn_cebra.CEBRA( + model_architecture="offset10-model", + max_iterations=5, + batch_size=128, + ) + generator = torch.Generator().manual_seed(seed) + X = torch.rand(1000, 50, dtype=torch.float32, generator=generator) + cebra_model.fit(X) + + # Test single value + gof = cebra_sklearn_metrics.infonce_to_goodness_of_fit(1.0, + model=cebra_model) + assert isinstance(gof, float) + + # Test array of values + infonce_values = np.array([1.0, 2.0, 3.0]) + gof_array = cebra_sklearn_metrics.infonce_to_goodness_of_fit( + infonce_values, model=cebra_model) + assert isinstance(gof_array, np.ndarray) + assert gof_array.shape == infonce_values.shape + + # Test with explicit batch_size and num_sessions + gof = cebra_sklearn_metrics.infonce_to_goodness_of_fit(1.0, + batch_size=128, + num_sessions=1) + assert isinstance(gof, float) + + # Test error cases + with pytest.raises(ValueError, match="batch_size.*should not be provided"): + cebra_sklearn_metrics.infonce_to_goodness_of_fit(1.0, + model=cebra_model, + batch_size=128) + + with pytest.raises(ValueError, match="batch_size.*should not be provided"): + cebra_sklearn_metrics.infonce_to_goodness_of_fit(1.0, + model=cebra_model, + num_sessions=1) + + # Test with unfitted model + unfitted_model = cebra_sklearn_cebra.CEBRA(max_iterations=5) + with pytest.raises(RuntimeError, match="Fit the CEBRA model first"): + cebra_sklearn_metrics.infonce_to_goodness_of_fit(1.0, + model=unfitted_model) + + # Test with model having batch_size=None + none_batch_model = cebra_sklearn_cebra.CEBRA(batch_size=None, + max_iterations=5) + none_batch_model.fit(X) + with pytest.raises(ValueError, match="Computing the goodness of fit"): + cebra_sklearn_metrics.infonce_to_goodness_of_fit(1.0, + model=none_batch_model) + + # Test missing batch_size or num_sessions when model is None + with pytest.raises(ValueError, match="batch_size.*and num_sessions"): + cebra_sklearn_metrics.infonce_to_goodness_of_fit(1.0, batch_size=128) + + with pytest.raises(ValueError, match="batch_size.*and num_sessions"): + cebra_sklearn_metrics.infonce_to_goodness_of_fit(1.0, num_sessions=1) From 3100730c99a1129e27d7794e0199e2b88ed66e3b Mon Sep 17 00:00:00 2001 From: Steffen Schneider Date: Sun, 2 Feb 2025 18:41:55 -0500 Subject: [PATCH 08/86] Support numpy 2, upgrade tests to support torch 2.6 (#221) * Drop numpy constraint * Implement workaround for pytables * better error message * pin numpy only for python 3.9 * update dependencies * Upgrade torch version * Fix based on python version * Add support for torch.load with weights_only=True * Implement safe loading for torch models starting in torch 2.6 * Fix windows specs * fix docstring * Revert changes to loading logic --- .github/workflows/build.yml | 2 +- cebra/data/load.py | 26 +++++++++++++--- cebra/integrations/sklearn/cebra.py | 48 ++++++++++++++++++++++++----- setup.cfg | 6 ++-- tests/test_dlc.py | 7 ++--- tests/test_load.py | 22 ++++++------- 6 files changed, 80 insertions(+), 31 deletions(-) diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 3c4f68dd..5fed4c79 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -18,7 +18,7 @@ jobs: # We aim to support the versions on pytorch.org # as well as selected previous versions on # https://pytorch.org/get-started/previous-versions/ - torch-version: ["2.2.2", "2.4.0"] + torch-version: ["2.4.0", "2.6.0"] sklearn-version: ["latest"] include: - os: windows-latest diff --git a/cebra/data/load.py b/cebra/data/load.py index 6f1b86e5..02714ad0 100644 --- a/cebra/data/load.py +++ b/cebra/data/load.py @@ -275,11 +275,11 @@ def _is_dlc_df(h5_file: IO[bytes], df_keys: List[str]) -> bool: """ try: if ["_i_table", "table"] in df_keys: - df = pd.read_hdf(h5_file, key="table") + df = read_hdf(h5_file, key="table") else: - df = pd.read_hdf(h5_file, key=df_keys[0]) + df = read_hdf(h5_file, key=df_keys[0]) except KeyError: - df = pd.read_hdf(h5_file) + df = read_hdf(h5_file) return all(value in df.columns.names for value in ["scorer", "bodyparts", "coords"]) @@ -348,7 +348,7 @@ def load_from_h5(file: Union[pathlib.Path, str], key: str, Returns: A :py:func:`numpy.array` containing the data of interest extracted from the :py:class:`pandas.DataFrame`. """ - df = pd.read_hdf(file, key=key) + df = read_hdf(file, key=key) if columns is None: loaded_array = df.values elif isinstance(columns, list) and df.columns.nlevels == 1: @@ -716,3 +716,21 @@ def _get_loader(file_ending: str) -> _BaseLoader: if file_ending not in __loaders.keys() or file_ending == "": raise OSError(f"File ending {file_ending} not supported.") return __loaders[file_ending] + + +def read_hdf(filename, key=None): + """Read HDF5 file using pandas, with fallback to h5py if pandas fails. + + Args: + filename: Path to HDF5 file + key: Optional key to read from HDF5 file. If None, tries "df_with_missing" + then falls back to first available key. + + Returns: + pandas.DataFrame: The loaded data + + Raises: + RuntimeError: If both pandas and h5py fail to load the file + """ + + return pd.read_hdf(filename, key=key) diff --git a/cebra/integrations/sklearn/cebra.py b/cebra/integrations/sklearn/cebra.py index 9a74eeb6..c902307d 100644 --- a/cebra/integrations/sklearn/cebra.py +++ b/cebra/integrations/sklearn/cebra.py @@ -27,10 +27,11 @@ import numpy as np import numpy.typing as npt +import packaging.version import pkg_resources +import sklearn import sklearn.utils.validation as sklearn_utils_validation import torch -import sklearn from sklearn.base import BaseEstimator from sklearn.base import TransformerMixin from sklearn.utils.metaestimators import available_if @@ -43,11 +44,38 @@ import cebra.models import cebra.solver +# NOTE(stes): From torch 2.6 onwards, we need to specify the following list +# when loading CEBRA models to allow weights_only = True. +CEBRA_LOAD_SAFE_GLOBALS = [ + cebra.data.Offset, torch.torch_version.TorchVersion, np.dtype, + np.dtypes.Float64DType, np.dtypes.Int64DType +] + + def check_version(estimator): # NOTE(stes): required as a check for the old way of specifying tags # https://github.com/scikit-learn/scikit-learn/pull/29677#issuecomment-2334229165 - from packaging import version - return version.parse(sklearn.__version__) < version.parse("1.6.dev") + return packaging.version.parse( + sklearn.__version__) < packaging.version.parse("1.6.dev") + + +def _safe_torch_load(filename, weights_only, **kwargs): + if weights_only is None: + if packaging.version.parse( + torch.__version__) >= packaging.version.parse("2.6.0"): + weights_only = True + else: + weights_only = False + + if not weights_only: + checkpoint = torch.load(filename, weights_only=False, **kwargs) + else: + # NOTE(stes): This is only supported for torch 2.6+ + with torch.serialization.safe_globals(CEBRA_LOAD_SAFE_GLOBALS): + checkpoint = torch.load(filename, weights_only=True, **kwargs) + + return checkpoint + def _init_loader( is_cont: bool, @@ -1409,15 +1437,22 @@ def save(self, def load(cls, filename: str, backend: Literal["auto", "sklearn", "torch"] = "auto", + weights_only: bool = None, **kwargs) -> "CEBRA": """Load a model from disk. Args: filename: The path to the file in which to save the trained model. backend: A string identifying the used backend. + weights_only: Indicates whether unpickler should be restricted to loading only tensors, primitive types, + dictionaries and any types added via :py:func:`torch.serialization.add_safe_globals`. + See :py:func:`torch.load` with ``weights_only=True`` for more details. It it recommended to leave this + at the default value of ``None``, which sets the argument to ``False`` for torch<2.6, and ``True`` for + higher versions of torch. If you experience issues with loading custom models (specified outside + of the CEBRA package), you can try to set this to ``False`` if you trust the source of the model. kwargs: Optional keyword arguments passed directly to the loader. - Return: + Returns: The model to load. Note: @@ -1427,7 +1462,6 @@ def load(cls, For information about the file format please refer to :py:meth:`cebra.CEBRA.save`. Example: - >>> import cebra >>> import numpy as np >>> import tempfile @@ -1441,16 +1475,14 @@ def load(cls, >>> loaded_model = cebra.CEBRA.load(tmp_file) >>> embedding = loaded_model.transform(dataset) >>> tmp_file.unlink() - """ - supported_backends = ["auto", "sklearn", "torch"] if backend not in supported_backends: raise NotImplementedError( f"Unsupported backend: '{backend}'. Supported backends are: {', '.join(supported_backends)}" ) - checkpoint = torch.load(filename, **kwargs) + checkpoint = _safe_torch_load(filename, weights_only, **kwargs) if backend == "auto": backend = "sklearn" if isinstance(checkpoint, dict) else "torch" diff --git a/setup.cfg b/setup.cfg index 68263d73..2addd5d7 100644 --- a/setup.cfg +++ b/setup.cfg @@ -31,11 +31,13 @@ where = python_requires = >=3.9 install_requires = joblib - numpy<2.0.0 + numpy<2.0;platform_system=="Windows" + numpy<2.0;platform_system!="Windows" and python_version<"3.10" + numpy;platform_system!="Windows" and python_version>="3.10" literate-dataclasses scikit-learn scipy - torch + torch>=2.4.0 tqdm matplotlib requests diff --git a/tests/test_dlc.py b/tests/test_dlc.py index a19fe593..8ab29abd 100644 --- a/tests/test_dlc.py +++ b/tests/test_dlc.py @@ -29,6 +29,7 @@ import cebra.integrations.deeplabcut as cebra_dlc from cebra import CEBRA from cebra import load_data +from cebra.data.load import read_hdf # NOTE(stes): The original data URL is # https://github.com/DeepLabCut/DeepLabCut/blob/main/examples @@ -54,11 +55,7 @@ def test_imports(): def _load_dlc_dataframe(filename): - try: - df = pd.read_hdf(filename, "df_with_missing") - except KeyError: - df = pd.read_hdf(filename) - return df + return read_hdf(filename) def _get_annotated_data(url, keypoints): diff --git a/tests/test_load.py b/tests/test_load.py index 2a9ef3b5..4524b29c 100644 --- a/tests/test_load.py +++ b/tests/test_load.py @@ -248,7 +248,7 @@ def generate_h5_no_array(filename, dtype): def generate_h5_dataframe(filename, dtype): A = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) df_A = pd.DataFrame(np.array(A), columns=["a", "b", "c"]) - df_A.to_hdf(filename, "df_A") + df_A.to_hdf(filename, key="df_A") loaded_A = cebra_load.load(filename, key="df_A") return A, loaded_A @@ -258,7 +258,7 @@ def generate_h5_dataframe_columns(filename, dtype): A = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) A_col = A[:, :2] df_A = pd.DataFrame(np.array(A), columns=["a", "b", "c"]) - df_A.to_hdf(filename, "df_A") + df_A.to_hdf(filename, key="df_A") loaded_A = cebra_load.load(filename, key="df_A", columns=["a", "b"]) return A_col, loaded_A @@ -269,8 +269,8 @@ def generate_h5_multi_dataframe(filename, dtype): B = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) df_A = pd.DataFrame(np.array(A), columns=["a", "b", "c"]) df_B = pd.DataFrame(np.array(B), columns=["c", "d", "e"]) - df_A.to_hdf(filename, "df_A") - df_B.to_hdf(filename, "df_B") + df_A.to_hdf(filename, key="df_A") + df_B.to_hdf(filename, key="df_B") loaded_A = cebra_load.load(filename, key="df_A") return A, loaded_A @@ -279,7 +279,7 @@ def generate_h5_multi_dataframe(filename, dtype): def generate_h5_single_dataframe_no_key(filename, dtype): A = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]).astype(dtype) df_A = pd.DataFrame(np.array(A), columns=["a", "b", "c"]) - df_A.to_hdf(filename, "df_A") + df_A.to_hdf(filename, key="df_A") loaded_A = cebra_load.load(filename) return A, loaded_A @@ -290,8 +290,8 @@ def generate_h5_multi_dataframe_no_key(filename, dtype): B = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]).astype(dtype) df_A = pd.DataFrame(np.array(A), columns=["a", "b", "c"]) df_B = pd.DataFrame(np.array(B), columns=["c", "d", "e"]) - df_A.to_hdf(filename, "df_A") - df_B.to_hdf(filename, "df_B") + df_A.to_hdf(filename, key="df_A") + df_B.to_hdf(filename, key="df_B") _ = cebra_load.load(filename) @@ -304,7 +304,7 @@ def generate_h5_multicol_dataframe(filename, dtype): df_A = pd.DataFrame(A, columns=pd.MultiIndex.from_product([animals, keypoints])) - df_A.to_hdf(filename, "df_A") + df_A.to_hdf(filename, key="df_A") loaded_A = cebra_load.load(filename, key="df_A") return A, loaded_A @@ -313,7 +313,7 @@ def generate_h5_multicol_dataframe(filename, dtype): def generate_h5_dataframe_invalid_key(filename, dtype): A = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]).astype(dtype) df_A = pd.DataFrame(np.array(A), columns=["a", "b", "c"]) - df_A.to_hdf(filename, "df_A") + df_A.to_hdf(filename, key="df_A") _ = cebra_load.load(filename, key="df_B") @@ -321,7 +321,7 @@ def generate_h5_dataframe_invalid_key(filename, dtype): def generate_h5_dataframe_invalid_column(filename, dtype): A = np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]]).astype(dtype) df_A = pd.DataFrame(np.array(A), columns=["a", "b", "c"]) - df_A.to_hdf(filename, "df_A") + df_A.to_hdf(filename, key="df_A") _ = cebra_load.load(filename, key="df_A", columns=["d", "b"]) @@ -334,7 +334,7 @@ def generate_h5_multicol_dataframe_columns(filename, dtype): df_A = pd.DataFrame(A, columns=pd.MultiIndex.from_product([animals, keypoints])) - df_A.to_hdf(filename, "df_A") + df_A.to_hdf(filename, key="df_A") _ = cebra_load.load(filename, key="df_A", columns=["a", "b"]) From bea2c041a5210257c10df03b08597b487d995a39 Mon Sep 17 00:00:00 2001 From: Steffen Schneider Date: Sun, 2 Feb 2025 18:55:59 -0500 Subject: [PATCH 09/86] Release 0.5.0rc1 (#189) * Make bump_version script runnable on MacOS * Bump version to 0.5.0rc1 * fix minor formatting issues * remove commented code --------- Co-authored-by: Mackenzie Mathis --- Dockerfile | 2 +- Makefile | 2 +- PKGBUILD | 2 +- cebra/__init__.py | 2 +- cebra/integrations/sklearn/cebra.py | 2 +- docs/source/conf.py | 19 ++++++--------- reinstall.sh | 2 +- tools/bump_version.sh | 36 +++++++++++++++++++---------- 8 files changed, 37 insertions(+), 30 deletions(-) diff --git a/Dockerfile b/Dockerfile index d734ee6f..e8ac14a0 100644 --- a/Dockerfile +++ b/Dockerfile @@ -40,7 +40,7 @@ RUN make dist FROM cebra-base # install the cebra wheel -ENV WHEEL=cebra-0.4.0-py2.py3-none-any.whl +ENV WHEEL=cebra-0.5.0rc1-py2.py3-none-any.whl WORKDIR /build COPY --from=wheel /build/dist/${WHEEL} . RUN pip install --no-cache-dir ${WHEEL}'[dev,integrations,datasets]' diff --git a/Makefile b/Makefile index ca8c5480..a1e8d3b2 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,4 @@ -CEBRA_VERSION := 0.4.0 +CEBRA_VERSION := 0.5.0rc1 dist: python3 -m pip install virtualenv diff --git a/PKGBUILD b/PKGBUILD index 07fa3a1d..91ba4a4e 100644 --- a/PKGBUILD +++ b/PKGBUILD @@ -1,7 +1,7 @@ # Maintainer: Steffen Schneider pkgname=python-cebra _pkgname=cebra -pkgver=0.4.0 +pkgver=0.5.0rc1 pkgrel=1 pkgdesc="Consistent Embeddings of high-dimensional Recordings using Auxiliary variables" url="https://cebra.ai" diff --git a/cebra/__init__.py b/cebra/__init__.py index 204cd2a2..edf1b5ee 100644 --- a/cebra/__init__.py +++ b/cebra/__init__.py @@ -66,7 +66,7 @@ import cebra.integrations.sklearn as sklearn -__version__ = "0.4.0" +__version__ = "0.5.0rc1" __all__ = ["CEBRA"] __allow_lazy_imports = False __lazy_imports = {} diff --git a/cebra/integrations/sklearn/cebra.py b/cebra/integrations/sklearn/cebra.py index c902307d..5fb267ac 100644 --- a/cebra/integrations/sklearn/cebra.py +++ b/cebra/integrations/sklearn/cebra.py @@ -51,7 +51,6 @@ np.dtypes.Float64DType, np.dtypes.Int64DType ] - def check_version(estimator): # NOTE(stes): required as a check for the old way of specifying tags # https://github.com/scikit-learn/scikit-learn/pull/29677#issuecomment-2334229165 @@ -77,6 +76,7 @@ def _safe_torch_load(filename, weights_only, **kwargs): return checkpoint + def _init_loader( is_cont: bool, is_disc: bool, diff --git a/docs/source/conf.py b/docs/source/conf.py index 025a988b..c5e12b5a 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -28,18 +28,13 @@ # -- Path setup -------------------------------------------------------------- -# If extensions (or modules to document with autodoc) are in another directory, -# add these directories to sys.path here. If the directory is relative to the -# documentation root, use os.path.abspath to make it absolute, like shown here. -# +import datetime import os import sys sys.path.insert(0, os.path.abspath(".")) -import datetime - -import cebra +import cebra # noqa: E402 def get_years(start_year=2021): @@ -156,11 +151,6 @@ def get_years(start_year=2021): "url": "https://twitter.com/cebraAI", "icon": "fab fa-twitter", }, - # { - # "name": "DockerHub", - # "url": "https://hub.docker.com/r/stffsc/cebra", - # "icon": "fab fa-docker", - # }, { "name": "PyPI", "url": "https://pypi.org/project/cebra/", @@ -247,6 +237,9 @@ def get_years(start_year=2021): # Download link for the notebook, see # https://nbsphinx.readthedocs.io/en/0.3.0/prolog-and-epilog.html + +# fmt: off +# flake8: noqa: E501 nbsphinx_prolog = r""" .. only:: html @@ -269,3 +262,5 @@ def get_years(start_year=2021): ---- """ +# fmt: on +# flake8: enable=E501 diff --git a/reinstall.sh b/reinstall.sh index 778f98eb..549982a1 100755 --- a/reinstall.sh +++ b/reinstall.sh @@ -15,7 +15,7 @@ pip uninstall -y cebra # Get version info after uninstalling --- this will automatically get the # most recent version based on the source code in the current directory. # $(tools/get_cebra_version.sh) -VERSION=0.4.0 +VERSION=0.5.0rc1 echo "Upgrading to CEBRA v${VERSION}" # Upgrade the build system (PEP517/518 compatible) diff --git a/tools/bump_version.sh b/tools/bump_version.sh index fbc161b1..fb89f413 100755 --- a/tools/bump_version.sh +++ b/tools/bump_version.sh @@ -1,7 +1,7 @@ #!/bin/bash # Bump the CEBRA version to the specified value. # Edits all relevant files at once. -# +# # Usage: # tools/bump_version.sh 0.3.1rc1 @@ -10,24 +10,36 @@ if [ -z ${version} ]; then >&1 echo "Specify a version number." >&1 echo "Usage:" >&1 echo "tools/bump_version.sh " + exit 1 +fi + +# Determine the correct sed command based on the OS +# On macOS, the `sed` command requires an empty string argument after `-i` for in-place editing. +# On Linux and other Unix-like systems, the `sed` command only requires `-i` for in-place editing. +if [[ "$OSTYPE" == "darwin"* ]]; then + # macOS + SED_CMD="sed -i .bkp -e" +else + # Linux and other Unix-like systems + SED_CMD="sed -i -e" fi # python cebra version -sed -i "s/__version__ = .*/__version__ = \"${version}\"/" \ - cebra/__init__.py +$SED_CMD "s/__version__ = .*/__version__ = \"${version}\"/" cebra/__init__.py # reinstall script in root -sed -i "s/VERSION=.*/VERSION=${version}/" \ - reinstall.sh +$SED_CMD "s/VERSION=.*/VERSION=${version}/" reinstall.sh # Makefile -sed -i "s/CEBRA_VERSION := .*/CEBRA_VERSION := ${version}/" \ - Makefile +$SED_CMD "s/CEBRA_VERSION := .*/CEBRA_VERSION := ${version}/" Makefile -# Arch linux PKGBUILD -sed -i "s/pkgver=.*/pkgver=${version}/" \ - PKGBUILD +# Arch linux PKGBUILD +$SED_CMD "s/pkgver=.*/pkgver=${version}/" PKGBUILD # Dockerfile -sed -i "s/ENV WHEEL=cebra-.*\.whl/ENV WHEEL=cebra-${version}-py2.py3-none-any.whl/" \ - Dockerfile +$SED_CMD "s/ENV WHEEL=cebra-.*\.whl/ENV WHEEL=cebra-${version}-py2.py3-none-any.whl/" Dockerfile + +# Remove backup files +if [[ "$OSTYPE" == "darwin"* ]]; then + rm cebra/__init__.py.bkp reinstall.sh.bkp Makefile.bkp PKGBUILD.bkp Dockerfile.bkp +fi From c32ed67ab8bc1ff1afa4e47e1d09bb26ce0a3270 Mon Sep 17 00:00:00 2001 From: Steffen Schneider Date: Sun, 2 Feb 2025 19:46:43 -0500 Subject: [PATCH 10/86] Fix pypi action (#222) * force packaging upgrade to 24.2 for twine * Bump version to 0.5.0rc2 * remove universal compatibility option * revert tag * adapt files to new wheel name due to py3 --- .github/workflows/release-pypi.yml | 7 +++++++ Dockerfile | 2 +- PKGBUILD | 2 +- docs/source/contributing.rst | 4 ++-- pyproject.toml | 3 ++- reinstall.sh | 2 +- setup.cfg | 3 --- tools/build_docs.sh | 4 ++-- tools/bump_version.sh | 8 ++++++-- 9 files changed, 22 insertions(+), 13 deletions(-) diff --git a/.github/workflows/release-pypi.yml b/.github/workflows/release-pypi.yml index fc6d5c8e..ac078fd9 100644 --- a/.github/workflows/release-pypi.yml +++ b/.github/workflows/release-pypi.yml @@ -28,6 +28,13 @@ jobs: path: ~/.cache/pip key: ${{ runner.os }}-pip + - name: Install dependencies + run: | + pip install --upgrade pip + pip install wheel + # NOTE(stes) see https://github.com/pypa/twine/issues/1216#issuecomment-2629069669 + pip install "packaging>=24.2" + - name: Checkout code uses: actions/checkout@v3 diff --git a/Dockerfile b/Dockerfile index e8ac14a0..7cd326d5 100644 --- a/Dockerfile +++ b/Dockerfile @@ -40,7 +40,7 @@ RUN make dist FROM cebra-base # install the cebra wheel -ENV WHEEL=cebra-0.5.0rc1-py2.py3-none-any.whl +ENV WHEEL=cebra-0.5.0rc1-py3-none-any.whl WORKDIR /build COPY --from=wheel /build/dist/${WHEEL} . RUN pip install --no-cache-dir ${WHEEL}'[dev,integrations,datasets]' diff --git a/PKGBUILD b/PKGBUILD index 91ba4a4e..1f8b3db5 100644 --- a/PKGBUILD +++ b/PKGBUILD @@ -40,7 +40,7 @@ build() { package() { cd $srcdir/${_pkgname}-${pkgver} - pip install --ignore-installed --no-deps --root="${pkgdir}" dist/${_pkgname}-${pkgver}-py2.py3-none-any.whl + pip install --ignore-installed --no-deps --root="${pkgdir}" dist/${_pkgname}-${pkgver}-py3-none-any.whl find ${pkgdir} -iname __pycache__ -exec rm -r {} \; 2>/dev/null || echo install -Dm 644 LICENSE.md $pkgdir/usr/share/licenses/${pkgname}/LICENSE } diff --git a/docs/source/contributing.rst b/docs/source/contributing.rst index cc7ae0a8..7fcd16a1 100644 --- a/docs/source/contributing.rst +++ b/docs/source/contributing.rst @@ -155,13 +155,13 @@ Enter the build environment and build the package: host $ make interact docker $ make build # ... outputs ... - Successfully built cebra-X.X.XaX-py2.py3-none-any.whl + Successfully built cebra-X.X.XaX-py3-none-any.whl The built package can be found in ``dist/`` and can be installed locally with .. code:: bash - pip install dist/cebra-X.X.XaX-py2.py3-none-any.whl + pip install dist/cebra-X.X.XaX-py3-none-any.whl **Please do not distribute this package prior to the public release of the CEBRA repository, because it also contains parts of the source code.** diff --git a/pyproject.toml b/pyproject.toml index 4a927c6c..b64475e7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,8 @@ [build-system] requires = [ "setuptools>=43", - "wheel" + "wheel", + "packaging>=24.2" ] build-backend = "setuptools.build_meta" diff --git a/reinstall.sh b/reinstall.sh index 549982a1..ece080b8 100755 --- a/reinstall.sh +++ b/reinstall.sh @@ -24,4 +24,4 @@ python3 -m pip install --upgrade build python3 -m build --sdist --wheel . # Reinstall the package with most recent version -pip install --upgrade --no-cache-dir "dist/cebra-${VERSION}-py2.py3-none-any.whl[datasets,integrations]" +pip install --upgrade --no-cache-dir "dist/cebra-${VERSION}-py3-none-any.whl[datasets,integrations]" diff --git a/setup.cfg b/setup.cfg index 2addd5d7..9da156ec 100644 --- a/setup.cfg +++ b/setup.cfg @@ -112,6 +112,3 @@ dev = # docformatter[tomli] codespell cffconvert - -[bdist_wheel] -universal=1 diff --git a/tools/build_docs.sh b/tools/build_docs.sh index 3f5f36cd..38a7982e 100755 --- a/tools/build_docs.sh +++ b/tools/build_docs.sh @@ -62,8 +62,8 @@ FROM python:3.9 RUN python -m pip install --upgrade pip setuptools wheel \ && apt-get update -y && apt-get install -y pandoc git RUN pip install torch --extra-index-url=https://download.pytorch.org/whl/cpu -COPY dist/cebra-0.4.0-py2.py3-none-any.whl . -RUN pip install 'cebra-0.4.0-py2.py3-none-any.whl[docs]' +COPY dist/cebra-0.5.0rc1-py3-none-any.whl . +RUN pip install 'cebra-0.5.0rc1-py3-none-any.whl[docs]' EOF checkout_cebra_figures diff --git a/tools/bump_version.sh b/tools/bump_version.sh index fb89f413..17142f7e 100755 --- a/tools/bump_version.sh +++ b/tools/bump_version.sh @@ -37,9 +37,13 @@ $SED_CMD "s/CEBRA_VERSION := .*/CEBRA_VERSION := ${version}/" Makefile $SED_CMD "s/pkgver=.*/pkgver=${version}/" PKGBUILD # Dockerfile -$SED_CMD "s/ENV WHEEL=cebra-.*\.whl/ENV WHEEL=cebra-${version}-py2.py3-none-any.whl/" Dockerfile +$SED_CMD "s/ENV WHEEL=cebra-.*\.whl/ENV WHEEL=cebra-${version}-py3-none-any.whl/" Dockerfile + +# build_docs.sh +$SED_CMD "s/COPY dist\/cebra-.*-py3-none-any\.whl/COPY dist\/cebra-${version}-py3-none-any.whl/" tools/build_docs.sh +$SED_CMD "s/RUN pip install 'cebra-.*-py3-none-any\.whl/RUN pip install 'cebra-${version}-py3-none-any.whl/" tools/build_docs.sh # Remove backup files if [[ "$OSTYPE" == "darwin"* ]]; then - rm cebra/__init__.py.bkp reinstall.sh.bkp Makefile.bkp PKGBUILD.bkp Dockerfile.bkp + rm cebra/__init__.py.bkp reinstall.sh.bkp Makefile.bkp PKGBUILD.bkp Dockerfile.bkp tools/build_docs.sh.bkp fi From f99530cfc6d45ab7d7fae86f7fd8db3585d7e920 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=C3=8Dcaro?= Date: Tue, 18 Feb 2025 10:49:06 +0100 Subject: [PATCH 11/86] Update base.py (#224) This is a lazy solution to #223 --- cebra/solver/base.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/cebra/solver/base.py b/cebra/solver/base.py index e95151e5..14a22c68 100644 --- a/cebra/solver/base.py +++ b/cebra/solver/base.py @@ -210,7 +210,8 @@ def fit( self.decoding(loader, valid_loader)) if save_hook is not None: save_hook(num_steps, self) - self.save(logdir, f"checkpoint_{num_steps:#07d}.pth") + if logdir is not None: + self.save(logdir, f"checkpoint_{num_steps:#07d}.pth") def step(self, batch: cebra.data.Batch) -> dict: """Perform a single gradient update. From c822ffa7655942929c8cf19917977459b94c2e96 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9lia=20Benquet?= <32598028+CeliaBenquet@users.noreply.github.com> Date: Sat, 1 Mar 2025 15:41:58 +0100 Subject: [PATCH 12/86] Change max consistency value to 100 instead of 99 (#227) * Change text consistency max from 99 to 100 * Update cebra/integrations/matplotlib.py --------- Co-authored-by: Mackenzie Mathis Co-authored-by: Steffen Schneider --- cebra/integrations/matplotlib.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cebra/integrations/matplotlib.py b/cebra/integrations/matplotlib.py index 30af7fd4..c2696d4a 100644 --- a/cebra/integrations/matplotlib.py +++ b/cebra/integrations/matplotlib.py @@ -684,7 +684,7 @@ def _to_heatmap_format( else: heatmap_values[i, j] = score_dict[label_i, label_j] - return np.minimum(heatmap_values * 100, 99) + return heatmap_values * 100 def _create_text(self): """Create the text to add in the confusion matrix grid and the title.""" From b7133877ad0d8d688bd55fc352f5786bc12f360f Mon Sep 17 00:00:00 2001 From: Mackenzie Mathis Date: Sat, 1 Mar 2025 18:23:50 +0100 Subject: [PATCH 13/86] Update assets.py --> force check for parent dir (#230) Update assets.py - mkdir was failing in 0.5.0rc1; attempt to fix --- cebra/data/assets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cebra/data/assets.py b/cebra/data/assets.py index 86695482..adea8413 100644 --- a/cebra/data/assets.py +++ b/cebra/data/assets.py @@ -93,7 +93,7 @@ def download_file_with_progress_bar(url: str, ) # Create the directory and any necessary parent directories - location_path.mkdir(exist_ok=True) + location_path.mkdir(parents=True, exist_ok=True) filename = filename_match.group(1) file_path = location_path / filename From 47945cad9dd6a51104447a6aa29de871cd3745c3 Mon Sep 17 00:00:00 2001 From: Mackenzie Mathis Date: Sat, 1 Mar 2025 22:59:39 +0100 Subject: [PATCH 14/86] User docs minor edit (#229) * user note added to usage.rst - link added * Update usage.rst - more detailed note on the effect of temp. * Update usage.rst - add in temp to demo model - testout put thanks @stes * Update docs/source/usage.rst Co-authored-by: Steffen Schneider * Update docs/source/usage.rst Co-authored-by: Steffen Schneider * Update docs/source/usage.rst Co-authored-by: Steffen Schneider --------- Co-authored-by: Steffen Schneider --- docs/source/usage.rst | 40 ++++++++++++++++++++++------------------ 1 file changed, 22 insertions(+), 18 deletions(-) diff --git a/docs/source/usage.rst b/docs/source/usage.rst index 334f1bbc..53821e36 100644 --- a/docs/source/usage.rst +++ b/docs/source/usage.rst @@ -1,7 +1,7 @@ Using CEBRA =========== -This page covers a standard CEBRA usage. We recommend checking out the :py:doc:`demos` for in-depth CEBRA usage examples as well. Here we present a quick overview on how to use CEBRA on various datasets. Note that we provide two ways to interact with the code: +This page covers a standard CEBRA usage. We recommend checking out the :py:doc:`demos` for CEBRA usage examples as well. Here we present a quick overview on how to use CEBRA on various datasets. Note that we provide two ways to interact with the code: * For regular usage, we recommend leveraging the **high-level interface**, adhering to ``scikit-learn`` formatting. * Upon specific needs, advanced users might consider diving into the **low-level interface** that adheres to ``PyTorch`` formatting. @@ -12,7 +12,7 @@ Firstly, why use CEBRA? CEBRA is primarily designed for producing robust, consistent extractions of latent factors from time-series data. It supports three modes, and is a self-supervised representation learning algorithm that uses our modified contrastive learning approach designed for multi-modal time-series data. In short, it is a type of non-linear dimensionality reduction, like `tSNE `_ and `UMAP `_. We show in our original paper that it outperforms tSNE and UMAP at producing closer-to-ground-truth latents and is more consistent. -That being said, CEBRA can be used on non-time-series data and it does not strictly require multi-modal data. In general, we recommend considering using CEBRA for measuring changes in consistency across conditions (brain areas, cells, animals), for hypothesis-guided decoding, and for topological exploration of the resulting embedding spaces. It can also be used for visualization and considering dynamics within the embedding space. For examples of how CEBRA can be used to map space, decode natural movies, and make hypotheses for neural coding of sensorimotor systems, see our paper (Schneider, Lee, Mathis, 2023). +That being said, CEBRA can be used on non-time-series data and it does not strictly require multi-modal data. In general, we recommend considering using CEBRA for measuring changes in consistency across conditions (brain areas, cells, animals), for hypothesis-guided decoding, and for topological exploration of the resulting embedding spaces. It can also be used for visualization and considering dynamics within the embedding space. For examples of how CEBRA can be used to map space, decode natural movies, and make hypotheses for neural coding of sensorimotor systems, see `Schneider, Lee, Mathis. Nature 2023 `_. The CEBRA workflow ------------------ @@ -22,7 +22,7 @@ We recommend to start with running CEBRA-Time (unsupervised) and look both at th (1) Use CEBRA-Time for unsupervised data exploration. (2) Consider running a hyperparameter sweep on the inputs to the model, such as :py:attr:`cebra.CEBRA.model_architecture`, :py:attr:`cebra.CEBRA.time_offsets`, :py:attr:`cebra.CEBRA.output_dimension`, and set :py:attr:`cebra.CEBRA.batch_size` to be as high as your GPU allows. You want to see clear structure in the 3D plot (the first 3 latents are shown by default). -(3) Use CEBRA-Behavior with many different labels and combinations, then look at the InfoNCE loss - the lower the loss value, the better the fit (see :py:doc:`cebra-figures/figures/ExtendedDataFigure5`), and visualize the embeddings. The goal is to understand which labels are contributing to the structure you see in CEBRA-Time, and improve this structure. Again, you should consider a hyperparameter sweep. +(3) Use CEBRA-Behavior with many different labels and combinations, then look at the InfoNCE loss - the lower the loss value, the better the fit (see :py:doc:`cebra-figures/figures/ExtendedDataFigure5`), and visualize the embeddings. The goal is to understand which labels are contributing to the structure you see in CEBRA-Time, and improve this structure. Again, you should consider a hyperparameter sweep (and avoid overfitting by performing the proper train/validation split (see Step 3 in our quick start guide below). (4) Interpretability: now you can use these latents in downstream tasks, such as measuring consistency, decoding, and determining the dimensionality of your data with topological data analysis. All the steps to do this are described below. Enjoy using CEBRA! 🔥🦓 @@ -179,7 +179,7 @@ We provide a set of pre-defined models. You can access (and search) a list of av Then, you can choose the one that fits best with your needs and provide it to the CEBRA model as the :py:attr:`~.CEBRA.model_architecture` parameter. -As an indication the table below presents the model architecture we used to train CEBRA on the datasets presented in our paper (Schneider, Lee, Mathis, 2022). +As an indication the table below presents the model architecture we used to train CEBRA on the datasets presented in our paper (Schneider, Lee, Mathis. Nature 2023). .. list-table:: :widths: 25 25 20 30 @@ -265,9 +265,8 @@ For standard usage we recommend the default values (i.e., ``InfoNCE`` and ``cosi .. rubric:: Temperature :py:attr:`~.CEBRA.temperature` -:py:attr:`~.CEBRA.temperature` has the largest effect on visualization of the embedding (see :py:doc:`cebra-figures/figures/ExtendedDataFigure2`). Hence, it is important that it is fitted to your specific data. +:py:attr:`~.CEBRA.temperature` has the largest effect on *visualization* of the embedding (see :py:doc:`cebra-figures/figures/ExtendedDataFigure2`). Hence, it is important that it is fitted to your specific data. Lower temperatures (e.g. around 0.1) will result in a more dispersed embedding, higher temperatures (larger than 1) will concentrate the embedding. -The simplest way to handle it is to use a *learnable temperature*. For that, set :py:attr:`~.CEBRA.temperature_mode` to ``auto``. :py:attr:`~.CEBRA.temperature` will be trained alongside the model. 🚀 For advance usage, you might need to find the optimal :py:attr:`~.CEBRA.temperature`. For that we recommend to perform a grid-search. @@ -307,7 +306,6 @@ Here is an example of a CEBRA model initialization: cebra_model = CEBRA( model_architecture = "offset10-model", batch_size = 1024, - temperature_mode="auto", learning_rate = 0.001, max_iterations = 10, time_offsets = 10, @@ -321,8 +319,7 @@ Here is an example of a CEBRA model initialization: .. testoutput:: CEBRA(batch_size=1024, learning_rate=0.001, max_iterations=10, - model_architecture='offset10-model', temperature_mode='auto', - time_offsets=10) + model_architecture='offset10-model', time_offsets=10) .. admonition:: See API docs :class: dropdown @@ -568,7 +565,8 @@ We provide a simple hyperparameters sweep to compare CEBRA models with different learning_rate = [0.001], time_offsets = 5, max_iterations = 5, - temperature_mode = "auto", + temperature_mode='constant', + temperature = 0.1, verbose = False) # 2. Define the datasets to iterate over @@ -820,7 +818,7 @@ It takes a CEBRA model and returns a 2D plot of the loss against the number of i Displaying the temperature """""""""""""""""""""""""" -:py:attr:`~.CEBRA.temperature` has the largest effect on the visualization of the embedding. Hence it might be interesting to check its evolution when ``temperature_mode=auto``. +:py:attr:`~.CEBRA.temperature` has the largest effect on the visualization of the embedding. Hence it might be interesting to check its evolution when ``temperature_mode=auto``. We recommend only using `auto` if you have first explored the `constant` setting. If you use the ``auto`` mode, please always check the time evolution of the temperature over time alongside the loss curve. To that extend, you can use the function :py:func:`~.plot_temperature`. @@ -1186,9 +1184,10 @@ Improve model performance 🧐 Below is a (non-exhaustive) list of actions you can try if your embedding looks different from what you were expecting. #. Assess that your model `converged `_. For that, observe if the training loss stabilizes itself around the end of the training or still seems to be decreasing. Refer to `Visualize the training loss`_ for more details on how to display the training loss. -#. Increase the number of iterations. It should be at least 10,000. +#. Increase the number of iterations. It typically should be at least 10,000. On small datasets, it can make sense to stop training earlier to avoid overfitting effects. #. Make sure the batch size is big enough. It should be at least 512. #. Fine-tune the model's hyperparameters, namely ``learning_rate``, ``output_dimension``, ``num_hidden_units`` and eventually ``temperature`` (by setting ``temperature_mode`` back to ``constant``). Refer to `Grid search`_ for more details on performing hyperparameters tuning. +#. To note, you should still be mindful of performing train/validation splits and shuffle controls to avoid `overfitting `_. @@ -1202,14 +1201,19 @@ Putting all previous snippet examples together, we obtain the following pipeline import cebra from numpy.random import uniform, randint from sklearn.model_selection import train_test_split + import os + import tempfile + from pathlib import Path # 1. Define a CEBRA model cebra_model = cebra.CEBRA( model_architecture = "offset10-model", batch_size = 512, learning_rate = 1e-4, - max_iterations = 10, # TODO(user): to change to at least 10'000 - max_adapt_iterations = 10, # TODO(user): to change to ~100-500 + temperature_mode='constant', + temperature = 0.1, + max_iterations = 10, # TODO(user): to change to ~500-10000 depending on dataset size + #max_adapt_iterations = 10, # TODO(user): use and to change to ~100-500 if adapting time_offsets = 10, output_dimension = 8, verbose = False @@ -1243,7 +1247,7 @@ Putting all previous snippet examples together, we obtain the following pipeline # time contrastive learning cebra_model.fit(train_data) # discrete behavior contrastive learning - cebra_model.fit(train_data, train_discrete_label,) + cebra_model.fit(train_data, train_discrete_label) # continuous behavior contrastive learning cebra_model.fit(train_data, train_continuous_label) # mixed behavior contrastive learning @@ -1257,10 +1261,10 @@ Putting all previous snippet examples together, we obtain the following pipeline cebra_model = cebra.CEBRA.load(tmp_file) train_embedding = cebra_model.transform(train_data) valid_embedding = cebra_model.transform(valid_data) - assert train_embedding.shape == (70, 8) - assert valid_embedding.shape == (30, 8) + assert train_embedding.shape == (70, 8) # TODO(user): change to split ratio & output dim + assert valid_embedding.shape == (30, 8) # TODO(user): change to split ratio & output dim - # 7. Evaluate the model performances + # 7. Evaluate the model performance (you can also check the train_data) goodness_of_fit = cebra.sklearn.metrics.infonce_loss(cebra_model, valid_data, valid_discrete_label, From 823c9ca8378beca3e029988fd85a2932c9e3c405 Mon Sep 17 00:00:00 2001 From: Mackenzie Mathis Date: Mon, 3 Mar 2025 14:25:26 +0100 Subject: [PATCH 15/86] General Doc refresher (#232) * Update installation.rst - python 3.9+ * Update index.rst * Update figures.rst * Update index.rst -typo fix * Update usage.rst - update suggestion on data split * Update docs/source/usage.rst Co-authored-by: Steffen Schneider * Update usage.rst - indent error fixed * Update usage.rst - changed infoNCE to new GoF * Update usage.rst - finx numpy() doctest * Update usage.rst - small typo fix (label) * Update usage.rst --------- Co-authored-by: Steffen Schneider --- docs/source/figures.rst | 4 +- docs/source/index.rst | 39 +++++++++-------- docs/source/installation.rst | 6 +-- docs/source/usage.rst | 82 ++++++++++++++++++++---------------- 4 files changed, 72 insertions(+), 59 deletions(-) diff --git a/docs/source/figures.rst b/docs/source/figures.rst index 24b1987e..a4101f4a 100644 --- a/docs/source/figures.rst +++ b/docs/source/figures.rst @@ -1,7 +1,7 @@ Figures ======= -CEBRA was introduced in `Schneider, Lee and Mathis (2022)`_ and applied to various datasets across +CEBRA was introduced in `Schneider, Lee and Mathis (2023)`_ and applied to various datasets across animals and recording modalities. In this section, we provide reference code for reproducing the figures and experiments. Since especially @@ -56,4 +56,4 @@ differ in minor typographic details. -.. _Schneider, Lee and Mathis (2022): https://arxiv.org/abs/2204.00673 +.. _Schneider, Lee and Mathis (2023): https://www.nature.com/articles/s41586-023-06031-6 diff --git a/docs/source/index.rst b/docs/source/index.rst index c8231746..1a6ce4d2 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -34,27 +34,18 @@ Please support the development of CEBRA by starring and/or watching the project Installation and Setup ---------------------- -Please see the dedicated :doc:`Installation Guide ` for information on installation options using ``conda``, ``pip`` and ``docker``. - -Have fun! 😁 +Please see the dedicated :doc:`Installation Guide ` for information on installation options using ``conda``, ``pip`` and ``docker``. Have fun! 😁 Usage ----- Please head over to the :doc:`Usage ` tab to find step-by-step instructions to use CEBRA on your data. For example use cases, see the :doc:`Demos ` tab. -Integrations ------------- - -CEBRA can be directly integrated with existing libraries commonly used in data analysis. The ``cebra.integrations`` module -is getting actively extended. Right now, we offer integrations for ``scikit-learn``-like usage of CEBRA, a package making use of ``matplotlib`` to plot the CEBRA model results, as well as the -possibility to compute CEBRA embeddings on DeepLabCut_ outputs directly. - Licensing --------- - -Since version 0.4.0, CEBRA is open source software under an Apache 2.0 license. +The ideas presented in our package are currently patent pending (Patent No. WO2023143843). +Since version 0.4.0, CEBRA's source is licenced under an Apache 2.0 license. Prior versions 0.1.0 to 0.3.1 were released for academic use only. Please see the full license file on Github_ for further information. @@ -65,13 +56,19 @@ Contributing Please refer to the :doc:`Contributing ` tab to find our guidelines on contributions. -Code contributors +Code Contributors ----------------- -The CEBRA code was originally developed by Steffen Schneider, Jin H. Lee, and Mackenzie Mathis (up to internal version 0.0.2). As of March 2023, it is being actively extended and maintained by `Steffen Schneider`_, `Célia Benquet`_, and `Mackenzie Mathis`_. +The CEBRA code was originally developed by Steffen Schneider, Jin H. Lee, and Mackenzie Mathis (up to internal version 0.0.2). Please see our AUTHORS file for more information. -References ----------- +Integrations +------------ + +CEBRA can be directly integrated with existing libraries commonly used in data analysis. Namely, we provide a ``scikit-learn`` style interface to use CEBRA. Additionally, we offer integrations with our ``scikit-learn``-style of using CEBRA, a package making use of ``matplotlib`` and ``plotly`` to plot the CEBRA model results, as well as the possibility to compute CEBRA embeddings on DeepLabCut_ outputs directly. If you have another suggestion, please head over to Discussions_ on GitHub_! + + +Key References +-------------- .. code:: @article{schneider2023cebra, @@ -82,14 +79,22 @@ References year = {2023}, } + @article{xCEBRA2025, + author={Steffen Schneider and Rodrigo Gonz{\'a}lez Laiz and Anastasiia Filippova and Markus Frey and Mackenzie W Mathis}, + title = {Time-series attribution maps with regularized contrastive learning}, + journal = {AISTATS}, + url = {https://openreview.net/forum?id=aGrCXoTB4P}, + year = {2025}, + } + This documentation is based on the `PyData Theme`_. .. _`Twitter`: https://twitter.com/cebraAI .. _`PyData Theme`: https://github.com/pydata/pydata-sphinx-theme .. _`DeepLabCut`: https://deeplabcut.org +.. _`Discussions`: https://github.com/AdaptiveMotorControlLab/CEBRA/discussions .. _`Github`: https://github.com/AdaptiveMotorControlLab/cebra .. _`email`: mailto:mackenzie.mathis@epfl.ch .. _`Steffen Schneider`: https://github.com/stes -.. _`Célia Benquet`: https://github.com/CeliaBenquet .. _`Mackenzie Mathis`: https://github.com/MMathisLab diff --git a/docs/source/installation.rst b/docs/source/installation.rst index a9650452..c5823fa7 100644 --- a/docs/source/installation.rst +++ b/docs/source/installation.rst @@ -4,7 +4,7 @@ Installation Guide System Requirements ------------------- -CEBRA is written in Python (3.8+) and PyTorch. CEBRA is most effective when used with a GPU, but CPU-only support is provided. We provide instructions to run CEBRA on your system directly. The instructions below were tested on different compute setups with Ubuntu 18.04 or 20.04, using Nvidia GTX 2080, A4000, and V100 cards. Other setups are possible (including Windows), as long as CUDA 10.2+ support is guaranteed. +CEBRA is written in Python (3.9+) and PyTorch. CEBRA is most effective when used with a GPU, but CPU-only support is provided. We provide instructions to run CEBRA on your system directly. The instructions below were tested on different compute setups with Ubuntu 18.04 or 20.04, using Nvidia GTX 2080, A4000, and V100 cards. Other setups are possible (including Windows), as long as CUDA 10.2+ support is guaranteed. - Software dependencies and operating systems: - Linux or MacOS @@ -93,11 +93,11 @@ we outline different options below. * 🚀 For more advanced users, CEBRA has different extra install options that you can select based on your usecase: - * ``[integrations]``: This will install (experimental) support for our streamlit and jupyter integrations. + * ``[integrations]``: This will install (experimental) support for integrations, such as plotly. * ``[docs]``: This will install additional dependencies for building the package documentation. * ``[dev]``: This will install additional dependencies for development, unit and integration testing, code formatting, etc. Install this extension if you want to work on a pull request. - * ``[demos]``: This will install additional dependencies for running our demo notebooks. + * ``[demos]``: This will install additional dependencies for running our demo notebooks in Jupyter. * ``[datasets]``: This extension will install additional dependencies to use the pre-installed datasets in ``cebra.datasets``. diff --git a/docs/source/usage.rst b/docs/source/usage.rst index 53821e36..8b60aa69 100644 --- a/docs/source/usage.rst +++ b/docs/source/usage.rst @@ -1207,42 +1207,47 @@ Putting all previous snippet examples together, we obtain the following pipeline # 1. Define a CEBRA model cebra_model = cebra.CEBRA( - model_architecture = "offset10-model", - batch_size = 512, - learning_rate = 1e-4, - temperature_mode='constant', - temperature = 0.1, - max_iterations = 10, # TODO(user): to change to ~500-10000 depending on dataset size - #max_adapt_iterations = 10, # TODO(user): use and to change to ~100-500 if adapting - time_offsets = 10, - output_dimension = 8, - verbose = False + model_architecture = "offset10-model", + batch_size = 512, + learning_rate = 1e-4, + temperature_mode='constant', + temperature = 0.1, + max_iterations = 10, # TODO(user): to change to ~500-10000 depending on dataset size + #max_adapt_iterations = 10, # TODO(user): use and to change to ~100-500 if adapting + time_offsets = 10, + output_dimension = 8, + verbose = False ) - + # 2. Load example data neural_data = cebra.load_data(file="neural_data.npz", key="neural") new_neural_data = cebra.load_data(file="neural_data.npz", key="new_neural") continuous_label = cebra.load_data(file="auxiliary_behavior_data.h5", key="auxiliary_variables", columns=["continuous1", "continuous2", "continuous3"]) discrete_label = cebra.load_data(file="auxiliary_behavior_data.h5", key="auxiliary_variables", columns=["discrete"]).flatten() - + + assert neural_data.shape == (100, 3) assert new_neural_data.shape == (100, 4) assert discrete_label.shape == (100, ) assert continuous_label.shape == (100, 3) - - # 3. Split data and labels - ( - train_data, - valid_data, - train_discrete_label, - valid_discrete_label, - train_continuous_label, - valid_continuous_label, - ) = train_test_split(neural_data, - discrete_label, - continuous_label, - test_size=0.3) - + + # 3. Split data and labels into train/validation + from sklearn.model_selection import train_test_split + + split_idx = int(0.8 * len(neural_data)) + # suggestion: 5%-20% depending on your dataset size; note that this splits the + # into an early and late part, which might not be ideal for your data/experiment! + # As a more involved alternative, consider e.g. a nested time-series split. + + train_data = neural_data[:split_idx] + valid_data = neural_data[split_idx:] + + train_continuous_label = continuous_label[:split_idx] + valid_continuous_label = continuous_label[split_idx:] + + train_discrete_label = discrete_label[:split_idx] + valid_discrete_label = discrete_label[split_idx:] + # 4. Fit the model # time contrastive learning cebra_model.fit(train_data) @@ -1252,33 +1257,36 @@ Putting all previous snippet examples together, we obtain the following pipeline cebra_model.fit(train_data, train_continuous_label) # mixed behavior contrastive learning cebra_model.fit(train_data, train_discrete_label, train_continuous_label) - + + # 5. Save the model tmp_file = Path(tempfile.gettempdir(), 'cebra.pt') cebra_model.save(tmp_file) - + # 6. Load the model and compute an embedding cebra_model = cebra.CEBRA.load(tmp_file) train_embedding = cebra_model.transform(train_data) valid_embedding = cebra_model.transform(valid_data) - assert train_embedding.shape == (70, 8) # TODO(user): change to split ratio & output dim - assert valid_embedding.shape == (30, 8) # TODO(user): change to split ratio & output dim - + + assert train_embedding.shape == (80, 8) # TODO(user): change to split ratio & output dim + assert valid_embedding.shape == (20, 8) # TODO(user): change to split ratio & output dim + # 7. Evaluate the model performance (you can also check the train_data) - goodness_of_fit = cebra.sklearn.metrics.infonce_loss(cebra_model, + goodness_of_fit = cebra.sklearn.metrics.goodness_of_fit_score(cebra_model, valid_data, valid_discrete_label, - valid_continuous_label, - num_batches=5) - + valid_continuous_label) + # 8. Adapt the model to a new session cebra_model.fit(new_neural_data, adapt = True) - + # 9. Decode discrete labels behavior from the embedding decoder = cebra.KNNDecoder() decoder.fit(train_embedding, train_discrete_label) prediction = decoder.predict(valid_embedding) - assert prediction.shape == (30,) + assert prediction.shape == (20,) + + 👉 For further guidance on different/customized applications of CEBRA on your own data, refer to the ``examples/`` folder or to the full documentation folder ``docs/``. From b677e673cfc618883d1db0854a70cce81a1e7254 Mon Sep 17 00:00:00 2001 From: Mackenzie Mathis Date: Tue, 4 Mar 2025 22:58:59 +0100 Subject: [PATCH 16/86] render plotly in our docs, show code/doc version (#231) --- .github/workflows/docs.yml | 7 +++++++ cebra/integrations/plotly.py | 5 +++-- docs/Makefile | 5 +++++ docs/source/_static/css/custom.js | 6 ++++++ docs/source/conf.py | 28 +++++++++++++++++++++++++--- 5 files changed, 46 insertions(+), 5 deletions(-) create mode 100644 docs/source/_static/css/custom.js diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index 47b5862d..826d9e91 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -9,6 +9,12 @@ on: - main - public - dev + paths: + - '**.py' + - '**.ipynb' + - '**.js' + - '**.rst' + - '**.md' jobs: build: @@ -69,6 +75,7 @@ jobs: pip install torch --extra-index-url https://download.pytorch.org/whl/cpu pip install '.[docs]' + - name: Build docs run: | ls docs/source/cebra-figures diff --git a/cebra/integrations/plotly.py b/cebra/integrations/plotly.py index bbaa1de6..8b0515e4 100644 --- a/cebra/integrations/plotly.py +++ b/cebra/integrations/plotly.py @@ -28,6 +28,7 @@ import numpy.typing as npt import plotly.graph_objects import torch +import plotly.graph_objects as go from cebra.integrations.matplotlib import _EmbeddingPlot @@ -154,7 +155,7 @@ def _plot_3d(self, **kwargs) -> plotly.graph_objects.Figure: def plot_embedding_interactive( embedding: Union[npt.NDArray, torch.Tensor], embedding_labels: Optional[Union[npt.NDArray, torch.Tensor, str]] = "grey", - axis: Optional[plotly.graph_objects.Figure] = None, + axis: Optional["go.Figure"] = None, markersize: float = 1, idx_order: Optional[Tuple[int]] = None, alpha: float = 0.4, @@ -163,7 +164,7 @@ def plot_embedding_interactive( figsize: Tuple[int] = (5, 5), dpi: int = 100, **kwargs, -) -> plotly.graph_objects.Figure: +) -> "go.Figure": """Plot embedding in a 3D dimensional space. This is supposing that the dimensions provided to ``idx_order`` are in the range of the number of diff --git a/docs/Makefile b/docs/Makefile index 741d165e..2739f4af 100644 --- a/docs/Makefile +++ b/docs/Makefile @@ -18,6 +18,11 @@ help: html: PYTHONPATH=.. $(SPHINXBUILD) -M html "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) +# Build multiple versions +html_versions: + for v in latest v0.2.0 v0.3.0 v0.4.0; do \ + PYTHONPATH=.. $(SPHINXBUILD) -b html "$(SOURCEDIR)" "$(BUILDDIR)/$$v"; \ + done # Remove the current temp folder and page build clean: rm -rf build diff --git a/docs/source/_static/css/custom.js b/docs/source/_static/css/custom.js new file mode 100644 index 00000000..f9afa170 --- /dev/null +++ b/docs/source/_static/css/custom.js @@ -0,0 +1,6 @@ +requirejs.config({ + paths: { + base: '/static/base', + plotly: 'https://cdn.plot.ly/plotly-2.12.1.min.js?noext', + }, +}); diff --git a/docs/source/conf.py b/docs/source/conf.py index c5e12b5a..28cf2b14 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -47,8 +47,8 @@ def get_years(start_year=2021): # -- Project information ----------------------------------------------------- project = "cebra" -copyright = f"""{get_years(2021)}, Steffen Schneider, Jin H Lee, Mackenzie Mathis""" -author = "Steffen Schneider, Jin H Lee, Mackenzie Mathis" +copyright = f"""{get_years(2021)}""" +author = "See AUTHORS.md" # The full version, including alpha/beta/rc tags release = cebra.__version__ @@ -57,6 +57,13 @@ def get_years(start_year=2021): # Add any Sphinx extension module names here, as strings. They can be # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. + +#https://github.com/spatialaudio/nbsphinx/issues/128#issuecomment-1158712159 +html_js_files = [ + "require.min.js", # Add to your _static + "custom.js", +] + extensions = [ "sphinx.ext.autodoc", "sphinx.ext.napoleon", @@ -68,13 +75,13 @@ def get_years(start_year=2021): "sphinx_tabs.tabs", "sphinx.ext.mathjax", "IPython.sphinxext.ipython_console_highlighting", - # "sphinx_panels", # Note: package to avoid: no longer maintained. "sphinx_design", "sphinx_togglebutton", "sphinx.ext.doctest", "sphinx_gallery.load_style", ] + coverage_show_missing_items = True panels_add_bootstrap_css = False @@ -137,6 +144,21 @@ def get_years(start_year=2021): # a list of builtin themes. html_theme = "pydata_sphinx_theme" +html_context = { + "default_mode": "light", + "switcher": { + "version_match": "latest", # Adjust this dynamically per version + "versions": [ + ("latest", "/latest/"), + ("v0.2.0", "/v0.2.0/"), + ("v0.3.0", "/v0.3.0/"), + ("v0.4.0", "/v0.4.0/"), + ("v0.5.0rc1", "/v0.5.0rc1/"), + ], + }, + "navbar_start": ["version-switcher", "navbar-logo"], # Place the dropdown above the logo +} + # More info on theme options: # https://pydata-sphinx-theme.readthedocs.io/en/latest/user_guide/configuring.html html_theme_options = { From 37ed6f5ed953bb4222f81953ef53f5e0c2df00ce Mon Sep 17 00:00:00 2001 From: Mackenzie Mathis Date: Thu, 6 Mar 2025 18:00:29 +0100 Subject: [PATCH 17/86] Update layout.html (#233) --- docs/source/_templates/layout.html | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/docs/source/_templates/layout.html b/docs/source/_templates/layout.html index 2994db97..0140a5cf 100644 --- a/docs/source/_templates/layout.html +++ b/docs/source/_templates/layout.html @@ -1,11 +1,15 @@ {% extends "pydata_sphinx_theme/layout.html" %} -{% block fonts %} +{% block extrahead %} + + + +{% endblock %} +{% block fonts %} - {% endblock %} {% block docs_sidebar %} From 09b89749863a9ec00e825b30811ba59fca18b715 Mon Sep 17 00:00:00 2001 From: Mackenzie Mathis Date: Thu, 6 Mar 2025 19:22:55 +0100 Subject: [PATCH 18/86] Update conf.py (#234) - adding link to new notebook icon --- docs/source/conf.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/source/conf.py b/docs/source/conf.py index 28cf2b14..a58f24ec 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -219,6 +219,8 @@ def get_years(start_year=2021): ] nbsphinx_thumbnails = { + "demo_notebooks/CEBRA_best_practices": + "_static/thumbnails/cebra-best.png", "demo_notebooks/Demo_primate_reaching": "_static/thumbnails/ForelimbS1.png", "demo_notebooks/Demo_hippocampus": From aa0db43251f3315c17c6391e4f4823529748b25c Mon Sep 17 00:00:00 2001 From: Mackenzie Mathis Date: Sat, 15 Mar 2025 13:59:38 +0100 Subject: [PATCH 19/86] Refactoring setup.cfg (#228) --- AUTHORS.md | 28 ++++++++++++++++++++++++++++ setup.cfg | 7 +++---- 2 files changed, 31 insertions(+), 4 deletions(-) create mode 100644 AUTHORS.md diff --git a/AUTHORS.md b/AUTHORS.md new file mode 100644 index 00000000..11415b12 --- /dev/null +++ b/AUTHORS.md @@ -0,0 +1,28 @@ + + + + +CEBRA was initially developed by **Mackenzie Mathis** and **Steffen Schneider** (2021+), who are co-inventors on the patent application [WO2023143843](https://infoscience.epfl.ch/entities/patent/0d9debed-4d22-47b7-bad1-f211e7010323). +**Jin Hwa Lee** contributed significantly to our first paper: + +> **Schneider, S., Lee, J.H., & Mathis, M.W.** +> [*Learnable latent embeddings for joint behavioural and neural analysis.*](https://doi.org/10.1038/s41586-023-06031-6) +> Nature 617, 360–368 (2023) + +CEBRA is actively developed by [**Mackenzie Mathis**](https://www.mackenziemathislab.org/) and [**Steffen Schneider**](https://dynamical-inference.ai/) and their labs. + +It is a publicly available tool that has benefited from contributions and suggestions from many individuals: [CEBRA/graphs/contributors](https://github.com/AdaptiveMotorControlLab/CEBRA/graphs/contributors). + +## CEBRA Extensions + +### 2023 +- **Steffen Schneider, Rodrigo González Laiz, Markus Frey, Mackenzie W. Mathis** + [*Identifiable attribution maps using regularized contrastive learning.*](https://sslneurips23.github.io/paper_pdfs/paper_80.pdf) + NeurIPS 4th Workshop on Self-Supervised Learning: Theory and Practice (2023) + +### 2025 +- **Steffen Schneider, Rodrigo González Laiz, Anastasiia Filippova, Markus Frey, Mackenzie W. Mathis** + [*Time-series attribution maps with regularized contrastive learning.*](https://openreview.net/forum?id=aGrCXoTB4P) + AISTATS (2025) + + diff --git a/setup.cfg b/setup.cfg index 9da156ec..9a3c3a41 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,8 +1,8 @@ [metadata] name = cebra version = attr: cebra.__version__ -author = Steffen Schneider, Jin H Lee, Mackenzie W Mathis -author_email = stes@hey.com +author = file: AUTHORS.md +author_email = stes@hey.com, mackenzie@post.harvard.edu description = Consistent Embeddings of high-dimensional Recordings using Auxiliary variables long_description = file: README.md long_description_content_type = text/markdown @@ -58,9 +58,9 @@ datasets = hdf5storage # for creating .mat files in new format openpyxl # for excel file format loading integrations = - jupyter pandas plotly + seaborn docs = sphinx==5.3 sphinx-gallery==0.10.1 @@ -83,7 +83,6 @@ demos = ipykernel jupyter nbconvert - seaborn # TODO(stes): Additional dependency for running # co-homology analysis # is ripser, which can be tricky to From 490196651350416136aa9d9770fb5e84f71f674b Mon Sep 17 00:00:00 2001 From: Mackenzie Mathis Date: Sat, 15 Mar 2025 17:57:02 +0100 Subject: [PATCH 20/86] Home page landing update (#235) * website refresh --- docs/root/index.html | 266 +++++++++++++++++++++++++++---------------- 1 file changed, 170 insertions(+), 96 deletions(-) diff --git a/docs/root/index.html b/docs/root/index.html index 86015297..cee11753 100644 --- a/docs/root/index.html +++ b/docs/root/index.html @@ -7,21 +7,21 @@ - Learnable latent embeddings for joint behavioural and neural analysis - + CEBRA + - + - + @@ -36,7 +36,6 @@ CEBRA @@ -93,58 +116,26 @@
-

Learnable latent embeddings for joint behavioural and neural analysis

-
- - -
-
-
- Steffen Schneider*
- EPFL & IMPRS-IS - - -
-
- Jin Hwa Lee*
- EPFL - -
-
- Mackenzie Mathis
- EPFL - - -
+

CEBRA: a self-supervised learning algorithm for obtaining interpretable, Consistent EmBeddings of high-dimensional Recordings using Auxiliary variables

-
-
-
+ - -
+ - -
- CEBRA is a machine-learning - method that can be used to - compress time series in a way - that reveals otherwise hidden - structures in the variability of - the data. It excels on behavioural - and neural data recorded - simultaneously, and it can - decode activity from the visual - cortex of the mouse brain to - reconstruct a viewed video. +
- +

Demo Applications

-

Application of CEBRA-Behavior to rat hippocampus data (Grosmark and Buzsáki, 2016), showing position/neural activity (left), overlayed with decoding obtained by CEBRA. The current point in embedding space is highlighted (right). CEBRA obtains a median absolute error of 5cm (total track length: 160cm; see pre-print for details). Video is played at 2x real-time speed.

+

Application of CEBRA-Behavior to rat hippocampus data (Grosmark and Buzsáki, 2016), showing position/neural activity (left), overlayed with decoding obtained by CEBRA. The current point in embedding space is highlighted (right). CEBRA obtains a median absolute error of 5cm (total track length: 160cm; see Schneider et al. 2023 for details). Video is played at 2x real-time speed.

+
+ +
+ +
+ +
+ +

Interactive visualization of the CEBRA embedding for the rat hippocampus data. This 3D plot shows how neural activity is mapped to a lower-dimensional space that correlates with the animal's position and movement direction. Open In Colaboratory

+
+
- + -

CEBRA applied to mouse primary visual cortex, collected at the Allen Institute (de Vries et al. 2020, Siegle et al. 2021). 2-photon and Neuropixels recordings are embedded with CEBRA using DINO frame features as labels. - The embedding is used to decode the video frames using a kNN decoder on the CEBRA-Behavior embedding from the test set.

+

CEBRA applied to mouse primary visual cortex, collected at the Allen Institute (de Vries et al. 2020, Siegle et al. 2021). 2-photon and Neuropixels recordings are embedded with CEBRA using DINO frame features as labels. + The embedding is used to decode the video frames using a kNN decoder on the CEBRA-Behavior embedding from the test set.

+
+ +
+ + +

CEBRA applied to M1 and S1 neural data, demonstrating how neural activity from primary motor and somatosensory cortices can be effectively embedded and analyzed. See DeWolf et al. 2024 for details.

+
+
+

Publications

+ +
+
+
Learnable latent embeddings for joint behavioural and neural analysis
+

Steffen Schneider*, Jin Hwa Lee*, Mackenzie Weygandt Mathis. Nature 2023

+

A comprehensive introduction to CEBRA, demonstrating its capabilities in joint behavioral and neural analysis across various datasets and species.

+ Read Paper + Preprint +
+
+ +
+
+
Time-series attribution maps with regularized contrastive learning
+

Steffen Schneider, Rodrigo González Laiz, Anastasiia Filipova, Markus Frey, Mackenzie Weygandt Mathis. AISTATS 2025

+

An extension of CEBRA that provides attribution maps for time-series data using regularized contrastive learning.

+ Read Paper + Preprint + NeurIPS-W 2023 Version +
+
+
+ +
+

Patent Information

+ +
+
+
Patent Pending
+

Please note EPFL has filed a patent titled "Dimensionality reduction of time-series data, and systems and devices that use the resultant embeddings" so if this does not work for your non-academic use case, please contact the Tech Transfer Office at EPFL.

+
+

- Abstract + Overview

@@ -209,31 +257,6 @@

-
-

- - Pre-Print -

-
- -
-

- The pre-print is available on arxiv at arxiv.org/abs/2204.00673. -

- -
-

@@ -244,8 +267,7 @@

You can find our official implementation of the CEBRA algorithm on GitHub: Watch and Star the repository to be notified of future updates and releases. - You can also follow us on Twitter or subscribe to our - mailing list for updates on the project. + You can also follow us on Twitter for updates on the project.

If you are interested in collaborations, please contact us via @@ -258,13 +280,13 @@

BibTeX

-

Please cite our paper as follows:

+

Please cite our papers as follows:

@article{schneider2023cebra,
-   author={Schneider, Steffen and Lee, Jin Hwa and Mathis, Mackenzie Weygandt},
+   author={Steffen Schneider and Jin Hwa Lee and Mackenzie Weygandt Mathis},
  title={Learnable latent embeddings for joint behavioural and neural analysis},
  journal={Nature},
  year={2023},
@@ -277,6 +299,58 @@

+ +
+
+ + @inproceedings{schneider2025timeseries,
+   title={Time-series attribution maps with regularized contrastive learning},
+   author={Steffen Schneider and Rodrigo Gonz{\'a}lez Laiz and Anastasiia Filippova and Markus Frey and Mackenzie Weygandt Mathis},
+   booktitle={The 28th International Conference on Artificial Intelligence and Statistics},
+   year={2025},
+   url={https://openreview.net/forum?id=aGrCXoTB4P}
+ } +
+
+
+ +
+

+ + Impact & Citations +

+
+ +
+

+ CEBRA has been cited in numerous high-impact publications across neuroscience, machine learning, and related fields. Our work has influenced research in neural decoding, brain-computer interfaces, computational neuroscience, and machine learning methods for time-series analysis. +

+ + + +
+
+

Our research has been cited in proceedings and journals including Nature Science ICML Nature Neuroscience ICML Neuron NeurIPS ICLR and others.

+
+
+
+ +
+
+ + MLAI Logo + +
+ © 2021 - present | EPFL Mathis Laboratory +
+
+
Webpage designed using Bootstrap 5 and Fontawesome 5. From b2357fd8fec19fc7e1f543d0f91631de30c199f2 Mon Sep 17 00:00:00 2001 From: Mackenzie Mathis Date: Thu, 17 Apr 2025 10:51:45 +0200 Subject: [PATCH 21/86] v0.5.0 (#238) --- AUTHORS.md | 28 +++++++++++++--------------- Dockerfile | 2 +- Makefile | 2 +- PKGBUILD | 2 +- cebra/__init__.py | 2 +- cebra/integrations/plotly.py | 25 +++++++++++++------------ docs/root/index.html | 30 +++++++++++++++--------------- docs/source/conf.py | 9 +++++---- docs/source/usage.rst | 32 ++++++++++++++++---------------- reinstall.sh | 2 +- setup.cfg | 4 +++- tools/build_docker.sh | 24 +++++++++++++++++++++++- tools/build_docs.sh | 4 ++-- 13 files changed, 95 insertions(+), 71 deletions(-) diff --git a/AUTHORS.md b/AUTHORS.md index 11415b12..17db8887 100644 --- a/AUTHORS.md +++ b/AUTHORS.md @@ -2,27 +2,25 @@ -CEBRA was initially developed by **Mackenzie Mathis** and **Steffen Schneider** (2021+), who are co-inventors on the patent application [WO2023143843](https://infoscience.epfl.ch/entities/patent/0d9debed-4d22-47b7-bad1-f211e7010323). -**Jin Hwa Lee** contributed significantly to our first paper: +CEBRA was initially developed by **Mackenzie Mathis** and **Steffen Schneider** (2021+), who are co-inventors on the patent application [WO2023143843](https://infoscience.epfl.ch/entities/patent/0d9debed-4d22-47b7-bad1-f211e7010323). +**Jin Hwa Lee** contributed significantly to our first paper: -> **Schneider, S., Lee, J.H., & Mathis, M.W.** -> [*Learnable latent embeddings for joint behavioural and neural analysis.*](https://doi.org/10.1038/s41586-023-06031-6) +> **Schneider, S., Lee, J.H., & Mathis, M.W.** +> [*Learnable latent embeddings for joint behavioural and neural analysis.*](https://doi.org/10.1038/s41586-023-06031-6) > Nature 617, 360–368 (2023) -CEBRA is actively developed by [**Mackenzie Mathis**](https://www.mackenziemathislab.org/) and [**Steffen Schneider**](https://dynamical-inference.ai/) and their labs. +CEBRA is actively developed by [**Mackenzie Mathis**](https://www.mackenziemathislab.org/) and [**Steffen Schneider**](https://dynamical-inference.ai/) and their labs. -It is a publicly available tool that has benefited from contributions and suggestions from many individuals: [CEBRA/graphs/contributors](https://github.com/AdaptiveMotorControlLab/CEBRA/graphs/contributors). +It is a publicly available tool that has benefited from contributions and suggestions from many individuals: [CEBRA/graphs/contributors](https://github.com/AdaptiveMotorControlLab/CEBRA/graphs/contributors). -## CEBRA Extensions +## CEBRA Extensions -### 2023 -- **Steffen Schneider, Rodrigo González Laiz, Markus Frey, Mackenzie W. Mathis** - [*Identifiable attribution maps using regularized contrastive learning.*](https://sslneurips23.github.io/paper_pdfs/paper_80.pdf) +### 2023 +- **Steffen Schneider, Rodrigo González Laiz, Markus Frey, Mackenzie W. Mathis** + [*Identifiable attribution maps using regularized contrastive learning.*](https://sslneurips23.github.io/paper_pdfs/paper_80.pdf) NeurIPS 4th Workshop on Self-Supervised Learning: Theory and Practice (2023) -### 2025 -- **Steffen Schneider, Rodrigo González Laiz, Anastasiia Filippova, Markus Frey, Mackenzie W. Mathis** - [*Time-series attribution maps with regularized contrastive learning.*](https://openreview.net/forum?id=aGrCXoTB4P) +### 2025 +- **Steffen Schneider, Rodrigo González Laiz, Anastasiia Filippova, Markus Frey, Mackenzie W. Mathis** + [*Time-series attribution maps with regularized contrastive learning.*](https://openreview.net/forum?id=aGrCXoTB4P) AISTATS (2025) - - diff --git a/Dockerfile b/Dockerfile index 7cd326d5..46c8a555 100644 --- a/Dockerfile +++ b/Dockerfile @@ -40,7 +40,7 @@ RUN make dist FROM cebra-base # install the cebra wheel -ENV WHEEL=cebra-0.5.0rc1-py3-none-any.whl +ENV WHEEL=cebra-0.5.0-py3-none-any.whl WORKDIR /build COPY --from=wheel /build/dist/${WHEEL} . RUN pip install --no-cache-dir ${WHEEL}'[dev,integrations,datasets]' diff --git a/Makefile b/Makefile index a1e8d3b2..5b8cb107 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,4 @@ -CEBRA_VERSION := 0.5.0rc1 +CEBRA_VERSION := 0.5.0 dist: python3 -m pip install virtualenv diff --git a/PKGBUILD b/PKGBUILD index 1f8b3db5..7aa985a8 100644 --- a/PKGBUILD +++ b/PKGBUILD @@ -1,7 +1,7 @@ # Maintainer: Steffen Schneider pkgname=python-cebra _pkgname=cebra -pkgver=0.5.0rc1 +pkgver=0.5.0 pkgrel=1 pkgdesc="Consistent Embeddings of high-dimensional Recordings using Auxiliary variables" url="https://cebra.ai" diff --git a/cebra/__init__.py b/cebra/__init__.py index edf1b5ee..0eb1f645 100644 --- a/cebra/__init__.py +++ b/cebra/__init__.py @@ -66,7 +66,7 @@ import cebra.integrations.sklearn as sklearn -__version__ = "0.5.0rc1" +__version__ = "0.5.0" __all__ = ["CEBRA"] __allow_lazy_imports = False __lazy_imports = {} diff --git a/cebra/integrations/plotly.py b/cebra/integrations/plotly.py index 8b0515e4..2cfc5ec9 100644 --- a/cebra/integrations/plotly.py +++ b/cebra/integrations/plotly.py @@ -27,8 +27,8 @@ import numpy as np import numpy.typing as npt import plotly.graph_objects -import torch import plotly.graph_objects as go +import torch from cebra.integrations.matplotlib import _EmbeddingPlot @@ -153,17 +153,18 @@ def _plot_3d(self, **kwargs) -> plotly.graph_objects.Figure: def plot_embedding_interactive( - embedding: Union[npt.NDArray, torch.Tensor], - embedding_labels: Optional[Union[npt.NDArray, torch.Tensor, str]] = "grey", - axis: Optional["go.Figure"] = None, - markersize: float = 1, - idx_order: Optional[Tuple[int]] = None, - alpha: float = 0.4, - cmap: str = "cool", - title: str = "Embedding", - figsize: Tuple[int] = (5, 5), - dpi: int = 100, - **kwargs, + embedding: Union[npt.NDArray, torch.Tensor], + embedding_labels: Optional[Union[npt.NDArray, torch.Tensor, + str]] = "grey", + axis: Optional["go.Figure"] = None, + markersize: float = 1, + idx_order: Optional[Tuple[int]] = None, + alpha: float = 0.4, + cmap: str = "cool", + title: str = "Embedding", + figsize: Tuple[int] = (5, 5), + dpi: int = 100, + **kwargs, ) -> "go.Figure": """Plot embedding in a 3D dimensional space. diff --git a/docs/root/index.html b/docs/root/index.html index cee11753..aa740039 100644 --- a/docs/root/index.html +++ b/docs/root/index.html @@ -145,16 +145,16 @@

-

CEBRA is a machine-learning method that can be used to +

CEBRA is a machine-learning method that can be used to compress time series in a way that reveals otherwise hidden - structures in the variability of the data. It excels on - behavioural and neural data recorded simultaneously. + structures in the variability of the data. It excels on + behavioural and neural data recorded simultaneously. We have shown it can be used to decode the activity from the visual cortex of the mouse brain to reconstruct a viewed video, to decode trajectories from the sensoirmotor cortex of primates, and for decoding position during navigation. For these use cases and other demos see our Documentation.

- +
@@ -171,12 +171,12 @@

Demo Applications

-
- +

Interactive visualization of the CEBRA embedding for the rat hippocampus data. This 3D plot shows how neural activity is mapped to a lower-dimensional space that correlates with the animal's position and movement direction. Open In Colaboratory

@@ -191,7 +191,7 @@

Demo Applications

CEBRA applied to mouse primary visual cortex, collected at the Allen Institute (de Vries et al. 2020, Siegle et al. 2021). 2-photon and Neuropixels recordings are embedded with CEBRA using DINO frame features as labels. The embedding is used to decode the video frames using a kNN decoder on the CEBRA-Behavior embedding from the test set.

- +