Skip to content

Commit

Permalink
how to mark KJT offsets as dynamic?
Browse files Browse the repository at this point in the history
Summary:
# context
1. KJT contains three necessary tensors: `_values`, `_lengths`, `_offsets`
**a.** the shape of `_values` is independent
**b.** dim(`_lengths`) = dim(`batch_size`) * const(`len(kjt.keys())`)
**c.** dim(`_offsets`) = dim(`lengths`) + 1
2. `_lengths` and `_offsets` can be calculated from the other, so usually a KJT only stores one is the memory and calculate the other when needed.
3. previously only the `_lengths` is marked as dynamic shape, because `batch_size` and `len(kjt.keys())` are constant across iterations.
4. however, when we declare a KJT has both `_values` and `_offsets` as the dynamic shape, it won't pass the export function

# notes
1. the `feature2` in the test has **NO** impact on the failure because it errors out before `feature2` is used
2. the error is purely due to the change that marks `_offsets` as dynamic.

# investigation
* `_offsets` is set to `3 * batch_size + 1` as shown below:
```
{'features': [(<class 'torchrec.ir.utils.vlen1'>,), None, None, (<class 'torch.export.dynamic_shapes.3*batch_size1 + 1'>,)]}
```
* dynamic_shape `s1` is created for `_offsets`, dynamic_shape `s2` is craeted for `batch_size`
* why there is no `s1 == 3*batch_size + 1`?
```
0702 09:50:39.181000 140316068409792 torch/fx/experimental/symbolic_shapes.py:3575] create_symbol s1 = 7 for L['args'][0][0]._offsets.size()[0] [2, 12884901886] (_export/non_strict_utils.py:93 in fakify), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s1"
V0702 09:50:39.183000 140316068409792 torch/fx/experimental/symbolic_shapes.py:5189] eval False == False [statically known]
I0702 09:50:39.190000 140316068409792 torch/fx/experimental/symbolic_shapes.py:3575] create_symbol s2 = 2 for batch_size1 [2, 4294967295] (export/dynamic_shapes.py:569 in _process_equalities), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s2"
V0702 09:50:39.267000 140316068409792 torch/fx/experimental/symbolic_shapes.py:5189] eval ((s1 - 1)//3) >= 0 == True [statically known]
I0702 09:50:39.273000 140316068409792 torch/fx/experimental/symbolic_shapes.py:5104] eval Ne(((s1 - 1)//3), 0) [guard added] (_subclasses/functional_tensor.py:134 in __new__), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Ne(((s1 - 1)//3), 0)"
V0702 09:50:39.322000 140316068409792 torch/fx/experimental/symbolic_shapes.py:4736] _update_var_to_range s1 = VR[7, 7] (update)
I0702 09:50:39.330000 140316068409792 torch/fx/experimental/symbolic_shapes.py:4855] set_replacement s1 = 7 (range_refined_to_singleton) VR[7, 7]
```

# resolve the issue
* there is an internal flag `_allow_complex_guards_as_runtime_asserts=True` can support this correlation
* before
```
        ep = torch.export.export(
            model,
            (feature1,),
            {},
            dynamic_shapes=collection.dynamic_shapes(model, (feature1,)),
            strict=False,
            # Allows KJT to not be unflattened and run a forward on unflattened EP
            preserve_module_call_signature=tuple(sparse_fqns),
        )
```
* after
```
        ep = torch.export._trace._export(
            model,
            (feature1,),
            {},
            dynamic_shapes=collection.dynamic_shapes(model, (feature1,)),
            strict=False,
            # Allows KJT to not be unflattened and run a forward on unflattened EP
            preserve_module_call_signature=tuple(sparse_fqns),
            _allow_complex_guards_as_runtime_asserts=True,
        )
```

Differential Revision: D59201188
  • Loading branch information
TroyGarden authored and facebook-github-bot committed Jul 2, 2024
1 parent dccf2e5 commit 80c8f0f
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions torchrec/ir/tests/test_serializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -258,22 +258,23 @@ def test_dynamic_shape_ebc(self) -> None:

feature2 = KeyedJaggedTensor.from_offsets_sync(
keys=["f1", "f2", "f3"],
values=torch.tensor([0, 1, 2, 3, 2, 3, 4]),
offsets=torch.tensor([0, 2, 2, 3, 4, 5, 7]),
values=torch.tensor([0, 1, 2, 3, 2, 3, 4, 5, 6, 8, 1, 2]),
offsets=torch.tensor([0, 2, 2, 3, 4, 5, 7, 8, 10, 12]),
)
eager_out = model(feature2)

# Serialize EBC
collection = mark_dynamic_kjt(feature1)
collection = mark_dynamic_kjt(feature1, variable_length=True)
model, sparse_fqns = encapsulate_ir_modules(model, JsonSerializer)
ep = torch.export.export(
ep = torch.export._trace._export(
model,
(feature1,),
{},
dynamic_shapes=collection.dynamic_shapes(model, (feature1,)),
strict=False,
# Allows KJT to not be unflattened and run a forward on unflattened EP
preserve_module_call_signature=tuple(sparse_fqns),
_allow_complex_guards_as_runtime_asserts=True,
)

# Run forward on ExportedProgram
Expand Down

0 comments on commit 80c8f0f

Please sign in to comment.