Skip to content

Commit

Permalink
feat: Support non-variadic input to field.one_of
Browse files Browse the repository at this point in the history
  • Loading branch information
dangotbanned committed Jul 26, 2024
1 parent 35360b7 commit 7a4451a
Showing 1 changed file with 42 additions and 13 deletions.
55 changes: 42 additions & 13 deletions altair/vegalite/v5/_api_rfc.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

from __future__ import annotations

from typing import TYPE_CHECKING, Any, Dict, Literal, Sequence, Union
from typing import TYPE_CHECKING, Any, Dict, Literal, Mapping, Sequence, Union

from typing_extensions import TypeAlias

Expand Down Expand Up @@ -93,6 +93,38 @@ def _wrap_composition(predicate: Predicate, /) -> SelectionPredicateComposition:
return SelectionPredicateComposition(predicate.to_dict())


def _one_of_flatten(
values: tuple[OneOfType, ...] | tuple[Sequence[OneOfType]] | tuple[Any, ...], /
) -> Sequence[OneOfType]:
if (
len(values) == 1
and not isinstance(values[0], (str, bool, float, int, Mapping, SchemaBase))
and isinstance(values[0], Sequence)
):
return values[0]
elif len(values) > 1:
return values
else:
msg = (
f"Expected `values` to be either a single `Sequence` "
f"or used variadically, but got: {values!r}."
)
raise TypeError(msg)


def _one_of_variance(val_1: Any, *rest: OneOfType) -> Sequence[Any]:
# Required that all elements are the same type
tp = type(val_1)
if all(isinstance(v, tp) for v in rest):
return (val_1, *rest)
else:
msg = (
f"Expected all `values` to be of the same type, but got:\n"
f"{tuple(f'{type(v).__name__}' for v in (val_1, *rest))!r}"
)
raise TypeError(msg)


class agg:
"""
Utility class providing autocomplete for shorthand.
Expand Down Expand Up @@ -279,19 +311,16 @@ def __new__( # type: ignore[misc]

@classmethod
def one_of(
cls, field: str, /, *values: OneOfType, timeUnit: TimeUnitType = Undefined
cls,
field: str,
/,
*values: OneOfType | Sequence[OneOfType],
timeUnit: TimeUnitType = Undefined,
) -> SelectionPredicateComposition:
tp: type[Any] = type(values[0])
if all(isinstance(v, tp) for v in values):
vals: Sequence[Any] = values
p = FieldOneOfPredicate(field=field, oneOf=vals, timeUnit=timeUnit)
return _wrap_composition(p)
else:
msg = (
f"Expected all `values` to be of the same type, but got:\n"
f"{tuple(f"{type(v).__name__}" for v in values)!r}"
)
raise TypeError(msg)
seq = _one_of_flatten(values)
one_of = _one_of_variance(*seq)
p = FieldOneOfPredicate(field=field, oneOf=one_of, timeUnit=timeUnit)
return _wrap_composition(p)

@classmethod
def eq(
Expand Down

0 comments on commit 7a4451a

Please sign in to comment.