Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add NJT/TD support for EC (pytorch#2596)
Summary: # Documents * [TorchRec NJT Work Items](https://fburl.com/gdoc/gcqq6luv) * [KJT <> TensorDict](https://docs.google.com/document/d/1zqJL5AESnoKeIt5VZ6K1289fh_1QcSwu76yo0nB4Ecw/edit?tab=t.0#heading=h.bn9zwvg79) {F1949248817} # Context * Continued from previous D66465376, which adds NJT/TD support for EBC, this diff is for EC * As depicted above, we are extending TorchRec input data type from KJT (KeyedJaggedTensor) to TD (TensorDict) * Basically we can support TensorDict in both **eager mode** and **distributed (sharded) mode**: `Input (Union[KJT, TD]) ==> EC ==> Output (KT)` * In eager mode, we directly call `td_to_kjt` in the forward function to convert TD to KJT. * In distributed mode, we do the conversion inside the `ShardedEmbeddingCollection`, specifically in the `input_dist`, where the input sparse features are prepared (permuted) for the `KJTAllToAll` communication. * In the KJT scenario, the input KJT would be permuted (and partially duplicated in some cases), followed by the `KJTAllToAll` communication. While in the TD scenario, the input TD will directly be converted to the permuted KJT ready for the following `KJTAllToAll` communication. NOTE: This diff re-used a number of existing test framework/cases with minimal critical changes in the `EmbeddingCollection` and `shardedEmbeddingCollection`. Please see the follwoing verifications for NJT/TD correctness. # Verification - input with TensorDict * breakpoint at [sharding_single_rank_test](https://fburl.com/code/x74s13fd) * sharded model ``` (Pdb) local_model DistributedModelParallel( (_dmp_wrapped_module): DistributedDataParallel( (module): TestSequenceSparseNN( (dense): TestDenseArch( (linear): Linear(in_features=16, out_features=8, bias=True) ) (sparse): TestSequenceSparseArch( (ec): ShardedEmbeddingCollection( (lookups): GroupedEmbeddingsLookup( (_emb_modules): ModuleList( (0): BatchedDenseEmbedding( (_emb_module): DenseTableBatchedEmbeddingBagsCodegen() ) ) ) (_input_dists): RwSparseFeaturesDist( (_dist): KJTAllToAll() ) (_output_dists): RwSequenceEmbeddingDist( (_dist): SequenceEmbeddingsAllToAll() ) (embeddings): ModuleDict( (table_0): Module() (table_1): Module() (table_2): Module() (table_3): Module() (table_4): Module() (table_5): Module() ) ) ) (over): TestSequenceOverArch( (linear): Linear(in_features=1928, out_features=16, bias=True) ) ) ) ) ``` * TD input ``` (Pdb) local_input ModelInput(float_features=tensor([[0.8893, 0.6990, 0.6512, 0.9617, 0.5531, 0.9029, 0.8455, 0.9288, 0.2433, 0.8901, 0.8849, 0.3849, 0.4535, 0.9318, 0.5002, 0.8056], [0.1978, 0.4822, 0.2907, 0.9947, 0.6707, 0.4246, 0.2294, 0.6623, 0.7146, 0.1914, 0.6517, 0.9449, 0.5650, 0.2358, 0.6787, 0.3671], [0.3964, 0.6190, 0.7695, 0.6526, 0.7095, 0.2790, 0.0581, 0.2470, 0.8315, 0.9374, 0.0215, 0.3572, 0.0516, 0.1447, 0.0811, 0.2678], [0.0475, 0.9740, 0.0039, 0.6126, 0.9783, 0.5080, 0.5583, 0.0703, 0.8320, 0.9837, 0.3936, 0.6329, 0.8229, 0.8486, 0.7715, 0.9617]], device='cuda:0'), idlist_features=TensorDict( fields={ feature_0: NestedTensor(shape=torch.Size([4, j5]), device=cuda:0, dtype=torch.int64, is_shared=True), feature_1: NestedTensor(shape=torch.Size([4, j6]), device=cuda:0, dtype=torch.int64, is_shared=True), feature_2: NestedTensor(shape=torch.Size([4, j7]), device=cuda:0, dtype=torch.int64, is_shared=True), feature_3: NestedTensor(shape=torch.Size([4, j8]), device=cuda:0, dtype=torch.int64, is_shared=True)}, batch_size=torch.Size([]), device=cuda:0, is_shared=True), idscore_features=None, label=tensor([0.2093, 0.6164, 0.1763, 0.1895], device='cuda:0')) ``` * unsharded model ``` (Pdb) global_model TestSequenceSparseNN( (dense): TestDenseArch( (linear): Linear(in_features=16, out_features=8, bias=True) ) (sparse): TestSequenceSparseArch( (ec): EmbeddingCollection( (embeddings): ModuleDict( (table_0): Embedding(11, 16) (table_1): Embedding(22, 16) (table_2): Embedding(33, 16) (table_3): Embedding(44, 16) (table_4): Embedding(11, 16) (table_5): Embedding(22, 16) ) ) ) (over): TestSequenceOverArch( (linear): Linear(in_features=1928, out_features=16, bias=True) ) ) ``` * TD input ``` (Pdb) global_input ModelInput(float_features=tensor([[0.8893, 0.6990, 0.6512, 0.9617, 0.5531, 0.9029, 0.8455, 0.9288, 0.2433, 0.8901, 0.8849, 0.3849, 0.4535, 0.9318, 0.5002, 0.8056], [0.1978, 0.4822, 0.2907, 0.9947, 0.6707, 0.4246, 0.2294, 0.6623, 0.7146, 0.1914, 0.6517, 0.9449, 0.5650, 0.2358, 0.6787, 0.3671], [0.3964, 0.6190, 0.7695, 0.6526, 0.7095, 0.2790, 0.0581, 0.2470, 0.8315, 0.9374, 0.0215, 0.3572, 0.0516, 0.1447, 0.0811, 0.2678], [0.0475, 0.9740, 0.0039, 0.6126, 0.9783, 0.5080, 0.5583, 0.0703, 0.8320, 0.9837, 0.3936, 0.6329, 0.8229, 0.8486, 0.7715, 0.9617], [0.6807, 0.7970, 0.1164, 0.8487, 0.7730, 0.1654, 0.5599, 0.5923, 0.3909, 0.4720, 0.9423, 0.7868, 0.3710, 0.6075, 0.6849, 0.1366], [0.0246, 0.5967, 0.2838, 0.8114, 0.3761, 0.3963, 0.7792, 0.9119, 0.4026, 0.4769, 0.1477, 0.0923, 0.0723, 0.4416, 0.4560, 0.9548], [0.8666, 0.6254, 0.9162, 0.1954, 0.8466, 0.6498, 0.3412, 0.2098, 0.9786, 0.3349, 0.7625, 0.3615, 0.8880, 0.0751, 0.8417, 0.5380], [0.2857, 0.6871, 0.6694, 0.8206, 0.5142, 0.5641, 0.3780, 0.9441, 0.0964, 0.2007, 0.1148, 0.8054, 0.1520, 0.3742, 0.6364, 0.9797]], device='cuda:0'), idlist_features=TensorDict( fields={ feature_0: NestedTensor(shape=torch.Size([8, j1]), device=cuda:0, dtype=torch.int64, is_shared=True), feature_1: NestedTensor(shape=torch.Size([8, j2]), device=cuda:0, dtype=torch.int64, is_shared=True), feature_2: NestedTensor(shape=torch.Size([8, j3]), device=cuda:0, dtype=torch.int64, is_shared=True), feature_3: NestedTensor(shape=torch.Size([8, j4]), device=cuda:0, dtype=torch.int64, is_shared=True)}, batch_size=torch.Size([]), device=cuda:0, is_shared=True), idscore_features=None, label=tensor([0.2093, 0.6164, 0.1763, 0.1895, 0.3132, 0.2133, 0.4997, 0.0055], device='cuda:0')) ``` Differential Revision: D66521351
- Loading branch information