From d06ee7e2ce451dfc07d7dcd79aba87a4296441af Mon Sep 17 00:00:00 2001 From: dengdifan Date: Thu, 13 Apr 2023 13:50:55 +0200 Subject: [PATCH] allow NoResamplingStrategyTypes to be passed to TrainEvaluator --- autoPyTorch/evaluation/train_evaluator.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/autoPyTorch/evaluation/train_evaluator.py b/autoPyTorch/evaluation/train_evaluator.py index e88c8eaca..e20cc6833 100644 --- a/autoPyTorch/evaluation/train_evaluator.py +++ b/autoPyTorch/evaluation/train_evaluator.py @@ -14,7 +14,7 @@ CLASSIFICATION_TASKS, MULTICLASSMULTIOUTPUT ) -from autoPyTorch.datasets.resampling_strategy import CrossValTypes, HoldoutValTypes +from autoPyTorch.datasets.resampling_strategy import CrossValTypes, HoldoutValTypes, NoResamplingStrategyTypes from autoPyTorch.evaluation.abstract_evaluator import ( AbstractEvaluator, fit_and_suppress_warnings @@ -153,10 +153,10 @@ def __init__(self, backend: Backend, queue: Queue, search_space_updates=search_space_updates ) - if not isinstance(self.resampling_strategy, (CrossValTypes, HoldoutValTypes)): + if not isinstance(self.resampling_strategy, (CrossValTypes, HoldoutValTypes, NoResamplingStrategyTypes)): raise ValueError( f'resampling_strategy for TrainEvaluator must be in ' - f'(CrossValTypes, HoldoutValTypes), but got {self.resampling_strategy}' + f'(CrossValTypes, HoldoutValTypes, NoResamplingStrategyTypes), but got {self.resampling_strategy}' ) self.num_folds: int = len(self.splits)