diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index c341306fa0d8..ecf8749e7441 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -27,6 +27,7 @@ import scipy.sparse from .compat import ( + CFFI_INSTALLED, PANDAS_INSTALLED, PYARROW_INSTALLED, arrow_cffi, @@ -1706,8 +1707,8 @@ def __pred_for_pyarrow_table( predict_type: int, ) -> Tuple[np.ndarray, int]: """Predict for a PyArrow table.""" - if not PYARROW_INSTALLED: - raise LightGBMError("Cannot predict from Arrow without `pyarrow` installed.") + if not (PYARROW_INSTALLED and CFFI_INSTALLED): + raise LightGBMError("Cannot predict from Arrow without 'pyarrow' and 'cffi' installed.") # Check that the input is valid: we only handle numbers (for now) if not all(arrow_is_integer(t) or arrow_is_floating(t) or arrow_is_boolean(t) for t in table.schema.types): @@ -2458,8 +2459,8 @@ def __init_from_pyarrow_table( ref_dataset: Optional[_DatasetHandle], ) -> "Dataset": """Initialize data from a PyArrow table.""" - if not PYARROW_INSTALLED: - raise LightGBMError("Cannot init dataframe from Arrow without `pyarrow` installed.") + if not (PYARROW_INSTALLED and CFFI_INSTALLED): + raise LightGBMError("Cannot init Dataset from Arrow without 'pyarrow' and 'cffi' installed.") # Check that the input is valid: we only handle numbers (for now) if not all(arrow_is_integer(t) or arrow_is_floating(t) or arrow_is_boolean(t) for t in table.schema.types): diff --git a/python-package/lightgbm/compat.py b/python-package/lightgbm/compat.py index 96dee6522572..04a831131652 100644 --- a/python-package/lightgbm/compat.py +++ b/python-package/lightgbm/compat.py @@ -289,7 +289,6 @@ def __init__(self, *args: Any, **kwargs: Any): from pyarrow import ChunkedArray as pa_ChunkedArray from pyarrow import Table as pa_Table from pyarrow import chunked_array as pa_chunked_array - from pyarrow.cffi import ffi as arrow_cffi from pyarrow.types import is_boolean as arrow_is_boolean from pyarrow.types import is_floating as arrow_is_floating from pyarrow.types import is_integer as arrow_is_integer @@ -316,19 +315,8 @@ class pa_Table: # type: ignore def __init__(self, *args: Any, **kwargs: Any): pass - class arrow_cffi: # type: ignore - """Dummy class for pyarrow.cffi.ffi.""" - - CData = None - addressof = None - cast = None - new = None - - def __init__(self, *args: Any, **kwargs: Any): - pass - class pa_compute: # type: ignore - """Dummy class for pyarrow.compute.""" + """Dummy class for pyarrow.compute module.""" all = None equal = None @@ -338,6 +326,24 @@ class pa_compute: # type: ignore arrow_is_integer = None arrow_is_floating = None + +"""cffi""" +try: + from pyarrow.cffi import ffi as arrow_cffi + + CFFI_INSTALLED = True +except ImportError: + CFFI_INSTALLED = False + + class arrow_cffi: # type: ignore + """Dummy class for pyarrow.cffi.ffi.""" + + CData = None + + def __init__(self, *args: Any, **kwargs: Any): + pass + + """cpu_count()""" try: from joblib import cpu_count diff --git a/tests/python_package_test/conftest.py b/tests/python_package_test/conftest.py index 7d9c5b27079f..1f4a7943a9a9 100644 --- a/tests/python_package_test/conftest.py +++ b/tests/python_package_test/conftest.py @@ -1,6 +1,15 @@ import numpy as np import pytest +import lightgbm + + +@pytest.fixture(scope="function") +def missing_module_cffi(monkeypatch): + """Mock 'cffi' not being importable""" + monkeypatch.setattr(lightgbm.compat, "CFFI_INSTALLED", False) + monkeypatch.setattr(lightgbm.basic, "CFFI_INSTALLED", False) + @pytest.fixture(scope="function") def rng(): diff --git a/tests/python_package_test/test_arrow.py b/tests/python_package_test/test_arrow.py index d8246f3842de..b592d733d41e 100644 --- a/tests/python_package_test/test_arrow.py +++ b/tests/python_package_test/test_arrow.py @@ -454,3 +454,32 @@ def test_arrow_feature_name_manual(): ) booster = lgb.train({"num_leaves": 7}, dataset, num_boost_round=5) assert booster.feature_name() == ["c", "d"] + + +def test_dataset_construction_from_pa_table_without_cffi_raises_informative_error(missing_module_cffi): + with pytest.raises( + lgb.basic.LightGBMError, match="Cannot init Dataset from Arrow without 'pyarrow' and 'cffi' installed." + ): + lgb.Dataset( + generate_dummy_arrow_table(), + label=pa.array([0, 1, 0, 0, 1]), + params=dummy_dataset_params(), + ).construct() + + +def test_predicting_from_pa_table_without_cffi_raises_informative_error(missing_module_cffi): + data = generate_random_arrow_table(num_columns=3, num_datapoints=1_000, seed=42) + labels = generate_random_arrow_array(num_datapoints=data.shape[0], seed=42) + bst = lgb.train( + params={"num_leaves": 7, "verbose": -1}, + train_set=lgb.Dataset( + data.to_pandas(), + label=labels.to_pandas(), + ), + num_boost_round=2, + ) + + with pytest.raises( + lgb.basic.LightGBMError, match="Cannot predict from Arrow without 'pyarrow' and 'cffi' installed." + ): + bst.predict(data)