Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Refactor] Made CrossValTypes, HoldoutValTypes to have split functions directly #108

Open
wants to merge 12 commits into
base: master
Choose a base branch
from
Open
Prev Previous commit
Next Next commit
[fix] Fix most test cases
nabenabe0928 committed May 19, 2021
commit 910e7d461402b965ea2ec688a5d7f89665dc3a6f
25 changes: 13 additions & 12 deletions autoPyTorch/datasets/resampling_strategy.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from enum import Enum
from functools import partial
from enum import IntEnum
from typing import List, NamedTuple, Optional, Tuple, Union

import numpy as np
@@ -92,7 +91,7 @@ def time_series(
return splits


class CrossValTypes(Enum):
class CrossValTypes(IntEnum):
"""The type of cross validation
This class is used to specify the cross validation function
@@ -107,11 +106,11 @@ class CrossValTypes(Enum):
>>> for cross_val_type in CrossValTypes:
print(cross_val_type.name, cross_val_type.value)
k_fold_cross_validation functools.partial(<function CrossValFuncs.k_fold_cross_validation at ...>)
time_series <function CrossValFuncs.time_series>
k_fold_cross_validation 100
time_series 101
"""
k_fold_cross_validation = partial(CrossValFuncs.k_fold_cross_validation)
time_series = partial(CrossValFuncs.time_series)
k_fold_cross_validation = 100
time_series = 101

def __call__(
self,
@@ -140,8 +139,9 @@ def __call__(

default_num_splits = _ResamplingStrategyArgs().num_splits
num_splits = num_splits if num_splits is not None else default_num_splits
split_fn = getattr(CrossValFuncs, self.name)

return self.value(
return split_fn(
random_state=random_state if shuffle else None,
num_splits=num_splits,
indices=indices,
@@ -150,7 +150,7 @@ def __call__(
)


class HoldoutValTypes(Enum):
class HoldoutValTypes(IntEnum):
"""The type of holdout validation
This class is used to specify the holdout validation function
@@ -164,7 +164,7 @@ class HoldoutValTypes(Enum):
>>> print(holdout_type.value)
functools.partial(<function HoldoutValTypes.holdout_validation at ...>)
0
>>> for holdout_type in HoldoutValTypes:
print(holdout_type.name)
@@ -174,7 +174,7 @@ class HoldoutValTypes(Enum):
Additionally, HoldoutValTypes.<function> can be called directly.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can you add an example to use it directly?

"""

holdout_validation = partial(HoldoutFuncs.holdout_validation)
holdout_validation = 0

def __call__(
self,
@@ -203,8 +203,9 @@ def __call__(

default_val_share = _ResamplingStrategyArgs().val_share
val_share = val_share if val_share is not None else default_val_share
split_fn = getattr(HoldoutFuncs, self.name)

return self.value(
return split_fn(
random_state=random_state if shuffle else None,
val_share=val_share,
indices=indices,