Skip to content

Commit 6b577f6

Browse files
committed
[test] Increase the coverage
1 parent b6efccb commit 6b577f6

File tree

3 files changed

+78
-82
lines changed

3 files changed

+78
-82
lines changed

test/test_datasets/test_resampling_strategies.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,15 @@
11
import numpy as np
22

3-
from autoPyTorch.datasets.resampling_strategy import CrossValFuncs, HoldOutFuncs
3+
import pytest
4+
5+
from autoPyTorch.datasets.resampling_strategy import (
6+
CrossValFuncs,
7+
CrossValTypes,
8+
HoldOutFuncs,
9+
HoldoutValTypes,
10+
NoResamplingStrategyTypes,
11+
check_resampling_strategy
12+
)
413

514

615
def test_holdoutfuncs():
@@ -40,3 +49,12 @@ def test_crossvalfuncs():
4049
splits = split.stratified_k_fold_cross_validation(0, 10, X, stratify=y)
4150
assert len(splits) == 10
4251
assert all([0 in y[s[1]] for s in splits])
52+
53+
54+
def test_check_resampling_strategy():
55+
for rs in (CrossValTypes, HoldoutValTypes, NoResamplingStrategyTypes):
56+
for rs_func in rs:
57+
check_resampling_strategy(rs_func)
58+
59+
with pytest.raises(ValueError):
60+
check_resampling_strategy(None)

test/test_evaluation/test_evaluators.py

Lines changed: 50 additions & 81 deletions
Original file line numberDiff line numberDiff line change
@@ -143,34 +143,15 @@ def tearDown(self):
143143
if os.path.exists(self.ev_path):
144144
shutil.rmtree(self.ev_path)
145145

146-
def test_evaluate_loss(self):
147-
D = get_binary_classification_datamanager()
148-
backend_api = create(self.tmp_dir, self.output_dir, prefix='autoPyTorch')
149-
backend_api.load_datamanager = lambda: D
150-
fixed_params_dict = self.fixed_params._asdict()
151-
fixed_params_dict.update(backend=backend_api)
152-
evaluator = Evaluator(
153-
queue=multiprocessing.Queue(),
154-
fixed_pipeline_params=FixedPipelineParams(**fixed_params_dict),
155-
evaluator_params=self.eval_params
156-
)
157-
evaluator.splits = None
158-
with pytest.raises(ValueError):
159-
evaluator.evaluate_loss()
160-
161-
@unittest.mock.patch('autoPyTorch.pipeline.tabular_classification.TabularClassificationPipeline')
162-
def test_holdout(self, pipeline_mock):
163-
pipeline_mock.fit_dictionary = {'budget_type': 'epochs', 'epochs': 50}
164-
# Binary iris, contains 69 train samples, 31 test samples
165-
D = get_binary_classification_datamanager()
146+
def _get_evaluator(self, pipeline_mock, data):
166147
pipeline_mock.predict_proba.side_effect = \
167148
lambda X, batch_size=None: np.tile([0.6, 0.4], (len(X), 1))
168149
pipeline_mock.side_effect = lambda **kwargs: pipeline_mock
169150
pipeline_mock.get_additional_run_info.return_value = None
170151

171152
_queue = multiprocessing.Queue()
172153
backend_api = create(self.tmp_dir, self.output_dir, prefix='autoPyTorch')
173-
backend_api.load_datamanager = lambda: D
154+
backend_api.load_datamanager = lambda: data
174155

175156
fixed_params_dict = self.fixed_params._asdict()
176157
fixed_params_dict.update(backend=backend_api)
@@ -184,56 +165,72 @@ def test_holdout(self, pipeline_mock):
184165

185166
evaluator.evaluate_loss()
186167

168+
return evaluator
169+
170+
def _check_results(self, evaluator, ans):
187171
rval = read_queue(evaluator.queue)
188172
self.assertEqual(len(rval), 1)
189173
result = rval[0]['loss']
190174
self.assertEqual(len(rval[0]), 3)
191175
self.assertRaises(queue.Empty, evaluator.queue.get, timeout=1)
192-
176+
self.assertEqual(result, ans)
193177
self.assertEqual(evaluator._save_to_backend.call_count, 1)
194-
self.assertEqual(result, 0.5652173913043479)
195-
self.assertEqual(pipeline_mock.fit.call_count, 1)
196-
# 3 calls because of train, holdout and test set
197-
self.assertEqual(pipeline_mock.predict_proba.call_count, 3)
198-
call_args = evaluator._save_to_backend.call_args
199-
self.assertEqual(call_args[0][0].shape[0], len(D.splits[0][1]))
200-
self.assertIsNone(call_args[0][1])
201-
self.assertEqual(call_args[0][2].shape[0], D.test_tensors[1].shape[0])
202-
self.assertEqual(evaluator.pipelines[0].fit.call_count, 1)
203178

204-
@unittest.mock.patch('autoPyTorch.pipeline.tabular_classification.TabularClassificationPipeline')
205-
def test_cv(self, pipeline_mock):
206-
D = get_binary_classification_datamanager(resampling_strategy=CrossValTypes.k_fold_cross_validation)
179+
def _check_whether_save_y_opt_is_correct(self, resampling_strategy, ans):
180+
backend_api = create(self.tmp_dir, self.output_dir, prefix='autoPyTorch')
181+
D = get_binary_classification_datamanager(resampling_strategy)
182+
backend_api.load_datamanager = lambda: D
183+
fixed_params_dict = self.fixed_params._asdict()
184+
fixed_params_dict.update(backend=backend_api, save_y_opt=True)
185+
evaluator = Evaluator(
186+
queue=multiprocessing.Queue(),
187+
fixed_pipeline_params=FixedPipelineParams(**fixed_params_dict),
188+
evaluator_params=self.eval_params
189+
)
190+
assert evaluator.fixed_pipeline_params.save_y_opt == ans
207191

208-
pipeline_mock.predict_proba.side_effect = \
209-
lambda X, batch_size=None: np.tile([0.6, 0.4], (len(X), 1))
210-
pipeline_mock.side_effect = lambda **kwargs: pipeline_mock
211-
pipeline_mock.get_additional_run_info.return_value = None
192+
def test_whether_save_y_opt_is_correct_for_no_resampling(self):
193+
self._check_whether_save_y_opt_is_correct(NoResamplingStrategyTypes.no_resampling, False)
212194

213-
_queue = multiprocessing.Queue()
195+
def test_whether_save_y_opt_is_correct_for_resampling(self):
196+
self._check_whether_save_y_opt_is_correct(CrossValTypes.k_fold_cross_validation, True)
197+
198+
def test_evaluate_loss(self):
199+
D = get_binary_classification_datamanager()
214200
backend_api = create(self.tmp_dir, self.output_dir, prefix='autoPyTorch')
215201
backend_api.load_datamanager = lambda: D
216-
217202
fixed_params_dict = self.fixed_params._asdict()
218203
fixed_params_dict.update(backend=backend_api)
219204
evaluator = Evaluator(
220-
queue=_queue,
205+
queue=multiprocessing.Queue(),
221206
fixed_pipeline_params=FixedPipelineParams(**fixed_params_dict),
222207
evaluator_params=self.eval_params
223208
)
224-
evaluator._save_to_backend = unittest.mock.Mock(spec=evaluator._save_to_backend)
225-
evaluator._save_to_backend.return_value = True
209+
evaluator.splits = None
210+
with pytest.raises(ValueError):
211+
evaluator.evaluate_loss()
226212

227-
evaluator.evaluate_loss()
213+
@unittest.mock.patch('autoPyTorch.pipeline.tabular_classification.TabularClassificationPipeline')
214+
def test_holdout(self, pipeline_mock):
215+
D = get_binary_classification_datamanager()
216+
evaluator = self._get_evaluator(pipeline_mock, D)
217+
self._check_results(evaluator, ans=0.5652173913043479)
228218

229-
rval = read_queue(evaluator.queue)
230-
self.assertEqual(len(rval), 1)
231-
result = rval[0]['loss']
232-
self.assertEqual(len(rval[0]), 3)
233-
self.assertRaises(queue.Empty, evaluator.queue.get, timeout=1)
219+
self.assertEqual(pipeline_mock.fit.call_count, 1)
220+
# 3 calls because of train, holdout and test set
221+
self.assertEqual(pipeline_mock.predict_proba.call_count, 3)
222+
call_args = evaluator._save_to_backend.call_args
223+
self.assertEqual(call_args[0][0].shape[0], len(D.splits[0][1]))
224+
self.assertIsNone(call_args[0][1])
225+
self.assertEqual(call_args[0][2].shape[0], D.test_tensors[1].shape[0])
226+
self.assertEqual(evaluator.pipelines[0].fit.call_count, 1)
227+
228+
@unittest.mock.patch('autoPyTorch.pipeline.tabular_classification.TabularClassificationPipeline')
229+
def test_cv(self, pipeline_mock):
230+
D = get_binary_classification_datamanager(resampling_strategy=CrossValTypes.k_fold_cross_validation)
231+
evaluator = self._get_evaluator(pipeline_mock, D)
232+
self._check_results(evaluator, ans=0.463768115942029)
234233

235-
self.assertEqual(evaluator._save_to_backend.call_count, 1)
236-
self.assertEqual(result, 0.463768115942029)
237234
self.assertEqual(pipeline_mock.fit.call_count, 5)
238235
# 15 calls because of the training, holdout and
239236
# test set (3 sets x 5 folds = 15)
@@ -251,38 +248,10 @@ def test_cv(self, pipeline_mock):
251248

252249
@unittest.mock.patch('autoPyTorch.pipeline.tabular_classification.TabularClassificationPipeline')
253250
def test_no_resampling(self, pipeline_mock):
254-
pipeline_mock.fit_dictionary = {'budget_type': 'epochs', 'epochs': 10}
255-
# Binary iris, contains 69 train samples, 31 test samples
256251
D = get_binary_classification_datamanager(NoResamplingStrategyTypes.no_resampling)
257-
pipeline_mock.predict_proba.side_effect = \
258-
lambda X, batch_size=None: np.tile([0.6, 0.4], (len(X), 1))
259-
pipeline_mock.side_effect = lambda **kwargs: pipeline_mock
260-
pipeline_mock.get_additional_run_info.return_value = None
261-
262-
_queue = multiprocessing.Queue()
263-
backend_api = create(self.tmp_dir, self.output_dir, prefix='autoPyTorch')
264-
backend_api.load_datamanager = lambda: D
265-
266-
fixed_params_dict = self.fixed_params._asdict()
267-
fixed_params_dict.update(backend=backend_api)
268-
evaluator = Evaluator(
269-
queue=_queue,
270-
fixed_pipeline_params=FixedPipelineParams(**fixed_params_dict),
271-
evaluator_params=self.eval_params
272-
)
273-
evaluator._save_to_backend = unittest.mock.Mock(spec=evaluator._save_to_backend)
274-
evaluator._save_to_backend.return_value = True
252+
evaluator = self._get_evaluator(pipeline_mock, D)
253+
self._check_results(evaluator, ans=0.5806451612903225)
275254

276-
evaluator.evaluate_loss()
277-
278-
rval = read_queue(evaluator.queue)
279-
self.assertEqual(len(rval), 1)
280-
result = rval[0]['loss']
281-
self.assertEqual(len(rval[0]), 3)
282-
self.assertRaises(queue.Empty, evaluator.queue.get, timeout=1)
283-
284-
self.assertEqual(evaluator._save_to_backend.call_count, 1)
285-
self.assertEqual(result, 0.5806451612903225)
286255
self.assertEqual(pipeline_mock.fit.call_count, 1)
287256
# 2 calls because of train and test set
288257
self.assertEqual(pipeline_mock.predict_proba.call_count, 2)

test/test_evaluation/test_tae.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,15 @@ def test_check_run_info(self):
104104
with pytest.raises(ValueError):
105105
taq.run_wrapper(run_info)
106106

107+
def test_check_and_get_default_budget(self):
108+
taq = _create_taq()
109+
budget = taq._check_and_get_default_budget()
110+
assert isinstance(budget, float)
111+
112+
taq.fixed_pipeline_params = taq.fixed_pipeline_params._replace(budget_type='test')
113+
with pytest.raises(ValueError):
114+
taq._check_and_get_default_budget()
115+
107116
def test_cutoff_update_in_run_wrapper(self):
108117
taq = _create_taq()
109118
run_info = RunInfo(

0 commit comments

Comments
 (0)