diff --git a/altair/vegalite/v5/api.py b/altair/vegalite/v5/api.py index 9a6feb667..f5e15f30c 100644 --- a/altair/vegalite/v5/api.py +++ b/altair/vegalite/v5/api.py @@ -4070,25 +4070,35 @@ def add_selection(self, *params) -> Self: # noqa: ANN002 return self.add_params(*params) def interactive( - self, name: str | None = None, bind_x: bool = True, bind_y: bool = True - ) -> Self: + self, + name: str | None = None, + bind_x: bool = True, + bind_y: bool = True, + tooltip: bool = True, + legend: bool | LegendChannel_T = False, + ) -> Chart: """ - Make chart axes scales interactive. + Add common interactive elements to the chart. Parameters ---------- - name : string + name : string or None The parameter name to use for the axes scales. This name should be - unique among all parameters within the chart. + unique among all parameters within the chart bind_x : boolean, default True - If true, then bind the interactive scales to the x-axis + Bind the interactive scales to the x-axis bind_y : boolean, default True - If true, then bind the interactive scales to the y-axis + Bind the interactive scales to the y-axis + tooltip : boolean, default True, + Add a tooltip containing the encodings used in the chart + legend : boolean or string, default True + A single encoding channel to be used to create a clickable legend. + The deafult is to guess from the spec based on the most commonly used legend encodings. Returns ------- chart : - copy of self, with interactive axes added + copy of self, with interactivity added """ encodings: list[SingleDefUnitChannel_T] = [] @@ -4096,7 +4106,81 @@ def interactive( encodings.append("x") if bind_y: encodings.append("y") - return self.add_params(selection_interval(bind="scales", encodings=encodings)) + chart: Chart = self.copy().add_params( + selection_interval(bind="scales", encodings=encodings) + ) + # We can't simply use configure_mark since configure methods + # are not allowed in layered specs + if tooltip: + chart = _add_tooltip(chart) + legend_encodings_missing = utils.is_undefined(chart.encoding) + if legend and not legend_encodings_missing: + facet_encoding: FacetedEncoding = chart.encoding + if not isinstance(legend, str): + legend = _infer_legend_encoding(facet_encoding) + + facet_legend = facet_encoding[legend] + legend_type = facet_legend["type"] + if utils.is_undefined(legend_type): + legend_type = facet_legend.to_dict(context={"data": chart.data})["type"] + + if legend_type == "nominal": + # TODO Ideally this would work for ordinal data too + legend_selection = selection_point(bind="legend", encodings=[legend]) + initial_computed_domain = param(expr=f"domain('{legend}')") + nonreactive_domain = param( + react=False, expr=initial_computed_domain.name + ) + scale = facet_legend["scale"] + if utils.is_undefined(scale): + scale = {"domain": nonreactive_domain} + else: + scale["domain"] = nonreactive_domain + chart = chart.add_params( + legend_selection, + initial_computed_domain, + nonreactive_domain, + ).transform_filter(legend_selection) + else: + msg = f"Expected only 'nominal' legend type but got {legend_type!r}" + raise NotImplementedError(msg) + return chart + + +LegendChannel_T: TypeAlias = Literal[ + "color", + "fill", + "shape", + "stroke", + "opacity", + "fillOpacity", + "strokeOpacity", + "strokeWidth", + "strokeDash", + "angle", # TODO Untested + "radius", # TODO Untested + "radius2", # TODO Untested + # "size", # TODO Currently size is not working, renders empty legend +] + + +def _add_tooltip(chart: _TChart, /) -> _TChart: + if isinstance(chart.mark, str): + chart.mark = {"type": chart.mark, "tooltip": True} + else: + chart.mark.tooltip = True + return chart + + +def _infer_legend_encoding(encoding: FacetedEncoding, /) -> LegendChannel_T: + """Set the legend to commonly used encodings by default.""" + _channels = t.get_args(LegendChannel_T) + it = (ch for ch in _channels if not utils.is_undefined(encoding[ch])) + if legend := next(it, None): + return legend + else: + msg = f"Unable to infer target channel for 'legend'.\n\n{encoding!r}" + raise NotImplementedError(msg) def _check_if_valid_subspec( @@ -5176,6 +5260,16 @@ def sphere() -> SphereGenerator: return core.SphereGenerator(sphere=True) +_TChart = TypeVar( + "_TChart", + Chart, + RepeatChart, + ConcatChart, + HConcatChart, + VConcatChart, + FacetChart, + LayerChart, +) ChartType: TypeAlias = Union[ Chart, RepeatChart, ConcatChart, HConcatChart, VConcatChart, FacetChart, LayerChart ] diff --git a/tests/vegalite/v5/test_api.py b/tests/vegalite/v5/test_api.py index 390a7217f..7d733a12b 100644 --- a/tests/vegalite/v5/test_api.py +++ b/tests/vegalite/v5/test_api.py @@ -15,7 +15,7 @@ from datetime import date, datetime from importlib.metadata import version as importlib_version from importlib.util import find_spec -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Literal import duckdb import jsonschema @@ -33,14 +33,19 @@ if TYPE_CHECKING: from typing import Any + from altair.typing import ChartType, Optional from altair.vegalite.v5.api import _Conditional, _Conditions - from altair.vegalite.v5.schema._typing import Map + from altair.vegalite.v5.schema._typing import Map, SingleDefUnitChannel_T PANDAS_VERSION = Version(importlib_version("pandas")) +_MakeType = Literal[ + "layer", "hconcat", "vconcat", "concat", "facet", "facet_encoding", "repeat" +] -def getargs(*args, **kwargs): + +def getargs(*args, **kwargs) -> tuple[tuple[Any, ...], dict[str, Any]]: return args, kwargs @@ -51,7 +56,7 @@ def getargs(*args, **kwargs): } -def _make_chart_type(chart_type): +def _make_chart_type(chart_type: _MakeType) -> ChartType: data = pd.DataFrame( { "x": [28, 55, 43, 91, 81, 53, 19, 87], @@ -85,6 +90,22 @@ def _make_chart_type(chart_type): raise ValueError(msg) +@pytest.fixture( + params=( + "layer", + "hconcat", + "vconcat", + "concat", + "facet", + "facet_encoding", + "repeat", + ) +) +def all_chart_types(request) -> ChartType: + """Use the parameter name ``all_chart_types`` to automatically parameterise.""" + return _make_chart_type(request.param) + + @pytest.fixture def basic_chart() -> alt.Chart: data = pd.DataFrame( @@ -97,6 +118,18 @@ def basic_chart() -> alt.Chart: return alt.Chart(data).mark_bar().encode(x="a", y="b") +@pytest.fixture +def color_data() -> pl.DataFrame: + """10 rows, 3 columns ``"x:Q"``, ``"y:Q"``, ``"color:(N|O)"``.""" + return pl.DataFrame( + { + "x": [0.32, 0.86, 0.10, 0.30, 0.47, 0.0, 1.0, 0.91, 0.88, 0.12], + "y": [0.11, 0.33, 0.01, 0.04, 0.77, 0.1, 0.2, 0.23, 0.05, 0.29], + "color": list("ACABBCABBA"), + } + ) + + @pytest.fixture def cars(): return pd.DataFrame( @@ -1840,3 +1873,118 @@ def old_binding(input: Any, **kwargs: Any) -> alt.Binding: # NOTE: Both type checkers can detect the issue on the new signature with pytest.raises(TypeError, match=MISSING_INPUT): alt.binding(placeholder="Country", name="Search") # type: ignore[call-arg] + + +# TODO These chart types don't all work yet +@pytest.mark.parametrize( + "chart_type", + [ + "chart", + pytest.param( + "layer", marks=pytest.mark.xfail(reason="Not Implemented", raises=TypeError) + ), + pytest.param( + "facet", marks=pytest.mark.xfail(reason="Not Implemented", raises=TypeError) + ), + ], +) +def test_interactive_for_chart_types(chart_type: _MakeType): + chart = _make_chart_type(chart_type) + assert chart.interactive(legend=True).to_dict() # type: ignore[call-arg] + + +def test_interactive_with_no_encoding(all_chart_types): + # Should not raise error when there is no encoding + assert all_chart_types.interactive().to_dict() + + +def test_interactive_tooltip_added_for_all_encodings(): + # TODO Loop through all possible encodings + # and check that tooltip interactivity is added for all of them + assert "tooltip" in alt.Chart().mark_point().interactive().to_dict()["mark"] + assert ( + "tooltip" + not in alt.Chart().mark_point().interactive(tooltip=False).to_dict()["mark"] + ) + + +@pytest.mark.parametrize( + ("encoding", "err"), + [ + ("xOffset", NotImplementedError), + ("yOffset", NotImplementedError), + ("x2", NotImplementedError), + ("y2", NotImplementedError), + ("longitude", NotImplementedError), + ("latitude", NotImplementedError), + ("longitude2", NotImplementedError), + ("latitude2", NotImplementedError), + ("theta", NotImplementedError), + ("theta2", NotImplementedError), + ("radius", None), + ("radius2", KeyError), + ("color", None), + ("fill", None), + ("stroke", None), + ("opacity", None), + ("fillOpacity", None), + ("strokeOpacity", None), + ("strokeWidth", None), + ("strokeDash", None), + ("size", NotImplementedError), + ("angle", None), + ("shape", None), + ("key", NotImplementedError), + ("text", NotImplementedError), + ("href", NotImplementedError), + ("url", NotImplementedError), + ("description", NotImplementedError), + ], +) +def test_interactive_legend_made_interactive_for_legend_encodings( + color_data, encoding: SingleDefUnitChannel_T, err: type[Exception] | None +) -> None: + """Check that legend interactivity is added only for the allowed legend encodings.""" + chart = ( + alt.Chart(color_data).mark_point().encode(x="x", y="y", **{encoding: "color"}) + ) + if err is None: + assert chart.interactive(legend=True).to_dict() + else: + with pytest.raises(err): + chart.interactive(legend=True).to_dict() + + +def test_interactive_legend_made_interactive_for_appropriate_encodings_types( + color_data, +) -> None: + chart = alt.Chart(color_data).mark_point().encode(x="x", y="y") + + # TODO Reverse legend=False/True once we revert the default arg to true + assert len(chart.encode(color="color:N").interactive().to_dict()["params"]) == 1 + chart_with_nominal_legend_encoding = ( + chart.encode(color="color:N").interactive(legend=True).to_dict() + ) + assert len(chart_with_nominal_legend_encoding["params"]) == 4 + for param in chart_with_nominal_legend_encoding["params"]: + if "expr" in param: + assert param["expr"] == "domain('color')" or "react" in param + + # TODO These tests currently don't work because we are raising + # when the legend is not nominal. To be updated if we change that behavior + # TODO Could change this first test if we get interactivity working with nominal + # chart_with_ordinal_legend_encoding = ( + # chart.encode(color="color:O").interactive(legend=True).to_dict() + # ) + # assert len(chart_with_ordinal_legend_encoding["params"]) == 1 + + # chart_with_quantitative_legend_encoding = ( + # chart.encode(color="color:Q").interactive(legend=True).to_dict() + # ) + # assert len(chart_with_quantitative_legend_encoding["params"]) == 1 + + +def test_interactive_legend_encoding_correctly_picked_from_multiple_encodings(): + # TODO The legend should be chosen based on the priority order + # in the list of possible legend encodings + ...