Skip to content

Commit

Permalink
Revert D55277833
Browse files Browse the repository at this point in the history
Summary:
This diff reverts D55277833
(The context such as a Sandcastle job, Task, SEV, etc. was not provided.)

Reviewed By: seanx92

Differential Revision: D62156250
  • Loading branch information
Dark Knight authored and facebook-github-bot committed Sep 3, 2024
1 parent c43184e commit b7a5412
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 54 deletions.
53 changes: 0 additions & 53 deletions torchrec/ir/tests/test_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -352,59 +352,6 @@ def test_deserialized_device(self) -> None:
continue
assert param.device.type == device.type, f"{name} should be on {device}"

# pyre-ignore
@unittest.skipIf(
torch.cuda.device_count() <= 0,
"this test needs a GPU machine to run",
)
def test_deserialize_device_kt_regroup(self) -> None:
class Model(nn.Module):
def __init__(self, ebc):
super().__init__()
self.ebc = ebc

def forward(
self,
features: KeyedJaggedTensor,
) -> List[torch.Tensor]:
kt = self.ebc(features)
return KeyedTensor.regroup([kt], [[key] for key in kt.keys()])

model = self.generate_model()
model = Model(model.ebc1)
id_list_features = KeyedJaggedTensor.from_offsets_sync(
keys=["f1", "f2", "f3"],
values=torch.tensor([0, 1, 2, 3, 2, 3]),
offsets=torch.tensor([0, 2, 2, 3, 4, 5, 6]),
)
eager_out = model(id_list_features)

# Serialize EBC
model, sparse_fqns = encapsulate_ir_modules(model, JsonSerializer)
ep = torch.export.export(
model,
(id_list_features,),
{},
strict=False,
# Allows KJT to not be unflattened and run a forward on unflattened EP
preserve_module_call_signature=(tuple(sparse_fqns)),
)
unflatten_model = torch.export.unflatten(ep)
deserialized_model = decapsulate_ir_modules(
unflatten_model, JsonSerializer, torch.device("cuda")
)
device = torch.device("cuda")
deserialized_model.to(device)
id_list_features = id_list_features.to(device)

deserialized_model.load_state_dict(model.state_dict())
# Run forward on deserialized model
deserialized_out = deserialized_model(id_list_features)

for i, tensor in enumerate(deserialized_out):
assert eager_out[i].shape == tensor.shape
assert torch.allclose(eager_out[i].to(tensor), tensor)

def test_compound_module(self) -> None:
tb1_config = EmbeddingBagConfig(
name="t1",
Expand Down
6 changes: 5 additions & 1 deletion torchrec/sparse/jagged_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2841,7 +2841,11 @@ def to_dict(self) -> Dict[str, torch.Tensor]:
def regroup(
keyed_tensors: List["KeyedTensor"], groups: List[List[str]]
) -> List[torch.Tensor]:
return permute_multi_embedding(keyed_tensors, groups)
# Fast path, one-to-one correspondence between keyed_tensors and groups
if _all_keys_used_once(keyed_tensors, groups) is True:
return _fbgemm_permute_pooled_embs(keyed_tensors, groups)
else: # Fallback to slow path otherwise
return _regroup_keyed_tensors(keyed_tensors, groups)

@staticmethod
def regroup_as_dict(
Expand Down

0 comments on commit b7a5412

Please sign in to comment.