diff --git a/torchrec/pt2/checks.py b/torchrec/pt2/checks.py index 6bd8140f0..76626a9f8 100644 --- a/torchrec/pt2/checks.py +++ b/torchrec/pt2/checks.py @@ -73,13 +73,13 @@ def pt2_checks_tensor_slice( torch._check(end_offset >= start_offset) -def pt2_checks_all_is_size(list: List[int]) -> List[int]: +def pt2_checks_all_is_size(x: List[int]) -> List[int]: if torch.jit.is_scripting() or not is_pt2_compiling(): - return list + return x - for i in list: + for i in x: torch._check_is_size(i) - return list + return x def pt2_check_size_nonzero(x: torch.Tensor) -> torch.Tensor: