From 5f607ff63ba757f98074e29da6384ab6bdcc9de9 Mon Sep 17 00:00:00 2001 From: Faran Ahmad Date: Sun, 22 Dec 2024 16:05:07 -0800 Subject: [PATCH] Fixing broken unit tests regarding module compatibility (#2651) Summary: Pull Request resolved: https://github.com/pytorch/torchrec/pull/2651 Fixing unit test breakage in the task: T211156854 Differential Revision: D67581406 fbshipit-source-id: 7e25585e23f6ec3f4924c8ebdaf82ca1ea0d4294 --- torchrec/distributed/embedding.py | 2 +- torchrec/distributed/quant_embedding.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/torchrec/distributed/embedding.py b/torchrec/distributed/embedding.py index ff2e4449e..6a3d63ba7 100644 --- a/torchrec/distributed/embedding.py +++ b/torchrec/distributed/embedding.py @@ -28,7 +28,7 @@ import torch from torch import distributed as dist, nn from torch.autograd.profiler import record_function -from torch.distributed._shard.sharding_spec.api import EnumerableShardingSpec +from torch.distributed._shard.sharding_spec import EnumerableShardingSpec from torch.distributed._tensor import DTensor from torch.nn.parallel import DistributedDataParallel from torchrec.distributed.embedding_sharding import ( diff --git a/torchrec/distributed/quant_embedding.py b/torchrec/distributed/quant_embedding.py index cb82b690a..2077297b7 100644 --- a/torchrec/distributed/quant_embedding.py +++ b/torchrec/distributed/quant_embedding.py @@ -17,7 +17,7 @@ IntNBitTableBatchedEmbeddingBagsCodegen, ) from torch import nn -from torch.distributed._shard.sharding_spec.api import EnumerableShardingSpec +from torch.distributed._shard.sharding_spec import EnumerableShardingSpec from torchrec.distributed.embedding import ( create_sharding_infos_by_sharding_device_group, EmbeddingShardingInfo,