Skip to content

Commit

Permalink
[Feature] return probs in binary classification (#36)
Browse files Browse the repository at this point in the history
  • Loading branch information
RektPunk authored Sep 13, 2024
1 parent 34f0730 commit e8db7bc
Show file tree
Hide file tree
Showing 4 changed files with 6 additions and 11 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "rektgbm"
version = "0.1.4"
version = "0.1.5"
description = "No-brainer machine learning solution to achieve satisfactory performance"
authors = ["RektPunk <[email protected]>"]
license = "MIT"
Expand Down
2 changes: 1 addition & 1 deletion rektgbm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,4 +3,4 @@
from rektgbm.gbm import RektGBM
from rektgbm.optimizer import RektOptimizer

__version__ = "0.1.4"
__version__ = "0.1.5"
11 changes: 3 additions & 8 deletions rektgbm/gbm.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def fit(
)
if self._task_type in {TaskType.binary, TaskType.multiclass, TaskType.rank}:
self.label_encoder = dataset.fit_transform_label()
self._label_encoder_used = True
self._is_label_encoder_used = True

if valid_set is not None and self.__is_label_encoder_used:
valid_set.transform_label(label_encoder=self.label_encoder)
Expand All @@ -66,13 +66,9 @@ def fit(

def predict(self, dataset: RektDataset):
preds = self.engine.predict(dataset=dataset)

if self._task_type in {TaskType.regression, TaskType.rank}:
if self._task_type in {TaskType.binary, TaskType.regression, TaskType.rank}:
return preds

if self._task_type == TaskType.binary:
preds = np.around(preds).astype(int)

if self._task_type == TaskType.multiclass:
if self.method == MethodName.lightgbm:
preds = np.argmax(preds, axis=1).astype(int)
Expand All @@ -81,9 +77,8 @@ def predict(self, dataset: RektDataset):

if self.__is_label_encoder_used:
preds = self.label_encoder.inverse_transform(series=preds)

return preds

@property
def __is_label_encoder_used(self) -> bool:
return getattr(self, "_label_encoder_used", False)
return getattr(self, "_is_label_encoder_used", False)
2 changes: 1 addition & 1 deletion tests/test_gbm.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def test_rektgbm_predict_binary(mock_dataset, mock_engine):
gbm._is_fitted = True

preds = gbm.predict(dataset=mock_dataset)
np.testing.assert_allclose(preds, [0, 1, 0, 1], rtol=1e-5)
np.testing.assert_allclose(preds, [0.1, 0.9, 0.2, 0.8], rtol=1e-5)


def test_rektgbm_predict_multiclass(mock_dataset, mock_engine):
Expand Down

1 comment on commit e8db7bc

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tests Skipped Failures Errors Time
91 0 💤 0 ❌ 0 🔥 5.522s ⏱️

Please sign in to comment.