Skip to content

Batched inference CEBRA & padding at the Solver level #168

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 107 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
107 commits
Select commit Hold shift + click to select a range
283de06
first proposal for batching in tranform method
gonlairo Jun 21, 2023
202e379
first running version of padding with batched inference
gonlairo Jun 22, 2023
1f1989d
start tests
gonlairo Jun 23, 2023
8665660
add pad_before_transform to fit function and add support for convolut…
gonlairo Sep 27, 2023
8d5b114
remove print statements
gonlairo Sep 27, 2023
32c5ecd
first passing test
gonlairo Sep 27, 2023
9928f63
add support for hybrid models
gonlairo Sep 28, 2023
be5630a
rewrite transform in sklearn API
gonlairo Sep 28, 2023
1300b20
baseline version of a torch.Datset
gonlairo Oct 16, 2023
bc6af24
move batching logic outside solver
gonlairo Oct 20, 2023
ec377b9
move functionality to base file in solver and separate in functions
gonlairo Oct 27, 2023
6f9ca98
add test_select_model for single session
gonlairo Oct 27, 2023
fbe7eb4
add checks and test for _process_batch
gonlairo Oct 27, 2023
463b0f8
add test_select_model for multisession
gonlairo Oct 30, 2023
5219171
make self.num_sessions compatible with single session training
gonlairo Oct 31, 2023
f9bd1a6
improve test_batched_transform_singlesession
gonlairo Nov 1, 2023
e23a7ef
make it work with small batches
gonlairo Nov 7, 2023
19c3f87
make test with multisession work
gonlairo Nov 8, 2023
87bebac
change to torch padding
gonlairo Nov 9, 2023
f0303e0
add argument to sklearn api
gonlairo Nov 9, 2023
8c8be85
add torch padding to _transform
gonlairo Nov 9, 2023
59df402
convert to torch if numpy array as inputs
gonlairo Nov 9, 2023
1aadc8b
add distinction between pad with data and pad with zeros and modify t…
gonlairo Nov 15, 2023
bc8ee25
differentiate between data padding and zero padding
gonlairo Nov 17, 2023
5e7a14c
remove float16
gonlairo Nov 24, 2023
928d882
change argument position
gonlairo Nov 27, 2023
07bac1c
clean test
gonlairo Nov 27, 2023
0823b54
clean test
gonlairo Nov 27, 2023
9fe3af3
Fix warning
CeliaBenquet Mar 26, 2024
b417a23
Improve modularity remove duplicate code and todos
CeliaBenquet Aug 21, 2024
83c1669
Add tests to solver
CeliaBenquet Aug 22, 2024
9c46eb9
Remove unused import in solver/utils
CeliaBenquet Aug 22, 2024
c845ec3
Fix test plot
CeliaBenquet Aug 22, 2024
9db3e37
Add some coverage
CeliaBenquet Aug 22, 2024
8e5f933
Fix save/load
CeliaBenquet Aug 22, 2024
d08e400
Remove duplicate configure_for in multi dataset
CeliaBenquet Aug 22, 2024
0c693dd
Make save/load cleaner
CeliaBenquet Aug 22, 2024
ae056b2
Merge branch 'main' into batched-inference-and-padding
CeliaBenquet Sep 18, 2024
794867b
Fix codespell errors
CeliaBenquet Sep 18, 2024
0bb6549
Fix docs compilation errors
CeliaBenquet Sep 18, 2024
04a102f
Fix formatting
CeliaBenquet Sep 18, 2024
7aab282
Fix extra docs errors
CeliaBenquet Sep 18, 2024
ffa66eb
Fix offset in docs
CeliaBenquet Sep 18, 2024
7f58607
Remove attribute ref
CeliaBenquet Sep 18, 2024
c2544c7
Add review updates
CeliaBenquet Sep 19, 2024
ad5da03
Merge branch 'main' into batched-inference-and-padding
stes Oct 20, 2024
f6aa2e6
Merge branch 'main' into batched-inference-and-padding
MMathisLab Oct 20, 2024
e1b7cc7
apply ruff auto-fixes
stes Oct 27, 2024
0eac868
Merge remote-tracking branch 'origin/main' into batched-inference-and…
stes Oct 27, 2024
81b964c
Concatenate last batches for batched inference (#200)
CeliaBenquet Jan 21, 2025
a09d123
Fix linting errors in tests (#188)
stes Oct 27, 2024
521f003
Fix `scikit-learn` reference in conda environment files (#195)
stes Nov 8, 2024
46610e3
Add support for new __sklearn_tags__ (#205)
stes Dec 16, 2024
e8004ba
Update workflows to actions/setup-python@v5, actions/cache@v4 (#212)
stes Jan 21, 2025
ddc00f4
Fix deprecation warning force_all_finite -> ensure_all_finite for skl…
icarosadero Jan 22, 2025
7dc9f81
Add tests to check legacy model loading (#214)
stes Jan 29, 2025
a2a6c44
Add improved goodness of fit implementation (#190)
stes Feb 2, 2025
a3b143f
Support numpy 2, upgrade tests to support torch 2.6 (#221)
stes Feb 2, 2025
0d5d82a
Release 0.5.0rc1 (#189)
stes Feb 2, 2025
92fd9bc
Fix pypi action (#222)
stes Feb 3, 2025
69d91ef
Update base.py (#224)
icarosadero Feb 18, 2025
782b63a
Change max consistency value to 100 instead of 99 (#227)
CeliaBenquet Mar 1, 2025
d72b055
Update assets.py --> force check for parent dir (#230)
MMathisLab Mar 1, 2025
9fd91c3
User docs minor edit (#229)
MMathisLab Mar 1, 2025
8d636e9
General Doc refresher (#232)
MMathisLab Mar 3, 2025
36370be
render plotly in our docs, show code/doc version (#231)
MMathisLab Mar 4, 2025
f7f4d7f
Update layout.html (#233)
MMathisLab Mar 6, 2025
798f7b2
Update conf.py (#234)
MMathisLab Mar 6, 2025
4a2996d
Refactoring setup.cfg (#228)
MMathisLab Mar 15, 2025
7abd1b0
Home page landing update (#235)
MMathisLab Mar 15, 2025
673019a
v0.5.0 (#238)
MMathisLab Apr 17, 2025
9625680
Upgrade docs build (#241)
stes Apr 18, 2025
95e5296
Allow indexing of the cebra docs (#242)
stes Apr 20, 2025
20f5a77
Fix broken docs coverage workflows (#246)
stes Apr 23, 2025
0d85abb
Add xCEBRA implementation (AISTATS 2025) (#225)
gonlairo Apr 23, 2025
b19be59
start tests
gonlairo Jun 23, 2023
e908083
remove print statements
gonlairo Sep 27, 2023
3d2b1e3
first passing test
gonlairo Sep 27, 2023
3ef4bc1
move functionality to base file in solver and separate in functions
gonlairo Oct 27, 2023
ad56472
add test_select_model for multisession
gonlairo Oct 30, 2023
b73c123
remove float16
gonlairo Nov 24, 2023
d71ca8d
Improve modularity remove duplicate code and todos
CeliaBenquet Aug 21, 2024
3e91459
Add tests to solver
CeliaBenquet Aug 22, 2024
c6179ad
Fix save/load
CeliaBenquet Aug 22, 2024
dafabe5
Fix extra docs errors
CeliaBenquet Sep 18, 2024
7b0cc68
Add review updates
CeliaBenquet Sep 19, 2024
7dfd4b9
apply ruff auto-fixes
stes Oct 27, 2024
3acbdf4
fix linting errors
stes Jan 21, 2025
5745449
Run isort, ruff, yapf
CeliaBenquet Apr 23, 2025
fa3cd3e
Merge remote-tracking branch 'upstream/main' into batched-inference-a…
CeliaBenquet Apr 23, 2025
1453885
Merge branch 'main' into batched-inference-and-padding
MMathisLab Apr 23, 2025
acd2111
Fix gaussian mixture dataset import
CeliaBenquet Apr 23, 2025
217a8a7
Fix all tests but xcebra tests
CeliaBenquet Apr 23, 2025
a1218aa
Fix pytorch API usage example
CeliaBenquet Apr 24, 2025
64d1db8
Make xCEBRA compatible with the batched inference & padding in solver
CeliaBenquet Apr 24, 2025
9875a38
Add some tests on transform() with xCEBRA
CeliaBenquet Apr 24, 2025
65fc455
Add some docstrings and typings and clean unnecessary changes
CeliaBenquet Apr 24, 2025
1d0c498
Implement review comments
CeliaBenquet Apr 24, 2025
4a25899
Fix sklearn test
CeliaBenquet Apr 25, 2025
0d56e44
Add name in NOTE
CeliaBenquet Apr 25, 2025
c5dc011
Implement reviews on tests and typing
CeliaBenquet Apr 25, 2025
c9fa5c8
Fix import errors
CeliaBenquet Apr 28, 2025
4632c04
Add select_model to aux solvers
CeliaBenquet Apr 28, 2025
22e3c47
Fix docs error
CeliaBenquet Apr 30, 2025
2fcfb7f
Add tests on the private functions in base solver
CeliaBenquet May 2, 2025
66fc6aa
Update tests and duplicate code based on review
CeliaBenquet May 5, 2025
4d68110
Merge branch 'main' into batched-inference-and-padding
CeliaBenquet May 5, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions cebra/data/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,8 @@ class Loader(abc.ABC, cebra.io.HasDevice):
doc="""A dataset instance specifying a ``__getitem__`` function.""",
)

time_offset: int = dataclasses.field(default=10)

num_steps: int = dataclasses.field(
default=None,
doc=
Expand Down
36 changes: 25 additions & 11 deletions cebra/data/multi_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,10 @@

import literate_dataclasses as dataclasses
import torch
import torch.nn as nn

import cebra.data as cebra_data
import cebra.distributions as cebra_distr
import cebra.distributions
from cebra.data.datatypes import Batch
from cebra.data.datatypes import BatchIndex

Expand Down Expand Up @@ -104,10 +105,25 @@ def load_batch(self, index: BatchIndex) -> List[Batch]:
) for session_id, session in enumerate(self.iter_sessions())
]

def configure_for(self, model):
self.offset = model.get_offset()
for session in self.iter_sessions():
session.configure_for(model)
def configure_for(self, model: "cebra.models.Model"):
"""Configure the dataset offset for the provided model.

Call this function before indexing the dataset. This sets the
:py:attr:`~.Dataset.offset` attribute of the dataset.

Args:
model: The model to configure the dataset for.
"""
if not isinstance(model, nn.ModuleList):
raise ValueError(
"The model must be a nn.ModuleList to configure the dataset.")
if len(model) != self.num_sessions:
raise ValueError(
f"The model must have {self.num_sessions} sessions, but got {len(model)}."
)

for i, session in enumerate(self.iter_sessions()):
session.configure_for(model[i])


@dataclasses.dataclass
Expand All @@ -119,12 +135,10 @@ class MultiSessionLoader(cebra_data.Loader):
dimension, it is better to use a :py:class:`cebra.data.single_session.MixedDataLoader`.
"""

time_offset: int = dataclasses.field(default=10)

def __post_init__(self):
super().__post_init__()
self.sampler = cebra_distr.MultisessionSampler(self.dataset,
self.time_offset)
self.sampler = cebra.distributions.MultisessionSampler(
self.dataset, self.time_offset)

def get_indices(self, num_samples: int) -> List[BatchIndex]:
ref_idx = self.sampler.sample_prior(self.batch_size)
Expand All @@ -149,7 +163,6 @@ class ContinuousMultiSessionDataLoader(MultiSessionLoader):
"""Contrastive learning conditioned on a continuous behavior variable."""

conditional: str = "time_delta"
time_offset: int = dataclasses.field(default=10)

@property
def index(self):
Expand All @@ -163,7 +176,8 @@ class DiscreteMultiSessionDataLoader(MultiSessionLoader):
# Overwrite sampler with the discrete implementation
# Generalize MultisessionSampler to avoid doing this?
def __post_init__(self):
self.sampler = cebra_distr.DiscreteMultisessionSampler(self.dataset)
self.sampler = cebra.distributions.DiscreteMultisessionSampler(
self.dataset)

@property
def index(self):
Expand Down
2 changes: 0 additions & 2 deletions cebra/data/single_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,6 @@ class ContinuousDataLoader(cebra_data.Loader):
and become equivalent to time contrastive learning.
""",
)
time_offset: int = dataclasses.field(default=10)
delta: float = dataclasses.field(default=0.1)

def __post_init__(self):
Expand Down Expand Up @@ -278,7 +277,6 @@ class MixedDataLoader(cebra_data.Loader):
"""

conditional: str = dataclasses.field(default="time_delta")
time_offset: int = dataclasses.field(default=10)

@property
def dindex(self):
Expand Down
2 changes: 1 addition & 1 deletion cebra/datasets/demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
import cebra.io
from cebra.datasets import register

_DEFAULT_NUM_TIMEPOINTS = 100000
_DEFAULT_NUM_TIMEPOINTS = 1_000


class DemoDataset(cebra.data.SingleSessionDataset):
Expand Down
109 changes: 40 additions & 69 deletions cebra/integrations/sklearn/cebra.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
np.dtypes.Float64DType, np.dtypes.Int64DType
]


def check_version(estimator):
# NOTE(stes): required as a check for the old way of specifying tags
# https://github.com/scikit-learn/scikit-learn/pull/29677#issuecomment-2334229165
Expand All @@ -76,7 +77,6 @@ def _safe_torch_load(filename, weights_only, **kwargs):
return checkpoint



def _init_loader(
is_cont: bool,
is_disc: bool,
Expand Down Expand Up @@ -129,7 +129,7 @@ def _init_loader(
(not is_cont, not is_disc, is_multi),
]
if any(all(combination) for combination in incompatible_combinations):
raise ValueError(f"Invalid index combination.\n"
raise ValueError("Invalid index combination.\n"
f"Continuous: {is_cont},\n"
f"Discrete: {is_disc},\n"
f"Hybrid training: {is_hybrid},\n"
Expand Down Expand Up @@ -293,7 +293,7 @@ def _require_arg(key):
"single-session",
)

error_message = (f"Invalid index combination.\n"
error_message = ("Invalid index combination.\n"
f"Continuous: {is_cont},\n"
f"Discrete: {is_disc},\n"
f"Hybrid training: {is_hybrid},\n"
Expand Down Expand Up @@ -340,7 +340,7 @@ def _load_cebra_with_sklearn_backend(cebra_info: Dict) -> "CEBRA":
if missing_keys:
raise ValueError(
f"Missing keys in data dictionary: {', '.join(missing_keys)}. "
f"You can try loading the CEBRA model with the torch backend.")
"You can try loading the CEBRA model with the torch backend.")

args, state, state_dict = cebra_info['args'], cebra_info[
'state'], cebra_info['state_dict']
Expand Down Expand Up @@ -656,12 +656,12 @@ def _get_dataset_multi(X: List[Iterable], y: List[Iterable]):
# TODO(celia): to make it work for multiple set of index. For now, y should be a tuple of one list only
if isinstance(y, tuple) and len(y) > 1:
raise NotImplementedError(
f"Support for multiple set of index is not implemented in multissesion training, "
"Support for multiple set of index is not implemented in multissesion training, "
f"got {len(y)} sets of indexes.")

if not _are_sessions_equal(X, y):
raise ValueError(
f"Invalid number of sessions: number of sessions in X and y need to match, "
"Invalid number of sessions: number of sessions in X and y need to match, "
f"got X:{len(X)} and y:{[len(y_i) for y_i in y]}.")

for session in range(len(X)):
Expand All @@ -685,8 +685,8 @@ def _get_dataset_multi(X: List[Iterable], y: List[Iterable]):
else:
if not _are_sessions_equal(X, y):
raise ValueError(
f"Invalid number of samples or labels sessions: provide one session for single-session training, "
f"and make sure the number of samples in X and y need match, "
"Invalid number of samples or labels sessions: provide one session for single-session training, "
"and make sure the number of samples in X and y match, "
f"got {len(X)} and {[len(y_i) for y_i in y]}.")
is_multisession = False
dataset = _get_dataset(X, y)
Expand Down Expand Up @@ -813,8 +813,6 @@ def _configure_for_all(
"receptive fields/offsets larger than 1 via the sklearn API. "
"Please use a different model, or revert to the pytorch "
"API for training.")

d.configure_for(model[n])
else:
if not isinstance(model, cebra.models.ConvolutionalModelMixin):
if len(model.get_offset()) > 1:
Expand All @@ -824,37 +822,13 @@ def _configure_for_all(
"Please use a different model, or revert to the pytorch "
"API for training.")

dataset.configure_for(model)
dataset.configure_for(model)

def _select_model(self, X: Union[npt.NDArray, torch.Tensor],
session_id: int):
# Choose the model and get its corresponding offset
if self.num_sessions is not None: # multisession implementation
if session_id is None:
raise RuntimeError(
"No session_id provided: multisession model requires a session_id to choose the model corresponding to your data shape."
)
if session_id >= self.num_sessions or session_id < 0:
raise RuntimeError(
f"Invalid session_id {session_id}: session_id for the current multisession model must be between 0 and {self.num_sessions-1}."
)
if self.n_features_[session_id] != X.shape[1]:
raise ValueError(
f"Invalid input shape: model for session {session_id} requires an input of shape"
f"(n_samples, {self.n_features_[session_id]}), got (n_samples, {X.shape[1]})."
)

model = self.model_[session_id]
model.to(self.device_)
else: # single session
if session_id is not None and session_id > 0:
raise RuntimeError(
f"Invalid session_id {session_id}: single session models only takes an optional null session_id."
)
model = self.model_

offset = model.get_offset()
return model, offset
if isinstance(X, np.ndarray):
X = torch.from_numpy(X)
return self.solver_._select_model(X, session_id=session_id)

def _check_labels_types(self, y: tuple, session_id: Optional[int] = None):
"""Check that the input labels are compatible with the labels used to fit the model.
Expand All @@ -876,7 +850,7 @@ def _check_labels_types(self, y: tuple, session_id: Optional[int] = None):
# Check that same number of index
if len(self.label_types_) != n_idx:
raise ValueError(
f"Number of index invalid: labels must have the same number of index as for fitting,"
"Number of index invalid: labels must have the same number of index as for fitting,"
f"expects {len(self.label_types_)}, got {n_idx} idx.")

for i in range(len(self.label_types_)): # for each index
Expand All @@ -889,12 +863,12 @@ def _check_labels_types(self, y: tuple, session_id: Optional[int] = None):
> 1): # is there more than one feature in the index
if label_types_idx[1][1] != y[i].shape[1]:
raise ValueError(
f"Labels invalid: must have the same number of features as the ones used for fitting,"
"Labels invalid: must have the same number of features as the ones used for fitting,"
f"expects {label_types_idx[1]}, got {y[i].shape}.")

if label_types_idx[0] != y[i].dtype:
raise ValueError(
f"Labels invalid: must have the same type of features as the ones used for fitting,"
"Labels invalid: must have the same type of features as the ones used for fitting,"
f"expects {label_types_idx[0]}, got {y[i].dtype}.")

def _prepare_fit(
Expand Down Expand Up @@ -1081,14 +1055,13 @@ def _partial_fit(

# Save variables of interest as semi-private attributes
self.model_ = model
self.n_features_ = ([
loader.dataset.get_input_dimension(session_id)
for session_id in range(loader.dataset.num_sessions)
] if is_multisession else loader.dataset.input_dimension)

self.n_features_ = solver.n_features
self.num_sessions_ = solver.num_sessions if hasattr(
solver, "num_sessions") else None
self.solver_ = solver
self.n_features_in_ = ([model[n].num_input for n in range(len(model))]
if is_multisession else model.num_input)
self.num_sessions_ = loader.dataset.num_sessions if is_multisession else None

return self

Expand Down Expand Up @@ -1236,11 +1209,13 @@ def fit(

def transform(self,
X: Union[npt.NDArray, torch.Tensor],
batch_size: Optional[int] = None,
session_id: Optional[int] = None) -> npt.NDArray:
"""Transform an input sequence and return the embedding.

Args:
X: A numpy array or torch tensor of size ``time x dimension``.
batch_size:
session_id: The session ID, an :py:class:`int` between 0 and :py:attr:`num_sessions` for
multisession, set to ``None`` for single session.

Expand All @@ -1255,37 +1230,28 @@ def transform(self,
>>> cebra_model = cebra.CEBRA(max_iterations=10)
>>> cebra_model.fit(dataset)
CEBRA(max_iterations=10)
>>> embedding = cebra_model.transform(dataset)
>>> embedding = cebra_model.transform(dataset, batch_size=200)

"""

sklearn_utils_validation.check_is_fitted(self, "n_features_")
model, offset = self._select_model(X, session_id)
self.solver_._check_is_session_id_valid(session_id=session_id)

# Input validation
X = sklearn_utils.check_input_array(X, min_samples=len(self.offset_))
input_dtype = X.dtype
if torch.is_tensor(X):
X = X.detach().cpu()

with torch.no_grad():
model.eval()

if self.pad_before_transform:
X = np.pad(X, ((offset.left, offset.right - 1), (0, 0)),
mode="edge")
X = torch.from_numpy(X).float().to(self.device_)
X = sklearn_utils.check_input_array(X, min_samples=len(self.offset_))

if isinstance(model, cebra.models.ConvolutionalModelMixin):
# Fully convolutional evaluation, switch (T, C) -> (1, C, T)
X = X.transpose(1, 0).unsqueeze(0)
output = model(X).cpu().numpy().squeeze(0).transpose(1, 0)
else:
# Standard evaluation, (T, C, dt)
output = model(X).cpu().numpy()
if isinstance(X, np.ndarray):
X = torch.from_numpy(X)

if input_dtype == "float64":
return output.astype(input_dtype)
with torch.no_grad():
output = self.solver_.transform(
inputs=X,
pad_before_transform=self.pad_before_transform,
session_id=session_id,
batch_size=batch_size)

return output
return output.detach().cpu().numpy()

def fit_transform(
self,
Expand Down Expand Up @@ -1501,6 +1467,11 @@ def load(cls,
else:
cebra_ = _check_type_checkpoint(checkpoint)

n_features = cebra_.n_features_
cebra_.solver_.n_features = ([
session_n_features for session_n_features in n_features
] if isinstance(n_features, list) else n_features)

return cebra_

def to(self, device: Union[str, torch.device]):
Expand Down
3 changes: 2 additions & 1 deletion cebra/integrations/sklearn/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,8 @@ def infonce_loss(
f"got {len(y[0])} sessions.")

model, _ = cebra_model._select_model(
X, session_id) # check session_id validity and corresponding model
X, session_id=session_id
) # check session_id validity and corresponding model
cebra_model._check_labels_types(y, session_id=session_id)

dataset, is_multisession = cebra_model._prepare_data(X, y) # single session
Expand Down
3 changes: 2 additions & 1 deletion cebra/integrations/sklearn/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,8 @@ def check_input_array(X: npt.NDArray, *, min_samples: int) -> npt.NDArray:
X,
accept_sparse=False,
accept_large_sparse=False,
dtype=("float16", "float32", "float64"),
# NOTE(celia): remove float16 because F.pad does not allow float16.
dtype=("float32", "float64"),
order=None,
copy=False,
ensure_2d=True,
Expand Down
Loading