diff --git a/torchrec/distributed/mc_modules.py b/torchrec/distributed/mc_modules.py index 5ab9fcefc..e513b3e35 100644 --- a/torchrec/distributed/mc_modules.py +++ b/torchrec/distributed/mc_modules.py @@ -663,7 +663,7 @@ def compute( values=features.values(), lengths=features.lengths(), # TODO: improve this temp solution by passing real weights - weights=torch.tensor(kjt.length_per_key()), + weights=torch.tensor(features.length_per_key()), ) } mcm = self._managed_collision_modules[table]