Skip to content

Commit

Permalink
🔨 Rename UCI regression dm
Browse files Browse the repository at this point in the history
  • Loading branch information
o-laurent committed Nov 17, 2024
1 parent a71f4cf commit cde2997
Show file tree
Hide file tree
Showing 6 changed files with 24 additions and 17 deletions.
2 changes: 1 addition & 1 deletion docs/source/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -387,7 +387,7 @@ Regression
:nosignatures:
:template: class.rst

UCIDataModule
UCIRegressionDataModule

.. currentmodule:: torch_uncertainty.datamodules.segmentation

Expand Down
6 changes: 3 additions & 3 deletions experiments/regression/uci_datasets/deep_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@

from torch_uncertainty import cli_main, init_args
from torch_uncertainty.baselines import DeepEnsemblesBaseline
from torch_uncertainty.datamodules import UCIDataModule
from torch_uncertainty.datamodules import UCIRegressionDataModule

if __name__ == "__main__":
args = init_args(DeepEnsemblesBaseline, UCIDataModule)
args = init_args(DeepEnsemblesBaseline, UCIRegressionDataModule)
if args.root == "./data/":
root = Path(__file__).parent.absolute().parents[2]
else:
Expand All @@ -15,7 +15,7 @@

# datamodule
args.root = str(root / "data")
dm = UCIDataModule(dataset_name="kin8nm", **vars(args))
dm = UCIRegressionDataModule(dataset_name="kin8nm", **vars(args))

# model
args.task = "regression"
Expand Down
4 changes: 2 additions & 2 deletions experiments/regression/uci_datasets/mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from torch_uncertainty import TULightningCLI
from torch_uncertainty.baselines.regression import MLPBaseline
from torch_uncertainty.datamodules import UCIDataModule
from torch_uncertainty.datamodules import UCIRegressionDataModule


class MLPCLI(TULightningCLI):
Expand All @@ -12,7 +12,7 @@ def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None:


def cli_main() -> MLPCLI:
return MLPCLI(MLPBaseline, UCIDataModule)
return MLPCLI(MLPBaseline, UCIRegressionDataModule)


if __name__ == "__main__":
Expand Down
8 changes: 4 additions & 4 deletions tests/datamodules/test_uci_regression.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from functools import partial

from tests._dummies.dataset import DummyRegressionDataset
from torch_uncertainty.datamodules import UCIDataModule
from torch_uncertainty.datamodules import UCIRegressionDataModule


class TestUCIDataModule:
"""Testing the UCIDataModule datamodule class."""
class TestUCIRegressionDataModule:
"""Testing the UCIRegressionDataModule datamodule class."""

def test_uci_regression(self):
dm = UCIDataModule(
dm = UCIRegressionDataModule(
dataset_name="kin8nm", root="./data/", batch_size=128
)

Expand Down
19 changes: 13 additions & 6 deletions torch_uncertainty/datamodules/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,16 @@
# ruff: noqa: F401
from .abstract import TUDataModule
from .classification.cifar10 import CIFAR10DataModule
from .classification.cifar100 import CIFAR100DataModule
from .classification.imagenet import ImageNetDataModule
from .classification.mnist import MNISTDataModule
from .classification.tiny_imagenet import TinyImageNetDataModule
from .classification import (
BankMarketingDataModule,
CIFAR10DataModule,
CIFAR100DataModule,
Dota2GamesDataModule,
HTRU2DataModule,
ImageNetDataModule,
MNISTDataModule,
OnlineShoppersDataModule,
SpamBaseDataModule,
TinyImageNetDataModule,
)
from .segmentation import CamVidDataModule, CityscapesDataModule
from .uci_regression import UCIDataModule
from .uci_regression import UCIRegressionDataModule
2 changes: 1 addition & 1 deletion torch_uncertainty/datamodules/uci_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from .abstract import TUDataModule


class UCIDataModule(TUDataModule):
class UCIRegressionDataModule(TUDataModule):
training_task = "regression"

def __init__(
Expand Down

0 comments on commit cde2997

Please sign in to comment.