From 097a58e975f526f13691fa1ceda6234ebf1b5181 Mon Sep 17 00:00:00 2001 From: Fabio Date: Tue, 23 Apr 2024 17:17:21 +0200 Subject: [PATCH] Add encoder --- hiclass/LocalClassifierPerLevel.py | 9 ++++++--- hiclass/LocalClassifierPerParentNode.py | 9 ++++++--- 2 files changed, 12 insertions(+), 6 deletions(-) diff --git a/hiclass/LocalClassifierPerLevel.py b/hiclass/LocalClassifierPerLevel.py index 907e61cf..2420ac36 100644 --- a/hiclass/LocalClassifierPerLevel.py +++ b/hiclass/LocalClassifierPerLevel.py @@ -5,13 +5,13 @@ """ import hashlib +import numpy as np import pickle from copy import deepcopy -from os.path import exists - -import numpy as np from joblib import Parallel, delayed +from os.path import exists from sklearn.base import BaseEstimator +from sklearn.preprocessing import LabelEncoder from sklearn.utils.validation import check_array, check_is_fitted from hiclass.ConstantClassifier import ConstantClassifier @@ -273,6 +273,9 @@ def _fit_classifier(self, level, separator): classifier = ConstantClassifier() if not self.bert: try: + label_encoder = LabelEncoder() + label_encoder.fit(y) + y = label_encoder.transform(y) classifier.fit(X, y, sample_weight) except TypeError: classifier.fit(X, y) diff --git a/hiclass/LocalClassifierPerParentNode.py b/hiclass/LocalClassifierPerParentNode.py index 47f77475..77d674a5 100644 --- a/hiclass/LocalClassifierPerParentNode.py +++ b/hiclass/LocalClassifierPerParentNode.py @@ -5,13 +5,13 @@ """ import hashlib +import networkx as nx +import numpy as np import pickle from copy import deepcopy from os.path import exists - -import networkx as nx -import numpy as np from sklearn.base import BaseEstimator +from sklearn.preprocessing import LabelEncoder from sklearn.utils.validation import check_array, check_is_fitted from hiclass.ConstantClassifier import ConstantClassifier @@ -231,6 +231,9 @@ def _fit_classifier(self, node): classifier = ConstantClassifier() if not self.bert: try: + label_encoder = LabelEncoder() + label_encoder.fit(y) + y = label_encoder.transform(y) classifier.fit(X, y, sample_weight) except TypeError: classifier.fit(X, y)