Skip to content

Commit

Permalink
small fix of main.finetune method and its unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
technocreep committed Oct 16, 2024
1 parent 3620311 commit 257b24a
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 3 deletions.
5 changes: 3 additions & 2 deletions fedot_ind/api/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,10 +232,10 @@ def predict_proba(self,
Args:
predict_mode: ``default='default'``. Defines the mode of prediction. Could be 'default' or 'probs'.
predict_data: tuple with test_features and test_target
calibrate_probs: ``default=False``. If True, calibrate probabilities
Returns:
the array with prediction probabilities
:param calibrate_probs:
"""
self.predict_data = self._process_input_data(predict_data)
Expand All @@ -262,7 +262,8 @@ def finetune(self,

train_data = self._process_input_data(train_data) if \
not self.api_controller.condition_check.input_data_is_fedot_type(train_data) else train_data
tuning_params = ApiConverter.tuning_params_is_none(tuning_params)
if tuning_params is None:
tuning_params = ApiConverter.tuning_params_is_none(tuning_params)
tuning_params['metric'] = FEDOT_TUNING_METRICS[self.config_dict['problem']]

for tuner_name, tuner_type in FEDOT_TUNER_STRATEGY.items():
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/api/main/test_api_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def test_finetune(fedot_industrial_classification):
industrial = fedot_industrial_classification
data = univariate_clf_data()
industrial.fit(data)
industrial.finetune(data)
industrial.finetune(train_data=data, tuning_params={'tuning_timeout': 0.1})
assert industrial.solver is not None


Expand Down

0 comments on commit 257b24a

Please sign in to comment.