diff --git a/docs/source/api.rst b/docs/source/api.rst index 5114f0e6..b18bb813 100644 --- a/docs/source/api.rst +++ b/docs/source/api.rst @@ -387,7 +387,7 @@ Regression :nosignatures: :template: class.rst - UCIDataModule + UCIRegressionDataModule .. currentmodule:: torch_uncertainty.datamodules.segmentation diff --git a/experiments/regression/uci_datasets/deep_ensemble.py b/experiments/regression/uci_datasets/deep_ensemble.py index 6628e6e6..2a8bdc8f 100644 --- a/experiments/regression/uci_datasets/deep_ensemble.py +++ b/experiments/regression/uci_datasets/deep_ensemble.py @@ -2,10 +2,10 @@ from torch_uncertainty import cli_main, init_args from torch_uncertainty.baselines import DeepEnsemblesBaseline -from torch_uncertainty.datamodules import UCIDataModule +from torch_uncertainty.datamodules import UCIRegressionDataModule if __name__ == "__main__": - args = init_args(DeepEnsemblesBaseline, UCIDataModule) + args = init_args(DeepEnsemblesBaseline, UCIRegressionDataModule) if args.root == "./data/": root = Path(__file__).parent.absolute().parents[2] else: @@ -15,7 +15,7 @@ # datamodule args.root = str(root / "data") - dm = UCIDataModule(dataset_name="kin8nm", **vars(args)) + dm = UCIRegressionDataModule(dataset_name="kin8nm", **vars(args)) # model args.task = "regression" diff --git a/experiments/regression/uci_datasets/mlp.py b/experiments/regression/uci_datasets/mlp.py index 7c187673..54a9fafc 100644 --- a/experiments/regression/uci_datasets/mlp.py +++ b/experiments/regression/uci_datasets/mlp.py @@ -3,7 +3,7 @@ from torch_uncertainty import TULightningCLI from torch_uncertainty.baselines.regression import MLPBaseline -from torch_uncertainty.datamodules import UCIDataModule +from torch_uncertainty.datamodules import UCIRegressionDataModule class MLPCLI(TULightningCLI): @@ -12,7 +12,7 @@ def add_arguments_to_parser(self, parser: LightningArgumentParser) -> None: def cli_main() -> MLPCLI: - return MLPCLI(MLPBaseline, UCIDataModule) + return MLPCLI(MLPBaseline, UCIRegressionDataModule) if __name__ == "__main__": diff --git a/tests/datamodules/test_uci_regression.py b/tests/datamodules/test_uci_regression.py index 1297666c..aeda20ea 100644 --- a/tests/datamodules/test_uci_regression.py +++ b/tests/datamodules/test_uci_regression.py @@ -1,14 +1,14 @@ from functools import partial from tests._dummies.dataset import DummyRegressionDataset -from torch_uncertainty.datamodules import UCIDataModule +from torch_uncertainty.datamodules import UCIRegressionDataModule -class TestUCIDataModule: - """Testing the UCIDataModule datamodule class.""" +class TestUCIRegressionDataModule: + """Testing the UCIRegressionDataModule datamodule class.""" def test_uci_regression(self): - dm = UCIDataModule( + dm = UCIRegressionDataModule( dataset_name="kin8nm", root="./data/", batch_size=128 ) diff --git a/torch_uncertainty/datamodules/__init__.py b/torch_uncertainty/datamodules/__init__.py index 670dfc4c..b4c9797d 100644 --- a/torch_uncertainty/datamodules/__init__.py +++ b/torch_uncertainty/datamodules/__init__.py @@ -1,9 +1,16 @@ # ruff: noqa: F401 from .abstract import TUDataModule -from .classification.cifar10 import CIFAR10DataModule -from .classification.cifar100 import CIFAR100DataModule -from .classification.imagenet import ImageNetDataModule -from .classification.mnist import MNISTDataModule -from .classification.tiny_imagenet import TinyImageNetDataModule +from .classification import ( + BankMarketingDataModule, + CIFAR10DataModule, + CIFAR100DataModule, + Dota2GamesDataModule, + HTRU2DataModule, + ImageNetDataModule, + MNISTDataModule, + OnlineShoppersDataModule, + SpamBaseDataModule, + TinyImageNetDataModule, +) from .segmentation import CamVidDataModule, CityscapesDataModule -from .uci_regression import UCIDataModule +from .uci_regression import UCIRegressionDataModule diff --git a/torch_uncertainty/datamodules/uci_regression.py b/torch_uncertainty/datamodules/uci_regression.py index a5cbe8af..4d1b304e 100644 --- a/torch_uncertainty/datamodules/uci_regression.py +++ b/torch_uncertainty/datamodules/uci_regression.py @@ -9,7 +9,7 @@ from .abstract import TUDataModule -class UCIDataModule(TUDataModule): +class UCIRegressionDataModule(TUDataModule): training_task = "regression" def __init__(