Skip to content

Commit

Permalink
Fix unleved bug with bert
Browse files Browse the repository at this point in the history
  • Loading branch information
mirand863 committed Dec 2, 2024
1 parent a344014 commit a4a006e
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 1 deletion.
2 changes: 1 addition & 1 deletion hiclass/HierarchicalClassifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def _pre_fit(self, X, y, sample_weight):
)
else:
self.X_ = np.array(X)
self.y_ = np.array(y)
self.y_ = np.array(make_leveled(y))

if sample_weight is not None:
self.sample_weight_ = _check_sample_weight(sample_weight, X)
Expand Down
11 changes: 11 additions & 0 deletions tests/test_LocalClassifiers.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,3 +143,14 @@ def test_tmp_dir(classifier):
assert expected_name == name
check_is_fitted(classifier)
clf.fit(x, y)


@pytest.mark.parametrize("classifier", classifiers)
def test_bert_unleveled(classifier):
clf = classifier(
local_classifier=LogisticRegression(),
bert=True,
)
x = [[0, 1], [2, 3]]
y = [["a"], ["b", "c"]]
clf.fit(x, y)

0 comments on commit a4a006e

Please sign in to comment.