Skip to content

Commit e7b6360

Browse files
jd7-trfacebook-github-bot
authored andcommitted
Add missing fields to KJT's PyTree flatten/unflatten logic for VBE KJT (#2952)
Summary: # Context * Currently torchrec IR serializer does not support exporting variable batch KJT, because the `stride_per_rank_per_rank` and `inverse_indices` fields are needed for deserializing VBE KJTs but they are included in the KJT's PyTree flatten/unflatten function. * The diff updates KJT's PyTree flatten/unflatten function to include `stride_per_rank_per_rank` and `inverse_indices`. # Ref Differential Revision: D74295924
1 parent 8fa8123 commit e7b6360

File tree

2 files changed

+64
-40
lines changed

2 files changed

+64
-40
lines changed

torchrec/ir/tests/test_serializer.py

Lines changed: 45 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -207,8 +207,14 @@ def forward(
207207
num_embeddings=10,
208208
feature_names=["f2"],
209209
)
210+
config3 = EmbeddingBagConfig(
211+
name="t3",
212+
embedding_dim=5,
213+
num_embeddings=10,
214+
feature_names=["f3"],
215+
)
210216
ebc = EmbeddingBagCollection(
211-
tables=[config1, config2],
217+
tables=[config1, config2, config3],
212218
is_weighted=False,
213219
)
214220

@@ -293,42 +299,60 @@ def test_serialize_deserialize_ebc(self) -> None:
293299
self.assertEqual(deserialized.shape, orginal.shape)
294300
self.assertTrue(torch.allclose(deserialized, orginal))
295301

296-
@unittest.skip("Adding test for demonstrating VBE KJT flattening issue for now.")
297302
def test_serialize_deserialize_ebc_with_vbe_kjt(self) -> None:
298303
model = self.generate_model_for_vbe_kjt()
299-
id_list_features = KeyedJaggedTensor(
300-
keys=["f1", "f2"],
301-
values=torch.tensor([5, 6, 7, 1, 2, 3, 0, 1]),
302-
lengths=torch.tensor([3, 3, 2]),
303-
stride_per_key_per_rank=[[2], [1]],
304-
inverse_indices=(["f1", "f2"], torch.tensor([[0, 1, 0], [0, 0, 0]])),
304+
kjt_1 = KeyedJaggedTensor(
305+
keys=["f1", "f2", "f3"],
306+
values=torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 0]),
307+
lengths=torch.tensor([1, 2, 3, 2, 1, 1]),
308+
stride_per_key_per_rank=torch.tensor([[3], [2], [1]]),
309+
inverse_indices=(
310+
["f1", "f2", "f3"],
311+
torch.tensor([[0, 1, 2], [0, 1, 0], [0, 0, 0]]),
312+
),
313+
)
314+
kjt_2 = KeyedJaggedTensor(
315+
keys=["f1", "f2", "f3"],
316+
values=torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 0]),
317+
lengths=torch.tensor([1, 2, 3, 2, 1, 1]),
318+
stride_per_key_per_rank=torch.tensor([[1], [2], [3]]),
319+
inverse_indices=(
320+
["f1", "f2", "f3"],
321+
torch.tensor([[0, 0, 0], [0, 1, 0], [0, 1, 2]]),
322+
),
305323
)
306324

307-
eager_out = model(id_list_features)
325+
eager_out = model(kjt_1)
326+
eager_out_2 = model(kjt_2)
308327

309328
# Serialize EBC
310329
model, sparse_fqns = encapsulate_ir_modules(model, JsonSerializer)
311330
ep = torch.export.export(
312331
model,
313-
(id_list_features,),
332+
(kjt_1,),
314333
{},
315334
strict=False,
316335
# Allows KJT to not be unflattened and run a forward on unflattened EP
317336
preserve_module_call_signature=(tuple(sparse_fqns)),
318337
)
319338

320339
# Run forward on ExportedProgram
321-
ep_output = ep.module()(id_list_features)
340+
ep_output = ep.module()(kjt_1)
341+
ep_output_2 = ep.module()(kjt_2)
322342

343+
self.assertEqual(len(ep_output), len(kjt_1.keys()))
344+
self.assertEqual(len(ep_output_2), len(kjt_2.keys()))
323345
for i, tensor in enumerate(ep_output):
324-
self.assertEqual(eager_out[i].shape, tensor.shape)
346+
self.assertEqual(eager_out[i].shape[1], tensor.shape[1])
347+
for i, tensor in enumerate(ep_output_2):
348+
self.assertEqual(eager_out_2[i].shape[1], tensor.shape[1])
325349

326350
# Deserialize EBC
327351
unflatten_ep = torch.export.unflatten(ep)
328352
deserialized_model = decapsulate_ir_modules(unflatten_ep, JsonSerializer)
329353

330354
# check EBC config
331-
for i in range(5):
355+
for i in range(1):
332356
ebc_name = f"ebc{i + 1}"
333357
self.assertIsInstance(
334358
getattr(deserialized_model, ebc_name), EmbeddingBagCollection
@@ -343,36 +367,22 @@ def test_serialize_deserialize_ebc_with_vbe_kjt(self) -> None:
343367
self.assertEqual(deserialized.num_embeddings, orginal.num_embeddings)
344368
self.assertEqual(deserialized.feature_names, orginal.feature_names)
345369

346-
# check FPEBC config
347-
for i in range(2):
348-
fpebc_name = f"fpebc{i + 1}"
349-
assert isinstance(
350-
getattr(deserialized_model, fpebc_name),
351-
FeatureProcessedEmbeddingBagCollection,
352-
)
353-
354-
for deserialized, orginal in zip(
355-
getattr(
356-
deserialized_model, fpebc_name
357-
)._embedding_bag_collection.embedding_bag_configs(),
358-
getattr(
359-
model, fpebc_name
360-
)._embedding_bag_collection.embedding_bag_configs(),
361-
):
362-
self.assertEqual(deserialized.name, orginal.name)
363-
self.assertEqual(deserialized.embedding_dim, orginal.embedding_dim)
364-
self.assertEqual(deserialized.num_embeddings, orginal.num_embeddings)
365-
self.assertEqual(deserialized.feature_names, orginal.feature_names)
366-
367370
# Run forward on deserialized model and compare the output
368371
deserialized_model.load_state_dict(model.state_dict())
369-
deserialized_out = deserialized_model(id_list_features)
372+
deserialized_out = deserialized_model(kjt_1)
370373

371374
self.assertEqual(len(deserialized_out), len(eager_out))
372375
for deserialized, orginal in zip(deserialized_out, eager_out):
373376
self.assertEqual(deserialized.shape, orginal.shape)
374377
self.assertTrue(torch.allclose(deserialized, orginal))
375378

379+
deserialized_out_2 = deserialized_model(kjt_2)
380+
381+
self.assertEqual(len(deserialized_out_2), len(eager_out_2))
382+
for deserialized, orginal in zip(deserialized_out_2, eager_out_2):
383+
self.assertEqual(deserialized.shape, orginal.shape)
384+
self.assertTrue(torch.allclose(deserialized, orginal))
385+
376386
def test_dynamic_shape_ebc_disabled_in_oss_compatibility(self) -> None:
377387
model = self.generate_model()
378388
feature1 = KeyedJaggedTensor.from_offsets_sync(

torchrec/sparse/jagged_tensor.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1107,7 +1107,7 @@ def _maybe_compute_stride_kjt(
11071107
if inverse_indices is not None and inverse_indices[1].numel() > 0:
11081108
return inverse_indices[1].shape[-1]
11091109

1110-
stride = int(stride_per_key_per_rank.sum(dim=1).max().item())
1110+
stride = int(stride_per_key_per_rank.sum(dim=-1).max().item())
11111111
elif offsets is not None and offsets.numel() > 0:
11121112
stride = (offsets.numel() - 1) // len(keys)
11131113
elif lengths is not None:
@@ -1728,6 +1728,8 @@ class KeyedJaggedTensor(Pipelineable, metaclass=JaggedTensorMeta):
17281728
"_weights",
17291729
"_lengths",
17301730
"_offsets",
1731+
"_stride_per_key_per_rank",
1732+
"_inverse_indices",
17311733
]
17321734

17331735
def __init__(
@@ -3016,7 +3018,10 @@ def dist_init(
30163018
def _kjt_flatten(
30173019
t: KeyedJaggedTensor,
30183020
) -> Tuple[List[Optional[torch.Tensor]], List[str]]:
3019-
return [getattr(t, a) for a in KeyedJaggedTensor._fields], t._keys
3021+
values = [getattr(t, a) for a in KeyedJaggedTensor._fields[:-1]]
3022+
values.append(t._inverse_indices[1] if t._inverse_indices is not None else None)
3023+
3024+
return values, t._keys
30203025

30213026

30223027
def _kjt_flatten_with_keys(
@@ -3030,15 +3035,24 @@ def _kjt_flatten_with_keys(
30303035

30313036

30323037
def _kjt_unflatten(
3033-
values: List[Optional[torch.Tensor]], context: List[str] # context is the _keys
3038+
values: List[Optional[torch.Tensor]],
3039+
context: List[str], # context is the _keys
30343040
) -> KeyedJaggedTensor:
3035-
return KeyedJaggedTensor(context, *values)
3041+
return KeyedJaggedTensor(
3042+
context,
3043+
*values[:-2],
3044+
stride_per_key_per_rank=values[-2],
3045+
inverse_indices=(context, values[-1]) if values[-1] is not None else None,
3046+
)
30363047

30373048

30383049
def _kjt_flatten_spec(
30393050
t: KeyedJaggedTensor, spec: TreeSpec
30403051
) -> List[Optional[torch.Tensor]]:
3041-
return [getattr(t, a) for a in KeyedJaggedTensor._fields]
3052+
values = [getattr(t, a) for a in KeyedJaggedTensor._fields[:-1]]
3053+
values.append(t._inverse_indices[1] if t._inverse_indices is not None else None)
3054+
3055+
return values
30423056

30433057

30443058
register_pytree_node(

0 commit comments

Comments
 (0)