Skip to content

Add unified encoder pytorch implementation #251

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 91 commits into
base: batched-inference-and-padding
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
91 commits
Select commit Hold shift + click to select a range
9898850
Fix linting errors in tests (#188)
stes Oct 27, 2024
36a91c7
Fix `scikit-learn` reference in conda environment files (#195)
stes Nov 8, 2024
5f46c32
Add support for new __sklearn_tags__ (#205)
stes Dec 16, 2024
7a4d3fc
Update workflows to actions/setup-python@v5, actions/cache@v4 (#212)
stes Jan 21, 2025
a79c2de
Fix deprecation warning force_all_finite -> ensure_all_finite for skl…
icarosadero Jan 22, 2025
7e74eda
Add tests to check legacy model loading (#214)
stes Jan 29, 2025
4e32661
Add improved goodness of fit implementation (#190)
stes Feb 2, 2025
3100730
Support numpy 2, upgrade tests to support torch 2.6 (#221)
stes Feb 2, 2025
bea2c04
Release 0.5.0rc1 (#189)
stes Feb 2, 2025
c32ed67
Fix pypi action (#222)
stes Feb 3, 2025
f99530c
Update base.py (#224)
icarosadero Feb 18, 2025
c822ffa
Change max consistency value to 100 instead of 99 (#227)
CeliaBenquet Mar 1, 2025
b713387
Update assets.py --> force check for parent dir (#230)
MMathisLab Mar 1, 2025
47945ca
User docs minor edit (#229)
MMathisLab Mar 1, 2025
823c9ca
General Doc refresher (#232)
MMathisLab Mar 3, 2025
b677e67
render plotly in our docs, show code/doc version (#231)
MMathisLab Mar 4, 2025
37ed6f5
Update layout.html (#233)
MMathisLab Mar 6, 2025
09b8974
Update conf.py (#234)
MMathisLab Mar 6, 2025
aa0db43
Refactoring setup.cfg (#228)
MMathisLab Mar 15, 2025
4901966
Home page landing update (#235)
MMathisLab Mar 15, 2025
b2357fd
v0.5.0 (#238)
MMathisLab Apr 17, 2025
ae3ef2a
Upgrade docs build (#241)
stes Apr 18, 2025
1127432
Allow indexing of the cebra docs (#242)
stes Apr 20, 2025
d86ccf0
Fix broken docs coverage workflows (#246)
stes Apr 23, 2025
92c8b1f
Add xCEBRA implementation (AISTATS 2025) (#225)
gonlairo Apr 23, 2025
a09d123
Fix linting errors in tests (#188)
stes Oct 27, 2024
521f003
Fix `scikit-learn` reference in conda environment files (#195)
stes Nov 8, 2024
46610e3
Add support for new __sklearn_tags__ (#205)
stes Dec 16, 2024
e8004ba
Update workflows to actions/setup-python@v5, actions/cache@v4 (#212)
stes Jan 21, 2025
ddc00f4
Fix deprecation warning force_all_finite -> ensure_all_finite for skl…
icarosadero Jan 22, 2025
7dc9f81
Add tests to check legacy model loading (#214)
stes Jan 29, 2025
a2a6c44
Add improved goodness of fit implementation (#190)
stes Feb 2, 2025
a3b143f
Support numpy 2, upgrade tests to support torch 2.6 (#221)
stes Feb 2, 2025
0d5d82a
Release 0.5.0rc1 (#189)
stes Feb 2, 2025
92fd9bc
Fix pypi action (#222)
stes Feb 3, 2025
69d91ef
Update base.py (#224)
icarosadero Feb 18, 2025
782b63a
Change max consistency value to 100 instead of 99 (#227)
CeliaBenquet Mar 1, 2025
d72b055
Update assets.py --> force check for parent dir (#230)
MMathisLab Mar 1, 2025
9fd91c3
User docs minor edit (#229)
MMathisLab Mar 1, 2025
8d636e9
General Doc refresher (#232)
MMathisLab Mar 3, 2025
36370be
render plotly in our docs, show code/doc version (#231)
MMathisLab Mar 4, 2025
f7f4d7f
Update layout.html (#233)
MMathisLab Mar 6, 2025
798f7b2
Update conf.py (#234)
MMathisLab Mar 6, 2025
4a2996d
Refactoring setup.cfg (#228)
MMathisLab Mar 15, 2025
7abd1b0
Home page landing update (#235)
MMathisLab Mar 15, 2025
673019a
v0.5.0 (#238)
MMathisLab Apr 17, 2025
9625680
Upgrade docs build (#241)
stes Apr 18, 2025
95e5296
Allow indexing of the cebra docs (#242)
stes Apr 20, 2025
20f5a77
Fix broken docs coverage workflows (#246)
stes Apr 23, 2025
0d85abb
Add xCEBRA implementation (AISTATS 2025) (#225)
gonlairo Apr 23, 2025
b19be59
start tests
gonlairo Jun 23, 2023
e908083
remove print statements
gonlairo Sep 27, 2023
3d2b1e3
first passing test
gonlairo Sep 27, 2023
3ef4bc1
move functionality to base file in solver and separate in functions
gonlairo Oct 27, 2023
ad56472
add test_select_model for multisession
gonlairo Oct 30, 2023
b73c123
remove float16
gonlairo Nov 24, 2023
d71ca8d
Improve modularity remove duplicate code and todos
CeliaBenquet Aug 21, 2024
3e91459
Add tests to solver
CeliaBenquet Aug 22, 2024
c6179ad
Fix save/load
CeliaBenquet Aug 22, 2024
dafabe5
Fix extra docs errors
CeliaBenquet Sep 18, 2024
7b0cc68
Add review updates
CeliaBenquet Sep 19, 2024
7dfd4b9
apply ruff auto-fixes
stes Oct 27, 2024
3acbdf4
fix linting errors
stes Jan 21, 2025
5745449
Run isort, ruff, yapf
CeliaBenquet Apr 23, 2025
fa3cd3e
Merge remote-tracking branch 'upstream/main' into batched-inference-a…
CeliaBenquet Apr 23, 2025
f082a1c
Update conf.py (#237)
MMathisLab Apr 23, 2025
d303e50
Update docs.yml to build from main (#248)
MMathisLab Apr 23, 2025
73f90ee
Update installation.rst --> add link to docker hub (#247)
MMathisLab Apr 23, 2025
1453885
Merge branch 'main' into batched-inference-and-padding
MMathisLab Apr 23, 2025
acd2111
Fix gaussian mixture dataset import
CeliaBenquet Apr 23, 2025
217a8a7
Fix all tests but xcebra tests
CeliaBenquet Apr 23, 2025
a1218aa
Fix pytorch API usage example
CeliaBenquet Apr 24, 2025
64d1db8
Make xCEBRA compatible with the batched inference & padding in solver
CeliaBenquet Apr 24, 2025
9875a38
Add some tests on transform() with xCEBRA
CeliaBenquet Apr 24, 2025
65fc455
Add some docstrings and typings and clean unnecessary changes
CeliaBenquet Apr 24, 2025
1d0c498
Implement review comments
CeliaBenquet Apr 24, 2025
4a25899
Fix sklearn test
CeliaBenquet Apr 25, 2025
b8945ae
Initial pass at integrating unifiedCEBRA
CeliaBenquet Apr 25, 2025
0d56e44
Add name in NOTE
CeliaBenquet Apr 25, 2025
c5dc011
Implement reviews on tests and typing
CeliaBenquet Apr 25, 2025
c9fa5c8
Fix import errors
CeliaBenquet Apr 28, 2025
9ba22bc
Merge branch 'batched-inference-and-padding' into unified-cebra
CeliaBenquet Apr 28, 2025
4632c04
Add select_model to aux solvers
CeliaBenquet Apr 28, 2025
a52f502
Merge branch 'batched-inference-and-padding' into unified-cebra
CeliaBenquet Apr 28, 2025
c22e40e
Fix tests
CeliaBenquet Apr 28, 2025
e8a1877
Add mask tests
CeliaBenquet Apr 28, 2025
22e3c47
Fix docs error
CeliaBenquet Apr 30, 2025
464f4aa
Merge branch 'batched-inference-and-padding' into unified-cebra
CeliaBenquet May 1, 2025
57c9494
Remove masking init()
CeliaBenquet May 1, 2025
0d953fc
Remove shuffled neurons in unified dataset
CeliaBenquet May 1, 2025
eba09b6
Remove extra datasets
CeliaBenquet May 1, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .dockerignore
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,5 @@ tests/
third_party/
tools/
PKGBUILD

!docs/requirements.txt
24 changes: 20 additions & 4 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ on:

jobs:
build:
timeout-minutes: 30
strategy:
fail-fast: true
matrix:
Expand All @@ -18,27 +19,33 @@ jobs:
# We aim to support the versions on pytorch.org
# as well as selected previous versions on
# https://pytorch.org/get-started/previous-versions/
torch-version: ["2.2.2", "2.4.0"]
torch-version: ["2.4.0", "2.6.0"]
sklearn-version: ["latest"]
include:
- os: windows-latest
torch-version: 2.4.0
python-version: "3.10"
sklearn-version: "latest"
- os: ubuntu-latest
torch-version: 2.4.0
python-version: "3.10"
sklearn-version: "legacy"

runs-on: ${{ matrix.os }}

steps:
- name: Cache dependencies
id: pip-cache
uses: actions/cache@v3
uses: actions/cache@v4
with:
path: ~/.cache/pip
key: pip-os_${{ runner.os }}-python_${{ matrix.python-version }}-torch_${{ matrix.torch-version }}
key: pip-os_${{ runner.os }}-python_${{ matrix.python-version }}-torch_${{ matrix.torch-version }}-sklearn_${{ matrix.sklearn-version }}

- name: Checkout code
uses: actions/checkout@v2

- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v4
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}

Expand All @@ -48,6 +55,11 @@ jobs:
python -m pip install torch==${{ matrix.torch-version }} --extra-index-url https://download.pytorch.org/whl/cpu
pip install '.[dev,datasets,integrations]'

- name: Check sklearn legacy version
if: matrix.sklearn-version == 'legacy'
run: |
pip install scikit-learn==1.4.2 '.[dev,datasets,integrations]'

- name: Run the formatter
run: |
make format
Expand All @@ -56,6 +68,10 @@ jobs:
run: |
make codespell

- name: Check the documentation coverage
run: |
make interrogate

- name: Check CITATION.cff validity
run: |
cffconvert --validate
Expand Down
82 changes: 0 additions & 82 deletions .github/workflows/doc-coverage.yml

This file was deleted.

29 changes: 16 additions & 13 deletions .github/workflows/docs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@ on:
pull_request:
branches:
- main
- public
- dev

jobs:
build:
Expand All @@ -17,7 +15,7 @@ jobs:
steps:
- name: Cache dependencies
id: pip-cache
uses: actions/cache@v3
uses: actions/cache@v4
with:
path: ~/.cache/pip
key: ${{ runner.os }}-pip
Expand Down Expand Up @@ -51,28 +49,33 @@ jobs:
path: docs/source/demo_notebooks
ref: main

- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
- name: Set up Python 3.10
uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
python-version: "3.10"

- name: Install package
run: |
python -m pip install --upgrade pip setuptools wheel
# NOTE(stes) Pandoc version must be at least (2.14.2) but less than (4.0.0).
# as of 29/10/23. Ubuntu 22.04 which is used for ubuntu-latest only has an
# as of 29/10/23. Ubuntu 22.04 which is used for ubuntu-latest only has an
# old pandoc version (2.9.). We will hence install the latest version manually.
# previou: sudo apt-get install -y pandoc
wget https://github.com/jgm/pandoc/releases/download/3.1.9/pandoc-3.1.9-1-amd64.deb
sudo dpkg -i pandoc-3.1.9-1-amd64.deb
rm pandoc-3.1.9-1-amd64.deb
pip install torch --extra-index-url https://download.pytorch.org/whl/cpu
pip install '.[docs]'
# NOTE(stes): Updated to latest version as of 17/04/2025, v3.6.4.
wget -q https://github.com/jgm/pandoc/releases/download/3.6.4/pandoc-3.6.4-1-amd64.deb
sudo dpkg -i pandoc-3.6.4-1-amd64.deb
rm pandoc-3.6.4-1-amd64.deb
pip install -r docs/requirements.txt

- name: Check software versions
run: |
sphinx-build --version
pandoc --version

- name: Build docs
run: |
ls docs/source/cebra-figures
# later also add the -n option to check for broken links
export SPHINXBUILD="sphinx-build"
export SPHINXOPTS="-W --keep-going -n"
make docs

Expand Down
9 changes: 8 additions & 1 deletion .github/workflows/release-pypi.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,18 @@ jobs:
steps:
- name: Cache dependencies
id: pip-cache
uses: actions/cache@v3
uses: actions/cache@v4
with:
path: ~/.cache/pip
key: ${{ runner.os }}-pip

- name: Install dependencies
run: |
pip install --upgrade pip
pip install wheel
# NOTE(stes) see https://github.com/pypa/twine/issues/1216#issuecomment-2629069669
pip install "packaging>=24.2"

- name: Checkout code
uses: actions/checkout@v3

Expand Down
19 changes: 19 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,25 @@ experiments/sweeps
exports/
demo_notebooks/
assets/
.remove

# demo run
.vscode/
auxiliary_behavior_data.h5
cebra_model.pt
data.npz
grid_search_models/
neural_data.npz
saved_models/

# demo run
.vscode/
auxiliary_behavior_data.h5
cebra_model.pt
data.npz
grid_search_models/
neural_data.npz
saved_models/

# Binary files
*.png
Expand Down
26 changes: 26 additions & 0 deletions AUTHORS.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
<img src="https://github.com/user-attachments/assets/1f327e57-8ee1-4a2f-afd3-2bbce885c2f8" width="200"/>



CEBRA was initially developed by **Mackenzie Mathis** and **Steffen Schneider** (2021+), who are co-inventors on the patent application [WO2023143843](https://infoscience.epfl.ch/entities/patent/0d9debed-4d22-47b7-bad1-f211e7010323).
**Jin Hwa Lee** contributed significantly to our first paper:

> **Schneider, S., Lee, J.H., & Mathis, M.W.**
> [*Learnable latent embeddings for joint behavioural and neural analysis.*](https://doi.org/10.1038/s41586-023-06031-6)
> Nature 617, 360–368 (2023)

CEBRA is actively developed by [**Mackenzie Mathis**](https://www.mackenziemathislab.org/) and [**Steffen Schneider**](https://dynamical-inference.ai/) and their labs.

It is a publicly available tool that has benefited from contributions and suggestions from many individuals: [CEBRA/graphs/contributors](https://github.com/AdaptiveMotorControlLab/CEBRA/graphs/contributors).

## CEBRA Extensions

### 2023
- **Steffen Schneider, Rodrigo González Laiz, Markus Frey, Mackenzie W. Mathis**
[*Identifiable attribution maps using regularized contrastive learning.*](https://sslneurips23.github.io/paper_pdfs/paper_80.pdf)
NeurIPS 4th Workshop on Self-Supervised Learning: Theory and Practice (2023)

### 2025
- **Steffen Schneider, Rodrigo González Laiz, Anastasiia Filippova, Markus Frey, Mackenzie W. Mathis**
[*Time-series attribution maps with regularized contrastive learning.*](https://openreview.net/forum?id=aGrCXoTB4P)
AISTATS (2025)
2 changes: 1 addition & 1 deletion Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ RUN make dist
FROM cebra-base

# install the cebra wheel
ENV WHEEL=cebra-0.4.0-py2.py3-none-any.whl
ENV WHEEL=cebra-0.6.0a1-py3-none-any.whl
WORKDIR /build
COPY --from=wheel /build/dist/${WHEEL} .
RUN pip install --no-cache-dir ${WHEEL}'[dev,integrations,datasets]'
Expand Down
4 changes: 2 additions & 2 deletions Makefile
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
CEBRA_VERSION := 0.4.0
CEBRA_VERSION := 0.6.0a1

dist:
python3 -m pip install virtualenv
Expand Down Expand Up @@ -55,7 +55,7 @@ interrogate:
--ignore-private \
--ignore-magic \
--omit-covered-files \
-f 90 \
-f 80 \
cebra

# Build documentation using sphinx
Expand Down
Loading