Skip to content

Commit

Permalink
refactor gbm
Browse files Browse the repository at this point in the history
  • Loading branch information
RektPunk committed Sep 13, 2024
1 parent e8db7bc commit f69f138
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 12 deletions.
15 changes: 3 additions & 12 deletions rektgbm/gbm.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,10 +49,8 @@ def fit(
)
if self._task_type in {TaskType.binary, TaskType.multiclass, TaskType.rank}:
self.label_encoder = dataset.fit_transform_label()
self._is_label_encoder_used = True

if valid_set is not None and self.__is_label_encoder_used:
valid_set.transform_label(label_encoder=self.label_encoder)
if valid_set:
valid_set.transform_label(label_encoder=self.label_encoder)

_objective = self.rekt_objective.get_objective_dict(method=self.method)
_metric = self.rekt_metric.get_metric_dict(method=self.method)
Expand All @@ -74,11 +72,4 @@ def predict(self, dataset: RektDataset):
preds = np.argmax(preds, axis=1).astype(int)
else:
preds = np.around(preds).astype(int)

if self.__is_label_encoder_used:
preds = self.label_encoder.inverse_transform(series=preds)
return preds

@property
def __is_label_encoder_used(self) -> bool:
return getattr(self, "_is_label_encoder_used", False)
return self.label_encoder.inverse_transform(series=preds)
3 changes: 3 additions & 0 deletions tests/test_gbm.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from rektgbm.base import MethodName
from rektgbm.dataset import RektDataset
from rektgbm.encoder import RektLabelEncoder
from rektgbm.engine import RektEngine
from rektgbm.gbm import RektGBM
from rektgbm.task import TaskType
Expand Down Expand Up @@ -90,6 +91,8 @@ def test_rektgbm_predict_multiclass(mock_dataset, mock_engine):
gbm.engine = mock_engine
gbm._task_type = TaskType.multiclass
gbm._is_fitted = True
gbm.label_encoder = RektLabelEncoder()
gbm.label_encoder.fit_label([0, 1, 2])

mock_engine.predict.return_value = np.array(
[[0.1, 0.7, 0.2], [0.3, 0.4, 0.3], [0.2, 0.2, 0.6]]
Expand Down

0 comments on commit f69f138

Please sign in to comment.