Skip to content

Commit

Permalink
ellipsis.
Browse files Browse the repository at this point in the history
  • Loading branch information
trivialfis committed Jun 17, 2024
1 parent c793524 commit 55c4f1b
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 8 deletions.
11 changes: 6 additions & 5 deletions python-package/xgboost/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from enum import IntEnum, unique
from functools import wraps
from inspect import Parameter, signature
from types import EllipsisType
from typing import (
Any,
Callable,
Expand Down Expand Up @@ -1829,7 +1830,7 @@ def __setstate__(self, state: Dict) -> None:
state["handle"] = handle
self.__dict__.update(state)

def __getitem__(self, val: Union[Integer, tuple, slice]) -> "Booster":
def __getitem__(self, val: Union[Integer, tuple, slice, EllipsisType]) -> "Booster":
"""Get a slice of the tree-based model.
.. versionadded:: 1.3.0
Expand All @@ -1838,21 +1839,21 @@ def __getitem__(self, val: Union[Integer, tuple, slice]) -> "Booster":
# convert to slice for all other types
if isinstance(val, (np.integer, int)):
val = slice(int(val), int(val + 1))
if isinstance(val, type(Ellipsis)):
if isinstance(val, EllipsisType):
val = slice(0, 0)
if isinstance(val, tuple):
raise ValueError("Only supports slicing through 1 dimension.")
# All supported types are now slice
# FIXME(jiamingy): Use `types.EllipsisType` once Python 3.10 is used.
if not isinstance(val, slice):
msg = _expect((int, slice, np.integer, type(Ellipsis)), type(val))
msg = _expect((int, slice, np.integer, EllipsisType), type(val))
raise TypeError(msg)

if isinstance(val.start, type(Ellipsis)) or val.start is None:
if isinstance(val.start, EllipsisType) or val.start is None:
start = 0
else:
start = val.start
if isinstance(val.stop, type(Ellipsis)) or val.stop is None:
if isinstance(val.stop, EllipsisType) or val.stop is None:
stop = 0
else:
stop = val.stop
Expand Down
8 changes: 5 additions & 3 deletions tests/python/test_basic_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -425,8 +425,8 @@ def run_slice(
np.testing.assert_allclose(merged, single, atol=1e-6)

@pytest.mark.skipif(**tm.no_sklearn())
@pytest.mark.parametrize("booster", ["gbtree", "dart"])
def test_slice(self, booster):
@pytest.mark.parametrize("booster_name", ["gbtree", "dart"])
def test_slice(self, booster_name: str) -> None:
from sklearn.datasets import make_classification

num_classes = 3
Expand All @@ -442,7 +442,7 @@ def test_slice(self, booster):
"num_parallel_tree": num_parallel_tree,
"subsample": 0.5,
"num_class": num_classes,
"booster": booster,
"booster": booster_name,
"objective": "multi:softprob",
},
num_boost_round=num_boost_round,
Expand All @@ -452,6 +452,8 @@ def test_slice(self, booster):

assert len(booster.get_dump()) == total_trees

assert booster[...].num_boosted_rounds() == num_boost_round

self.run_slice(
booster, dtrain, num_parallel_tree, num_classes, num_boost_round, False
)
Expand Down

0 comments on commit 55c4f1b

Please sign in to comment.