Skip to content

Commit

Permalink
EventFnArgMismatch fix to support defaults args (#4004)
Browse files Browse the repository at this point in the history
* EventFnArgMismatch fix to support defaults args

* fixing type hint and docstring raises

* enforce stronger type checking

* unwrap var annotations :(

---------

Co-authored-by: Khaleel Al-Adhami <[email protected]>
  • Loading branch information
LeoGrosjean and adhami3310 authored Sep 26, 2024
1 parent 54c7b5a commit 60276cf
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 22 deletions.
75 changes: 56 additions & 19 deletions reflex/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,12 @@
get_type_hints,
)

from typing_extensions import get_args, get_origin

from reflex import constants
from reflex.utils import format
from reflex.utils.exceptions import EventFnArgMismatch, EventHandlerArgMismatch
from reflex.utils.types import ArgsSpec
from reflex.utils.types import ArgsSpec, GenericType
from reflex.vars import VarData
from reflex.vars.base import LiteralVar, Var
from reflex.vars.function import FunctionStringVar, FunctionVar
Expand Down Expand Up @@ -417,7 +419,7 @@ class FileUpload:
on_upload_progress: Optional[Union[EventHandler, Callable]] = None

@staticmethod
def on_upload_progress_args_spec(_prog: Dict[str, Union[int, float, bool]]):
def on_upload_progress_args_spec(_prog: Var[Dict[str, Union[int, float, bool]]]):
"""Args spec for on_upload_progress event handler.
Returns:
Expand Down Expand Up @@ -910,6 +912,20 @@ def call_event_handler(
)


def unwrap_var_annotation(annotation: GenericType):
"""Unwrap a Var annotation or return it as is if it's not Var[X].
Args:
annotation: The annotation to unwrap.
Returns:
The unwrapped annotation.
"""
if get_origin(annotation) is Var and (args := get_args(annotation)):
return args[0]
return annotation


def parse_args_spec(arg_spec: ArgsSpec):
"""Parse the args provided in the ArgsSpec of an event trigger.
Expand All @@ -921,20 +937,54 @@ def parse_args_spec(arg_spec: ArgsSpec):
"""
spec = inspect.getfullargspec(arg_spec)
annotations = get_type_hints(arg_spec)

return arg_spec(
*[
Var(f"_{l_arg}").to(annotations.get(l_arg, FrontendEvent))
Var(f"_{l_arg}").to(
unwrap_var_annotation(annotations.get(l_arg, FrontendEvent))
)
for l_arg in spec.args
]
)


def check_fn_match_arg_spec(fn: Callable, arg_spec: ArgsSpec) -> List[Var]:
"""Ensures that the function signature matches the passed argument specification
or raises an EventFnArgMismatch if they do not.
Args:
fn: The function to be validated.
arg_spec: The argument specification for the event trigger.
Returns:
The parsed arguments from the argument specification.
Raises:
EventFnArgMismatch: Raised if the number of mandatory arguments do not match
"""
fn_args = inspect.getfullargspec(fn).args
fn_defaults_args = inspect.getfullargspec(fn).defaults
n_fn_args = len(fn_args)
n_fn_defaults_args = len(fn_defaults_args) if fn_defaults_args else 0
if isinstance(fn, types.MethodType):
n_fn_args -= 1 # subtract 1 for bound self arg
parsed_args = parse_args_spec(arg_spec)
if not (n_fn_args - n_fn_defaults_args <= len(parsed_args) <= n_fn_args):
raise EventFnArgMismatch(
"The number of mandatory arguments accepted by "
f"{fn} ({n_fn_args - n_fn_defaults_args}) "
"does not match the arguments passed by the event trigger: "
f"{[str(v) for v in parsed_args]}\n"
"See https://reflex.dev/docs/events/event-arguments/"
)
return parsed_args


def call_event_fn(fn: Callable, arg_spec: ArgsSpec) -> list[EventSpec] | Var:
"""Call a function to a list of event specs.
The function should return a single EventSpec, a list of EventSpecs, or a
single Var. The function signature must match the passed arg_spec or
EventFnArgsMismatch will be raised.
single Var.
Args:
fn: The function to call.
Expand All @@ -944,27 +994,14 @@ def call_event_fn(fn: Callable, arg_spec: ArgsSpec) -> list[EventSpec] | Var:
The event specs from calling the function or a Var.
Raises:
EventFnArgMismatch: If the function signature doesn't match the arg spec.
EventHandlerValueError: If the lambda returns an unusable value.
"""
# Import here to avoid circular imports.
from reflex.event import EventHandler, EventSpec
from reflex.utils.exceptions import EventHandlerValueError

# Check that fn signature matches arg_spec
fn_args = inspect.getfullargspec(fn).args
n_fn_args = len(fn_args)
if isinstance(fn, types.MethodType):
n_fn_args -= 1 # subtract 1 for bound self arg
parsed_args = parse_args_spec(arg_spec)
if len(parsed_args) != n_fn_args:
raise EventFnArgMismatch(
"The number of arguments accepted by "
f"{fn} ({n_fn_args}) "
"does not match the arguments passed by the event trigger: "
f"{[str(v) for v in parsed_args]}\n"
"See https://reflex.dev/docs/events/event-arguments/"
)
parsed_args = check_fn_match_arg_spec(fn, arg_spec)

# Call the function with the parsed args.
out = fn(*parsed_args)
Expand Down
19 changes: 17 additions & 2 deletions reflex/utils/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import types
from functools import cached_property, lru_cache, wraps
from typing import (
TYPE_CHECKING,
Any,
Callable,
ClassVar,
Expand Down Expand Up @@ -96,8 +97,22 @@ def override(func: Callable) -> Callable:
StateVar = Union[PrimitiveType, Base, None]
StateIterVar = Union[list, set, tuple]

# ArgsSpec = Callable[[Var], list[Var]]
ArgsSpec = Callable
if TYPE_CHECKING:
from reflex.vars.base import Var

# ArgsSpec = Callable[[Var], list[Var]]
ArgsSpec = (
Callable[[], List[Var]]
| Callable[[Var], List[Var]]
| Callable[[Var, Var], List[Var]]
| Callable[[Var, Var, Var], List[Var]]
| Callable[[Var, Var, Var, Var], List[Var]]
| Callable[[Var, Var, Var, Var, Var], List[Var]]
| Callable[[Var, Var, Var, Var, Var, Var], List[Var]]
| Callable[[Var, Var, Var, Var, Var, Var, Var], List[Var]]
)
else:
ArgsSpec = Callable[..., List[Any]]


PrimitiveToAnnotation = {
Expand Down
2 changes: 1 addition & 1 deletion tests/units/test_event.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ def test_fn_with_args(_, arg1, arg2):

test_fn_with_args.__qualname__ = "test_fn_with_args"

def spec(a2: str) -> List[str]:
def spec(a2: Var[str]) -> List[Var[str]]:
return [a2]

handler = EventHandler(fn=test_fn_with_args)
Expand Down

0 comments on commit 60276cf

Please sign in to comment.