From 917690c37c0ac99336c52011ef5696cce2516993 Mon Sep 17 00:00:00 2001 From: Huanyu He Date: Mon, 30 Dec 2024 13:09:44 -0800 Subject: [PATCH] add NJT/TD support for EBC and pipeline benchmark (#2581) Summary: # Documents * [TorchRec NJT Work Items](https://fburl.com/gdoc/gcqq6luv) * [KJT <> TensorDict](https://docs.google.com/document/d/1zqJL5AESnoKeIt5VZ6K1289fh_1QcSwu76yo0nB4Ecw/edit?tab=t.0#heading=h.bn9zwvg79) {F1949248817} # Context * As depicted above, we are extending TorchRec input data type from KJT (KeyedJaggedTensor) to TD (TensorDict) * Basically we can support TensorDict in both **eager mode** and **distributed (sharded) mode**: `Input (Union[KJT, TD]) ==> EBC ==> Output (KT)` * In eager mode, we directly call `td_to_kjt` in the forward function to convert TD to KJT. * In distributed mode, we do the conversion inside the `ShardedEmbeddingBagCollection`, specifically in the `input_dist`, where the input sparse features are prepared (permuted) for the `KJTAllToAll` communication. * In the KJT scenario, the input KJT would be permuted (and partially duplicated in some cases), followed by the `KJTAllToAll` communication. While in the TD scenario, the input TD will directly be converted to the permuted KJT ready for the following `KJTAllToAll` communication. * ref: D63436011 # Details * `td_to_kjt` implemented in python, which has cpu perf regression. But it's not on the training critical path so it has a minimal impact on the overall training QPS (see test plan benchmark results) * Currently only support EBC use case WARNING: `TensorDict` does **NOT** support weighted jagged tensor, **Nor** variable batch_size neither. NOTE: All the following comparisons are between the **`KJT.permute`** in the KJT input scenario and the **`TD-KJT conversion`** in the TD input scenario. * Both `KJT.permute` and `TD-KJT conversion` are correctly marked in the `TrainPipelineBase` traces `TD-KJT conversion` has more real executions in CPU, but the heavy-lifting computation is in GPU, which is delayed/blocked by the backward pass of the previous batch. GPU runtime has a small difference ~10%. {F1949366822} * For the `Copy-Batch-To-GPU` part, TD has more fragmented `HtoD` comms while KJT has a single contiguous `HtoD` comm Runtime-wise they are similar ~10% {F1949374305} * In the most commonly used `TrainPipelineSparseDist`, where the `Copy-Batch-To-GPU` and the cpu runtime are not on the critical path, we do observe very similar training QPS in the pipeline benchmark ~1% {F1949390271} ``` TrainPipelineSparseDist | Runtime (P90): 26.737 s | Memory (P90): 34.801 GB (TD) TrainPipelineSparseDist | Runtime (P90): 26.539 s | Memory (P90): 34.765 GB (KJT) ``` * increased data size, GPU runtime is 4x {F1949386106} # Conclusion 1. [Enablement] With this approach (replacing the `KJT permute` with `TD-KJT conversion`), the EBC can now take `TensorDict` as the module input in both single-GPU and multi-GPU (sharded) scenarios, tested with TrainPipelineBase, TrainPipelineSparseDist, TrainPipelineSemiSync, and TrainPipelinePrefetch. 2. [Performance] The TD host-to-device data transfer might not necessarily be a concern/blocker for the most commonly used train pipeline (TrainPipelineSparseDist). 2. [Feature Support] In order to become production-ready, the TensorDict needs to (1) integrate the `KJT.weights` data, and (2) to support the variable batch size, which are almost used in all the production models. 3. [Improvement] There are two major operations we can improve: (1) move TensorDict from host to device, and (2) convert TD to KJT. Currently they are both in the vanilla state. Since we are not sure how the real traces would be like with production models, we can't tell if these improvements are needed/helpful. Differential Revision: D65103519 --- torchrec/distributed/embeddingbag.py | 59 ++++++++++++------- .../tests/pipeline_benchmarks.py | 4 +- torchrec/modules/embedding_modules.py | 10 +--- torchrec/sparse/jagged_tensor.py | 6 +- torchrec/sparse/tensor_dict.py | 43 ++++++++++++++ 5 files changed, 88 insertions(+), 34 deletions(-) create mode 100644 torchrec/sparse/tensor_dict.py diff --git a/torchrec/distributed/embeddingbag.py b/torchrec/distributed/embeddingbag.py index 9ffc74e0c..38e4671e8 100644 --- a/torchrec/distributed/embeddingbag.py +++ b/torchrec/distributed/embeddingbag.py @@ -27,6 +27,7 @@ import torch from fbgemm_gpu.permute_pooled_embedding_modules import PermutePooledEmbeddings +from tensordict import TensorDict from torch import distributed as dist, nn, Tensor from torch.autograd.profiler import record_function from torch.distributed._tensor import DTensor @@ -90,6 +91,7 @@ from torchrec.optim.fused import EmptyFusedOptimizer, FusedOptimizerModule from torchrec.optim.keyed import CombinedOptimizer, KeyedOptimizer from torchrec.sparse.jagged_tensor import _to_offsets, KeyedJaggedTensor, KeyedTensor +from torchrec.sparse.tensor_dict import maybe_td_to_kjt try: torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") @@ -98,13 +100,6 @@ except OSError: pass -try: - from tensordict import TensorDict -except ImportError: - - class TensorDict: - pass - def _pin_and_move(tensor: torch.Tensor, device: torch.device) -> torch.Tensor: return ( @@ -659,9 +654,7 @@ def __init__( self._inverse_indices_permute_indices: Optional[torch.Tensor] = None # to support mean pooling callback hook self._has_mean_pooling_callback: bool = ( - True - if PoolingType.MEAN.value in self._pooling_type_to_rs_features - else False + PoolingType.MEAN.value in self._pooling_type_to_rs_features ) self._dim_per_key: Optional[torch.Tensor] = None self._kjt_key_indices: Dict[str, int] = {} @@ -1151,26 +1144,37 @@ def _create_inverse_indices_permute_indices( # pyre-ignore [14] def input_dist( - self, ctx: EmbeddingBagCollectionContext, features: KeyedJaggedTensor + self, + ctx: EmbeddingBagCollectionContext, + features: Union[KeyedJaggedTensor, TensorDict], ) -> Awaitable[Awaitable[KJTList]]: - ctx.variable_batch_per_feature = features.variable_stride_per_key() - ctx.inverse_indices = features.inverse_indices_or_none() + if isinstance(features, KeyedJaggedTensor): + ctx.variable_batch_per_feature = features.variable_stride_per_key() + ctx.inverse_indices = features.inverse_indices_or_none() + feature_keys = features.keys() + else: # features is TensorDict + ctx.variable_batch_per_feature = False # TD does not support variable batch + ctx.inverse_indices = None + feature_keys = list(features.keys()) # pyre-ignore[6] if self._has_uninitialized_input_dist: - self._create_input_dist(features.keys()) + self._create_input_dist(feature_keys) self._has_uninitialized_input_dist = False if ctx.variable_batch_per_feature: self._create_inverse_indices_permute_indices(ctx.inverse_indices) if self._has_mean_pooling_callback: - self._init_mean_pooling_callback(features.keys(), ctx.inverse_indices) + self._init_mean_pooling_callback(feature_keys, ctx.inverse_indices) with torch.no_grad(): - if self._has_features_permute: + if isinstance(features, KeyedJaggedTensor) and self._has_features_permute: features = features.permute( self._features_order, # pyre-fixme[6]: For 2nd argument expected `Optional[Tensor]` # but got `Union[Module, Tensor]`. self._features_order_tensor, ) - if self._has_mean_pooling_callback: + if ( + isinstance(features, KeyedJaggedTensor) + and self._has_mean_pooling_callback + ): ctx.divisor = _create_mean_pooling_divisor( lengths=features.lengths(), stride=features.stride(), @@ -1189,9 +1193,24 @@ def input_dist( weights=features.weights_or_none(), ) - features_by_shards = features.split( - self._feature_splits, - ) + if isinstance(features, KeyedJaggedTensor): + features_by_shards = features.split( + self._feature_splits, + ) + else: + feature_names = [feature_keys[i] for i in self._features_order] + feature_name_by_sharding_types: List[List[str]] = [] + start = 0 + for length in self._feature_splits: + feature_name_by_sharding_types.append( + feature_names[start : start + length] + ) + start += length + features_by_shards = [ + maybe_td_to_kjt(features, names) + for names in feature_name_by_sharding_types + ] + awaitables = [] for input_dist, features_by_shard, sharding_type in zip( self._input_dists, diff --git a/torchrec/distributed/train_pipeline/tests/pipeline_benchmarks.py b/torchrec/distributed/train_pipeline/tests/pipeline_benchmarks.py index e8dc5eccb..fdb900fe0 100644 --- a/torchrec/distributed/train_pipeline/tests/pipeline_benchmarks.py +++ b/torchrec/distributed/train_pipeline/tests/pipeline_benchmarks.py @@ -160,7 +160,7 @@ def main( tables = [ EmbeddingBagConfig( - num_embeddings=(i + 1) * 1000, + num_embeddings=max(i + 1, 100) * 1000, embedding_dim=dim_emb, name="table_" + str(i), feature_names=["feature_" + str(i)], @@ -169,7 +169,7 @@ def main( ] weighted_tables = [ EmbeddingBagConfig( - num_embeddings=(i + 1) * 1000, + num_embeddings=max(i + 1, 100) * 1000, embedding_dim=dim_emb, name="weighted_table_" + str(i), feature_names=["weighted_feature_" + str(i)], diff --git a/torchrec/modules/embedding_modules.py b/torchrec/modules/embedding_modules.py index 9a1878361..4ade3df2f 100644 --- a/torchrec/modules/embedding_modules.py +++ b/torchrec/modules/embedding_modules.py @@ -19,14 +19,7 @@ pooling_type_to_str, ) from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor, KeyedTensor - - -try: - from tensordict import TensorDict -except ImportError: - - class TensorDict: - pass +from torchrec.sparse.tensor_dict import maybe_td_to_kjt @torch.fx.wrap @@ -237,6 +230,7 @@ def forward(self, features: KeyedJaggedTensor) -> KeyedTensor: KeyedTensor """ flat_feature_names: List[str] = [] + features = maybe_td_to_kjt(features, None) for names in self._feature_names: flat_feature_names.extend(names) inverse_indices = reorder_inverse_indices( diff --git a/torchrec/sparse/jagged_tensor.py b/torchrec/sparse/jagged_tensor.py index 8468c9977..07278f852 100644 --- a/torchrec/sparse/jagged_tensor.py +++ b/torchrec/sparse/jagged_tensor.py @@ -49,11 +49,9 @@ # OSS try: - from tensordict import TensorDict + pass except ImportError: - - class TensorDict: - pass + pass logger: logging.Logger = logging.getLogger() diff --git a/torchrec/sparse/tensor_dict.py b/torchrec/sparse/tensor_dict.py new file mode 100644 index 000000000..5eadebd1b --- /dev/null +++ b/torchrec/sparse/tensor_dict.py @@ -0,0 +1,43 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import List, Optional + +import torch +from tensordict import TensorDict + +from torchrec.sparse.jagged_tensor import KeyedJaggedTensor + + +def maybe_td_to_kjt( + features: KeyedJaggedTensor, keys: Optional[List[str]] = None +) -> KeyedJaggedTensor: + if torch.jit.is_scripting(): + assert isinstance(features, KeyedJaggedTensor) + return features + if isinstance(features, TensorDict): + if keys is None: + keys = list(features.keys()) + values = torch.cat([features[key]._values for key in keys], dim=0) + lengths = torch.cat( + [ + ( + (features[key]._lengths) + if features[key]._lengths is not None + else torch.diff(features[key]._offsets) + ) + for key in keys + ], + dim=0, + ) + return KeyedJaggedTensor( + keys=keys, + values=values, + lengths=lengths, + ) + else: + return features