Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
KT unflatten issue with torch.export
Summary: # context current error: ``` 1) torchrec.fb.ir.tests.test_serializer.TestSerializer: test_deserialized_device_vle 1) RuntimeError: Node ir_dynamic_batch_emb_lookup_default referenced nonexistent value id_list_features__values! Run Graph.lint() to diagnose such issues While executing %ir_dynamic_batch_emb_lookup_default : [num_users=1] = call_function[target=torch.ops.torchrec.ir_dynamic_batch_emb_lookup.default](args = ([%id_list_features__values, None, %id_list_features__lengths, None], %floordiv, [4, 5]), kwargs = {}) Original traceback: File "/data/users/hhy/fbsource/buck-out/v2/gen/fbcode/009ebbab256a7e75/torchrec/fb/ir/tests/__test_serializer__/test_serializer#link-tree/torchrec/fb/ir/tests/test_serializer.py", line 142, in forward return self.sparse_arch(id_list_features) File "torchrec/fb/ir/tests/test_serializer.py", line 446, in test_deserialized_device_vle output = deserialized_model(features_batch_3.to(device)) File "torch/nn/modules/module.py", line 1736, in _wrapped_call_impl return self._call_impl(*args, **kwargs) File "torch/nn/modules/module.py", line 1747, in _call_impl return forward_call(*args, **kwargs) File "torch/export/unflatten.py", line 482, in forward tree_out = torch.fx.Interpreter(self, graph=self.graph).run( File "torch/fx/interpreter.py", line 146, in run self.env[node] = self.run_node(node) File "torch/fx/interpreter.py", line 200, in run_node args, kwargs = self.fetch_args_kwargs_from_env(n) File "torch/fx/interpreter.py", line 372, in fetch_args_kwargs_from_env args = self.map_nodes_to_values(n.args, n) File "torch/fx/interpreter.py", line 394, in map_nodes_to_values return map_arg(args, load_arg) File "torch/fx/node.py", line 760, in map_arg return map_aggregate(a, lambda x: fn(x) if isinstance(x, Node) else x) File "torch/fx/node.py", line 768, in map_aggregate t = tuple(map_aggregate(elem, fn) for elem in a) File "torch/fx/node.py", line 768, in <genexpr> t = tuple(map_aggregate(elem, fn) for elem in a) File "torch/fx/node.py", line 772, in map_aggregate return immutable_list(map_aggregate(elem, fn) for elem in a) File "torch/fx/node.py", line 772, in <genexpr> return immutable_list(map_aggregate(elem, fn) for elem in a) File "torch/fx/node.py", line 778, in map_aggregate return fn(a) File "torch/fx/node.py", line 760, in <lambda> return map_aggregate(a, lambda x: fn(x) if isinstance(x, Node) else x) File "torch/fx/interpreter.py", line 391, in load_arg raise RuntimeError(f'Node {n} referenced nonexistent value {n_arg}! Run Graph.lint() ' ``` Differential Revision: D59238744
- Loading branch information