diff --git a/hiclass/LocalClassifierPerLevel.py b/hiclass/LocalClassifierPerLevel.py index 907e61cf..2420ac36 100644 --- a/hiclass/LocalClassifierPerLevel.py +++ b/hiclass/LocalClassifierPerLevel.py @@ -5,13 +5,13 @@ """ import hashlib +import numpy as np import pickle from copy import deepcopy -from os.path import exists - -import numpy as np from joblib import Parallel, delayed +from os.path import exists from sklearn.base import BaseEstimator +from sklearn.preprocessing import LabelEncoder from sklearn.utils.validation import check_array, check_is_fitted from hiclass.ConstantClassifier import ConstantClassifier @@ -273,6 +273,9 @@ def _fit_classifier(self, level, separator): classifier = ConstantClassifier() if not self.bert: try: + label_encoder = LabelEncoder() + label_encoder.fit(y) + y = label_encoder.transform(y) classifier.fit(X, y, sample_weight) except TypeError: classifier.fit(X, y) diff --git a/hiclass/LocalClassifierPerParentNode.py b/hiclass/LocalClassifierPerParentNode.py index 47f77475..77d674a5 100644 --- a/hiclass/LocalClassifierPerParentNode.py +++ b/hiclass/LocalClassifierPerParentNode.py @@ -5,13 +5,13 @@ """ import hashlib +import networkx as nx +import numpy as np import pickle from copy import deepcopy from os.path import exists - -import networkx as nx -import numpy as np from sklearn.base import BaseEstimator +from sklearn.preprocessing import LabelEncoder from sklearn.utils.validation import check_array, check_is_fitted from hiclass.ConstantClassifier import ConstantClassifier @@ -231,6 +231,9 @@ def _fit_classifier(self, node): classifier = ConstantClassifier() if not self.bert: try: + label_encoder = LabelEncoder() + label_encoder.fit(y) + y = label_encoder.transform(y) classifier.fit(X, y, sample_weight) except TypeError: classifier.fit(X, y)