Skip to content

Commit

Permalink
Merge remote-tracking branch 'upstream/HEAD' into shorthand-namespace
Browse files Browse the repository at this point in the history
  • Loading branch information
dangotbanned committed Aug 27, 2024
2 parents 6570912 + 030db9b commit 4e02ac4
Show file tree
Hide file tree
Showing 12 changed files with 264 additions and 221 deletions.
3 changes: 1 addition & 2 deletions altair/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from .core import (
SHORTHAND_KEYS,
SchemaBase,
display_traceback,
infer_encoding_types,
infer_vegalite_type_for_pandas,
Expand All @@ -13,7 +12,7 @@
from .deprecation import AltairDeprecationWarning, deprecated, deprecated_warn
from .html import spec_to_html
from .plugin_registry import PluginRegistry
from .schemapi import Optional, Undefined, is_undefined
from .schemapi import Optional, SchemaBase, Undefined, is_undefined

__all__ = (
"SHORTHAND_KEYS",
Expand Down
6 changes: 0 additions & 6 deletions altair/utils/_vegafusion_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,6 @@
)
from weakref import WeakValueDictionary

import narwhals.stable.v1 as nw

from altair.utils._importers import import_vegafusion
from altair.utils.core import DataFrameLike
from altair.utils.data import (
Expand Down Expand Up @@ -71,10 +69,6 @@ def vegafusion_data_transformer(
data: DataType | None = None, max_rows: int = 100000
) -> Callable[..., Any] | _VegaFusionReturnType:
"""VegaFusion Data Transformer."""
# Vegafusion does not support Narwhals, so if `data` is a Narwhals
# object, we make sure to extract the native object and let Vegafusion handle it.
# `strict=False` passes `data` through as-is if it is not a Narwhals object.
data = nw.to_native(data, strict=False)
if data is None:
return vegafusion_data_transformer
elif isinstance(data, DataFrameLike) and not isinstance(data, SupportsGeoInterface):
Expand Down
76 changes: 39 additions & 37 deletions altair/utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,12 +314,7 @@ def to_values(data: DataType) -> ToValuesReturnType:
# `strict=False` passes `data` through as-is if it is not a Narwhals object.
data_native = nw.to_native(data, strict=False)
if isinstance(data_native, SupportsGeoInterface):
if _is_pandas_dataframe(data_native):
data_native = sanitize_pandas_dataframe(data_native)
# Maybe the type could be further clarified here that it is
# SupportGeoInterface and then the ignore statement is not needed?
data_sanitized = sanitize_geo_interface(data_native.__geo_interface__)
return {"values": data_sanitized}
return {"values": _from_geo_interface(data_native)}
elif _is_pandas_dataframe(data_native):
data_native = sanitize_pandas_dataframe(data_native)
return {"values": data_native.to_dict(orient="records")}
Expand Down Expand Up @@ -350,32 +345,45 @@ def _compute_data_hash(data_str: str) -> str:
return hashlib.sha256(data_str.encode()).hexdigest()[:32]


def _from_geo_interface(data: SupportsGeoInterface | Any) -> dict[str, Any]:
"""
Santize a ``__geo_interface__`` w/ pre-santize step for ``pandas`` if needed.
Notes
-----
Split out to resolve typing issues related to:
- Intersection types
- ``typing.TypeGuard``
- ``pd.DataFrame.__getattr__``
"""
if _is_pandas_dataframe(data):
data = sanitize_pandas_dataframe(data)
return sanitize_geo_interface(data.__geo_interface__)


def _data_to_json_string(data: DataType) -> str:
"""Return a JSON string representation of the input data."""
check_data_type(data)
# `strict=False` passes `data` through as-is if it is not a Narwhals object.
data_native = nw.to_native(data, strict=False)
if isinstance(data_native, SupportsGeoInterface):
if _is_pandas_dataframe(data_native):
data_native = sanitize_pandas_dataframe(data_native)
data_native = sanitize_geo_interface(data_native.__geo_interface__)
return json.dumps(data_native)
elif _is_pandas_dataframe(data_native):
data = sanitize_pandas_dataframe(data_native)
return data_native.to_json(orient="records", double_precision=15)
elif isinstance(data_native, dict):
if "values" not in data_native:
if isinstance(data, SupportsGeoInterface):
return json.dumps(_from_geo_interface(data))
elif _is_pandas_dataframe(data):
data = sanitize_pandas_dataframe(data)
return data.to_json(orient="records", double_precision=15)
elif isinstance(data, dict):
if "values" not in data:
msg = "values expected in data dict, but not present."
raise KeyError(msg)
return json.dumps(data_native["values"], sort_keys=True)
elif isinstance(data, nw.DataFrame):
return json.dumps(data.rows(named=True))
else:
msg = "to_json only works with data expressed as " "a DataFrame or as a dict"
raise NotImplementedError(msg)
return json.dumps(data["values"], sort_keys=True)
try:
data_nw = nw.from_native(data, eager_only=True)
except TypeError as exc:
msg = "to_json only works with data expressed as a DataFrame or as a dict"
raise NotImplementedError(msg) from exc
data_nw = sanitize_narwhals_dataframe(data_nw)
return json.dumps(data_nw.rows(named=True))


def _data_to_csv_string(data: dict | pd.DataFrame | DataFrameLike) -> str:
def _data_to_csv_string(data: DataType) -> str:
"""Return a CSV string representation of the input data."""
check_data_type(data)
if isinstance(data, SupportsGeoInterface):
Expand All @@ -398,18 +406,12 @@ def _data_to_csv_string(data: dict | pd.DataFrame | DataFrameLike) -> str:
msg = "pandas is required to convert a dict to a CSV string"
raise ImportError(msg) from exc
return pd.DataFrame.from_dict(data["values"]).to_csv(index=False)
elif isinstance(data, DataFrameLike):
# experimental interchange dataframe support
import pyarrow as pa
import pyarrow.csv as pa_csv

pa_table = arrow_table_from_dfi_dataframe(data)
csv_buffer = pa.BufferOutputStream()
pa_csv.write_csv(pa_table, csv_buffer)
return csv_buffer.getvalue().to_pybytes().decode()
else:
msg = "to_csv only works with data expressed as " "a DataFrame or as a dict"
raise NotImplementedError(msg)
try:
data_nw = nw.from_native(data, eager_only=True)
except TypeError as exc:
msg = "to_csv only works with data expressed as a DataFrame or as a dict"
raise NotImplementedError(msg) from exc
return data_nw.write_csv()


def arrow_table_from_dfi_dataframe(dfi_df: DataFrameLike) -> pa.Table:
Expand Down
129 changes: 64 additions & 65 deletions altair/utils/schemapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
Sequence,
TypeVar,
Union,
cast,
overload,
)
from typing_extensions import TypeAlias
Expand Down Expand Up @@ -113,7 +114,7 @@ def validate_jsonschema(
rootschema: dict[str, Any] | None = ...,
*,
raise_error: Literal[True] = ...,
) -> None: ...
) -> Never: ...


@overload
Expand All @@ -128,11 +129,11 @@ def validate_jsonschema(

def validate_jsonschema(
spec,
schema,
rootschema=None,
schema: dict[str, Any],
rootschema: dict[str, Any] | None = None,
*,
raise_error=True,
):
raise_error: bool = True,
) -> jsonschema.exceptions.ValidationError | None:
"""
Validates the passed in spec against the schema in the context of the rootschema.
Expand All @@ -149,7 +150,7 @@ def validate_jsonschema(

# Nothing special about this first error but we need to choose one
# which can be raised
main_error = next(iter(grouped_errors.values()))[0]
main_error: Any = next(iter(grouped_errors.values()))[0]
# All errors are then attached as a new attribute to ValidationError so that
# they can be used in SchemaValidationError to craft a more helpful
# error message. Setting a new attribute like this is not ideal as
Expand Down Expand Up @@ -833,6 +834,41 @@ def is_undefined(obj: Any) -> TypeIs[UndefinedType]:
return obj is Undefined


@overload
def _shallow_copy(obj: _CopyImpl) -> _CopyImpl: ...
@overload
def _shallow_copy(obj: Any) -> Any: ...
def _shallow_copy(obj: _CopyImpl | Any) -> _CopyImpl | Any:
if isinstance(obj, SchemaBase):
return obj.copy(deep=False)
elif isinstance(obj, (list, dict)):
return obj.copy()
else:
return obj


@overload
def _deep_copy(obj: _CopyImpl, by_ref: set[str]) -> _CopyImpl: ...
@overload
def _deep_copy(obj: Any, by_ref: set[str]) -> Any: ...
def _deep_copy(obj: _CopyImpl | Any, by_ref: set[str]) -> _CopyImpl | Any:
copy = partial(_deep_copy, by_ref=by_ref)
if isinstance(obj, SchemaBase):
if copier := getattr(obj, "__deepcopy__", None):
with debug_mode(False):
return copier(obj)
args = (copy(arg) for arg in obj._args)
kwds = {k: (copy(v) if k not in by_ref else v) for k, v in obj._kwds.items()}
with debug_mode(False):
return obj.__class__(*args, **kwds)
elif isinstance(obj, list):
return [copy(v) for v in obj]
elif isinstance(obj, dict):
return {k: (copy(v) if k not in by_ref else v) for k, v in obj.items()}
else:
return obj


class SchemaBase:
"""
Base class for schema wrappers.
Expand Down Expand Up @@ -870,7 +906,7 @@ def __init__(self, *args: Any, **kwds: Any) -> None:
if DEBUG_MODE and self._class_is_valid_at_instantiation:
self.to_dict(validate=True)

def copy( # noqa: C901
def copy(
self, deep: bool | Iterable[Any] = True, ignore: list[str] | None = None
) -> Self:
"""
Expand All @@ -887,53 +923,11 @@ def copy( # noqa: C901
A list of keys for which the contents should not be copied, but
only stored by reference.
"""

def _shallow_copy(obj):
if isinstance(obj, SchemaBase):
return obj.copy(deep=False)
elif isinstance(obj, list):
return obj[:]
elif isinstance(obj, dict):
return obj.copy()
else:
return obj

def _deep_copy(obj, ignore: list[str] | None = None):
if ignore is None:
ignore = []
if isinstance(obj, SchemaBase):
args = tuple(_deep_copy(arg) for arg in obj._args)
kwds = {
k: (_deep_copy(v, ignore=ignore) if k not in ignore else v)
for k, v in obj._kwds.items()
}
with debug_mode(False):
return obj.__class__(*args, **kwds)
elif isinstance(obj, list):
return [_deep_copy(v, ignore=ignore) for v in obj]
elif isinstance(obj, dict):
return {
k: (_deep_copy(v, ignore=ignore) if k not in ignore else v)
for k, v in obj.items()
}
else:
return obj

try:
deep = list(deep) # type: ignore[arg-type]
except TypeError:
deep_is_list = False
else:
deep_is_list = True

if deep and not deep_is_list:
return _deep_copy(self, ignore=ignore)

if deep is True:
return cast("Self", _deep_copy(self, set(ignore) if ignore else set()))
with debug_mode(False):
copy = self.__class__(*self._args, **self._kwds)
if deep_is_list:
# Assert statement is for the benefit of Mypy
assert isinstance(deep, list)
if _is_iterable(deep):
for attr in deep:
copy[attr] = _shallow_copy(copy._get(attr))
return copy
Expand All @@ -953,7 +947,7 @@ def __getattr__(self, attr):
return self._kwds[attr]
else:
try:
_getattr = super().__getattr__
_getattr = super().__getattr__ # pyright: ignore[reportAttributeAccessIssue]
except AttributeError:
_getattr = super().__getattribute__
return _getattr(attr)
Expand Down Expand Up @@ -1202,9 +1196,7 @@ def validate(
schema = cls._schema
# For the benefit of mypy
assert schema is not None
return validate_jsonschema(
instance, schema, rootschema=cls._rootschema or cls._schema
)
validate_jsonschema(instance, schema, rootschema=cls._rootschema or cls._schema)

@classmethod
def resolve_references(cls, schema: dict[str, Any] | None = None) -> dict[str, Any]:
Expand All @@ -1230,7 +1222,7 @@ def validate_property(
np_opt = sys.modules.get("numpy")
value = _todict(value, context={}, np_opt=np_opt, pd_opt=pd_opt)
props = cls.resolve_references(schema or cls._schema).get("properties", {})
return validate_jsonschema(
validate_jsonschema(
value, props.get(name, {}), rootschema=cls._rootschema or cls._schema
)

Expand All @@ -1240,6 +1232,13 @@ def __dir__(self) -> list[str]:

TSchemaBase = TypeVar("TSchemaBase", bound=SchemaBase)

_CopyImpl = TypeVar("_CopyImpl", SchemaBase, Dict[Any, Any], List[Any])
"""
Types which have an implementation in ``SchemaBase.copy()``.
All other types are returned **by reference**.
"""


def _is_dict(obj: Any | dict[Any, Any]) -> TypeIs[dict[Any, Any]]:
return isinstance(obj, dict)
Expand Down Expand Up @@ -1325,11 +1324,11 @@ def from_dict(
@overload
def from_dict(
self,
dct: dict[str, Any],
tp: None = ...,
dct: dict[str, Any] | list[dict[str, Any]],
tp: Any = ...,
schema: Any = ...,
rootschema: None = ...,
default_class: type[TSchemaBase] = ...,
rootschema: Any = ...,
default_class: type[TSchemaBase] = ..., # pyright: ignore[reportInvalidTypeVarUse]
) -> TSchemaBase: ...
@overload
def from_dict(
Expand Down Expand Up @@ -1365,15 +1364,15 @@ def from_dict(
schema: dict[str, Any] | None = None,
rootschema: dict[str, Any] | None = None,
default_class: Any = _passthrough,
) -> TSchemaBase:
) -> TSchemaBase | SchemaBase:
"""Construct an object from a dict representation."""
target_tp: type[TSchemaBase]
target_tp: Any
current_schema: dict[str, Any]
if isinstance(dct, SchemaBase):
return dct # type: ignore[return-value]
return dct
elif tp is not None:
current_schema = tp._schema
root_schema = rootschema or tp._rootschema or current_schema
root_schema: dict[str, Any] = rootschema or tp._rootschema or current_schema
target_tp = tp
elif schema is not None:
# If there are multiple matches, we use the first one in the dict.
Expand Down
6 changes: 5 additions & 1 deletion altair/vegalite/v5/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from typing_extensions import TypeAlias

import jsonschema
import narwhals.stable.v1 as nw

from altair import utils
from altair.expr import core as _expr_core
Expand Down Expand Up @@ -274,7 +275,7 @@ def _prepare_data(
# convert dataframes or objects with __geo_interface__ to dict
elif not isinstance(data, dict) and _is_data_type(data):
if func := data_transformers.get():
data = func(data)
data = func(nw.to_native(data, strict=False))

# convert string input to a URLData
elif isinstance(data, str):
Expand Down Expand Up @@ -1060,6 +1061,9 @@ def to_dict(self, *args: Any, **kwds: Any) -> _Conditional[_C]: # type: ignore[
m = super().to_dict(*args, **kwds)
return _Conditional(condition=m["condition"])

def __deepcopy__(self, memo: Any) -> Self:
return type(self)(_Conditional(condition=_deepcopy(self.condition)))


class ChainedWhen(_BaseWhen):
"""
Expand Down
Loading

0 comments on commit 4e02ac4

Please sign in to comment.