From b8daf38e4154c2e73071e79d7f00fa88922be46e Mon Sep 17 00:00:00 2001 From: Pratik Aher Date: Mon, 8 Jul 2024 13:48:24 -0700 Subject: [PATCH] Bug fix for local run in TensorAllToAllValuesAwaitable Summary: bug fix when local workers is 1. we should take in variable ids instead of one with splits Reviewed By: sarckk Differential Revision: D59491275 --- torchrec/distributed/dist_data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchrec/distributed/dist_data.py b/torchrec/distributed/dist_data.py index be37cc5a3..0f8b50240 100644 --- a/torchrec/distributed/dist_data.py +++ b/torchrec/distributed/dist_data.py @@ -1558,7 +1558,7 @@ def __init__( self._dist_values: torch.Tensor if self._workers == 1: - self._dist_values = input_splits + self._dist_values = input return else: if input.dim() > 1: