From b606a7d6e6980865930d8bb8cb720d6340855782 Mon Sep 17 00:00:00 2001 From: dangotbanned <125183946+dangotbanned@users.noreply.github.com> Date: Wed, 29 Jan 2025 15:05:47 +0000 Subject: [PATCH] feat(DRAFT): Private API overhaul **Public API is unchanged** Core changes are to simplify testing and extension: - `_readers.py` -> `_reader.py` - w/ two new support modules `_constraints`, and `_readimpl` - Functions (`BaseImpl`) are declared with what they support (`include`) and restrictions (`exclude`) on that subset - Transforms a lot of the imperative logic into set operations - Greatly improved `pyarrow` support - Utilize schema - Provides additional fallback `.json` implementations - `_stdlib_read_json_to_arrow` finally resolves `"movies.json"` issue --- altair/datasets/_cache.py | 106 +++++- altair/datasets/_constraints.py | 115 +++++++ altair/datasets/_exceptions.py | 78 +++-- altair/datasets/_loader.py | 40 ++- altair/datasets/_reader.py | 540 ++++++++++++++++++++++++++++++ altair/datasets/_readers.py | 574 -------------------------------- altair/datasets/_readimpl.py | 414 +++++++++++++++++++++++ tests/test_datasets.py | 79 ++--- 8 files changed, 1260 insertions(+), 686 deletions(-) create mode 100644 altair/datasets/_constraints.py create mode 100644 altair/datasets/_reader.py delete mode 100644 altair/datasets/_readers.py create mode 100644 altair/datasets/_readimpl.py diff --git a/altair/datasets/_cache.py b/altair/datasets/_cache.py index a415a8380..9abe09726 100644 --- a/altair/datasets/_cache.py +++ b/altair/datasets/_cache.py @@ -5,10 +5,9 @@ from collections import defaultdict from importlib.util import find_spec from pathlib import Path -from typing import TYPE_CHECKING, ClassVar, Generic, TypeVar, cast, get_args +from typing import TYPE_CHECKING, ClassVar, TypeVar, cast, get_args import narwhals.stable.v1 as nw -from narwhals.stable.v1.typing import IntoDataFrameT, IntoFrameT from altair.datasets._exceptions import AltairDatasetsError from altair.datasets._typing import Dataset @@ -29,12 +28,18 @@ ) from io import IOBase from typing import Any, Final + from urllib.request import OpenerDirector from _typeshed import StrPath from narwhals.stable.v1.dtypes import DType + from narwhals.stable.v1.typing import IntoExpr from altair.datasets._typing import Metadata + if sys.version_info >= (3, 12): + from typing import Unpack + else: + from typing_extensions import Unpack if sys.version_info >= (3, 11): from typing import LiteralString else: @@ -43,8 +48,8 @@ from typing import TypeAlias else: from typing_extensions import TypeAlias - from altair.datasets._readers import _Reader from altair.datasets._typing import FlFieldStr + from altair.vegalite.v5.schema._typing import OneOrSeq _Dataset: TypeAlias = "Dataset | LiteralString" _FlSchema: TypeAlias = Mapping[str, FlFieldStr] @@ -83,6 +88,10 @@ https://narwhals-dev.github.io/narwhals/api-reference/dtypes/ """ +_FIELD_TO_DTYPE: Mapping[FlFieldStr, type[DType]] = { + v: k for k, v in _DTYPE_TO_FIELD.items() +} + def _iter_metadata(df: nw.DataFrame[Any], /) -> Iterator[Metadata]: """ @@ -179,10 +188,7 @@ def rotated(self) -> Mapping[str, Sequence[Any]]: self._rotated[k].append(v) return self._rotated - def metadata(self, ns: Any, /) -> nw.LazyFrame: - data: Any = self.rotated - return nw.maybe_convert_dtypes(nw.from_dict(data, native_namespace=ns).lazy()) - + # TODO: Evaluate which errors are now obsolete def __getitem__(self, key: _Dataset, /) -> Metadata: if meta := self.get(key, None): return meta @@ -194,6 +200,7 @@ def __getitem__(self, key: _Dataset, /) -> Metadata: msg = f"{key!r} does not refer to a known dataset." raise TypeError(msg) + # TODO: Evaluate which errors are now obsolete def url(self, name: _Dataset, /) -> str: if meta := self.get(name, None): if meta["suffix"] == ".parquet" and not find_spec("vegafusion"): @@ -207,6 +214,9 @@ def url(self, name: _Dataset, /) -> str: msg = f"{name!r} does not refer to a known dataset." raise TypeError(msg) + def __repr__(self) -> str: + return f"<{type(self).__name__}: {'COLLECTED' if self._mapping else 'READY'}>" + class SchemaCache(CompressedCache["_Dataset", "_FlSchema"]): """ @@ -230,8 +240,10 @@ def __init__( self, *, tp: type[MutableMapping[_Dataset, _FlSchema]] = dict["_Dataset", "_FlSchema"], + implementation: nw.Implementation = nw.Implementation.UNKNOWN, ) -> None: self._mapping: MutableMapping[_Dataset, _FlSchema] = tp() + self._implementation: nw.Implementation = implementation def read(self) -> Any: import json @@ -259,8 +271,63 @@ def by_dtype(self, name: _Dataset, *dtypes: type[DType]) -> list[str]: else: return list(match) + def is_active(self) -> bool: + return self._implementation in { + nw.Implementation.PANDAS, + nw.Implementation.PYARROW, + nw.Implementation.MODIN, + nw.Implementation.PYARROW, + } + + def schema_kwds(self, meta: Metadata, /) -> dict[str, Any]: + name: Any = meta["dataset_name"] + impl = self._implementation + if (impl.is_pandas_like() or impl.is_pyarrow()) and (self[name]): + suffix = meta["suffix"] + if impl.is_pandas_like(): + if cols := self.by_dtype(name, nw.Date, nw.Datetime): + if suffix == ".json": + return {"convert_dates": cols} + elif suffix in {".csv", ".tsv"}: + return {"parse_dates": cols} + else: + schema = self.schema_pyarrow(name) + if suffix in {".csv", ".tsv"}: + from pyarrow.csv import ConvertOptions + + return {"convert_options": ConvertOptions(column_types=schema)} # pyright: ignore[reportCallIssue] + elif suffix == ".parquet": + return {"schema": schema} + + return {} + + def schema(self, name: _Dataset, /) -> Mapping[str, DType]: + return { + column: _FIELD_TO_DTYPE[tp_str]() for column, tp_str in self[name].items() + } + + # TODO: Open an issue in ``narwhals`` to try and get a public api for type conversion + def schema_pyarrow(self, name: _Dataset, /): + schema = self.schema(name) + if schema: + from narwhals._arrow.utils import narwhals_to_native_dtype + from narwhals.utils import Version -class DatasetCache(Generic[IntoDataFrameT, IntoFrameT]): + m = {k: narwhals_to_native_dtype(v, Version.V1) for k, v in schema.items()} + else: + m = {} + return nw.dependencies.get_pyarrow().schema(m) + + +class _SupportsScanMetadata(Protocol): + _opener: ClassVar[OpenerDirector] + + def _scan_metadata( + self, *predicates: OneOrSeq[IntoExpr], **constraints: Unpack[Metadata] + ) -> nw.LazyFrame: ... + + +class DatasetCache: """Opt-out caching of remote dataset requests.""" _ENV_VAR: ClassVar[LiteralString] = "ALTAIR_DATASETS_DIR" @@ -268,8 +335,8 @@ class DatasetCache(Generic[IntoDataFrameT, IntoFrameT]): Path(os.environ.get("XDG_CACHE_HOME", Path.home() / ".cache")) / "altair" ).resolve() - def __init__(self, reader: _Reader[IntoDataFrameT, IntoFrameT], /) -> None: - self._rd: _Reader[IntoDataFrameT, IntoFrameT] = reader + def __init__(self, reader: _SupportsScanMetadata, /) -> None: + self._rd: _SupportsScanMetadata = reader def clear(self) -> None: """Delete all previously cached datasets.""" @@ -308,10 +375,24 @@ def download_all(self) -> None: return None print(f"Downloading {len(frame)} missing datasets...") for meta in _iter_metadata(frame): - self._rd._download(meta["url"], self.path / (meta["sha"] + meta["suffix"])) + self._download_one(meta["url"], self.path_meta(meta)) print("Finished downloads") return None + def _maybe_download(self, meta: Metadata, /) -> Path: + fp = self.path_meta(meta) + return ( + fp + if (fp.exists() and fp.stat().st_size) + else self._download_one(meta["url"], fp) + ) + + def _download_one(self, url: str, fp: Path, /) -> Path: + with self._rd._opener.open(url) as f: + fp.touch() + fp.write_bytes(f.read()) + return fp + @property def path(self) -> Path: """ @@ -354,6 +435,9 @@ def path(self, source: StrPath | None, /) -> None: else: os.environ[self._ENV_VAR] = "" + def path_meta(self, meta: Metadata, /) -> Path: + return self.path / (meta["sha"] + meta["suffix"]) + def __iter__(self) -> Iterator[Path]: yield from self.path.iterdir() diff --git a/altair/datasets/_constraints.py b/altair/datasets/_constraints.py new file mode 100644 index 000000000..e5eaa3b97 --- /dev/null +++ b/altair/datasets/_constraints.py @@ -0,0 +1,115 @@ +"""Set-like guards for matching metadata to an implementation.""" + +from __future__ import annotations + +from collections.abc import Set +from itertools import chain +from typing import TYPE_CHECKING, Any + +from narwhals.stable import v1 as nw + +if TYPE_CHECKING: + import sys + from collections.abc import Iterable, Iterator + + from altair.datasets._typing import Metadata + + if sys.version_info >= (3, 12): + from typing import Unpack + else: + from typing_extensions import Unpack + if sys.version_info >= (3, 10): + from typing import TypeAlias + else: + from typing_extensions import TypeAlias + +__all__ = [ + "Items", + "MetaIs", + "is_arrow", + "is_csv", + "is_json", + "is_meta", + "is_not_tabular", + "is_parquet", + "is_spatial", + "is_tsv", +] + +Items: TypeAlias = Set[tuple[str, Any]] + + +class MetaIs(Set[tuple[str, Any]]): + _requires: frozenset[tuple[str, Any]] + + def __init__(self, kwds: frozenset[tuple[str, Any]], /) -> None: + object.__setattr__(self, "_requires", kwds) + + @classmethod + def from_metadata(cls, meta: Metadata, /) -> MetaIs: + return cls(frozenset(meta.items())) + + def to_metadata(self) -> Metadata: + if TYPE_CHECKING: + + def collect(**kwds: Unpack[Metadata]) -> Metadata: + return kwds + + return collect(**dict(self)) + return dict(self) + + def to_expr(self) -> nw.Expr: + return nw.all_horizontal(nw.col(name) == val for name, val in self) + + def isdisjoint(self, other: Iterable[Any]) -> bool: + return super().isdisjoint(other) + + def issubset(self, other: Iterable[Any]) -> bool: + return self._requires.issubset(other) + + def __call__(self, meta: Items, /) -> bool: + return self._requires <= meta + + def __hash__(self) -> int: + return hash(self._requires) + + def __contains__(self, x: object) -> bool: + return self._requires.__contains__(x) + + def __iter__(self) -> Iterator[tuple[str, Any]]: + yield from self._requires + + def __len__(self) -> int: + return self._requires.__len__() + + def __setattr__(self, name: str, value: Any): + msg = ( + f"{type(self).__name__!r} is immutable.\n" + f"Could not assign self.{name} = {value}" + ) + raise TypeError(msg) + + def __repr__(self) -> str: + items = dict(self) + if not items: + contents = "" + elif suffix := items.pop("suffix", None): + contents = ", ".join( + chain([f"'*{suffix}'"], (f"{k}={v!r}" for k, v in items.items())) + ) + else: + contents = ", ".join(f"{k}={v!r}" for k, v in items.items()) + return f"is_meta({contents})" + + +def is_meta(**kwds: Unpack[Metadata]) -> MetaIs: + return MetaIs.from_metadata(kwds) + + +is_csv = is_meta(suffix=".csv") +is_json = is_meta(suffix=".json") +is_tsv = is_meta(suffix=".tsv") +is_arrow = is_meta(suffix=".arrow") +is_parquet = is_meta(suffix=".parquet") +is_spatial = is_meta(is_spatial=True) +is_not_tabular = is_meta(is_tabular=False) diff --git a/altair/datasets/_exceptions.py b/altair/datasets/_exceptions.py index 36dba27ef..2f9c13d45 100644 --- a/altair/datasets/_exceptions.py +++ b/altair/datasets/_exceptions.py @@ -5,7 +5,7 @@ if TYPE_CHECKING: from collections.abc import Sequence - from altair.datasets._readers import _Backend + from altair.datasets._reader import _Backend from altair.datasets._typing import Metadata @@ -26,6 +26,19 @@ def from_url(cls, meta: Metadata, /) -> AltairDatasetsError: raise NotImplementedError(msg) return cls(msg) + @classmethod + def from_tabular(cls, meta: Metadata, backend_name: str, /) -> AltairDatasetsError: + install_other = None + mid = "\n" + if not meta["is_image"] and not meta["is_tabular"]: + install_other = "polars" + if meta["is_spatial"]: + mid = f"Geospatial data is not supported natively by {backend_name!r}." + elif meta["is_json"]: + mid = f"Non-tabular json is not supported natively by {backend_name!r}." + msg = f"{_failed_tabular(meta)}{mid}{_suggest_url(meta, install_other)}" + return cls(msg) + @classmethod def from_priority(cls, priority: Sequence[_Backend], /) -> AltairDatasetsError: msg = f"Found no supported backend, searched:\n{priority!r}" @@ -33,12 +46,12 @@ def from_priority(cls, priority: Sequence[_Backend], /) -> AltairDatasetsError: def module_not_found( - backend_name: str, reqs: str | tuple[str, ...], missing: str + backend_name: str, reqs: Sequence[str], missing: str ) -> ModuleNotFoundError: - if isinstance(reqs, tuple): - depends = ", ".join(f"{req!r}" for req in reqs) + " packages" + if len(reqs) == 1: + depends = f"{reqs[0]!r} package" else: - depends = f"{reqs!r} package" + depends = ", ".join(f"{req!r}" for req in reqs) + " packages" msg = ( f"Backend {backend_name!r} requires the {depends}, but {missing!r} could not be found.\n" f"This can be installed with pip using:\n" @@ -49,29 +62,6 @@ def module_not_found( return ModuleNotFoundError(msg, name=missing) -def image(meta: Metadata, /) -> AltairDatasetsError: - msg = f"{_failed_tabular(meta)}\n{_suggest_url(meta)}" - return AltairDatasetsError(msg) - - -def geospatial(meta: Metadata, backend_name: str) -> NotImplementedError: - msg = ( - f"{_failed_tabular(meta)}" - f"Geospatial data is not supported natively by {backend_name!r}." - f"{_suggest_url(meta, 'polars')}" - ) - return NotImplementedError(msg) - - -def non_tabular_json(meta: Metadata, backend_name: str) -> NotImplementedError: - msg = ( - f"{_failed_tabular(meta)}" - f"Non-tabular json is not supported natively by {backend_name!r}." - f"{_suggest_url(meta, 'polars')}" - ) - return NotImplementedError(msg) - - def _failed_url(meta: Metadata, /) -> str: return f"Unable to load {meta['file_name']!r} via url.\n" @@ -87,3 +77,35 @@ def _suggest_url(meta: Metadata, install_other: str | None = None) -> str: " from altair.datasets import url\n" f" url({meta['dataset_name']!r})" ) + + +# TODO: +# - Use `AltairDatasetsError` +# - Remove notes from doc +# - Improve message and how data is selected +def implementation_not_found(meta: Metadata, /) -> NotImplementedError: + """ + Search finished without finding a *declared* incompatibility. + + Notes + ----- + - New kind of error + - Previously, every backend had a function assigned + - But they might not all work + - Now, only things that are known to be widely safe are added + - Should probably suggest using a pre-defined backend that supports everything + - What can reach here? + - `is_image` (all) + - `"pandas"` (using inference wont trigger these) + - `.arrow` (w/o `pyarrow`) + - `.parquet` (w/o either `pyarrow` or `fastparquet`) + """ + INDENT = " " * 4 + record = f",\n{INDENT}".join( + f"{k}={v!r}" + for k, v in meta.items() + if not (k.startswith(("is_", "sha", "bytes", "has_"))) + or (v is True and k.startswith("is_")) + ) + msg = f"Found no implementation that supports:\n{INDENT}{record}" + return NotImplementedError(msg) diff --git a/altair/datasets/_loader.py b/altair/datasets/_loader.py index 8f13ab2de..9b55daf70 100644 --- a/altair/datasets/_loader.py +++ b/altair/datasets/_loader.py @@ -2,9 +2,10 @@ from typing import TYPE_CHECKING, Generic, final, overload -from narwhals.stable.v1.typing import IntoDataFrameT, IntoFrameT +from narwhals.stable.v1.typing import IntoDataFrameT -from altair.datasets._readers import _Reader, backend +from altair.datasets import _reader +from altair.datasets._reader import IntoFrameT if TYPE_CHECKING: import sys @@ -13,14 +14,16 @@ import pandas as pd import polars as pl import pyarrow as pa + from narwhals.stable import v1 as nw from altair.datasets._cache import DatasetCache + from altair.datasets._reader import Reader if sys.version_info >= (3, 11): - from typing import LiteralString + from typing import LiteralString, Self else: - from typing_extensions import LiteralString - from altair.datasets._readers import _Backend + from typing_extensions import LiteralString, Self + from altair.datasets._reader import _Backend from altair.datasets._typing import Dataset, Extension @@ -43,7 +46,7 @@ class Loader(Generic[IntoDataFrameT, IntoFrameT]): https://github.com/vega/vega-datasets """ - _reader: _Reader[IntoDataFrameT, IntoFrameT] + _reader: Reader[IntoDataFrameT, IntoFrameT] @overload @classmethod @@ -55,16 +58,18 @@ def from_backend( @classmethod def from_backend( cls, backend_name: Literal["pandas", "pandas[pyarrow]"], / - ) -> Loader[pd.DataFrame, pd.DataFrame]: ... + ) -> Loader[pd.DataFrame, nw.LazyFrame]: ... @overload @classmethod def from_backend( cls, backend_name: Literal["pyarrow"], / - ) -> Loader[pa.Table, pa.Table]: ... + ) -> Loader[pa.Table, nw.LazyFrame]: ... @classmethod - def from_backend(cls, backend_name: _Backend = "polars", /) -> Loader[Any, Any]: + def from_backend( + cls: type[Loader[Any, Any]], backend_name: _Backend = "polars", / + ) -> Loader[Any, Any]: """ Initialize a new loader, with the specified backend. @@ -128,8 +133,12 @@ def from_backend(cls, backend_name: _Backend = "polars", /) -> Loader[Any, Any]: .. _JSON format not supported: https://arrow.apache.org/docs/python/json.html#reading-json-files """ - obj = Loader.__new__(Loader) - obj._reader = backend(backend_name) + return cls.from_reader(_reader._from_backend(backend_name)) + + @classmethod + def from_reader(cls, reader: Reader[IntoDataFrameT, IntoFrameT], /) -> Self: + obj = cls.__new__(cls) + obj._reader = reader return obj def __call__( @@ -278,7 +287,7 @@ def url( return self._reader.url(name, suffix) @property - def cache(self) -> DatasetCache[IntoDataFrameT, IntoFrameT]: + def cache(self) -> DatasetCache: """ Caching of remote dataset requests. @@ -361,12 +370,9 @@ def __call__( def __getattr__(name): if name == "load": - from altair.datasets._readers import infer_backend - - reader = infer_backend() + reader = _reader.infer_backend() global load - load = _Load.__new__(_Load) - load._reader = reader + load = _Load.from_reader(reader) return load else: msg = f"module {__name__!r} has no attribute {name!r}" diff --git a/altair/datasets/_reader.py b/altair/datasets/_reader.py new file mode 100644 index 000000000..eacc516ba --- /dev/null +++ b/altair/datasets/_reader.py @@ -0,0 +1,540 @@ +""" +Backend for ``alt.datasets.Loader``. + +Notes +----- +Extending would be more ergonomic if `read`, `scan`, `_constraints` were available under a single export:: + + from altair.datasets import ext, reader + import polars as pl + + impls = ( + ext.read(pl.read_parquet, ext.is_parquet), + ext.read(pl.read_csv, ext.is_csv), + ext.read(pl.read_json, ext.is_json), + ) + user_reader = reader(impls) + user_reader.dataset("airports") +""" + +from __future__ import annotations + +from collections import Counter +from collections.abc import Mapping +from importlib import import_module +from importlib.util import find_spec +from itertools import chain +from pathlib import Path +from typing import TYPE_CHECKING, Any, ClassVar, Generic, Literal, cast, overload +from urllib.request import build_opener as _build_opener + +from narwhals.stable import v1 as nw +from narwhals.stable.v1.typing import IntoDataFrameT, IntoExpr +from packaging.requirements import Requirement + +from altair.datasets import _readimpl +from altair.datasets._cache import CsvCache, DatasetCache, SchemaCache, _iter_metadata +from altair.datasets._constraints import is_parquet +from altair.datasets._exceptions import ( + AltairDatasetsError, + implementation_not_found, + module_not_found, +) +from altair.datasets._readimpl import IntoFrameT, is_available +from altair.datasets._typing import EXTENSION_SUFFIXES + +if TYPE_CHECKING: + import sys + from collections.abc import Callable, Sequence + from urllib.request import OpenerDirector + + import pandas as pd + import polars as pl + import pyarrow as pa + + from altair.datasets._readimpl import BaseImpl, R, ReadImpl, ScanImpl + from altair.datasets._typing import Dataset, Extension, Metadata + from altair.vegalite.v5.schema._typing import OneOrSeq + + if sys.version_info >= (3, 13): + from typing import TypeIs, TypeVar + else: + from typing_extensions import TypeIs, TypeVar + if sys.version_info >= (3, 12): + from typing import Unpack + else: + from typing_extensions import Unpack + if sys.version_info >= (3, 11): + from typing import LiteralString + else: + from typing_extensions import LiteralString + if sys.version_info >= (3, 10): + from typing import TypeAlias + else: + from typing_extensions import TypeAlias + _Polars: TypeAlias = Literal["polars"] + _Pandas: TypeAlias = Literal["pandas"] + _PyArrow: TypeAlias = Literal["pyarrow"] + _PandasAny: TypeAlias = Literal[_Pandas, "pandas[pyarrow]"] + _Backend: TypeAlias = Literal[_Polars, _PandasAny, _PyArrow] + _CuDF: TypeAlias = Literal["cudf"] + _Dask: TypeAlias = Literal["dask"] + _DuckDB: TypeAlias = Literal["duckdb"] + _Ibis: TypeAlias = Literal["ibis"] + _PySpark: TypeAlias = Literal["pyspark"] + _NwSupport: TypeAlias = Literal[ + _Polars, _Pandas, _PyArrow, _CuDF, _Dask, _DuckDB, _Ibis, _PySpark + ] + _NwSupportT = TypeVar( + "_NwSupportT", + _Polars, + _Pandas, + _PyArrow, + _CuDF, + _Dask, + _DuckDB, + _Ibis, + _PySpark, + ) + + +class Reader(Generic[IntoDataFrameT, IntoFrameT]): + """ + Modular file reader, targeting remote & local tabular resources. + + .. warning:: + Use ``reader(...)`` instead of instantiating ``Reader`` directly. + """ + + # TODO: Docs + _read: Sequence[ReadImpl[IntoDataFrameT]] + """Eager file read functions.""" + + # TODO: Docs + _scan: Sequence[ScanImpl[IntoFrameT]] + """ + *Optionally*-lazy file read/scan functions. + + Used exclusively for ``metadata.parquet``. + + Currently ``"polars"`` is the only lazy option. + All others defer to the eager variant. + """ + + _name: str + """ + Used in error messages, repr and matching ``@overload``(s). + + Otherwise, has no concrete meaning. + """ + + _implementation: nw.Implementation + """ + Corresponding `narwhals implementation`_. + + .. _narwhals implementation: + https://github.com/narwhals-dev/narwhals/blob/9b6a355530ea46c590d5a6d1d0567be59c0b5742/narwhals/utils.py#L61-L290 + """ + + _opener: ClassVar[OpenerDirector] = _build_opener() + _metadata_path: ClassVar[Path] = ( + Path(__file__).parent / "_metadata" / "metadata.parquet" + ) + + def __init__( + self, + read: Sequence[ReadImpl[IntoDataFrameT]], + scan: Sequence[ScanImpl[IntoFrameT]], + name: str, + implementation: nw.Implementation, + ) -> None: + self._read = read + self._scan = scan + self._name = name + self._implementation = implementation + self._schema_cache = SchemaCache(implementation=implementation) + + # TODO: Finish working on presentation + # - The contents of both are functional + def profile(self, mode: Literal["any", "each"]): + """ + Describe which datasets/groups are supported. + + Focusing on actual datasets, rather than describing wrapped functions (repr) + + .. note:: + Having this public to make testing easier (``tests.test_datasets.is_polars_backed_pyarrow``) + """ + if mode == "any": + relevant_columns = set( + chain.from_iterable(impl._relevant_columns for impl in self._read) + ) + frame = self._scan_metadata().select("dataset_name", *relevant_columns) + it = (impl._include_expr for impl in self._read) + # BUG: ``narwhals`` raises a ``ValueError`` when ``__invert__``-ing a previously used Expr? + # - Can't reproduce trivially + # - Doesnt seem to be related to genexp + inc_expr = nw.any_horizontal(*it) + include = _dataset_names(frame, inc_expr) + exclude = _dataset_names(frame, ~nw.col("dataset_name").is_in(include)) + return {"include": include, "exclude": exclude} + elif mode == "each": + # FIXME: Rough draft of how to group results + # - Don't really want a nested dict + m = {} + frame = self._scan_metadata() + for impl in self._read: + name = impl._contents + m[name] = {"include": _dataset_names(frame, impl._include_expr)} + if impl.exclude: + m[name].update(exclude=_dataset_names(frame, impl._exclude_expr)) + return m + else: + msg = f"Unexpected {mode=}" + raise TypeError(msg) + + def __repr__(self) -> str: + from textwrap import indent + + PREFIX = " " * 4 + NL = "\n" + body = f"read\n{indent(NL.join(el._contents for el in self._read), PREFIX)}" + if self._scan: + body += ( + f"\nscan\n{indent(NL.join(el._contents for el in self._scan), PREFIX)}" + ) + return f"Reader[{self._name}] {self._implementation!r}\n{body}" + + def read_fn(self, meta: Metadata, /) -> Callable[..., IntoDataFrameT]: + return self._solve(meta, self._read) + + def scan_fn(self, meta: Metadata | Path | str, /) -> Callable[..., IntoFrameT]: + meta = meta if isinstance(meta, Mapping) else {"suffix": _into_suffix(meta)} + return self._solve(meta, self._scan) + + @property + def cache(self) -> DatasetCache: + return DatasetCache(self) + + def dataset( + self, + name: Dataset | LiteralString, + suffix: Extension | None = None, + /, + **kwds: Any, + ) -> IntoDataFrameT: + frame = self._query(name, suffix) + meta = next(_iter_metadata(frame)) + fn = self.read_fn(meta) + fn_kwds = self._merge_kwds(meta, kwds) + if self.cache.is_active(): + fp = self.cache._maybe_download(meta) + return fn(fp, **fn_kwds) + else: + with self._opener.open(meta["url"]) as f: + return fn(f, **fn_kwds) + + def url( + self, name: Dataset | LiteralString, suffix: Extension | None = None, / + ) -> str: + frame = self._query(name, suffix) + meta = next(_iter_metadata(frame)) + if is_parquet(meta.items()) and not is_available("vegafusion"): + raise AltairDatasetsError.from_url(meta) + url = meta["url"] + if isinstance(url, str): + return url + else: + msg = f"Expected 'str' but got {type(url).__name__!r}\nfrom {url!r}." + raise TypeError(msg) + + def _query( + self, name: Dataset | LiteralString, suffix: Extension | None = None, / + ) -> nw.DataFrame[IntoDataFrameT]: + """ + Query a tabular version of `vega-datasets/datapackage.json`_. + + Applies a filter, erroring out when no results would be returned. + + .. _vega-datasets/datapackage.json: + https://github.com/vega/vega-datasets/blob/main/datapackage.json + """ + constraints = _into_constraints(name, suffix) + frame = self._scan_metadata(**constraints).collect() + if not frame.is_empty(): + return frame + else: + msg = f"Found no results for:\n {constraints!r}" + raise ValueError(msg) + + # TODO: Docs + def _merge_kwds(self, meta: Metadata, kwds: dict[str, Any], /) -> Mapping[str, Any]: + """ + Hook to utilize ``meta`` to extend ``kwds`` with known helpful defaults. + + - User provided arguments have a higher precedence. + - The keywords for schemas vary between libraries + - pandas is internally inconsistent + - By default, returns unchanged + """ + if self._schema_cache.is_active() and ( + schema := self._schema_cache.schema_kwds(meta) + ): + kwds = schema | kwds if kwds else schema + return kwds + + @property + def _metadata_frame(self) -> nw.LazyFrame: + fp = self._metadata_path + return nw.from_native(self.scan_fn(fp)(fp)).lazy() + + def _scan_metadata( + self, *predicates: OneOrSeq[IntoExpr], **constraints: Unpack[Metadata] + ) -> nw.LazyFrame: + if predicates or constraints: + return self._metadata_frame.filter(*predicates, **constraints) + return self._metadata_frame + + # TODO: Docs + def _solve( + self, meta: Metadata, impls: Sequence[BaseImpl[R]], / + ) -> Callable[..., R]: + """ + Return the first function meeting constraints of meta. + + Notes + ----- + - Iterate over impls + - Each one can either match or signal an error + - An error blocks any additional checking + - Both include & exclude + - Uses ``ItemsView`` to support set ops + - `meta` isn't iterated over + - Leaves the door open for caching the search space + """ + items = meta.items() + it = (some for impl in impls if (some := impl.unwrap_or(items))) + if fn_or_err := next(it, None): + if _is_err(fn_or_err): + raise fn_or_err.from_tabular(meta, self._name) + return fn_or_err + if meta["is_image"]: + raise AltairDatasetsError.from_tabular(meta, self._name) + raise implementation_not_found(meta) + + +# TODO: Review after finishing `profile` +# NOTE: Temp helper function for `Reader.profile` +def _dataset_names( + frame: nw.LazyFrame, + *predicates: OneOrSeq[IntoExpr], + **constraints: Unpack[Metadata], +): + return ( + frame.filter(*predicates, **constraints) + .select("dataset_name") + .collect() + .get_column("dataset_name") + .to_list() + ) + + +class _NoParquetReader(Reader[IntoDataFrameT, IntoFrameT]): + def __repr__(self) -> str: + return f"{super().__repr__()}\ncsv_cache\n {self.csv_cache!r}" + + @property + def csv_cache(self) -> CsvCache: + if not hasattr(self, "_csv_cache"): + self._csv_cache = CsvCache() + return self._csv_cache + + @property + def _metadata_frame(self) -> nw.LazyFrame: + ns = self._implementation.to_native_namespace() + data = cast("dict[str, Any]", self.csv_cache.rotated) + return nw.maybe_convert_dtypes(nw.from_dict(data, native_namespace=ns)).lazy() + + +@overload +def reader( + read_fns: Sequence[ReadImpl[IntoDataFrameT]], + scan_fns: tuple[()] = ..., + *, + name: str | None = ..., + implementation: nw.Implementation = ..., +) -> Reader[IntoDataFrameT, nw.LazyFrame]: ... + + +@overload +def reader( + read_fns: Sequence[ReadImpl[IntoDataFrameT]], + scan_fns: Sequence[ScanImpl[IntoFrameT]], + *, + name: str | None = ..., + implementation: nw.Implementation = ..., +) -> Reader[IntoDataFrameT, IntoFrameT]: ... + + +def reader( + read_fns: Sequence[ReadImpl[IntoDataFrameT]], + scan_fns: Sequence[ScanImpl[IntoFrameT]] = (), + *, + name: str | None = None, + implementation: nw.Implementation = nw.Implementation.UNKNOWN, +) -> Reader[IntoDataFrameT, IntoFrameT] | Reader[IntoDataFrameT, nw.LazyFrame]: + name = name or Counter(el._inferred_package for el in read_fns).most_common(1)[0][0] + if implementation is nw.Implementation.UNKNOWN: + implementation = _into_implementation(Requirement(name)) + if scan_fns: + return Reader(read_fns, scan_fns, name, implementation) + if stolen := _steal_eager_parquet(read_fns): + return Reader(read_fns, stolen, name, implementation) + else: + return _NoParquetReader[IntoDataFrameT](read_fns, (), name, implementation) + + +def infer_backend( + *, priority: Sequence[_Backend] = ("polars", "pandas[pyarrow]", "pandas", "pyarrow") +) -> Reader[Any, Any]: + """ + Return the first available reader in order of `priority`. + + Notes + ----- + - ``"polars"``: can natively load every dataset (including ``(Geo|Topo)JSON``) + - ``"pandas[pyarrow]"``: can load *most* datasets, guarantees ``.parquet`` support + - ``"pandas"``: supports ``.parquet``, if `fastparquet`_ is installed + - ``"pyarrow"``: least reliable + + .. _fastparquet: + https://github.com/dask/fastparquet + """ + it = (_from_backend(name) for name in priority if is_available(_requirements(name))) + if reader := next(it, None): + return reader + raise AltairDatasetsError.from_priority(priority) + + +@overload +def _from_backend(name: _Polars, /) -> Reader[pl.DataFrame, pl.LazyFrame]: ... +@overload +def _from_backend(name: _PandasAny, /) -> Reader[pd.DataFrame, nw.LazyFrame]: ... +@overload +def _from_backend(name: _PyArrow, /) -> Reader[pa.Table, nw.LazyFrame]: ... + + +# FIXME: The order this is defined in makes splitting the module complicated +# - Can't use a classmethod, since some result in a subclass used +def _from_backend(name: _Backend, /) -> Reader[Any, Any]: + """ + Reader initialization dispatcher. + + FIXME: Works, but defining these in mixed shape functions seems off. + """ + if not _is_backend(name): + msg = f"Unknown backend {name!r}" + raise TypeError(msg) + implementation = _into_implementation(name) + if name == "polars": + rd, sc = _readimpl.pl_only() + return reader(rd, sc, name=name, implementation=implementation) + elif name == "pandas[pyarrow]": + return reader(_readimpl.pd_pyarrow(), name=name, implementation=implementation) + elif name == "pandas": + return reader(_readimpl.pd_only(), name=name, implementation=implementation) + elif name == "pyarrow": + return reader(_readimpl.pa_any(), name=name, implementation=implementation) + + +def _is_backend(obj: Any) -> TypeIs[_Backend]: + return obj in {"polars", "pandas", "pandas[pyarrow]", "pyarrow"} + + +def _is_err(obj: Any) -> TypeIs[type[AltairDatasetsError]]: + return obj is AltairDatasetsError + + +def _into_constraints( + name: Dataset | LiteralString, suffix: Extension | None, / +) -> Metadata: + """Transform args into a mapping to column names.""" + m: Metadata = {} + if "." in name: + m["file_name"] = name + elif suffix is None: + m["dataset_name"] = name + elif suffix.startswith("."): + m = {"dataset_name": name, "suffix": suffix} + else: + msg = ( + f"Expected 'suffix' to be one of {EXTENSION_SUFFIXES!r},\n" + f"but got: {suffix!r}" + ) + raise TypeError(msg) + return m + + +def _into_implementation( + backend: _NwSupport | _PandasAny | Requirement, / +) -> nw.Implementation: + primary = _import_guarded(backend) + mapping: Mapping[LiteralString, nw.Implementation] = { + "polars": nw.Implementation.POLARS, + "pandas": nw.Implementation.PANDAS, + "pyarrow": nw.Implementation.PYARROW, + "cudf": nw.Implementation.CUDF, + "dask": nw.Implementation.DASK, + "duckdb": nw.Implementation.DUCKDB, + "ibis": nw.Implementation.IBIS, + "pyspark": nw.Implementation.PYSPARK, + } + if impl := mapping.get(primary): + return impl + msg = f"Package {primary!r} is not supported by `narhwals`." + raise ValueError(msg) + + +def _into_suffix(obj: Path | str, /) -> Any: + if isinstance(obj, Path): + return obj.suffix + elif isinstance(obj, str): + return obj + else: + msg = f"Unexpected type {type(obj).__name__!r}" + raise TypeError(msg) + + +def _steal_eager_parquet( + read_fns: Sequence[ReadImpl[IntoDataFrameT]], / +) -> Sequence[ScanImpl[nw.LazyFrame]] | None: + if convertable := next((rd for rd in read_fns if rd.include <= is_parquet), None): + return (convertable.to_scan_impl(),) + return None + + +@overload +def _import_guarded(req: _PandasAny, /) -> _Pandas: ... + + +@overload +def _import_guarded(req: _NwSupportT, /) -> _NwSupportT: ... + + +@overload +def _import_guarded(req: Requirement, /) -> LiteralString: ... + + +def _import_guarded(req: Any, /) -> LiteralString: + requires = _requirements(req) + for name in requires: + if spec := find_spec(name): + import_module(spec.name) + else: + raise module_not_found(str(req), requires, missing=name) + return requires[0] + + +def _requirements(req: Requirement | str, /) -> tuple[Any, ...]: + req = Requirement(req) if isinstance(req, str) else req + return (req.name, *req.extras) diff --git a/altair/datasets/_readers.py b/altair/datasets/_readers.py deleted file mode 100644 index a1f66dee1..000000000 --- a/altair/datasets/_readers.py +++ /dev/null @@ -1,574 +0,0 @@ -""" -Backends for ``alt.datasets.Loader``. - -- Interfacing with the cached metadata. - - But not updating it -- Performing requests from those urls -- Dispatching read function on file extension -""" - -from __future__ import annotations - -import urllib.request -from collections.abc import Callable, Iterable, Mapping, Sequence -from functools import partial -from importlib import import_module -from importlib.util import find_spec -from itertools import chain -from pathlib import Path -from typing import ( - TYPE_CHECKING, - Any, - ClassVar, - Final, - Literal, - Protocol, - TypeVar, - overload, -) - -import narwhals.stable.v1 as nw -from narwhals.stable.v1.typing import IntoDataFrameT, IntoExpr, IntoFrameT - -from altair.datasets import _exceptions as _ds_exc -from altair.datasets._cache import CsvCache, DatasetCache, SchemaCache, _iter_metadata -from altair.datasets._typing import EXTENSION_SUFFIXES, Metadata, is_ext_read - -if TYPE_CHECKING: - import sys - from io import IOBase - from urllib.request import OpenerDirector - - import pandas as pd - import polars as pl - import pyarrow as pa - from _typeshed import StrPath - from pyarrow.csv import read_csv as pa_read_csv # noqa: F401 - from pyarrow.feather import read_table as pa_read_feather # noqa: F401 - from pyarrow.json import read_json as pa_read_json # noqa: F401 - from pyarrow.parquet import read_table as pa_read_parquet # noqa: F401 - - if sys.version_info >= (3, 13): - from typing import TypeIs, Unpack - else: - from typing_extensions import TypeIs, Unpack - if sys.version_info >= (3, 11): - from typing import LiteralString - else: - from typing_extensions import LiteralString - if sys.version_info >= (3, 10): - from typing import TypeAlias - else: - from typing_extensions import TypeAlias - from packaging.requirements import Requirement - - from altair.datasets._typing import Dataset, Extension, Metadata - from altair.vegalite.v5.schema._typing import OneOrSeq - - _IntoSuffix: TypeAlias = "StrPath | Metadata" - _ExtensionScan: TypeAlias = Literal[".parquet"] - _T = TypeVar("_T") - - # NOTE: Using a constrained instead of bound `TypeVar` - # error: Incompatible return value type (got "DataFrame[Any] | LazyFrame[Any]", expected "FrameT") [return-value] - # - https://typing.readthedocs.io/en/latest/spec/generics.html#introduction - # - https://typing.readthedocs.io/en/latest/spec/generics.html#type-variables-with-an-upper-bound - # https://github.com/narwhals-dev/narwhals/blob/21b8436567de3631c584ef67632317ad70ae5de0/narwhals/typing.py#L59 - FrameT = TypeVar("FrameT", nw.DataFrame[Any], nw.LazyFrame) - - _Polars: TypeAlias = Literal["polars"] - _Pandas: TypeAlias = Literal["pandas"] - _PyArrow: TypeAlias = Literal["pyarrow"] - _ConcreteT = TypeVar("_ConcreteT", _Polars, _Pandas, _PyArrow) - _PandasAny: TypeAlias = Literal[_Pandas, "pandas[pyarrow]"] - _Backend: TypeAlias = Literal[_Polars, _PandasAny, _PyArrow] - - -__all__ = ["backend", "infer_backend"] - -_METADATA: Final[Path] = Path(__file__).parent / "_metadata" / "metadata.parquet" - - -class _Reader(Protocol[IntoDataFrameT, IntoFrameT]): - """ - Describes basic IO for remote & local tabular resources. - - Subclassing this protocol directly will provide a *mostly* complete implementation. - - Each of the following must be explicitly assigned: - - _Reader._read_fn - _Reader._scan_fn - _Reader._name - """ - - _read_fn: Mapping[Extension, Callable[..., IntoDataFrameT]] - """ - Eager file read functions. - - Each corresponds to a known file extension within ``vega-datasets``. - """ - - _scan_fn: Mapping[_ExtensionScan, Callable[..., IntoFrameT]] - """ - *Optionally*-lazy file read/scan functions. - - Used exclusively for ``metadata.parquet``. - - Currently ``"polars"`` is the only lazy option. - """ - - _name: LiteralString - """ - Used in error messages, repr and matching ``@overload``(s). - - Otherwise, has no concrete meaning. - """ - - _opener: ClassVar[OpenerDirector] = urllib.request.build_opener() - - def read_fn(self, source: _IntoSuffix, /) -> Callable[..., IntoDataFrameT]: - return self._read_fn[_extract_suffix(source, is_ext_read)] - - def scan_fn(self, source: _IntoSuffix, /) -> Callable[..., IntoFrameT]: - return self._scan_fn[_extract_suffix(source, is_ext_scan)] - - def _schema_kwds(self, meta: Metadata, /) -> dict[str, Any]: - """Hook to provide additional schema metadata on read.""" - return {} - - def _maybe_fn(self, meta: Metadata, /) -> Callable[..., IntoDataFrameT]: - """Backend specific tweaks/errors/warnings, based on ``Metadata``.""" - if meta["is_image"]: - raise _ds_exc.image(meta) - return self.read_fn(meta) - - def dataset( - self, - name: Dataset | LiteralString, - suffix: Extension | None = None, - /, - **kwds: Any, - ) -> IntoDataFrameT: - df = self.query(**_extract_constraints(name, suffix)) - meta = next(_iter_metadata(df)) - fn = self._maybe_fn(meta) - url = meta["url"] - if default_kwds := self._schema_kwds(meta): - kwds = default_kwds | kwds if kwds else default_kwds - - if self.cache.is_active(): - fp = self.cache.path / (meta["sha"] + meta["suffix"]) - if not (fp.exists() and fp.stat().st_size): - self._download(url, fp) - return fn(fp, **kwds) - else: - with self._opener.open(url) as f: - return fn(f, **kwds) - - def url( - self, - name: Dataset | LiteralString, - suffix: Extension | None = None, - /, - ) -> str: - frame = self.query(**_extract_constraints(name, suffix)) - meta = next(_iter_metadata(frame)) - if meta["suffix"] == ".parquet" and not is_available("vegafusion"): - raise _ds_exc.AltairDatasetsError.from_url(meta) - url = meta["url"] - if isinstance(url, str): - return url - else: - msg = f"Expected 'str' but got {type(url).__name__!r}\nfrom {url!r}." - raise TypeError(msg) - - def query( - self, *predicates: OneOrSeq[IntoExpr], **constraints: Unpack[Metadata] - ) -> nw.DataFrame[IntoDataFrameT]: - """ - Query a tabular version of `vega-datasets/datapackage.json`_. - - Applies a filter, erroring out when no results would be returned. - - Notes - ----- - Arguments correspond to those seen in `pl.LazyFrame.filter`_. - - .. _vega-datasets/datapackage.json: - https://github.com/vega/vega-datasets/blob/main/datapackage.json - .. _pl.LazyFrame.filter: - https://docs.pola.rs/api/python/stable/reference/lazyframe/api/polars.LazyFrame.filter.html - """ - frame = self._scan_metadata(*predicates, **constraints).collect() - if not frame.is_empty(): - return frame - else: - terms = "\n".join(f"{t!r}" for t in (predicates, constraints) if t) - msg = f"Found no results for:\n {terms}" - raise ValueError(msg) - - def _scan_metadata( - self, *predicates: OneOrSeq[IntoExpr], **constraints: Unpack[Metadata] - ) -> nw.LazyFrame: - if predicates or constraints: - return self._metadata.filter(*predicates, **constraints) - return self._metadata - - @property - def _metadata(self) -> nw.LazyFrame: - return nw.from_native(self.scan_fn(_METADATA)(_METADATA)).lazy() - - def _download(self, url: str, fp: Path, /) -> None: - with self._opener.open(url) as f: - fp.touch() - fp.write_bytes(f.read()) - - @property - def cache(self) -> DatasetCache[IntoDataFrameT, IntoFrameT]: - return DatasetCache(self) - - def _import(self, name: str, /) -> Any: - if spec := find_spec(name): - return import_module(spec.name) - raise _ds_exc.module_not_found(self._name, _requirements(self._name), name) # type: ignore[call-overload] - - def __repr__(self) -> str: - return f"Reader[{self._name}]" - - def __init__(self, name: LiteralString, /) -> None: ... - - -class _PandasReaderBase(_Reader["pd.DataFrame", "pd.DataFrame"], Protocol): - """ - Provides temporal column names as keyword arguments on read. - - Related - ------- - - https://github.com/vega/altair/pull/3631#issuecomment-2480816377 - - https://github.com/vega/vega-datasets/pull/631 - - https://pandas.pydata.org/docs/reference/api/pandas.read_csv.html - - https://pandas.pydata.org/docs/reference/api/pandas.read_json.html - """ - - _schema_cache: SchemaCache - - def _schema_kwds(self, meta: Metadata, /) -> dict[str, Any]: - name: Any = meta["dataset_name"] - suffix = meta["suffix"] - if cols := self._schema_cache.by_dtype(name, nw.Date, nw.Datetime): - if suffix == ".json": - return {"convert_dates": cols} - elif suffix in {".csv", ".tsv"}: - return {"parse_dates": cols} - return super()._schema_kwds(meta) - - def _maybe_fn(self, meta: Metadata, /) -> Callable[..., pd.DataFrame]: - fn = super()._maybe_fn(meta) - if meta["is_spatial"]: - raise _ds_exc.geospatial(meta, self._name) - return fn - - -class _PandasReader(_PandasReaderBase): - def __init__(self, name: _Pandas, /) -> None: - self._name = _requirements(name) - if not TYPE_CHECKING: - pd = self._import(self._name) - self._read_fn = { - ".csv": pd.read_csv, - ".json": pd.read_json, - ".tsv": partial["pd.DataFrame"](pd.read_csv, sep="\t"), - ".arrow": pd.read_feather, - ".parquet": pd.read_parquet, - } - self._scan_fn = {".parquet": pd.read_parquet} - self._supports_parquet: bool = is_available( - "pyarrow", "fastparquet", require_all=False - ) - self._csv_cache = CsvCache() - self._schema_cache = SchemaCache() - - @property - def _metadata(self) -> nw.LazyFrame: - if self._supports_parquet: - return super()._metadata - return self._csv_cache.metadata(nw.dependencies.get_pandas()) - - -class _PandasPyArrowReader(_PandasReaderBase): - def __init__(self, name: Literal["pandas[pyarrow]"], /) -> None: - _pd, _pa = _requirements(name) - self._name = name - if not TYPE_CHECKING: - pd = self._import(_pd) - pa = self._import(_pa) # noqa: F841 - - self._read_fn = { - ".csv": partial["pd.DataFrame"](pd.read_csv, dtype_backend=_pa), - ".json": partial["pd.DataFrame"](pd.read_json, dtype_backend=_pa), - ".tsv": partial["pd.DataFrame"](pd.read_csv, sep="\t", dtype_backend=_pa), - ".arrow": partial(pd.read_feather, dtype_backend=_pa), - ".parquet": partial(pd.read_parquet, dtype_backend=_pa), - } - self._scan_fn = {".parquet": partial(pd.read_parquet, dtype_backend=_pa)} - self._schema_cache = SchemaCache() - - -def _pl_read_json_roundtrip(source: Path | IOBase, /, **kwds: Any) -> pl.DataFrame: - """ - Try to utilize better date parsing available in `pl.read_csv`_. - - `pl.read_json`_ has few options when compared to `pl.read_csv`_. - - Chaining the two together - *where possible* - is still usually faster than `pandas.read_json`_. - - .. _pl.read_json: - https://docs.pola.rs/api/python/stable/reference/api/polars.read_json.html - .. _pl.read_csv: - https://docs.pola.rs/api/python/stable/reference/api/polars.read_csv.html - .. _pandas.read_json: - https://pandas.pydata.org/docs/reference/api/pandas.read_json.html - """ - from io import BytesIO - - import polars as pl - - df = pl.read_json(source, **kwds) - if any(tp.is_nested() for tp in df.schema.dtypes()): - # NOTE: Inferred as `(Geo|Topo)JSON`, which wouldn't be supported by `read_csv` - return df - buf = BytesIO() - df.write_csv(buf) - if kwds: - SHARED_KWDS = {"schema", "schema_overrides", "infer_schema_length"} - kwds = {k: v for k, v in kwds.items() if k in SHARED_KWDS} - return pl.read_csv(buf, try_parse_dates=True, **kwds) - - -class _PolarsReader(_Reader["pl.DataFrame", "pl.LazyFrame"]): - def __init__(self, name: _Polars, /) -> None: - self._name = _requirements(name) - if not TYPE_CHECKING: - pl = self._import(self._name) - self._read_fn = { - ".csv": partial(pl.read_csv, try_parse_dates=True), - ".json": _pl_read_json_roundtrip, - ".tsv": partial(pl.read_csv, separator="\t", try_parse_dates=True), - ".arrow": pl.read_ipc, - ".parquet": pl.read_parquet, - } - self._scan_fn = {".parquet": pl.scan_parquet} - - -class _PyArrowReader(_Reader["pa.Table", "pa.Table"]): - """ - Reader backed by `pyarrow.Table`_. - - Warning - ------- - **JSON**: Only supports `line-delimited`_ JSON. - Likely to raise the following error: - - ArrowInvalid: JSON parse error: Column() changed from object to array in row 0 - - .. _pyarrow.Table: - https://arrow.apache.org/docs/python/generated/pyarrow.Table.html - .. _line-delimited: - https://arrow.apache.org/docs/python/json.html#reading-json-files - """ - - def _maybe_fn(self, meta: Metadata, /) -> Callable[..., pa.Table]: - fn = super()._maybe_fn(meta) - if fn == self._read_json_polars: - return fn - elif meta["is_json"]: - if meta["is_tabular"]: - return self._read_json_tabular - elif meta["is_spatial"]: - raise _ds_exc.geospatial(meta, self._name) - else: - raise _ds_exc.non_tabular_json(meta, self._name) - else: - return fn - - def _read_json_tabular(self, source: Any, /, **kwds: Any) -> pa.Table: - import json - - if not isinstance(source, Path): - obj = json.load(source) - else: - with Path(source).open(encoding="utf-8") as f: - obj = json.load(f) - pa = nw.dependencies.get_pyarrow() - return pa.Table.from_pylist(obj) - - def _read_json_polars(self, source: Any, /, **kwds: Any) -> pa.Table: - return _pl_read_json_roundtrip(source).to_arrow() - - def __init__(self, name: _PyArrow, /) -> None: - self._name = _requirements(name) - if not TYPE_CHECKING: - pa = self._import(self._name) # noqa: F841 - pa_read_csv = self._import(f"{self._name}.csv").read_csv - pa_read_feather = self._import(f"{self._name}.feather").read_table - pa_read_parquet = self._import(f"{self._name}.parquet").read_table - - # NOTE: Prefer `polars` since it is zero-copy and fast - if find_spec("polars") is not None: - pa_read_json = self._read_json_polars - else: - pa_read_json = self._import(f"{self._name}.json").read_json - - # NOTE: Stubs suggest using a dataclass, but no way to construct it - tab_sep: Any = {"delimiter": "\t"} - - self._read_fn = { - ".csv": pa_read_csv, - ".json": pa_read_json, - ".tsv": partial(pa_read_csv, parse_options=tab_sep), - ".arrow": pa_read_feather, - ".parquet": pa_read_parquet, - } - self._scan_fn = {".parquet": pa_read_parquet} - - -def _extract_constraints( - name: Dataset | LiteralString, suffix: Extension | None, / -) -> Metadata: - """Transform args into a mapping to column names.""" - constraints: Metadata = {} - if name.endswith(EXTENSION_SUFFIXES): - fp = Path(name) - constraints["dataset_name"] = fp.stem - constraints["suffix"] = fp.suffix - return constraints - elif suffix is not None: - if not is_ext_read(suffix): - msg = ( - f"Expected 'suffix' to be one of {EXTENSION_SUFFIXES!r},\n" - f"but got: {suffix!r}" - ) - raise TypeError(msg) - else: - constraints["suffix"] = suffix - constraints["dataset_name"] = name - return constraints - - -def _extract_suffix(source: _IntoSuffix, guard: Callable[..., TypeIs[_T]], /) -> _T: - suffix: Any = ( - Path(source).suffix if not isinstance(source, Mapping) else source["suffix"] - ) - if guard(suffix): - return suffix - else: - msg = f"Unexpected file extension {suffix!r}, from:\n{source}" - raise TypeError(msg) - - -def is_ext_scan(suffix: Any) -> TypeIs[_ExtensionScan]: - return suffix == ".parquet" - - -def is_available( - pkg_names: str | Iterable[str], *more_pkg_names: str, require_all: bool = True -) -> bool: - """ - Check for importable package(s), without raising on failure. - - Parameters - ---------- - pkg_names, more_pkg_names - One or more packages. - require_all - * ``True`` every package. - * ``False`` at least one package. - """ - if not more_pkg_names and isinstance(pkg_names, str): - return find_spec(pkg_names) is not None - pkgs_names = pkg_names if not isinstance(pkg_names, str) else (pkg_names,) - names = chain(pkgs_names, more_pkg_names) - fn = all if require_all else any - return fn(find_spec(name) is not None for name in names) - - -def infer_backend( - *, priority: Sequence[_Backend] = ("polars", "pandas[pyarrow]", "pandas", "pyarrow") -) -> _Reader[Any, Any]: - """ - Return the first available reader in order of `priority`. - - Notes - ----- - - ``"polars"``: can natively load every dataset (including ``(Geo|Topo)JSON``) - - ``"pandas[pyarrow]"``: can load *most* datasets, guarantees ``.parquet`` support - - ``"pandas"``: supports ``.parquet``, if `fastparquet`_ is installed - - ``"pyarrow"``: least reliable - - .. _fastparquet: - https://github.com/dask/fastparquet - """ - it = (backend(name) for name in priority if is_available(_requirements(name))) - if reader := next(it, None): - return reader - raise _ds_exc.AltairDatasetsError.from_priority(priority) - - -@overload -def backend(name: _Polars, /) -> _Reader[pl.DataFrame, pl.LazyFrame]: ... - - -@overload -def backend(name: _PandasAny, /) -> _Reader[pd.DataFrame, pd.DataFrame]: ... - - -@overload -def backend(name: _PyArrow, /) -> _Reader[pa.Table, pa.Table]: ... - - -def backend(name: _Backend, /) -> _Reader[Any, Any]: - """Reader initialization dispatcher.""" - if name == "polars": - return _PolarsReader(name) - elif name == "pandas[pyarrow]": - return _PandasPyArrowReader(name) - elif name == "pandas": - return _PandasReader(name) - elif name == "pyarrow": - return _PyArrowReader(name) - elif name in {"ibis", "cudf", "dask", "modin"}: - msg = "Supported by ``narwhals``, not investigated yet" - raise NotImplementedError(msg) - else: - msg = f"Unknown backend {name!r}" - raise TypeError(msg) - - -@overload -def _requirements(s: _ConcreteT, /) -> _ConcreteT: ... - - -@overload -def _requirements(s: Literal["pandas[pyarrow]"], /) -> tuple[_Pandas, _PyArrow]: ... - - -def _requirements(s: Any, /) -> Any: - concrete: set[Literal[_Polars, _Pandas, _PyArrow]] = {"polars", "pandas", "pyarrow"} - if s in concrete: - return s - else: - from packaging.requirements import Requirement - - req = Requirement(s) - supports_extras: set[Literal[_Pandas]] = {"pandas"} - if req.name in supports_extras and req.extras == {"pyarrow"}: - return req.name, "pyarrow" - return _requirements_unknown(req) - - -def _requirements_unknown(req: Requirement | str, /) -> Any: - from packaging.requirements import Requirement - - req = Requirement(req) if isinstance(req, str) else req - return (req.name, *req.extras) diff --git a/altair/datasets/_readimpl.py b/altair/datasets/_readimpl.py new file mode 100644 index 000000000..119352db5 --- /dev/null +++ b/altair/datasets/_readimpl.py @@ -0,0 +1,414 @@ +"""Individual read functions and siuations they support.""" + +from __future__ import annotations + +import sys +from enum import Enum +from functools import partial, wraps +from importlib.util import find_spec +from itertools import chain +from operator import itemgetter +from pathlib import Path +from typing import TYPE_CHECKING, Any, Generic, Literal + +from narwhals.stable import v1 as nw +from narwhals.stable.v1.dependencies import get_pandas, get_polars +from narwhals.stable.v1.typing import IntoDataFrameT + +from altair.datasets._constraints import ( + is_arrow, + is_csv, + is_json, + is_meta, + is_not_tabular, + is_parquet, + is_spatial, + is_tsv, +) +from altair.datasets._exceptions import AltairDatasetsError + +if sys.version_info >= (3, 13): + from typing import TypeVar +else: + from typing_extensions import TypeVar + +if TYPE_CHECKING: + from collections.abc import Callable, Iterable, Iterator, Sequence + from io import IOBase + from types import ModuleType + + import pandas as pd + import polars as pl + import pyarrow as pa + from narwhals.stable.v1 import typing as nwt + + from altair.datasets._constraints import Items, MetaIs + +__all__ = ["is_available", "pa_any", "pd_only", "pd_pyarrow", "pl_only", "read", "scan"] + +R = TypeVar("R") +IntoFrameT = TypeVar( + "IntoFrameT", + bound="nwt.NativeFrame | nw.DataFrame[Any] | nw.LazyFrame | nwt.DataFrameLike", + default=nw.LazyFrame, +) + + +class Skip(Enum): + """Falsy sentinel.""" + + skip = 0 + + def __bool__(self) -> Literal[False]: + return False + + def __repr__(self) -> Literal[""]: + return "" + + +class BaseImpl(Generic[R]): + fn: Callable[..., R] + """Wrapped read function.""" + include: MetaIs + """Passing this makes ``fn`` a candidate.""" + exclude: MetaIs + """Passing this overrides ``include``, transforming into an error.""" + + def __init__( + self, + fn: Callable[..., R], + include: MetaIs, + exclude: MetaIs | None, + kwds: dict[str, Any], + /, + ) -> None: + exclude = exclude or self._exclude_none() + if not include.isdisjoint(exclude): + intersection = ", ".join(f"{k}={v!r}" for k, v in include & exclude) + msg = f"Constraints overlap at: `{intersection}`\ninclude={include!r}\nexclude={exclude!r}" + raise TypeError(msg) + object.__setattr__(self, "fn", partial(fn, **kwds) if kwds else fn) + object.__setattr__(self, "include", include) + object.__setattr__(self, "exclude", exclude) + + # TODO: Consider renaming + # NOTE: + # - Fn means call it + # - Err means raise it + # - Skip means its safe to check other impls + def unwrap_or( + self, meta: Items, / + ) -> Callable[..., R] | type[AltairDatasetsError] | Skip: + if self.include.issubset(meta): + return self.fn if self.exclude.isdisjoint(meta) else AltairDatasetsError + return Skip.skip + + @classmethod + def _exclude_none(cls) -> MetaIs: + return is_meta() + + def __setattr__(self, name: str, value: Any): + msg = ( + f"{type(self).__name__!r} is immutable.\n" + f"Could not assign self.{name} = {value}" + ) + raise TypeError(msg) + + @property + def _inferred_package(self) -> str: + return _root_package_name(_unwrap_partial(self.fn), "UNKNOWN") + + def __repr__(self) -> str: + tp_name = f"{type(self).__name__}[{self._inferred_package}?]" + return f"{tp_name}({self._contents})" + + # TODO: Consider renaming + @property + def _contents(self) -> str: + if isinstance(self.fn, partial): + fn = _unwrap_partial(self.fn) + it = (f"{k}={v!r}" for k, v in self.fn.keywords.items()) + fn_repr = f"{fn.__name__}(..., {', '.join(it)})" + else: + fn_repr = f"{self.fn.__name__}(...)" + if self.exclude: + params = f"include={self.include!r}, exclude={self.exclude!r}" + else: + params = repr(self.include) + return f"{fn_repr}, {params}" + + @property + def _relevant_columns(self) -> Iterator[str]: + name = itemgetter(0) + yield from (name(obj) for obj in chain(self.include, self.exclude)) + + @property + def _include_expr(self) -> nw.Expr: + return ( + self.include.to_expr() & ~self.exclude.to_expr() + if self.exclude + else self.include.to_expr() + ) + + @property + def _exclude_expr(self) -> nw.Expr: + if self.exclude: + return self.include.to_expr() & self.exclude.to_expr() + msg = f"Unable to generate an exclude expression without setting exclude\n\n{self!r}" + raise TypeError(msg) + + +def _unwrap_partial(fn: Any, /) -> Any: + # NOTE: ``functools._unwrap_partial`` + func = fn + while isinstance(func, partial): + func = func.func + return func + + +class ScanImpl(BaseImpl[IntoFrameT]): ... + + +class ReadImpl(BaseImpl[IntoDataFrameT]): + def to_scan_impl(self) -> ScanImpl[nw.LazyFrame]: + return ScanImpl(_into_scan_fn(self.fn), self.include, self.exclude, {}) + + +def _into_scan_fn(fn: Callable[..., IntoDataFrameT], /) -> Callable[..., nw.LazyFrame]: + @wraps(_unwrap_partial(fn)) + def wrapper(*args: Any, **kwds: Any) -> nw.LazyFrame: + return nw.from_native(fn(*args, **kwds)).lazy() + + return wrapper + + +def _root_package_name(obj: Any, default: str, /) -> str: + # NOTE: Defers importing `inspect`, if we can get the module name + if hasattr(obj, "__module__"): + return obj.__module__.split(".")[0] + else: + from inspect import getmodule + + module = getmodule(obj) + if module and (pkg := module.__package__): + return pkg.split(".")[0] + return default + + +def is_available( + pkg_names: str | Iterable[str], *more_pkg_names: str, require_all: bool = True +) -> bool: + """ + Check for importable package(s), without raising on failure. + + Parameters + ---------- + pkg_names, more_pkg_names + One or more packages. + require_all + * ``True`` every package. + * ``False`` at least one package. + """ + if not more_pkg_names and isinstance(pkg_names, str): + return find_spec(pkg_names) is not None + pkgs_names = pkg_names if not isinstance(pkg_names, str) else (pkg_names,) + names = chain(pkgs_names, more_pkg_names) + fn = all if require_all else any + return fn(find_spec(name) is not None for name in names) + + +def read( + fn: Callable[..., IntoDataFrameT], + /, + include: MetaIs, + exclude: MetaIs | None = None, + **kwds: Any, +) -> ReadImpl[IntoDataFrameT]: + return ReadImpl(fn, include, exclude, kwds) + + +def scan( + fn: Callable[..., IntoFrameT], + /, + include: MetaIs, + exclude: MetaIs | None = None, + **kwds: Any, +) -> ScanImpl[IntoFrameT]: + return ScanImpl(fn, include, exclude, kwds) + + +def pl_only() -> tuple[ + Sequence[ReadImpl[pl.DataFrame]], Sequence[ScanImpl[pl.LazyFrame]] +]: + import polars as pl + + read_fns = ( + read(pl.read_csv, is_csv, try_parse_dates=True), + read(_pl_read_json_roundtrip(get_polars()), is_json), + read(pl.read_csv, is_tsv, separator="\t", try_parse_dates=True), + read(pl.read_ipc, is_arrow), + read(pl.read_parquet, is_parquet), + ) + scan_fns = (scan(pl.scan_parquet, is_parquet),) + return read_fns, scan_fns + + +def pd_only() -> Sequence[ReadImpl[pd.DataFrame]]: + import pandas as pd + + opt: Sequence[ReadImpl[pd.DataFrame]] + if is_available("pyarrow"): + opt = read(pd.read_feather, is_arrow), read(pd.read_parquet, is_parquet) + elif is_available("fastparquet"): + opt = (read(pd.read_parquet, is_parquet),) + else: + opt = () + return ( + read(pd.read_csv, is_csv), + read(_pd_read_json(get_pandas()), is_json, exclude=is_spatial), + read(pd.read_csv, is_tsv, sep="\t"), + *opt, + ) + + +def pd_pyarrow() -> Sequence[ReadImpl[pd.DataFrame]]: + import pandas as pd + + kwds: dict[str, Any] = {"dtype_backend": "pyarrow"} + return ( + read(pd.read_csv, is_csv, **kwds), + read(_pd_read_json(get_pandas()), is_json, exclude=is_spatial, **kwds), + read(pd.read_csv, is_tsv, sep="\t", **kwds), + read(pd.read_feather, is_arrow, **kwds), + read(pd.read_parquet, is_parquet, **kwds), + ) + + +def pa_any() -> Sequence[ReadImpl[pa.Table]]: + from pyarrow import csv, feather, parquet + + return ( + read(csv.read_csv, is_csv), + _pa_read_json_impl(), + read(csv.read_csv, is_tsv, parse_options={"delimiter": "\t"}), + read(feather.read_table, is_arrow), + read(parquet.read_table, is_parquet), + ) + + +def _pa_read_json_impl() -> ReadImpl[pa.Table]: + """ + Mitigating ``pyarrow``'s `line-delimited`_ JSON requirement. + + .. _line-delimited: + https://arrow.apache.org/docs/python/json.html#reading-json-files + """ + if is_available("polars"): + return read(_pl_read_json_roundtrip_to_arrow(get_polars()), is_json) + elif is_available("pandas"): + return read(_pd_read_json_to_arrow(get_pandas()), is_json, exclude=is_spatial) + return read(_stdlib_read_json_to_arrow, is_json, exclude=is_not_tabular) + + +def _pd_read_json(ns: ModuleType, /) -> Callable[..., pd.DataFrame]: + @wraps(ns.read_json) + def fn(source: Path | Any, /, **kwds: Any) -> pd.DataFrame: + return _pd_fix_dtypes_nw(ns.read_json(source, **kwds), **kwds).to_native() + + return fn + + +def _pd_fix_dtypes_nw( + df: pd.DataFrame, /, *, dtype_backend: Any = None, **kwds: Any +) -> nw.DataFrame[pd.DataFrame]: + kwds = {"dtype_backend": dtype_backend} if dtype_backend else {} + return ( + df.convert_dtypes(**kwds) + .pipe(nw.from_native, eager_only=True) + .with_columns(nw.selectors.by_dtype(nw.Object).cast(nw.String)) + ) + + +def _pd_read_json_to_arrow(ns: ModuleType, /) -> Callable[..., pa.Table]: + @wraps(ns.read_json) + def fn(source: Path | Any, /, *, schema: Any = None, **kwds: Any) -> pa.Table: + """``schema`` is only here to swallow the ``SchemaCache`` if used.""" + return ( + ns.read_json(source, **kwds) + .pipe(_pd_fix_dtypes_nw, dtype_backend="pyarrow") + .to_arrow() + ) + + return fn + + +def _pl_read_json_roundtrip(ns: ModuleType, /) -> Callable[..., pl.DataFrame]: + """ + Try to utilize better date parsing available in `pl.read_csv`_. + + `pl.read_json`_ has few options when compared to `pl.read_csv`_. + + Chaining the two together - *where possible* - is still usually faster than `pandas.read_json`_. + + .. _pl.read_json: + https://docs.pola.rs/api/python/stable/reference/api/polars.read_json.html + .. _pl.read_csv: + https://docs.pola.rs/api/python/stable/reference/api/polars.read_csv.html + .. _pandas.read_json: + https://pandas.pydata.org/docs/reference/api/pandas.read_json.html + """ + from io import BytesIO + + @wraps(ns.read_json) + def fn(source: Path | IOBase, /, **kwds: Any) -> pl.DataFrame: + df = ns.read_json(source, **kwds) + if any(tp.is_nested() for tp in df.schema.dtypes()): + return df + buf = BytesIO() + df.write_csv(buf) + if kwds: + SHARED_KWDS = {"schema", "schema_overrides", "infer_schema_length"} + kwds = {k: v for k, v in kwds.items() if k in SHARED_KWDS} + return ns.read_csv(buf, try_parse_dates=True, **kwds) + + return fn + + +def _pl_read_json_roundtrip_to_arrow(ns: ModuleType, /) -> Callable[..., pa.Table]: + eager = _pl_read_json_roundtrip(ns) + + @wraps(ns.read_json) + def fn(source: Path | IOBase, /, **kwds: Any) -> pa.Table: + return eager(source).to_arrow() + + return fn + + +def _stdlib_read_json(source: Path | Any, /) -> Any: + import json + + if not isinstance(source, Path): + return json.load(source) + else: + with Path(source).open(encoding="utf-8") as f: + return json.load(f) + + +def _stdlib_read_json_to_arrow(source: Path | Any, /, **kwds: Any) -> pa.Table: + import pyarrow as pa + + rows: list[dict[str, Any]] = _stdlib_read_json(source) + try: + return pa.Table.from_pylist(rows, **kwds) + except TypeError: + import csv + import io + + from pyarrow import csv as pa_csv + + with io.StringIO() as f: + writer = csv.DictWriter(f, rows[0].keys(), dialect=csv.unix_dialect) + writer.writeheader() + writer.writerows(rows) + with io.BytesIO(f.getvalue().encode()) as f2: + return pa_csv.read_csv(f2) diff --git a/tests/test_datasets.py b/tests/test_datasets.py index 0855b73af..3765fa69b 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -29,7 +29,7 @@ import polars as pl from _pytest.mark.structures import ParameterSet - from altair.datasets._readers import _Backend, _PandasAny, _Polars, _PyArrow + from altair.datasets._reader import _Backend, _PandasAny, _Polars, _PyArrow from altair.vegalite.v5.schema._typing import OneOrSeq if sys.version_info >= (3, 10): @@ -117,11 +117,14 @@ def is_url(name: Dataset, fn_url: Callable[..., str], /) -> bool: def is_polars_backed_pyarrow(loader: Loader[Any, Any], /) -> bool: """User requested ``pyarrow``, but also has ``polars`` installed.""" # NOTE: Would prefer if there was a *less* private method to test this. - return bool( - is_loader_backend(loader, "pyarrow") - and (fn := getattr(loader._reader, "_read_json_polars", None)) - and fn == loader._reader.read_fn("dummy.json") - ) + from altair.datasets._constraints import is_meta + + if is_loader_backend(loader, "pyarrow"): + items = is_meta(suffix=".json", is_spatial=True) + impls = loader._reader._read + it = (some for impl in impls if (some := impl.unwrap_or(items))) + return callable(next(it, None)) + return False @backends @@ -151,7 +154,7 @@ def test_load_infer_priority(monkeypatch: pytest.MonkeyPatch) -> None: See Also -------- - ``altair.datasets._readers.infer_backend`` + ``altair.datasets._reader.infer_backend`` """ import altair.datasets._loader from altair.datasets import load @@ -247,7 +250,7 @@ def test_url(name: Dataset) -> None: def test_url_no_backend(monkeypatch: pytest.MonkeyPatch) -> None: from altair.datasets._cache import csv_cache - from altair.datasets._readers import infer_backend + from altair.datasets._reader import infer_backend priority: Any = ("fake_mod_1", "fake_mod_2", "fake_mod_3", "fake_mod_4") assert csv_cache._mapping == {} @@ -318,7 +321,7 @@ def test_dataset_not_found(backend: _Backend) -> None: with pytest.raises( ERR_NO_RESULT, match=re.compile( - rf"{MSG_NO_RESULT}.+{SUFFIX}.+{incorrect_suffix}.+{NAME}.+{real_name}", + rf"{MSG_NO_RESULT}.+{NAME}.+{real_name}.+{SUFFIX}.+{incorrect_suffix}", re.DOTALL, ), ): @@ -326,19 +329,7 @@ def test_dataset_not_found(backend: _Backend) -> None: def test_reader_missing_dependencies() -> None: - from packaging.requirements import Requirement - - from altair.datasets._readers import _Reader - - class MissingDeps(_Reader): - def __init__(self, name) -> None: - self._name = name - reqs = Requirement(name) - for req in (reqs.name, *reqs.extras): - self._import(req) - - self._read_fn = {} - self._scan_fn = {} + from altair.datasets._reader import _import_guarded fake_name = "not_a_real_package" real_name = "altair" @@ -351,7 +342,7 @@ def __init__(self, name) -> None: flags=re.DOTALL, ), ): - MissingDeps(fake_name) + _import_guarded(fake_name) # type: ignore with pytest.raises( ModuleNotFoundError, match=re.compile( @@ -359,7 +350,7 @@ def __init__(self, name) -> None: flags=re.DOTALL, ), ): - MissingDeps(backend) + _import_guarded(backend) # type: ignore @backends @@ -494,38 +485,10 @@ def test_reader_cache_disable(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) - assert not load.cache.is_empty() -# TODO: Investigate adding schemas for `pyarrow`. @pytest.mark.parametrize( - ("name", "fallback"), - [ - ("cars", "polars"), - ("movies", "polars"), - ("wheat", "polars"), - ("barley", "polars"), - ("gapminder", "polars"), - ("income", "polars"), - ("burtin", "polars"), - ("cars", None), - pytest.param( - "movies", - None, - marks=pytest.mark.xfail( - True, - raises=TypeError, - reason=( - "msg: `Expected bytes, got a 'int' object`\n" - "Isn't happy with the mixed `int`/`str` column." - ), - strict=True, - ), - ), - ("wheat", None), - ("barley", None), - ("gapminder", None), - ("income", None), - ("burtin", None), - ], + "name", ["cars", "movies", "wheat", "barley", "gapminder", "income", "burtin"] ) +@pytest.mark.parametrize("fallback", ["polars", None]) @backends_pyarrow def test_pyarrow_read_json( backend: _PyArrow, @@ -550,7 +513,7 @@ def test_spatial(backend: _Backend, name: Dataset) -> None: rf"{name}.+geospatial.+native.+{re.escape(backend)}.+try.+polars.+url", flags=re.DOTALL | re.IGNORECASE, ) - with pytest.raises(NotImplementedError, match=pattern): + with pytest.raises(AltairDatasetsError, match=pattern): load(name) @@ -558,7 +521,11 @@ def test_spatial(backend: _Backend, name: Dataset) -> None: @datasets_debug def test_all_datasets(polars_loader: PolarsLoader, name: Dataset) -> None: if name in {"7zip", "ffox", "gimp"}: - with pytest.raises(AltairDatasetsError, match=rf"{name}.+tabular"): + pattern = re.compile( + rf"Unable to load.+{name}.png.+as tabular data", + flags=re.DOTALL | re.IGNORECASE, + ) + with pytest.raises((AltairDatasetsError, NotImplementedError), match=pattern): polars_loader(name) else: frame = polars_loader(name)