Skip to content

Add missing fields to KJT's PyTree flatten/unflatten logic for VBE KJT #2952

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
80 changes: 45 additions & 35 deletions torchrec/ir/tests/test_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,8 +207,14 @@ def forward(
num_embeddings=10,
feature_names=["f2"],
)
config3 = EmbeddingBagConfig(
name="t3",
embedding_dim=5,
num_embeddings=10,
feature_names=["f3"],
)
ebc = EmbeddingBagCollection(
tables=[config1, config2],
tables=[config1, config2, config3],
is_weighted=False,
)

Expand Down Expand Up @@ -293,42 +299,60 @@ def test_serialize_deserialize_ebc(self) -> None:
self.assertEqual(deserialized.shape, orginal.shape)
self.assertTrue(torch.allclose(deserialized, orginal))

@unittest.skip("Adding test for demonstrating VBE KJT flattening issue for now.")
def test_serialize_deserialize_ebc_with_vbe_kjt(self) -> None:
model = self.generate_model_for_vbe_kjt()
id_list_features = KeyedJaggedTensor(
keys=["f1", "f2"],
values=torch.tensor([5, 6, 7, 1, 2, 3, 0, 1]),
lengths=torch.tensor([3, 3, 2]),
stride_per_key_per_rank=[[2], [1]],
inverse_indices=(["f1", "f2"], torch.tensor([[0, 1, 0], [0, 0, 0]])),
kjt_1 = KeyedJaggedTensor(
keys=["f1", "f2", "f3"],
values=torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 0]),
lengths=torch.tensor([1, 2, 3, 2, 1, 1]),
stride_per_key_per_rank=torch.tensor([[3], [2], [1]]),
inverse_indices=(
["f1", "f2", "f3"],
torch.tensor([[0, 1, 2], [0, 1, 0], [0, 0, 0]]),
),
)
kjt_2 = KeyedJaggedTensor(
keys=["f1", "f2", "f3"],
values=torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 0]),
lengths=torch.tensor([1, 2, 3, 2, 1, 1]),
stride_per_key_per_rank=torch.tensor([[1], [2], [3]]),
inverse_indices=(
["f1", "f2", "f3"],
torch.tensor([[0, 0, 0], [0, 1, 0], [0, 1, 2]]),
),
)

eager_out = model(id_list_features)
eager_out = model(kjt_1)
eager_out_2 = model(kjt_2)

# Serialize EBC
model, sparse_fqns = encapsulate_ir_modules(model, JsonSerializer)
ep = torch.export.export(
model,
(id_list_features,),
(kjt_1,),
{},
strict=False,
# Allows KJT to not be unflattened and run a forward on unflattened EP
preserve_module_call_signature=(tuple(sparse_fqns)),
)

# Run forward on ExportedProgram
ep_output = ep.module()(id_list_features)
ep_output = ep.module()(kjt_1)
ep_output_2 = ep.module()(kjt_2)

self.assertEqual(len(ep_output), len(kjt_1.keys()))
self.assertEqual(len(ep_output_2), len(kjt_2.keys()))
for i, tensor in enumerate(ep_output):
self.assertEqual(eager_out[i].shape, tensor.shape)
self.assertEqual(eager_out[i].shape[1], tensor.shape[1])
for i, tensor in enumerate(ep_output_2):
self.assertEqual(eager_out_2[i].shape[1], tensor.shape[1])

# Deserialize EBC
unflatten_ep = torch.export.unflatten(ep)
deserialized_model = decapsulate_ir_modules(unflatten_ep, JsonSerializer)

# check EBC config
for i in range(5):
for i in range(1):
ebc_name = f"ebc{i + 1}"
self.assertIsInstance(
getattr(deserialized_model, ebc_name), EmbeddingBagCollection
Expand All @@ -343,36 +367,22 @@ def test_serialize_deserialize_ebc_with_vbe_kjt(self) -> None:
self.assertEqual(deserialized.num_embeddings, orginal.num_embeddings)
self.assertEqual(deserialized.feature_names, orginal.feature_names)

# check FPEBC config
for i in range(2):
fpebc_name = f"fpebc{i + 1}"
assert isinstance(
getattr(deserialized_model, fpebc_name),
FeatureProcessedEmbeddingBagCollection,
)

for deserialized, orginal in zip(
getattr(
deserialized_model, fpebc_name
)._embedding_bag_collection.embedding_bag_configs(),
getattr(
model, fpebc_name
)._embedding_bag_collection.embedding_bag_configs(),
):
self.assertEqual(deserialized.name, orginal.name)
self.assertEqual(deserialized.embedding_dim, orginal.embedding_dim)
self.assertEqual(deserialized.num_embeddings, orginal.num_embeddings)
self.assertEqual(deserialized.feature_names, orginal.feature_names)

# Run forward on deserialized model and compare the output
deserialized_model.load_state_dict(model.state_dict())
deserialized_out = deserialized_model(id_list_features)
deserialized_out = deserialized_model(kjt_1)

self.assertEqual(len(deserialized_out), len(eager_out))
for deserialized, orginal in zip(deserialized_out, eager_out):
self.assertEqual(deserialized.shape, orginal.shape)
self.assertTrue(torch.allclose(deserialized, orginal))

deserialized_out_2 = deserialized_model(kjt_2)

self.assertEqual(len(deserialized_out_2), len(eager_out_2))
for deserialized, orginal in zip(deserialized_out_2, eager_out_2):
self.assertEqual(deserialized.shape, orginal.shape)
self.assertTrue(torch.allclose(deserialized, orginal))

def test_dynamic_shape_ebc_disabled_in_oss_compatibility(self) -> None:
model = self.generate_model()
feature1 = KeyedJaggedTensor.from_offsets_sync(
Expand Down
38 changes: 34 additions & 4 deletions torchrec/sparse/jagged_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1728,6 +1728,8 @@ class KeyedJaggedTensor(Pipelineable, metaclass=JaggedTensorMeta):
"_weights",
"_lengths",
"_offsets",
"_stride_per_key_per_rank",
"_inverse_indices",
]

def __init__(
Expand Down Expand Up @@ -3021,7 +3023,26 @@ def dist_init(
def _kjt_flatten(
t: KeyedJaggedTensor,
) -> Tuple[List[Optional[torch.Tensor]], List[str]]:
return [getattr(t, a) for a in KeyedJaggedTensor._fields], t._keys
"""
Used by PyTorch's pytree utilities for serialization and processing.
Extracts tensor attributes of a KeyedJaggedTensor and returns them
as a flat list, along with the necessary metadata to reconstruct the KeyedJaggedTensor.

Component tensors are returned as dynamic attributes.
KJT metadata are added as static specs.

Returns:
Tuple containing:
- List[Optional[torch.Tensor]]: All tensor attributes (_values, _weights, _lengths,
_offsets, _stride_per_key_per_rank, and the tensor part of _inverse_indices if present)
- Tuple[List[str], List[str]]: Metadata needed for reconstruction:
- List of keys from the original KeyedJaggedTensor
- List of inverse indices keys (if present, otherwise empty list)
"""
values = [getattr(t, a) for a in KeyedJaggedTensor._fields[:-1]]
values.append(t._inverse_indices[1] if t._inverse_indices is not None else None)

return values, t._keys


def _kjt_flatten_with_keys(
Expand All @@ -3035,15 +3056,24 @@ def _kjt_flatten_with_keys(


def _kjt_unflatten(
values: List[Optional[torch.Tensor]], context: List[str] # context is the _keys
values: List[Optional[torch.Tensor]],
context: List[str], # context is _keys
) -> KeyedJaggedTensor:
return KeyedJaggedTensor(context, *values)
return KeyedJaggedTensor(
context,
*values[:-2],
stride_per_key_per_rank=values[-2],
inverse_indices=(context, values[-1]) if values[-1] is not None else None,
)


def _kjt_flatten_spec(
t: KeyedJaggedTensor, spec: TreeSpec
) -> List[Optional[torch.Tensor]]:
return [getattr(t, a) for a in KeyedJaggedTensor._fields]
values = [getattr(t, a) for a in KeyedJaggedTensor._fields[:-1]]
values.append(t._inverse_indices[1] if t._inverse_indices is not None else None)

return values


register_pytree_node(
Expand Down
Loading