Skip to content

Batched inference CEBRA & padding at the Solver level #168

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conversation

CeliaBenquet
Copy link
Member

@CeliaBenquet CeliaBenquet commented Aug 23, 2024

fix https://github.com/AdaptiveMotorControlLab/CEBRA-dev/pull/746
fix #199

This PR adds the following features:

Example Usage of the new PyTorch API:

    import numpy as np
    import cebra.datasets
    import torch

    if torch.cuda.is_available():
        device = "cuda"
    else:
        device = "cpu"

    neural_data = cebra.load_data(file="neural_data.npz", key="neural")

    discrete_label = cebra.load_data(
        file="auxiliary_behavior_data.h5", key="auxiliary_variables", columns=["discrete"],
    )

    # 1. Define a CEBRA-ready dataset
    input_data = cebra.data.TensorDataset(
        torch.from_numpy(neural_data).type(torch.FloatTensor),
        discrete=torch.from_numpy(np.array(discrete_label[:, 0])).type(torch.LongTensor),
    ).to(device)

    # 2. Define a CEBRA model
    neural_model = cebra.models.init(
        name="offset10-model",
        num_neurons=input_data.input_dimension,
        num_units=32,
        num_output=2,
    ).to(device)

    input_data.configure_for(neural_model)

    # 3. Define the Loss Function Criterion and Optimizer
    crit = cebra.models.criterions.LearnableCosineInfoNCE(
        temperature=1,
    ).to(device)

    opt = torch.optim.Adam(
        list(neural_model.parameters()) + list(crit.parameters()),
        lr=0.001,
        weight_decay=0,
    )

    # 4. Initialize the CEBRA model
    solver = cebra.solver.init(
        name="single-session",
        model=neural_model,
        criterion=crit,
        optimizer=opt,
        tqdm_on=True,
    ).to(device)

    # 5. Define Data Loader
    loader = cebra.data.single_session.DiscreteDataLoader(
        dataset=input_data, num_steps=10, batch_size=200, prior="uniform"
    )

    # 6. Fit Model
    solver.fit(loader=loader)

    # 7. Transform Embedding
    x_train_emb = solver.transform(
        torch.from_numpy(neural_data).type(torch.FloatTensor).to(device),
        pad_before_transform=True,
        batch_size=512).to(device)

    # 8. Plot Embedding
    cebra.plot_embedding(
        x_train_emb.cpu(),
        discrete_label[:,0],
        markersize=10,
    )

all is similar to previous implementation but the inference part, which doesn't require to handle the padding of the input before passing it to the model.

@CeliaBenquet CeliaBenquet requested a review from MMathisLab April 24, 2025 09:01
Copy link
Member

@stes stes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

some early comments; apologies if i have asked some of these before

Comment on lines 314 to 320
@pytest.mark.parametrize(
"data_name, loader_initfunc, model_architecture, solver_initfunc",
multi_session_tests)
def test_multi_session(data_name, loader_initfunc, model_architecture,
solver_initfunc):
data = cebra.datasets.init(data_name)
loader = _get_loader(data, loader_initfunc)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why the changes here? i.e. did anything change that would cause the "old" multi session tests to break?

Copy link
Member Author

@CeliaBenquet CeliaBenquet Apr 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I restablished _get_loader as it was but added a return value as I need the dataset to configure it with the model.

Else,

  • I added the model_architecture as offset1-model is a special case for padding at transform.
  • I added the configure_for(model) as now this is handled in the solver.
  • I added some tests on the transform (was not done at all before), similar to the sklearn tests but at the pytorch level.

Comment on lines 193 to 215
single_session_tests_select_model = []
single_session_hybrid_tests_select_model = []
for model_name in ["offset1-model", "offset10-model"]:
for session_id in [None, 0, 5]:
for args in [
("demo-discrete", model_name, session_id,
cebra.data.DiscreteDataLoader),
("demo-continuous", model_name, session_id,
cebra.data.ContinuousDataLoader),
("demo-mixed", model_name, session_id, cebra.data.MixedDataLoader),
]:
single_session_tests_select_model.append(
(*args, cebra.solver.SingleSessionSolver))
single_session_hybrid_tests_select_model.append(
(*args, cebra.solver.SingleSessionHybridSolver))

multi_session_tests_select_model = []
for model_name in ["offset10-model"]:
for session_id in [None, 0, 1, 5, 2, 6, 4]:
for args in [("demo-continuous-multisession", model_name, session_id,
cebra.data.ContinuousMultiSessionDataLoader)]:
multi_session_tests_select_model.append(
(*args, cebra.solver.MultiSessionSolver))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you wrap the for loops here (quite complex) in functions, and only do the assingment on the global level?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I proposed something lmk if that's what you meant :)

@MMathisLab
Copy link
Member

doc error is: /home/runner/work/CEBRA/CEBRA/cebra/data/single_session.py:docstring of cebra.data.single_session.SingleSessionDataset.configure_for:3: WARNING: py:attr reference target not found: cebra_data.Dataset.offset /home/runner/work/CEBRA/CEBRA/cebra/data/multi_session.py:docstring of cebra.data.multi_session.MultiSessionDataset.configure_for:3: WARNING: py:attr reference target not found: cebra_data.Dataset.offset

@MMathisLab
Copy link
Member

@CeliaBenquet not sure I see your edits post review; did you push them?

@MMathisLab MMathisLab closed this Apr 24, 2025
@MMathisLab MMathisLab reopened this Apr 24, 2025
@CeliaBenquet CeliaBenquet requested a review from stes April 24, 2025 16:06
Copy link
Member

@stes stes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left some initial comments; broader discussion is a bit on the api design in the solver/base class --- lets discuss offline.

Copy link
Member

@stes stes left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, review got a bit longer again; I realized I missed a few things on the last review. High level comments:

  • I made some comments in solver which could be fine; I think some arguments were moved from the sklearn class to the solver class, but the motivation for that is not entirely clear. Mostly needs one round of discussion so we can settle on a good API design for these. Specifically, what is the usecase for storing these variables now in the solver, where are they called?
  • the new transform function adds a lot of duplicated code that should be unified; again, could be first discussed

Comment on lines +331 to +332
if hasattr(self, "n_features"):
state_dict["n_features"] = self.n_features
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this an attribute of the solver, vs. being returned directly from the model? For sklearn it makes sense to fix this, but for the solver this could also simply be a property to be returned from the model? Where is this used?

E.g. what would happen for an xCEBRA solver, where you have not a single feature dim, but multiple

Copy link
Member Author

@CeliaBenquet CeliaBenquet May 2, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for the multisession case that's already the case and that's a list.

num_features cannot be a property I think, because that can be defined only based on the inputs provided to the fit(), and later if we adapt the solver, it needs to be reset. This is used to be saved with the solver as it's needed when reloading it + to be called in the sklearn + to see if the solver is fitted when calling transform().

for xcebra that's just similar to the original sklearn one but at a lower level, so yes we need to think about it but we would have had to in any case.

@MMathisLab MMathisLab requested a review from stes May 12, 2025 09:44
@MMathisLab
Copy link
Member

At the risk of it not being perfect, lets merge this now; @CeliaBenquet can document in an issue the remaining Qs on the API design, but getting #251 merged is a priority 🦾

@MMathisLab MMathisLab merged commit 7ae5e1e into AdaptiveMotorControlLab:main May 23, 2025
11 checks passed
@stes stes deleted the batched-inference-and-padding branch June 5, 2025 19:40
stes pushed a commit that referenced this pull request Jun 5, 2025
* start tests

* remove print statements

* first passing test

* move functionality to base file in solver and separate in functions

* add test_select_model for multisession

* remove float16

* Improve modularity remove duplicate code and todos

* Add tests to solver

* Fix save/load

* Fix extra docs errors

* Add review updates

* apply ruff auto-fixes

* fix linting errors

* Run isort, ruff, yapf

* Fix gaussian mixture dataset import

* Fix all tests but xcebra tests

* Fix pytorch API usage example

* Make xCEBRA compatible with the batched inference & padding in solver

* Add some tests on transform() with xCEBRA

* Add some docstrings and typings and clean unnecessary changes

* Implement review comments

* Fix sklearn test

* Initial pass at integrating unifiedCEBRA

* Add name in NOTE

* Implement reviews on tests and typing

* Fix import errors

* Add select_model to aux solvers

* Fix tests

* Add mask tests

* Fix docs error

* Remove masking init()

* Remove shuffled neurons in unified dataset

* Remove extra datasets

* Add tests on the private functions in base solver

* Update tests and duplicate code based on review

* Fix quantized_embedding_norm undefined when `normalize=False` (#249)

* Fix tests

* Adapt unified code to get_model method

* Update mask.py

add headers to new files

* Update masking.py

- header

* Update test_data_masking.py

- header

* Implement review comments and fix typos

* Fix docs errors

* Remove np.int typing error

* Fix docstring warning

* Fix indentation docstrings

* Implement review comments

* Fix circular import and abstract method

* Add maskedmixin to __all__

* Implement extra review comments

* Change masking kwargs as tuple and not dict in sklearn impl

* Add integrations/decoders.py

* Fix typo

* minor simplification in solver

---------

Note, some comments in this PR overlap with
#168
and
#225
which were developed in parallel.
@stes stes mentioned this pull request Jun 5, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA signed enhancement New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Last complete batch indexes for batched inference can go above the input length
5 participants