diff --git a/haystack_experimental/core/super_component/utils.py b/haystack_experimental/core/super_component/utils.py index ba79d0c4..c65e2a63 100644 --- a/haystack_experimental/core/super_component/utils.py +++ b/haystack_experimental/core/super_component/utils.py @@ -2,7 +2,7 @@ # # SPDX-License-Identifier: Apache-2.0 -from typing import Annotated, Any, Union, get_args, get_origin +from typing import Annotated, Any, TypeVar, Union, get_args, get_origin from haystack.core.component.types import HAYSTACK_GREEDY_VARIADIC_ANNOTATION, HAYSTACK_VARIADIC_ANNOTATION @@ -10,7 +10,9 @@ class _delegate_default: """Custom object for delegating filling of default values to the underlying components.""" -def is_compatible(type1, type2, unwrap_nested: bool = True) -> bool: +T = TypeVar("T") + +def is_compatible(type1: T, type2: T, unwrap_nested: bool = True) -> bool: """ Check if two types are compatible (bidirectional/symmetric check). @@ -26,7 +28,7 @@ def is_compatible(type1, type2, unwrap_nested: bool = True) -> bool: return _types_are_compatible(type1_unwrapped, type2_unwrapped) -def _types_are_compatible(type1, type2) -> bool: # noqa: PLR0911 +def _types_are_compatible(type1: T, type2: T) -> bool: """ Core type compatibility check implementing symmetric matching. @@ -34,37 +36,43 @@ def _types_are_compatible(type1, type2) -> bool: # noqa: PLR0911 :param type2: Second unwrapped type to compare :return: True if types are compatible, False otherwise """ - if type1 is Any or type2 is Any: - return True - - if type1 == type2: + # Handle Any type and direct equality + if type1 is Any or type2 is Any or type1 == type2: return True type1_origin = get_origin(type1) type2_origin = get_origin(type2) - # Handle Union types (both directions) + # Handle Union types + if type1_origin is Union or type2_origin is Union: + return _check_union_compatibility(type1, type2, type1_origin, type2_origin) + + # Handle non-Union types + return _check_non_union_compatibility(type1, type2, type1_origin, type2_origin) + + +def _check_union_compatibility(type1: T, type2: T, type1_origin: Any, type2_origin: Any) -> bool: + """Handle all Union type compatibility cases.""" if type1_origin is Union and type2_origin is not Union: return any(_types_are_compatible(union_arg, type2) for union_arg in get_args(type1)) if type2_origin is Union and type1_origin is not Union: return any(_types_are_compatible(type1, union_arg) for union_arg in get_args(type2)) - if type1_origin is Union and type2_origin is Union: - return any(any(_types_are_compatible(arg1, arg2) - for arg2 in get_args(type2)) - for arg1 in get_args(type1)) + # Both are Union types + return any(any(_types_are_compatible(arg1, arg2) + for arg2 in get_args(type2)) + for arg1 in get_args(type1)) + +def _check_non_union_compatibility(type1: T, type2: T, type1_origin: Any, type2_origin: Any) -> bool: + """Handle non-Union type compatibility cases.""" # If no origin, compare types directly if not type1_origin and not type2_origin: return type1 == type2 - # If only one has origin, they're incompatible - if not type1_origin or not type2_origin: - return False - - # If different origins, types are incompatible - if type1_origin != type2_origin: + # Both must have origins and they must be equal + if not (type1_origin and type2_origin and type1_origin == type2_origin): return False # Compare generic type arguments @@ -77,8 +85,22 @@ def _types_are_compatible(type1, type2) -> bool: # noqa: PLR0911 return all(_types_are_compatible(t1_arg, t2_arg) for t1_arg, t2_arg in zip(type1_args, type2_args)) +def _handle_union_type_matches(type1: T, type2: T, type1_origin: T, type2_origin: T) -> bool: + """ + Handles cases where either type is Union. + """ + if type1_origin is Union and type2_origin is not Union: + return any(_types_are_compatible(union_arg, type2) + for union_arg in get_args(type1)) + if type2_origin is Union and type1_origin is not Union: + return any(_types_are_compatible(type1, union_arg) + for union_arg in get_args(type2)) + else: + return any(any(_types_are_compatible(arg1, arg2) + for arg2 in get_args(type2)) + for arg1 in get_args(type1)) -def _unwrap_all(t, recursive: bool): +def _unwrap_all(t: T, recursive: bool) -> T: """ Unwrap a type until no more unwrapping is possible. @@ -105,7 +127,7 @@ def _unwrap_all(t, recursive: bool): return t -def _is_variadic_type(t) -> bool: +def _is_variadic_type(t: T) -> bool: """Check if type is a Variadic or GreedyVariadic type.""" origin = get_origin(t) if origin is Annotated: @@ -114,7 +136,7 @@ def _is_variadic_type(t) -> bool: return False -def _is_optional_type(t) -> bool: +def _is_optional_type(t: T) -> bool: """Check if type is an Optional type.""" origin = get_origin(t) if origin is Union: @@ -123,7 +145,7 @@ def _is_optional_type(t) -> bool: return False -def _unwrap_variadics(t, recursive: bool): +def _unwrap_variadics(t: T, recursive: bool) -> T: """ Unwrap Variadic or GreedyVariadic annotated types. @@ -145,7 +167,7 @@ def _unwrap_variadics(t, recursive: bool): return inner_type -def _unwrap_optionals(t, recursive: bool): +def _unwrap_optionals(t: T, recursive: bool) -> T: """ Unwrap Optional[...] types (Union[X, None]). @@ -158,9 +180,9 @@ def _unwrap_optionals(t, recursive: bool): args = list(get_args(t)) args.remove(type(None)) - result = args[0] if len(args) == 1 else Union[tuple(args)] + result = args[0] if len(args) == 1 else Union[tuple(args)] # type: ignore # Only recursively unwrap if requested if recursive: - return _unwrap_all(result, recursive) - return result + return _unwrap_all(result, recursive) # type: ignore + return result # type: ignore