From 121e1347c95389d3534955056d752838dc76d244 Mon Sep 17 00:00:00 2001 From: Lukas Drews Date: Thu, 11 Apr 2024 22:08:05 +0200 Subject: [PATCH] black formatting for tests --- tests/test_LocalClassifierPerLevel.py | 105 +++++-- tests/test_LocalClassifierPerNode.py | 168 ++++++++--- tests/test_LocalClassifierPerParentNode.py | 123 ++++++-- tests/test_ProbabilityCombiner.py | 204 ++++++++++--- tests/test_calibration.py | 336 +++++++++++++++------ tests/test_metrics.py | 93 ++++-- 6 files changed, 756 insertions(+), 273 deletions(-) diff --git a/tests/test_LocalClassifierPerLevel.py b/tests/test_LocalClassifierPerLevel.py index 6286a06d..880fc904 100644 --- a/tests/test_LocalClassifierPerLevel.py +++ b/tests/test_LocalClassifierPerLevel.py @@ -21,7 +21,9 @@ def test_sklearn_compatible_estimator(estimator, check): @pytest.fixture def digraph_logistic_regression(): - digraph = LocalClassifierPerLevel(local_classifier=LogisticRegression(), calibration_method="ivap") + digraph = LocalClassifierPerLevel( + local_classifier=LogisticRegression(), calibration_method="ivap" + ) digraph.hierarchy_ = nx.DiGraph([("a", "b"), ("a", "c")]) digraph.y_ = np.array([["a", "b"], ["a", "c"]]) digraph.X_ = np.array([[1, 2], [3, 4]]) @@ -46,14 +48,12 @@ def test_initialize_local_classifiers(digraph_logistic_regression): LogisticRegression, ) + def test_initialize_local_calibrators(digraph_logistic_regression): digraph_logistic_regression._initialize_local_classifiers() digraph_logistic_regression._initialize_local_calibrators() for calibrator in digraph_logistic_regression.local_calibrators_: - assert isinstance( - calibrator, - _Calibrator - ) + assert isinstance(calibrator, _Calibrator) def test_fit_digraph(digraph_logistic_regression): @@ -71,6 +71,7 @@ def test_fit_digraph(digraph_logistic_regression): pytest.fail(repr(e)) assert 1 + def test_calibrate_digraph(digraph_logistic_regression): classifiers = [ LogisticRegression(), @@ -80,7 +81,10 @@ def test_calibrate_digraph(digraph_logistic_regression): digraph_logistic_regression.local_classifiers_ = classifiers digraph_logistic_regression._fit_digraph(local_mode=True) - calibrators = [_Calibrator(classifier) for classifier in digraph_logistic_regression.local_classifiers_] + calibrators = [ + _Calibrator(classifier) + for classifier in digraph_logistic_regression.local_classifiers_ + ] digraph_logistic_regression.local_calibrators_ = calibrators digraph_logistic_regression._calibrate_digraph(local_mode=True) @@ -91,7 +95,6 @@ def test_calibrate_digraph(digraph_logistic_regression): assert 1 - def test_fit_digraph_joblib_multiprocessing(digraph_logistic_regression): classifiers = [ LogisticRegression(), @@ -108,6 +111,7 @@ def test_fit_digraph_joblib_multiprocessing(digraph_logistic_regression): pytest.fail(repr(e)) assert 1 + def test_calibrate_digraph_joblib_multiprocessing(digraph_logistic_regression): classifiers = [ LogisticRegression(), @@ -117,7 +121,10 @@ def test_calibrate_digraph_joblib_multiprocessing(digraph_logistic_regression): digraph_logistic_regression.local_classifiers_ = classifiers digraph_logistic_regression._fit_digraph(local_mode=True, use_joblib=True) - calibrators = [_Calibrator(classifier) for classifier in digraph_logistic_regression.local_classifiers_] + calibrators = [ + _Calibrator(classifier) + for classifier in digraph_logistic_regression.local_classifiers_ + ] digraph_logistic_regression.local_calibrators_ = calibrators digraph_logistic_regression._calibrate_digraph(local_mode=True, use_joblib=True) @@ -128,15 +135,15 @@ def test_calibrate_digraph_joblib_multiprocessing(digraph_logistic_regression): assert 1 - @pytest.fixture def fitted_logistic_regression(): digraph = LocalClassifierPerLevel( local_classifier=LogisticRegression(), return_all_probabilities=True, calibration_method="ivap", - probability_combiner=None) - + probability_combiner=None, + ) + digraph.separator_ = "::HiClass::Separator::" digraph.hierarchy_ = nx.DiGraph( [("r", "1"), ("r", "2"), ("1", "1.1"), ("1", "1.2"), ("2", "2.1"), ("2", "2.2")] @@ -148,16 +155,38 @@ def fitted_logistic_regression(): # for predict_proba tmp_labels = digraph._disambiguate(make_leveled(digraph.y_)) - digraph.max_level_dimensions_ = np.array([len(np.unique(tmp_labels[:, level])) for level in range(tmp_labels.shape[1])]) - digraph.global_classes_ = [np.unique(tmp_labels[:, level]).astype("str") for level in range(tmp_labels.shape[1])] - digraph.global_class_to_index_mapping_ = [{digraph.global_classes_[level][index]: index for index in range(len(digraph.global_classes_[level]))} for level in range(tmp_labels.shape[1])] + digraph.max_level_dimensions_ = np.array( + [len(np.unique(tmp_labels[:, level])) for level in range(tmp_labels.shape[1])] + ) + digraph.global_classes_ = [ + np.unique(tmp_labels[:, level]).astype("str") + for level in range(tmp_labels.shape[1]) + ] + digraph.global_class_to_index_mapping_ = [ + { + digraph.global_classes_[level][index]: index + for index in range(len(digraph.global_classes_[level])) + } + for level in range(tmp_labels.shape[1]) + ] classes_ = [digraph.global_classes_[0]] for level in range(1, digraph.max_levels_): - classes_.append(np.sort(np.unique([label.split(digraph.separator_)[level] for label in digraph.global_classes_[level]]))) + classes_.append( + np.sort( + np.unique( + [ + label.split(digraph.separator_)[level] + for label in digraph.global_classes_[level] + ] + ) + ) + ) digraph.classes_ = classes_ - digraph.class_to_index_mapping_ = [{local_labels[index]: index for index in range(len(local_labels))} for local_labels in classes_] - + digraph.class_to_index_mapping_ = [ + {local_labels[index]: index for index in range(len(local_labels))} + for local_labels in classes_ + ] digraph.dtype_ = "