Skip to content

Commit

Permalink
Update TorchScript server (#990)
Browse files Browse the repository at this point in the history
  • Loading branch information
lutzroeder committed Nov 24, 2022
1 parent 285d32b commit 65f82c4
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 66 deletions.
143 changes: 80 additions & 63 deletions source/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,11 @@ def to_json(self): # pylint: disable=missing-function-docstring,too-many-locals,
[ torch.int32, 'int32'], # pylint: disable=no-member
[ torch.int64, 'int64'], # pylint: disable=no-member
])
def constant_value(node):
if node.hasAttribute('value'):
selector = node.kindOf('value')
return getattr(node, selector)('value')
return None
arguments_map = {}
def argument(value):
if not value in arguments_map:
Expand All @@ -76,7 +81,8 @@ def argument(value):
node = value.node()
if node.kind() == "prim::GetAttr":
tensor, name = self._getattr(node)
if tensor is not None and len(name) > 0:
if tensor is not None and len(name) > 0 and \
isinstance(tensor, torch.Tensor):
json_argument['name'] = name
json_argument['initializer'] = {}
json_tensor_shape = {
Expand All @@ -86,6 +92,17 @@ def argument(value):
'dataType': data_type_map[tensor.dtype],
'shape': json_tensor_shape
}
elif node.kind() == "prim::Constant":
tensor = constant_value(node)
if tensor and isinstance(tensor, torch.Tensor):
json_argument['initializer'] = {}
json_tensor_shape = {
'dimensions': list(tensor.shape)
}
json_argument['type'] = {
'dataType': data_type_map[tensor.dtype],
'shape': json_tensor_shape
}
elif value.isCompleteTensor():
json_tensor_shape = {
'dimensions': value.type().sizes()
Expand All @@ -110,59 +127,35 @@ def argument(value):
'name': value.debugName(),
'arguments': [ argument(value) ]
})
scopes = {}
constants = {}
for node in graph.nodes():
if node.kind() == 'prim::Constant':
attributes = node.attributeNames()
if len(attributes) == 1 and attributes[0] == 'value':
selector = node.kindOf('value')
value = getattr(node, selector)('value')
if not isinstance(value, torch.Tensor):
constants[node] = 0
constants[node] = 0

lists = {}
for node in graph.nodes():
if node.kind() == 'prim::ListConstruct':
if all(_.node() in constants for _ in node.inputs()):
for _ in node.inputs():
constants[_.node()] += 1
lists[node] = 0
for _ in graph.nodes():
if _ in lists:
continue
if _ in constants:
continue
if _.kind() == 'prim::GetAttr':
name = _.s('name')
key = _.output().debugName()
parent = _.input().node()
if parent.kind() == 'prim::GetAttr':
# parent_name = parent.s("name")
parent_key = parent.output().debugName()
parent_scope = scopes[parent_key]
scope = parent_scope.split('/')[-1]
scopes[key] = parent_scope + '/' + scope + '.' + name
else:
scopes[key] = '__module.' + name
if _.output().type().kind() != 'ClassType':
# node_py = NodePyOP(node)
# node_py.scopeName = scopes[key]
# nodes_py.append(node_py)
pass
continue
schema = _.schema() if hasattr(_, 'schema') else None

def create_node(node):
schema = node.schema() if hasattr(node, 'schema') else None
schema = self.metadata.type(schema) if schema and schema != '(no schema)' else None
json_node = {
'type': {
'name': _.kind(),
'name': node.kind(),
'category': schema['category'] if schema and 'category' in schema else ''
},
'inputs': [],
'outputs': [],
'attributes': []
}
json_graph['nodes'].append(json_node)
for name in _.attributeNames():
selector = _.kindOf(name)
value = getattr(_, selector)(name)
for name in node.attributeNames():
selector = node.kindOf(name)
value = getattr(node, selector)(name)
json_attribute = {
'name': name,
'value': value
Expand All @@ -174,46 +167,70 @@ def argument(value):
})
else:
json_node['attributes'].append(json_attribute)
def constant_value(node):
selector = node.kindOf('value')
return getattr(node, selector)('value')
for i, value in enumerate(_.inputs()):

for i, value in enumerate(node.inputs()):
parameter = schema['inputs'][i] if schema and i < len(schema['inputs']) else None
name = parameter['name'] if parameter and 'name' in parameter else 'input'
node = value.node()
if node in constants:
json_attribute = {
'name': name,
'value': constant_value(node)
}
json_node['attributes'].append(json_attribute)
constants[node] = constants[node] + 1
parameter_name = parameter['name'] if parameter and 'name' in parameter else 'input'
parameter_type = parameter['type'] if parameter and 'type' in parameter else None
input_node = value.node()
if input_node in constants:
if parameter_type == 'Tensor' or value.type().kind() == 'TensorType':
json_node['inputs'].append({
'name': parameter_name,
'arguments': [ argument(value) ]
})
else:
json_attribute = {
'name': parameter_name,
'value': constant_value(input_node)
}
if parameter and 'type' in parameter:
json_attribute['type'] = parameter['type']
json_node['attributes'].append(json_attribute)
constants[input_node] = constants[input_node] + 1
continue
if node in lists:
if input_node in lists:
json_attribute = {
'name': name,
'value': [ constant_value(_.node()) for _ in node.inputs() ]
'name': parameter_name,
'value': [ constant_value(_.node()) for _ in input_node.inputs() ]
}
json_node['attributes'].append(json_attribute)
lists[node] = lists[node] + 1
lists[input_node] += 1
continue
json_parameter = {
'name': name,
if input_node.kind() == 'prim::TupleUnpack':
continue
if input_node.kind() == 'prim::TupleConstruct':
continue
json_node['inputs'].append({
'name': parameter_name,
'arguments': [ argument(value) ]
}
json_node['inputs'].append(json_parameter)
})

for i, output_value in enumerate(_.outputs()):
for i, value in enumerate(node.outputs()):
parameter = schema['outputs'][i] if schema and i < len(schema['outputs']) else None
name = parameter['name'] if parameter and 'name' in parameter else 'output'
json_node['outputs'].append({
'name': name,
'arguments': [ argument(output_value) ]
'arguments': [ argument(value) ]
})
for _ in graph.nodes():
if _.kind() == 'prim::Constant' and _ in constants and constants[_]:
if constants[_] != len(_.output().uses()):
pass

for node in graph.nodes():
if node in lists:
continue
if node in constants:
continue
if node.kind() == 'prim::GetAttr':
continue
create_node(node)

for node in graph.nodes():
if node.kind() == 'prim::Constant' and \
node in constants and constants[node] != len(node.output().uses()):
create_node(node)
if node.kind() == 'prim::ListConstruct' and \
node in lists and lists[node] != len(node.output().uses()):
create_node(node)

return json_graph

class Metadata: # pylint: disable=too-few-public-methods,missing-class-docstring
Expand Down
6 changes: 3 additions & 3 deletions test/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,13 @@ def _test_torchscript():
# model = torchvision.models.alexnet(weights=torchvision.models.AlexNet_Weights.DEFAULT)
# model = torchvision.models.resnet34(weights=torchvision.models.ResNet34_Weights.DEFAULT)
model = torchvision.models.resnet34()
file = os.path.join(test_data_dir, 'pytorch', 'resnet34-333f7ec4.pth')
state_dict = torch.load(file)
state_dict = torch.load(os.path.join(test_data_dir, 'pytorch', 'resnet34-333f7ec4.pth'))
model.load_state_dict(state_dict)
args = torch.zeros([1, 3, 224, 224])
trace = torch.jit.trace(model, args, strict=True)
# graph, _ = torch.jit._get_trace_graph(model, args) # pylint: disable=protected-access
# torch.onnx._optimize_trace(graph, torch.onnx.OperatorExportTypes.ONNX)
trace = torch.jit.trace(model, args, strict=True)
# trace = torch.load(os.path.join(test_data_dir, 'pytorch', 'fasterrcnn_resnet50_fpn.pt'))
torch._C._jit_pass_inline(trace.graph)
# https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/ir/ir.h
netron.serve('resnet34', trace)
Expand Down

0 comments on commit 65f82c4

Please sign in to comment.