diff --git a/altair/utils/__init__.py b/altair/utils/__init__.py index 216645c3f..b6855e1ee 100644 --- a/altair/utils/__init__.py +++ b/altair/utils/__init__.py @@ -1,6 +1,5 @@ from .core import ( SHORTHAND_KEYS, - SchemaBase, display_traceback, infer_encoding_types, infer_vegalite_type_for_pandas, @@ -13,7 +12,7 @@ from .deprecation import AltairDeprecationWarning, deprecated, deprecated_warn from .html import spec_to_html from .plugin_registry import PluginRegistry -from .schemapi import Optional, Undefined, is_undefined +from .schemapi import Optional, SchemaBase, Undefined, is_undefined __all__ = ( "SHORTHAND_KEYS", diff --git a/altair/utils/_vegafusion_data.py b/altair/utils/_vegafusion_data.py index 99779b62e..970098d33 100644 --- a/altair/utils/_vegafusion_data.py +++ b/altair/utils/_vegafusion_data.py @@ -13,8 +13,6 @@ ) from weakref import WeakValueDictionary -import narwhals.stable.v1 as nw - from altair.utils._importers import import_vegafusion from altair.utils.core import DataFrameLike from altair.utils.data import ( @@ -71,10 +69,6 @@ def vegafusion_data_transformer( data: DataType | None = None, max_rows: int = 100000 ) -> Callable[..., Any] | _VegaFusionReturnType: """VegaFusion Data Transformer.""" - # Vegafusion does not support Narwhals, so if `data` is a Narwhals - # object, we make sure to extract the native object and let Vegafusion handle it. - # `strict=False` passes `data` through as-is if it is not a Narwhals object. - data = nw.to_native(data, strict=False) if data is None: return vegafusion_data_transformer elif isinstance(data, DataFrameLike) and not isinstance(data, SupportsGeoInterface): diff --git a/altair/utils/data.py b/altair/utils/data.py index 1986ec8c5..42c87ece4 100644 --- a/altair/utils/data.py +++ b/altair/utils/data.py @@ -314,12 +314,7 @@ def to_values(data: DataType) -> ToValuesReturnType: # `strict=False` passes `data` through as-is if it is not a Narwhals object. data_native = nw.to_native(data, strict=False) if isinstance(data_native, SupportsGeoInterface): - if _is_pandas_dataframe(data_native): - data_native = sanitize_pandas_dataframe(data_native) - # Maybe the type could be further clarified here that it is - # SupportGeoInterface and then the ignore statement is not needed? - data_sanitized = sanitize_geo_interface(data_native.__geo_interface__) - return {"values": data_sanitized} + return {"values": _from_geo_interface(data_native)} elif _is_pandas_dataframe(data_native): data_native = sanitize_pandas_dataframe(data_native) return {"values": data_native.to_dict(orient="records")} @@ -350,32 +345,45 @@ def _compute_data_hash(data_str: str) -> str: return hashlib.sha256(data_str.encode()).hexdigest()[:32] +def _from_geo_interface(data: SupportsGeoInterface | Any) -> dict[str, Any]: + """ + Santize a ``__geo_interface__`` w/ pre-santize step for ``pandas`` if needed. + + Notes + ----- + Split out to resolve typing issues related to: + - Intersection types + - ``typing.TypeGuard`` + - ``pd.DataFrame.__getattr__`` + """ + if _is_pandas_dataframe(data): + data = sanitize_pandas_dataframe(data) + return sanitize_geo_interface(data.__geo_interface__) + + def _data_to_json_string(data: DataType) -> str: """Return a JSON string representation of the input data.""" check_data_type(data) - # `strict=False` passes `data` through as-is if it is not a Narwhals object. - data_native = nw.to_native(data, strict=False) - if isinstance(data_native, SupportsGeoInterface): - if _is_pandas_dataframe(data_native): - data_native = sanitize_pandas_dataframe(data_native) - data_native = sanitize_geo_interface(data_native.__geo_interface__) - return json.dumps(data_native) - elif _is_pandas_dataframe(data_native): - data = sanitize_pandas_dataframe(data_native) - return data_native.to_json(orient="records", double_precision=15) - elif isinstance(data_native, dict): - if "values" not in data_native: + if isinstance(data, SupportsGeoInterface): + return json.dumps(_from_geo_interface(data)) + elif _is_pandas_dataframe(data): + data = sanitize_pandas_dataframe(data) + return data.to_json(orient="records", double_precision=15) + elif isinstance(data, dict): + if "values" not in data: msg = "values expected in data dict, but not present." raise KeyError(msg) - return json.dumps(data_native["values"], sort_keys=True) - elif isinstance(data, nw.DataFrame): - return json.dumps(data.rows(named=True)) - else: - msg = "to_json only works with data expressed as " "a DataFrame or as a dict" - raise NotImplementedError(msg) + return json.dumps(data["values"], sort_keys=True) + try: + data_nw = nw.from_native(data, eager_only=True) + except TypeError as exc: + msg = "to_json only works with data expressed as a DataFrame or as a dict" + raise NotImplementedError(msg) from exc + data_nw = sanitize_narwhals_dataframe(data_nw) + return json.dumps(data_nw.rows(named=True)) -def _data_to_csv_string(data: dict | pd.DataFrame | DataFrameLike) -> str: +def _data_to_csv_string(data: DataType) -> str: """Return a CSV string representation of the input data.""" check_data_type(data) if isinstance(data, SupportsGeoInterface): @@ -398,18 +406,12 @@ def _data_to_csv_string(data: dict | pd.DataFrame | DataFrameLike) -> str: msg = "pandas is required to convert a dict to a CSV string" raise ImportError(msg) from exc return pd.DataFrame.from_dict(data["values"]).to_csv(index=False) - elif isinstance(data, DataFrameLike): - # experimental interchange dataframe support - import pyarrow as pa - import pyarrow.csv as pa_csv - - pa_table = arrow_table_from_dfi_dataframe(data) - csv_buffer = pa.BufferOutputStream() - pa_csv.write_csv(pa_table, csv_buffer) - return csv_buffer.getvalue().to_pybytes().decode() - else: - msg = "to_csv only works with data expressed as " "a DataFrame or as a dict" - raise NotImplementedError(msg) + try: + data_nw = nw.from_native(data, eager_only=True) + except TypeError as exc: + msg = "to_csv only works with data expressed as a DataFrame or as a dict" + raise NotImplementedError(msg) from exc + return data_nw.write_csv() def arrow_table_from_dfi_dataframe(dfi_df: DataFrameLike) -> pa.Table: diff --git a/altair/utils/schemapi.py b/altair/utils/schemapi.py index fdf0d6594..bc0b40581 100644 --- a/altair/utils/schemapi.py +++ b/altair/utils/schemapi.py @@ -25,6 +25,7 @@ Sequence, TypeVar, Union, + cast, overload, ) from typing_extensions import TypeAlias @@ -113,7 +114,7 @@ def validate_jsonschema( rootschema: dict[str, Any] | None = ..., *, raise_error: Literal[True] = ..., -) -> None: ... +) -> Never: ... @overload @@ -128,11 +129,11 @@ def validate_jsonschema( def validate_jsonschema( spec, - schema, - rootschema=None, + schema: dict[str, Any], + rootschema: dict[str, Any] | None = None, *, - raise_error=True, -): + raise_error: bool = True, +) -> jsonschema.exceptions.ValidationError | None: """ Validates the passed in spec against the schema in the context of the rootschema. @@ -149,7 +150,7 @@ def validate_jsonschema( # Nothing special about this first error but we need to choose one # which can be raised - main_error = next(iter(grouped_errors.values()))[0] + main_error: Any = next(iter(grouped_errors.values()))[0] # All errors are then attached as a new attribute to ValidationError so that # they can be used in SchemaValidationError to craft a more helpful # error message. Setting a new attribute like this is not ideal as @@ -833,6 +834,41 @@ def is_undefined(obj: Any) -> TypeIs[UndefinedType]: return obj is Undefined +@overload +def _shallow_copy(obj: _CopyImpl) -> _CopyImpl: ... +@overload +def _shallow_copy(obj: Any) -> Any: ... +def _shallow_copy(obj: _CopyImpl | Any) -> _CopyImpl | Any: + if isinstance(obj, SchemaBase): + return obj.copy(deep=False) + elif isinstance(obj, (list, dict)): + return obj.copy() + else: + return obj + + +@overload +def _deep_copy(obj: _CopyImpl, by_ref: set[str]) -> _CopyImpl: ... +@overload +def _deep_copy(obj: Any, by_ref: set[str]) -> Any: ... +def _deep_copy(obj: _CopyImpl | Any, by_ref: set[str]) -> _CopyImpl | Any: + copy = partial(_deep_copy, by_ref=by_ref) + if isinstance(obj, SchemaBase): + if copier := getattr(obj, "__deepcopy__", None): + with debug_mode(False): + return copier(obj) + args = (copy(arg) for arg in obj._args) + kwds = {k: (copy(v) if k not in by_ref else v) for k, v in obj._kwds.items()} + with debug_mode(False): + return obj.__class__(*args, **kwds) + elif isinstance(obj, list): + return [copy(v) for v in obj] + elif isinstance(obj, dict): + return {k: (copy(v) if k not in by_ref else v) for k, v in obj.items()} + else: + return obj + + class SchemaBase: """ Base class for schema wrappers. @@ -870,7 +906,7 @@ def __init__(self, *args: Any, **kwds: Any) -> None: if DEBUG_MODE and self._class_is_valid_at_instantiation: self.to_dict(validate=True) - def copy( # noqa: C901 + def copy( self, deep: bool | Iterable[Any] = True, ignore: list[str] | None = None ) -> Self: """ @@ -887,53 +923,11 @@ def copy( # noqa: C901 A list of keys for which the contents should not be copied, but only stored by reference. """ - - def _shallow_copy(obj): - if isinstance(obj, SchemaBase): - return obj.copy(deep=False) - elif isinstance(obj, list): - return obj[:] - elif isinstance(obj, dict): - return obj.copy() - else: - return obj - - def _deep_copy(obj, ignore: list[str] | None = None): - if ignore is None: - ignore = [] - if isinstance(obj, SchemaBase): - args = tuple(_deep_copy(arg) for arg in obj._args) - kwds = { - k: (_deep_copy(v, ignore=ignore) if k not in ignore else v) - for k, v in obj._kwds.items() - } - with debug_mode(False): - return obj.__class__(*args, **kwds) - elif isinstance(obj, list): - return [_deep_copy(v, ignore=ignore) for v in obj] - elif isinstance(obj, dict): - return { - k: (_deep_copy(v, ignore=ignore) if k not in ignore else v) - for k, v in obj.items() - } - else: - return obj - - try: - deep = list(deep) # type: ignore[arg-type] - except TypeError: - deep_is_list = False - else: - deep_is_list = True - - if deep and not deep_is_list: - return _deep_copy(self, ignore=ignore) - + if deep is True: + return cast("Self", _deep_copy(self, set(ignore) if ignore else set())) with debug_mode(False): copy = self.__class__(*self._args, **self._kwds) - if deep_is_list: - # Assert statement is for the benefit of Mypy - assert isinstance(deep, list) + if _is_iterable(deep): for attr in deep: copy[attr] = _shallow_copy(copy._get(attr)) return copy @@ -953,7 +947,7 @@ def __getattr__(self, attr): return self._kwds[attr] else: try: - _getattr = super().__getattr__ + _getattr = super().__getattr__ # pyright: ignore[reportAttributeAccessIssue] except AttributeError: _getattr = super().__getattribute__ return _getattr(attr) @@ -1202,9 +1196,7 @@ def validate( schema = cls._schema # For the benefit of mypy assert schema is not None - return validate_jsonschema( - instance, schema, rootschema=cls._rootschema or cls._schema - ) + validate_jsonschema(instance, schema, rootschema=cls._rootschema or cls._schema) @classmethod def resolve_references(cls, schema: dict[str, Any] | None = None) -> dict[str, Any]: @@ -1230,7 +1222,7 @@ def validate_property( np_opt = sys.modules.get("numpy") value = _todict(value, context={}, np_opt=np_opt, pd_opt=pd_opt) props = cls.resolve_references(schema or cls._schema).get("properties", {}) - return validate_jsonschema( + validate_jsonschema( value, props.get(name, {}), rootschema=cls._rootschema or cls._schema ) @@ -1240,6 +1232,13 @@ def __dir__(self) -> list[str]: TSchemaBase = TypeVar("TSchemaBase", bound=SchemaBase) +_CopyImpl = TypeVar("_CopyImpl", SchemaBase, Dict[Any, Any], List[Any]) +""" +Types which have an implementation in ``SchemaBase.copy()``. + +All other types are returned **by reference**. +""" + def _is_dict(obj: Any | dict[Any, Any]) -> TypeIs[dict[Any, Any]]: return isinstance(obj, dict) @@ -1325,11 +1324,11 @@ def from_dict( @overload def from_dict( self, - dct: dict[str, Any], - tp: None = ..., + dct: dict[str, Any] | list[dict[str, Any]], + tp: Any = ..., schema: Any = ..., - rootschema: None = ..., - default_class: type[TSchemaBase] = ..., + rootschema: Any = ..., + default_class: type[TSchemaBase] = ..., # pyright: ignore[reportInvalidTypeVarUse] ) -> TSchemaBase: ... @overload def from_dict( @@ -1365,15 +1364,15 @@ def from_dict( schema: dict[str, Any] | None = None, rootschema: dict[str, Any] | None = None, default_class: Any = _passthrough, - ) -> TSchemaBase: + ) -> TSchemaBase | SchemaBase: """Construct an object from a dict representation.""" - target_tp: type[TSchemaBase] + target_tp: Any current_schema: dict[str, Any] if isinstance(dct, SchemaBase): - return dct # type: ignore[return-value] + return dct elif tp is not None: current_schema = tp._schema - root_schema = rootschema or tp._rootschema or current_schema + root_schema: dict[str, Any] = rootschema or tp._rootschema or current_schema target_tp = tp elif schema is not None: # If there are multiple matches, we use the first one in the dict. diff --git a/altair/vegalite/v5/api.py b/altair/vegalite/v5/api.py index e77d62e00..4e8fde039 100644 --- a/altair/vegalite/v5/api.py +++ b/altair/vegalite/v5/api.py @@ -23,6 +23,7 @@ from typing_extensions import TypeAlias import jsonschema +import narwhals.stable.v1 as nw from altair import utils from altair.expr import core as _expr_core @@ -274,7 +275,7 @@ def _prepare_data( # convert dataframes or objects with __geo_interface__ to dict elif not isinstance(data, dict) and _is_data_type(data): if func := data_transformers.get(): - data = func(data) + data = func(nw.to_native(data, strict=False)) # convert string input to a URLData elif isinstance(data, str): @@ -1060,6 +1061,9 @@ def to_dict(self, *args: Any, **kwds: Any) -> _Conditional[_C]: # type: ignore[ m = super().to_dict(*args, **kwds) return _Conditional(condition=m["condition"]) + def __deepcopy__(self, memo: Any) -> Self: + return type(self)(_Conditional(condition=_deepcopy(self.condition))) + class ChainedWhen(_BaseWhen): """ diff --git a/pyproject.toml b/pyproject.toml index 4a0c3874c..4c85070a1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,7 +20,7 @@ dependencies = [ # If you update the minimum required jsonschema version, also update it in build.yml "jsonschema>=3.0", "packaging", - "narwhals>=1.1.0" + "narwhals>=1.5.2" ] description = "Vega-Altair: A declarative statistical visualization library for Python." readme = "README.md" diff --git a/tests/test_jupyter_chart.py b/tests/test_jupyter_chart.py index cec67153f..0630ce4b3 100644 --- a/tests/test_jupyter_chart.py +++ b/tests/test_jupyter_chart.py @@ -1,5 +1,8 @@ +from importlib.metadata import version as importlib_version + import pandas as pd import pytest +from packaging.version import Version import altair as alt from vega_datasets import data @@ -17,6 +20,10 @@ else: jupyter_chart = None # type: ignore +skip_requires_anywidget = pytest.mark.skipif( + not has_anywidget, reason="anywidget not importable" +) + try: import vegafusion # type: ignore # noqa: F401 @@ -25,13 +32,23 @@ except ImportError: transformers = ["default"] +param_transformers = pytest.mark.parametrize("transformer", transformers) + + +if Version(importlib_version("ipywidgets")) < Version("8.1.4"): + # See https://github.com/vega/altair/issues/3234#issuecomment-2268515312 + _filterwarn = pytest.mark.filterwarnings( + "ignore:Deprecated in traitlets 4.1.*:DeprecationWarning" + ) + jupyter_marks: pytest.MarkDecorator = skip_requires_anywidget( + _filterwarn(param_transformers) + ) +else: + jupyter_marks = skip_requires_anywidget(param_transformers) -@pytest.mark.filterwarnings("ignore:Deprecated in traitlets 4.1.*:DeprecationWarning") -@pytest.mark.parametrize("transformer", transformers) -def test_chart_with_no_interactivity(transformer): - if not has_anywidget: - pytest.skip("anywidget not importable; skipping test") +@jupyter_marks +def test_chart_with_no_interactivity(transformer): with alt.data_transformers.enable(transformer): source = pd.DataFrame( { @@ -56,12 +73,8 @@ def test_chart_with_no_interactivity(transformer): assert len(widget.params.trait_values()) == 0 -@pytest.mark.filterwarnings("ignore:Deprecated in traitlets 4.1.*:DeprecationWarning") -@pytest.mark.parametrize("transformer", transformers) +@jupyter_marks def test_interval_selection_example(transformer): - if not has_anywidget: - pytest.skip("anywidget not importable; skipping test") - with alt.data_transformers.enable(transformer): source = data.cars() brush = alt.selection_interval(name="interval") @@ -128,12 +141,8 @@ def test_interval_selection_example(transformer): assert selection.store == store -@pytest.mark.filterwarnings("ignore:Deprecated in traitlets 4.1.*:DeprecationWarning") -@pytest.mark.parametrize("transformer", transformers) +@jupyter_marks def test_index_selection_example(transformer): - if not has_anywidget: - pytest.skip("anywidget not importable; skipping test") - with alt.data_transformers.enable(transformer): source = data.cars() brush = alt.selection_point(name="index") @@ -192,12 +201,8 @@ def test_index_selection_example(transformer): assert selection.store == store -@pytest.mark.filterwarnings("ignore:Deprecated in traitlets 4.1.*:DeprecationWarning") -@pytest.mark.parametrize("transformer", transformers) +@jupyter_marks def test_point_selection(transformer): - if not has_anywidget: - pytest.skip("anywidget not importable; skipping test") - with alt.data_transformers.enable(transformer): source = data.cars() brush = alt.selection_point(name="point", encodings=["color"], bind="legend") @@ -259,12 +264,8 @@ def test_point_selection(transformer): assert selection.store == store -@pytest.mark.filterwarnings("ignore:Deprecated in traitlets 4.1.*:DeprecationWarning") -@pytest.mark.parametrize("transformer", transformers) +@jupyter_marks def test_param_updates(transformer): - if not has_anywidget: - pytest.skip("anywidget not importable; skipping test") - with alt.data_transformers.enable(transformer): source = data.cars() size_param = alt.param( diff --git a/tests/utils/test_data.py b/tests/utils/test_data.py index 673c2852c..30b5b7f8e 100644 --- a/tests/utils/test_data.py +++ b/tests/utils/test_data.py @@ -1,5 +1,7 @@ +from __future__ import annotations + from pathlib import Path -from typing import Any, Callable +from typing import Any, Callable, SupportsIndex, TypeVar import narwhals.stable.v1 as nw import pandas as pd @@ -15,6 +17,8 @@ to_values, ) +T = TypeVar("T") + def _pipe(data: Any, *funcs: Callable[..., Any]) -> Any: # Redefined to maintain existing tests @@ -24,13 +28,15 @@ def _pipe(data: Any, *funcs: Callable[..., Any]) -> Any: return data -def _create_dataframe(N): - data = pd.DataFrame({"x": range(N), "y": range(N)}) +def _create_dataframe( + n: SupportsIndex, /, tp: Callable[..., T] | type[Any] = pd.DataFrame +) -> T | Any: + data = tp({"x": range(n), "y": range(n)}) return data -def _create_data_with_values(N): - data = {"values": [{"x": i, "y": i + 1} for i in range(N)]} +def _create_data_with_values(n: SupportsIndex, /) -> dict[str, Any]: + data = {"values": [{"x": i, "y": i + 1} for i in range(n)]} return data @@ -127,19 +133,20 @@ def test_dict_to_json(): assert data == {"values": output} -def test_dataframe_to_csv(): +@pytest.mark.parametrize("tp", [pd.DataFrame, pl.DataFrame], ids=["pandas", "polars"]) +def test_dataframe_to_csv(tp: type[Any]) -> None: """ Test to_csv with dataframe input. - make certain the filename is deterministic - make certain the file contents match the data. """ - data = _create_dataframe(10) + data = _create_dataframe(10, tp=tp) try: result1 = _pipe(data, to_csv) result2 = _pipe(data, to_csv) filename = result1["url"] - output = pd.read_csv(filename) + output = tp(pd.read_csv(filename)) finally: Path(filename).unlink() diff --git a/tests/utils/test_schemapi.py b/tests/utils/test_schemapi.py index a6107601a..2f0b9faab 100644 --- a/tests/utils/test_schemapi.py +++ b/tests/utils/test_schemapi.py @@ -620,7 +620,7 @@ def chart_error_example__two_errors_with_one_in_nested_layered_chart(): return chart -def chart_error_example__four_errors(): +def chart_error_example__four_errors_hide_fourth(): # Error 1: unknown is not a valid encoding channel option # Error 2: Invalid Y option value "asdf". # Error 3: another_unknown is not a valid encoding channel option @@ -639,14 +639,16 @@ def chart_error_example__four_errors(): ) -def id_func(val) -> str: +def id_func_chart_error_example(val) -> str: """ - Ensures the generated test-id name uses only `chart_func` and not `expected_error_message`. + Ensures the generated test-id name uses only the unique portion of `chart_func`. - Without this, the name is ``test_chart_validation_errors[chart_func-expected_error_message]`` + Otherwise the name is like below, but ``...`` represents the full error message:: + + "test_chart_validation_errors[chart_error_example__two_errors_with_one_in_nested_layered_chart-...]" """ if isinstance(val, types.FunctionType): - return val.__name__ + return val.__name__.replace("chart_error_example__", "") else: return "" @@ -821,7 +823,7 @@ def id_func(val) -> str: r"""'1' is an invalid value for `value`. Valid values are of type 'object', 'string', or 'null'.$""", ), ( - chart_error_example__four_errors, + chart_error_example__four_errors_hide_fourth, r"""Multiple errors were found. Error 1: `Color` has no parameter named 'another_unknown' @@ -856,7 +858,9 @@ def id_func(val) -> str: @pytest.mark.parametrize( - ("chart_func", "expected_error_message"), chart_funcs_error_message, ids=id_func + ("chart_func", "expected_error_message"), + chart_funcs_error_message, + ids=id_func_chart_error_example, ) def test_chart_validation_errors(chart_func, expected_error_message): # For some wrong chart specifications such as an unknown encoding channel, diff --git a/tests/vegalite/v5/test_api.py b/tests/vegalite/v5/test_api.py index 29d68d1ea..241d47378 100644 --- a/tests/vegalite/v5/test_api.py +++ b/tests/vegalite/v5/test_api.py @@ -698,6 +698,27 @@ def test_when_condition_parity( assert chart_condition == chart_when +def test_when_then_interactive() -> None: + """Copy-related regression found in https://github.com/vega/altair/pull/3394#issuecomment-2302995453.""" + source = "https://cdn.jsdelivr.net/npm/vega-datasets@v1.29.0/data/movies.json" + predicate = (alt.datum.IMDB_Rating == None) | ( # noqa: E711 + alt.datum.Rotten_Tomatoes_Rating == None # noqa: E711 + ) + + chart = ( + alt.Chart(source) + .mark_point(invalid=None) + .encode( + x="IMDB_Rating:Q", + y="Rotten_Tomatoes_Rating:Q", + color=alt.when(predicate).then(alt.value("grey")), # type: ignore[arg-type] + ) + ) + assert chart.interactive() + assert chart.copy() + assert chart.to_dict() + + def test_selection_to_dict(): brush = alt.selection_interval() diff --git a/tests/vegalite/v5/test_renderers.py b/tests/vegalite/v5/test_renderers.py index c0c1333a7..f5ed6f922 100644 --- a/tests/vegalite/v5/test_renderers.py +++ b/tests/vegalite/v5/test_renderers.py @@ -1,8 +1,10 @@ """Tests of various renderers.""" import json +from importlib.metadata import version as importlib_version import pytest +from packaging.version import Version import altair.vegalite.v5 as alt @@ -18,6 +20,20 @@ anywidget = None # type: ignore +skip_requires_anywidget = pytest.mark.skipif( + not anywidget, reason="anywidget not importable" +) +if Version(importlib_version("ipywidgets")) < Version("8.1.4"): + # See https://github.com/vega/altair/issues/3234#issuecomment-2268515312 + jupyter_marks = skip_requires_anywidget( + pytest.mark.filterwarnings( + "ignore:Deprecated in traitlets 4.1.*:DeprecationWarning" + ) + ) +else: + jupyter_marks = skip_requires_anywidget + + @pytest.fixture def chart(): return alt.Chart("data.csv").mark_point() @@ -94,12 +110,9 @@ def test_renderer_with_none_embed_options(chart, renderer="mimetype"): assert bundle["image/svg+xml"].startswith(" None: """Test that we get the expected widget mimetype when the jupyter renderer is enabled.""" - if not anywidget: - pytest.skip("anywidget not importable; skipping test") - with alt.renderers.enable(renderer): assert ( "application/vnd.jupyter.widget-view+json" diff --git a/tools/schemapi/schemapi.py b/tools/schemapi/schemapi.py index 1c756c2a2..9d21ab793 100644 --- a/tools/schemapi/schemapi.py +++ b/tools/schemapi/schemapi.py @@ -23,6 +23,7 @@ Sequence, TypeVar, Union, + cast, overload, ) from typing_extensions import TypeAlias @@ -111,7 +112,7 @@ def validate_jsonschema( rootschema: dict[str, Any] | None = ..., *, raise_error: Literal[True] = ..., -) -> None: ... +) -> Never: ... @overload @@ -126,11 +127,11 @@ def validate_jsonschema( def validate_jsonschema( spec, - schema, - rootschema=None, + schema: dict[str, Any], + rootschema: dict[str, Any] | None = None, *, - raise_error=True, -): + raise_error: bool = True, +) -> jsonschema.exceptions.ValidationError | None: """ Validates the passed in spec against the schema in the context of the rootschema. @@ -147,7 +148,7 @@ def validate_jsonschema( # Nothing special about this first error but we need to choose one # which can be raised - main_error = next(iter(grouped_errors.values()))[0] + main_error: Any = next(iter(grouped_errors.values()))[0] # All errors are then attached as a new attribute to ValidationError so that # they can be used in SchemaValidationError to craft a more helpful # error message. Setting a new attribute like this is not ideal as @@ -831,6 +832,41 @@ def is_undefined(obj: Any) -> TypeIs[UndefinedType]: return obj is Undefined +@overload +def _shallow_copy(obj: _CopyImpl) -> _CopyImpl: ... +@overload +def _shallow_copy(obj: Any) -> Any: ... +def _shallow_copy(obj: _CopyImpl | Any) -> _CopyImpl | Any: + if isinstance(obj, SchemaBase): + return obj.copy(deep=False) + elif isinstance(obj, (list, dict)): + return obj.copy() + else: + return obj + + +@overload +def _deep_copy(obj: _CopyImpl, by_ref: set[str]) -> _CopyImpl: ... +@overload +def _deep_copy(obj: Any, by_ref: set[str]) -> Any: ... +def _deep_copy(obj: _CopyImpl | Any, by_ref: set[str]) -> _CopyImpl | Any: + copy = partial(_deep_copy, by_ref=by_ref) + if isinstance(obj, SchemaBase): + if copier := getattr(obj, "__deepcopy__", None): + with debug_mode(False): + return copier(obj) + args = (copy(arg) for arg in obj._args) + kwds = {k: (copy(v) if k not in by_ref else v) for k, v in obj._kwds.items()} + with debug_mode(False): + return obj.__class__(*args, **kwds) + elif isinstance(obj, list): + return [copy(v) for v in obj] + elif isinstance(obj, dict): + return {k: (copy(v) if k not in by_ref else v) for k, v in obj.items()} + else: + return obj + + class SchemaBase: """ Base class for schema wrappers. @@ -868,7 +904,7 @@ def __init__(self, *args: Any, **kwds: Any) -> None: if DEBUG_MODE and self._class_is_valid_at_instantiation: self.to_dict(validate=True) - def copy( # noqa: C901 + def copy( self, deep: bool | Iterable[Any] = True, ignore: list[str] | None = None ) -> Self: """ @@ -885,53 +921,11 @@ def copy( # noqa: C901 A list of keys for which the contents should not be copied, but only stored by reference. """ - - def _shallow_copy(obj): - if isinstance(obj, SchemaBase): - return obj.copy(deep=False) - elif isinstance(obj, list): - return obj[:] - elif isinstance(obj, dict): - return obj.copy() - else: - return obj - - def _deep_copy(obj, ignore: list[str] | None = None): - if ignore is None: - ignore = [] - if isinstance(obj, SchemaBase): - args = tuple(_deep_copy(arg) for arg in obj._args) - kwds = { - k: (_deep_copy(v, ignore=ignore) if k not in ignore else v) - for k, v in obj._kwds.items() - } - with debug_mode(False): - return obj.__class__(*args, **kwds) - elif isinstance(obj, list): - return [_deep_copy(v, ignore=ignore) for v in obj] - elif isinstance(obj, dict): - return { - k: (_deep_copy(v, ignore=ignore) if k not in ignore else v) - for k, v in obj.items() - } - else: - return obj - - try: - deep = list(deep) # type: ignore[arg-type] - except TypeError: - deep_is_list = False - else: - deep_is_list = True - - if deep and not deep_is_list: - return _deep_copy(self, ignore=ignore) - + if deep is True: + return cast("Self", _deep_copy(self, set(ignore) if ignore else set())) with debug_mode(False): copy = self.__class__(*self._args, **self._kwds) - if deep_is_list: - # Assert statement is for the benefit of Mypy - assert isinstance(deep, list) + if _is_iterable(deep): for attr in deep: copy[attr] = _shallow_copy(copy._get(attr)) return copy @@ -951,7 +945,7 @@ def __getattr__(self, attr): return self._kwds[attr] else: try: - _getattr = super().__getattr__ + _getattr = super().__getattr__ # pyright: ignore[reportAttributeAccessIssue] except AttributeError: _getattr = super().__getattribute__ return _getattr(attr) @@ -1200,9 +1194,7 @@ def validate( schema = cls._schema # For the benefit of mypy assert schema is not None - return validate_jsonschema( - instance, schema, rootschema=cls._rootschema or cls._schema - ) + validate_jsonschema(instance, schema, rootschema=cls._rootschema or cls._schema) @classmethod def resolve_references(cls, schema: dict[str, Any] | None = None) -> dict[str, Any]: @@ -1228,7 +1220,7 @@ def validate_property( np_opt = sys.modules.get("numpy") value = _todict(value, context={}, np_opt=np_opt, pd_opt=pd_opt) props = cls.resolve_references(schema or cls._schema).get("properties", {}) - return validate_jsonschema( + validate_jsonschema( value, props.get(name, {}), rootschema=cls._rootschema or cls._schema ) @@ -1238,6 +1230,13 @@ def __dir__(self) -> list[str]: TSchemaBase = TypeVar("TSchemaBase", bound=SchemaBase) +_CopyImpl = TypeVar("_CopyImpl", SchemaBase, Dict[Any, Any], List[Any]) +""" +Types which have an implementation in ``SchemaBase.copy()``. + +All other types are returned **by reference**. +""" + def _is_dict(obj: Any | dict[Any, Any]) -> TypeIs[dict[Any, Any]]: return isinstance(obj, dict) @@ -1323,11 +1322,11 @@ def from_dict( @overload def from_dict( self, - dct: dict[str, Any], - tp: None = ..., + dct: dict[str, Any] | list[dict[str, Any]], + tp: Any = ..., schema: Any = ..., - rootschema: None = ..., - default_class: type[TSchemaBase] = ..., + rootschema: Any = ..., + default_class: type[TSchemaBase] = ..., # pyright: ignore[reportInvalidTypeVarUse] ) -> TSchemaBase: ... @overload def from_dict( @@ -1363,15 +1362,15 @@ def from_dict( schema: dict[str, Any] | None = None, rootschema: dict[str, Any] | None = None, default_class: Any = _passthrough, - ) -> TSchemaBase: + ) -> TSchemaBase | SchemaBase: """Construct an object from a dict representation.""" - target_tp: type[TSchemaBase] + target_tp: Any current_schema: dict[str, Any] if isinstance(dct, SchemaBase): - return dct # type: ignore[return-value] + return dct elif tp is not None: current_schema = tp._schema - root_schema = rootschema or tp._rootschema or current_schema + root_schema: dict[str, Any] = rootschema or tp._rootschema or current_schema target_tp = tp elif schema is not None: # If there are multiple matches, we use the first one in the dict.