From cb86e7a4365b29088e546a04fc673a53a443bd17 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Sat, 9 Nov 2024 17:32:31 +0000 Subject: [PATCH] fix: use PyCapsule Interface instead of Dataframe Interchange Protocol --- pyproject.toml | 1 + seaborn/_core/data.py | 66 +++++++++++++++++++--------------------- seaborn/_core/plot.py | 2 +- tests/_core/test_data.py | 2 +- tests/conftest.py | 4 +-- 5 files changed, 36 insertions(+), 39 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 0a4e497d0d..ae5b401f12 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -42,6 +42,7 @@ dev = [ "mypy", "pandas-stubs", "pre-commit", + "pyarrow", "flit", ] docs = [ diff --git a/seaborn/_core/data.py b/seaborn/_core/data.py index c17bfe95c5..2c2a0cf0d1 100644 --- a/seaborn/_core/data.py +++ b/seaborn/_core/data.py @@ -5,7 +5,6 @@ from collections.abc import Mapping, Sized from typing import cast -import warnings import pandas as pd from pandas import DataFrame @@ -269,9 +268,9 @@ def _assign_variables( def handle_data_source(data: object) -> pd.DataFrame | Mapping | None: """Convert the data source object to a common union representation.""" - if isinstance(data, pd.DataFrame) or hasattr(data, "__dataframe__"): + if isinstance(data, pd.DataFrame) or hasattr(data, "__arrow_c_stream__"): # Check for pd.DataFrame inheritance could be removed once - # minimal pandas version supports dataframe interchange (1.5.0). + # minimal pandas version supports PyCapsule Interface (2.2). data = convert_dataframe_to_pandas(data) elif data is not None and not isinstance(data, Mapping): err = f"Data source must be a DataFrame or Mapping, not {type(data)!r}." @@ -285,35 +284,32 @@ def convert_dataframe_to_pandas(data: object) -> pd.DataFrame: if isinstance(data, pd.DataFrame): return data - if not hasattr(pd.api, "interchange"): - msg = ( - "Support for non-pandas DataFrame objects requires a version of pandas " - "that implements the DataFrame interchange protocol. Please upgrade " - "your pandas version or coerce your data to pandas before passing " - "it to seaborn." - ) - raise TypeError(msg) - - if _version_predates(pd, "2.0.2"): - msg = ( - "DataFrame interchange with pandas<2.0.2 has some known issues. " - f"You are using pandas {pd.__version__}. " - "Continuing, but it is recommended to carefully inspect the results and to " - "consider upgrading." - ) - warnings.warn(msg, stacklevel=2) - - try: - # This is going to convert all columns in the input dataframe, even though - # we may only need one or two of them. It would be more efficient to select - # the columns that are going to be used in the plot prior to interchange. - # Solving that in general is a hard problem, especially with the objects - # interface where variables passed in Plot() may only be referenced later - # in Plot.add(). But noting here in case this seems to be a bottleneck. - return pd.api.interchange.from_dataframe(data) - except Exception as err: - msg = ( - "Encountered an exception when converting data source " - "to a pandas DataFrame. See traceback above for details." - ) - raise RuntimeError(msg) from err + if hasattr(data, '__arrow_c_stream__'): + try: + import pyarrow + except ImportError as err: + msg = "PyArrow is required for non-pandas Dataframe support." + raise RuntimeError(msg) from err + if _version_predates(pyarrow, '14.0.0'): + msg = "PyArrow>=14.0.0 is required for non-pandas Dataframe support." + raise RuntimeError(msg) + try: + # This is going to convert all columns in the input dataframe, even though + # we may only need one or two of them. It would be more efficient to select + # the columns that are going to be used in the plot prior to interchange. + # Solving that in general is a hard problem, especially with the objects + # interface where variables passed in Plot() may only be referenced later + # in Plot.add(). But noting here in case this seems to be a bottleneck. + return pyarrow.table(data).to_pandas() + except Exception as err: + msg = ( + "Encountered an exception when converting data source " + "to a pandas DataFrame. See traceback above for details." + ) + raise RuntimeError(msg) from err + + msg = ( + "Expected object which implements '__arrow_c_stream__' from the " + f"PyCapsule Interface, got: {type(data)}" + ) + raise TypeError(msg) diff --git a/seaborn/_core/plot.py b/seaborn/_core/plot.py index c9dc61c8a7..b021e7b959 100644 --- a/seaborn/_core/plot.py +++ b/seaborn/_core/plot.py @@ -349,7 +349,7 @@ def _resolve_positionals( if ( isinstance(args[0], (abc.Mapping, pd.DataFrame)) - or hasattr(args[0], "__dataframe__") + or hasattr(args[0], "__arrow_c_stream__") ): if data is not None: raise TypeError("`data` given by both name and position.") diff --git a/tests/_core/test_data.py b/tests/_core/test_data.py index 0e67ed37b4..2e7dff3c09 100644 --- a/tests/_core/test_data.py +++ b/tests/_core/test_data.py @@ -425,7 +425,7 @@ def test_data_interchange(self, mock_long_df, long_df): ) def test_data_interchange_failure(self, mock_long_df): - mock_long_df._data = None # Break __dataframe__() + mock_long_df.__arrow_c_stream__ = lambda _x: 1 / 0 # Break __arrow_c_stream__() with pytest.raises(RuntimeError, match="Encountered an exception"): PlotData(mock_long_df, {"x": "x"}) diff --git a/tests/conftest.py b/tests/conftest.py index 6ee53e7ee4..485e6ad67a 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -188,8 +188,8 @@ class MockInterchangeableDataFrame: def __init__(self, data): self._data = data - def __dataframe__(self, *args, **kwargs): - return self._data.__dataframe__(*args, **kwargs) + def __arrow_c_stream__(self, *args, **kwargs): + return self._data.__arrow_c_stream__() @pytest.fixture