From d4b896f9f5e05f8908d0d7c9c811384a5d1d3bdd Mon Sep 17 00:00:00 2001 From: Aaron Orenstein Date: Wed, 22 Jan 2025 16:37:24 -0800 Subject: [PATCH] Teach is_signature_compatible() to dig into similar annotations (#2693) Summary: Pull Request resolved: https://github.com/pytorch/torchrec/pull/2693 D68450007 updated some annotations in pytorch. This function wasn't correctly evaluating `typing.Dict[X, Y]` and `dict[X, Y]` as the equivalent. Reviewed By: izaitsevfb Differential Revision: D68475380 fbshipit-source-id: 3b71ab41f95e6c20986ebe6fbf6f9cbe3b3d58f9 --- torchrec/schema/utils.py | 30 +++++++++++++++++++++++++++++- 1 file changed, 29 insertions(+), 1 deletion(-) diff --git a/torchrec/schema/utils.py b/torchrec/schema/utils.py index 0f9b897cb..b4f8a6075 100644 --- a/torchrec/schema/utils.py +++ b/torchrec/schema/utils.py @@ -8,6 +8,32 @@ # pyre-strict import inspect +import typing +from typing import Any + + +def _is_annot_compatible(prev: object, curr: object) -> bool: + if prev == curr: + return True + + if not (prev_origin := typing.get_origin(prev)): + return False + if not (curr_origin := typing.get_origin(curr)): + return False + + if prev_origin != curr_origin: + return False + + prev_args = typing.get_args(prev) + curr_args = typing.get_args(curr) + if len(prev_args) != len(curr_args): + return False + + for prev_arg, curr_arg in zip(prev_args, curr_args): + if not _is_annot_compatible(prev_arg, curr_arg): + return False + + return True def is_signature_compatible( @@ -84,6 +110,8 @@ def is_signature_compatible( return False # TODO: Account for Union Types? - if current_signature.return_annotation != previous_signature.return_annotation: + if not _is_annot_compatible( + previous_signature.return_annotation, current_signature.return_annotation + ): return False return True