diff --git a/torchrec/distributed/embeddingbag.py b/torchrec/distributed/embeddingbag.py index e24a695d9..d1a5ee51c 100644 --- a/torchrec/distributed/embeddingbag.py +++ b/torchrec/distributed/embeddingbag.py @@ -27,6 +27,7 @@ import torch from fbgemm_gpu.permute_pooled_embedding_modules import PermutePooledEmbeddings +from tensordict import TensorDict from torch import distributed as dist, nn, Tensor from torch.autograd.profiler import record_function from torch.distributed._tensor import DTensor @@ -90,7 +91,12 @@ ) from torchrec.optim.fused import EmptyFusedOptimizer, FusedOptimizerModule from torchrec.optim.keyed import CombinedOptimizer, KeyedOptimizer -from torchrec.sparse.jagged_tensor import _to_offsets, KeyedJaggedTensor, KeyedTensor +from torchrec.sparse.jagged_tensor import ( + _to_offsets, + KeyedJaggedTensor, + KeyedTensor, + td_to_kjt, +) try: torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") @@ -99,13 +105,6 @@ except OSError: pass -try: - from tensordict import TensorDict -except ImportError: - - class TensorDict: - pass - def _pin_and_move(tensor: torch.Tensor, device: torch.device) -> torch.Tensor: return ( @@ -662,9 +661,7 @@ def __init__( self._inverse_indices_permute_indices: Optional[torch.Tensor] = None # to support mean pooling callback hook self._has_mean_pooling_callback: bool = ( - True - if PoolingType.MEAN.value in self._pooling_type_to_rs_features - else False + PoolingType.MEAN.value in self._pooling_type_to_rs_features ) self._dim_per_key: Optional[torch.Tensor] = None self._kjt_key_indices: Dict[str, int] = {} @@ -1171,26 +1168,37 @@ def _create_inverse_indices_permute_indices( # pyre-ignore [14] def input_dist( - self, ctx: EmbeddingBagCollectionContext, features: KeyedJaggedTensor + self, + ctx: EmbeddingBagCollectionContext, + features: Union[KeyedJaggedTensor, TensorDict], ) -> Awaitable[Awaitable[KJTList]]: - ctx.variable_batch_per_feature = features.variable_stride_per_key() - ctx.inverse_indices = features.inverse_indices_or_none() + if isinstance(features, KeyedJaggedTensor): + ctx.variable_batch_per_feature = features.variable_stride_per_key() + ctx.inverse_indices = features.inverse_indices_or_none() + feature_keys = features.keys() + else: # features is TensorDict + ctx.variable_batch_per_feature = False # TD does not support variable batch + ctx.inverse_indices = None + feature_keys = list(features.keys()) # pyre-ignore[6] if self._has_uninitialized_input_dist: - self._create_input_dist(features.keys()) + self._create_input_dist(feature_keys) self._has_uninitialized_input_dist = False if ctx.variable_batch_per_feature: self._create_inverse_indices_permute_indices(ctx.inverse_indices) if self._has_mean_pooling_callback: - self._init_mean_pooling_callback(features.keys(), ctx.inverse_indices) + self._init_mean_pooling_callback(feature_keys, ctx.inverse_indices) with torch.no_grad(): - if self._has_features_permute: + if isinstance(features, KeyedJaggedTensor) and self._has_features_permute: features = features.permute( self._features_order, # pyre-fixme[6]: For 2nd argument expected `Optional[Tensor]` # but got `Union[Module, Tensor]`. self._features_order_tensor, ) - if self._has_mean_pooling_callback: + if ( + isinstance(features, KeyedJaggedTensor) + and self._has_mean_pooling_callback + ): ctx.divisor = _create_mean_pooling_divisor( lengths=features.lengths(), stride=features.stride(), @@ -1209,9 +1217,24 @@ def input_dist( weights=features.weights_or_none(), ) - features_by_shards = features.split( - self._feature_splits, - ) + if isinstance(features, KeyedJaggedTensor): + features_by_shards = features.split( + self._feature_splits, + ) + else: + feature_names = [feature_keys[i] for i in self._features_order] + feature_name_by_sharding_types: List[List[str]] = [] + start = 0 + for length in self._feature_splits: + feature_name_by_sharding_types.append( + feature_names[start : start + length] + ) + start += length + features_by_shards = [ + td_to_kjt(features, names) + for names in feature_name_by_sharding_types + ] + awaitables = [] for input_dist, features_by_shard, sharding_type in zip( self._input_dists, diff --git a/torchrec/distributed/train_pipeline/tests/pipeline_benchmarks.py b/torchrec/distributed/train_pipeline/tests/pipeline_benchmarks.py index e8dc5eccb..fdb900fe0 100644 --- a/torchrec/distributed/train_pipeline/tests/pipeline_benchmarks.py +++ b/torchrec/distributed/train_pipeline/tests/pipeline_benchmarks.py @@ -160,7 +160,7 @@ def main( tables = [ EmbeddingBagConfig( - num_embeddings=(i + 1) * 1000, + num_embeddings=max(i + 1, 100) * 1000, embedding_dim=dim_emb, name="table_" + str(i), feature_names=["feature_" + str(i)], @@ -169,7 +169,7 @@ def main( ] weighted_tables = [ EmbeddingBagConfig( - num_embeddings=(i + 1) * 1000, + num_embeddings=max(i + 1, 100) * 1000, embedding_dim=dim_emb, name="weighted_table_" + str(i), feature_names=["weighted_feature_" + str(i)], diff --git a/torchrec/modules/embedding_modules.py b/torchrec/modules/embedding_modules.py index 9a1878361..b22e7492f 100644 --- a/torchrec/modules/embedding_modules.py +++ b/torchrec/modules/embedding_modules.py @@ -12,21 +12,19 @@ import torch import torch.nn as nn +from tensordict import TensorDict from torchrec.modules.embedding_configs import ( DataType, EmbeddingBagConfig, EmbeddingConfig, pooling_type_to_str, ) -from torchrec.sparse.jagged_tensor import JaggedTensor, KeyedJaggedTensor, KeyedTensor - - -try: - from tensordict import TensorDict -except ImportError: - - class TensorDict: - pass +from torchrec.sparse.jagged_tensor import ( + JaggedTensor, + KeyedJaggedTensor, + KeyedTensor, + td_to_kjt, +) @torch.fx.wrap @@ -226,7 +224,7 @@ def __init__( self._feature_names: List[List[str]] = [table.feature_names for table in tables] self.reset_parameters() - def forward(self, features: KeyedJaggedTensor) -> KeyedTensor: + def forward(self, features: Union[KeyedJaggedTensor, TensorDict]) -> KeyedTensor: """ Run the EmbeddingBagCollection forward pass. This method takes in a `KeyedJaggedTensor` and returns a `KeyedTensor`, which is the result of pooling the embeddings for each feature. @@ -237,6 +235,8 @@ def forward(self, features: KeyedJaggedTensor) -> KeyedTensor: KeyedTensor """ flat_feature_names: List[str] = [] + if isinstance(features, TensorDict): + features = td_to_kjt(features) for names in self._feature_names: flat_feature_names.extend(names) inverse_indices = reorder_inverse_indices( diff --git a/torchrec/sparse/jagged_tensor.py b/torchrec/sparse/jagged_tensor.py index fa28309e3..cf9db12fb 100644 --- a/torchrec/sparse/jagged_tensor.py +++ b/torchrec/sparse/jagged_tensor.py @@ -15,6 +15,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch +from tensordict import TensorDict from torch.autograd.profiler import record_function from torch.fx._pytree import register_pytree_flatten_spec, TreeSpec from torch.utils._pytree import GetAttrKey, KeyEntry, register_pytree_node @@ -49,11 +50,9 @@ # OSS try: - from tensordict import TensorDict + pass except ImportError: - - class TensorDict: - pass + pass logger: logging.Logger = logging.getLogger() @@ -3027,6 +3026,28 @@ def dist_init( return kjt.sync() +def td_to_kjt(td: TensorDict, keys: Optional[List[str]] = None) -> KeyedJaggedTensor: + if keys is None: + keys = list(td.keys()) # pyre-ignore[6] + values = torch.cat([td[key]._values for key in keys], dim=0) + lengths = torch.cat( + [ + ( + (td[key]._lengths) + if td[key]._lengths is not None + else torch.diff(td[key]._offsets) + ) + for key in keys + ], + dim=0, + ) + return KeyedJaggedTensor( + keys=keys, + values=values, + lengths=lengths, + ) + + def _kjt_flatten( t: KeyedJaggedTensor, ) -> Tuple[List[Optional[torch.Tensor]], List[str]]: