1
1
import os
2
2
import uuid
3
3
from abc import ABCMeta
4
- from typing import Any , Dict , List , Optional , Sequence , Tuple , Union , cast
4
+ from typing import Any , Dict , List , Optional , Sequence , Tuple , Union
5
5
6
6
import numpy as np
7
7
14
14
import torchvision
15
15
16
16
from autoPyTorch .constants import CLASSIFICATION_OUTPUTS , STRING_TO_OUTPUT_TYPES
17
- from autoPyTorch .datasets .resampling_strategy import (
18
- CrossValFunc ,
19
- CrossValFuncs ,
20
- CrossValTypes ,
21
- DEFAULT_RESAMPLING_PARAMETERS ,
22
- HoldOutFunc ,
23
- HoldOutFuncs ,
24
- HoldoutValTypes
25
- )
17
+ from autoPyTorch .datasets .resampling_strategy import CrossValTypes , HoldoutTypes
26
18
from autoPyTorch .utils .common import FitRequirement
27
19
28
20
BaseDatasetInputType = Union [Tuple [np .ndarray , np .ndarray ], Dataset ]
@@ -77,7 +69,7 @@ def __init__(
77
69
dataset_name : Optional [str ] = None ,
78
70
val_tensors : Optional [BaseDatasetInputType ] = None ,
79
71
test_tensors : Optional [BaseDatasetInputType ] = None ,
80
- resampling_strategy : Union [CrossValTypes , HoldoutValTypes ] = HoldoutValTypes . holdout_validation ,
72
+ resampling_strategy : Union [CrossValTypes , HoldoutTypes ] = HoldoutTypes . holdout ,
81
73
resampling_strategy_args : Optional [Dict [str , Any ]] = None ,
82
74
shuffle : Optional [bool ] = True ,
83
75
seed : Optional [int ] = 42 ,
@@ -94,14 +86,14 @@ def __init__(
94
86
validation data
95
87
test_tensors (An optional tuple of objects that have a __len__ and a __getitem__ attribute):
96
88
test data
97
- resampling_strategy (Union[CrossValTypes, HoldoutValTypes ]),
98
- (default=HoldoutValTypes.holdout_validation ):
89
+ resampling_strategy (Union[CrossValTypes, HoldoutTypes ]),
90
+ (default=HoldoutTypes.holdout ):
99
91
strategy to split the training data.
100
92
resampling_strategy_args (Optional[Dict[str, Any]]): arguments
101
93
required for the chosen resampling strategy. If None, uses
102
94
the default values provided in DEFAULT_RESAMPLING_PARAMETERS
103
95
in ```datasets/resampling_strategy.py```.
104
- shuffle: Whether to shuffle the data before performing splits
96
+ shuffle: Whether to shuffle the data when performing splits
105
97
seed (int), (default=1): seed to be used for reproducibility.
106
98
train_transforms (Optional[torchvision.transforms.Compose]):
107
99
Additional Transforms to be applied to the training data
@@ -116,12 +108,12 @@ def __init__(
116
108
if not hasattr (train_tensors [0 ], 'shape' ):
117
109
type_check (train_tensors , val_tensors )
118
110
self .train_tensors , self .val_tensors , self .test_tensors = train_tensors , val_tensors , test_tensors
119
- self .cross_validators : Dict [str , CrossValFunc ] = {}
120
- self .holdout_validators : Dict [str , HoldOutFunc ] = {}
121
111
self .random_state = np .random .RandomState (seed = seed )
122
112
self .shuffle = shuffle
123
113
self .resampling_strategy = resampling_strategy
124
114
self .resampling_strategy_args = resampling_strategy_args
115
+ self .is_stratify = self .resampling_strategy .get ('stratify' , False )
116
+
125
117
self .task_type : Optional [str ] = None
126
118
self .issparse : bool = issparse (self .train_tensors [0 ])
127
119
self .input_shape : Tuple [int ] = self .train_tensors [0 ].shape [1 :]
@@ -137,9 +129,6 @@ def __init__(
137
129
# TODO: Look for a criteria to define small enough to preprocess
138
130
self .is_small_preprocess = True
139
131
140
- # Make sure cross validation splits are created once
141
- self .cross_validators = CrossValFuncs .get_cross_validators (* CrossValTypes )
142
- self .holdout_validators = HoldOutFuncs .get_holdout_validators (* HoldoutValTypes )
143
132
self .splits = self .get_splits_from_resampling_strategy ()
144
133
145
134
# We also need to be able to transform the data, be it for pre-processing
@@ -205,7 +194,30 @@ def __len__(self) -> int:
205
194
return self .train_tensors [0 ].shape [0 ]
206
195
207
196
def _get_indices (self ) -> np .ndarray :
208
- return self .random_state .permutation (len (self )) if self .shuffle else np .arange (len (self ))
197
+ return np .arange (len (self ))
198
+
199
+ def _process_resampling_strategy_args (self ) -> None :
200
+ if not any (isinstance (self .resampling_strategy , val_type )
201
+ for val_type in [HoldoutTypes , CrossValTypes ]):
202
+ raise ValueError (f"resampling_strategy { self .resampling_strategy } is not supported." )
203
+
204
+ if self .resampling_strategy_args is not None and \
205
+ not isinstance (self .resampling_strategy_args , dict ):
206
+
207
+ raise TypeError ("resampling_strategy_args must be dict or None,"
208
+ f" but got { type (self .resampling_strategy_args )} " )
209
+
210
+ val_share = self .resampling_strategy_args .get ('val_share' , None )
211
+ num_splits = self .resampling_strategy_args .get ('num_splits' , None )
212
+
213
+ if val_share is not None and (val_share < 0 or val_share > 1 ):
214
+ raise ValueError (f"`val_share` must be between 0 and 1, got { val_share } ." )
215
+
216
+ if num_splits is not None :
217
+ if num_splits <= 0 :
218
+ raise ValueError (f"`num_splits` must be a positive integer, got { num_splits } ." )
219
+ elif not isinstance (num_splits , int ):
220
+ raise ValueError (f"`num_splits` must be an integer, got { num_splits } ." )
209
221
210
222
def get_splits_from_resampling_strategy (self ) -> List [Tuple [List [int ], List [int ]]]:
211
223
"""
@@ -214,100 +226,33 @@ def get_splits_from_resampling_strategy(self) -> List[Tuple[List[int], List[int]
214
226
Returns
215
227
(List[Tuple[List[int], List[int]]]): splits in the [train_indices, val_indices] format
216
228
"""
217
- splits = []
218
- if isinstance (self .resampling_strategy , HoldoutValTypes ):
219
- val_share = DEFAULT_RESAMPLING_PARAMETERS [self .resampling_strategy ].get (
220
- 'val_share' , None )
221
- if self .resampling_strategy_args is not None :
222
- val_share = self .resampling_strategy_args .get ('val_share' , val_share )
223
- splits .append (
224
- self .create_holdout_val_split (
225
- holdout_val_type = self .resampling_strategy ,
226
- val_share = val_share ,
227
- )
229
+ # check if the requirements are met and if we can get splits
230
+ self ._process_resampling_strategy_args ()
231
+
232
+ labels_to_stratify = self .train_tensors [- 1 ] if self .is_stratify else None
233
+
234
+ if isinstance (self .resampling_strategy , HoldoutTypes ):
235
+ val_share = self .resampling_strategy_args ['val_share' ]
236
+
237
+ return self .resampling_strategy (
238
+ random_state = self .random_state ,
239
+ val_share = val_share ,
240
+ shuffle = self .shuffle ,
241
+ indices = self ._get_indices (),
242
+ labels_to_stratify = labels_to_stratify
228
243
)
229
244
elif isinstance (self .resampling_strategy , CrossValTypes ):
230
- num_splits = DEFAULT_RESAMPLING_PARAMETERS [self .resampling_strategy ].get (
231
- 'num_splits' , None )
232
- if self .resampling_strategy_args is not None :
233
- num_splits = self .resampling_strategy_args .get ('num_splits' , num_splits )
234
- # Create the split if it was not created before
235
- splits .extend (
236
- self .create_cross_val_splits (
237
- cross_val_type = self .resampling_strategy ,
238
- num_splits = cast (int , num_splits ),
239
- )
245
+ num_splits = self .resampling_strategy_args ['num_splits' ]
246
+
247
+ return self .create_cross_val_splits (
248
+ random_state = self .random_state ,
249
+ num_splits = int (num_splits ),
250
+ shuffle = self .shuffle ,
251
+ indices = self ._get_indices (),
252
+ labels_to_stratify = labels_to_stratify
240
253
)
241
254
else :
242
255
raise ValueError (f"Unsupported resampling strategy={ self .resampling_strategy } " )
243
- return splits
244
-
245
- def create_cross_val_splits (
246
- self ,
247
- cross_val_type : CrossValTypes ,
248
- num_splits : int
249
- ) -> List [Tuple [Union [List [int ], np .ndarray ], Union [List [int ], np .ndarray ]]]:
250
- """
251
- This function creates the cross validation split for the given task.
252
-
253
- It is done once per dataset to have comparable results among pipelines
254
- Args:
255
- cross_val_type (CrossValTypes):
256
- num_splits (int): number of splits to be created
257
-
258
- Returns:
259
- (List[Tuple[Union[List[int], np.ndarray], Union[List[int], np.ndarray]]]):
260
- list containing 'num_splits' splits.
261
- """
262
- # Create just the split once
263
- # This is gonna be called multiple times, because the current dataset
264
- # is being used for multiple pipelines. That is, to be efficient with memory
265
- # we dump the dataset to memory and read it on a need basis. So this function
266
- # should be robust against multiple calls, and it does so by remembering the splits
267
- if not isinstance (cross_val_type , CrossValTypes ):
268
- raise NotImplementedError (f'The selected `cross_val_type` "{ cross_val_type } " is not implemented.' )
269
- kwargs = {}
270
- if cross_val_type .is_stratified ():
271
- # we need additional information about the data for stratification
272
- kwargs ["stratify" ] = self .train_tensors [- 1 ]
273
- splits = self .cross_validators [cross_val_type .name ](
274
- self .random_state , num_splits , self ._get_indices (), ** kwargs )
275
- return splits
276
-
277
- def create_holdout_val_split (
278
- self ,
279
- holdout_val_type : HoldoutValTypes ,
280
- val_share : float ,
281
- ) -> Tuple [np .ndarray , np .ndarray ]:
282
- """
283
- This function creates the holdout split for the given task.
284
-
285
- It is done once per dataset to have comparable results among pipelines
286
- Args:
287
- holdout_val_type (HoldoutValTypes):
288
- val_share (float): share of the validation data
289
-
290
- Returns:
291
- (Tuple[np.ndarray, np.ndarray]): Tuple containing (train_indices, val_indices)
292
- """
293
- if holdout_val_type is None :
294
- raise ValueError (
295
- '`val_share` specified, but `holdout_val_type` not specified.'
296
- )
297
- if self .val_tensors is not None :
298
- raise ValueError (
299
- '`val_share` specified, but the Dataset was a given a pre-defined split at initialization already.' )
300
- if val_share < 0 or val_share > 1 :
301
- raise ValueError (f"`val_share` must be between 0 and 1, got { val_share } ." )
302
- if not isinstance (holdout_val_type , HoldoutValTypes ):
303
- raise NotImplementedError (f'The specified `holdout_val_type` "{ holdout_val_type } " is not supported.' )
304
- kwargs = {}
305
- if holdout_val_type .is_stratified ():
306
- # we need additional information about the data for stratification
307
- kwargs ["stratify" ] = self .train_tensors [- 1 ]
308
- train , val = self .holdout_validators [holdout_val_type .name ](
309
- self .random_state , val_share , self ._get_indices (), ** kwargs )
310
- return train , val
311
256
312
257
def get_dataset_for_training (self , split_id : int ) -> Tuple [Dataset , Dataset ]:
313
258
"""
0 commit comments