From dc22c4c9ebb95480f34033ae18611071c8019cb1 Mon Sep 17 00:00:00 2001 From: Takeshi Watanabe Date: Tue, 4 Jan 2022 21:00:37 +0900 Subject: [PATCH 01/38] [WIP] Support torch model reconstruction --- pytorch_pfn_extras/onnx/export/export.py | 13 ++++++- .../onnx/export/torch_reconstruct.py | 38 +++++++++++++++++++ stubs/torch/_C/__init__.pyi | 1 + 3 files changed, 51 insertions(+), 1 deletion(-) create mode 100644 pytorch_pfn_extras/onnx/export/torch_reconstruct.py diff --git a/pytorch_pfn_extras/onnx/export/export.py b/pytorch_pfn_extras/onnx/export/export.py index b59ee160f..f857efbae 100644 --- a/pytorch_pfn_extras/onnx/export/export.py +++ b/pytorch_pfn_extras/onnx/export/export.py @@ -89,10 +89,18 @@ def _to_tuple_if_not_sequence(v: Any) -> Any: def onnx_node_doc_string(onnx_node: torch._C.Node, torch_node: torch._C.Node) -> str: + inputs: List[torch._C.Value] = list(torch_node.inputs()) + nodes: List[torch._C.Node] = [torch_node] + while len(inputs) > 0: + n = inputs.pop().node() + if n is not None and n.kind() in ["onnx::Constant", "prim::Constant", "prim::ListConstruct"]: + nodes.insert(0, n) + inputs = list(n.inputs()) + inputs + nodes_str: str = "".join([repr(n) for n in nodes]) return f"""## Symbolic node {onnx_node} ## Original node -{torch_node} +{nodes_str} ## Scope {torch_node.scopeName()} ## Source Range @@ -876,4 +884,7 @@ def export( ex = _Exporter(model, inputs=args, **kwargs) ex.generate(f) + from pytorch_pfn_extras.onnx.export.torch_reconstruct import reconstruct + print(reconstruct(ex.model)) + return ex.outputs diff --git a/pytorch_pfn_extras/onnx/export/torch_reconstruct.py b/pytorch_pfn_extras/onnx/export/torch_reconstruct.py new file mode 100644 index 000000000..04bd88f3a --- /dev/null +++ b/pytorch_pfn_extras/onnx/export/torch_reconstruct.py @@ -0,0 +1,38 @@ +import marko +import onnx +import torch + +from typing import Dict, List + + +def reconstruct(model: onnx.ModelProto) -> torch._C.Graph: + g = torch._C.Graph() + values: Dict[str, torch._C.Value] = {} + inputs: List[torch._C.Value] = [] + for i in model.graph.input: + inputs.append(g.addInput()) + inputs[-1].setDebugName(i.name) + values[i.name] = inputs[-1] + original_lines: List[str] = [] + for n in model.graph.node: + p = marko.parser.Parser() + md = p.parse(n.doc_string) + original_paragraph: bool = False + for c in md.children: + if isinstance(c, marko.block.Paragraph) and original_paragraph: + lines = [line.children for line in c.children if isinstance(line, marko.inline.RawText)] + print(lines) + original_lines.extend(lines) + original_paragraph = False + if not isinstance(c, marko.block.Heading): + continue + if c.level != 2: + continue + if c.children[0].children == "Original node": + original_paragraph = True + original_lines = list(set(original_lines)) + print(original_lines) + + assert len(list(g.nodes())) == len(model.graph.node) + + return g diff --git a/stubs/torch/_C/__init__.pyi b/stubs/torch/_C/__init__.pyi index 969b98431..eb6b4b363 100644 --- a/stubs/torch/_C/__init__.pyi +++ b/stubs/torch/_C/__init__.pyi @@ -444,6 +444,7 @@ class Graph: def prependNode(self, n: Node) -> Node: ... def insertNode(self, n: Node) -> Node: ... def return_node(self) -> Node: ... + def addInput(self) -> Value: ... ... # Defined in torch/csrc/jit/ir/ir.h From d1414b8f512c4ba5af6081fc9e30b71ef601e1e4 Mon Sep 17 00:00:00 2001 From: twata Date: Fri, 7 Jan 2022 04:57:06 +0000 Subject: [PATCH 02/38] Use parse_ir method to construct graph --- pytorch_pfn_extras/onnx/export/export.py | 2 + .../onnx/export/torch_reconstruct.py | 42 ++++++++++--------- 2 files changed, 25 insertions(+), 19 deletions(-) diff --git a/pytorch_pfn_extras/onnx/export/export.py b/pytorch_pfn_extras/onnx/export/export.py index f857efbae..7d4681724 100644 --- a/pytorch_pfn_extras/onnx/export/export.py +++ b/pytorch_pfn_extras/onnx/export/export.py @@ -100,7 +100,9 @@ def onnx_node_doc_string(onnx_node: torch._C.Node, torch_node: torch._C.Node) -> return f"""## Symbolic node {onnx_node} ## Original node +``` {nodes_str} +``` ## Scope {torch_node.scopeName()} ## Source Range diff --git a/pytorch_pfn_extras/onnx/export/torch_reconstruct.py b/pytorch_pfn_extras/onnx/export/torch_reconstruct.py index 04bd88f3a..964e8c04d 100644 --- a/pytorch_pfn_extras/onnx/export/torch_reconstruct.py +++ b/pytorch_pfn_extras/onnx/export/torch_reconstruct.py @@ -6,33 +6,37 @@ def reconstruct(model: onnx.ModelProto) -> torch._C.Graph: - g = torch._C.Graph() - values: Dict[str, torch._C.Value] = {} - inputs: List[torch._C.Value] = [] - for i in model.graph.input: - inputs.append(g.addInput()) - inputs[-1].setDebugName(i.name) - values[i.name] = inputs[-1] original_lines: List[str] = [] for n in model.graph.node: - p = marko.parser.Parser() - md = p.parse(n.doc_string) original_paragraph: bool = False - for c in md.children: - if isinstance(c, marko.block.Paragraph) and original_paragraph: - lines = [line.children for line in c.children if isinstance(line, marko.inline.RawText)] - print(lines) - original_lines.extend(lines) + for c in marko.parser.Parser().parse(n.doc_string).children: + if isinstance(c, marko.block.FencedCode) and original_paragraph: + for lines in c.children: + if not isinstance(lines, marko.inline.RawText): + continue + for line in lines.children.split("\n"): + if len(line) == 0: + continue + original_lines.append(line) original_paragraph = False - if not isinstance(c, marko.block.Heading): - continue - if c.level != 2: + break + if not isinstance(c, marko.block.Heading) or c.level != 2: continue if c.children[0].children == "Original node": original_paragraph = True original_lines = list(set(original_lines)) - print(original_lines) - assert len(list(g.nodes())) == len(model.graph.node) + inputs: List[str] = [f"%{i.name}" for i in model.graph.input] + outputs: List[str] = [f"%{o.name}" for o in model.graph.output] + lines: str = "\n ".join(original_lines) + + src: str = f"""graph({", ".join(inputs)}): + {lines} + return ({", ".join(outputs)}) +""" + print(src) + + g: torch._C.Graph = torch._C.parse_ir(src) + torch._C._jit_pass_lint(g) return g From bbd3740cac15ad41ff69079e355b7092d38e8ff9 Mon Sep 17 00:00:00 2001 From: twata Date: Tue, 11 Jan 2022 06:14:07 +0000 Subject: [PATCH 03/38] Save scopes and fix order of nodes --- .../onnx/export/torch_reconstruct.py | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/pytorch_pfn_extras/onnx/export/torch_reconstruct.py b/pytorch_pfn_extras/onnx/export/torch_reconstruct.py index 964e8c04d..8c8ceeabe 100644 --- a/pytorch_pfn_extras/onnx/export/torch_reconstruct.py +++ b/pytorch_pfn_extras/onnx/export/torch_reconstruct.py @@ -1,12 +1,16 @@ import marko import onnx import torch +import re -from typing import Dict, List +from collections import OrderedDict +from typing import List def reconstruct(model: onnx.ModelProto) -> torch._C.Graph: original_lines: List[str] = [] + scopes: List[str] = [] + scope_re = re.compile("(.+), scope: (.+)") for n in model.graph.node: original_paragraph: bool = False for c in marko.parser.Parser().parse(n.doc_string).children: @@ -17,6 +21,13 @@ def reconstruct(model: onnx.ModelProto) -> torch._C.Graph: for line in lines.children.split("\n"): if len(line) == 0: continue + scope_match = re.match(scope_re, line) + if scope_match is not None: + scopes.append(scope_match[2]) + line = scope_match[1] + else: + scopes.append("") + line = line.replace("onnx::Constant", "prim::Constant") original_lines.append(line) original_paragraph = False break @@ -24,7 +35,7 @@ def reconstruct(model: onnx.ModelProto) -> torch._C.Graph: continue if c.children[0].children == "Original node": original_paragraph = True - original_lines = list(set(original_lines)) + original_lines = list(OrderedDict.fromkeys(original_lines)) inputs: List[str] = [f"%{i.name}" for i in model.graph.input] outputs: List[str] = [f"%{o.name}" for o in model.graph.output] From 3ab7088c5d4451cbc1c012f134ef3b8170d38adb Mon Sep 17 00:00:00 2001 From: twata Date: Fri, 14 Jan 2022 04:06:57 +0000 Subject: [PATCH 04/38] Print inlined graph in original traced graph --- pytorch_pfn_extras/onnx/export/export.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_pfn_extras/onnx/export/export.py b/pytorch_pfn_extras/onnx/export/export.py index 7d4681724..f9cecae84 100644 --- a/pytorch_pfn_extras/onnx/export/export.py +++ b/pytorch_pfn_extras/onnx/export/export.py @@ -222,7 +222,7 @@ def _run_trace(self) -> None: self.g = self.optimize_torch(self.g) self.log("Optimized graph", self.g) - self.log("Original traced graph", self.traced.graph) + self.log("Original traced graph", self.traced.inlined_graph) self.log("State dict", "\n".join([f"- {k}: {v}" for k, v in self.vars.items()])) def is_self(self, v: torch._C.Value) -> bool: From 4e2139b53b69272e4dc348eaacd542a4a7609080 Mon Sep 17 00:00:00 2001 From: twata Date: Fri, 14 Jan 2022 07:16:46 +0000 Subject: [PATCH 05/38] Remove debug print --- pytorch_pfn_extras/onnx/export/torch_reconstruct.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pytorch_pfn_extras/onnx/export/torch_reconstruct.py b/pytorch_pfn_extras/onnx/export/torch_reconstruct.py index 8c8ceeabe..b9afe0c71 100644 --- a/pytorch_pfn_extras/onnx/export/torch_reconstruct.py +++ b/pytorch_pfn_extras/onnx/export/torch_reconstruct.py @@ -45,7 +45,6 @@ def reconstruct(model: onnx.ModelProto) -> torch._C.Graph: {lines} return ({", ".join(outputs)}) """ - print(src) g: torch._C.Graph = torch._C.parse_ir(src) torch._C._jit_pass_lint(g) From 8b471cf4e95a01bbdf276ad59a614372299cc0bc Mon Sep 17 00:00:00 2001 From: twata Date: Fri, 14 Jan 2022 07:17:06 +0000 Subject: [PATCH 06/38] Set debug name of outputs too --- pytorch_pfn_extras/onnx/export/export.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/pytorch_pfn_extras/onnx/export/export.py b/pytorch_pfn_extras/onnx/export/export.py index f9cecae84..0e99c7691 100644 --- a/pytorch_pfn_extras/onnx/export/export.py +++ b/pytorch_pfn_extras/onnx/export/export.py @@ -301,6 +301,9 @@ def optimize_torch(self, graph: torch._C.Graph) -> torch._C.Graph: inputs = list(graph.inputs()) for idx, n in enumerate(input_names): inputs[idx].setDebugName(n) + if self.output_names is not None: + for name, out in zip(self.output_names, graph.outputs()): + out.setDebugName(name) torch._C._jit_pass_onnx_set_dynamic_input_shape( # type: ignore[attr-defined] graph, self.dynamic_axes or {}, input_names or [] ) From d0f365100dd5492c42374cf9b7303068ddd57711 Mon Sep 17 00:00:00 2001 From: twata Date: Fri, 14 Jan 2022 07:18:02 +0000 Subject: [PATCH 07/38] Restore value name and support initializers --- pytorch_pfn_extras/onnx/export/export.py | 10 +++++- .../onnx/export/torch_reconstruct.py | 33 +++++++++++++++---- 2 files changed, 35 insertions(+), 8 deletions(-) diff --git a/pytorch_pfn_extras/onnx/export/export.py b/pytorch_pfn_extras/onnx/export/export.py index 0e99c7691..740955844 100644 --- a/pytorch_pfn_extras/onnx/export/export.py +++ b/pytorch_pfn_extras/onnx/export/export.py @@ -678,6 +678,7 @@ def block2subgraph(name: str, b: torch._C.Block, doc_string: str) -> onnx.GraphP assert isinstance(self.vars[k], torch.Tensor) t: torch.Tensor = cast(torch.Tensor, self.vars[k]) onnx_vars[_unique_id(i)] = _tensor_to_proto(t, name=k) + onnx_vars[_unique_id(i)].doc_string = repr(i.node()) register_val_name(_unique_id(i), value_name(i), shadow=True) continue if _unique_id(i) not in val_tab: @@ -732,8 +733,15 @@ def assign_onnx_values( return onnx_nodes, onnx_vars, val_tab def generate_onnx(self) -> onnx.ModelProto: - # Convert prim and aten nodes to ONNX by using symbolic functions self.original_g: torch._C.Graph = self.g.copy() + + # Name all values to restore + for n in self.g.nodes(): + for o in n.outputs(): + if o.debugName() == str(o.unique()): + o.setDebugName(f"v{o.unique()}") + + # Convert prim and aten nodes to ONNX by using symbolic functions target_nodes = list(self.g.nodes()) for n in target_nodes: self.generate_onnx_node(self.g, n) diff --git a/pytorch_pfn_extras/onnx/export/torch_reconstruct.py b/pytorch_pfn_extras/onnx/export/torch_reconstruct.py index b9afe0c71..55a22ab95 100644 --- a/pytorch_pfn_extras/onnx/export/torch_reconstruct.py +++ b/pytorch_pfn_extras/onnx/export/torch_reconstruct.py @@ -4,13 +4,16 @@ import re from collections import OrderedDict -from typing import List +from typing import List, Tuple -def reconstruct(model: onnx.ModelProto) -> torch._C.Graph: +def reconstruct(model: onnx.ModelProto) -> Tuple[torch._C.Graph, List[Tuple[str, torch.Tensor]]]: original_lines: List[str] = [] scopes: List[str] = [] - scope_re = re.compile("(.+), scope: (.+)") + scope_re = re.compile("(.+), scope: ([^ ]+)") + const_vals_re = re.compile(r"value= ([\d ]+) \[ \w+Type\{\d+\} \]") + const_val_re = re.compile(r"value=\{(\d+)\}") + func_re = re.compile(r" = ^(\w+)\(") for n in model.graph.node: original_paragraph: bool = False for c in marko.parser.Parser().parse(n.doc_string).children: @@ -23,12 +26,20 @@ def reconstruct(model: onnx.ModelProto) -> torch._C.Graph: continue scope_match = re.match(scope_re, line) if scope_match is not None: - scopes.append(scope_match[2]) + scope = scope_match[2].split("/")[-1] + scopes.append(scope) line = scope_match[1] else: scopes.append("") line = line.replace("onnx::Constant", "prim::Constant") + if "prim::Constant" in line: + line = re.sub(const_vals_re, lambda m: f"value=[{m[1].replace(' ', ', ')}]", line) + line = re.sub(const_val_re, r"value=\1", line) original_lines.append(line) + + func_match = re.match(func_re, line) + if func_match: + raise f"Function call not supported for: {func_match[1]}" original_paragraph = False break if not isinstance(c, marko.block.Heading) or c.level != 2: @@ -37,10 +48,18 @@ def reconstruct(model: onnx.ModelProto) -> torch._C.Graph: original_paragraph = True original_lines = list(OrderedDict.fromkeys(original_lines)) - inputs: List[str] = [f"%{i.name}" for i in model.graph.input] - outputs: List[str] = [f"%{o.name}" for o in model.graph.output] + inputs: List[str] = ["%" + i.name for i in model.graph.input] + outputs: List[str] = ["%" + o.name.split(".")[-1] for o in model.graph.output] lines: str = "\n ".join(original_lines) + initializer_name_re = re.compile(r"^%(\w+) [:=]") + params: List[Tuple[str, torch.Tensor]] = [] + for i in model.graph.initializer: + i_name = re.match(initializer_name_re, i.doc_string) + if i_name: + inputs.append(f"%{i_name[1]}") + params.append((i.name, torch.from_numpy(onnx.numpy_helper.to_array(i).copy()))) + src: str = f"""graph({", ".join(inputs)}): {lines} return ({", ".join(outputs)}) @@ -49,4 +68,4 @@ def reconstruct(model: onnx.ModelProto) -> torch._C.Graph: g: torch._C.Graph = torch._C.parse_ir(src) torch._C._jit_pass_lint(g) - return g + return g, params From 5df50f0bd355ed7cb7819556f6c7c998fb24afdc Mon Sep 17 00:00:00 2001 From: twata Date: Fri, 14 Jan 2022 07:49:24 +0000 Subject: [PATCH 08/38] Cut out line processor to function --- .../onnx/export/torch_reconstruct.py | 45 +++++++++++-------- 1 file changed, 26 insertions(+), 19 deletions(-) diff --git a/pytorch_pfn_extras/onnx/export/torch_reconstruct.py b/pytorch_pfn_extras/onnx/export/torch_reconstruct.py index 55a22ab95..6513b82dd 100644 --- a/pytorch_pfn_extras/onnx/export/torch_reconstruct.py +++ b/pytorch_pfn_extras/onnx/export/torch_reconstruct.py @@ -7,13 +7,33 @@ from typing import List, Tuple +_scope_re = re.compile("(.+), scope: ([^ ]+)") +_const_vals_re = re.compile(r"value= ([\d ]+) \[ \w+Type\{\d+\} \]") +_const_val_re = re.compile(r"value=\{(\d+)\}") +_func_re = re.compile(r" = \^(\w+)\(") + + +def _process_line(line: str) -> (str, str): + scope_match = re.match(_scope_re, line) + scope = "" + if scope_match is not None: + scope = scope_match[2].split("/")[-1] + line = scope_match[1] + line = line.replace("onnx::Constant", "prim::Constant") + if "prim::Constant" in line: + line = re.sub(_const_vals_re, lambda m: f"value=[{m[1].replace(' ', ', ')}]", line) + line = re.sub(_const_val_re, r"value=\1", line) + + func_match = re.search(_func_re, line) + if func_match: + raise RuntimeError(f"Function call not supported for: {func_match[1]} in line: {line}") + + return line, scope + + def reconstruct(model: onnx.ModelProto) -> Tuple[torch._C.Graph, List[Tuple[str, torch.Tensor]]]: original_lines: List[str] = [] scopes: List[str] = [] - scope_re = re.compile("(.+), scope: ([^ ]+)") - const_vals_re = re.compile(r"value= ([\d ]+) \[ \w+Type\{\d+\} \]") - const_val_re = re.compile(r"value=\{(\d+)\}") - func_re = re.compile(r" = ^(\w+)\(") for n in model.graph.node: original_paragraph: bool = False for c in marko.parser.Parser().parse(n.doc_string).children: @@ -24,22 +44,9 @@ def reconstruct(model: onnx.ModelProto) -> Tuple[torch._C.Graph, List[Tuple[str, for line in lines.children.split("\n"): if len(line) == 0: continue - scope_match = re.match(scope_re, line) - if scope_match is not None: - scope = scope_match[2].split("/")[-1] - scopes.append(scope) - line = scope_match[1] - else: - scopes.append("") - line = line.replace("onnx::Constant", "prim::Constant") - if "prim::Constant" in line: - line = re.sub(const_vals_re, lambda m: f"value=[{m[1].replace(' ', ', ')}]", line) - line = re.sub(const_val_re, r"value=\1", line) + line, scope = _process_line(line) original_lines.append(line) - - func_match = re.match(func_re, line) - if func_match: - raise f"Function call not supported for: {func_match[1]}" + scopes.append(scope) original_paragraph = False break if not isinstance(c, marko.block.Heading) or c.level != 2: From 380b4cf9a91245fa0554a6581d20c07a1045a2eb Mon Sep 17 00:00:00 2001 From: twata Date: Fri, 14 Jan 2022 08:02:25 +0000 Subject: [PATCH 09/38] Cut out markdown processor --- .../onnx/export/torch_reconstruct.py | 54 +++++++++++-------- 1 file changed, 32 insertions(+), 22 deletions(-) diff --git a/pytorch_pfn_extras/onnx/export/torch_reconstruct.py b/pytorch_pfn_extras/onnx/export/torch_reconstruct.py index 6513b82dd..b86972947 100644 --- a/pytorch_pfn_extras/onnx/export/torch_reconstruct.py +++ b/pytorch_pfn_extras/onnx/export/torch_reconstruct.py @@ -31,33 +31,43 @@ def _process_line(line: str) -> (str, str): return line, scope +def _process_markdown(md: str) -> Tuple[List[str], List[str]]: + lines: List[str] = [] + scopes: List[str] = [] + target_para: bool = False + for c in marko.parser.Parser().parse(md).children: + if isinstance(c, marko.block.FencedCode) and target_para: + for text in c.children: + if not isinstance(text, marko.inline.RawText): + continue + for line in text.children.split("\n"): + if len(line) == 0: + continue + line, scope = _process_line(line) + lines.append(line) + scopes.append(scope) + target_para = False + break + if not isinstance(c, marko.block.Heading) or c.level != 2: + continue + if c.children[0].children == "Original node": + target_para = True + + return lines, scopes + + def reconstruct(model: onnx.ModelProto) -> Tuple[torch._C.Graph, List[Tuple[str, torch.Tensor]]]: - original_lines: List[str] = [] + lines: List[str] = [] scopes: List[str] = [] for n in model.graph.node: - original_paragraph: bool = False - for c in marko.parser.Parser().parse(n.doc_string).children: - if isinstance(c, marko.block.FencedCode) and original_paragraph: - for lines in c.children: - if not isinstance(lines, marko.inline.RawText): - continue - for line in lines.children.split("\n"): - if len(line) == 0: - continue - line, scope = _process_line(line) - original_lines.append(line) - scopes.append(scope) - original_paragraph = False - break - if not isinstance(c, marko.block.Heading) or c.level != 2: - continue - if c.children[0].children == "Original node": - original_paragraph = True - original_lines = list(OrderedDict.fromkeys(original_lines)) + new_lines, new_scopes = _process_markdown(n.doc_string) + lines.extend(new_lines) + scopes.extend(new_scopes) + lines = list(OrderedDict.fromkeys(lines)) inputs: List[str] = ["%" + i.name for i in model.graph.input] outputs: List[str] = ["%" + o.name.split(".")[-1] for o in model.graph.output] - lines: str = "\n ".join(original_lines) + body = "\n ".join(lines) initializer_name_re = re.compile(r"^%(\w+) [:=]") params: List[Tuple[str, torch.Tensor]] = [] @@ -68,7 +78,7 @@ def reconstruct(model: onnx.ModelProto) -> Tuple[torch._C.Graph, List[Tuple[str, params.append((i.name, torch.from_numpy(onnx.numpy_helper.to_array(i).copy()))) src: str = f"""graph({", ".join(inputs)}): - {lines} + {body} return ({", ".join(outputs)}) """ From 546953cb1a41718bf95c9ca3f989b6423dc0f418 Mon Sep 17 00:00:00 2001 From: twata Date: Fri, 14 Jan 2022 08:26:11 +0000 Subject: [PATCH 10/38] Fix doc string generation for If --- pytorch_pfn_extras/onnx/export/export.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pytorch_pfn_extras/onnx/export/export.py b/pytorch_pfn_extras/onnx/export/export.py index 740955844..b7939827a 100644 --- a/pytorch_pfn_extras/onnx/export/export.py +++ b/pytorch_pfn_extras/onnx/export/export.py @@ -430,7 +430,9 @@ def handle_if(self, g: torch._C.Graph, n: torch._C.Node) -> None: # Generated onnx node doc string should be added later since DCE isn't completed yet doc_str: str = f""" ## Original node +``` {n} +``` ## Scope {n.scopeName()} ## Source Range From b7bbfd16ef9f3bca71c7b60e636f61b58459d394 Mon Sep 17 00:00:00 2001 From: twata Date: Fri, 14 Jan 2022 08:26:25 +0000 Subject: [PATCH 11/38] Skip torch.autograd.Function error --- pytorch_pfn_extras/onnx/export/export.py | 8 ++++++-- pytorch_pfn_extras/onnx/export/torch_reconstruct.py | 2 +- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/pytorch_pfn_extras/onnx/export/export.py b/pytorch_pfn_extras/onnx/export/export.py index b7939827a..af5ee7f52 100644 --- a/pytorch_pfn_extras/onnx/export/export.py +++ b/pytorch_pfn_extras/onnx/export/export.py @@ -899,7 +899,11 @@ def export( ex = _Exporter(model, inputs=args, **kwargs) ex.generate(f) - from pytorch_pfn_extras.onnx.export.torch_reconstruct import reconstruct - print(reconstruct(ex.model)) + try: + from pytorch_pfn_extras.onnx.export.torch_reconstruct import reconstruct + print(reconstruct(ex.model)) + except RuntimeError as e: + if not e.args[0].startswith("torch.autograd.Function call not supported"): + raise return ex.outputs diff --git a/pytorch_pfn_extras/onnx/export/torch_reconstruct.py b/pytorch_pfn_extras/onnx/export/torch_reconstruct.py index b86972947..8781a24ae 100644 --- a/pytorch_pfn_extras/onnx/export/torch_reconstruct.py +++ b/pytorch_pfn_extras/onnx/export/torch_reconstruct.py @@ -26,7 +26,7 @@ def _process_line(line: str) -> (str, str): func_match = re.search(_func_re, line) if func_match: - raise RuntimeError(f"Function call not supported for: {func_match[1]} in line: {line}") + raise RuntimeError(f"torch.autograd.Function call not supported for: {func_match[1]} in line: {line}") return line, scope From 8b563b8e80b0d4b236e7e26a0e772347fd612fe6 Mon Sep 17 00:00:00 2001 From: twata Date: Fri, 14 Jan 2022 08:55:03 +0000 Subject: [PATCH 12/38] Support literal onnx::SequenceConstruct --- pytorch_pfn_extras/onnx/export/export.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_pfn_extras/onnx/export/export.py b/pytorch_pfn_extras/onnx/export/export.py index af5ee7f52..af2da38eb 100644 --- a/pytorch_pfn_extras/onnx/export/export.py +++ b/pytorch_pfn_extras/onnx/export/export.py @@ -93,7 +93,7 @@ def onnx_node_doc_string(onnx_node: torch._C.Node, torch_node: torch._C.Node) -> nodes: List[torch._C.Node] = [torch_node] while len(inputs) > 0: n = inputs.pop().node() - if n is not None and n.kind() in ["onnx::Constant", "prim::Constant", "prim::ListConstruct"]: + if n is not None and n.kind() in ["onnx::Constant", "prim::Constant", "prim::ListConstruct", "onnx::SequenceConstruct"]: nodes.insert(0, n) inputs = list(n.inputs()) + inputs nodes_str: str = "".join([repr(n) for n in nodes]) From 3bac50b3e64df4b178f1e3d8b62cca955e267ccb Mon Sep 17 00:00:00 2001 From: twata Date: Fri, 14 Jan 2022 08:55:31 +0000 Subject: [PATCH 13/38] Skip initializer input to avoid duplicate --- pytorch_pfn_extras/onnx/export/torch_reconstruct.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/pytorch_pfn_extras/onnx/export/torch_reconstruct.py b/pytorch_pfn_extras/onnx/export/torch_reconstruct.py index 8781a24ae..4204daad2 100644 --- a/pytorch_pfn_extras/onnx/export/torch_reconstruct.py +++ b/pytorch_pfn_extras/onnx/export/torch_reconstruct.py @@ -4,7 +4,7 @@ import re from collections import OrderedDict -from typing import List, Tuple +from typing import List, Set, Tuple _scope_re = re.compile("(.+), scope: ([^ ]+)") @@ -65,7 +65,9 @@ def reconstruct(model: onnx.ModelProto) -> Tuple[torch._C.Graph, List[Tuple[str, scopes.extend(new_scopes) lines = list(OrderedDict.fromkeys(lines)) - inputs: List[str] = ["%" + i.name for i in model.graph.input] + skip_inputs: Set[str] = set([i.name for i in model.graph.initializer]) + + inputs: List[str] = ["%" + i.name for i in model.graph.input if i.name not in skip_inputs] outputs: List[str] = ["%" + o.name.split(".")[-1] for o in model.graph.output] body = "\n ".join(lines) From b5ebf7e64f4818d98c3e0e7ad97c6760563d6dcb Mon Sep 17 00:00:00 2001 From: twata Date: Fri, 14 Jan 2022 09:18:47 +0000 Subject: [PATCH 14/38] Support more expression --- pytorch_pfn_extras/onnx/export/torch_reconstruct.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_pfn_extras/onnx/export/torch_reconstruct.py b/pytorch_pfn_extras/onnx/export/torch_reconstruct.py index 4204daad2..f3c7fe788 100644 --- a/pytorch_pfn_extras/onnx/export/torch_reconstruct.py +++ b/pytorch_pfn_extras/onnx/export/torch_reconstruct.py @@ -8,8 +8,8 @@ _scope_re = re.compile("(.+), scope: ([^ ]+)") -_const_vals_re = re.compile(r"value= ([\d ]+) \[ \w+Type\{\d+\} \]") -_const_val_re = re.compile(r"value=\{(\d+)\}") +_const_vals_re = re.compile(r"value= ([\d\- ]+) \[ \w+Type\{\d+\} \]") +_const_val_re = re.compile(r"value=\{(-?[\d\.e-]+)\}") _func_re = re.compile(r" = \^(\w+)\(") From 4a1c088c95801c75380debb8d86937777275a3a2 Mon Sep 17 00:00:00 2001 From: twata Date: Fri, 14 Jan 2022 09:19:24 +0000 Subject: [PATCH 15/38] Place onnx identity node instead to track identity op in onnx --- pytorch_pfn_extras/onnx/export/export.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/pytorch_pfn_extras/onnx/export/export.py b/pytorch_pfn_extras/onnx/export/export.py index af2da38eb..063c6e641 100644 --- a/pytorch_pfn_extras/onnx/export/export.py +++ b/pytorch_pfn_extras/onnx/export/export.py @@ -516,9 +516,13 @@ def list_added_nodes() -> List[torch._C.Node]: sym_nodes: List[torch._C.Node] = list_added_nodes() + # Place onnx::Identity node instead node when none is added + if len(sym_nodes) == 0: + sym_outs = g.op("Identity", sym_outs[0]), + sym_nodes = [sym_outs[0].node()] + self.log(f"Converting node {n.kind()}", n) - if len(sym_nodes) > 0: - self.log(f"Converted node {n.kind()}", "\n".join([str(i) for i in sym_nodes])) + self.log(f"Converted node {n.kind()}", "\n".join([str(i) for i in sym_nodes])) # Generate doc string before old node lifetime ends for sym_nd in sym_nodes: From 36dad083b17f1ffb9b449c7c48affb68f283362b Mon Sep 17 00:00:00 2001 From: twata Date: Fri, 14 Jan 2022 10:18:01 +0000 Subject: [PATCH 16/38] mypy --- pytorch_pfn_extras/onnx/export/export.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pytorch_pfn_extras/onnx/export/export.py b/pytorch_pfn_extras/onnx/export/export.py index 063c6e641..17fcedcf4 100644 --- a/pytorch_pfn_extras/onnx/export/export.py +++ b/pytorch_pfn_extras/onnx/export/export.py @@ -743,9 +743,9 @@ def generate_onnx(self) -> onnx.ModelProto: # Name all values to restore for n in self.g.nodes(): - for o in n.outputs(): - if o.debugName() == str(o.unique()): - o.setDebugName(f"v{o.unique()}") + for n_o in n.outputs(): + if n_o.debugName() == str(n_o.unique()): + n_o.setDebugName(f"v{n_o.unique()}") # Convert prim and aten nodes to ONNX by using symbolic functions target_nodes = list(self.g.nodes()) From 024a73b2529350075d783f5e96a588adc4cbb05f Mon Sep 17 00:00:00 2001 From: twata Date: Fri, 14 Jan 2022 10:18:12 +0000 Subject: [PATCH 17/38] Replace onnx::SequenceConstruct too --- pytorch_pfn_extras/onnx/export/torch_reconstruct.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pytorch_pfn_extras/onnx/export/torch_reconstruct.py b/pytorch_pfn_extras/onnx/export/torch_reconstruct.py index f3c7fe788..4526f2599 100644 --- a/pytorch_pfn_extras/onnx/export/torch_reconstruct.py +++ b/pytorch_pfn_extras/onnx/export/torch_reconstruct.py @@ -20,6 +20,7 @@ def _process_line(line: str) -> (str, str): scope = scope_match[2].split("/")[-1] line = scope_match[1] line = line.replace("onnx::Constant", "prim::Constant") + line = line.replace("onnx::SequenceConstruct", "prim::ListConstruct") if "prim::Constant" in line: line = re.sub(_const_vals_re, lambda m: f"value=[{m[1].replace(' ', ', ')}]", line) line = re.sub(_const_val_re, r"value=\1", line) From 789f1ff3d3e2d46558b4cfdb07d18c7325905514 Mon Sep 17 00:00:00 2001 From: twata Date: Fri, 14 Jan 2022 10:18:48 +0000 Subject: [PATCH 18/38] Run check only in tests --- pytorch_pfn_extras/onnx/export/export.py | 7 ------- pytorch_pfn_extras/onnx/export/torch_reconstruct.py | 13 ++++++++++--- .../onnx_tests/test_as_output.py | 2 +- .../onnx_tests/test_export.py | 2 +- .../onnx_tests/test_export_testcase.py | 9 +++++++-- tests/pytorch_pfn_extras_tests/onnx_tests/utils.py | 7 ++++++- 6 files changed, 25 insertions(+), 15 deletions(-) diff --git a/pytorch_pfn_extras/onnx/export/export.py b/pytorch_pfn_extras/onnx/export/export.py index 17fcedcf4..213b5c259 100644 --- a/pytorch_pfn_extras/onnx/export/export.py +++ b/pytorch_pfn_extras/onnx/export/export.py @@ -903,11 +903,4 @@ def export( ex = _Exporter(model, inputs=args, **kwargs) ex.generate(f) - try: - from pytorch_pfn_extras.onnx.export.torch_reconstruct import reconstruct - print(reconstruct(ex.model)) - except RuntimeError as e: - if not e.args[0].startswith("torch.autograd.Function call not supported"): - raise - return ex.outputs diff --git a/pytorch_pfn_extras/onnx/export/torch_reconstruct.py b/pytorch_pfn_extras/onnx/export/torch_reconstruct.py index 4526f2599..f3c8cb3c4 100644 --- a/pytorch_pfn_extras/onnx/export/torch_reconstruct.py +++ b/pytorch_pfn_extras/onnx/export/torch_reconstruct.py @@ -13,7 +13,11 @@ _func_re = re.compile(r" = \^(\w+)\(") -def _process_line(line: str) -> (str, str): +class ReconstructError(Exception): + pass + + +def _process_line(line: str) -> Tuple[str, str]: scope_match = re.match(_scope_re, line) scope = "" if scope_match is not None: @@ -27,7 +31,7 @@ def _process_line(line: str) -> (str, str): func_match = re.search(_func_re, line) if func_match: - raise RuntimeError(f"torch.autograd.Function call not supported for: {func_match[1]} in line: {line}") + raise ReconstructError(f"torch.autograd.Function call not supported for: {func_match[1]} in line: {line}") return line, scope @@ -36,7 +40,7 @@ def _process_markdown(md: str) -> Tuple[List[str], List[str]]: lines: List[str] = [] scopes: List[str] = [] target_para: bool = False - for c in marko.parser.Parser().parse(md).children: + for c in marko.parser.Parser().parse(md).children: # type: ignore[union-attr] if isinstance(c, marko.block.FencedCode) and target_para: for text in c.children: if not isinstance(text, marko.inline.RawText): @@ -61,6 +65,8 @@ def reconstruct(model: onnx.ModelProto) -> Tuple[torch._C.Graph, List[Tuple[str, lines: List[str] = [] scopes: List[str] = [] for n in model.graph.node: + if len(n.doc_string) == 0 and n.op_type != "Constant": + raise ReconstructError(f"doc_string not found in node: {onnx.helper.printable_node(n)}. Please use strip_doc_string=False option") new_lines, new_scopes = _process_markdown(n.doc_string) lines.extend(new_lines) scopes.extend(new_scopes) @@ -84,6 +90,7 @@ def reconstruct(model: onnx.ModelProto) -> Tuple[torch._C.Graph, List[Tuple[str, {body} return ({", ".join(outputs)}) """ + print(src) g: torch._C.Graph = torch._C.parse_ir(src) torch._C._jit_pass_lint(g) diff --git a/tests/pytorch_pfn_extras_tests/onnx_tests/test_as_output.py b/tests/pytorch_pfn_extras_tests/onnx_tests/test_as_output.py index 8a73f761c..55ffaae8c 100644 --- a/tests/pytorch_pfn_extras_tests/onnx_tests/test_as_output.py +++ b/tests/pytorch_pfn_extras_tests/onnx_tests/test_as_output.py @@ -58,7 +58,7 @@ def forward(self, x): model = Net() x = torch.ones((1, 1, 32, 32)) - output_dir = _helper(model, x, 'as_output') + output_dir = _helper(model, x, 'as_output', check_reconstruct=False) actual_onnx = onnx.load(os.path.join(output_dir, 'model.onnx')) named_nodes = {n.name: n for n in actual_onnx.graph.node} diff --git a/tests/pytorch_pfn_extras_tests/onnx_tests/test_export.py b/tests/pytorch_pfn_extras_tests/onnx_tests/test_export.py index 82b1d6120..e8caeb275 100644 --- a/tests/pytorch_pfn_extras_tests/onnx_tests/test_export.py +++ b/tests/pytorch_pfn_extras_tests/onnx_tests/test_export.py @@ -58,7 +58,7 @@ def forward(self, x): return Func.apply(x) + torch.tensor([10], dtype=torch.float) assert hasattr(Func, "symbolic") - run_model_test(Model(), (torch.rand((20,)),)) + run_model_test(Model(), (torch.rand((20,)),), check_reconstruct=False) class AnyModel(torch.nn.Module): diff --git a/tests/pytorch_pfn_extras_tests/onnx_tests/test_export_testcase.py b/tests/pytorch_pfn_extras_tests/onnx_tests/test_export_testcase.py index eb707aa20..d62e5325a 100644 --- a/tests/pytorch_pfn_extras_tests/onnx_tests/test_export_testcase.py +++ b/tests/pytorch_pfn_extras_tests/onnx_tests/test_export_testcase.py @@ -19,6 +19,7 @@ from pytorch_pfn_extras.onnx import LARGE_TENSOR_DATA_THRESHOLD from pytorch_pfn_extras.onnx.strip_large_tensor import _strip_large_tensor_tool_impl from pytorch_pfn_extras.onnx.unstrip_tensor import unstrip +from pytorch_pfn_extras.onnx.export.torch_reconstruct import reconstruct output_dir = 'out' @@ -55,7 +56,7 @@ def _get_output_dir(d, **kwargs): return output_dir -def _helper(model, args, d, use_pfto=True, **kwargs): +def _helper(model, args, d, use_pfto=True, check_reconstruct=True, **kwargs): output_dir = _get_output_dir(d) if 'training' not in kwargs: kwargs['training'] = model.training @@ -63,7 +64,11 @@ def _helper(model, args, d, use_pfto=True, **kwargs): kwargs['do_constant_folding'] = False if 'metadata' not in kwargs: kwargs["metadata"] = False + if "strip_doc_string" not in kwargs: + kwargs["strip_doc_string"] = False export_testcase(model, args, output_dir, use_pfto=use_pfto, **kwargs) + if check_reconstruct and use_pfto and not kwargs["strip_doc_string"]: + reconstruct(onnx.load(os.path.join(output_dir, "model.onnx"))) return output_dir @@ -259,7 +264,7 @@ def test_export_testcase_strip_large_tensor_data(): output_dir = _helper( model, x, 'mnist_stripped_tensor_data', output_grad=True, strip_large_tensor_data=True, - metadata=True) + metadata=True, check_reconstruct=False) assert os.path.isdir(output_dir) assert os.path.isfile(os.path.join(output_dir, 'meta.json')) diff --git a/tests/pytorch_pfn_extras_tests/onnx_tests/utils.py b/tests/pytorch_pfn_extras_tests/onnx_tests/utils.py index f9609e049..966d87fc1 100644 --- a/tests/pytorch_pfn_extras_tests/onnx_tests/utils.py +++ b/tests/pytorch_pfn_extras_tests/onnx_tests/utils.py @@ -5,6 +5,7 @@ import onnxruntime as ort import torch from pytorch_pfn_extras.onnx.export import export as pfto_export +from pytorch_pfn_extras.onnx.export.torch_reconstruct import reconstruct def run_model_test( @@ -20,6 +21,7 @@ def run_model_test( strict_trace=True, mode="eval", use_gpu=False, + check_reconstruct=True, **kwargs, ) -> onnx.ModelProto: if mode == "train": @@ -82,4 +84,7 @@ def run_model_test( cmp = torch.isclose(torch.tensor(a), e.cpu(), rtol=rtol, atol=atol) assert cmp.all(), f"{cmp.logical_not().count_nonzero()} / {cmp.numel()} values failed" - return onnx.load(f.name) + onnx_model = onnx.load(f.name) + if check_reconstruct: + reconstruct(onnx_model) + return onnx_model From 605b37a1c36fa6febc34a638e7389fa5affc0ee7 Mon Sep 17 00:00:00 2001 From: Takeshi Watanabe Date: Mon, 7 Feb 2022 20:51:28 +0900 Subject: [PATCH 19/38] Fix import names --- .../pytorch_pfn_extras_tests/onnx_tests/test_export_testcase.py | 2 +- tests/pytorch_pfn_extras_tests/onnx_tests/utils.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/pytorch_pfn_extras_tests/onnx_tests/test_export_testcase.py b/tests/pytorch_pfn_extras_tests/onnx_tests/test_export_testcase.py index 202dc3c66..ee15a81ff 100644 --- a/tests/pytorch_pfn_extras_tests/onnx_tests/test_export_testcase.py +++ b/tests/pytorch_pfn_extras_tests/onnx_tests/test_export_testcase.py @@ -19,7 +19,7 @@ from pytorch_pfn_extras.onnx import LARGE_TENSOR_DATA_THRESHOLD from pytorch_pfn_extras.onnx.strip_large_tensor import _strip_large_tensor_tool_impl from pytorch_pfn_extras.onnx.unstrip_tensor import unstrip -from pytorch_pfn_extras.onnx.export.torch_reconstruct import reconstruct +from pytorch_pfn_extras.onnx.pfto_exporter.torch_reconstruct import reconstruct output_dir = 'out' diff --git a/tests/pytorch_pfn_extras_tests/onnx_tests/utils.py b/tests/pytorch_pfn_extras_tests/onnx_tests/utils.py index ce8ada7d5..7eb9b87ca 100644 --- a/tests/pytorch_pfn_extras_tests/onnx_tests/utils.py +++ b/tests/pytorch_pfn_extras_tests/onnx_tests/utils.py @@ -5,7 +5,7 @@ import onnxruntime as ort import torch from pytorch_pfn_extras.onnx.pfto_exporter.export import export as pfto_export -from pytorch_pfn_extras.onnx.export.torch_reconstruct import reconstruct +from pytorch_pfn_extras.onnx.pfto_exporter.torch_reconstruct import reconstruct def run_model_test( From b10113c717e8e70c35143e5894b290a976d75431 Mon Sep 17 00:00:00 2001 From: Takeshi Watanabe Date: Tue, 8 Feb 2022 16:14:42 +0900 Subject: [PATCH 20/38] Remove debug print --- pytorch_pfn_extras/onnx/pfto_exporter/torch_reconstruct.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pytorch_pfn_extras/onnx/pfto_exporter/torch_reconstruct.py b/pytorch_pfn_extras/onnx/pfto_exporter/torch_reconstruct.py index f3c8cb3c4..8278b094a 100644 --- a/pytorch_pfn_extras/onnx/pfto_exporter/torch_reconstruct.py +++ b/pytorch_pfn_extras/onnx/pfto_exporter/torch_reconstruct.py @@ -90,7 +90,6 @@ def reconstruct(model: onnx.ModelProto) -> Tuple[torch._C.Graph, List[Tuple[str, {body} return ({", ".join(outputs)}) """ - print(src) g: torch._C.Graph = torch._C.parse_ir(src) torch._C._jit_pass_lint(g) From c5e95a071c542244fec74f285e8df033f43a3ecf Mon Sep 17 00:00:00 2001 From: Takeshi Watanabe Date: Wed, 9 Feb 2022 12:48:21 +0900 Subject: [PATCH 21/38] Add marko to onnx dependency --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 71b32251a..e298f3b66 100644 --- a/setup.py +++ b/setup.py @@ -17,7 +17,7 @@ install_requires=['numpy', 'packaging', 'torch', 'typing-extensions>=3.10'], extras_require={ 'test': ['pytest', 'onnxruntime', 'torchvision'], - 'onnx': ['onnx'], + 'onnx': ['onnx', 'marko'], }, python_requires='>=3.6.0', packages=setuptools.find_packages(exclude=['tests', 'tests.*']), From 6fb5cb38ae34e93a46116364337f0970c9c24338 Mon Sep 17 00:00:00 2001 From: Kenichi Maehashi Date: Wed, 23 Mar 2022 11:27:51 +0000 Subject: [PATCH 22/38] add marko --- .flexci/linux/build_and_push.sh | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.flexci/linux/build_and_push.sh b/.flexci/linux/build_and_push.sh index 58c7be7bf..3175f9e51 100755 --- a/.flexci/linux/build_and_push.sh +++ b/.flexci/linux/build_and_push.sh @@ -8,7 +8,7 @@ if [ "${IMAGE_BASE}" = "" ]; then fi TEST_PIP_PACKAGES=" -matplotlib tensorboard ipython ipywidgets pandas optuna onnx onnxruntime +matplotlib tensorboard ipython ipywidgets pandas optuna onnx onnxruntime marko pytest flake8 pysen[lint] pytest-cov " From 291561c33e3369ecf0299cfb49ae467bccdaac34 Mon Sep 17 00:00:00 2001 From: Kenichi Maehashi Date: Wed, 23 Mar 2022 11:28:35 +0000 Subject: [PATCH 23/38] add marko to Windows CI --- .flexci/windows/test.ps1 | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.flexci/windows/test.ps1 b/.flexci/windows/test.ps1 index 707697fb2..838f1d9ed 100644 --- a/.flexci/windows/test.ps1 +++ b/.flexci/windows/test.ps1 @@ -36,7 +36,7 @@ if ($test -eq "torch18") { RunOrDie python -V # Install common requirements -RunOrDie python -m pip install pytorch-ignite pytest flake8 matplotlib tensorboard onnx ipython ipywidgets pandas optuna cupy-cuda102 onnxruntime +RunOrDie python -m pip install pytorch-ignite pytest flake8 matplotlib tensorboard onnx ipython ipywidgets pandas optuna cupy-cuda102 onnxruntime marko RunOrDie python -m pip list # Install From 90e529c184cd4a94aba2b46b1c85cae8b3e863ed Mon Sep 17 00:00:00 2001 From: Takeshi Watanabe Date: Mon, 3 Oct 2022 17:27:24 +0900 Subject: [PATCH 24/38] Install missing package in cpu test --- .github/workflows/test-cpu.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/test-cpu.yml b/.github/workflows/test-cpu.yml index 96c00e67e..ecdf859c3 100644 --- a/.github/workflows/test-cpu.yml +++ b/.github/workflows/test-cpu.yml @@ -23,7 +23,7 @@ jobs: pip install 'torch==1.9.*' pip install 'torchvision==0.10.*' pip install pytest - pip install matplotlib tensorboard ipython ipywidgets pandas optuna onnx onnxruntime pytorch-ignite + pip install matplotlib tensorboard ipython ipywidgets pandas optuna onnx onnxruntime pytorch-ignite marko pip install -v -e . # Test PPE is importable with minimum dependency python -c 'import pytorch_pfn_extras' From f6366844ff87460d748f2734e38d6971ca48371a Mon Sep 17 00:00:00 2001 From: Takeshi Watanabe Date: Tue, 4 Oct 2022 11:52:31 +0900 Subject: [PATCH 25/38] Support typed constant --- pytorch_pfn_extras/onnx/pfto_exporter/torch_reconstruct.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pytorch_pfn_extras/onnx/pfto_exporter/torch_reconstruct.py b/pytorch_pfn_extras/onnx/pfto_exporter/torch_reconstruct.py index 8278b094a..fc72d9fc8 100644 --- a/pytorch_pfn_extras/onnx/pfto_exporter/torch_reconstruct.py +++ b/pytorch_pfn_extras/onnx/pfto_exporter/torch_reconstruct.py @@ -9,6 +9,7 @@ _scope_re = re.compile("(.+), scope: ([^ ]+)") _const_vals_re = re.compile(r"value= ([\d\- ]+) \[ \w+Type\{\d+\} \]") +_const_typed_val_re = re.compile(r"value=\[ \w+Type\{(-?[\d\.e-]+)\} \]") _const_val_re = re.compile(r"value=\{(-?[\d\.e-]+)\}") _func_re = re.compile(r" = \^(\w+)\(") @@ -27,6 +28,7 @@ def _process_line(line: str) -> Tuple[str, str]: line = line.replace("onnx::SequenceConstruct", "prim::ListConstruct") if "prim::Constant" in line: line = re.sub(_const_vals_re, lambda m: f"value=[{m[1].replace(' ', ', ')}]", line) + line = re.sub(_const_typed_val_re, r"value=\1", line) line = re.sub(_const_val_re, r"value=\1", line) func_match = re.search(_func_re, line) From 89d5a7a288cdef8d0af538872e2c36d8dae463ba Mon Sep 17 00:00:00 2001 From: Takeshi Watanabe Date: Tue, 4 Oct 2022 14:47:56 +0900 Subject: [PATCH 26/38] Disable check for now --- .../onnx_tests/test_export_testcase.py | 18 ++++++++++++------ 1 file changed, 12 insertions(+), 6 deletions(-) diff --git a/tests/pytorch_pfn_extras_tests/onnx_tests/test_export_testcase.py b/tests/pytorch_pfn_extras_tests/onnx_tests/test_export_testcase.py index bd6ace8bd..7e85895b9 100644 --- a/tests/pytorch_pfn_extras_tests/onnx_tests/test_export_testcase.py +++ b/tests/pytorch_pfn_extras_tests/onnx_tests/test_export_testcase.py @@ -75,7 +75,9 @@ def test_export_testcase(): model = Net().to('cpu') x = torch.zeros((1, 1, 28, 28)) - output_dir = _helper(model, x, 'mnist', output_grad=True, metadata=True) + output_dir = _helper( + model, x, 'mnist', output_grad=True, metadata=True, + check_reconstruct=False) assert os.path.isdir(output_dir) assert os.path.isfile(os.path.join(output_dir, 'meta.json')) @@ -159,10 +161,12 @@ def test_model_not_overwrite(): x = torch.zeros((1, 1, 28, 28)) dir_name = 'multiple_test_dataset' - output_dir = _helper(model, x, dir_name) + output_dir = _helper(model, x, dir_name, check_reconstruct=False) assert os.path.isdir(output_dir) - output_dir = _helper(model, x + 0.5, dir_name, model_overwrite=False) + output_dir = _helper( + model, x + 0.5, dir_name, + model_overwrite=False, check_reconstruct=False) test_data_set_dir = os.path.join(output_dir, 'test_data_set_1') assert os.path.isfile(os.path.join(test_data_set_dir, 'input_0.pb')) @@ -355,7 +359,8 @@ def test_export_testcase_options(): output_dir = _helper( model, x, 'mnist_stripped_tensor_data', - opset_version=11, strip_doc_string=False) + opset_version=11, strip_doc_string=False, + check_reconstruct=False) onnx_model = onnx.load(os.path.join( output_dir, 'model.onnx'), load_external_data=False) @@ -396,7 +401,8 @@ def test_export_testcase_with_unused_input(keep_initializers_as_inputs): output_dir = _helper( model, args=(x, unused), d='net_with_unused_input_without_input_names', opset_version=11, strip_doc_string=False, - keep_initializers_as_inputs=keep_initializers_as_inputs) + keep_initializers_as_inputs=keep_initializers_as_inputs, + check_reconstruct=False) assert os.path.isdir(output_dir) test_data_set_dir = os.path.join(output_dir, 'test_data_set_0') assert os.path.exists(os.path.join(test_data_set_dir, 'input_0.pb')) @@ -412,7 +418,7 @@ def test_export_testcase_with_unused_input(keep_initializers_as_inputs): model, args=(x, unused), d='net_with_unused_input_with_input_names', opset_version=11, strip_doc_string=False, keep_initializers_as_inputs=keep_initializers_as_inputs, - input_names=['x', 'unused']) + input_names=['x', 'unused'], check_reconstruct=False) assert os.path.isdir(output_dir) test_data_set_dir = os.path.join(output_dir, 'test_data_set_0') assert os.path.exists(os.path.join(test_data_set_dir, 'input_0.pb')) From c6e167d8b6dd4544c2850b2b29cc081616c2d5ea Mon Sep 17 00:00:00 2001 From: Takeshi Watanabe Date: Wed, 5 Oct 2022 17:23:14 +0900 Subject: [PATCH 27/38] Fix permission of script --- .flexci/linux/build_and_push.sh | 0 1 file changed, 0 insertions(+), 0 deletions(-) mode change 100644 => 100755 .flexci/linux/build_and_push.sh diff --git a/.flexci/linux/build_and_push.sh b/.flexci/linux/build_and_push.sh old mode 100644 new mode 100755 From 1afc031ec9dc4d4628778691adc58b0a089afd96 Mon Sep 17 00:00:00 2001 From: Takeshi Watanabe Date: Thu, 6 Oct 2022 14:49:42 +0900 Subject: [PATCH 28/38] Fix initializer name handling --- .../onnx/pfto_exporter/torch_reconstruct.py | 2 +- .../onnx_tests/test_export_testcase.py | 13 ++++++------- 2 files changed, 7 insertions(+), 8 deletions(-) diff --git a/pytorch_pfn_extras/onnx/pfto_exporter/torch_reconstruct.py b/pytorch_pfn_extras/onnx/pfto_exporter/torch_reconstruct.py index fc72d9fc8..f4f70d510 100644 --- a/pytorch_pfn_extras/onnx/pfto_exporter/torch_reconstruct.py +++ b/pytorch_pfn_extras/onnx/pfto_exporter/torch_reconstruct.py @@ -80,7 +80,7 @@ def reconstruct(model: onnx.ModelProto) -> Tuple[torch._C.Graph, List[Tuple[str, outputs: List[str] = ["%" + o.name.split(".")[-1] for o in model.graph.output] body = "\n ".join(lines) - initializer_name_re = re.compile(r"^%(\w+) [:=]") + initializer_name_re = re.compile(r"^%([\w.]+) [:=]") params: List[Tuple[str, torch.Tensor]] = [] for i in model.graph.initializer: i_name = re.match(initializer_name_re, i.doc_string) diff --git a/tests/pytorch_pfn_extras_tests/onnx_tests/test_export_testcase.py b/tests/pytorch_pfn_extras_tests/onnx_tests/test_export_testcase.py index 7e85895b9..93b1a26ec 100644 --- a/tests/pytorch_pfn_extras_tests/onnx_tests/test_export_testcase.py +++ b/tests/pytorch_pfn_extras_tests/onnx_tests/test_export_testcase.py @@ -77,7 +77,7 @@ def test_export_testcase(): output_dir = _helper( model, x, 'mnist', output_grad=True, metadata=True, - check_reconstruct=False) + check_reconstruct=True, verbose=False) assert os.path.isdir(output_dir) assert os.path.isfile(os.path.join(output_dir, 'meta.json')) @@ -161,12 +161,12 @@ def test_model_not_overwrite(): x = torch.zeros((1, 1, 28, 28)) dir_name = 'multiple_test_dataset' - output_dir = _helper(model, x, dir_name, check_reconstruct=False) + output_dir = _helper(model, x, dir_name) assert os.path.isdir(output_dir) output_dir = _helper( model, x + 0.5, dir_name, - model_overwrite=False, check_reconstruct=False) + model_overwrite=False) test_data_set_dir = os.path.join(output_dir, 'test_data_set_1') assert os.path.isfile(os.path.join(test_data_set_dir, 'input_0.pb')) @@ -266,7 +266,7 @@ def test_export_testcase_strip_large_tensor_data(): output_dir = _helper( model, x, 'mnist_stripped_tensor_data', output_grad=True, strip_large_tensor_data=True, - metadata=True, check_reconstruct=False) + metadata=True) assert os.path.isdir(output_dir) assert os.path.isfile(os.path.join(output_dir, 'meta.json')) @@ -401,8 +401,7 @@ def test_export_testcase_with_unused_input(keep_initializers_as_inputs): output_dir = _helper( model, args=(x, unused), d='net_with_unused_input_without_input_names', opset_version=11, strip_doc_string=False, - keep_initializers_as_inputs=keep_initializers_as_inputs, - check_reconstruct=False) + keep_initializers_as_inputs=keep_initializers_as_inputs) assert os.path.isdir(output_dir) test_data_set_dir = os.path.join(output_dir, 'test_data_set_0') assert os.path.exists(os.path.join(test_data_set_dir, 'input_0.pb')) @@ -418,7 +417,7 @@ def test_export_testcase_with_unused_input(keep_initializers_as_inputs): model, args=(x, unused), d='net_with_unused_input_with_input_names', opset_version=11, strip_doc_string=False, keep_initializers_as_inputs=keep_initializers_as_inputs, - input_names=['x', 'unused'], check_reconstruct=False) + input_names=['x', 'unused']) assert os.path.isdir(output_dir) test_data_set_dir = os.path.join(output_dir, 'test_data_set_0') assert os.path.exists(os.path.join(test_data_set_dir, 'input_0.pb')) From 2e1c0237a689a13f755c08380aab6e5c8db3af6a Mon Sep 17 00:00:00 2001 From: Takeshi Watanabe Date: Thu, 6 Oct 2022 15:16:46 +0900 Subject: [PATCH 29/38] Skip reconstruct in stripped test --- .../onnx_tests/test_export_testcase.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/pytorch_pfn_extras_tests/onnx_tests/test_export_testcase.py b/tests/pytorch_pfn_extras_tests/onnx_tests/test_export_testcase.py index 93b1a26ec..2624f172f 100644 --- a/tests/pytorch_pfn_extras_tests/onnx_tests/test_export_testcase.py +++ b/tests/pytorch_pfn_extras_tests/onnx_tests/test_export_testcase.py @@ -66,7 +66,8 @@ def _helper(model, args, d, use_pfto=True, check_reconstruct=True, **kwargs): kwargs["strip_doc_string"] = False export_testcase(model, args, output_dir, use_pfto=use_pfto, **kwargs) if check_reconstruct and use_pfto and not kwargs["strip_doc_string"]: - reconstruct(onnx.load(os.path.join(output_dir, "model.onnx"))) + reconstruct(pytorch_pfn_extras.onnx.load_model( + os.path.join(output_dir, "model.onnx"))) return output_dir @@ -266,7 +267,7 @@ def test_export_testcase_strip_large_tensor_data(): output_dir = _helper( model, x, 'mnist_stripped_tensor_data', output_grad=True, strip_large_tensor_data=True, - metadata=True) + metadata=True, check_reconstruct=False) assert os.path.isdir(output_dir) assert os.path.isfile(os.path.join(output_dir, 'meta.json')) From f1bed626432d16165b64e8da9c5e25d75439ad14 Mon Sep 17 00:00:00 2001 From: Takeshi Watanabe Date: Thu, 6 Oct 2022 15:40:08 +0900 Subject: [PATCH 30/38] Support unstripping too --- .../onnx/pfto_exporter/torch_reconstruct.py | 9 ++++++++- .../onnx_tests/test_export_testcase.py | 2 +- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/pytorch_pfn_extras/onnx/pfto_exporter/torch_reconstruct.py b/pytorch_pfn_extras/onnx/pfto_exporter/torch_reconstruct.py index f4f70d510..ea5511cf1 100644 --- a/pytorch_pfn_extras/onnx/pfto_exporter/torch_reconstruct.py +++ b/pytorch_pfn_extras/onnx/pfto_exporter/torch_reconstruct.py @@ -6,6 +6,8 @@ from collections import OrderedDict from typing import List, Set, Tuple +import pytorch_pfn_extras.onnx.unstrip_tensor + _scope_re = re.compile("(.+), scope: ([^ ]+)") _const_vals_re = re.compile(r"value= ([\d\- ]+) \[ \w+Type\{\d+\} \]") @@ -86,7 +88,12 @@ def reconstruct(model: onnx.ModelProto) -> Tuple[torch._C.Graph, List[Tuple[str, i_name = re.match(initializer_name_re, i.doc_string) if i_name: inputs.append(f"%{i_name[1]}") - params.append((i.name, torch.from_numpy(onnx.numpy_helper.to_array(i).copy()))) + + i_u = onnx.TensorProto() + i_u.CopyFrom(i) + pytorch_pfn_extras.onnx.unstrip_tensor._unstrip_tensor(i_u) + t = torch.from_numpy(onnx.numpy_helper.to_array(i_u).copy()) + params.append((i.name, t)) src: str = f"""graph({", ".join(inputs)}): {body} diff --git a/tests/pytorch_pfn_extras_tests/onnx_tests/test_export_testcase.py b/tests/pytorch_pfn_extras_tests/onnx_tests/test_export_testcase.py index 2624f172f..755b31c33 100644 --- a/tests/pytorch_pfn_extras_tests/onnx_tests/test_export_testcase.py +++ b/tests/pytorch_pfn_extras_tests/onnx_tests/test_export_testcase.py @@ -267,7 +267,7 @@ def test_export_testcase_strip_large_tensor_data(): output_dir = _helper( model, x, 'mnist_stripped_tensor_data', output_grad=True, strip_large_tensor_data=True, - metadata=True, check_reconstruct=False) + metadata=True) assert os.path.isdir(output_dir) assert os.path.isfile(os.path.join(output_dir, 'meta.json')) From 975a977703ccab0b237b465ceb6d29c4a13e19a3 Mon Sep 17 00:00:00 2001 From: Takeshi Watanabe Date: Thu, 6 Oct 2022 16:01:42 +0900 Subject: [PATCH 31/38] Mark shufflenet not reconstructible --- tests/pytorch_pfn_extras_tests/onnx_tests/test_torchvision.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/pytorch_pfn_extras_tests/onnx_tests/test_torchvision.py b/tests/pytorch_pfn_extras_tests/onnx_tests/test_torchvision.py index 183c8e6f7..78f0a2977 100644 --- a/tests/pytorch_pfn_extras_tests/onnx_tests/test_torchvision.py +++ b/tests/pytorch_pfn_extras_tests/onnx_tests/test_torchvision.py @@ -40,5 +40,5 @@ def test_shufflenet(): run_model_test( torchvision.models.shufflenetv2.shufflenet_v2_x1_0(), (torch.rand(1, 3, 224, 224),), - use_gpu=True, + use_gpu=True, check_reconstruct=False ) From 784c378c012cacfb7893141361a55a5baa5fe653 Mon Sep 17 00:00:00 2001 From: twata Date: Wed, 5 Apr 2023 07:17:46 +0000 Subject: [PATCH 32/38] Make mypy happy --- pytorch_pfn_extras/onnx/pfto_exporter/export.py | 6 ++++-- stubs/torch/_C/__init__.pyi | 2 +- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/pytorch_pfn_extras/onnx/pfto_exporter/export.py b/pytorch_pfn_extras/onnx/pfto_exporter/export.py index 80d25e357..3efd2aa88 100644 --- a/pytorch_pfn_extras/onnx/pfto_exporter/export.py +++ b/pytorch_pfn_extras/onnx/pfto_exporter/export.py @@ -600,8 +600,10 @@ def list_added_nodes() -> List[torch._C.Node]: # Place onnx::Identity node instead node when none is added if len(sym_nodes) == 0: - sym_outs = g.op("Identity", sym_outs[0]), - sym_nodes = [sym_outs[0].node()] + sym_out = g.op("Identity", sym_outs[0]) + assert isinstance(sym_out, torch._C.Value) + sym_outs = sym_out, + sym_nodes = [sym_out.node()] self.log(f"Converting node {n.kind()}", n) if len(sym_nodes) > 0: diff --git a/stubs/torch/_C/__init__.pyi b/stubs/torch/_C/__init__.pyi index bfe8ef95a..0fba4d85a 100644 --- a/stubs/torch/_C/__init__.pyi +++ b/stubs/torch/_C/__init__.pyi @@ -384,7 +384,7 @@ def _dump_upgraders_map() -> Dict[str, str]: ... def _test_only_populate_upgraders(content: Dict[str, str]) -> None: ... def _test_only_remove_upgraders(content: Dict[str, str]) -> None: ... def merge_type_from_type_comment(decl: Decl, type_annotation_decl: Decl, is_method: _bool) -> Decl: ... -def parse_ir(input: str, parse_tensor_constants: _bool) -> Graph: ... +def parse_ir(input: str, parse_tensor_constants: _bool = False) -> Graph: ... def parse_schema(schema: str) -> FunctionSchema: ... def get_device(input: Tensor) -> _int: ... From 442401d952336fca429e7046612a70b57da5f31e Mon Sep 17 00:00:00 2001 From: twata Date: Wed, 5 Apr 2023 07:27:05 +0000 Subject: [PATCH 33/38] Use graph context --- pytorch_pfn_extras/onnx/pfto_exporter/export.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_pfn_extras/onnx/pfto_exporter/export.py b/pytorch_pfn_extras/onnx/pfto_exporter/export.py index 3efd2aa88..6bfb61bfc 100644 --- a/pytorch_pfn_extras/onnx/pfto_exporter/export.py +++ b/pytorch_pfn_extras/onnx/pfto_exporter/export.py @@ -600,7 +600,7 @@ def list_added_nodes() -> List[torch._C.Node]: # Place onnx::Identity node instead node when none is added if len(sym_nodes) == 0: - sym_out = g.op("Identity", sym_outs[0]) + sym_out = g_ctx.op("Identity", sym_outs[0]) assert isinstance(sym_out, torch._C.Value) sym_outs = sym_out, sym_nodes = [sym_out.node()] From b6bedeabf58a68de26fad526d4b831e5c062e43a Mon Sep 17 00:00:00 2001 From: twata Date: Wed, 5 Apr 2023 07:27:37 +0000 Subject: [PATCH 34/38] Make some tests not supported --- tests/pytorch_pfn_extras_tests/onnx_tests/test_as_output.py | 4 ++-- tests/pytorch_pfn_extras_tests/onnx_tests/test_export.py | 3 ++- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/tests/pytorch_pfn_extras_tests/onnx_tests/test_as_output.py b/tests/pytorch_pfn_extras_tests/onnx_tests/test_as_output.py index c748d5e93..ffdfe07f6 100644 --- a/tests/pytorch_pfn_extras_tests/onnx_tests/test_as_output.py +++ b/tests/pytorch_pfn_extras_tests/onnx_tests/test_as_output.py @@ -106,7 +106,7 @@ def forward(self, x): model = Net() x = torch.ones((1, 1, 32, 32)) - output_dir = _helper(model, x, 'as_output') + output_dir = _helper(model, x, 'as_output', check_reconstruct=False) actual_onnx = onnx.load(os.path.join(output_dir, 'model.onnx')) named_nodes = {n.name: n for n in actual_onnx.graph.node} @@ -138,7 +138,7 @@ def forward(self, x): model = Net() x = torch.ones((1, 1, 32, 32)) - output_dir = _helper(model, x, 'as_output') + output_dir = _helper(model, x, 'as_output', check_reconstruct=False) actual_onnx = onnx.load(os.path.join(output_dir, 'model.onnx')) named_nodes = {n.name: n for n in actual_onnx.graph.node} diff --git a/tests/pytorch_pfn_extras_tests/onnx_tests/test_export.py b/tests/pytorch_pfn_extras_tests/onnx_tests/test_export.py index 977b5329b..79cc5b634 100644 --- a/tests/pytorch_pfn_extras_tests/onnx_tests/test_export.py +++ b/tests/pytorch_pfn_extras_tests/onnx_tests/test_export.py @@ -226,7 +226,8 @@ def forward(self, x): m = run_model_test( Model(), (torch.randn(2, 7, 17),), skip_oxrt=True, - custom_opsets={"org.chainer": ver}) + custom_opsets={"org.chainer": ver}, + check_reconstruct=False) assert len(m.opset_import) == 2 From b0f7864027646e97c88728f8a023e7b1f52e3b2d Mon Sep 17 00:00:00 2001 From: twata Date: Tue, 11 Apr 2023 04:42:02 +0000 Subject: [PATCH 35/38] Run reconstructed graph --- .../onnx/pfto_exporter/torch_reconstruct.py | 18 ++++++++++++++++-- .../onnx_tests/test_export.py | 3 ++- .../onnx_tests/utils.py | 13 ++++++++++++- 3 files changed, 30 insertions(+), 4 deletions(-) diff --git a/pytorch_pfn_extras/onnx/pfto_exporter/torch_reconstruct.py b/pytorch_pfn_extras/onnx/pfto_exporter/torch_reconstruct.py index ea5511cf1..81c32d5b7 100644 --- a/pytorch_pfn_extras/onnx/pfto_exporter/torch_reconstruct.py +++ b/pytorch_pfn_extras/onnx/pfto_exporter/torch_reconstruct.py @@ -31,7 +31,14 @@ def _process_line(line: str) -> Tuple[str, str]: if "prim::Constant" in line: line = re.sub(_const_vals_re, lambda m: f"value=[{m[1].replace(' ', ', ')}]", line) line = re.sub(_const_typed_val_re, r"value=\1", line) - line = re.sub(_const_val_re, r"value=\1", line) + if "[] = " in line: + line = re.sub(_const_val_re, r"value=[\1]", line) + else: + line = re.sub(_const_val_re, r"value=\1", line) + + line = line.replace("Bool(device=cpu)", "bool") + line = line.replace("Long(device=cpu)", "int") + line = line.replace("Double(device=cpu)", "float") func_match = re.search(_func_re, line) if func_match: @@ -95,10 +102,17 @@ def reconstruct(model: onnx.ModelProto) -> Tuple[torch._C.Graph, List[Tuple[str, t = torch.from_numpy(onnx.numpy_helper.to_array(i_u).copy()) params.append((i.name, t)) - src: str = f"""graph({", ".join(inputs)}): + if len(outputs) == 1: + src: str = f"""graph({", ".join(inputs)}): {body} return ({", ".join(outputs)}) """ + else: + src: str = f"""graph({", ".join(inputs)}): + {body} + %__out = prim::TupleConstruct({", ".join(outputs)}) + return (%__out) +""" g: torch._C.Graph = torch._C.parse_ir(src) torch._C._jit_pass_lint(g) diff --git a/tests/pytorch_pfn_extras_tests/onnx_tests/test_export.py b/tests/pytorch_pfn_extras_tests/onnx_tests/test_export.py index 79cc5b634..6e14f500a 100644 --- a/tests/pytorch_pfn_extras_tests/onnx_tests/test_export.py +++ b/tests/pytorch_pfn_extras_tests/onnx_tests/test_export.py @@ -197,7 +197,8 @@ def forward(self, *hidden): run_model_test( Model(), (torch.randn(2, 7, 17), torch.randn(2, 7, 17)), - skip_oxrt=True, output_names=["a", "b", "c"]) + skip_oxrt=True, output_names=["a", "b", "c"], + check_reconstruct=False) @pytest.mark.filterwarnings("ignore:The shape inference of org.chainer..Add type is missing:UserWarning") diff --git a/tests/pytorch_pfn_extras_tests/onnx_tests/utils.py b/tests/pytorch_pfn_extras_tests/onnx_tests/utils.py index 7fe3c4baa..d5e435615 100644 --- a/tests/pytorch_pfn_extras_tests/onnx_tests/utils.py +++ b/tests/pytorch_pfn_extras_tests/onnx_tests/utils.py @@ -30,6 +30,7 @@ def run_model_test( assert mode == "eval" model.eval() + dev = "cpu" if use_gpu and torch.cuda.is_available(): dev = "cuda" model.to(dev) @@ -87,7 +88,17 @@ def run_model_test( assert len(te_model.graph.input) == len(pfto_model.graph.input) if check_reconstruct: - reconstruct(pfto_model) + pt, pt_params = reconstruct(pfto_model) + pt_f = torch._C._create_function_from_graph("forward", pt) + + torch.set_rng_state(rng_state) + pt_res = pt_f(*args, *[p[1].to(dev) for p in pt_params]) + if isinstance(pt_res, torch.Tensor): + pt_res = pt_res, + assert len(pt_res) == len(expected) + for a, e in zip(pt_res, expected): + cmp = torch.isclose(a.cpu(), e.cpu(), rtol=rtol, atol=atol) + assert cmp.all(), f"{cmp.logical_not().count_nonzero()} / {cmp.numel()} values failed" if skip_oxrt: return pfto_model From 61e3b7168bd7b52a7ab4838bfe55110c15895bf7 Mon Sep 17 00:00:00 2001 From: twata Date: Tue, 11 Apr 2023 04:47:08 +0000 Subject: [PATCH 36/38] redef --- pytorch_pfn_extras/onnx/pfto_exporter/torch_reconstruct.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/pytorch_pfn_extras/onnx/pfto_exporter/torch_reconstruct.py b/pytorch_pfn_extras/onnx/pfto_exporter/torch_reconstruct.py index 81c32d5b7..dbedd199f 100644 --- a/pytorch_pfn_extras/onnx/pfto_exporter/torch_reconstruct.py +++ b/pytorch_pfn_extras/onnx/pfto_exporter/torch_reconstruct.py @@ -102,13 +102,14 @@ def reconstruct(model: onnx.ModelProto) -> Tuple[torch._C.Graph, List[Tuple[str, t = torch.from_numpy(onnx.numpy_helper.to_array(i_u).copy()) params.append((i.name, t)) + src: str = "" if len(outputs) == 1: - src: str = f"""graph({", ".join(inputs)}): + src = f"""graph({", ".join(inputs)}): {body} return ({", ".join(outputs)}) """ else: - src: str = f"""graph({", ".join(inputs)}): + src = f"""graph({", ".join(inputs)}): {body} %__out = prim::TupleConstruct({", ".join(outputs)}) return (%__out) From 9aaad9968f5d9567f7b1d3bcb6436a9c49e4c7a9 Mon Sep 17 00:00:00 2001 From: twata Date: Fri, 14 Apr 2023 06:05:47 +0000 Subject: [PATCH 37/38] Make tests pass in multiple torch versions --- pytorch_pfn_extras/onnx/pfto_exporter/torch_reconstruct.py | 7 ++++++- tests/pytorch_pfn_extras_tests/onnx_tests/test_export.py | 4 +++- 2 files changed, 9 insertions(+), 2 deletions(-) diff --git a/pytorch_pfn_extras/onnx/pfto_exporter/torch_reconstruct.py b/pytorch_pfn_extras/onnx/pfto_exporter/torch_reconstruct.py index dbedd199f..77f124041 100644 --- a/pytorch_pfn_extras/onnx/pfto_exporter/torch_reconstruct.py +++ b/pytorch_pfn_extras/onnx/pfto_exporter/torch_reconstruct.py @@ -30,16 +30,21 @@ def _process_line(line: str) -> Tuple[str, str]: line = line.replace("onnx::SequenceConstruct", "prim::ListConstruct") if "prim::Constant" in line: line = re.sub(_const_vals_re, lambda m: f"value=[{m[1].replace(' ', ', ')}]", line) - line = re.sub(_const_typed_val_re, r"value=\1", line) if "[] = " in line: + line = re.sub(_const_typed_val_re, r"value=[\1]", line) line = re.sub(_const_val_re, r"value=[\1]", line) else: + line = re.sub(_const_typed_val_re, r"value=\1", line) line = re.sub(_const_val_re, r"value=\1", line) line = line.replace("Bool(device=cpu)", "bool") line = line.replace("Long(device=cpu)", "int") line = line.replace("Double(device=cpu)", "float") + line = line.replace("Bool(requires_grad=0, device=cpu)", "bool") + line = line.replace("Long(requires_grad=0, device=cpu)", "int") + line = line.replace("Double(requires_grad=0, device=cpu)", "float") + func_match = re.search(_func_re, line) if func_match: raise ReconstructError(f"torch.autograd.Function call not supported for: {func_match[1]} in line: {line}") diff --git a/tests/pytorch_pfn_extras_tests/onnx_tests/test_export.py b/tests/pytorch_pfn_extras_tests/onnx_tests/test_export.py index 6e14f500a..10603ab89 100644 --- a/tests/pytorch_pfn_extras_tests/onnx_tests/test_export.py +++ b/tests/pytorch_pfn_extras_tests/onnx_tests/test_export.py @@ -2,6 +2,7 @@ import pytest import torch +import pytorch_pfn_extras from pytorch_pfn_extras_tests.onnx_tests.utils import run_model_test @@ -166,7 +167,8 @@ def __init__(self): def forward(self, x): return torch.norm(x) - run_model_test(Net(), (torch.rand(2, 3, 5, 7),), opset_version=13) + check_reconstruct = pytorch_pfn_extras.requires("2.0") + run_model_test(Net(), (torch.rand(2, 3, 5, 7),), opset_version=13, check_reconstruct=check_reconstruct) def test_rand(): From c700340f4721615d47ded7da88de1398964ea8e2 Mon Sep 17 00:00:00 2001 From: twata Date: Tue, 23 May 2023 07:49:20 +0000 Subject: [PATCH 38/38] Fix test failures --- tests/pytorch_pfn_extras_tests/onnx_tests/test_export.py | 1 + tests/pytorch_pfn_extras_tests/onnx_tests/test_grad.py | 3 +++ tests/pytorch_pfn_extras_tests/onnx_tests/utils.py | 1 + 3 files changed, 5 insertions(+) diff --git a/tests/pytorch_pfn_extras_tests/onnx_tests/test_export.py b/tests/pytorch_pfn_extras_tests/onnx_tests/test_export.py index 5abbb9e77..78e9a8f7a 100644 --- a/tests/pytorch_pfn_extras_tests/onnx_tests/test_export.py +++ b/tests/pytorch_pfn_extras_tests/onnx_tests/test_export.py @@ -356,4 +356,5 @@ def forward(self, x): Proxy(), (x,), do_constant_folding=False, + check_reconstruct=False, ) diff --git a/tests/pytorch_pfn_extras_tests/onnx_tests/test_grad.py b/tests/pytorch_pfn_extras_tests/onnx_tests/test_grad.py index 18b9e544d..0660ce22b 100644 --- a/tests/pytorch_pfn_extras_tests/onnx_tests/test_grad.py +++ b/tests/pytorch_pfn_extras_tests/onnx_tests/test_grad.py @@ -99,6 +99,7 @@ def forward(self, x): f"grad_{use_pfto}", enable_onnx_checker=False, use_pfto=use_pfto, + check_reconstruct=False, output_names=["h"], ) @@ -172,6 +173,7 @@ def forward(self, x): f"grad_multi_times_{use_pfto}", enable_onnx_checker=False, use_pfto=use_pfto, + check_reconstruct=False, output_names=["h"], ) @@ -247,6 +249,7 @@ def forward(self, x): f"grad_multi_inputs_{use_pfto}", enable_onnx_checker=False, use_pfto=use_pfto, + check_reconstruct=False, output_names=["h"], ) diff --git a/tests/pytorch_pfn_extras_tests/onnx_tests/utils.py b/tests/pytorch_pfn_extras_tests/onnx_tests/utils.py index 685d1af8d..3b518eb9b 100644 --- a/tests/pytorch_pfn_extras_tests/onnx_tests/utils.py +++ b/tests/pytorch_pfn_extras_tests/onnx_tests/utils.py @@ -68,6 +68,7 @@ def run_model_test( strict_trace=strict_trace, return_output=True, use_pfto=True, + strip_doc_string=False, **kwargs, ) if isinstance(actual, torch.Tensor):