forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_pytorch_onnx_no_runtime.py
114 lines (97 loc) · 3.65 KB
/
test_pytorch_onnx_no_runtime.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
# Owner(s): ["module: onnx"]
"""Tests for onnx export that don't run the exported model."""
import io
import unittest
from typing import Optional, Type
import onnx
import torch
from torch import Tensor
from torch.onnx import symbolic_helper
from torch.testing._internal.common_utils import (
instantiate_parametrized_tests,
parametrize,
)
class TestOptionalOutput(unittest.TestCase):
# TODO: Move these tests to test_pytorch_onnx_onnxruntime once
# ONNX Runtime 1.11 is released and supports opset 16.
class IfNoneInput(torch.nn.Module):
def forward(self, x) -> Optional[Tensor]:
y: Optional[Tensor] = None
if x.size(0) > 1:
y = x
return y
class IfNoneOutput(torch.nn.Module):
def forward(self, x) -> Optional[Tensor]:
y: Optional[Tensor] = x
if x.size(0) > 1:
y = None
return y
class LoopNoneInput(torch.nn.Module):
def forward(self, x) -> Optional[Tensor]:
y: Optional[Tensor] = None
for _ in range(x.size(0)):
y = x
return y
class LoopNoneOutput(torch.nn.Module):
def forward(self, x) -> Optional[Tensor]:
y: Optional[Tensor] = x
for _ in range(x.size(0)):
y = None
return y
@parametrize(
"module_class",
(IfNoneInput, IfNoneOutput, LoopNoneInput, LoopNoneOutput),
name_fn=lambda module_class: module_class.__name__,
)
@parametrize("x_size", (0, 1), name_fn=lambda x_size: str(x_size))
def test_optional_output(self, module_class: Type[torch.nn.Module], x_size: int):
# Need scripting to preserve control flow for this test to be meaningful.
model = torch.jit.script(module_class())
f = io.BytesIO()
x = torch.ones(x_size)
dynamic_axis_name = "condition"
torch.onnx.export(
model,
(x,),
f,
opset_version=15,
# Ensure condition is not constant
dynamic_axes={"x": {0: dynamic_axis_name}},
input_names=["x"],
)
exported = onnx.load_from_string(f.getvalue())
expected_elem_type = symbolic_helper.scalar_type_to_onnx[
symbolic_helper.scalar_type_to_pytorch_type.index(x.dtype)
].value
expected_output_type = onnx.helper.make_optional_type_proto(
onnx.helper.make_tensor_type_proto(expected_elem_type, (dynamic_axis_name,))
)
self.assertEqual(expected_output_type, exported.graph.output[0].type)
for node in exported.graph.node:
# Both branches output types should match.
if node.op_type == "If":
for attr in node.attribute:
if attr.name in ("then_branch", "else_branch"):
self.assertEqual(expected_output_type, attr.g.output[0].type)
def test_uninitialized_optional(self):
class Module(torch.nn.Module):
def forward(self, y: Optional[Tensor]) -> Optional[Tensor]:
if y is not None:
if y.shape[1] < 5:
if y.size(0) == 1:
y = y + 4
else:
return y
return y
y = torch.ones((3, 4), dtype=torch.int)
torch.onnx.export(
torch.jit.script(Module()),
y,
io.BytesIO(),
opset_version=15,
dynamic_axes={"y": {0: "y0", 1: "y1"}},
input_names=["y"],
)
instantiate_parametrized_tests(TestOptionalOutput)
if __name__ == "__main__":
unittest.main()