Skip to content

Commit

Permalink
Apply suggestions from code review
Browse files Browse the repository at this point in the history
Co-authored-by: José Morales <[email protected]>
  • Loading branch information
jameslamb and jmoralez authored Nov 7, 2023
1 parent 29422e5 commit c0b507e
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 9 deletions.
10 changes: 2 additions & 8 deletions python-package/lightgbm/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,10 +194,7 @@ def __call__(self, preds: np.ndarray, dataset: Dataset) -> Tuple[np.ndarray, np.

if argc == 4:
group = _get_group_from_constructed_dataset(dataset)
if group is not None:
return self.func(labels, preds, weight, group) # type: ignore[call-arg]
else:
return self.func(labels, preds, weight, group) # type: ignore[call-arg]
return self.func(labels, preds, weight, group) # type: ignore[call-arg]

raise TypeError(f"Self-defined objective function should have 2, 3 or 4 arguments, got {argc}")

Expand Down Expand Up @@ -278,10 +275,7 @@ def __call__(

if argc == 4:
group = _get_group_from_constructed_dataset(dataset)
if group is not None:
return self.func(labels, preds, weight, group) # type: ignore[call-arg]
else:
return self.func(labels, preds, weight, group) # type: ignore[call-arg]
return self.func(labels, preds, weight, group) # type: ignore[call-arg]

raise TypeError(f"Self-defined eval function should have 2, 3 or 4 arguments, got {argc}")

Expand Down
2 changes: 1 addition & 1 deletion tests/python_package_test/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,7 +503,7 @@ def test_dataset_construction_overwrites_user_provided_metadata_fields():

X = np.array([[1.0, 2.0], [3.0, 4.0]])

position=np.array([0.0, 1.0], dtype=np.float32)
position = np.array([0.0, 1.0], dtype=np.float32)
if getenv('TASK', '') == 'cuda':
position = None

Expand Down

0 comments on commit c0b507e

Please sign in to comment.