From 24b17d8890252d356fa5ee4945065001760f698f Mon Sep 17 00:00:00 2001 From: Huanyu He Date: Wed, 4 Dec 2024 12:28:37 -0800 Subject: [PATCH] add NJT/TD support for EC (#2596) Summary: # Documents * [TorchRec NJT Work Items](https://fburl.com/gdoc/gcqq6luv) * [KJT <> TensorDict](https://docs.google.com/document/d/1zqJL5AESnoKeIt5VZ6K1289fh_1QcSwu76yo0nB4Ecw/edit?tab=t.0#heading=h.bn9zwvg79) {F1949248817} # Context * Continued from previous D66465376, which adds NJT/TD support for EBC, this diff is for EC * As depicted above, we are extending TorchRec input data type from KJT (KeyedJaggedTensor) to TD (TensorDict) * Basically we can support TensorDict in both **eager mode** and **distributed (sharded) mode**: `Input (Union[KJT, TD]) ==> EC ==> Output (KT)` * In eager mode, we directly call `td_to_kjt` in the forward function to convert TD to KJT. * In distributed mode, we do the conversion inside the `ShardedEmbeddingCollection`, specifically in the `input_dist`, where the input sparse features are prepared (permuted) for the `KJTAllToAll` communication. * In the KJT scenario, the input KJT would be permuted (and partially duplicated in some cases), followed by the `KJTAllToAll` communication. While in the TD scenario, the input TD will directly be converted to the permuted KJT ready for the following `KJTAllToAll` communication. NOTE: This diff re-used a number of existing test framework/cases with minimal critical changes in the `EmbeddingCollection` and `shardedEmbeddingCollection`. Please see the follwoing verifications for NJT/TD correctness. # Verification - input with TensorDict * breakpoint at [sharding_single_rank_test](https://fburl.com/code/x74s13fd) * sharded model ``` (Pdb) local_model DistributedModelParallel( (_dmp_wrapped_module): DistributedDataParallel( (module): TestSequenceSparseNN( (dense): TestDenseArch( (linear): Linear(in_features=16, out_features=8, bias=True) ) (sparse): TestSequenceSparseArch( (ec): ShardedEmbeddingCollection( (lookups): GroupedEmbeddingsLookup( (_emb_modules): ModuleList( (0): BatchedDenseEmbedding( (_emb_module): DenseTableBatchedEmbeddingBagsCodegen() ) ) ) (_input_dists): RwSparseFeaturesDist( (_dist): KJTAllToAll() ) (_output_dists): RwSequenceEmbeddingDist( (_dist): SequenceEmbeddingsAllToAll() ) (embeddings): ModuleDict( (table_0): Module() (table_1): Module() (table_2): Module() (table_3): Module() (table_4): Module() (table_5): Module() ) ) ) (over): TestSequenceOverArch( (linear): Linear(in_features=1928, out_features=16, bias=True) ) ) ) ) ``` * TD input ``` (Pdb) local_input ModelInput(float_features=tensor([[0.8893, 0.6990, 0.6512, 0.9617, 0.5531, 0.9029, 0.8455, 0.9288, 0.2433, 0.8901, 0.8849, 0.3849, 0.4535, 0.9318, 0.5002, 0.8056], [0.1978, 0.4822, 0.2907, 0.9947, 0.6707, 0.4246, 0.2294, 0.6623, 0.7146, 0.1914, 0.6517, 0.9449, 0.5650, 0.2358, 0.6787, 0.3671], [0.3964, 0.6190, 0.7695, 0.6526, 0.7095, 0.2790, 0.0581, 0.2470, 0.8315, 0.9374, 0.0215, 0.3572, 0.0516, 0.1447, 0.0811, 0.2678], [0.0475, 0.9740, 0.0039, 0.6126, 0.9783, 0.5080, 0.5583, 0.0703, 0.8320, 0.9837, 0.3936, 0.6329, 0.8229, 0.8486, 0.7715, 0.9617]], device='cuda:0'), idlist_features=TensorDict( fields={ feature_0: NestedTensor(shape=torch.Size([4, j5]), device=cuda:0, dtype=torch.int64, is_shared=True), feature_1: NestedTensor(shape=torch.Size([4, j6]), device=cuda:0, dtype=torch.int64, is_shared=True), feature_2: NestedTensor(shape=torch.Size([4, j7]), device=cuda:0, dtype=torch.int64, is_shared=True), feature_3: NestedTensor(shape=torch.Size([4, j8]), device=cuda:0, dtype=torch.int64, is_shared=True)}, batch_size=torch.Size([]), device=cuda:0, is_shared=True), idscore_features=None, label=tensor([0.2093, 0.6164, 0.1763, 0.1895], device='cuda:0')) ``` * unsharded model ``` (Pdb) global_model TestSequenceSparseNN( (dense): TestDenseArch( (linear): Linear(in_features=16, out_features=8, bias=True) ) (sparse): TestSequenceSparseArch( (ec): EmbeddingCollection( (embeddings): ModuleDict( (table_0): Embedding(11, 16) (table_1): Embedding(22, 16) (table_2): Embedding(33, 16) (table_3): Embedding(44, 16) (table_4): Embedding(11, 16) (table_5): Embedding(22, 16) ) ) ) (over): TestSequenceOverArch( (linear): Linear(in_features=1928, out_features=16, bias=True) ) ) ``` * TD input ``` (Pdb) global_input ModelInput(float_features=tensor([[0.8893, 0.6990, 0.6512, 0.9617, 0.5531, 0.9029, 0.8455, 0.9288, 0.2433, 0.8901, 0.8849, 0.3849, 0.4535, 0.9318, 0.5002, 0.8056], [0.1978, 0.4822, 0.2907, 0.9947, 0.6707, 0.4246, 0.2294, 0.6623, 0.7146, 0.1914, 0.6517, 0.9449, 0.5650, 0.2358, 0.6787, 0.3671], [0.3964, 0.6190, 0.7695, 0.6526, 0.7095, 0.2790, 0.0581, 0.2470, 0.8315, 0.9374, 0.0215, 0.3572, 0.0516, 0.1447, 0.0811, 0.2678], [0.0475, 0.9740, 0.0039, 0.6126, 0.9783, 0.5080, 0.5583, 0.0703, 0.8320, 0.9837, 0.3936, 0.6329, 0.8229, 0.8486, 0.7715, 0.9617], [0.6807, 0.7970, 0.1164, 0.8487, 0.7730, 0.1654, 0.5599, 0.5923, 0.3909, 0.4720, 0.9423, 0.7868, 0.3710, 0.6075, 0.6849, 0.1366], [0.0246, 0.5967, 0.2838, 0.8114, 0.3761, 0.3963, 0.7792, 0.9119, 0.4026, 0.4769, 0.1477, 0.0923, 0.0723, 0.4416, 0.4560, 0.9548], [0.8666, 0.6254, 0.9162, 0.1954, 0.8466, 0.6498, 0.3412, 0.2098, 0.9786, 0.3349, 0.7625, 0.3615, 0.8880, 0.0751, 0.8417, 0.5380], [0.2857, 0.6871, 0.6694, 0.8206, 0.5142, 0.5641, 0.3780, 0.9441, 0.0964, 0.2007, 0.1148, 0.8054, 0.1520, 0.3742, 0.6364, 0.9797]], device='cuda:0'), idlist_features=TensorDict( fields={ feature_0: NestedTensor(shape=torch.Size([8, j1]), device=cuda:0, dtype=torch.int64, is_shared=True), feature_1: NestedTensor(shape=torch.Size([8, j2]), device=cuda:0, dtype=torch.int64, is_shared=True), feature_2: NestedTensor(shape=torch.Size([8, j3]), device=cuda:0, dtype=torch.int64, is_shared=True), feature_3: NestedTensor(shape=torch.Size([8, j4]), device=cuda:0, dtype=torch.int64, is_shared=True)}, batch_size=torch.Size([]), device=cuda:0, is_shared=True), idscore_features=None, label=tensor([0.2093, 0.6164, 0.1763, 0.1895, 0.3132, 0.2133, 0.4997, 0.0055], device='cuda:0')) ``` Differential Revision: D66521351 --- torchrec/distributed/embedding.py | 50 ++++++++++++++----- .../distributed/test_utils/test_sharding.py | 32 +++++++++--- .../tests/test_sequence_model_parallel.py | 41 +++++++++++++++ torchrec/modules/embedding_modules.py | 8 ++- 4 files changed, 110 insertions(+), 21 deletions(-) diff --git a/torchrec/distributed/embedding.py b/torchrec/distributed/embedding.py index 5f16efc1b..a6db7095c 100644 --- a/torchrec/distributed/embedding.py +++ b/torchrec/distributed/embedding.py @@ -26,6 +26,7 @@ ) import torch +from tensordict import TensorDict from torch import distributed as dist, nn from torch.autograd.profiler import record_function from torch.distributed._tensor import DTensor @@ -88,7 +89,12 @@ from torchrec.modules.utils import construct_jagged_tensors, SequenceVBEContext from torchrec.optim.fused import EmptyFusedOptimizer, FusedOptimizerModule from torchrec.optim.keyed import CombinedOptimizer, KeyedOptimizer -from torchrec.sparse.jagged_tensor import _to_offsets, JaggedTensor, KeyedJaggedTensor +from torchrec.sparse.jagged_tensor import ( + _to_offsets, + JaggedTensor, + KeyedJaggedTensor, + td_to_kjt, +) try: torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") @@ -96,13 +102,6 @@ except OSError: pass -try: - from tensordict import TensorDict -except ImportError: - - class TensorDict: - pass - logger: logging.Logger = logging.getLogger(__name__) @@ -1146,25 +1145,50 @@ def _compute_sequence_vbe_context( def input_dist( self, ctx: EmbeddingCollectionContext, - features: KeyedJaggedTensor, + features: TypeUnion[KeyedJaggedTensor, TensorDict], ) -> Awaitable[Awaitable[KJTList]]: + # torch.distributed.breakpoint() + feature_keys = list(features.keys()) # pyre-ignore[6] if self._has_uninitialized_input_dist: - self._create_input_dist(input_feature_names=features.keys()) + self._create_input_dist(input_feature_names=feature_keys) self._has_uninitialized_input_dist = False with torch.no_grad(): unpadded_features = None - if features.variable_stride_per_key(): + if ( + isinstance(features, KeyedJaggedTensor) + and features.variable_stride_per_key() + ): unpadded_features = features features = pad_vbe_kjt_lengths(unpadded_features) - if self._features_order: + if isinstance(features, KeyedJaggedTensor) and self._features_order: features = features.permute( self._features_order, # pyre-fixme[6]: For 2nd argument expected `Optional[Tensor]` # but got `TypeUnion[Module, Tensor]`. self._features_order_tensor, ) - features_by_shards = features.split(self._feature_splits) + + if isinstance(features, KeyedJaggedTensor): + features_by_shards = features.split(self._feature_splits) + else: # TensorDict + feature_names = ( + [feature_keys[i] for i in self._features_order] + if self._features_order # empty features_order means no reordering + else feature_keys + ) + feature_names = [name.split("@")[0] for name in feature_names] + feature_name_by_sharding_types: List[List[str]] = [] + start = 0 + for length in self._feature_splits: + feature_name_by_sharding_types.append( + feature_names[start : start + length] + ) + start += length + features_by_shards = [ + td_to_kjt(features, names) + for names in feature_name_by_sharding_types + ] if self._use_index_dedup: features_by_shards = self._dedup_indices(ctx, features_by_shards) diff --git a/torchrec/distributed/test_utils/test_sharding.py b/torchrec/distributed/test_utils/test_sharding.py index dbd8f1007..f74c597f5 100644 --- a/torchrec/distributed/test_utils/test_sharding.py +++ b/torchrec/distributed/test_utils/test_sharding.py @@ -148,6 +148,7 @@ def gen_model_and_input( long_indices: bool = True, global_constant_batch: bool = False, num_inputs: int = 1, + input_type: str = "kjt", # "kjt" or "td" ) -> Tuple[nn.Module, List[Tuple[ModelInput, List[ModelInput]]]]: torch.manual_seed(0) if dedup_feature_names: @@ -178,9 +179,9 @@ def gen_model_and_input( feature_processor_modules=feature_processor_modules, ) inputs = [] - for _ in range(num_inputs): - inputs.append( - ( + if input_type == "kjt" and generate == ModelInput.generate_variable_batch_input: + for _ in range(num_inputs): + inputs.append( cast(VariableBatchModelInputCallable, generate)( average_batch_size=batch_size, world_size=world_size, @@ -189,8 +190,26 @@ def gen_model_and_input( weighted_tables=weighted_tables or [], global_constant_batch=global_constant_batch, ) - if generate == ModelInput.generate_variable_batch_input - else cast(ModelInputCallable, generate)( + ) + elif generate == ModelInput.generate: + for _ in range(num_inputs): + inputs.append( + ModelInput.generate( + world_size=world_size, + tables=tables, + dedup_tables=dedup_tables, + weighted_tables=weighted_tables or [], + num_float_features=num_float_features, + variable_batch_size=variable_batch_size, + batch_size=batch_size, + long_indices=long_indices, + input_type=input_type, + ) + ) + else: + for _ in range(num_inputs): + inputs.append( + cast(ModelInputCallable, generate)( world_size=world_size, tables=tables, dedup_tables=dedup_tables, @@ -201,7 +220,6 @@ def gen_model_and_input( long_indices=long_indices, ) ) - ) return (model, inputs) @@ -287,6 +305,7 @@ def sharding_single_rank_test( global_constant_batch: bool = False, world_size_2D: Optional[int] = None, node_group_size: Optional[int] = None, + input_type: str = "kjt", # "kjt" or "td" ) -> None: with MultiProcessContext(rank, world_size, backend, local_size) as ctx: @@ -310,6 +329,7 @@ def sharding_single_rank_test( batch_size=batch_size, feature_processor_modules=feature_processor_modules, global_constant_batch=global_constant_batch, + input_type=input_type, ) global_model = global_model.to(ctx.device) global_input = inputs[0][0].to(ctx.device) diff --git a/torchrec/distributed/tests/test_sequence_model_parallel.py b/torchrec/distributed/tests/test_sequence_model_parallel.py index aec092354..d13d819c3 100644 --- a/torchrec/distributed/tests/test_sequence_model_parallel.py +++ b/torchrec/distributed/tests/test_sequence_model_parallel.py @@ -376,3 +376,44 @@ def _test_sharding( variable_batch_per_feature=variable_batch_per_feature, global_constant_batch=True, ) + + +@skip_if_asan_class +class TDSequenceModelParallelTest(SequenceModelParallelTest): + + def test_sharding_variable_batch(self) -> None: + pass + + def _test_sharding( + self, + sharders: List[TestEmbeddingCollectionSharder], + backend: str = "gloo", + world_size: int = 2, + local_size: Optional[int] = None, + constraints: Optional[Dict[str, ParameterConstraints]] = None, + model_class: Type[TestSparseNNBase] = TestSequenceSparseNN, + qcomms_config: Optional[QCommsConfig] = None, + apply_optimizer_in_backward_config: Optional[ + Dict[str, Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]] + ] = None, + variable_batch_size: bool = False, + variable_batch_per_feature: bool = False, + ) -> None: + self._run_multi_process_test( + callable=sharding_single_rank_test, + world_size=world_size, + local_size=local_size, + model_class=model_class, + tables=self.tables, + embedding_groups=self.embedding_groups, + sharders=sharders, + optim=EmbOptimType.EXACT_SGD, + backend=backend, + constraints=constraints, + qcomms_config=qcomms_config, + apply_optimizer_in_backward_config=apply_optimizer_in_backward_config, + variable_batch_size=variable_batch_size, + variable_batch_per_feature=variable_batch_per_feature, + global_constant_batch=True, + input_type="td", + ) diff --git a/torchrec/modules/embedding_modules.py b/torchrec/modules/embedding_modules.py index b22e7492f..b84675d3e 100644 --- a/torchrec/modules/embedding_modules.py +++ b/torchrec/modules/embedding_modules.py @@ -456,7 +456,7 @@ def __init__( # noqa C901 def forward( self, - features: KeyedJaggedTensor, + features: Union[KeyedJaggedTensor, TensorDict], ) -> Dict[str, JaggedTensor]: """ Run the EmbeddingBagCollection forward pass. This method takes in a `KeyedJaggedTensor` @@ -470,7 +470,10 @@ def forward( """ feature_embeddings: Dict[str, JaggedTensor] = {} - jt_dict: Dict[str, JaggedTensor] = features.to_dict() + if isinstance(features, KeyedJaggedTensor): + jt_dict: Dict[str, JaggedTensor] = features.to_dict() + else: + jt_dict = features for i, emb_module in enumerate(self.embeddings.values()): feature_names = self._feature_names[i] embedding_names = self._embedding_names_by_table[i] @@ -483,6 +486,7 @@ def forward( feature_embeddings[embedding_name] = JaggedTensor( values=lookup, lengths=f.lengths(), + offsets=f.offsets() if isinstance(features, TensorDict) else None, weights=f.values() if self._need_indices else None, ) return feature_embeddings