Skip to content

Commit 097a58e

Browse files
committed
Add encoder
1 parent 6f37990 commit 097a58e

File tree

2 files changed

+12
-6
lines changed

2 files changed

+12
-6
lines changed

hiclass/LocalClassifierPerLevel.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,13 @@
55
"""
66

77
import hashlib
8+
import numpy as np
89
import pickle
910
from copy import deepcopy
10-
from os.path import exists
11-
12-
import numpy as np
1311
from joblib import Parallel, delayed
12+
from os.path import exists
1413
from sklearn.base import BaseEstimator
14+
from sklearn.preprocessing import LabelEncoder
1515
from sklearn.utils.validation import check_array, check_is_fitted
1616

1717
from hiclass.ConstantClassifier import ConstantClassifier
@@ -273,6 +273,9 @@ def _fit_classifier(self, level, separator):
273273
classifier = ConstantClassifier()
274274
if not self.bert:
275275
try:
276+
label_encoder = LabelEncoder()
277+
label_encoder.fit(y)
278+
y = label_encoder.transform(y)
276279
classifier.fit(X, y, sample_weight)
277280
except TypeError:
278281
classifier.fit(X, y)

hiclass/LocalClassifierPerParentNode.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,13 @@
55
"""
66

77
import hashlib
8+
import networkx as nx
9+
import numpy as np
810
import pickle
911
from copy import deepcopy
1012
from os.path import exists
11-
12-
import networkx as nx
13-
import numpy as np
1413
from sklearn.base import BaseEstimator
14+
from sklearn.preprocessing import LabelEncoder
1515
from sklearn.utils.validation import check_array, check_is_fitted
1616

1717
from hiclass.ConstantClassifier import ConstantClassifier
@@ -231,6 +231,9 @@ def _fit_classifier(self, node):
231231
classifier = ConstantClassifier()
232232
if not self.bert:
233233
try:
234+
label_encoder = LabelEncoder()
235+
label_encoder.fit(y)
236+
y = label_encoder.transform(y)
234237
classifier.fit(X, y, sample_weight)
235238
except TypeError:
236239
classifier.fit(X, y)

0 commit comments

Comments
 (0)