diff --git a/torchrec/distributed/embedding.py b/torchrec/distributed/embedding.py index 5f16efc1b..a6db7095c 100644 --- a/torchrec/distributed/embedding.py +++ b/torchrec/distributed/embedding.py @@ -26,6 +26,7 @@ ) import torch +from tensordict import TensorDict from torch import distributed as dist, nn from torch.autograd.profiler import record_function from torch.distributed._tensor import DTensor @@ -88,7 +89,12 @@ from torchrec.modules.utils import construct_jagged_tensors, SequenceVBEContext from torchrec.optim.fused import EmptyFusedOptimizer, FusedOptimizerModule from torchrec.optim.keyed import CombinedOptimizer, KeyedOptimizer -from torchrec.sparse.jagged_tensor import _to_offsets, JaggedTensor, KeyedJaggedTensor +from torchrec.sparse.jagged_tensor import ( + _to_offsets, + JaggedTensor, + KeyedJaggedTensor, + td_to_kjt, +) try: torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops") @@ -96,13 +102,6 @@ except OSError: pass -try: - from tensordict import TensorDict -except ImportError: - - class TensorDict: - pass - logger: logging.Logger = logging.getLogger(__name__) @@ -1146,25 +1145,50 @@ def _compute_sequence_vbe_context( def input_dist( self, ctx: EmbeddingCollectionContext, - features: KeyedJaggedTensor, + features: TypeUnion[KeyedJaggedTensor, TensorDict], ) -> Awaitable[Awaitable[KJTList]]: + # torch.distributed.breakpoint() + feature_keys = list(features.keys()) # pyre-ignore[6] if self._has_uninitialized_input_dist: - self._create_input_dist(input_feature_names=features.keys()) + self._create_input_dist(input_feature_names=feature_keys) self._has_uninitialized_input_dist = False with torch.no_grad(): unpadded_features = None - if features.variable_stride_per_key(): + if ( + isinstance(features, KeyedJaggedTensor) + and features.variable_stride_per_key() + ): unpadded_features = features features = pad_vbe_kjt_lengths(unpadded_features) - if self._features_order: + if isinstance(features, KeyedJaggedTensor) and self._features_order: features = features.permute( self._features_order, # pyre-fixme[6]: For 2nd argument expected `Optional[Tensor]` # but got `TypeUnion[Module, Tensor]`. self._features_order_tensor, ) - features_by_shards = features.split(self._feature_splits) + + if isinstance(features, KeyedJaggedTensor): + features_by_shards = features.split(self._feature_splits) + else: # TensorDict + feature_names = ( + [feature_keys[i] for i in self._features_order] + if self._features_order # empty features_order means no reordering + else feature_keys + ) + feature_names = [name.split("@")[0] for name in feature_names] + feature_name_by_sharding_types: List[List[str]] = [] + start = 0 + for length in self._feature_splits: + feature_name_by_sharding_types.append( + feature_names[start : start + length] + ) + start += length + features_by_shards = [ + td_to_kjt(features, names) + for names in feature_name_by_sharding_types + ] if self._use_index_dedup: features_by_shards = self._dedup_indices(ctx, features_by_shards) diff --git a/torchrec/distributed/test_utils/test_sharding.py b/torchrec/distributed/test_utils/test_sharding.py index dbd8f1007..f74c597f5 100644 --- a/torchrec/distributed/test_utils/test_sharding.py +++ b/torchrec/distributed/test_utils/test_sharding.py @@ -148,6 +148,7 @@ def gen_model_and_input( long_indices: bool = True, global_constant_batch: bool = False, num_inputs: int = 1, + input_type: str = "kjt", # "kjt" or "td" ) -> Tuple[nn.Module, List[Tuple[ModelInput, List[ModelInput]]]]: torch.manual_seed(0) if dedup_feature_names: @@ -178,9 +179,9 @@ def gen_model_and_input( feature_processor_modules=feature_processor_modules, ) inputs = [] - for _ in range(num_inputs): - inputs.append( - ( + if input_type == "kjt" and generate == ModelInput.generate_variable_batch_input: + for _ in range(num_inputs): + inputs.append( cast(VariableBatchModelInputCallable, generate)( average_batch_size=batch_size, world_size=world_size, @@ -189,8 +190,26 @@ def gen_model_and_input( weighted_tables=weighted_tables or [], global_constant_batch=global_constant_batch, ) - if generate == ModelInput.generate_variable_batch_input - else cast(ModelInputCallable, generate)( + ) + elif generate == ModelInput.generate: + for _ in range(num_inputs): + inputs.append( + ModelInput.generate( + world_size=world_size, + tables=tables, + dedup_tables=dedup_tables, + weighted_tables=weighted_tables or [], + num_float_features=num_float_features, + variable_batch_size=variable_batch_size, + batch_size=batch_size, + long_indices=long_indices, + input_type=input_type, + ) + ) + else: + for _ in range(num_inputs): + inputs.append( + cast(ModelInputCallable, generate)( world_size=world_size, tables=tables, dedup_tables=dedup_tables, @@ -201,7 +220,6 @@ def gen_model_and_input( long_indices=long_indices, ) ) - ) return (model, inputs) @@ -287,6 +305,7 @@ def sharding_single_rank_test( global_constant_batch: bool = False, world_size_2D: Optional[int] = None, node_group_size: Optional[int] = None, + input_type: str = "kjt", # "kjt" or "td" ) -> None: with MultiProcessContext(rank, world_size, backend, local_size) as ctx: @@ -310,6 +329,7 @@ def sharding_single_rank_test( batch_size=batch_size, feature_processor_modules=feature_processor_modules, global_constant_batch=global_constant_batch, + input_type=input_type, ) global_model = global_model.to(ctx.device) global_input = inputs[0][0].to(ctx.device) diff --git a/torchrec/distributed/tests/test_sequence_model_parallel.py b/torchrec/distributed/tests/test_sequence_model_parallel.py index aec092354..d13d819c3 100644 --- a/torchrec/distributed/tests/test_sequence_model_parallel.py +++ b/torchrec/distributed/tests/test_sequence_model_parallel.py @@ -376,3 +376,44 @@ def _test_sharding( variable_batch_per_feature=variable_batch_per_feature, global_constant_batch=True, ) + + +@skip_if_asan_class +class TDSequenceModelParallelTest(SequenceModelParallelTest): + + def test_sharding_variable_batch(self) -> None: + pass + + def _test_sharding( + self, + sharders: List[TestEmbeddingCollectionSharder], + backend: str = "gloo", + world_size: int = 2, + local_size: Optional[int] = None, + constraints: Optional[Dict[str, ParameterConstraints]] = None, + model_class: Type[TestSparseNNBase] = TestSequenceSparseNN, + qcomms_config: Optional[QCommsConfig] = None, + apply_optimizer_in_backward_config: Optional[ + Dict[str, Tuple[Type[torch.optim.Optimizer], Dict[str, Any]]] + ] = None, + variable_batch_size: bool = False, + variable_batch_per_feature: bool = False, + ) -> None: + self._run_multi_process_test( + callable=sharding_single_rank_test, + world_size=world_size, + local_size=local_size, + model_class=model_class, + tables=self.tables, + embedding_groups=self.embedding_groups, + sharders=sharders, + optim=EmbOptimType.EXACT_SGD, + backend=backend, + constraints=constraints, + qcomms_config=qcomms_config, + apply_optimizer_in_backward_config=apply_optimizer_in_backward_config, + variable_batch_size=variable_batch_size, + variable_batch_per_feature=variable_batch_per_feature, + global_constant_batch=True, + input_type="td", + ) diff --git a/torchrec/modules/embedding_modules.py b/torchrec/modules/embedding_modules.py index b22e7492f..b84675d3e 100644 --- a/torchrec/modules/embedding_modules.py +++ b/torchrec/modules/embedding_modules.py @@ -456,7 +456,7 @@ def __init__( # noqa C901 def forward( self, - features: KeyedJaggedTensor, + features: Union[KeyedJaggedTensor, TensorDict], ) -> Dict[str, JaggedTensor]: """ Run the EmbeddingBagCollection forward pass. This method takes in a `KeyedJaggedTensor` @@ -470,7 +470,10 @@ def forward( """ feature_embeddings: Dict[str, JaggedTensor] = {} - jt_dict: Dict[str, JaggedTensor] = features.to_dict() + if isinstance(features, KeyedJaggedTensor): + jt_dict: Dict[str, JaggedTensor] = features.to_dict() + else: + jt_dict = features for i, emb_module in enumerate(self.embeddings.values()): feature_names = self._feature_names[i] embedding_names = self._embedding_names_by_table[i] @@ -483,6 +486,7 @@ def forward( feature_embeddings[embedding_name] = JaggedTensor( values=lookup, lengths=f.lengths(), + offsets=f.offsets() if isinstance(features, TensorDict) else None, weights=f.values() if self._need_indices else None, ) return feature_embeddings