diff --git a/torchrec/ir/tests/test_serializer.py b/torchrec/ir/tests/test_serializer.py index 7ff2adab2..060c2d224 100644 --- a/torchrec/ir/tests/test_serializer.py +++ b/torchrec/ir/tests/test_serializer.py @@ -352,59 +352,6 @@ def test_deserialized_device(self) -> None: continue assert param.device.type == device.type, f"{name} should be on {device}" - # pyre-ignore - @unittest.skipIf( - torch.cuda.device_count() <= 0, - "this test needs a GPU machine to run", - ) - def test_deserialize_device_kt_regroup(self) -> None: - class Model(nn.Module): - def __init__(self, ebc): - super().__init__() - self.ebc = ebc - - def forward( - self, - features: KeyedJaggedTensor, - ) -> List[torch.Tensor]: - kt = self.ebc(features) - return KeyedTensor.regroup([kt], [[key] for key in kt.keys()]) - - model = self.generate_model() - model = Model(model.ebc1) - id_list_features = KeyedJaggedTensor.from_offsets_sync( - keys=["f1", "f2", "f3"], - values=torch.tensor([0, 1, 2, 3, 2, 3]), - offsets=torch.tensor([0, 2, 2, 3, 4, 5, 6]), - ) - eager_out = model(id_list_features) - - # Serialize EBC - model, sparse_fqns = encapsulate_ir_modules(model, JsonSerializer) - ep = torch.export.export( - model, - (id_list_features,), - {}, - strict=False, - # Allows KJT to not be unflattened and run a forward on unflattened EP - preserve_module_call_signature=(tuple(sparse_fqns)), - ) - unflatten_model = torch.export.unflatten(ep) - deserialized_model = decapsulate_ir_modules( - unflatten_model, JsonSerializer, torch.device("cuda") - ) - device = torch.device("cuda") - deserialized_model.to(device) - id_list_features = id_list_features.to(device) - - deserialized_model.load_state_dict(model.state_dict()) - # Run forward on deserialized model - deserialized_out = deserialized_model(id_list_features) - - for i, tensor in enumerate(deserialized_out): - assert eager_out[i].shape == tensor.shape - assert torch.allclose(eager_out[i].to(tensor), tensor) - def test_compound_module(self) -> None: tb1_config = EmbeddingBagConfig( name="t1", diff --git a/torchrec/sparse/jagged_tensor.py b/torchrec/sparse/jagged_tensor.py index 10e004897..5a988ee13 100644 --- a/torchrec/sparse/jagged_tensor.py +++ b/torchrec/sparse/jagged_tensor.py @@ -2841,7 +2841,11 @@ def to_dict(self) -> Dict[str, torch.Tensor]: def regroup( keyed_tensors: List["KeyedTensor"], groups: List[List[str]] ) -> List[torch.Tensor]: - return permute_multi_embedding(keyed_tensors, groups) + # Fast path, one-to-one correspondence between keyed_tensors and groups + if _all_keys_used_once(keyed_tensors, groups) is True: + return _fbgemm_permute_pooled_embs(keyed_tensors, groups) + else: # Fallback to slow path otherwise + return _regroup_keyed_tensors(keyed_tensors, groups) @staticmethod def regroup_as_dict(