Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: use PyCapsule Interface instead of Dataframe Interchange Protocol #3782

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ dev = [
"mypy",
"pandas-stubs",
"pre-commit",
"pyarrow",
"flit",
]
docs = [
Expand Down
66 changes: 31 additions & 35 deletions seaborn/_core/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

from collections.abc import Mapping, Sized
from typing import cast
import warnings

import pandas as pd
from pandas import DataFrame
Expand Down Expand Up @@ -269,9 +268,9 @@ def _assign_variables(

def handle_data_source(data: object) -> pd.DataFrame | Mapping | None:
"""Convert the data source object to a common union representation."""
if isinstance(data, pd.DataFrame) or hasattr(data, "__dataframe__"):
if isinstance(data, pd.DataFrame) or hasattr(data, "__arrow_c_stream__"):
# Check for pd.DataFrame inheritance could be removed once
# minimal pandas version supports dataframe interchange (1.5.0).
# minimal pandas version supports PyCapsule Interface (2.2).
data = convert_dataframe_to_pandas(data)
elif data is not None and not isinstance(data, Mapping):
err = f"Data source must be a DataFrame or Mapping, not {type(data)!r}."
Expand All @@ -285,35 +284,32 @@ def convert_dataframe_to_pandas(data: object) -> pd.DataFrame:
if isinstance(data, pd.DataFrame):
return data

if not hasattr(pd.api, "interchange"):
msg = (
"Support for non-pandas DataFrame objects requires a version of pandas "
"that implements the DataFrame interchange protocol. Please upgrade "
"your pandas version or coerce your data to pandas before passing "
"it to seaborn."
)
raise TypeError(msg)

if _version_predates(pd, "2.0.2"):
msg = (
"DataFrame interchange with pandas<2.0.2 has some known issues. "
f"You are using pandas {pd.__version__}. "
"Continuing, but it is recommended to carefully inspect the results and to "
"consider upgrading."
)
warnings.warn(msg, stacklevel=2)

try:
# This is going to convert all columns in the input dataframe, even though
# we may only need one or two of them. It would be more efficient to select
# the columns that are going to be used in the plot prior to interchange.
# Solving that in general is a hard problem, especially with the objects
# interface where variables passed in Plot() may only be referenced later
# in Plot.add(). But noting here in case this seems to be a bottleneck.
return pd.api.interchange.from_dataframe(data)
except Exception as err:
msg = (
"Encountered an exception when converting data source "
"to a pandas DataFrame. See traceback above for details."
)
raise RuntimeError(msg) from err
if hasattr(data, '__arrow_c_stream__'):
try:
import pyarrow
except ImportError as err:
msg = "PyArrow is required for non-pandas Dataframe support."
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this generally a dependency of non-pandas dataframe libraries now? Or could this change introduce a regression for e.g. polars users who are currently leveraging the dataframe interchange protocol?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for your review!

Polars doesn't depend on PyArrow, but polars.DataFrame.to_pandas always requires PyArrow. So, in practice, anyone working with both dataframe libraries may well already have PyArrow already installed

To avoid requiring PyArrow for the cases when it's not necessary, one way could be to do something like:

  • try using the interchange protocol
  • if it raises, then fall back to the PyCapsule Interface (which currently requires PyArrow)

This has the upside of not requiring PyArrow in some cases, but the downside of hiding issues where the interchange protocol silently produces invalid results

It may be possible to do this PyCapsule Interface conversion in the future without PyArrow but with something lighter instead, like arro3 by @kylebarron (who I'm ccing in case he has comments too)

What would be your preference?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Some polars users may not have pyarrow installed. If seaborn needs to get pandas data, the only production-ready way to do Arrow -> pandas that I know of is using pyarrow.

As Marco mentions I'm working on arro3, which is a minimal library for Arrow in Python, but Pandas interop is not a primary concern, and it's not production-ready today.

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FWIW pandas 3.x is going to strongly incentivize users to install PyArrow, although it stops short of outright requiring it. In theory, the only people that shouldn't have PyArrow installed are those that operate in space/resource constrained environments, probably in headless environments like AWS Lambda where seaborn won't be used

Of course up to you how much you want to support non-PyArrow configurations, but the dataframe interchange protocol is relatively buggy and gets very little support, so you may find it easier altogether to force users towards PyArrow

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cuDF have said that they will deprecate the interchange format: rapidsai/cudf#17282

Plotly have stopped using it, so Seaborn is the only library left using it

At this point, I think there's a greater risk in keeping it - I don't want to force anything here of course, just making sure you're aware

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please clarify what you mean by "if the infra isn't there yet"?

The Arrow C Interface already has quite widespread adoption and I'm not aware of edge cases in its implementations. @WillAyd wrote about switching his Pantab project over to it in Leveraging the Arrow C Data Interface, and noted

Almost immediately my issues went away [...] I felt more confident in the implementation and had to deal with less memory corruption / crashes than before. And, perhaps most importantly, I saved a lot of time.

That was nearly a year ago, and given that he's now suggesting it here in Plotly, I'd say that his experience has stayed just as positive


Regarding PyArrow dependency, I'll also note that polars.DataFrame.to_pandas also requires PyArrow, so any Polars user (such as myself) would already have needed PyArrow installed if they were converting to pandas via the Polars official method

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Basically my threshold is "do I need to think about it at all". I'm just not interested in the minutia of competing Python dataframe libraries or the various attempts to make them work better together. The previous approach was sold as a simple protocol that always works, but it turns out that wasn't the case. Maybe this new way is better, the problem is I have no real way to say for sure without spending a lot of time learning about something that doesn't interest me.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shall I close and leave you to remove cross-dataframe compatibility altogether?

Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem is then I get issues bugging me about Polars, so I have to think about it anyway :D

Copy link
Contributor Author

@MarcoGorelli MarcoGorelli Nov 26, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

😄 that's understandable

I'm aware that you said that using Narwhals was a complete non-starter, but just to showcase that as a possibility:

import narwhals.stable.v1 as nw
from narwhals.stable.v1.typing import IntoDataFrame
import polars as pl
import pandas as pd

def convert_dataframe_to_pandas(data: IntoDataFrame) -> pd.DataFrame:
    return nw.from_native(data).to_pandas()

and then leave it up to Narwhals to convert to pandas in the best way for each input library

Altair, Plotly, and Vegafusion are using it as required dependency now, and Bokeh have a PR in progress to do the same


For completeness: the way the other libraries are using Narwhals is by making the whole logic dataframe-agnostic. In Plotly this resulted in 2-3x better performance for many plots involving group-bys (compared with converting all inputs to pandas), but I understand that you may not be interested in that

raise RuntimeError(msg) from err
if _version_predates(pyarrow, '14.0.0'):
msg = "PyArrow>=14.0.0 is required for non-pandas Dataframe support."
raise RuntimeError(msg)
try:
# This is going to convert all columns in the input dataframe, even though
# we may only need one or two of them. It would be more efficient to select
# the columns that are going to be used in the plot prior to interchange.
# Solving that in general is a hard problem, especially with the objects
# interface where variables passed in Plot() may only be referenced later
# in Plot.add(). But noting here in case this seems to be a bottleneck.
return pyarrow.table(data).to_pandas()
except Exception as err:
msg = (
"Encountered an exception when converting data source "
"to a pandas DataFrame. See traceback above for details."
)
raise RuntimeError(msg) from err

msg = (
"Expected object which implements '__arrow_c_stream__' from the "
f"PyCapsule Interface, got: {type(data)}"
)
raise TypeError(msg)
2 changes: 1 addition & 1 deletion seaborn/_core/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ def _resolve_positionals(

if (
isinstance(args[0], (abc.Mapping, pd.DataFrame))
or hasattr(args[0], "__dataframe__")
or hasattr(args[0], "__arrow_c_stream__")
):
if data is not None:
raise TypeError("`data` given by both name and position.")
Expand Down
4 changes: 2 additions & 2 deletions seaborn/_core/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,9 @@
VariableSpec = Union[ColumnName, Vector, None]
VariableSpecList = Union[List[VariableSpec], Index, None]

# A DataSource can be an object implementing __dataframe__, or a Mapping
# A DataSource can be an object implementing __arrow_c_stream__, or a Mapping
# (and is optional in all contexts where it is used).
# I don't think there's an abc for "has __dataframe__", so we type as object
# I don't think there's an abc for "has __arrow_c_stream__", so we type as object
# but keep the (slightly odd) Union alias for better user-facing annotations.
DataSource = Union[object, Mapping, None]

Expand Down
28 changes: 15 additions & 13 deletions tests/_core/test_data.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import functools
import numpy as np
import pandas as pd
from seaborn.external.version import Version

import pytest
from numpy.testing import assert_array_equal
Expand Down Expand Up @@ -404,11 +405,11 @@ def test_bad_type(self, flat_list):
with pytest.raises(TypeError, match=err):
PlotData(flat_list, {})

@pytest.mark.skipif(
condition=not hasattr(pd.api, "interchange"),
reason="Tests behavior assuming support for dataframe interchange"
)
def test_data_interchange(self, mock_long_df, long_df):
pytest.importorskip(
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice, TIL

'pyarrow', '14.0',
reason="Tests behavior assuming support for PyCapsule Interface"
)

variables = {"x": "x", "y": "z", "color": "a"}
p = PlotData(mock_long_df, variables)
Expand All @@ -419,21 +420,22 @@ def test_data_interchange(self, mock_long_df, long_df):
for var, col in variables.items():
assert_vector_equal(p.frame[var], long_df[col])

@pytest.mark.skipif(
condition=not hasattr(pd.api, "interchange"),
reason="Tests behavior assuming support for dataframe interchange"
)
def test_data_interchange_failure(self, mock_long_df):
pytest.importorskip(
'pyarrow', '14.0',
reason="Tests behavior assuming support for PyCapsule Interface"
)

mock_long_df._data = None # Break __dataframe__()
mock_long_df.__arrow_c_stream__ = lambda _x: 1 / 0 # Break __arrow_c_stream__()
with pytest.raises(RuntimeError, match="Encountered an exception"):
PlotData(mock_long_df, {"x": "x"})

@pytest.mark.skipif(
condition=hasattr(pd.api, "interchange"),
reason="Tests graceful failure without support for dataframe interchange"
)
def test_data_interchange_support_test(self, mock_long_df):
pyarrow = pytest.importorskip('pyarrow')
if Version(pyarrow.__version__) >= Version('14.0.0'):
pytest.skip(
reason="Tests graceful failure without support for PyCapsule Interface"
)

with pytest.raises(TypeError, match="Support for non-pandas DataFrame"):
PlotData(mock_long_df, {"x": "x"})
8 changes: 4 additions & 4 deletions tests/_core/test_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,11 +170,11 @@ def test_positional_x(self, long_df):
assert p._data.source_data is None
assert list(p._data.source_vars) == ["x"]

@pytest.mark.skipif(
condition=not hasattr(pd.api, "interchange"),
reason="Tests behavior assuming support for dataframe interchange"
)
def test_positional_interchangeable_dataframe(self, mock_long_df, long_df):
pytest.importorskip(
'pyarrow', '14.0',
reason="Tests behavior assuming support for PyCapsule Interface"
)

p = Plot(mock_long_df, x="x")
assert_frame_equal(p._data.source_data, long_df)
Expand Down
7 changes: 4 additions & 3 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,11 +188,12 @@ class MockInterchangeableDataFrame:
def __init__(self, data):
self._data = data

def __dataframe__(self, *args, **kwargs):
return self._data.__dataframe__(*args, **kwargs)
def __arrow_c_stream__(self, *args, **kwargs):
return self._data.__arrow_c_stream__()


@pytest.fixture
def mock_long_df(long_df):
import pyarrow

return MockInterchangeableDataFrame(long_df)
return MockInterchangeableDataFrame(pyarrow.Table.from_pandas(long_df))
16 changes: 8 additions & 8 deletions tests/test_axisgrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -708,11 +708,11 @@ def test_tick_params(self):
assert mpl.colors.same_color(tick.tick2line.get_color(), color)
assert tick.get_pad() == pad

@pytest.mark.skipif(
condition=not hasattr(pd.api, "interchange"),
reason="Tests behavior assuming support for dataframe interchange"
)
def test_data_interchange(self, mock_long_df, long_df):
pytest.importorskip(
'pyarrow', '14.0',
reason="Tests behavior assuming support for PyCapsule Interface"
)

g = ag.FacetGrid(mock_long_df, col="a", row="b")
g.map(scatterplot, "x", "y")
Expand Down Expand Up @@ -1477,11 +1477,11 @@ def test_tick_params(self):
assert mpl.colors.same_color(tick.tick2line.get_color(), color)
assert tick.get_pad() == pad

@pytest.mark.skipif(
condition=not hasattr(pd.api, "interchange"),
reason="Tests behavior assuming support for dataframe interchange"
)
def test_data_interchange(self, mock_long_df, long_df):
pytest.importorskip(
'pyarrow', '14.0',
reason="Tests behavior assuming support for PyCapsule Interface"
)

g = ag.PairGrid(mock_long_df, vars=["x", "y", "z"], hue="a")
g.map(scatterplot)
Expand Down
Loading