Skip to content

Commit 8b5c9b8

Browse files
committed
Run isort, ruff, yapf
1 parent edc467d commit 8b5c9b8

File tree

4 files changed

+9
-11
lines changed

4 files changed

+9
-11
lines changed

cebra/data/single_session.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -370,7 +370,7 @@ def __post_init__(self):
370370

371371
self._init_behavior_distribution()
372372
self._init_time_distribution()
373-
373+
374374
if self.conditional != "time_delta":
375375
raise NotImplementedError(
376376
"Hybrid training is currently only implemented using the ``time_delta`` "

cebra/integrations/sklearn/cebra.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@
5151
np.dtypes.Float64DType, np.dtypes.Int64DType
5252
]
5353

54+
5455
def check_version(estimator):
5556
# NOTE(stes): required as a check for the old way of specifying tags
5657
# https://github.com/scikit-learn/scikit-learn/pull/29677#issuecomment-2334229165
@@ -76,8 +77,6 @@ def _safe_torch_load(filename, weights_only, **kwargs):
7677
return checkpoint
7778

7879

79-
80-
8180
def _init_loader(
8281
is_cont: bool,
8382
is_disc: bool,

cebra/solver/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
import abc
3434
import os
3535
import warnings
36-
from typing import Callable, Dict, Iterable, List, Literal, Optional, Tuple, Union
36+
from typing import Callable, Dict, List, Literal, Optional, Tuple, Union
3737

3838
import literate_dataclasses as dataclasses
3939
import numpy.typing as npt

tests/test_solver.py

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -59,15 +59,14 @@
5959
cebra.data.ContinuousMultiSessionDataLoader, "offset1-model"),
6060
("demo-continuous-multisession",
6161
cebra.data.ContinuousMultiSessionDataLoader, "offset10-model"),
62-
("demo-discrete-multisession",
63-
cebra.data.DiscreteMultiSessionDataLoader, "offset1-model"),
64-
("demo-discrete-multisession",
65-
cebra.data.DiscreteMultiSessionDataLoader, "offset10-model"),
62+
("demo-discrete-multisession", cebra.data.DiscreteMultiSessionDataLoader,
63+
"offset1-model"),
64+
("demo-discrete-multisession", cebra.data.DiscreteMultiSessionDataLoader,
65+
"offset10-model"),
6666
]:
6767
multi_session_tests.append((*args, cebra.solver.MultiSessionSolver))
6868

6969

70-
7170
def _get_loader(data, loader_initfunc):
7271
kwargs = dict(num_steps=5, batch_size=32)
7372
loader = loader_initfunc(data, **kwargs)
@@ -168,7 +167,7 @@ def test_single_session(data_name, loader_initfunc, model_architecture,
168167

169168
assert solver.num_sessions is None
170169
assert solver.n_features == X.shape[1]
171-
170+
172171
embedding = solver.transform(X)
173172
assert isinstance(embedding, torch.Tensor)
174173
assert embedding.shape == (X.shape[0], OUTPUT_DIMENSION)
@@ -498,7 +497,7 @@ def test_multi_session_2(data_name, loader_initfunc, solver_initfunc):
498497
assert isinstance(log, dict)
499498

500499
solver.fit(loader)
501-
500+
502501

503502
def create_model(model_name, input_dimension):
504503
return cebra.models.init(model_name,

0 commit comments

Comments
 (0)