From 2cec090fa56a730a9982a36fbda4a64ebe646522 Mon Sep 17 00:00:00 2001 From: Leo Grosjean Date: Wed, 25 Sep 2024 22:51:21 +0200 Subject: [PATCH 1/4] EventFnArgMismatch fix to support defaults args --- reflex/event.py | 51 +++++++++++++++++++++++++++++++++---------------- 1 file changed, 35 insertions(+), 16 deletions(-) diff --git a/reflex/event.py b/reflex/event.py index ac0c713ab8..7849b8029a 100644 --- a/reflex/event.py +++ b/reflex/event.py @@ -929,12 +929,44 @@ def parse_args_spec(arg_spec: ArgsSpec): ) +def check_fn_match_arg_spec(fn: Callable, arg_spec: ArgsSpec): + """Ensures that the function signature matches the passed argument specification + or raises an EventFnArgMismatch if they do not. + + Args: + fn (callable): The function to be validated. + arg_spec (Any): The argument specification for the event trigger. + + Returns: + list: The parsed arguments from the argument specification. + + Raises: + EventFnArgMismatch: Raised if the number of mandatory arguments in the + function's signature does not match the number of arguments in the argument specification. + """ + fn_args = inspect.getfullargspec(fn).args + fn_defaults_args = inspect.getfullargspec(fn).defaults + n_fn_args = len(fn_args) + n_fn_defaults_args = len(fn_defaults_args) if fn_defaults_args else 0 + if isinstance(fn, types.MethodType): + n_fn_args -= 1 # subtract 1 for bound self arg + parsed_args = parse_args_spec(arg_spec) + if not (n_fn_args - n_fn_defaults_args <= len(parsed_args) <= n_fn_args): + raise EventFnArgMismatch( + "The number of mandatory arguments accepted by " + f"{fn} ({n_fn_args - n_fn_defaults_args}) " + "does not match the arguments passed by the event trigger: " + f"{[str(v) for v in parsed_args]}\n" + "See https://reflex.dev/docs/events/event-arguments/" + ) + return parsed_args + + def call_event_fn(fn: Callable, arg_spec: ArgsSpec) -> list[EventSpec] | Var: """Call a function to a list of event specs. The function should return a single EventSpec, a list of EventSpecs, or a - single Var. The function signature must match the passed arg_spec or - EventFnArgsMismatch will be raised. + single Var. Args: fn: The function to call. @@ -944,7 +976,6 @@ def call_event_fn(fn: Callable, arg_spec: ArgsSpec) -> list[EventSpec] | Var: The event specs from calling the function or a Var. Raises: - EventFnArgMismatch: If the function signature doesn't match the arg spec. EventHandlerValueError: If the lambda returns an unusable value. """ # Import here to avoid circular imports. @@ -952,19 +983,7 @@ def call_event_fn(fn: Callable, arg_spec: ArgsSpec) -> list[EventSpec] | Var: from reflex.utils.exceptions import EventHandlerValueError # Check that fn signature matches arg_spec - fn_args = inspect.getfullargspec(fn).args - n_fn_args = len(fn_args) - if isinstance(fn, types.MethodType): - n_fn_args -= 1 # subtract 1 for bound self arg - parsed_args = parse_args_spec(arg_spec) - if len(parsed_args) != n_fn_args: - raise EventFnArgMismatch( - "The number of arguments accepted by " - f"{fn} ({n_fn_args}) " - "does not match the arguments passed by the event trigger: " - f"{[str(v) for v in parsed_args]}\n" - "See https://reflex.dev/docs/events/event-arguments/" - ) + parsed_args = check_fn_match_arg_spec(fn, arg_spec) # Call the function with the parsed args. out = fn(*parsed_args) From da90de86a22365892680aac1d0f2bda545e7aac4 Mon Sep 17 00:00:00 2001 From: Leo Grosjean Date: Thu, 26 Sep 2024 09:43:53 +0200 Subject: [PATCH 2/4] fixing type hint and docstring raises --- reflex/event.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/reflex/event.py b/reflex/event.py index 7849b8029a..2142b4d4c9 100644 --- a/reflex/event.py +++ b/reflex/event.py @@ -929,7 +929,7 @@ def parse_args_spec(arg_spec: ArgsSpec): ) -def check_fn_match_arg_spec(fn: Callable, arg_spec: ArgsSpec): +def check_fn_match_arg_spec(fn: Callable | callable, arg_spec: ArgsSpec | Any): """Ensures that the function signature matches the passed argument specification or raises an EventFnArgMismatch if they do not. @@ -941,8 +941,7 @@ def check_fn_match_arg_spec(fn: Callable, arg_spec: ArgsSpec): list: The parsed arguments from the argument specification. Raises: - EventFnArgMismatch: Raised if the number of mandatory arguments in the - function's signature does not match the number of arguments in the argument specification. + EventFnArgMismatch: Raised if the number of mandatory arguments do not match """ fn_args = inspect.getfullargspec(fn).args fn_defaults_args = inspect.getfullargspec(fn).defaults From fec7eb0e118d0a8717edae36cb9df706b083570b Mon Sep 17 00:00:00 2001 From: Khaleel Al-Adhami Date: Thu, 26 Sep 2024 10:58:12 -0700 Subject: [PATCH 3/4] enforce stronger type checking --- reflex/event.py | 10 +++++----- reflex/utils/types.py | 19 +++++++++++++++++-- tests/test_event.py | 2 +- 3 files changed, 23 insertions(+), 8 deletions(-) diff --git a/reflex/event.py b/reflex/event.py index 2142b4d4c9..304a54995f 100644 --- a/reflex/event.py +++ b/reflex/event.py @@ -417,7 +417,7 @@ class FileUpload: on_upload_progress: Optional[Union[EventHandler, Callable]] = None @staticmethod - def on_upload_progress_args_spec(_prog: Dict[str, Union[int, float, bool]]): + def on_upload_progress_args_spec(_prog: Var[Dict[str, Union[int, float, bool]]]): """Args spec for on_upload_progress event handler. Returns: @@ -929,16 +929,16 @@ def parse_args_spec(arg_spec: ArgsSpec): ) -def check_fn_match_arg_spec(fn: Callable | callable, arg_spec: ArgsSpec | Any): +def check_fn_match_arg_spec(fn: Callable, arg_spec: ArgsSpec) -> List[Var]: """Ensures that the function signature matches the passed argument specification or raises an EventFnArgMismatch if they do not. Args: - fn (callable): The function to be validated. - arg_spec (Any): The argument specification for the event trigger. + fn: The function to be validated. + arg_spec: The argument specification for the event trigger. Returns: - list: The parsed arguments from the argument specification. + The parsed arguments from the argument specification. Raises: EventFnArgMismatch: Raised if the number of mandatory arguments do not match diff --git a/reflex/utils/types.py b/reflex/utils/types.py index 63238f67be..41e1ed49a9 100644 --- a/reflex/utils/types.py +++ b/reflex/utils/types.py @@ -9,6 +9,7 @@ import types from functools import cached_property, lru_cache, wraps from typing import ( + TYPE_CHECKING, Any, Callable, ClassVar, @@ -96,8 +97,22 @@ def override(func: Callable) -> Callable: StateVar = Union[PrimitiveType, Base, None] StateIterVar = Union[list, set, tuple] -# ArgsSpec = Callable[[Var], list[Var]] -ArgsSpec = Callable +if TYPE_CHECKING: + from reflex.vars.base import Var + + # ArgsSpec = Callable[[Var], list[Var]] + ArgsSpec = ( + Callable[[], List[Var]] + | Callable[[Var], List[Var]] + | Callable[[Var, Var], List[Var]] + | Callable[[Var, Var, Var], List[Var]] + | Callable[[Var, Var, Var, Var], List[Var]] + | Callable[[Var, Var, Var, Var, Var], List[Var]] + | Callable[[Var, Var, Var, Var, Var, Var], List[Var]] + | Callable[[Var, Var, Var, Var, Var, Var, Var], List[Var]] + ) +else: + ArgsSpec = Callable[..., List[Any]] PrimitiveToAnnotation = { diff --git a/tests/test_event.py b/tests/test_event.py index a152c37493..3996a61014 100644 --- a/tests/test_event.py +++ b/tests/test_event.py @@ -97,7 +97,7 @@ def test_fn_with_args(_, arg1, arg2): test_fn_with_args.__qualname__ = "test_fn_with_args" - def spec(a2: str) -> List[str]: + def spec(a2: Var[str]) -> List[Var[str]]: return [a2] handler = EventHandler(fn=test_fn_with_args) From 678d793b8450e44a9e8f843aee9b7250f519343c Mon Sep 17 00:00:00 2001 From: Khaleel Al-Adhami Date: Thu, 26 Sep 2024 11:10:09 -0700 Subject: [PATCH 4/4] unwrap var annotations :( --- reflex/event.py | 23 +++++++++++++++++++++-- 1 file changed, 21 insertions(+), 2 deletions(-) diff --git a/reflex/event.py b/reflex/event.py index 304a54995f..d8f0a5f0f6 100644 --- a/reflex/event.py +++ b/reflex/event.py @@ -18,10 +18,12 @@ get_type_hints, ) +from typing_extensions import get_args, get_origin + from reflex import constants from reflex.utils import format from reflex.utils.exceptions import EventFnArgMismatch, EventHandlerArgMismatch -from reflex.utils.types import ArgsSpec +from reflex.utils.types import ArgsSpec, GenericType from reflex.vars import VarData from reflex.vars.base import LiteralVar, Var from reflex.vars.function import FunctionStringVar, FunctionVar @@ -910,6 +912,20 @@ def call_event_handler( ) +def unwrap_var_annotation(annotation: GenericType): + """Unwrap a Var annotation or return it as is if it's not Var[X]. + + Args: + annotation: The annotation to unwrap. + + Returns: + The unwrapped annotation. + """ + if get_origin(annotation) is Var and (args := get_args(annotation)): + return args[0] + return annotation + + def parse_args_spec(arg_spec: ArgsSpec): """Parse the args provided in the ArgsSpec of an event trigger. @@ -921,9 +937,12 @@ def parse_args_spec(arg_spec: ArgsSpec): """ spec = inspect.getfullargspec(arg_spec) annotations = get_type_hints(arg_spec) + return arg_spec( *[ - Var(f"_{l_arg}").to(annotations.get(l_arg, FrontendEvent)) + Var(f"_{l_arg}").to( + unwrap_var_annotation(annotations.get(l_arg, FrontendEvent)) + ) for l_arg in spec.args ] )