Skip to content

Commit

Permalink
chore: typing and lint
Browse files Browse the repository at this point in the history
  • Loading branch information
mathislucka committed Jan 28, 2025
1 parent 965970d commit c528b0d
Showing 1 changed file with 48 additions and 26 deletions.
74 changes: 48 additions & 26 deletions haystack_experimental/core/super_component/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,17 @@
#
# SPDX-License-Identifier: Apache-2.0

from typing import Annotated, Any, Union, get_args, get_origin
from typing import Annotated, Any, TypeVar, Union, get_args, get_origin

from haystack.core.component.types import HAYSTACK_GREEDY_VARIADIC_ANNOTATION, HAYSTACK_VARIADIC_ANNOTATION


class _delegate_default:
"""Custom object for delegating filling of default values to the underlying components."""

def is_compatible(type1, type2, unwrap_nested: bool = True) -> bool:
T = TypeVar("T")

def is_compatible(type1: T, type2: T, unwrap_nested: bool = True) -> bool:
"""
Check if two types are compatible (bidirectional/symmetric check).
Expand All @@ -26,45 +28,51 @@ def is_compatible(type1, type2, unwrap_nested: bool = True) -> bool:
return _types_are_compatible(type1_unwrapped, type2_unwrapped)


def _types_are_compatible(type1, type2) -> bool: # noqa: PLR0911
def _types_are_compatible(type1: T, type2: T) -> bool:
"""
Core type compatibility check implementing symmetric matching.
:param type1: First unwrapped type to compare
:param type2: Second unwrapped type to compare
:return: True if types are compatible, False otherwise
"""
if type1 is Any or type2 is Any:
return True

if type1 == type2:
# Handle Any type and direct equality
if type1 is Any or type2 is Any or type1 == type2:
return True

type1_origin = get_origin(type1)
type2_origin = get_origin(type2)

# Handle Union types (both directions)
# Handle Union types
if type1_origin is Union or type2_origin is Union:
return _check_union_compatibility(type1, type2, type1_origin, type2_origin)

# Handle non-Union types
return _check_non_union_compatibility(type1, type2, type1_origin, type2_origin)


def _check_union_compatibility(type1: T, type2: T, type1_origin: Any, type2_origin: Any) -> bool:
"""Handle all Union type compatibility cases."""
if type1_origin is Union and type2_origin is not Union:
return any(_types_are_compatible(union_arg, type2)
for union_arg in get_args(type1))
if type2_origin is Union and type1_origin is not Union:
return any(_types_are_compatible(type1, union_arg)
for union_arg in get_args(type2))
if type1_origin is Union and type2_origin is Union:
return any(any(_types_are_compatible(arg1, arg2)
for arg2 in get_args(type2))
for arg1 in get_args(type1))
# Both are Union types
return any(any(_types_are_compatible(arg1, arg2)
for arg2 in get_args(type2))
for arg1 in get_args(type1))


def _check_non_union_compatibility(type1: T, type2: T, type1_origin: Any, type2_origin: Any) -> bool:
"""Handle non-Union type compatibility cases."""
# If no origin, compare types directly
if not type1_origin and not type2_origin:
return type1 == type2

# If only one has origin, they're incompatible
if not type1_origin or not type2_origin:
return False

# If different origins, types are incompatible
if type1_origin != type2_origin:
# Both must have origins and they must be equal
if not (type1_origin and type2_origin and type1_origin == type2_origin):
return False

# Compare generic type arguments
Expand All @@ -77,8 +85,22 @@ def _types_are_compatible(type1, type2) -> bool: # noqa: PLR0911
return all(_types_are_compatible(t1_arg, t2_arg)
for t1_arg, t2_arg in zip(type1_args, type2_args))

def _handle_union_type_matches(type1: T, type2: T, type1_origin: T, type2_origin: T) -> bool:
"""
Handles cases where either type is Union.
"""
if type1_origin is Union and type2_origin is not Union:
return any(_types_are_compatible(union_arg, type2)
for union_arg in get_args(type1))
if type2_origin is Union and type1_origin is not Union:
return any(_types_are_compatible(type1, union_arg)
for union_arg in get_args(type2))
else:
return any(any(_types_are_compatible(arg1, arg2)
for arg2 in get_args(type2))
for arg1 in get_args(type1))

def _unwrap_all(t, recursive: bool):
def _unwrap_all(t: T, recursive: bool) -> T:
"""
Unwrap a type until no more unwrapping is possible.
Expand All @@ -105,7 +127,7 @@ def _unwrap_all(t, recursive: bool):
return t


def _is_variadic_type(t) -> bool:
def _is_variadic_type(t: T) -> bool:
"""Check if type is a Variadic or GreedyVariadic type."""
origin = get_origin(t)
if origin is Annotated:
Expand All @@ -114,7 +136,7 @@ def _is_variadic_type(t) -> bool:
return False


def _is_optional_type(t) -> bool:
def _is_optional_type(t: T) -> bool:
"""Check if type is an Optional type."""
origin = get_origin(t)
if origin is Union:
Expand All @@ -123,7 +145,7 @@ def _is_optional_type(t) -> bool:
return False


def _unwrap_variadics(t, recursive: bool):
def _unwrap_variadics(t: T, recursive: bool) -> T:
"""
Unwrap Variadic or GreedyVariadic annotated types.
Expand All @@ -145,7 +167,7 @@ def _unwrap_variadics(t, recursive: bool):
return inner_type


def _unwrap_optionals(t, recursive: bool):
def _unwrap_optionals(t: T, recursive: bool) -> T:
"""
Unwrap Optional[...] types (Union[X, None]).
Expand All @@ -158,9 +180,9 @@ def _unwrap_optionals(t, recursive: bool):

args = list(get_args(t))
args.remove(type(None))
result = args[0] if len(args) == 1 else Union[tuple(args)]
result = args[0] if len(args) == 1 else Union[tuple(args)] # type: ignore

# Only recursively unwrap if requested
if recursive:
return _unwrap_all(result, recursive)
return result
return _unwrap_all(result, recursive) # type: ignore
return result # type: ignore

0 comments on commit c528b0d

Please sign in to comment.