Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

EventFnArgMismatch fix to support defaults args #4004

Merged
merged 5 commits into from
Sep 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
75 changes: 56 additions & 19 deletions reflex/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,12 @@
get_type_hints,
)

from typing_extensions import get_args, get_origin

from reflex import constants
from reflex.utils import format
from reflex.utils.exceptions import EventFnArgMismatch, EventHandlerArgMismatch
from reflex.utils.types import ArgsSpec
from reflex.utils.types import ArgsSpec, GenericType
from reflex.vars import VarData
from reflex.vars.base import LiteralVar, Var
from reflex.vars.function import FunctionStringVar, FunctionVar
Expand Down Expand Up @@ -417,7 +419,7 @@ class FileUpload:
on_upload_progress: Optional[Union[EventHandler, Callable]] = None

@staticmethod
def on_upload_progress_args_spec(_prog: Dict[str, Union[int, float, bool]]):
def on_upload_progress_args_spec(_prog: Var[Dict[str, Union[int, float, bool]]]):
"""Args spec for on_upload_progress event handler.

Returns:
Expand Down Expand Up @@ -910,6 +912,20 @@ def call_event_handler(
)


def unwrap_var_annotation(annotation: GenericType):
"""Unwrap a Var annotation or return it as is if it's not Var[X].

Args:
annotation: The annotation to unwrap.

Returns:
The unwrapped annotation.
"""
if get_origin(annotation) is Var and (args := get_args(annotation)):
return args[0]
return annotation


def parse_args_spec(arg_spec: ArgsSpec):
"""Parse the args provided in the ArgsSpec of an event trigger.

Expand All @@ -921,20 +937,54 @@ def parse_args_spec(arg_spec: ArgsSpec):
"""
spec = inspect.getfullargspec(arg_spec)
annotations = get_type_hints(arg_spec)

return arg_spec(
*[
Var(f"_{l_arg}").to(annotations.get(l_arg, FrontendEvent))
Var(f"_{l_arg}").to(
unwrap_var_annotation(annotations.get(l_arg, FrontendEvent))
)
for l_arg in spec.args
]
)


def check_fn_match_arg_spec(fn: Callable, arg_spec: ArgsSpec) -> List[Var]:
"""Ensures that the function signature matches the passed argument specification
or raises an EventFnArgMismatch if they do not.

Args:
fn: The function to be validated.
arg_spec: The argument specification for the event trigger.

Returns:
The parsed arguments from the argument specification.

Raises:
EventFnArgMismatch: Raised if the number of mandatory arguments do not match
"""
fn_args = inspect.getfullargspec(fn).args
fn_defaults_args = inspect.getfullargspec(fn).defaults
n_fn_args = len(fn_args)
n_fn_defaults_args = len(fn_defaults_args) if fn_defaults_args else 0
if isinstance(fn, types.MethodType):
n_fn_args -= 1 # subtract 1 for bound self arg
parsed_args = parse_args_spec(arg_spec)
if not (n_fn_args - n_fn_defaults_args <= len(parsed_args) <= n_fn_args):
raise EventFnArgMismatch(
"The number of mandatory arguments accepted by "
f"{fn} ({n_fn_args - n_fn_defaults_args}) "
"does not match the arguments passed by the event trigger: "
f"{[str(v) for v in parsed_args]}\n"
"See https://reflex.dev/docs/events/event-arguments/"
)
return parsed_args


def call_event_fn(fn: Callable, arg_spec: ArgsSpec) -> list[EventSpec] | Var:
"""Call a function to a list of event specs.

The function should return a single EventSpec, a list of EventSpecs, or a
single Var. The function signature must match the passed arg_spec or
EventFnArgsMismatch will be raised.
single Var.

Args:
fn: The function to call.
Expand All @@ -944,27 +994,14 @@ def call_event_fn(fn: Callable, arg_spec: ArgsSpec) -> list[EventSpec] | Var:
The event specs from calling the function or a Var.

Raises:
EventFnArgMismatch: If the function signature doesn't match the arg spec.
EventHandlerValueError: If the lambda returns an unusable value.
"""
# Import here to avoid circular imports.
from reflex.event import EventHandler, EventSpec
from reflex.utils.exceptions import EventHandlerValueError

# Check that fn signature matches arg_spec
fn_args = inspect.getfullargspec(fn).args
n_fn_args = len(fn_args)
if isinstance(fn, types.MethodType):
n_fn_args -= 1 # subtract 1 for bound self arg
parsed_args = parse_args_spec(arg_spec)
if len(parsed_args) != n_fn_args:
raise EventFnArgMismatch(
"The number of arguments accepted by "
f"{fn} ({n_fn_args}) "
"does not match the arguments passed by the event trigger: "
f"{[str(v) for v in parsed_args]}\n"
"See https://reflex.dev/docs/events/event-arguments/"
)
parsed_args = check_fn_match_arg_spec(fn, arg_spec)

# Call the function with the parsed args.
out = fn(*parsed_args)
Expand Down
19 changes: 17 additions & 2 deletions reflex/utils/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import types
from functools import cached_property, lru_cache, wraps
from typing import (
TYPE_CHECKING,
Any,
Callable,
ClassVar,
Expand Down Expand Up @@ -96,8 +97,22 @@ def override(func: Callable) -> Callable:
StateVar = Union[PrimitiveType, Base, None]
StateIterVar = Union[list, set, tuple]

# ArgsSpec = Callable[[Var], list[Var]]
ArgsSpec = Callable
if TYPE_CHECKING:
from reflex.vars.base import Var

# ArgsSpec = Callable[[Var], list[Var]]
ArgsSpec = (
Callable[[], List[Var]]
| Callable[[Var], List[Var]]
| Callable[[Var, Var], List[Var]]
| Callable[[Var, Var, Var], List[Var]]
| Callable[[Var, Var, Var, Var], List[Var]]
| Callable[[Var, Var, Var, Var, Var], List[Var]]
| Callable[[Var, Var, Var, Var, Var, Var], List[Var]]
| Callable[[Var, Var, Var, Var, Var, Var, Var], List[Var]]
)
else:
ArgsSpec = Callable[..., List[Any]]


PrimitiveToAnnotation = {
Expand Down
2 changes: 1 addition & 1 deletion tests/units/test_event.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def test_fn_with_args(_, arg1, arg2):

test_fn_with_args.__qualname__ = "test_fn_with_args"

def spec(a2: str) -> List[str]:
def spec(a2: Var[str]) -> List[Var[str]]:
return [a2]

handler = EventHandler(fn=test_fn_with_args)
Expand Down
Loading