Skip to content

Commit b9c25f1

Browse files
Thiago Crepaldipytorchmergebot
Thiago Crepaldi
authored andcommitted
Ignore shape inference exception from Caffe2 ATen fallback (pytorch#90408)
Fixes pytorch#87318 Pull Request resolved: pytorch#90408 Approved by: https://github.com/BowenBao
1 parent c988de1 commit b9c25f1

File tree

2 files changed

+25
-11
lines changed

2 files changed

+25
-11
lines changed

test/onnx/test_export_modes.py

-5
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
import shutil
66
import sys
77
import tempfile
8-
import unittest
98

109
import torch
1110
import torch.nn as nn
@@ -87,10 +86,6 @@ def foo(a):
8786
x = torch.ones(3)
8887
torch.onnx.export(foo, (x,), f)
8988

90-
# TODO(87318): Can't pass even with Caffe2
91-
@unittest.skip(
92-
"RuntimeError: ScalarType UNKNOWN_SCALAR is an unexpected tensor scalar type"
93-
)
9489
@common_utils.skipIfNoCaffe2
9590
@common_utils.skipIfNoLapack
9691
def test_caffe2_aten_fallback(self):

torch/onnx/utils.py

+25-6
Original file line numberDiff line numberDiff line change
@@ -686,9 +686,19 @@ def _optimize_graph(
686686
graph = _C._jit_pass_canonicalize(graph)
687687
_C._jit_pass_lint(graph)
688688
if GLOBALS.onnx_shape_inference:
689-
_C._jit_pass_onnx_graph_shape_type_inference(
690-
graph, params_dict, GLOBALS.export_onnx_opset_version
691-
)
689+
try:
690+
_C._jit_pass_onnx_graph_shape_type_inference(
691+
graph, params_dict, GLOBALS.export_onnx_opset_version
692+
)
693+
except RuntimeError as exc:
694+
if (
695+
_C_onnx._CAFFE2_ATEN_FALLBACK
696+
and exc.args[0]
697+
== "ScalarType UNKNOWN_SCALAR is an unexpected tensor scalar type!"
698+
):
699+
# Caffe2 builds can have UNKNOWN_SCALAR for some tensors
700+
pass
701+
692702
return graph
693703

694704

@@ -1183,9 +1193,18 @@ def _model_to_graph(
11831193
_C._jit_pass_dce_allow_deleting_nodes_with_side_effects(graph)
11841194

11851195
if GLOBALS.onnx_shape_inference:
1186-
_C._jit_pass_onnx_graph_shape_type_inference(
1187-
graph, params_dict, GLOBALS.export_onnx_opset_version
1188-
)
1196+
try:
1197+
_C._jit_pass_onnx_graph_shape_type_inference(
1198+
graph, params_dict, GLOBALS.export_onnx_opset_version
1199+
)
1200+
except RuntimeError as exc:
1201+
if (
1202+
_C_onnx._CAFFE2_ATEN_FALLBACK
1203+
and exc.args[0]
1204+
== "ScalarType UNKNOWN_SCALAR is an unexpected tensor scalar type!"
1205+
):
1206+
# Caffe2 builds can have UNKNOWN_SCALAR for some tensors
1207+
pass
11891208

11901209
params_dict = _C._jit_pass_onnx_eliminate_unused_items(graph, params_dict)
11911210

0 commit comments

Comments
 (0)