diff --git a/torchrec/schema/utils.py b/torchrec/schema/utils.py index 0f9b897cb..b4f8a6075 100644 --- a/torchrec/schema/utils.py +++ b/torchrec/schema/utils.py @@ -8,6 +8,32 @@ # pyre-strict import inspect +import typing +from typing import Any + + +def _is_annot_compatible(prev: object, curr: object) -> bool: + if prev == curr: + return True + + if not (prev_origin := typing.get_origin(prev)): + return False + if not (curr_origin := typing.get_origin(curr)): + return False + + if prev_origin != curr_origin: + return False + + prev_args = typing.get_args(prev) + curr_args = typing.get_args(curr) + if len(prev_args) != len(curr_args): + return False + + for prev_arg, curr_arg in zip(prev_args, curr_args): + if not _is_annot_compatible(prev_arg, curr_arg): + return False + + return True def is_signature_compatible( @@ -84,6 +110,8 @@ def is_signature_compatible( return False # TODO: Account for Union Types? - if current_signature.return_annotation != previous_signature.return_annotation: + if not _is_annot_compatible( + previous_signature.return_annotation, current_signature.return_annotation + ): return False return True