Skip to content

Commit

Permalink
black formatting for tests
Browse files Browse the repository at this point in the history
  • Loading branch information
LukasDrews97 committed Apr 11, 2024
1 parent 2d0b9db commit 121e134
Show file tree
Hide file tree
Showing 6 changed files with 756 additions and 273 deletions.
105 changes: 76 additions & 29 deletions tests/test_LocalClassifierPerLevel.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,9 @@ def test_sklearn_compatible_estimator(estimator, check):

@pytest.fixture
def digraph_logistic_regression():
digraph = LocalClassifierPerLevel(local_classifier=LogisticRegression(), calibration_method="ivap")
digraph = LocalClassifierPerLevel(
local_classifier=LogisticRegression(), calibration_method="ivap"
)
digraph.hierarchy_ = nx.DiGraph([("a", "b"), ("a", "c")])
digraph.y_ = np.array([["a", "b"], ["a", "c"]])
digraph.X_ = np.array([[1, 2], [3, 4]])
Expand All @@ -46,14 +48,12 @@ def test_initialize_local_classifiers(digraph_logistic_regression):
LogisticRegression,
)


def test_initialize_local_calibrators(digraph_logistic_regression):
digraph_logistic_regression._initialize_local_classifiers()
digraph_logistic_regression._initialize_local_calibrators()
for calibrator in digraph_logistic_regression.local_calibrators_:
assert isinstance(
calibrator,
_Calibrator
)
assert isinstance(calibrator, _Calibrator)


def test_fit_digraph(digraph_logistic_regression):
Expand All @@ -71,6 +71,7 @@ def test_fit_digraph(digraph_logistic_regression):
pytest.fail(repr(e))
assert 1


def test_calibrate_digraph(digraph_logistic_regression):
classifiers = [
LogisticRegression(),
Expand All @@ -80,7 +81,10 @@ def test_calibrate_digraph(digraph_logistic_regression):
digraph_logistic_regression.local_classifiers_ = classifiers
digraph_logistic_regression._fit_digraph(local_mode=True)

calibrators = [_Calibrator(classifier) for classifier in digraph_logistic_regression.local_classifiers_]
calibrators = [
_Calibrator(classifier)
for classifier in digraph_logistic_regression.local_classifiers_
]
digraph_logistic_regression.local_calibrators_ = calibrators
digraph_logistic_regression._calibrate_digraph(local_mode=True)

Expand All @@ -91,7 +95,6 @@ def test_calibrate_digraph(digraph_logistic_regression):
assert 1



def test_fit_digraph_joblib_multiprocessing(digraph_logistic_regression):
classifiers = [
LogisticRegression(),
Expand All @@ -108,6 +111,7 @@ def test_fit_digraph_joblib_multiprocessing(digraph_logistic_regression):
pytest.fail(repr(e))
assert 1


def test_calibrate_digraph_joblib_multiprocessing(digraph_logistic_regression):
classifiers = [
LogisticRegression(),
Expand All @@ -117,7 +121,10 @@ def test_calibrate_digraph_joblib_multiprocessing(digraph_logistic_regression):
digraph_logistic_regression.local_classifiers_ = classifiers
digraph_logistic_regression._fit_digraph(local_mode=True, use_joblib=True)

calibrators = [_Calibrator(classifier) for classifier in digraph_logistic_regression.local_classifiers_]
calibrators = [
_Calibrator(classifier)
for classifier in digraph_logistic_regression.local_classifiers_
]
digraph_logistic_regression.local_calibrators_ = calibrators
digraph_logistic_regression._calibrate_digraph(local_mode=True, use_joblib=True)

Expand All @@ -128,15 +135,15 @@ def test_calibrate_digraph_joblib_multiprocessing(digraph_logistic_regression):
assert 1



@pytest.fixture
def fitted_logistic_regression():
digraph = LocalClassifierPerLevel(
local_classifier=LogisticRegression(),
return_all_probabilities=True,
calibration_method="ivap",
probability_combiner=None)

probability_combiner=None,
)

digraph.separator_ = "::HiClass::Separator::"
digraph.hierarchy_ = nx.DiGraph(
[("r", "1"), ("r", "2"), ("1", "1.1"), ("1", "1.2"), ("2", "2.1"), ("2", "2.2")]
Expand All @@ -148,16 +155,38 @@ def fitted_logistic_regression():

# for predict_proba
tmp_labels = digraph._disambiguate(make_leveled(digraph.y_))
digraph.max_level_dimensions_ = np.array([len(np.unique(tmp_labels[:, level])) for level in range(tmp_labels.shape[1])])
digraph.global_classes_ = [np.unique(tmp_labels[:, level]).astype("str") for level in range(tmp_labels.shape[1])]
digraph.global_class_to_index_mapping_ = [{digraph.global_classes_[level][index]: index for index in range(len(digraph.global_classes_[level]))} for level in range(tmp_labels.shape[1])]
digraph.max_level_dimensions_ = np.array(
[len(np.unique(tmp_labels[:, level])) for level in range(tmp_labels.shape[1])]
)
digraph.global_classes_ = [
np.unique(tmp_labels[:, level]).astype("str")
for level in range(tmp_labels.shape[1])
]
digraph.global_class_to_index_mapping_ = [
{
digraph.global_classes_[level][index]: index
for index in range(len(digraph.global_classes_[level]))
}
for level in range(tmp_labels.shape[1])
]

classes_ = [digraph.global_classes_[0]]
for level in range(1, digraph.max_levels_):
classes_.append(np.sort(np.unique([label.split(digraph.separator_)[level] for label in digraph.global_classes_[level]])))
classes_.append(
np.sort(
np.unique(
[
label.split(digraph.separator_)[level]
for label in digraph.global_classes_[level]
]
)
)
)
digraph.classes_ = classes_
digraph.class_to_index_mapping_ = [{local_labels[index]: index for index in range(len(local_labels))} for local_labels in classes_]

digraph.class_to_index_mapping_ = [
{local_labels[index]: index for index in range(len(local_labels))}
for local_labels in classes_
]

digraph.dtype_ = "<U3"
digraph.root_ = "r"
Expand Down Expand Up @@ -194,17 +223,28 @@ def test_predict_proba(fitted_logistic_regression):
assert len(proba) == 2
assert proba[0].shape == (4, 2)
assert proba[1].shape == (4, 4)
assert_array_almost_equal(np.sum(proba[0], axis=1), np.ones(len(proba[0])), decimal=10)
assert_array_almost_equal(np.sum(proba[1], axis=1), np.ones(len(proba[1])), decimal=10)
assert_array_almost_equal(
np.sum(proba[0], axis=1), np.ones(len(proba[0])), decimal=10
)
assert_array_almost_equal(
np.sum(proba[1], axis=1), np.ones(len(proba[1])), decimal=10
)


def test_predict_proba_sparse(fitted_logistic_regression):
proba = fitted_logistic_regression.predict_proba(csr_matrix([[7, 8], [5, 6], [3, 4], [1, 2]]))
proba = fitted_logistic_regression.predict_proba(
csr_matrix([[7, 8], [5, 6], [3, 4], [1, 2]])
)
assert len(proba) == 2
assert proba[0].shape == (4, 2)
assert proba[1].shape == (4, 4)
assert_array_almost_equal(np.sum(proba[0], axis=1), np.ones(len(proba[0])), decimal=10)
assert_array_almost_equal(np.sum(proba[1], axis=1), np.ones(len(proba[1])), decimal=10)
assert_array_almost_equal(
np.sum(proba[0], axis=1), np.ones(len(proba[0])), decimal=10
)
assert_array_almost_equal(
np.sum(proba[1], axis=1), np.ones(len(proba[1])), decimal=10
)


def test_fit_predict():
lcpl = LocalClassifierPerLevel(local_classifier=LogisticRegression())
Expand All @@ -223,13 +263,15 @@ def test_fit_predict():
predictions = lcpl.predict(x)
assert_array_equal(y, predictions)


def test_fit_calibrate_predict_proba():
lcpl = LocalClassifierPerLevel(
local_classifier=LogisticRegression(),
local_classifier=LogisticRegression(),
return_all_probabilities=True,
calibration_method="ivap",
probability_combiner="geometric")

probability_combiner="geometric",
)

x = np.array([[1, 2], [3, 4]])
y = np.array([["a", "b"], ["b", "c"]])
ground_truth = np.array(
Expand All @@ -250,15 +292,20 @@ def test_fit_calibrate_predict_proba():
assert len(proba) == 2
assert proba[0].shape == (2, 2)
assert proba[1].shape == (2, 2)
assert_array_almost_equal(np.sum(proba[0], axis=1), np.ones(len(proba[0])), decimal=10)
assert_array_almost_equal(np.sum(proba[1], axis=1), np.ones(len(proba[1])), decimal=10)
assert_array_almost_equal(
np.sum(proba[0], axis=1), np.ones(len(proba[0])), decimal=10
)
assert_array_almost_equal(
np.sum(proba[1], axis=1), np.ones(len(proba[1])), decimal=10
)


def test_fit_calibrate_predict_predict_proba_bert():
classifier = LocalClassifierPerLevel(
local_classifier=LogisticRegression(),
local_classifier=LogisticRegression(),
return_all_probabilities=True,
calibration_method="ivap",
probability_combiner="geometric"
probability_combiner="geometric",
)

classifier.logger_ = logging.getLogger("HC")
Expand Down
Loading

0 comments on commit 121e134

Please sign in to comment.