From 584bb441309ae286b7fb42cbaf54ae88f77a595f Mon Sep 17 00:00:00 2001 From: Shabab Ayub Date: Thu, 1 Feb 2024 09:35:35 -0800 Subject: [PATCH] Deepcopy FP module even if on meta device (#1676) Summary: Pull Request resolved: https://github.com/pytorch/torchrec/pull/1676 When we fx trace, even if there are 2 FP modules (because 2 cards), since it was sharded on meta, the ranks just have a reference to the FP on rank 0 and for whatever reason, FX eliminates the FP on rank 1 and it just shows the one on rank 0 do a deepcopy even when on meta device so each rank explicitly has their own copy, fx will persist it Reviewed By: lequytra, tissue3 Differential Revision: D53294788 fbshipit-source-id: 8056241c8c47a40e7d7d200e07de741b3dd24603 --- torchrec/distributed/quant_embeddingbag.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/torchrec/distributed/quant_embeddingbag.py b/torchrec/distributed/quant_embeddingbag.py index 2d724db7b..de1b58993 100644 --- a/torchrec/distributed/quant_embeddingbag.py +++ b/torchrec/distributed/quant_embeddingbag.py @@ -5,6 +5,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import copy from typing import Any, Dict, List, Optional, Type import torch @@ -409,7 +410,7 @@ def __init__( self.feature_processors_per_rank: nn.ModuleList = torch.nn.ModuleList() for i in range(env.world_size): self.feature_processors_per_rank.append( - feature_processor + copy.deepcopy(feature_processor) if device_type == "meta" else copy_to_device( feature_processor,