Skip to content

Commit 36cef27

Browse files
committed
[refactor] Remove get_cross_validators and get_holdout_validators
Since we can call each split function directly from CrossValTypes and HoldoutValTypes. I removed these two functions.
1 parent 94df1e3 commit 36cef27

File tree

3 files changed

+230
-308
lines changed

3 files changed

+230
-308
lines changed

autoPyTorch/datasets/base_dataset.py

Lines changed: 54 additions & 109 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import os
22
import uuid
33
from abc import ABCMeta
4-
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast
4+
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
55

66
import numpy as np
77

@@ -14,15 +14,7 @@
1414
import torchvision
1515

1616
from autoPyTorch.constants import CLASSIFICATION_OUTPUTS, STRING_TO_OUTPUT_TYPES
17-
from autoPyTorch.datasets.resampling_strategy import (
18-
CrossValFunc,
19-
CrossValFuncs,
20-
CrossValTypes,
21-
DEFAULT_RESAMPLING_PARAMETERS,
22-
HoldOutFunc,
23-
HoldOutFuncs,
24-
HoldoutValTypes
25-
)
17+
from autoPyTorch.datasets.resampling_strategy import CrossValTypes, HoldoutTypes
2618
from autoPyTorch.utils.common import FitRequirement
2719

2820
BaseDatasetInputType = Union[Tuple[np.ndarray, np.ndarray], Dataset]
@@ -77,7 +69,7 @@ def __init__(
7769
dataset_name: Optional[str] = None,
7870
val_tensors: Optional[BaseDatasetInputType] = None,
7971
test_tensors: Optional[BaseDatasetInputType] = None,
80-
resampling_strategy: Union[CrossValTypes, HoldoutValTypes] = HoldoutValTypes.holdout_validation,
72+
resampling_strategy: Union[CrossValTypes, HoldoutTypes] = HoldoutTypes.holdout,
8173
resampling_strategy_args: Optional[Dict[str, Any]] = None,
8274
shuffle: Optional[bool] = True,
8375
seed: Optional[int] = 42,
@@ -94,14 +86,14 @@ def __init__(
9486
validation data
9587
test_tensors (An optional tuple of objects that have a __len__ and a __getitem__ attribute):
9688
test data
97-
resampling_strategy (Union[CrossValTypes, HoldoutValTypes]),
98-
(default=HoldoutValTypes.holdout_validation):
89+
resampling_strategy (Union[CrossValTypes, HoldoutTypes]),
90+
(default=HoldoutTypes.holdout):
9991
strategy to split the training data.
10092
resampling_strategy_args (Optional[Dict[str, Any]]): arguments
10193
required for the chosen resampling strategy. If None, uses
10294
the default values provided in DEFAULT_RESAMPLING_PARAMETERS
10395
in ```datasets/resampling_strategy.py```.
104-
shuffle: Whether to shuffle the data before performing splits
96+
shuffle: Whether to shuffle the data when performing splits
10597
seed (int), (default=1): seed to be used for reproducibility.
10698
train_transforms (Optional[torchvision.transforms.Compose]):
10799
Additional Transforms to be applied to the training data
@@ -116,12 +108,12 @@ def __init__(
116108
if not hasattr(train_tensors[0], 'shape'):
117109
type_check(train_tensors, val_tensors)
118110
self.train_tensors, self.val_tensors, self.test_tensors = train_tensors, val_tensors, test_tensors
119-
self.cross_validators: Dict[str, CrossValFunc] = {}
120-
self.holdout_validators: Dict[str, HoldOutFunc] = {}
121111
self.random_state = np.random.RandomState(seed=seed)
122112
self.shuffle = shuffle
123113
self.resampling_strategy = resampling_strategy
124114
self.resampling_strategy_args = resampling_strategy_args
115+
self.is_stratify = self.resampling_strategy.get('stratify', False)
116+
125117
self.task_type: Optional[str] = None
126118
self.issparse: bool = issparse(self.train_tensors[0])
127119
self.input_shape: Tuple[int] = self.train_tensors[0].shape[1:]
@@ -137,9 +129,6 @@ def __init__(
137129
# TODO: Look for a criteria to define small enough to preprocess
138130
self.is_small_preprocess = True
139131

140-
# Make sure cross validation splits are created once
141-
self.cross_validators = CrossValFuncs.get_cross_validators(*CrossValTypes)
142-
self.holdout_validators = HoldOutFuncs.get_holdout_validators(*HoldoutValTypes)
143132
self.splits = self.get_splits_from_resampling_strategy()
144133

145134
# We also need to be able to transform the data, be it for pre-processing
@@ -205,7 +194,30 @@ def __len__(self) -> int:
205194
return self.train_tensors[0].shape[0]
206195

207196
def _get_indices(self) -> np.ndarray:
208-
return self.random_state.permutation(len(self)) if self.shuffle else np.arange(len(self))
197+
return np.arange(len(self))
198+
199+
def _process_resampling_strategy_args(self) -> None:
200+
if not any(isinstance(self.resampling_strategy, val_type)
201+
for val_type in [HoldoutTypes, CrossValTypes]):
202+
raise ValueError(f"resampling_strategy {self.resampling_strategy} is not supported.")
203+
204+
if self.resampling_strategy_args is not None and \
205+
not isinstance(self.resampling_strategy_args, dict):
206+
207+
raise TypeError("resampling_strategy_args must be dict or None,"
208+
f" but got {type(self.resampling_strategy_args)}")
209+
210+
val_share = self.resampling_strategy_args.get('val_share', None)
211+
num_splits = self.resampling_strategy_args.get('num_splits', None)
212+
213+
if val_share is not None and (val_share < 0 or val_share > 1):
214+
raise ValueError(f"`val_share` must be between 0 and 1, got {val_share}.")
215+
216+
if num_splits is not None:
217+
if num_splits <= 0:
218+
raise ValueError(f"`num_splits` must be a positive integer, got {num_splits}.")
219+
elif not isinstance(num_splits, int):
220+
raise ValueError(f"`num_splits` must be an integer, got {num_splits}.")
209221

210222
def get_splits_from_resampling_strategy(self) -> List[Tuple[List[int], List[int]]]:
211223
"""
@@ -214,100 +226,33 @@ def get_splits_from_resampling_strategy(self) -> List[Tuple[List[int], List[int]
214226
Returns
215227
(List[Tuple[List[int], List[int]]]): splits in the [train_indices, val_indices] format
216228
"""
217-
splits = []
218-
if isinstance(self.resampling_strategy, HoldoutValTypes):
219-
val_share = DEFAULT_RESAMPLING_PARAMETERS[self.resampling_strategy].get(
220-
'val_share', None)
221-
if self.resampling_strategy_args is not None:
222-
val_share = self.resampling_strategy_args.get('val_share', val_share)
223-
splits.append(
224-
self.create_holdout_val_split(
225-
holdout_val_type=self.resampling_strategy,
226-
val_share=val_share,
227-
)
229+
# check if the requirements are met and if we can get splits
230+
self._process_resampling_strategy_args()
231+
232+
labels_to_stratify = self.train_tensors[-1] if self.is_stratify else None
233+
234+
if isinstance(self.resampling_strategy, HoldoutTypes):
235+
val_share = self.resampling_strategy_args['val_share']
236+
237+
return self.resampling_strategy(
238+
random_state=self.random_state,
239+
val_share=val_share,
240+
shuffle=self.shuffle,
241+
indices=self._get_indices(),
242+
labels_to_stratify=labels_to_stratify
228243
)
229244
elif isinstance(self.resampling_strategy, CrossValTypes):
230-
num_splits = DEFAULT_RESAMPLING_PARAMETERS[self.resampling_strategy].get(
231-
'num_splits', None)
232-
if self.resampling_strategy_args is not None:
233-
num_splits = self.resampling_strategy_args.get('num_splits', num_splits)
234-
# Create the split if it was not created before
235-
splits.extend(
236-
self.create_cross_val_splits(
237-
cross_val_type=self.resampling_strategy,
238-
num_splits=cast(int, num_splits),
239-
)
245+
num_splits = self.resampling_strategy_args['num_splits']
246+
247+
return self.create_cross_val_splits(
248+
random_state=self.random_state,
249+
num_splits=int(num_splits),
250+
shuffle=self.shuffle,
251+
indices=self._get_indices(),
252+
labels_to_stratify=labels_to_stratify
240253
)
241254
else:
242255
raise ValueError(f"Unsupported resampling strategy={self.resampling_strategy}")
243-
return splits
244-
245-
def create_cross_val_splits(
246-
self,
247-
cross_val_type: CrossValTypes,
248-
num_splits: int
249-
) -> List[Tuple[Union[List[int], np.ndarray], Union[List[int], np.ndarray]]]:
250-
"""
251-
This function creates the cross validation split for the given task.
252-
253-
It is done once per dataset to have comparable results among pipelines
254-
Args:
255-
cross_val_type (CrossValTypes):
256-
num_splits (int): number of splits to be created
257-
258-
Returns:
259-
(List[Tuple[Union[List[int], np.ndarray], Union[List[int], np.ndarray]]]):
260-
list containing 'num_splits' splits.
261-
"""
262-
# Create just the split once
263-
# This is gonna be called multiple times, because the current dataset
264-
# is being used for multiple pipelines. That is, to be efficient with memory
265-
# we dump the dataset to memory and read it on a need basis. So this function
266-
# should be robust against multiple calls, and it does so by remembering the splits
267-
if not isinstance(cross_val_type, CrossValTypes):
268-
raise NotImplementedError(f'The selected `cross_val_type` "{cross_val_type}" is not implemented.')
269-
kwargs = {}
270-
if cross_val_type.is_stratified():
271-
# we need additional information about the data for stratification
272-
kwargs["stratify"] = self.train_tensors[-1]
273-
splits = self.cross_validators[cross_val_type.name](
274-
self.random_state, num_splits, self._get_indices(), **kwargs)
275-
return splits
276-
277-
def create_holdout_val_split(
278-
self,
279-
holdout_val_type: HoldoutValTypes,
280-
val_share: float,
281-
) -> Tuple[np.ndarray, np.ndarray]:
282-
"""
283-
This function creates the holdout split for the given task.
284-
285-
It is done once per dataset to have comparable results among pipelines
286-
Args:
287-
holdout_val_type (HoldoutValTypes):
288-
val_share (float): share of the validation data
289-
290-
Returns:
291-
(Tuple[np.ndarray, np.ndarray]): Tuple containing (train_indices, val_indices)
292-
"""
293-
if holdout_val_type is None:
294-
raise ValueError(
295-
'`val_share` specified, but `holdout_val_type` not specified.'
296-
)
297-
if self.val_tensors is not None:
298-
raise ValueError(
299-
'`val_share` specified, but the Dataset was a given a pre-defined split at initialization already.')
300-
if val_share < 0 or val_share > 1:
301-
raise ValueError(f"`val_share` must be between 0 and 1, got {val_share}.")
302-
if not isinstance(holdout_val_type, HoldoutValTypes):
303-
raise NotImplementedError(f'The specified `holdout_val_type` "{holdout_val_type}" is not supported.')
304-
kwargs = {}
305-
if holdout_val_type.is_stratified():
306-
# we need additional information about the data for stratification
307-
kwargs["stratify"] = self.train_tensors[-1]
308-
train, val = self.holdout_validators[holdout_val_type.name](
309-
self.random_state, val_share, self._get_indices(), **kwargs)
310-
return train, val
311256

312257
def get_dataset_for_training(self, split_id: int) -> Tuple[Dataset, Dataset]:
313258
"""

0 commit comments

Comments
 (0)