Skip to content

Commit 7514fc0

Browse files
committed
[test] Add tests for pipeline repr
Since the modifications in tests removed the coverage on pipeline repr, I added tests to increase those parts. Basically, the decrease in the coverage happened due to the usage of dummy pipelines.
1 parent 55e3c3c commit 7514fc0

File tree

4 files changed

+34
-8
lines changed

4 files changed

+34
-8
lines changed

test/test_api/utils.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,12 @@
22

33
from smac.runhistory.runhistory import DataOrigin, RunHistory, RunKey, RunValue, StatusType
44

5-
# from autoPyTorch.constants import REGRESSION_TASKS
5+
from autoPyTorch.constants import REGRESSION_TASKS
66
from autoPyTorch.evaluation.abstract_evaluator import fit_pipeline
7-
# from autoPyTorch.evaluation.pipeline_class_collection import (
8-
# DummyClassificationPipeline,
9-
# DummyRegressionPipeline
10-
# )
7+
from autoPyTorch.evaluation.pipeline_class_collection import (
8+
DummyClassificationPipeline,
9+
DummyRegressionPipeline
10+
)
1111
from autoPyTorch.evaluation.train_evaluator import TrainEvaluator
1212
from autoPyTorch.pipeline.traditional_tabular_classification import TraditionalTabularClassificationPipeline
1313
from autoPyTorch.utils.common import subsampler
@@ -29,15 +29,13 @@ def dummy_traditional_classification(self, time_left: int, func_eval_time_limit_
2929
# Fixtures
3030
# ========
3131
class DummyTrainEvaluator(TrainEvaluator):
32-
"""
3332
def _get_pipeline(self):
3433
if self.task_type in REGRESSION_TASKS:
3534
pipeline = DummyRegressionPipeline(config=1)
3635
else:
3736
pipeline = DummyClassificationPipeline(config=1)
3837

3938
return pipeline
40-
"""
4139

4240
def _fit_and_evaluate_loss(self, pipeline, split_id, train_indices, opt_indices):
4341
X = dict(train_indices=train_indices, val_indices=opt_indices, split_id=split_id, num_run=self.num_run)

test/test_evaluation/test_pipeline_class_collection.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
import pytest
88

9+
import autoPyTorch.pipeline.tabular_regression
910
from autoPyTorch.constants import (
1011
IMAGE_CLASSIFICATION,
1112
REGRESSION_TASKS,
@@ -21,7 +22,6 @@
2122
get_default_pipeline_config,
2223
get_pipeline_class,
2324
)
24-
import autoPyTorch.pipeline.tabular_regression
2525

2626

2727
def test_get_default_pipeline_config():

test/test_pipeline/test_tabular_classification.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -491,3 +491,17 @@ def test_train_pipeline_with_runtime(fit_dictionary_tabular_dummy):
491491

492492
# More than 200 epochs would have pass in 5 seconds for this dataset
493493
assert len(run_summary.performance_tracker['start_time']) > 100
494+
495+
496+
def test_get_pipeline_representation():
497+
pipeline = TabularClassificationPipeline(
498+
dataset_properties={
499+
'numerical_columns': None,
500+
'categorical_columns': None,
501+
'task_type': 'tabular_classification'
502+
}
503+
)
504+
repr = pipeline.get_pipeline_representation()
505+
print(repr)
506+
assert isinstance(repr, dict)
507+
assert all(word in repr for word in ['Preprocessing', 'Estimator'])

test/test_pipeline/test_tabular_regression.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -317,3 +317,17 @@ def test_pipeline_score(fit_dictionary_tabular_dummy):
317317
# we should be able to get a decent score on this dummy data
318318
assert r2_score >= 0.8, f"Pipeline:{pipeline} Config:{config} FitDict: {fit_dictionary_tabular_dummy}, " \
319319
f"{pipeline.named_steps['trainer'].run_summary.performance_tracker['train_metrics']}"
320+
321+
322+
def test_get_pipeline_representation():
323+
pipeline = TabularRegressionPipeline(
324+
dataset_properties={
325+
'numerical_columns': None,
326+
'categorical_columns': None,
327+
'task_type': 'tabular_classification'
328+
}
329+
)
330+
repr = pipeline.get_pipeline_representation()
331+
print(repr)
332+
assert isinstance(repr, dict)
333+
assert all(word in repr for word in ['Preprocessing', 'Estimator'])

0 commit comments

Comments
 (0)