Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
refactor _maybe_compute_kjt_to_jt_dict (pytorch#2326)
Summary: Pull Request resolved: pytorch#2326 # context * want to resolve graph break: [failures_and_restarts](https://interncache-all.fbcdn.net/manifold/tlparse_reports/tree/logs/.tmpKJM3FI/failures_and_restarts.html), P1537573230 ``` Tried to use data-dependent value in the subsequent computation. This can happen when we encounter unbounded dynamic value that is unknown during tracing time. You will need to explicitly give hint to the compiler. Please take a look at torch._check OR torch._check_is_size APIs. Could not guard on data-dependent expression Eq(((2*u48)//(u48 + u49)), 0) (unhinted: Eq(((2*u48)//(u48 + u49)), 0)). (Size-like symbols: u49, u48) Potential framework code culprit (scroll up for full backtrace): File "/data/users/hhy/fbsource/buck-out/v2/gen/fbcode/61f992c26f3f2773/aps_models/ads/icvr/__icvr_launcher_live__/icvr_launcher_live#link-tree/torch/_refs/__init__.py", line 3950, in unbind if guard_size_oblivious(t.shape[dim] == 0): For more information, run with TORCH_LOGS="dynamic" For extended logs when we create symbols, also add TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="u49,u48" If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1 For more debugging help, see https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing User Stack (most recent call last): (snipped, see stack below for prefix) ... File "/data/users/hhy/fbsource/buck-out/v2/gen/fbcode/61f992c26f3f2773/aps_models/ads/icvr/__icvr_launcher_live__/icvr_launcher_live#link-tree/torchrec/sparse/jagged_tensor.py", line 2241, in to_dict _jt_dict = _maybe_compute_kjt_to_jt_dict( File "/data/users/hhy/fbsource/buck-out/v2/gen/fbcode/61f992c26f3f2773/aps_models/ads/icvr/__icvr_launcher_live__/icvr_launcher_live#link-tree/torchrec/sparse/jagged_tensor.py", line 1226, in _maybe_compute_kjt_to_jt_dict split_lengths = torch.unbind( ``` * we added [shape check](https://fburl.com/code/p02u4mck): ``` if pt2_guard_size_oblivious(lengths.numel() > 0): strided_lengths = lengths.view(-1, stride) if not torch.jit.is_scripting() and is_torchdynamo_compiling(): torch._check(strided_lengths.shape[0] > 0) torch._check(strided_lengths.shape[1] > 0) split_lengths = torch.unbind( strided_lengths, dim=0, ) ``` * however the error is still there ``` File "/data/users/hhy/fbsource/buck-out/v2/gen/fbcode/61f992c26f3f2773/aps_models/ads/icvr/__icvr_launcher_live__/icvr_launcher_live#link-tree/torch/_refs/__init__.py", line 3950, in unbind if guard_size_oblivious(t.shape[dim] == 0): File "/data/users/hhy/fbsource/buck-out/v2/gen/fbcode/61f992c26f3f2773/aps_models/ads/icvr/__icvr_launcher_live__/icvr_launcher_live#link-tree/torch/fx/experimental/symbolic_shapes.py", line 253, in guard_size_oblivious return expr.node.guard_size_oblivious("", 0) File "/data/users/hhy/fbsource/buck-out/v2/gen/fbcode/61f992c26f3f2773/aps_models/ads/icvr/__icvr_launcher_live__/icvr_launcher_live#link-tree/torch/fx/experimental/sym_node.py", line 503, in guard_size_oblivious r = self.shape_env.evaluate_expr( ``` * [implementation](https://fburl.com/code/20iue1ib) ``` register_decomposition(aten.unbind) def unbind(t: TensorLikeType, dim: int = 0) -> TensorSequenceType: from torch.fx.experimental.symbolic_shapes import guard_size_oblivious dim = utils.canonicalize_dim(t.ndim, dim) torch._check_index( len(t.shape) > 0, lambda: "Dimension specified as 0 but tensor has no dimensions", ) if guard_size_oblivious(t.shape[dim] == 0): # <------- here return () else: return tuple( torch.squeeze(s, dim) for s in torch.tensor_split(t, t.shape[dim], dim) ) ``` * with D61677207 [no graph break at _maybe_compute_kjt_to_jt_dict](https://interncache-all.fbcdn.net/manifold/tlparse_reports/tree/logs/.tmpNcI14t/failures_and_restarts.html) Differential Revision: D55277785
- Loading branch information