Skip to content

Commit

Permalink
feat(DRAFT): Private API overhaul
Browse files Browse the repository at this point in the history
**Public API is unchanged**
Core changes are to simplify testing and extension:

- `_readers.py` -> `_reader.py`
  - w/ two new support modules `_constraints`, and `_readimpl`
- Functions (`BaseImpl`) are declared with what they support (`include`) and restrictions (`exclude`) on that subset
  - Transforms a lot of the imperative logic into set operations
- Greatly improved `pyarrow` support
  - Utilize schema
  - Provides additional fallback `.json` implementations
  - `_stdlib_read_json_to_arrow` finally resolves `"movies.json"` issue
  • Loading branch information
dangotbanned committed Jan 29, 2025
1 parent d3b3ef2 commit b606a7d
Show file tree
Hide file tree
Showing 8 changed files with 1,260 additions and 686 deletions.
106 changes: 95 additions & 11 deletions altair/datasets/_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,9 @@
from collections import defaultdict
from importlib.util import find_spec
from pathlib import Path
from typing import TYPE_CHECKING, ClassVar, Generic, TypeVar, cast, get_args
from typing import TYPE_CHECKING, ClassVar, TypeVar, cast, get_args

import narwhals.stable.v1 as nw
from narwhals.stable.v1.typing import IntoDataFrameT, IntoFrameT

from altair.datasets._exceptions import AltairDatasetsError
from altair.datasets._typing import Dataset
Expand All @@ -29,12 +28,18 @@
)
from io import IOBase
from typing import Any, Final
from urllib.request import OpenerDirector

from _typeshed import StrPath
from narwhals.stable.v1.dtypes import DType
from narwhals.stable.v1.typing import IntoExpr

from altair.datasets._typing import Metadata

if sys.version_info >= (3, 12):
from typing import Unpack
else:
from typing_extensions import Unpack
if sys.version_info >= (3, 11):
from typing import LiteralString
else:
Expand All @@ -43,8 +48,8 @@
from typing import TypeAlias
else:
from typing_extensions import TypeAlias
from altair.datasets._readers import _Reader
from altair.datasets._typing import FlFieldStr
from altair.vegalite.v5.schema._typing import OneOrSeq

_Dataset: TypeAlias = "Dataset | LiteralString"
_FlSchema: TypeAlias = Mapping[str, FlFieldStr]
Expand Down Expand Up @@ -83,6 +88,10 @@
https://narwhals-dev.github.io/narwhals/api-reference/dtypes/
"""

_FIELD_TO_DTYPE: Mapping[FlFieldStr, type[DType]] = {
v: k for k, v in _DTYPE_TO_FIELD.items()
}


def _iter_metadata(df: nw.DataFrame[Any], /) -> Iterator[Metadata]:
"""
Expand Down Expand Up @@ -179,10 +188,7 @@ def rotated(self) -> Mapping[str, Sequence[Any]]:
self._rotated[k].append(v)
return self._rotated

def metadata(self, ns: Any, /) -> nw.LazyFrame:
data: Any = self.rotated
return nw.maybe_convert_dtypes(nw.from_dict(data, native_namespace=ns).lazy())

# TODO: Evaluate which errors are now obsolete
def __getitem__(self, key: _Dataset, /) -> Metadata:
if meta := self.get(key, None):
return meta
Expand All @@ -194,6 +200,7 @@ def __getitem__(self, key: _Dataset, /) -> Metadata:
msg = f"{key!r} does not refer to a known dataset."
raise TypeError(msg)

# TODO: Evaluate which errors are now obsolete
def url(self, name: _Dataset, /) -> str:
if meta := self.get(name, None):
if meta["suffix"] == ".parquet" and not find_spec("vegafusion"):
Expand All @@ -207,6 +214,9 @@ def url(self, name: _Dataset, /) -> str:
msg = f"{name!r} does not refer to a known dataset."
raise TypeError(msg)

def __repr__(self) -> str:
return f"<{type(self).__name__}: {'COLLECTED' if self._mapping else 'READY'}>"


class SchemaCache(CompressedCache["_Dataset", "_FlSchema"]):
"""
Expand All @@ -230,8 +240,10 @@ def __init__(
self,
*,
tp: type[MutableMapping[_Dataset, _FlSchema]] = dict["_Dataset", "_FlSchema"],
implementation: nw.Implementation = nw.Implementation.UNKNOWN,
) -> None:
self._mapping: MutableMapping[_Dataset, _FlSchema] = tp()
self._implementation: nw.Implementation = implementation

def read(self) -> Any:
import json
Expand Down Expand Up @@ -259,17 +271,72 @@ def by_dtype(self, name: _Dataset, *dtypes: type[DType]) -> list[str]:
else:
return list(match)

def is_active(self) -> bool:
return self._implementation in {
nw.Implementation.PANDAS,
nw.Implementation.PYARROW,
nw.Implementation.MODIN,
nw.Implementation.PYARROW,
}

def schema_kwds(self, meta: Metadata, /) -> dict[str, Any]:
name: Any = meta["dataset_name"]
impl = self._implementation
if (impl.is_pandas_like() or impl.is_pyarrow()) and (self[name]):
suffix = meta["suffix"]
if impl.is_pandas_like():
if cols := self.by_dtype(name, nw.Date, nw.Datetime):
if suffix == ".json":
return {"convert_dates": cols}
elif suffix in {".csv", ".tsv"}:
return {"parse_dates": cols}
else:
schema = self.schema_pyarrow(name)
if suffix in {".csv", ".tsv"}:
from pyarrow.csv import ConvertOptions

return {"convert_options": ConvertOptions(column_types=schema)} # pyright: ignore[reportCallIssue]
elif suffix == ".parquet":
return {"schema": schema}

return {}

def schema(self, name: _Dataset, /) -> Mapping[str, DType]:
return {
column: _FIELD_TO_DTYPE[tp_str]() for column, tp_str in self[name].items()
}

# TODO: Open an issue in ``narwhals`` to try and get a public api for type conversion
def schema_pyarrow(self, name: _Dataset, /):
schema = self.schema(name)
if schema:
from narwhals._arrow.utils import narwhals_to_native_dtype
from narwhals.utils import Version

class DatasetCache(Generic[IntoDataFrameT, IntoFrameT]):
m = {k: narwhals_to_native_dtype(v, Version.V1) for k, v in schema.items()}
else:
m = {}
return nw.dependencies.get_pyarrow().schema(m)


class _SupportsScanMetadata(Protocol):
_opener: ClassVar[OpenerDirector]

def _scan_metadata(
self, *predicates: OneOrSeq[IntoExpr], **constraints: Unpack[Metadata]
) -> nw.LazyFrame: ...


class DatasetCache:
"""Opt-out caching of remote dataset requests."""

_ENV_VAR: ClassVar[LiteralString] = "ALTAIR_DATASETS_DIR"
_XDG_CACHE: ClassVar[Path] = (
Path(os.environ.get("XDG_CACHE_HOME", Path.home() / ".cache")) / "altair"
).resolve()

def __init__(self, reader: _Reader[IntoDataFrameT, IntoFrameT], /) -> None:
self._rd: _Reader[IntoDataFrameT, IntoFrameT] = reader
def __init__(self, reader: _SupportsScanMetadata, /) -> None:
self._rd: _SupportsScanMetadata = reader

def clear(self) -> None:
"""Delete all previously cached datasets."""
Expand Down Expand Up @@ -308,10 +375,24 @@ def download_all(self) -> None:
return None
print(f"Downloading {len(frame)} missing datasets...")
for meta in _iter_metadata(frame):
self._rd._download(meta["url"], self.path / (meta["sha"] + meta["suffix"]))
self._download_one(meta["url"], self.path_meta(meta))
print("Finished downloads")
return None

def _maybe_download(self, meta: Metadata, /) -> Path:
fp = self.path_meta(meta)
return (
fp
if (fp.exists() and fp.stat().st_size)
else self._download_one(meta["url"], fp)
)

def _download_one(self, url: str, fp: Path, /) -> Path:
with self._rd._opener.open(url) as f:
fp.touch()
fp.write_bytes(f.read())
return fp

@property
def path(self) -> Path:
"""
Expand Down Expand Up @@ -354,6 +435,9 @@ def path(self, source: StrPath | None, /) -> None:
else:
os.environ[self._ENV_VAR] = ""

def path_meta(self, meta: Metadata, /) -> Path:
return self.path / (meta["sha"] + meta["suffix"])

def __iter__(self) -> Iterator[Path]:
yield from self.path.iterdir()

Expand Down
115 changes: 115 additions & 0 deletions altair/datasets/_constraints.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
"""Set-like guards for matching metadata to an implementation."""

from __future__ import annotations

from collections.abc import Set
from itertools import chain
from typing import TYPE_CHECKING, Any

from narwhals.stable import v1 as nw

if TYPE_CHECKING:
import sys
from collections.abc import Iterable, Iterator

from altair.datasets._typing import Metadata

if sys.version_info >= (3, 12):
from typing import Unpack
else:
from typing_extensions import Unpack
if sys.version_info >= (3, 10):
from typing import TypeAlias
else:
from typing_extensions import TypeAlias

__all__ = [
"Items",
"MetaIs",
"is_arrow",
"is_csv",
"is_json",
"is_meta",
"is_not_tabular",
"is_parquet",
"is_spatial",
"is_tsv",
]

Items: TypeAlias = Set[tuple[str, Any]]


class MetaIs(Set[tuple[str, Any]]):
_requires: frozenset[tuple[str, Any]]

def __init__(self, kwds: frozenset[tuple[str, Any]], /) -> None:
object.__setattr__(self, "_requires", kwds)

@classmethod
def from_metadata(cls, meta: Metadata, /) -> MetaIs:
return cls(frozenset(meta.items()))

def to_metadata(self) -> Metadata:
if TYPE_CHECKING:

def collect(**kwds: Unpack[Metadata]) -> Metadata:
return kwds

return collect(**dict(self))
return dict(self)

def to_expr(self) -> nw.Expr:
return nw.all_horizontal(nw.col(name) == val for name, val in self)

def isdisjoint(self, other: Iterable[Any]) -> bool:
return super().isdisjoint(other)

def issubset(self, other: Iterable[Any]) -> bool:
return self._requires.issubset(other)

def __call__(self, meta: Items, /) -> bool:
return self._requires <= meta

def __hash__(self) -> int:
return hash(self._requires)

def __contains__(self, x: object) -> bool:
return self._requires.__contains__(x)

def __iter__(self) -> Iterator[tuple[str, Any]]:
yield from self._requires

def __len__(self) -> int:
return self._requires.__len__()

def __setattr__(self, name: str, value: Any):
msg = (
f"{type(self).__name__!r} is immutable.\n"
f"Could not assign self.{name} = {value}"
)
raise TypeError(msg)

def __repr__(self) -> str:
items = dict(self)
if not items:
contents = "<placeholder>"
elif suffix := items.pop("suffix", None):
contents = ", ".join(
chain([f"'*{suffix}'"], (f"{k}={v!r}" for k, v in items.items()))
)
else:
contents = ", ".join(f"{k}={v!r}" for k, v in items.items())
return f"is_meta({contents})"


def is_meta(**kwds: Unpack[Metadata]) -> MetaIs:
return MetaIs.from_metadata(kwds)


is_csv = is_meta(suffix=".csv")
is_json = is_meta(suffix=".json")
is_tsv = is_meta(suffix=".tsv")
is_arrow = is_meta(suffix=".arrow")
is_parquet = is_meta(suffix=".parquet")
is_spatial = is_meta(is_spatial=True)
is_not_tabular = is_meta(is_tabular=False)
Loading

0 comments on commit b606a7d

Please sign in to comment.