Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Summary: # context 1. KJT contains three necessary tensors: `_values`, `_lengths`, `_offsets` **a.** the shape of `_values` is independent **b.** dim(`_lengths`) = dim(`batch_size`) * const(`len(kjt.keys())`) **c.** dim(`_offsets`) = dim(`lengths`) + 1 2. `_lengths` and `_offsets` can be calculated from the other, so usually a KJT only stores one is the memory and calculate the other when needed. 3. previously only the `_lengths` is marked as dynamic shape, because `batch_size` and `len(kjt.keys())` are constant across iterations. 4. however, when we declare a KJT has both `_values` and `_offsets` as the dynamic shape, it won't pass the export function # notes 1. the `feature2` in the test has **NO** impact on the failure because it errors out before `feature2` is used 2. the error is purely due to the change that marks `_offsets` as dynamic. # investigation * `_offsets` is set to `3 * batch_size + 1` as shown below: ``` {'features': [(<class 'torchrec.ir.utils.vlen1'>,), None, None, (<class 'torch.export.dynamic_shapes.3*batch_size1 + 1'>,)]} ``` * dynamic_shape `s1` is created for `_offsets`, dynamic_shape `s2` is craeted for `batch_size` * why there is no `s1 == 3*batch_size + 1`? ``` 0702 09:50:39.181000 140316068409792 torch/fx/experimental/symbolic_shapes.py:3575] create_symbol s1 = 7 for L['args'][0][0]._offsets.size()[0] [2, 12884901886] (_export/non_strict_utils.py:93 in fakify), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s1" V0702 09:50:39.183000 140316068409792 torch/fx/experimental/symbolic_shapes.py:5189] eval False == False [statically known] I0702 09:50:39.190000 140316068409792 torch/fx/experimental/symbolic_shapes.py:3575] create_symbol s2 = 2 for batch_size1 [2, 4294967295] (export/dynamic_shapes.py:569 in _process_equalities), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="s2" V0702 09:50:39.267000 140316068409792 torch/fx/experimental/symbolic_shapes.py:5189] eval ((s1 - 1)//3) >= 0 == True [statically known] I0702 09:50:39.273000 140316068409792 torch/fx/experimental/symbolic_shapes.py:5104] eval Ne(((s1 - 1)//3), 0) [guard added] (_subclasses/functional_tensor.py:134 in __new__), for more info run with TORCHDYNAMO_EXTENDED_DEBUG_GUARD_ADDED="Ne(((s1 - 1)//3), 0)" V0702 09:50:39.322000 140316068409792 torch/fx/experimental/symbolic_shapes.py:4736] _update_var_to_range s1 = VR[7, 7] (update) I0702 09:50:39.330000 140316068409792 torch/fx/experimental/symbolic_shapes.py:4855] set_replacement s1 = 7 (range_refined_to_singleton) VR[7, 7] ``` # resolve the issue * there is an internal flag `_allow_complex_guards_as_runtime_asserts=True` can support this correlation * before ``` ep = torch.export.export( model, (feature1,), {}, dynamic_shapes=collection.dynamic_shapes(model, (feature1,)), strict=False, # Allows KJT to not be unflattened and run a forward on unflattened EP preserve_module_call_signature=tuple(sparse_fqns), ) ``` * after ``` ep = torch.export._trace._export( model, (feature1,), {}, dynamic_shapes=collection.dynamic_shapes(model, (feature1,)), strict=False, # Allows KJT to not be unflattened and run a forward on unflattened EP preserve_module_call_signature=tuple(sparse_fqns), _allow_complex_guards_as_runtime_asserts=True, ) ``` Differential Revision: D59201188
- Loading branch information