Skip to content

Commit 1e00413

Browse files
committed
[update] Update the PR according to the latest version and disable shuffle
Since the shuffle is performed in the split functions, we do not need shuffle before the splitting. For this reason, I disabled the shuffle argument from BaseDataset and added the shuffle keyword for resampling_strategy_args.
1 parent 36cef27 commit 1e00413

File tree

13 files changed

+70
-71
lines changed

13 files changed

+70
-71
lines changed

autoPyTorch/api/base_task.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
STRING_TO_TASK_TYPES,
3535
)
3636
from autoPyTorch.datasets.base_dataset import BaseDataset
37-
from autoPyTorch.datasets.resampling_strategy import CrossValTypes, HoldoutValTypes
37+
from autoPyTorch.datasets.resampling_strategy import CrossValTypes, HoldoutTypes
3838
from autoPyTorch.ensemble.ensemble_builder import EnsembleBuilderManager
3939
from autoPyTorch.ensemble.ensemble_selection import EnsembleSelection
4040
from autoPyTorch.ensemble.singlebest_ensemble import SingleBest
@@ -138,7 +138,7 @@ def __init__(
138138
include_components: Optional[Dict] = None,
139139
exclude_components: Optional[Dict] = None,
140140
backend: Optional[Backend] = None,
141-
resampling_strategy: Union[CrossValTypes, HoldoutValTypes] = HoldoutValTypes.holdout_validation,
141+
resampling_strategy: Union[CrossValTypes, HoldoutTypes] = HoldoutTypes.holdout,
142142
resampling_strategy_args: Optional[Dict[str, Any]] = None,
143143
search_space_updates: Optional[HyperparameterSearchSpaceUpdates] = None,
144144
task_type: Optional[str] = None
@@ -1171,7 +1171,7 @@ def predict(
11711171
assert self.ensemble_ is not None, "Load models should error out if no ensemble"
11721172
self.ensemble_ = cast(Union[SingleBest, EnsembleSelection], self.ensemble_)
11731173

1174-
if isinstance(self.resampling_strategy, HoldoutValTypes):
1174+
if isinstance(self.resampling_strategy, HoldoutTypes):
11751175
models = self.models_
11761176
elif isinstance(self.resampling_strategy, CrossValTypes):
11771177
models = self.cv_models_

autoPyTorch/api/tabular_classification.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from autoPyTorch.datasets.base_dataset import BaseDataset
1616
from autoPyTorch.datasets.resampling_strategy import (
1717
CrossValTypes,
18-
HoldoutValTypes,
18+
HoldoutTypes,
1919
)
2020
from autoPyTorch.datasets.tabular_dataset import TabularDataset
2121
from autoPyTorch.pipeline.tabular_classification import TabularClassificationPipeline
@@ -72,7 +72,7 @@ def __init__(
7272
delete_output_folder_after_terminate: bool = True,
7373
include_components: Optional[Dict] = None,
7474
exclude_components: Optional[Dict] = None,
75-
resampling_strategy: Union[CrossValTypes, HoldoutValTypes] = HoldoutValTypes.holdout_validation,
75+
resampling_strategy: Union[CrossValTypes, HoldoutTypes] = HoldoutTypes.holdout,
7676
resampling_strategy_args: Optional[Dict[str, Any]] = None,
7777
backend: Optional[Backend] = None,
7878
search_space_updates: Optional[HyperparameterSearchSpaceUpdates] = None

autoPyTorch/api/tabular_regression.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from autoPyTorch.datasets.base_dataset import BaseDataset
1616
from autoPyTorch.datasets.resampling_strategy import (
1717
CrossValTypes,
18-
HoldoutValTypes,
18+
HoldoutTypes,
1919
)
2020
from autoPyTorch.datasets.tabular_dataset import TabularDataset
2121
from autoPyTorch.pipeline.tabular_regression import TabularRegressionPipeline
@@ -64,7 +64,7 @@ def __init__(
6464
delete_output_folder_after_terminate: bool = True,
6565
include_components: Optional[Dict] = None,
6666
exclude_components: Optional[Dict] = None,
67-
resampling_strategy: Union[CrossValTypes, HoldoutValTypes] = HoldoutValTypes.holdout_validation,
67+
resampling_strategy: Union[CrossValTypes, HoldoutTypes] = HoldoutTypes.holdout,
6868
resampling_strategy_args: Optional[Dict[str, Any]] = None,
6969
backend: Optional[Backend] = None,
7070
search_space_updates: Optional[HyperparameterSearchSpaceUpdates] = None

autoPyTorch/datasets/base_dataset.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,6 @@ def __init__(
7171
test_tensors: Optional[BaseDatasetInputType] = None,
7272
resampling_strategy: Union[CrossValTypes, HoldoutTypes] = HoldoutTypes.holdout,
7373
resampling_strategy_args: Optional[Dict[str, Any]] = None,
74-
shuffle: Optional[bool] = True,
7574
seed: Optional[int] = 42,
7675
train_transforms: Optional[torchvision.transforms.Compose] = None,
7776
val_transforms: Optional[torchvision.transforms.Compose] = None,
@@ -89,10 +88,9 @@ def __init__(
8988
resampling_strategy (Union[CrossValTypes, HoldoutTypes]),
9089
(default=HoldoutTypes.holdout):
9190
strategy to split the training data.
92-
resampling_strategy_args (Optional[Dict[str, Any]]): arguments
93-
required for the chosen resampling strategy. If None, uses
94-
the default values provided in DEFAULT_RESAMPLING_PARAMETERS
95-
in ```datasets/resampling_strategy.py```.
91+
resampling_strategy_args (Optional[Dict[str, Any]]):
92+
arguments required for the chosen resampling strategy.
93+
The details are provided in autoPytorch/datasets/resampling_strategy.py
9694
shuffle: Whether to shuffle the data when performing splits
9795
seed (int), (default=1): seed to be used for reproducibility.
9896
train_transforms (Optional[torchvision.transforms.Compose]):
@@ -109,9 +107,9 @@ def __init__(
109107
type_check(train_tensors, val_tensors)
110108
self.train_tensors, self.val_tensors, self.test_tensors = train_tensors, val_tensors, test_tensors
111109
self.random_state = np.random.RandomState(seed=seed)
112-
self.shuffle = shuffle
113110
self.resampling_strategy = resampling_strategy
114111
self.resampling_strategy_args = resampling_strategy_args
112+
self.shuffle = self.resampling_strategy_args['shuffle']
115113
self.is_stratify = self.resampling_strategy.get('stratify', False)
116114

117115
self.task_type: Optional[str] = None

autoPyTorch/datasets/image_dataset.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from autoPyTorch.datasets.base_dataset import BaseDataset
2424
from autoPyTorch.datasets.resampling_strategy import (
2525
CrossValTypes,
26-
HoldoutValTypes,
26+
HoldoutTypes,
2727
)
2828

2929
IMAGE_DATASET_INPUT = Union[Dataset, Tuple[Union[np.ndarray, List[str]], np.ndarray]]
@@ -39,13 +39,12 @@ class ImageDataset(BaseDataset):
3939
validation data
4040
test (Union[Dataset, Tuple[Union[np.ndarray, List[str]], np.ndarray]]):
4141
testing data
42-
resampling_strategy (Union[CrossValTypes, HoldoutValTypes]),
43-
(default=HoldoutValTypes.holdout_validation):
42+
resampling_strategy (Union[CrossValTypes, HoldoutTypes]),
43+
(default=HoldoutTypes.holdout):
4444
strategy to split the training data.
45-
resampling_strategy_args (Optional[Dict[str, Any]]): arguments
46-
required for the chosen resampling strategy. If None, uses
47-
the default values provided in DEFAULT_RESAMPLING_PARAMETERS
48-
in ```datasets/resampling_strategy.py```.
45+
resampling_strategy_args (Optional[Dict[str, Any]]):
46+
arguments required for the chosen resampling strategy.
47+
The details are provided in autoPytorch/datasets/resampling_strategy.py
4948
shuffle: Whether to shuffle the data before performing splits
5049
seed (int), (default=1): seed to be used for reproducibility.
5150
train_transforms (Optional[torchvision.transforms.Compose]):
@@ -57,9 +56,8 @@ def __init__(self,
5756
train: IMAGE_DATASET_INPUT,
5857
val: Optional[IMAGE_DATASET_INPUT] = None,
5958
test: Optional[IMAGE_DATASET_INPUT] = None,
60-
resampling_strategy: Union[CrossValTypes, HoldoutValTypes] = HoldoutValTypes.holdout_validation,
59+
resampling_strategy: Union[CrossValTypes, HoldoutTypes] = HoldoutTypes.holdout,
6160
resampling_strategy_args: Optional[Dict[str, Any]] = None,
62-
shuffle: Optional[bool] = True,
6361
seed: Optional[int] = 42,
6462
train_transforms: Optional[torchvision.transforms.Compose] = None,
6563
val_transforms: Optional[torchvision.transforms.Compose] = None,
@@ -72,7 +70,7 @@ def __init__(self,
7270
test = _create_image_dataset(data=test)
7371
self.mean, self.std = _calc_mean_std(train=train)
7472

75-
super().__init__(train_tensors=train, val_tensors=val, test_tensors=test, shuffle=shuffle,
73+
super().__init__(train_tensors=train, val_tensors=val, test_tensors=test,
7674
resampling_strategy=resampling_strategy, resampling_strategy_args=resampling_strategy_args,
7775
seed=seed,
7876
train_transforms=train_transforms,

autoPyTorch/datasets/resampling_strategy.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from enum import Enum
22
from functools import partial
3-
from typing import List, Optional, Tuple, Union
3+
from typing import List, NamedTuple, Optional, Tuple, Union
44

55
import numpy as np
66

@@ -16,6 +16,13 @@
1616
from torch.utils.data import Dataset
1717

1818

19+
class _ResamplingStrategyArgs(NamedTuple):
20+
val_share: float = 0.33
21+
num_splits: int = 5
22+
shuffle: bool = False
23+
stratify: bool = False
24+
25+
1926
class HoldoutFuncs():
2027
@staticmethod
2128
def holdout(

autoPyTorch/datasets/tabular_dataset.py

Lines changed: 8 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from autoPyTorch.datasets.base_dataset import BaseDataset
2121
from autoPyTorch.datasets.resampling_strategy import (
2222
CrossValTypes,
23-
HoldoutValTypes,
23+
HoldoutTypes,
2424
)
2525

2626

@@ -44,13 +44,12 @@ class TabularDataset(BaseDataset):
4444
Y (Union[np.ndarray, pd.Series]): training data targets.
4545
X_test (Optional[Union[np.ndarray, pd.DataFrame]]): input testing data.
4646
Y_test (Optional[Union[np.ndarray, pd.DataFrame]]): testing data targets
47-
resampling_strategy (Union[CrossValTypes, HoldoutValTypes]),
48-
(default=HoldoutValTypes.holdout_validation):
47+
resampling_strategy (Union[CrossValTypes, HoldoutTypes]),
48+
(default=HoldoutTypes.holdout):
4949
strategy to split the training data.
50-
resampling_strategy_args (Optional[Dict[str, Any]]): arguments
51-
required for the chosen resampling strategy. If None, uses
52-
the default values provided in DEFAULT_RESAMPLING_PARAMETERS
53-
in ```datasets/resampling_strategy.py```.
50+
resampling_strategy_args (Optional[Dict[str, Any]]):
51+
arguments required for the chosen resampling strategy.
52+
The details are provided in autoPytorch/datasets/resampling_strategy.py
5453
shuffle: Whether to shuffle the data before performing splits
5554
seed (int), (default=1): seed to be used for reproducibility.
5655
train_transforms (Optional[torchvision.transforms.Compose]):
@@ -67,9 +66,8 @@ def __init__(self,
6766
Y: Union[np.ndarray, pd.Series],
6867
X_test: Optional[Union[np.ndarray, pd.DataFrame]] = None,
6968
Y_test: Optional[Union[np.ndarray, pd.DataFrame]] = None,
70-
resampling_strategy: Union[CrossValTypes, HoldoutValTypes] = HoldoutValTypes.holdout_validation,
69+
resampling_strategy: Union[CrossValTypes, HoldoutTypes] = HoldoutTypes.holdout,
7170
resampling_strategy_args: Optional[Dict[str, Any]] = None,
72-
shuffle: Optional[bool] = True,
7371
seed: Optional[int] = 42,
7472
train_transforms: Optional[torchvision.transforms.Compose] = None,
7573
val_transforms: Optional[torchvision.transforms.Compose] = None,
@@ -92,7 +90,7 @@ def __init__(self,
9290
self.num_features = validator.feature_validator.num_features
9391
self.categories = validator.feature_validator.categories
9492

95-
super().__init__(train_tensors=(X, Y), test_tensors=(X_test, Y_test), shuffle=shuffle,
93+
super().__init__(train_tensors=(X, Y), test_tensors=(X_test, Y_test),
9694
resampling_strategy=resampling_strategy,
9795
resampling_strategy_args=resampling_strategy_args,
9896
seed=seed, train_transforms=train_transforms,

autoPyTorch/datasets/time_series_dataset.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@ def __init__(self,
4141
val: Optional[TIME_SERIES_FORECASTING_INPUT] = None,
4242
resampling_strategy: Union[CrossValTypes, HoldoutTypes] = HoldoutTypes.holdout,
4343
resampling_strategy_args: Optional[Dict[str, Any]] = None,
44-
shuffle: Optional[bool] = False,
4544
seed: Optional[int] = 42,
4645
train_transforms: Optional[torchvision.transforms.Compose] = None,
4746
val_transforms: Optional[torchvision.transforms.Compose] = None,
@@ -69,7 +68,7 @@ def __init__(self,
6968
target_variables=target_variables,
7069
sequence_length=sequence_length,
7170
n_steps=n_steps)
72-
super().__init__(train_tensors=train, val_tensors=val, shuffle=shuffle,
71+
super().__init__(train_tensors=train, val_tensors=val,
7372
resampling_strategy=resampling_strategy, resampling_strategy_args=resampling_strategy_args,
7473
seed=seed,
7574
train_transforms=train_transforms,
@@ -129,15 +128,17 @@ def __init__(self,
129128
_check_time_series_inputs(train=train,
130129
val=val,
131130
task_type="time_series_classification")
132-
super().__init__(train_tensors=train, val_tensors=val, shuffle=True)
131+
resampling_strategy_args = {'shuffle': True}
132+
super().__init__(train_tensors=train, val_tensors=val, resampling_strategy_args=resampling_strategy_args)
133133

134134

135135
class TimeSeriesRegressionDataset(BaseDataset):
136136
def __init__(self, train: Tuple[np.ndarray, np.ndarray], val: Optional[Tuple[np.ndarray, np.ndarray]] = None):
137137
_check_time_series_inputs(train=train,
138138
val=val,
139139
task_type="time_series_regression")
140-
super().__init__(train_tensors=train, val_tensors=val, shuffle=True)
140+
resampling_strategy_args = {'shuffle': True}
141+
super().__init__(train_tensors=train, val_tensors=val, resampling_strategy_args=resampling_strategy_args)
141142

142143

143144
def _check_time_series_inputs(task_type: str,

autoPyTorch/optimizer/smbo.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,7 @@
1919
from autoPyTorch.datasets.base_dataset import BaseDataset
2020
from autoPyTorch.datasets.resampling_strategy import (
2121
CrossValTypes,
22-
DEFAULT_RESAMPLING_PARAMETERS,
23-
HoldoutValTypes,
22+
HoldoutTypes,
2423
)
2524
from autoPyTorch.ensemble.ensemble_builder import EnsembleBuilderManager
2625
from autoPyTorch.evaluation.tae import ExecuteTaFuncWithQueue, get_cost_of_crash
@@ -93,7 +92,7 @@ def __init__(self,
9392
pipeline_config: typing.Dict[str, typing.Any],
9493
start_num_run: int = 1,
9594
seed: int = 1,
96-
resampling_strategy: typing.Union[HoldoutValTypes, CrossValTypes] = HoldoutValTypes.holdout_validation,
95+
resampling_strategy: typing.Union[HoldoutTypes, CrossValTypes] = HoldoutTypes.holdout,
9796
resampling_strategy_args: typing.Optional[typing.Dict[str, typing.Any]] = None,
9897
include: typing.Optional[typing.Dict[str, typing.Any]] = None,
9998
exclude: typing.Optional[typing.Dict[str, typing.Any]] = None,
@@ -173,9 +172,7 @@ def __init__(self,
173172

174173
# Evaluation
175174
self.resampling_strategy = resampling_strategy
176-
if resampling_strategy_args is None:
177-
resampling_strategy_args = DEFAULT_RESAMPLING_PARAMETERS[resampling_strategy]
178-
self.resampling_strategy_args = resampling_strategy_args
175+
self.resampling_strategy_args = resampling_strategy_args if resampling_strategy_args is None else {}
179176

180177
# and a bunch of useful limits
181178
self.worst_possible_result = get_cost_of_crash(self.metric)

examples/tabular/40_advanced/example_resampling_strategy.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
import sklearn.model_selection
2525

2626
from autoPyTorch.api.tabular_classification import TabularClassificationTask
27-
from autoPyTorch.datasets.resampling_strategy import CrossValTypes, HoldoutValTypes
27+
from autoPyTorch.datasets.resampling_strategy import CrossValTypes, HoldoutTypes
2828

2929

3030
if __name__ == '__main__':
@@ -48,11 +48,11 @@
4848
# To maintain logs of the run, set the next two as False
4949
delete_tmp_folder_after_terminate=True,
5050
delete_output_folder_after_terminate=True,
51-
# 'HoldoutValTypes.holdout_validation' with 'val_share': 0.33
51+
# 'HoldoutTypes.holdout' with 'val_share': 0.33
5252
# is the default argument setting for TabularClassificationTask.
5353
# It is explicitly specified in this example for demonstrational
5454
# purpose.
55-
resampling_strategy=HoldoutValTypes.holdout_validation,
55+
resampling_strategy=HoldoutTypes.holdout,
5656
resampling_strategy_args={'val_share': 0.33}
5757
)
5858

@@ -90,7 +90,7 @@
9090
# To maintain logs of the run, set the next two as False
9191
delete_tmp_folder_after_terminate=True,
9292
delete_output_folder_after_terminate=True,
93-
resampling_strategy=CrossValTypes.k_fold_cross_validation,
93+
resampling_strategy=CrossValTypes.k_fold,
9494
resampling_strategy_args={'num_splits': 3}
9595
)
9696

@@ -130,9 +130,9 @@
130130
delete_output_folder_after_terminate=True,
131131
# For demonstration purposes, we use
132132
# Stratified hold out validation. However,
133-
# one can also use CrossValTypes.stratified_k_fold_cross_validation.
134-
resampling_strategy=HoldoutValTypes.stratified_holdout_validation,
135-
resampling_strategy_args={'val_share': 0.33}
133+
# one can also use CrossValTypes.k_fold.
134+
resampling_strategy=HoldoutTypes.holdout,
135+
resampling_strategy_args={'val_share': 0.33, 'stratify': True}
136136
)
137137

138138
############################################################################

0 commit comments

Comments
 (0)