Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

improve object var symantics #4290

Merged
merged 10 commits into from
Nov 5, 2024
29 changes: 25 additions & 4 deletions reflex/utils/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ def serializer(
)

# Apply type transformation if requested
if to is not None:
if to is not None or ((to := type_hints.get("return")) is not None):
SERIALIZER_TYPES[type_] = to
get_serializer_type.cache_clear()

Expand Down Expand Up @@ -189,16 +189,37 @@ def get_serializer_type(type_: Type) -> Optional[Type]:
return None


def has_serializer(type_: Type) -> bool:
def has_serializer(type_: Type, into_type: Type | None = None) -> bool:
"""Check if there is a serializer for the type.

Args:
type_: The type to check.
into_type: The type to serialize into.

Returns:
Whether there is a serializer for the type.
"""
return get_serializer(type_) is not None
serializer_for_type = get_serializer(type_)
return serializer_for_type is not None and (
into_type is None or get_serializer_type(type_) == into_type
)


def can_serialize(type_: Type, into_type: Type | None = None) -> bool:
"""Check if there is a serializer for the type.

Args:
type_: The type to check.
into_type: The type to serialize into.

Returns:
Whether there is a serializer for the type.
"""
return has_serializer(type_, into_type) or (
isinstance(type_, type)
and dataclasses.is_dataclass(type_)
and (into_type is None or into_type is dict)
)


@serializer(to=str)
Expand All @@ -214,7 +235,7 @@ def serialize_type(value: type) -> str:
return value.__name__


@serializer
@serializer(to=dict)
def serialize_base(value: Base) -> dict:
"""Serialize a Base instance.

Expand Down
82 changes: 48 additions & 34 deletions reflex/vars/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,6 @@
if TYPE_CHECKING:
from reflex.state import BaseState

from .function import FunctionVar
from .number import (
BooleanVar,
NumberVar,
Expand Down Expand Up @@ -279,6 +278,24 @@ def _decode_var_immutable(value: str) -> tuple[VarData | None, str]:
return VarData.merge(*var_datas) if var_datas else None, value


def can_use_in_object_var(cls: GenericType) -> bool:
"""Check if the class can be used in an ObjectVar.

Args:
cls: The class to check.

Returns:
Whether the class can be used in an ObjectVar.
"""
if types.is_union(cls):
return all(can_use_in_object_var(t) for t in types.get_args(cls))
return (
inspect.isclass(cls)
and not issubclass(cls, Var)
and serializers.can_serialize(cls, dict)
)


@dataclasses.dataclass(
eq=False,
frozen=True,
Expand Down Expand Up @@ -565,36 +582,33 @@ def __format__(self, format_spec: str) -> str:
# Encode the _var_data into the formatted output for tracking purposes.
return f"{constants.REFLEX_VAR_OPENING_TAG}{hashed_var}{constants.REFLEX_VAR_CLOSING_TAG}{self._js_expr}"

@overload
def to(self, output: Type[StringVar]) -> StringVar: ...

@overload
def to(self, output: Type[str]) -> StringVar: ...

@overload
def to(self, output: Type[BooleanVar]) -> BooleanVar: ...
def to(self, output: Type[bool]) -> BooleanVar: ...

@overload
def to(
self, output: Type[NumberVar], var_type: type[int] | type[float] = float
) -> NumberVar: ...
def to(self, output: type[int] | type[float]) -> NumberVar: ...

@overload
def to(
self,
output: Type[ArrayVar],
var_type: type[list] | type[tuple] | type[set] = list,
output: type[list] | type[tuple] | type[set],
) -> ArrayVar: ...

@overload
def to(
self, output: Type[ObjectVar], var_type: types.GenericType = dict
) -> ObjectVar: ...
self, output: Type[ObjectVar], var_type: Type[VAR_INSIDE]
) -> ObjectVar[VAR_INSIDE]: ...

@overload
def to(
self, output: Type[FunctionVar], var_type: Type[Callable] = Callable
) -> FunctionVar: ...
self, output: Type[ObjectVar], var_type: None = None
) -> ObjectVar[VAR_TYPE]: ...

@overload
def to(self, output: VAR_SUBCLASS, var_type: None = None) -> VAR_SUBCLASS: ...

@overload
def to(
Expand Down Expand Up @@ -630,21 +644,19 @@ def to(
return get_to_operation(NoneVar).create(self) # type: ignore

# Handle fixed_output_type being Base or a dataclass.
try:
if issubclass(fixed_output_type, Base):
return self.to(ObjectVar, output)
except TypeError:
pass
if dataclasses.is_dataclass(fixed_output_type) and not issubclass(
fixed_output_type, Var
):
if can_use_in_object_var(fixed_output_type):
return self.to(ObjectVar, output)

if inspect.isclass(output):
for var_subclass in _var_subclasses[::-1]:
if issubclass(output, var_subclass.var_subclass):
current_var_type = self._var_type
if current_var_type is Any:
new_var_type = var_type
else:
new_var_type = var_type or current_var_type
to_operation_return = var_subclass.to_var_subclass.create(
value=self, _var_type=var_type
value=self, _var_type=new_var_type
)
return to_operation_return # type: ignore

Expand Down Expand Up @@ -707,11 +719,7 @@ def guess_type(self) -> Var:
):
return self.to(NumberVar, self._var_type)

if all(
inspect.isclass(t)
and (issubclass(t, Base) or dataclasses.is_dataclass(t))
for t in inner_types
):
if can_use_in_object_var(var_type):
return self.to(ObjectVar, self._var_type)

return self
Expand All @@ -730,13 +738,9 @@ def guess_type(self) -> Var:
if issubclass(fixed_type, var_subclass.python_types):
return self.to(var_subclass.var_subclass, self._var_type)

try:
if issubclass(fixed_type, Base):
return self.to(ObjectVar, self._var_type)
except TypeError:
pass
if dataclasses.is_dataclass(fixed_type):
if can_use_in_object_var(fixed_type):
return self.to(ObjectVar, self._var_type)

return self

def get_default_value(self) -> Any:
Expand Down Expand Up @@ -1181,6 +1185,9 @@ def json(self) -> str:

OUTPUT = TypeVar("OUTPUT", bound=Var)

VAR_SUBCLASS = TypeVar("VAR_SUBCLASS", bound=Var)
VAR_INSIDE = TypeVar("VAR_INSIDE")


class ToOperation:
"""A var operation that converts a var to another type."""
Expand Down Expand Up @@ -2888,6 +2895,8 @@ def dispatch(

V = TypeVar("V")

BASE_TYPE = TypeVar("BASE_TYPE", bound=Base)


class Field(Generic[T]):
"""Shadow class for Var to allow for type hinting in the IDE."""
Expand Down Expand Up @@ -2924,6 +2933,11 @@ def __get__(
self: Field[Dict[str, V]], instance: None, owner
) -> ObjectVar[Dict[str, V]]: ...

@overload
def __get__(
self: Field[BASE_TYPE], instance: None, owner
) -> ObjectVar[BASE_TYPE]: ...

@overload
def __get__(self, instance: None, owner) -> Var[T]: ...

Expand Down
4 changes: 3 additions & 1 deletion reflex/vars/number.py
Original file line number Diff line number Diff line change
Expand Up @@ -1116,7 +1116,9 @@ def boolify(value: Var):


@var_operation
def ternary_operation(condition: BooleanVar, if_true: Var[T], if_false: Var[U]):
def ternary_operation(
condition: BooleanVar, if_true: Var[T], if_false: Var[U]
) -> CustomVarOperationReturn[Union[T, U]]:
"""Create a ternary operation.

Args:
Expand Down
Loading
Loading