Skip to content

Commit

Permalink
[python-package] Accept numpy generators as random_state (#6174)
Browse files Browse the repository at this point in the history
  • Loading branch information
david-cortes authored Nov 9, 2023
1 parent 5e90255 commit 501e6e6
Show file tree
Hide file tree
Showing 4 changed files with 24 additions and 9 deletions.
10 changes: 10 additions & 0 deletions python-package/lightgbm/compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,16 @@ def __init__(self, *args, **kwargs):

concat = None

"""numpy"""
try:
from numpy.random import Generator as np_random_Generator
except ImportError:
class np_random_Generator: # type: ignore
"""Dummy class for np.random.Generator."""

def __init__(self, *args, **kwargs):
pass

"""matplotlib"""
try:
import matplotlib # noqa: F401
Expand Down
6 changes: 3 additions & 3 deletions python-package/lightgbm/dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -1142,7 +1142,7 @@ def __init__(
colsample_bytree: float = 1.,
reg_alpha: float = 0.,
reg_lambda: float = 0.,
random_state: Optional[Union[int, np.random.RandomState]] = None,
random_state: Optional[Union[int, np.random.RandomState, 'np.random.Generator']] = None,
n_jobs: Optional[int] = None,
importance_type: str = 'split',
client: Optional[Client] = None,
Expand Down Expand Up @@ -1347,7 +1347,7 @@ def __init__(
colsample_bytree: float = 1.,
reg_alpha: float = 0.,
reg_lambda: float = 0.,
random_state: Optional[Union[int, np.random.RandomState]] = None,
random_state: Optional[Union[int, np.random.RandomState, 'np.random.Generator']] = None,
n_jobs: Optional[int] = None,
importance_type: str = 'split',
client: Optional[Client] = None,
Expand Down Expand Up @@ -1517,7 +1517,7 @@ def __init__(
colsample_bytree: float = 1.,
reg_alpha: float = 0.,
reg_lambda: float = 0.,
random_state: Optional[Union[int, np.random.RandomState]] = None,
random_state: Optional[Union[int, np.random.RandomState, 'np.random.Generator']] = None,
n_jobs: Optional[int] = None,
importance_type: str = 'split',
client: Optional[Client] = None,
Expand Down
10 changes: 7 additions & 3 deletions python-package/lightgbm/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from .compat import (SKLEARN_INSTALLED, LGBMNotFittedError, _LGBMAssertAllFinite, _LGBMCheckArray,
_LGBMCheckClassificationTargets, _LGBMCheckSampleWeight, _LGBMCheckXY, _LGBMClassifierBase,
_LGBMComputeSampleWeight, _LGBMCpuCount, _LGBMLabelEncoder, _LGBMModelBase, _LGBMRegressorBase,
dt_DataTable, pd_DataFrame)
dt_DataTable, np_random_Generator, pd_DataFrame)
from .engine import train

__all__ = [
Expand Down Expand Up @@ -448,7 +448,7 @@ def __init__(
colsample_bytree: float = 1.,
reg_alpha: float = 0.,
reg_lambda: float = 0.,
random_state: Optional[Union[int, np.random.RandomState]] = None,
random_state: Optional[Union[int, np.random.RandomState, 'np.random.Generator']] = None,
n_jobs: Optional[int] = None,
importance_type: str = 'split',
**kwargs
Expand Down Expand Up @@ -509,7 +509,7 @@ def __init__(
random_state : int, RandomState object or None, optional (default=None)
Random number seed.
If int, this number is used to seed the C++ code.
If RandomState object (numpy), a random integer is picked based on its state to seed the C++ code.
If RandomState or Generator object (numpy), a random integer is picked based on its state to seed the C++ code.
If None, default seeds in C++ code are used.
n_jobs : int or None, optional (default=None)
Number of parallel threads to use for training (can be changed at prediction time by
Expand Down Expand Up @@ -710,6 +710,10 @@ def _process_params(self, stage: str) -> Dict[str, Any]:

if isinstance(params['random_state'], np.random.RandomState):
params['random_state'] = params['random_state'].randint(np.iinfo(np.int32).max)
elif isinstance(params['random_state'], np_random_Generator):
params['random_state'] = int(
params['random_state'].integers(np.iinfo(np.int32).max)
)
if self._n_classes > 2:
for alias in _ConfigAliases.get('num_class'):
params.pop(alias, None)
Expand Down
7 changes: 4 additions & 3 deletions tests/python_package_test/test_sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -534,11 +534,12 @@ def test_non_serializable_objects_in_callbacks(tmp_path):
assert gbm.booster_.attr_set_inside_callback == 40


def test_random_state_object():
@pytest.mark.parametrize("rng_constructor", [np.random.RandomState, np.random.default_rng])
def test_random_state_object(rng_constructor):
X, y = load_iris(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42)
state1 = np.random.RandomState(123)
state2 = np.random.RandomState(123)
state1 = rng_constructor(123)
state2 = rng_constructor(123)
clf1 = lgb.LGBMClassifier(n_estimators=10, subsample=0.5, subsample_freq=1, random_state=state1)
clf2 = lgb.LGBMClassifier(n_estimators=10, subsample=0.5, subsample_freq=1, random_state=state2)
# Test if random_state is properly stored
Expand Down

0 comments on commit 501e6e6

Please sign in to comment.