diff --git a/torchrec/distributed/embeddingbag.py b/torchrec/distributed/embeddingbag.py index 9ffc74e0c..38e4671e8 100644 --- a/torchrec/distributed/embeddingbag.py +++ b/torchrec/distributed/embeddingbag.py @@ -27,6 +27,7 @@ import torch from fbgemm_gpu.permute_pooled_embedding_modules import PermutePooledEmbeddings +from tensordict import TensorDict from torch import distributed as dist, nn, Tensor from torch.autograd.profiler import record_function from torch.distributed._tensor import DTensor @@ -90,6 +91,7 @@ from torchrec.optim.fused import EmptyFusedOptimizer, FusedOptimizerModule from torchrec.optim.keyed import CombinedOptimizer, KeyedOptimizer from torchrec.sparse.jagged_tensor import _to_offsets, KeyedJaggedTensor, KeyedTensor +from torchrec.sparse.tensor_dict import maybe_td_to_kjt try: torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") @@ -98,13 +100,6 @@ except OSError: pass -try: - from tensordict import TensorDict -except ImportError: - - class TensorDict: - pass - def _pin_and_move(tensor: torch.Tensor, device: torch.device) -> torch.Tensor: return ( @@ -659,9 +654,7 @@ def __init__( self._inverse_indices_permute_indices: Optional[torch.Tensor] = None # to support mean pooling callback hook self._has_mean_pooling_callback: bool = ( - True - if PoolingType.MEAN.value in self._pooling_type_to_rs_features - else False + PoolingType.MEAN.value in self._pooling_type_to_rs_features ) self._dim_per_key: Optional[torch.Tensor] = None self._kjt_key_indices: Dict[str, int] = {} @@ -1151,26 +1144,37 @@ def _create_inverse_indices_permute_indices( # pyre-ignore [14] def input_dist( - self, ctx: EmbeddingBagCollectionContext, features: KeyedJaggedTensor + self, + ctx: EmbeddingBagCollectionContext, + features: Union[KeyedJaggedTensor, TensorDict], ) -> Awaitable[Awaitable[KJTList]]: - ctx.variable_batch_per_feature = features.variable_stride_per_key() - ctx.inverse_indices = features.inverse_indices_or_none() + if isinstance(features, KeyedJaggedTensor): + ctx.variable_batch_per_feature = features.variable_stride_per_key() + ctx.inverse_indices = features.inverse_indices_or_none() + feature_keys = features.keys() + else: # features is TensorDict + ctx.variable_batch_per_feature = False # TD does not support variable batch + ctx.inverse_indices = None + feature_keys = list(features.keys()) # pyre-ignore[6] if self._has_uninitialized_input_dist: - self._create_input_dist(features.keys()) + self._create_input_dist(feature_keys) self._has_uninitialized_input_dist = False if ctx.variable_batch_per_feature: self._create_inverse_indices_permute_indices(ctx.inverse_indices) if self._has_mean_pooling_callback: - self._init_mean_pooling_callback(features.keys(), ctx.inverse_indices) + self._init_mean_pooling_callback(feature_keys, ctx.inverse_indices) with torch.no_grad(): - if self._has_features_permute: + if isinstance(features, KeyedJaggedTensor) and self._has_features_permute: features = features.permute( self._features_order, # pyre-fixme[6]: For 2nd argument expected `Optional[Tensor]` # but got `Union[Module, Tensor]`. self._features_order_tensor, ) - if self._has_mean_pooling_callback: + if ( + isinstance(features, KeyedJaggedTensor) + and self._has_mean_pooling_callback + ): ctx.divisor = _create_mean_pooling_divisor( lengths=features.lengths(), stride=features.stride(), @@ -1189,9 +1193,24 @@ def input_dist( weights=features.weights_or_none(), ) - features_by_shards = features.split( - self._feature_splits, - ) + if isinstance(features, KeyedJaggedTensor): + features_by_shards = features.split( + self._feature_splits, + ) + else: + feature_names = [feature_keys[i] for i in self._features_order] + feature_name_by_sharding_types: List[List[str]] = [] + start = 0 + for length in self._feature_splits: + feature_name_by_sharding_types.append( + feature_names[start : start + length] + ) + start += length + features_by_shards = [ + maybe_td_to_kjt(features, names) + for names in feature_name_by_sharding_types + ] + awaitables = [] for input_dist, features_by_shard, sharding_type in zip( self._input_dists, diff --git a/torchrec/distributed/train_pipeline/tests/pipeline_benchmarks.py b/torchrec/distributed/train_pipeline/tests/pipeline_benchmarks.py index e8dc5eccb..fdb900fe0 100644 --- a/torchrec/distributed/train_pipeline/tests/pipeline_benchmarks.py +++ b/torchrec/distributed/train_pipeline/tests/pipeline_benchmarks.py @@ -160,7 +160,7 @@ def main( tables = [ EmbeddingBagConfig( - num_embeddings=(i + 1) * 1000, + num_embeddings=max(i + 1, 100) * 1000, embedding_dim=dim_emb, name="table_" + str(i), feature_names=["feature_" + str(i)], @@ -169,7 +169,7 @@ def main( ] weighted_tables = [ EmbeddingBagConfig( - num_embeddings=(i + 1) * 1000, + num_embeddings=max(i + 1, 100) * 1000, embedding_dim=dim_emb, name="weighted_table_" + str(i), feature_names=["weighted_feature_" + str(i)], diff --git a/torchrec/modules/embedding_modules.py b/torchrec/modules/embedding_modules.py index 9a1878361..4ade3df2f 100644 --- a/torchrec/modules/embedding_modules.py +++ b/torchrec/modules/embedding_modules.py @@ -19,14 +19,7 @@ pooling_type_to_str, ) from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor, KeyedTensor - - -try: - from tensordict import TensorDict -except ImportError: - - class TensorDict: - pass +from torchrec.sparse.tensor_dict import maybe_td_to_kjt @torch.fx.wrap @@ -237,6 +230,7 @@ def forward(self, features: KeyedJaggedTensor) -> KeyedTensor: KeyedTensor """ flat_feature_names: List[str] = [] + features = maybe_td_to_kjt(features, None) for names in self._feature_names: flat_feature_names.extend(names) inverse_indices = reorder_inverse_indices( diff --git a/torchrec/sparse/jagged_tensor.py b/torchrec/sparse/jagged_tensor.py index 8468c9977..07278f852 100644 --- a/torchrec/sparse/jagged_tensor.py +++ b/torchrec/sparse/jagged_tensor.py @@ -49,11 +49,9 @@ # OSS try: - from tensordict import TensorDict + pass except ImportError: - - class TensorDict: - pass + pass logger: logging.Logger = logging.getLogger() diff --git a/torchrec/sparse/tensor_dict.py b/torchrec/sparse/tensor_dict.py new file mode 100644 index 000000000..5eadebd1b --- /dev/null +++ b/torchrec/sparse/tensor_dict.py @@ -0,0 +1,43 @@ +#!/usr/bin/env python3 +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import List, Optional + +import torch +from tensordict import TensorDict + +from torchrec.sparse.jagged_tensor import KeyedJaggedTensor + + +def maybe_td_to_kjt( + features: KeyedJaggedTensor, keys: Optional[List[str]] = None +) -> KeyedJaggedTensor: + if torch.jit.is_scripting(): + assert isinstance(features, KeyedJaggedTensor) + return features + if isinstance(features, TensorDict): + if keys is None: + keys = list(features.keys()) + values = torch.cat([features[key]._values for key in keys], dim=0) + lengths = torch.cat( + [ + ( + (features[key]._lengths) + if features[key]._lengths is not None + else torch.diff(features[key]._offsets) + ) + for key in keys + ], + dim=0, + ) + return KeyedJaggedTensor( + keys=keys, + values=values, + lengths=lengths, + ) + else: + return features