Skip to content

Commit

Permalink
add NJT/TD support for EC (pytorch#2596)
Browse files Browse the repository at this point in the history
Summary:

# Documents
* [TorchRec NJT Work Items](https://fburl.com/gdoc/gcqq6luv)
* [KJT <> TensorDict](https://docs.google.com/document/d/1zqJL5AESnoKeIt5VZ6K1289fh_1QcSwu76yo0nB4Ecw/edit?tab=t.0#heading=h.bn9zwvg79)
 {F1949248817} 

# Context
* Continued from previous D66465376, which adds NJT/TD support for EBC, this diff is for EC
* As depicted above, we are extending TorchRec input data type from KJT (KeyedJaggedTensor) to TD (TensorDict)
* Basically we can support TensorDict in both **eager mode** and **distributed (sharded) mode**: `Input (Union[KJT, TD]) ==> EC ==> Output (KT)`
* In eager mode, we directly call `td_to_kjt` in the forward function to convert TD to KJT.
* In distributed mode, we do the conversion inside the `ShardedEmbeddingCollection`, specifically in the `input_dist`, where the input sparse features are prepared (permuted) for the `KJTAllToAll` communication.
* In the KJT scenario, the input KJT would be permuted (and partially duplicated in some cases), followed by the `KJTAllToAll` communication. 
While in the TD scenario, the input TD will directly be converted to the permuted KJT ready for the following `KJTAllToAll` communication.
NOTE: This diff re-used a number of existing test framework/cases with minimal critical changes in the `EmbeddingCollection` and `shardedEmbeddingCollection`. Please see the follwoing verifications for NJT/TD correctness.

# Verification - input with TensorDict
* breakpoint at [sharding_single_rank_test](https://fburl.com/code/x74s13fd)
* sharded model
```
(Pdb) local_model
DistributedModelParallel(
  (_dmp_wrapped_module): DistributedDataParallel(
    (module): TestSequenceSparseNN(
      (dense): TestDenseArch(
        (linear): Linear(in_features=16, out_features=8, bias=True)
      )
      (sparse): TestSequenceSparseArch(
        (ec): ShardedEmbeddingCollection(
          (lookups): 
           GroupedEmbeddingsLookup(
              (_emb_modules): ModuleList(
                (0): BatchedDenseEmbedding(
                  (_emb_module): DenseTableBatchedEmbeddingBagsCodegen()
                )
              )
            )
           (_input_dists): 
           RwSparseFeaturesDist(
              (_dist): KJTAllToAll()
            )
           (_output_dists): 
           RwSequenceEmbeddingDist(
              (_dist): SequenceEmbeddingsAllToAll()
            )
          (embeddings): ModuleDict(
            (table_0): Module()
            (table_1): Module()
            (table_2): Module()
            (table_3): Module()
            (table_4): Module()
            (table_5): Module()
          )
        )
      )
      (over): TestSequenceOverArch(
        (linear): Linear(in_features=1928, out_features=16, bias=True)
      )
    )
  )
)
```
* TD input
```
(Pdb) local_input
ModelInput(float_features=tensor([[0.8893, 0.6990, 0.6512, 0.9617, 0.5531, 0.9029, 0.8455, 0.9288, 0.2433,
         0.8901, 0.8849, 0.3849, 0.4535, 0.9318, 0.5002, 0.8056],
        [0.1978, 0.4822, 0.2907, 0.9947, 0.6707, 0.4246, 0.2294, 0.6623, 0.7146,
         0.1914, 0.6517, 0.9449, 0.5650, 0.2358, 0.6787, 0.3671],
        [0.3964, 0.6190, 0.7695, 0.6526, 0.7095, 0.2790, 0.0581, 0.2470, 0.8315,
         0.9374, 0.0215, 0.3572, 0.0516, 0.1447, 0.0811, 0.2678],
        [0.0475, 0.9740, 0.0039, 0.6126, 0.9783, 0.5080, 0.5583, 0.0703, 0.8320,
         0.9837, 0.3936, 0.6329, 0.8229, 0.8486, 0.7715, 0.9617]],
       device='cuda:0'), idlist_features=TensorDict(
    fields={
        feature_0: NestedTensor(shape=torch.Size([4, j5]), device=cuda:0, dtype=torch.int64, is_shared=True),
        feature_1: NestedTensor(shape=torch.Size([4, j6]), device=cuda:0, dtype=torch.int64, is_shared=True),
        feature_2: NestedTensor(shape=torch.Size([4, j7]), device=cuda:0, dtype=torch.int64, is_shared=True),
        feature_3: NestedTensor(shape=torch.Size([4, j8]), device=cuda:0, dtype=torch.int64, is_shared=True)},
    batch_size=torch.Size([]),
    device=cuda:0,
    is_shared=True), idscore_features=None, label=tensor([0.2093, 0.6164, 0.1763, 0.1895], device='cuda:0'))
```
* unsharded model
```
(Pdb) global_model
TestSequenceSparseNN(
  (dense): TestDenseArch(
    (linear): Linear(in_features=16, out_features=8, bias=True)
  )
  (sparse): TestSequenceSparseArch(
    (ec): EmbeddingCollection(
      (embeddings): ModuleDict(
        (table_0): Embedding(11, 16)
        (table_1): Embedding(22, 16)
        (table_2): Embedding(33, 16)
        (table_3): Embedding(44, 16)
        (table_4): Embedding(11, 16)
        (table_5): Embedding(22, 16)
      )
    )
  )
  (over): TestSequenceOverArch(
    (linear): Linear(in_features=1928, out_features=16, bias=True)
  )
)
```
* TD input
```
(Pdb) global_input
ModelInput(float_features=tensor([[0.8893, 0.6990, 0.6512, 0.9617, 0.5531, 0.9029, 0.8455, 0.9288, 0.2433,
         0.8901, 0.8849, 0.3849, 0.4535, 0.9318, 0.5002, 0.8056],
        [0.1978, 0.4822, 0.2907, 0.9947, 0.6707, 0.4246, 0.2294, 0.6623, 0.7146,
         0.1914, 0.6517, 0.9449, 0.5650, 0.2358, 0.6787, 0.3671],
        [0.3964, 0.6190, 0.7695, 0.6526, 0.7095, 0.2790, 0.0581, 0.2470, 0.8315,
         0.9374, 0.0215, 0.3572, 0.0516, 0.1447, 0.0811, 0.2678],
        [0.0475, 0.9740, 0.0039, 0.6126, 0.9783, 0.5080, 0.5583, 0.0703, 0.8320,
         0.9837, 0.3936, 0.6329, 0.8229, 0.8486, 0.7715, 0.9617],
        [0.6807, 0.7970, 0.1164, 0.8487, 0.7730, 0.1654, 0.5599, 0.5923, 0.3909,
         0.4720, 0.9423, 0.7868, 0.3710, 0.6075, 0.6849, 0.1366],
        [0.0246, 0.5967, 0.2838, 0.8114, 0.3761, 0.3963, 0.7792, 0.9119, 0.4026,
         0.4769, 0.1477, 0.0923, 0.0723, 0.4416, 0.4560, 0.9548],
        [0.8666, 0.6254, 0.9162, 0.1954, 0.8466, 0.6498, 0.3412, 0.2098, 0.9786,
         0.3349, 0.7625, 0.3615, 0.8880, 0.0751, 0.8417, 0.5380],
        [0.2857, 0.6871, 0.6694, 0.8206, 0.5142, 0.5641, 0.3780, 0.9441, 0.0964,
         0.2007, 0.1148, 0.8054, 0.1520, 0.3742, 0.6364, 0.9797]],
       device='cuda:0'), idlist_features=TensorDict(
    fields={
        feature_0: NestedTensor(shape=torch.Size([8, j1]), device=cuda:0, dtype=torch.int64, is_shared=True),
        feature_1: NestedTensor(shape=torch.Size([8, j2]), device=cuda:0, dtype=torch.int64, is_shared=True),
        feature_2: NestedTensor(shape=torch.Size([8, j3]), device=cuda:0, dtype=torch.int64, is_shared=True),
        feature_3: NestedTensor(shape=torch.Size([8, j4]), device=cuda:0, dtype=torch.int64, is_shared=True)},
    batch_size=torch.Size([]),
    device=cuda:0,
    is_shared=True), idscore_features=None, label=tensor([0.2093, 0.6164, 0.1763, 0.1895, 0.3132, 0.2133, 0.4997, 0.0055],
       device='cuda:0'))
```

Differential Revision: D66521351
  • Loading branch information
TroyGarden authored and facebook-github-bot committed Jan 5, 2025
1 parent 47012bb commit 5d16527
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 17 deletions.
20 changes: 11 additions & 9 deletions torchrec/distributed/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
)

import torch
from tensordict import TensorDict
from torch import distributed as dist, nn
from torch.autograd.profiler import record_function
from torch.distributed._shard.sharding_spec import EnumerableShardingSpec
Expand Down Expand Up @@ -90,20 +91,14 @@
from torchrec.optim.fused import EmptyFusedOptimizer, FusedOptimizerModule
from torchrec.optim.keyed import CombinedOptimizer, KeyedOptimizer
from torchrec.sparse.jagged_tensor import _to_offsets, JaggedTensor, KeyedJaggedTensor
from torchrec.sparse.tensor_dict import maybe_td_to_kjt

try:
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops")
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops_cpu")
except OSError:
pass

try:
from tensordict import TensorDict
except ImportError:

class TensorDict:
pass


logger: logging.Logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -1205,8 +1200,15 @@ def _compute_sequence_vbe_context(
def input_dist(
self,
ctx: EmbeddingCollectionContext,
features: KeyedJaggedTensor,
features: TypeUnion[KeyedJaggedTensor, TensorDict],
) -> Awaitable[Awaitable[KJTList]]:
need_permute: bool = True
if isinstance(features, TensorDict):
feature_keys = list(features.keys()) # pyre-ignore[6]
if self._features_order:
feature_keys = [feature_keys[i] for i in self._features_order]
need_permute = False
features = maybe_td_to_kjt(features, feature_keys) # pyre-ignore[6]
if self._has_uninitialized_input_dist:
self._create_input_dist(input_feature_names=features.keys())
self._has_uninitialized_input_dist = False
Expand All @@ -1216,7 +1218,7 @@ def input_dist(
unpadded_features = features
features = pad_vbe_kjt_lengths(unpadded_features)

if self._features_order:
if need_permute and self._features_order:
features = features.permute(
self._features_order,
# pyre-fixme[6]: For 2nd argument expected `Optional[Tensor]`
Expand Down
32 changes: 26 additions & 6 deletions torchrec/distributed/test_utils/test_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@ def gen_model_and_input(
long_indices: bool = True,
global_constant_batch: bool = False,
num_inputs: int = 1,
input_type: str = "kjt", # "kjt" or "td"
) -> Tuple[nn.Module, List[Tuple[ModelInput, List[ModelInput]]]]:
torch.manual_seed(0)
if dedup_feature_names:
Expand Down Expand Up @@ -177,9 +178,9 @@ def gen_model_and_input(
feature_processor_modules=feature_processor_modules,
)
inputs = []
for _ in range(num_inputs):
inputs.append(
(
if input_type == "kjt" and generate == ModelInput.generate_variable_batch_input:
for _ in range(num_inputs):
inputs.append(
cast(VariableBatchModelInputCallable, generate)(
average_batch_size=batch_size,
world_size=world_size,
Expand All @@ -188,8 +189,26 @@ def gen_model_and_input(
weighted_tables=weighted_tables or [],
global_constant_batch=global_constant_batch,
)
if generate == ModelInput.generate_variable_batch_input
else cast(ModelInputCallable, generate)(
)
elif generate == ModelInput.generate:
for _ in range(num_inputs):
inputs.append(
ModelInput.generate(
world_size=world_size,
tables=tables,
dedup_tables=dedup_tables,
weighted_tables=weighted_tables or [],
num_float_features=num_float_features,
variable_batch_size=variable_batch_size,
batch_size=batch_size,
long_indices=long_indices,
input_type=input_type,
)
)
else:
for _ in range(num_inputs):
inputs.append(
cast(ModelInputCallable, generate)(
world_size=world_size,
tables=tables,
dedup_tables=dedup_tables,
Expand All @@ -200,7 +219,6 @@ def gen_model_and_input(
long_indices=long_indices,
)
)
)
return (model, inputs)


Expand Down Expand Up @@ -286,6 +304,7 @@ def sharding_single_rank_test(
global_constant_batch: bool = False,
world_size_2D: Optional[int] = None,
node_group_size: Optional[int] = None,
input_type: str = "kjt", # "kjt" or "td"
) -> None:
with MultiProcessContext(rank, world_size, backend, local_size) as ctx:
# Generate model & inputs.
Expand All @@ -308,6 +327,7 @@ def sharding_single_rank_test(
batch_size=batch_size,
feature_processor_modules=feature_processor_modules,
global_constant_batch=global_constant_batch,
input_type=input_type,
)
global_model = global_model.to(ctx.device)
global_input = inputs[0][0].to(ctx.device)
Expand Down
41 changes: 41 additions & 0 deletions torchrec/distributed/tests/test_sequence_model_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,3 +376,44 @@ def _test_sharding(
variable_batch_per_feature=variable_batch_per_feature,
global_constant_batch=True,
)


@skip_if_asan_class
class TDSequenceModelParallelTest(SequenceModelParallelTest):

def test_sharding_variable_batch(self) -> None:
pass

def _test_sharding(
self,
sharders: List[TestEmbeddingCollectionSharder],
backend: str = "gloo",
world_size: int = 2,
local_size: Optional[int] = None,
constraints: Optional[Dict[str, ParameterConstraints]] = None,
model_class: Type[TestSparseNNBase] = TestSequenceSparseNN,
qcomms_config: Optional[QCommsConfig] = None,
apply_optimizer_in_backward_config: Optional[
Dict[str, Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]]
] = None,
variable_batch_size: bool = False,
variable_batch_per_feature: bool = False,
) -> None:
self._run_multi_process_test(
callable=sharding_single_rank_test,
world_size=world_size,
local_size=local_size,
model_class=model_class,
tables=self.tables,
embedding_groups=self.embedding_groups,
sharders=sharders,
optim=EmbOptimType.EXACT_SGD,
backend=backend,
constraints=constraints,
qcomms_config=qcomms_config,
apply_optimizer_in_backward_config=apply_optimizer_in_backward_config,
variable_batch_size=variable_batch_size,
variable_batch_per_feature=variable_batch_per_feature,
global_constant_batch=True,
input_type="td",
)
8 changes: 6 additions & 2 deletions torchrec/modules/embedding_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,10 @@ def __init__(
self._feature_names: List[List[str]] = [table.feature_names for table in tables]
self.reset_parameters()

def forward(self, features: KeyedJaggedTensor) -> KeyedTensor:
def forward(
self,
features: KeyedJaggedTensor, # can also take TensorDict as input
) -> KeyedTensor:
"""
Run the EmbeddingBagCollection forward pass. This method takes in a `KeyedJaggedTensor`
and returns a `KeyedTensor`, which is the result of pooling the embeddings for each feature.
Expand Down Expand Up @@ -450,7 +453,7 @@ def __init__( # noqa C901

def forward(
self,
features: KeyedJaggedTensor,
features: KeyedJaggedTensor, # can also take TensorDict as input
) -> Dict[str, JaggedTensor]:
"""
Run the EmbeddingBagCollection forward pass. This method takes in a `KeyedJaggedTensor`
Expand All @@ -463,6 +466,7 @@ def forward(
Dict[str, JaggedTensor]
"""

features = maybe_td_to_kjt(features, None)
feature_embeddings: Dict[str, JaggedTensor] = {}
jt_dict: Dict[str, JaggedTensor] = features.to_dict()
for i, emb_module in enumerate(self.embeddings.values()):
Expand Down

0 comments on commit 5d16527

Please sign in to comment.