Skip to content

Commit 55e3c3c

Browse files
committed
[experimental] Increase the coverage
1 parent c8299e2 commit 55e3c3c

File tree

3 files changed

+27
-13
lines changed

3 files changed

+27
-13
lines changed

test/test_api/test_api.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -206,8 +206,8 @@ def _get_estimator(
206206
resampling_strategy,
207207
resampling_strategy_args,
208208
metric,
209-
total_walltime_limit=40,
210-
func_eval_time_limit_secs=10,
209+
total_walltime_limit=18,
210+
func_eval_time_limit_secs=6,
211211
**kwargs
212212
):
213213

@@ -253,6 +253,10 @@ def _check_tabular_task(estimator, X_test, y_test, task_type, resampling_strateg
253253

254254
_check_picklable(estimator, X_test)
255255

256+
representation = estimator.show_models()
257+
assert isinstance(representation, str)
258+
assert all(word in representation for word in ['Weight', 'Preprocessing', 'Estimator'])
259+
256260

257261
# Test
258262
# ====
@@ -314,10 +318,6 @@ def test_tabular_regression(openml_id, resampling_strategy, backend, resampling_
314318
n_successful_runs=1
315319
)
316320

317-
representation = estimator.show_models()
318-
assert isinstance(representation, str)
319-
assert all(word in representation for word in ['Weight', 'Preprocessing', 'Estimator'])
320-
321321

322322
@pytest.mark.parametrize('openml_id', (
323323
1590, # Adult to test NaN in categorical columns
@@ -354,8 +354,8 @@ def test_tabular_input_support(openml_id, backend):
354354
X_train=X_train, y_train=y_train,
355355
X_test=X_test, y_test=y_test,
356356
optimize_metric='accuracy',
357-
total_walltime_limit=150,
358-
func_eval_time_limit_secs=50,
357+
total_walltime_limit=30,
358+
func_eval_time_limit_secs=6,
359359
enable_traditional_pipeline=False,
360360
load_models=False,
361361
)

test/test_api/utils.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,12 @@
22

33
from smac.runhistory.runhistory import DataOrigin, RunHistory, RunKey, RunValue, StatusType
44

5-
from autoPyTorch.constants import REGRESSION_TASKS
5+
# from autoPyTorch.constants import REGRESSION_TASKS
66
from autoPyTorch.evaluation.abstract_evaluator import fit_pipeline
7-
from autoPyTorch.evaluation.pipeline_class_collection import (
8-
DummyClassificationPipeline,
9-
DummyRegressionPipeline
10-
)
7+
# from autoPyTorch.evaluation.pipeline_class_collection import (
8+
# DummyClassificationPipeline,
9+
# DummyRegressionPipeline
10+
# )
1111
from autoPyTorch.evaluation.train_evaluator import TrainEvaluator
1212
from autoPyTorch.pipeline.traditional_tabular_classification import TraditionalTabularClassificationPipeline
1313
from autoPyTorch.utils.common import subsampler
@@ -29,13 +29,15 @@ def dummy_traditional_classification(self, time_left: int, func_eval_time_limit_
2929
# Fixtures
3030
# ========
3131
class DummyTrainEvaluator(TrainEvaluator):
32+
"""
3233
def _get_pipeline(self):
3334
if self.task_type in REGRESSION_TASKS:
3435
pipeline = DummyRegressionPipeline(config=1)
3536
else:
3637
pipeline = DummyClassificationPipeline(config=1)
3738
3839
return pipeline
40+
"""
3941

4042
def _fit_and_evaluate_loss(self, pipeline, split_id, train_indices, opt_indices):
4143
X = dict(train_indices=train_indices, val_indices=opt_indices, split_id=split_id, num_run=self.num_run)

test/test_evaluation/test_pipeline_class_collection.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
get_default_pipeline_config,
2222
get_pipeline_class,
2323
)
24+
import autoPyTorch.pipeline.tabular_regression
2425

2526

2627
def test_get_default_pipeline_config():
@@ -42,6 +43,17 @@ def test_get_pipeline_class(task_type, config):
4243
assert 'Classification' in pipeline_cls.__mro__[0].__name__
4344

4445

46+
@pytest.mark.parametrize('config,ans', (
47+
(1, DummyRegressionPipeline),
48+
('tradition', MyTraditionalTabularRegressionPipeline),
49+
(unittest.mock.Mock(spec=Configuration), autoPyTorch.pipeline.tabular_regression.TabularRegressionPipeline)
50+
))
51+
def test_get_pipeline_class_check_class(config, ans):
52+
task_type = TABULAR_REGRESSION
53+
pipeline_cls = get_pipeline_class(config, task_type)
54+
assert ans is pipeline_cls
55+
56+
4557
def test_get_pipeline_class_errors():
4658
with pytest.raises(RuntimeError):
4759
get_pipeline_class(config=1.5, task_type=TABULAR_CLASSIFICATION)

0 commit comments

Comments
 (0)