From d8a6f7a5c8d28084b3a4f7fe0bf3658dd6aff2d3 Mon Sep 17 00:00:00 2001 From: Aaron Orenstein Date: Tue, 21 Jan 2025 18:34:30 -0800 Subject: [PATCH] Teach is_signature_compatible() to dig into similar annotations Summary: D68450007 updated some annotations in pytorch. This function wasn't correctly evaluating `typing.Dict[X, Y]` and `dict[X, Y]` as the equivalent. Differential Revision: D68475380 --- torchrec/schema/utils.py | 30 +++++++++++++++++++++++++++++- 1 file changed, 29 insertions(+), 1 deletion(-) diff --git a/torchrec/schema/utils.py b/torchrec/schema/utils.py index 0f9b897cb..ba4b5beaa 100644 --- a/torchrec/schema/utils.py +++ b/torchrec/schema/utils.py @@ -8,6 +8,32 @@ # pyre-strict import inspect +import typing +from typing import Any + + +def _is_annot_compatible(prev: Any, curr: Any) -> bool: + if prev == curr: + return True + + if not (prev_origin := typing.get_origin(prev)): + return False + if not (curr_origin := typing.get_origin(curr)): + return False + + if prev_origin != curr_origin: + return False + + prev_args = typing.get_args(prev) + curr_args = typing.get_args(curr) + if len(prev_args) != len(curr_args): + return False + + for prev_arg, curr_arg in zip(prev_args, curr_args): + if not _is_annot_compatible(prev_arg, curr_arg): + return False + + return True def is_signature_compatible( @@ -84,6 +110,8 @@ def is_signature_compatible( return False # TODO: Account for Union Types? - if current_signature.return_annotation != previous_signature.return_annotation: + if not _is_annot_compatible( + previous_signature.return_annotation, current_signature.return_annotation + ): return False return True