diff --git a/torchrec/ir/tests/test_serializer.py b/torchrec/ir/tests/test_serializer.py index 060c2d224..26168aced 100644 --- a/torchrec/ir/tests/test_serializer.py +++ b/torchrec/ir/tests/test_serializer.py @@ -18,6 +18,7 @@ from torchrec.ir.serializer import JsonSerializer from torchrec.ir.utils import ( + # bypass_pytree_ebc_regroup, decapsulate_ir_modules, encapsulate_ir_modules, mark_dynamic_kjt, @@ -521,3 +522,69 @@ def forward( self.assertTrue(deserialized_model.regroup._is_inited) for key in eager_out.keys(): self.assertEqual(deserialized_out[key].shape, eager_out[key].shape) + + def test_key_order_with_ebc_and_regroup(self) -> None: + tb1_config = EmbeddingBagConfig( + name="t1", + embedding_dim=3, + num_embeddings=10, + feature_names=["f1"], + ) + tb2_config = EmbeddingBagConfig( + name="t2", + embedding_dim=4, + num_embeddings=10, + feature_names=["f2"], + ) + tb3_config = EmbeddingBagConfig( + name="t3", + embedding_dim=5, + num_embeddings=10, + feature_names=["f3"], + ) + id_list_features = KeyedJaggedTensor.from_offsets_sync( + keys=["f1", "f2", "f3", "f4", "f5"], + values=torch.tensor([0, 1, 2, 3, 2, 3, 4, 5, 6, 7, 8, 9, 1, 1, 2]), + offsets=torch.tensor([0, 2, 2, 3, 4, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 15]), + ) + ebc1 = EmbeddingBagCollection( + tables=[tb1_config, tb2_config, tb3_config], + is_weighted=False, + ) + ebc2 = EmbeddingBagCollection( + tables=[tb1_config, tb3_config, tb2_config], + is_weighted=False, + ) + ebc2.load_state_dict(ebc1.state_dict()) + regroup = KTRegroupAsDict([["f1", "f3"], ["f2"]], ["odd", "even"]) + class myModel(nn.Module): + def __init__(self, ebc, regroup): + super().__init__() + self.ebc = ebc + self.regroup = regroup + + def forward( + self, + features: KeyedJaggedTensor, + ) -> Dict[str, torch.Tensor]: + return self.regroup([self.ebc(features)]) + + model = myModel(ebc1, regroup) + eager_out = model(id_list_features) + + model, sparse_fqns = encapsulate_ir_modules(model, JsonSerializer) + ep = torch.export.export( + model, + (id_list_features,), + {}, + strict=False, + # Allows KJT to not be unflattened and run a forward on unflattened EP + preserve_module_call_signature=(tuple(sparse_fqns)), + ) + unflatten_ep = torch.export.unflatten(ep) + deserialized_model = decapsulate_ir_modules(unflatten_ep, JsonSerializer) + deserialized_model.ebc = ebc2 + # bypass_pytree_ebc_regroup(deserialized_model) + deserialized_out = deserialized_model(id_list_features) + for key in eager_out.keys(): + torch.testing.assert_close(deserialized_out[key], eager_out[key])