Skip to content

Commit

Permalink
added: test about reproducibility using seed
Browse files Browse the repository at this point in the history
  • Loading branch information
Caparrini committed Jan 23, 2024
1 parent 3a74a7d commit 047f09c
Showing 1 changed file with 26 additions and 0 deletions.
26 changes: 26 additions & 0 deletions mloptimizer/test/test_genoptimizer/test_optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
GradientBoostingOptimizer, SVCOptimizer, XGBClassifierOptimizer, KerasClassifierOptimizer, \
CustomXGBClassifierOptimizer, CatBoostClassifierOptimizer, \
BaseOptimizer
from mloptimizer.evaluation import kfold_score
from sklearn.metrics import accuracy_score, f1_score, roc_auc_score, \
balanced_accuracy_score, precision_score, recall_score, \
average_precision_score, log_loss, mean_squared_error, mean_absolute_error, \
Expand Down Expand Up @@ -39,3 +40,28 @@ def test_get_subclasses():
]
assert all([subclass.__name__ in subclasses_names for subclass in subclasses]) and \
len(subclasses) == len(subclasses_names)

Check warning on line 42 in mloptimizer/test/test_genoptimizer/test_optimizers.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

Incorrect type

Expected type 'Sized', got 'Iterable' instead


@pytest.mark.parametrize('optimizer',
(TreeOptimizer, ForestOptimizer,
# ExtraTreesOptimizer, GradientBoostingOptimizer,
XGBClassifierOptimizer,
# SVCOptimizer,KerasClassifierOptimizer
))
def test_reproducibility(optimizer):
X, y = load_iris(return_X_y=True)

Check notice on line 52 in mloptimizer/test/test_genoptimizer/test_optimizers.py

View workflow job for this annotation

GitHub Actions / Qodana Community for Python

PEP 8 naming convention violation

Variable in function should be lowercase
population = 2
generations = 2
seed = 25
distinct_seed = 2
optimizer1 = optimizer(X, y, score_function=balanced_accuracy_score,
eval_function=kfold_score, seed=seed)
result1 = optimizer1.optimize_clf(population=population, generations=generations)
optimizer2 = optimizer(X, y, score_function=balanced_accuracy_score,
eval_function=kfold_score, seed=seed)
result2 = optimizer2.optimize_clf(population=population, generations=generations)
optimizer3 = optimizer(X, y, score_function=balanced_accuracy_score,
eval_function=kfold_score, seed=distinct_seed)
result3 = optimizer3.optimize_clf(population=population, generations=generations)
assert str(result1) == str(result2)
assert str(result1) != str(result3)

0 comments on commit 047f09c

Please sign in to comment.