Skip to content

Commit

Permalink
Update TorchScript server (#990)
Browse files Browse the repository at this point in the history
  • Loading branch information
lutzroeder committed Nov 22, 2022
1 parent 7a5d30b commit 3e3a3e8
Show file tree
Hide file tree
Showing 2 changed files with 120 additions and 40 deletions.
151 changes: 116 additions & 35 deletions source/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,21 +27,32 @@ def __init__(self, metadata, model):

def to_json(self):
''' Serialize model to JSON message '''
import torch # pylint: disable=import-outside-toplevel
json_model = {
'signature': 'netron:pytorch',
'format': 'TorchScript',
'format': 'TorchScript v' + torch.__version__,
'graphs': [ self.graph.to_json() ]
}
return json_model

class _Graph: # pylint: disable=too-few-public-methods

def __init__(self, metadata, graph):
def __init__(self, metadata, model):
self.metadata = metadata
self.value = graph
self.param = model
self.value = model.graph
self.nodes = []

def to_json(self): # pylint: disable=missing-function-docstring,too-many-locals
def _getattr(self, node):
if node.kind() == 'prim::Param':
return (self.param, '')
if node.kind() == 'prim::GetAttr':
name = node.s('name')
obj, parent = self._getattr(node.input().node())
return (getattr(obj, name), parent + '.' + name if len(parent) > 0 else name)
raise Exception()

def to_json(self): # pylint: disable=missing-function-docstring,too-many-locals,too-many-statements,too-many-branches
import torch # pylint: disable=import-outside-toplevel
graph = self.value
json_graph = {
Expand All @@ -61,52 +72,97 @@ def to_json(self): # pylint: disable=missing-function-docstring,too-many-locals
def argument(value):
if not value in arguments_map:
json_argument = {}
json_argument['name'] = str(value.unique()) + '>' + str(value.node().kind())
if value.isCompleteTensor():
json_argument['name'] = str(value.unique())
node = value.node()
if node.kind() == "prim::GetAttr":
tensor, name = self._getattr(node)
if tensor is not None and len(name) > 0:
json_argument['name'] = name
json_argument['initializer'] = {}
json_tensor_shape = {
'dimensions': list(tensor.shape)
}
json_argument['type'] = {
'dataType': data_type_map[tensor.dtype],
'shape': json_tensor_shape
}
elif value.isCompleteTensor():
json_tensor_shape = {
'dimensions': value.type().sizes()
}
json_argument['type'] = {
'dataType': data_type_map[value.type().dtype()],
'shape': json_tensor_shape
}
if value.node().kind() == "prim::Param":
json_argument['initializer'] = {}
arguments = json_graph['arguments']
arguments_map[value] = len(arguments)
arguments.append(json_argument)
return arguments_map[value]

for _ in graph.inputs():
# if len(_.uses()) == 0:
# continue
json_graph['inputs'].append({
'name': _.debugName(),
'arguments': [ argument(_) ]
})
for _ in graph.outputs():
for value in graph.inputs():
if len(value.uses()) != 0 and value.type().kind() != 'ClassType':
json_graph['inputs'].append({
'name': value.debugName(),
'arguments': [ argument(value) ]
})
for value in graph.outputs():
json_graph['outputs'].append({
'name': _.debugName(),
'arguments': [ argument(_) ]
'name': value.debugName(),
'arguments': [ argument(value) ]
})
scopes = {}
constants = {}
for node in graph.nodes():
if node.kind() == 'prim::Constant':
attributes = node.attributeNames()
if len(attributes) == 1 and attributes[0] == 'value':
selector = node.kindOf('value')
value = getattr(node, selector)('value')
if not isinstance(value, torch.Tensor):
constants[node] = 0
lists = {}
for node in graph.nodes():
# if node.kind() == 'prim::ListConstruct':
# continue
# if node.kind() == 'prim::Constant':
# continue
# if node.kind() == 'prim::GetAttr':
# continue
schema = node.schema() if hasattr(node, 'schema') else None
if node.kind() == 'prim::ListConstruct':
if all(_.node() in constants for _ in node.inputs()):
lists[node] = 0
for _ in graph.nodes():
if _ in lists:
continue
if _ in constants:
continue
if _.kind() == 'prim::GetAttr':
name = _.s('name')
key = _.output().debugName()
parent = _.input().node()
if parent.kind() == 'prim::GetAttr':
# parent_name = parent.s("name")
parent_key = parent.output().debugName()
parent_scope = scopes[parent_key]
scope = parent_scope.split('/')[-1]
scopes[key] = parent_scope + '/' + scope + '.' + name
else:
scopes[key] = '__module.' + name
if _.output().type().kind() != 'ClassType':
# node_py = NodePyOP(node)
# node_py.scopeName = scopes[key]
# nodes_py.append(node_py)
pass
continue
schema = _.schema() if hasattr(_, 'schema') else None
schema = self.metadata.type(schema) if schema and schema != '(no schema)' else None
json_node = {
'type': { 'name': node.kind() },
'type': {
'name': _.kind(),
'category': schema['category'] if schema and 'category' in schema else ''
},
'inputs': [],
'outputs': [],
'attributes': []
}
json_graph['nodes'].append(json_node)
for name in node.attributeNames():
value = node[name]
for name in _.attributeNames():
selector = _.kindOf(name)
value = getattr(_, selector)(name)
json_attribute = {
'name': name,
'value': value
Expand All @@ -118,21 +174,46 @@ def argument(value):
})
else:
json_node['attributes'].append(json_attribute)

for i, input_value in enumerate(node.inputs()):
input_schema = schema['inputs'][i] if schema and i < len(schema['inputs']) else None
name = input_schema['name'] if hasattr(input_schema, 'name') else 'input'
def constant_value(node):
selector = node.kindOf('value')
return getattr(node, selector)('value')
for i, value in enumerate(_.inputs()):
parameter = schema['inputs'][i] if schema and i < len(schema['inputs']) else None
name = parameter['name'] if parameter and 'name' in parameter else 'input'
node = value.node()
if node in constants:
json_attribute = {
'name': name,
'value': constant_value(node)
}
json_node['attributes'].append(json_attribute)
constants[node] = constants[node] + 1
continue
if node in lists:
json_attribute = {
'name': name,
'value': [ constant_value(_.node()) for _ in node.inputs() ]
}
json_node['attributes'].append(json_attribute)
lists[node] = lists[node] + 1
continue
json_parameter = {
'name': name,
'arguments': [ argument(input_value) ]
'arguments': [ argument(value) ]
}
json_node['inputs'].append(json_parameter)

for output_value in node.outputs():
for i, output_value in enumerate(_.outputs()):
parameter = schema['outputs'][i] if schema and i < len(schema['outputs']) else None
name = parameter['name'] if parameter and 'name' in parameter else 'output'
json_node['outputs'].append({
'name': 'x',
'name': name,
'arguments': [ argument(output_value) ]
})
for _ in graph.nodes():
if _.kind() == 'prim::Constant' and _ in constants and constants[_]:
if constants[_] != len(_.output().uses()):
pass
return json_graph

class Metadata: # pylint: disable=too-few-public-methods,missing-class-docstring
Expand Down
9 changes: 4 additions & 5 deletions test/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,10 @@ def _test_torchscript():
# graph, _ = torch.jit._get_trace_graph(model, args) # pylint: disable=protected-access
# torch.onnx._optimize_trace(graph, torch.onnx.OperatorExportTypes.ONNX)
trace = torch.jit.trace(model, args, strict=True)
graph = trace.graph
torch._C._jit_pass_inline(graph)
torch._C._jit_pass_inline(trace.graph)
# https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/ir/ir.h
netron.serve('resnet34', graph)
netron.serve('resnet34', trace)

# _test_onnx()
# _test_torchscript()
_test_onnx_list()
_test_torchscript()
# _test_onnx_list()

0 comments on commit 3e3a3e8

Please sign in to comment.