Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[pfto] Reduce redudant copy #751

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 36 additions & 34 deletions pytorch_pfn_extras/onnx/pfto_exporter/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def _type_to_proto(t: torch._C.TensorType) -> onnx.TypeProto:

if t.kind() == "IntType":
ret.tensor_type.elem_type = onnx.TensorProto.DataType.INT64 # type: ignore[attr-defined]
ret.tensor_type.shape.CopyFrom(onnx.TensorShapeProto())
ret.tensor_type.shape.SetInParent()
return ret

assert t.kind() == "TensorType", f"Not Tensor type(actual: {t.kind()}): {t}"
Expand All @@ -173,7 +173,7 @@ def _type_to_proto(t: torch._C.TensorType) -> onnx.TypeProto:
sym_hel.cast_pytorch_to_onnx[t.scalarType()] # type: ignore[index]
)

ret.tensor_type.shape.CopyFrom(onnx.TensorShapeProto())
ret.tensor_type.shape.SetInParent()
if t.sizes() is not None:
for s in t.sizes(): # type: ignore
d = ret.tensor_type.shape.dim.add()
Expand Down Expand Up @@ -936,23 +936,29 @@ def assign_onnx_values(
assert len(blocks) == 2
for attr_name, block in zip(["then_branch", "else_branch"], blocks):
sub_g = block2subgraph(f"{new_nd.name}_{attr_name}", block, new_nd.doc_string)
new_nd.attribute.append(onnx.helper.make_attribute(attr_name, sub_g))
attr = new_nd.attribute.add()
attr.name = attr_name
attr.type = onnx.AttributeProto.GRAPH
attr.g.CopyFrom(sub_g)
else:
assert len(list(n.blocks())) == 0, f"Node with block needs to be handled separately: {n}"
if n in self.node_doc_string:
new_nd.doc_string = self.node_doc_string[n]
for attr_name in n.attributeNames():
attr_kind = n.kindOf(attr_name)
if attr_kind == "t":
attr = onnx.helper.make_attribute(attr_name, _tensor_to_proto(n.t(attr_name)))
attr = new_nd.attribute.add()
attr.name = attr_name
attr.type = onnx.AttributeProto.TENSOR
attr.t.CopyFrom(_tensor_to_proto(n.t(attr_name)))
else:
if pytorch_pfn_extras.requires('1.13'):
attr_val = sym_hel._node_get(n, attr_name) # type: ignore[attr-defined]
else:
attr_val = n[attr_name]
# Could not use onnx.helper.make_attribute for
if isinstance(attr_val, list):
attr = onnx.AttributeProto()
attr = new_nd.attribute.add()
attr.name = attr_name
if attr_kind == "ss":
attr.type = onnx.AttributeProto.STRINGS
Expand All @@ -966,8 +972,8 @@ def assign_onnx_values(
else:
assert False, f"'{attr_kind}' typed attribute not supported"
else:
attr = onnx.helper.make_attribute(attr_name, attr_val)
new_nd.attribute.append(attr)
attr = new_nd.attribute.add()
attr.CopyFrom(onnx.helper.make_attribute(attr_name, attr_val))
assign_onnx_values(new_nd.input, new_nd.name, n.inputs())
assign_onnx_values(new_nd.output, new_nd.name, n.outputs())
onnx_nodes.append(new_nd)
Expand Down Expand Up @@ -1001,6 +1007,8 @@ def generate_onnx(self) -> onnx.ModelProto:

self.log("ONNX graph", self.g)

model = onnx.ModelProto()

with record("to_node_proto"):
onnx_nodes, onnx_vars, val_tab = self.generate_proto_nodes(self.g, {}, {})

Expand Down Expand Up @@ -1069,52 +1077,46 @@ def apply_dynamic_axes_info(out: onnx.ValueInfoProto, k: str) -> None:

with record("rename_onnx_vars"):
unique_onnx_vars: Dict[str, onnx.ValueInfoProto] = {}
identities: List[onnx.NodeProto] = []
for onnx_name, ox_v in onnx_vars.items():
if ox_v.name in unique_onnx_vars:
ox_n = onnx.NodeProto()
ox_n = model.graph.node.add()
ox_n.name = f"{val_tab[onnx_name]}_id"
ox_n.op_type = "Identity"
ox_n.input.append(ox_v.name)
ox_n.output.append(val_tab[onnx_name])
identities.append(ox_n)
else:
unique_onnx_vars[ox_v.name] = ox_v
onnx_nodes = identities + onnx_nodes

with record("make_graph"):
graph = onnx.helper.make_graph(
nodes=onnx_nodes,
name=self.traced.original_name,
inputs=onnx_inputs,
outputs=onnx_outputs,
initializer=[v for k, v in unique_onnx_vars.items()],
doc_string=None if self.strip_doc_string else self.graph_doc_string,
# TODO(twata): Use torch IR's value type info
# value_info=[
# self.values[k] for k in set(list(self.values.keys())) - set(inout_names)
# ],
)
graph = model.graph
graph.node.extend(onnx_nodes)
graph.name = self.traced.original_name
graph.input.extend(onnx_inputs)
graph.output.extend(onnx_outputs)
graph.initializer.extend([v for k, v in unique_onnx_vars.items()])
if not self.strip_doc_string:
graph.doc_string = self.graph_doc_string
# TODO(twata): Use torch IR's value type info
# graph.value_info=[
# self.values[k] for k in set(list(self.values.keys())) - set(inout_names)
# ]

self.log("ONNX printable graph", lambda: onnx.helper.printable_graph(graph))

def get_model_opset_imports(graph: onnx.GraphProto) -> List[onnx.OperatorSetIdProto]:
def set_model_opset_imports(model: onnx.ModelProto) -> None:
opsets = {onnx.defs.ONNX_DOMAIN: self.opset_version}
for node in graph.node:
for node in model.graph.node:
if node.domain != onnx.defs.ONNX_DOMAIN:
opsets[node.domain] = self.custom_opsets.get(node.domain, 1)
opset_imports = []
for domain, version in opsets.items():
opset_imports.append(onnx.helper.make_opsetid(domain, version))
return opset_imports
o = model.opset_import.add()
o.domain = domain
o.version = version

with record("make_model"):
model: onnx.ModelProto = onnx.helper.make_model_gen_version(
graph,
opset_imports=get_model_opset_imports(graph),
producer_name="pfto",
ir_version=_fix_ir_version,
)
set_model_opset_imports(model)
model.producer_name = "pfto"
model.ir_version = _fix_ir_version
with record("pfto.check_model"):
model = self.check_model(model)

Expand Down