diff --git a/reflex/event.py b/reflex/event.py index ac0c713ab8..d8f0a5f0f6 100644 --- a/reflex/event.py +++ b/reflex/event.py @@ -18,10 +18,12 @@ get_type_hints, ) +from typing_extensions import get_args, get_origin + from reflex import constants from reflex.utils import format from reflex.utils.exceptions import EventFnArgMismatch, EventHandlerArgMismatch -from reflex.utils.types import ArgsSpec +from reflex.utils.types import ArgsSpec, GenericType from reflex.vars import VarData from reflex.vars.base import LiteralVar, Var from reflex.vars.function import FunctionStringVar, FunctionVar @@ -417,7 +419,7 @@ class FileUpload: on_upload_progress: Optional[Union[EventHandler, Callable]] = None @staticmethod - def on_upload_progress_args_spec(_prog: Dict[str, Union[int, float, bool]]): + def on_upload_progress_args_spec(_prog: Var[Dict[str, Union[int, float, bool]]]): """Args spec for on_upload_progress event handler. Returns: @@ -910,6 +912,20 @@ def call_event_handler( ) +def unwrap_var_annotation(annotation: GenericType): + """Unwrap a Var annotation or return it as is if it's not Var[X]. + + Args: + annotation: The annotation to unwrap. + + Returns: + The unwrapped annotation. + """ + if get_origin(annotation) is Var and (args := get_args(annotation)): + return args[0] + return annotation + + def parse_args_spec(arg_spec: ArgsSpec): """Parse the args provided in the ArgsSpec of an event trigger. @@ -921,20 +937,54 @@ def parse_args_spec(arg_spec: ArgsSpec): """ spec = inspect.getfullargspec(arg_spec) annotations = get_type_hints(arg_spec) + return arg_spec( *[ - Var(f"_{l_arg}").to(annotations.get(l_arg, FrontendEvent)) + Var(f"_{l_arg}").to( + unwrap_var_annotation(annotations.get(l_arg, FrontendEvent)) + ) for l_arg in spec.args ] ) +def check_fn_match_arg_spec(fn: Callable, arg_spec: ArgsSpec) -> List[Var]: + """Ensures that the function signature matches the passed argument specification + or raises an EventFnArgMismatch if they do not. + + Args: + fn: The function to be validated. + arg_spec: The argument specification for the event trigger. + + Returns: + The parsed arguments from the argument specification. + + Raises: + EventFnArgMismatch: Raised if the number of mandatory arguments do not match + """ + fn_args = inspect.getfullargspec(fn).args + fn_defaults_args = inspect.getfullargspec(fn).defaults + n_fn_args = len(fn_args) + n_fn_defaults_args = len(fn_defaults_args) if fn_defaults_args else 0 + if isinstance(fn, types.MethodType): + n_fn_args -= 1 # subtract 1 for bound self arg + parsed_args = parse_args_spec(arg_spec) + if not (n_fn_args - n_fn_defaults_args <= len(parsed_args) <= n_fn_args): + raise EventFnArgMismatch( + "The number of mandatory arguments accepted by " + f"{fn} ({n_fn_args - n_fn_defaults_args}) " + "does not match the arguments passed by the event trigger: " + f"{[str(v) for v in parsed_args]}\n" + "See https://reflex.dev/docs/events/event-arguments/" + ) + return parsed_args + + def call_event_fn(fn: Callable, arg_spec: ArgsSpec) -> list[EventSpec] | Var: """Call a function to a list of event specs. The function should return a single EventSpec, a list of EventSpecs, or a - single Var. The function signature must match the passed arg_spec or - EventFnArgsMismatch will be raised. + single Var. Args: fn: The function to call. @@ -944,7 +994,6 @@ def call_event_fn(fn: Callable, arg_spec: ArgsSpec) -> list[EventSpec] | Var: The event specs from calling the function or a Var. Raises: - EventFnArgMismatch: If the function signature doesn't match the arg spec. EventHandlerValueError: If the lambda returns an unusable value. """ # Import here to avoid circular imports. @@ -952,19 +1001,7 @@ def call_event_fn(fn: Callable, arg_spec: ArgsSpec) -> list[EventSpec] | Var: from reflex.utils.exceptions import EventHandlerValueError # Check that fn signature matches arg_spec - fn_args = inspect.getfullargspec(fn).args - n_fn_args = len(fn_args) - if isinstance(fn, types.MethodType): - n_fn_args -= 1 # subtract 1 for bound self arg - parsed_args = parse_args_spec(arg_spec) - if len(parsed_args) != n_fn_args: - raise EventFnArgMismatch( - "The number of arguments accepted by " - f"{fn} ({n_fn_args}) " - "does not match the arguments passed by the event trigger: " - f"{[str(v) for v in parsed_args]}\n" - "See https://reflex.dev/docs/events/event-arguments/" - ) + parsed_args = check_fn_match_arg_spec(fn, arg_spec) # Call the function with the parsed args. out = fn(*parsed_args) diff --git a/reflex/utils/types.py b/reflex/utils/types.py index 63238f67be..41e1ed49a9 100644 --- a/reflex/utils/types.py +++ b/reflex/utils/types.py @@ -9,6 +9,7 @@ import types from functools import cached_property, lru_cache, wraps from typing import ( + TYPE_CHECKING, Any, Callable, ClassVar, @@ -96,8 +97,22 @@ def override(func: Callable) -> Callable: StateVar = Union[PrimitiveType, Base, None] StateIterVar = Union[list, set, tuple] -# ArgsSpec = Callable[[Var], list[Var]] -ArgsSpec = Callable +if TYPE_CHECKING: + from reflex.vars.base import Var + + # ArgsSpec = Callable[[Var], list[Var]] + ArgsSpec = ( + Callable[[], List[Var]] + | Callable[[Var], List[Var]] + | Callable[[Var, Var], List[Var]] + | Callable[[Var, Var, Var], List[Var]] + | Callable[[Var, Var, Var, Var], List[Var]] + | Callable[[Var, Var, Var, Var, Var], List[Var]] + | Callable[[Var, Var, Var, Var, Var, Var], List[Var]] + | Callable[[Var, Var, Var, Var, Var, Var, Var], List[Var]] + ) +else: + ArgsSpec = Callable[..., List[Any]] PrimitiveToAnnotation = { diff --git a/tests/units/test_event.py b/tests/units/test_event.py index a152c37493..3996a61014 100644 --- a/tests/units/test_event.py +++ b/tests/units/test_event.py @@ -97,7 +97,7 @@ def test_fn_with_args(_, arg1, arg2): test_fn_with_args.__qualname__ = "test_fn_with_args" - def spec(a2: str) -> List[str]: + def spec(a2: Var[str]) -> List[Var[str]]: return [a2] handler = EventHandler(fn=test_fn_with_args)