diff --git a/py-polars/polars/dataframe/plotting.py b/py-polars/polars/dataframe/plotting.py index ed118e504656..ac787afd3882 100644 --- a/py-polars/polars/dataframe/plotting.py +++ b/py-polars/polars/dataframe/plotting.py @@ -6,15 +6,13 @@ import sys import altair as alt - from altair.typing import ( - ChannelColor, - ChannelOrder, - ChannelSize, - ChannelTooltip, - ChannelX, - ChannelY, - EncodeKwds, - ) + from altair.typing import ChannelColor as Color + from altair.typing import ChannelOrder as Order + from altair.typing import ChannelSize as Size + from altair.typing import ChannelTooltip as Tooltip + from altair.typing import ChannelX as X + from altair.typing import ChannelY as Y + from altair.typing import EncodeKwds from polars import DataFrame @@ -29,12 +27,15 @@ Encodings: TypeAlias = Dict[ str, - Union[ - ChannelX, ChannelY, ChannelColor, ChannelOrder, ChannelSize, ChannelTooltip - ], + Union[X, Y, Color, Order, Size, Tooltip], ] +def _add_tooltip(encodings: Encodings, /, **kwargs: Unpack[EncodeKwds]) -> None: + if "tooltip" not in kwargs: + encodings["tooltip"] = [*encodings.values(), *kwargs.values()] # type: ignore[assignment] + + class DataFramePlot: """DataFrame.plot namespace.""" @@ -45,10 +46,9 @@ def __init__(self, df: DataFrame) -> None: def bar( self, - x: ChannelX | None = None, - y: ChannelY | None = None, - color: ChannelColor | None = None, - tooltip: ChannelTooltip | None = None, + x: X | None = None, + y: Y | None = None, + color: Color | None = None, /, **kwargs: Unpack[EncodeKwds], ) -> alt.Chart: @@ -77,8 +77,6 @@ def bar( Column with y-coordinates of bars. color Column to color bars by. - tooltip - Columns to show values of when hovering over bars with pointer. **kwargs Additional keyword arguments passed to Altair. @@ -102,17 +100,15 @@ def bar( encodings["y"] = y if color is not None: encodings["color"] = color - if tooltip is not None: - encodings["tooltip"] = tooltip + _add_tooltip(encodings, **kwargs) return self._chart.mark_bar().encode(**encodings, **kwargs).interactive() def line( self, - x: ChannelX | None = None, - y: ChannelY | None = None, - color: ChannelColor | None = None, - order: ChannelOrder | None = None, - tooltip: ChannelTooltip | None = None, + x: X | None = None, + y: Y | None = None, + color: Color | None = None, + order: Order | None = None, /, **kwargs: Unpack[EncodeKwds], ) -> alt.Chart: @@ -142,8 +138,6 @@ def line( Column to color lines by. order Column to use for order of data points in lines. - tooltip - Columns to show values of when hovering over lines with pointer. **kwargs Additional keyword arguments passed to Altair. @@ -168,17 +162,15 @@ def line( encodings["color"] = color if order is not None: encodings["order"] = order - if tooltip is not None: - encodings["tooltip"] = tooltip + _add_tooltip(encodings, **kwargs) return self._chart.mark_line().encode(**encodings, **kwargs).interactive() def point( self, - x: ChannelX | None = None, - y: ChannelY | None = None, - color: ChannelColor | None = None, - size: ChannelSize | None = None, - tooltip: ChannelTooltip | None = None, + x: X | None = None, + y: Y | None = None, + color: Color | None = None, + size: Size | None = None, /, **kwargs: Unpack[EncodeKwds], ) -> alt.Chart: @@ -209,8 +201,6 @@ def point( Column to color points by. size Column which determines points' sizes. - tooltip - Columns to show values of when hovering over points with pointer. **kwargs Additional keyword arguments passed to Altair. @@ -234,8 +224,7 @@ def point( encodings["color"] = color if size is not None: encodings["size"] = size - if tooltip is not None: - encodings["tooltip"] = tooltip + _add_tooltip(encodings, **kwargs) return ( self._chart.mark_point() .encode( @@ -253,4 +242,10 @@ def __getattr__(self, attr: str) -> Callable[..., alt.Chart]: if method is None: msg = "Altair has no method 'mark_{attr}'" raise AttributeError(msg) - return lambda **kwargs: method().encode(**kwargs).interactive() + encodings: Encodings = {} + + def func(**kwargs: EncodeKwds) -> alt.Chart: + _add_tooltip(encodings, **kwargs) + return method().encode(**encodings, **kwargs).interactive() + + return func diff --git a/py-polars/polars/series/plotting.py b/py-polars/polars/series/plotting.py index cb5c6c93a1e1..5430d55c6ff3 100644 --- a/py-polars/polars/series/plotting.py +++ b/py-polars/polars/series/plotting.py @@ -2,6 +2,7 @@ from typing import TYPE_CHECKING, Callable +from polars.dataframe.plotting import _add_tooltip from polars.dependencies import altair as alt if TYPE_CHECKING: @@ -9,6 +10,8 @@ from altair.typing import EncodeKwds + from polars.dataframe.plotting import Encodings + if sys.version_info >= (3, 11): from typing import Unpack else: @@ -62,11 +65,13 @@ def hist( if self._series_name == "count()": msg = "Cannot use `plot.hist` when Series name is `'count()'`" raise ValueError(msg) + encodings: Encodings = { + "x": alt.X(f"{self._series_name}:Q", bin=True), + "y": "count()", + } + _add_tooltip(encodings, **kwargs) return ( - alt.Chart(self._df) - .mark_bar() - .encode(x=alt.X(f"{self._series_name}:Q", bin=True), y="count()", **kwargs) # type: ignore[misc] - .interactive() + alt.Chart(self._df).mark_bar().encode(**encodings, **kwargs).interactive() ) def kde( @@ -104,11 +109,13 @@ def kde( if self._series_name == "density": msg = "Cannot use `plot.kde` when Series name is `'density'`" raise ValueError(msg) + encodings: Encodings = {"x": self._series_name, "y": "density:Q"} + _add_tooltip(encodings, **kwargs) return ( alt.Chart(self._df) .transform_density(self._series_name, as_=[self._series_name, "density"]) .mark_area() - .encode(x=self._series_name, y="density:Q", **kwargs) # type: ignore[misc] + .encode(**encodings, **kwargs) .interactive() ) @@ -147,10 +154,12 @@ def line( if self._series_name == "index": msg = "Cannot call `plot.line` when Series name is 'index'" raise ValueError(msg) + encodings: Encodings = {"x": "index", "y": self._series_name} + _add_tooltip(encodings, **kwargs) return ( alt.Chart(self._df.with_row_index()) .mark_line() - .encode(x="index", y=self._series_name, **kwargs) # type: ignore[misc] + .encode(**encodings, **kwargs) .interactive() ) @@ -165,8 +174,10 @@ def __getattr__(self, attr: str) -> Callable[..., alt.Chart]: if method is None: msg = "Altair has no method 'mark_{attr}'" raise AttributeError(msg) - return ( - lambda **kwargs: method() - .encode(x="index", y=self._series_name, **kwargs) - .interactive() - ) + encodings: Encodings = {"x": "index", "y": self._series_name} + + def func(**kwargs: EncodeKwds) -> alt.Chart: + _add_tooltip(encodings, **kwargs) + return method().encode(**encodings, **kwargs).interactive() + + return func diff --git a/py-polars/tests/unit/operations/namespaces/test_plot.py b/py-polars/tests/unit/operations/namespaces/test_plot.py index fc2fbc02648a..5a4c1c21a596 100644 --- a/py-polars/tests/unit/operations/namespaces/test_plot.py +++ b/py-polars/tests/unit/operations/namespaces/test_plot.py @@ -17,6 +17,29 @@ def test_dataframe_plot() -> None: df.plot.area(x="length", y="width", color="species").to_json() +def test_dataframe_plot_tooltip() -> None: + df = pl.DataFrame( + { + "length": [1, 4, 6], + "width": [4, 5, 6], + "species": ["setosa", "setosa", "versicolor"], + } + ) + result = df.plot.line(x="length", y="width", color="species").to_dict() + assert result["encoding"]["tooltip"] == [ + {"field": "length", "type": "quantitative"}, + {"field": "width", "type": "quantitative"}, + {"field": "species", "type": "nominal"}, + ] + result = df.plot.line( + x="length", y="width", color="species", tooltip=["length", "width"] + ).to_dict() + assert result["encoding"]["tooltip"] == [ + {"field": "length", "type": "quantitative"}, + {"field": "width", "type": "quantitative"}, + ] + + def test_series_plot() -> None: # dry-run, check nothing errors s = pl.Series("a", [1, 4, 4, 4, 7, 2, 5, 3, 6]) @@ -26,6 +49,17 @@ def test_series_plot() -> None: s.plot.point().to_json() +def test_series_plot_tooltip() -> None: + s = pl.Series("a", [1, 4, 4, 4, 7, 2, 5, 3, 6]) + result = s.plot.line().to_dict() + assert result["encoding"]["tooltip"] == [ + {"field": "index", "type": "quantitative"}, + {"field": "a", "type": "quantitative"}, + ] + result = s.plot.line(tooltip=["a"]).to_dict() + assert result["encoding"]["tooltip"] == [{"field": "a", "type": "quantitative"}] + + def test_empty_dataframe() -> None: pl.DataFrame({"a": [], "b": []}).plot.point(x="a", y="b")