From c984002d500f42ea14ba38e087f0c746af5e3722 Mon Sep 17 00:00:00 2001 From: Marco Edward Gorelli Date: Wed, 21 Aug 2024 21:25:53 +0100 Subject: [PATCH] fix: Pass native dataframe to data transformers (#3550) Co-authored-by: dangotbanned <125183946+dangotbanned@users.noreply.github.com> --- altair/utils/_vegafusion_data.py | 6 --- altair/utils/data.py | 76 ++++++++++++++++---------------- altair/vegalite/v5/api.py | 3 +- pyproject.toml | 2 +- tests/utils/test_data.py | 23 ++++++---- 5 files changed, 57 insertions(+), 53 deletions(-) diff --git a/altair/utils/_vegafusion_data.py b/altair/utils/_vegafusion_data.py index 99779b62e..970098d33 100644 --- a/altair/utils/_vegafusion_data.py +++ b/altair/utils/_vegafusion_data.py @@ -13,8 +13,6 @@ ) from weakref import WeakValueDictionary -import narwhals.stable.v1 as nw - from altair.utils._importers import import_vegafusion from altair.utils.core import DataFrameLike from altair.utils.data import ( @@ -71,10 +69,6 @@ def vegafusion_data_transformer( data: DataType | None = None, max_rows: int = 100000 ) -> Callable[..., Any] | _VegaFusionReturnType: """VegaFusion Data Transformer.""" - # Vegafusion does not support Narwhals, so if `data` is a Narwhals - # object, we make sure to extract the native object and let Vegafusion handle it. - # `strict=False` passes `data` through as-is if it is not a Narwhals object. - data = nw.to_native(data, strict=False) if data is None: return vegafusion_data_transformer elif isinstance(data, DataFrameLike) and not isinstance(data, SupportsGeoInterface): diff --git a/altair/utils/data.py b/altair/utils/data.py index 1986ec8c5..42c87ece4 100644 --- a/altair/utils/data.py +++ b/altair/utils/data.py @@ -314,12 +314,7 @@ def to_values(data: DataType) -> ToValuesReturnType: # `strict=False` passes `data` through as-is if it is not a Narwhals object. data_native = nw.to_native(data, strict=False) if isinstance(data_native, SupportsGeoInterface): - if _is_pandas_dataframe(data_native): - data_native = sanitize_pandas_dataframe(data_native) - # Maybe the type could be further clarified here that it is - # SupportGeoInterface and then the ignore statement is not needed? - data_sanitized = sanitize_geo_interface(data_native.__geo_interface__) - return {"values": data_sanitized} + return {"values": _from_geo_interface(data_native)} elif _is_pandas_dataframe(data_native): data_native = sanitize_pandas_dataframe(data_native) return {"values": data_native.to_dict(orient="records")} @@ -350,32 +345,45 @@ def _compute_data_hash(data_str: str) -> str: return hashlib.sha256(data_str.encode()).hexdigest()[:32] +def _from_geo_interface(data: SupportsGeoInterface | Any) -> dict[str, Any]: + """ + Santize a ``__geo_interface__`` w/ pre-santize step for ``pandas`` if needed. + + Notes + ----- + Split out to resolve typing issues related to: + - Intersection types + - ``typing.TypeGuard`` + - ``pd.DataFrame.__getattr__`` + """ + if _is_pandas_dataframe(data): + data = sanitize_pandas_dataframe(data) + return sanitize_geo_interface(data.__geo_interface__) + + def _data_to_json_string(data: DataType) -> str: """Return a JSON string representation of the input data.""" check_data_type(data) - # `strict=False` passes `data` through as-is if it is not a Narwhals object. - data_native = nw.to_native(data, strict=False) - if isinstance(data_native, SupportsGeoInterface): - if _is_pandas_dataframe(data_native): - data_native = sanitize_pandas_dataframe(data_native) - data_native = sanitize_geo_interface(data_native.__geo_interface__) - return json.dumps(data_native) - elif _is_pandas_dataframe(data_native): - data = sanitize_pandas_dataframe(data_native) - return data_native.to_json(orient="records", double_precision=15) - elif isinstance(data_native, dict): - if "values" not in data_native: + if isinstance(data, SupportsGeoInterface): + return json.dumps(_from_geo_interface(data)) + elif _is_pandas_dataframe(data): + data = sanitize_pandas_dataframe(data) + return data.to_json(orient="records", double_precision=15) + elif isinstance(data, dict): + if "values" not in data: msg = "values expected in data dict, but not present." raise KeyError(msg) - return json.dumps(data_native["values"], sort_keys=True) - elif isinstance(data, nw.DataFrame): - return json.dumps(data.rows(named=True)) - else: - msg = "to_json only works with data expressed as " "a DataFrame or as a dict" - raise NotImplementedError(msg) + return json.dumps(data["values"], sort_keys=True) + try: + data_nw = nw.from_native(data, eager_only=True) + except TypeError as exc: + msg = "to_json only works with data expressed as a DataFrame or as a dict" + raise NotImplementedError(msg) from exc + data_nw = sanitize_narwhals_dataframe(data_nw) + return json.dumps(data_nw.rows(named=True)) -def _data_to_csv_string(data: dict | pd.DataFrame | DataFrameLike) -> str: +def _data_to_csv_string(data: DataType) -> str: """Return a CSV string representation of the input data.""" check_data_type(data) if isinstance(data, SupportsGeoInterface): @@ -398,18 +406,12 @@ def _data_to_csv_string(data: dict | pd.DataFrame | DataFrameLike) -> str: msg = "pandas is required to convert a dict to a CSV string" raise ImportError(msg) from exc return pd.DataFrame.from_dict(data["values"]).to_csv(index=False) - elif isinstance(data, DataFrameLike): - # experimental interchange dataframe support - import pyarrow as pa - import pyarrow.csv as pa_csv - - pa_table = arrow_table_from_dfi_dataframe(data) - csv_buffer = pa.BufferOutputStream() - pa_csv.write_csv(pa_table, csv_buffer) - return csv_buffer.getvalue().to_pybytes().decode() - else: - msg = "to_csv only works with data expressed as " "a DataFrame or as a dict" - raise NotImplementedError(msg) + try: + data_nw = nw.from_native(data, eager_only=True) + except TypeError as exc: + msg = "to_csv only works with data expressed as a DataFrame or as a dict" + raise NotImplementedError(msg) from exc + return data_nw.write_csv() def arrow_table_from_dfi_dataframe(dfi_df: DataFrameLike) -> pa.Table: diff --git a/altair/vegalite/v5/api.py b/altair/vegalite/v5/api.py index e77d62e00..d352b060b 100644 --- a/altair/vegalite/v5/api.py +++ b/altair/vegalite/v5/api.py @@ -23,6 +23,7 @@ from typing_extensions import TypeAlias import jsonschema +import narwhals.stable.v1 as nw from altair import utils from altair.expr import core as _expr_core @@ -274,7 +275,7 @@ def _prepare_data( # convert dataframes or objects with __geo_interface__ to dict elif not isinstance(data, dict) and _is_data_type(data): if func := data_transformers.get(): - data = func(data) + data = func(nw.to_native(data, strict=False)) # convert string input to a URLData elif isinstance(data, str): diff --git a/pyproject.toml b/pyproject.toml index 4a0c3874c..4c85070a1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,7 +20,7 @@ dependencies = [ # If you update the minimum required jsonschema version, also update it in build.yml "jsonschema>=3.0", "packaging", - "narwhals>=1.1.0" + "narwhals>=1.5.2" ] description = "Vega-Altair: A declarative statistical visualization library for Python." readme = "README.md" diff --git a/tests/utils/test_data.py b/tests/utils/test_data.py index 673c2852c..30b5b7f8e 100644 --- a/tests/utils/test_data.py +++ b/tests/utils/test_data.py @@ -1,5 +1,7 @@ +from __future__ import annotations + from pathlib import Path -from typing import Any, Callable +from typing import Any, Callable, SupportsIndex, TypeVar import narwhals.stable.v1 as nw import pandas as pd @@ -15,6 +17,8 @@ to_values, ) +T = TypeVar("T") + def _pipe(data: Any, *funcs: Callable[..., Any]) -> Any: # Redefined to maintain existing tests @@ -24,13 +28,15 @@ def _pipe(data: Any, *funcs: Callable[..., Any]) -> Any: return data -def _create_dataframe(N): - data = pd.DataFrame({"x": range(N), "y": range(N)}) +def _create_dataframe( + n: SupportsIndex, /, tp: Callable[..., T] | type[Any] = pd.DataFrame +) -> T | Any: + data = tp({"x": range(n), "y": range(n)}) return data -def _create_data_with_values(N): - data = {"values": [{"x": i, "y": i + 1} for i in range(N)]} +def _create_data_with_values(n: SupportsIndex, /) -> dict[str, Any]: + data = {"values": [{"x": i, "y": i + 1} for i in range(n)]} return data @@ -127,19 +133,20 @@ def test_dict_to_json(): assert data == {"values": output} -def test_dataframe_to_csv(): +@pytest.mark.parametrize("tp", [pd.DataFrame, pl.DataFrame], ids=["pandas", "polars"]) +def test_dataframe_to_csv(tp: type[Any]) -> None: """ Test to_csv with dataframe input. - make certain the filename is deterministic - make certain the file contents match the data. """ - data = _create_dataframe(10) + data = _create_dataframe(10, tp=tp) try: result1 = _pipe(data, to_csv) result2 = _pipe(data, to_csv) filename = result1["url"] - output = pd.read_csv(filename) + output = tp(pd.read_csv(filename)) finally: Path(filename).unlink()