diff --git a/smaug/python/graph.py b/smaug/python/graph.py index f1aa449a..59367c7d 100644 --- a/smaug/python/graph.py +++ b/smaug/python/graph.py @@ -7,23 +7,23 @@ from smaug.core import types_pb2 from smaug.core import tensor_pb2 from smaug.python import global_vars +from smaug.python.node import Node from smaug.python.tensor import Tensor class Graph: def __init__( self, name="DefaultGraph", backend="Reference", mem_policy=types_pb2.AllDma): - assert (backend in global_vars.backend_alignment) - self.graph = graph_pb2.GraphProto() + if backend not in global_vars.backend_alignment: + raise ValueError("An unknown backend %s is used!" % backend) + self._name = name + self._backend = backend + self._mem_policy = mem_policy + self._nodes = [] self._node_names = {} - self.graph.name = name - self.graph.backend = backend - self.graph.mem_policy = mem_policy - self.alignment = global_vars.backend_alignment[backend] + self._alignment = global_vars.backend_alignment[self._backend] # Layout transformation is enabled by default. self._layout_trans_enabled = True - # This proto stores all the parameters in the network. - self.tensor_data_array = tensor_pb2.TensorDataArray() self._parent_graph = None def __enter__(self): @@ -39,11 +39,15 @@ def __exit__(self, *args): @property def backend(self): - return self.graph.backend + return self._backend @property def mem_policy(self): - return self.graph.mem_policy + return self._mem_policy + + @property + def alignment(self): + return self._alignment @property def layout_trans_enabled(self): @@ -51,8 +55,12 @@ def layout_trans_enabled(self): def merge(self, other): """Merge another graph into this.""" + for node in other.get_nodes(): + if self.get_node(node.name) is not None: + raise ValueError( + "The graph to be merged contains a node with the same name as one " + "in the current graph. Possibly merging a graph more than once?") self.get_nodes().extend(other.get_nodes()) - self.tensor_data_array.data_array.extend(other.tensor_data_array.data_array) def add_node( self, name, op, input_tensors, output_tensors_dims, @@ -75,24 +83,16 @@ def add_node( Returns: The output tensor of the added node. """ - node = self.graph.nodes.add() - node.name = self.create_unique_name(name) - node.op = op - - # Add the parameters to the node. - if params != None: - node.params.CopyFrom(params) + name = self.create_unique_name(name) + node = Node(name, op, params) + self._nodes.append(node) - # Update the node's parents field, and add every input tensor to the node. + # Add every input tensor to the node. for i,tensor in enumerate(input_tensors): if tensor.name == None: tensor.name = node.name + "/input%d" % i - if tensor.source is not None: - node.parents.append(tensor.source[0].name) - node.src_tensors_indices.append(tensor.source[1]) + node.add_input(tensor) tensor.targets.append(node) - input_tensor_proto = node.input_tensors.add() - tensor.to_tensor_proto(input_tensor_proto, self.tensor_data_array) # Create the output tensor (with the node as its source), and add it to the # node. @@ -101,33 +101,33 @@ def add_node( output_tensor = Tensor( dims=d, name="%s/output%d" % (node.name, i), data_layout=output_tensor_layout, data_type=output_tensor_dtype, - data_format=output_tensor_dformat, source=(node, i), - alignment=self.alignment) - output_tensor_proto = node.output_tensors.add() - output_tensor.to_tensor_proto(output_tensor_proto, self.tensor_data_array) + data_format=output_tensor_dformat, source=node, source_index=i, + alignment=self._alignment) + node.add_output(output_tensor) output_tensors.append(output_tensor) return output_tensors def get_node(self, node_name, recursive=False): - """Return a node in the graph proto by its name. + """Return a node in the graph by its name. Args: node_name: Node name. recursive: If true, recursively search the node in the parent graphs. Returns: - A NodeProto if we find the node. + A `Node` if we find the node or None is returned. """ - for i in range(len(self.graph.nodes)): - if self.graph.nodes[i].name == node_name: - return self.graph.nodes[i] + for node in self._nodes: + if node.name == node_name: + return node if recursive and self._parent_graph is not None: return self._parent_graph.get_node(node_name, True) + return None def get_nodes(self): - """Return nodes in the graph proto.""" - return self.graph.nodes + """Return nodes in the graph.""" + return self._nodes def get_root_graph(self): """Return the root graph.""" @@ -167,6 +167,21 @@ def enable_layout_transform(self): """Enable automatic layout transformation.""" self._layout_trans_enabled = True + def to_proto(self): + """Serialize the graph. + + Returns: + A tuple of (`GraphProto`, `TensorDataArray`). + """ + graph_proto = graph_pb2.GraphProto() + graph_proto.name = self._name + graph_proto.backend = self._backend + graph_proto.mem_policy = self._mem_policy + tensor_data_array = tensor_pb2.TensorDataArray() + for node in self._nodes: + graph_proto.nodes.append(node.to_proto(tensor_data_array)) + return graph_proto, tensor_data_array + def write_graph(self, name=None): """Serialize the graph to a protobuf file. @@ -174,12 +189,14 @@ def write_graph(self, name=None): name: Name of the output protobuf file. If not specified, use the graph's name instead. """ - if name == None: - topo_name = self.graph.name + "_topo.pbtxt" - params_name = self.graph.name + "_params.pb" + graph_proto, tensor_data_array = self.to_proto() + if name is None: + name = self._name + topo_name = name + "_topo.pbtxt" + params_name = name + "_params.pb" with open(topo_name, "w") as f_topo, open(params_name, "wb") as f_params: - f_topo.write(text_format.MessageToString(self.graph)) - f_params.write(self.tensor_data_array.SerializeToString()) + f_topo.write(text_format.MessageToString(graph_proto)) + f_params.write(tensor_data_array.SerializeToString()) def print_summary(self): """Print the summary of the graph. @@ -189,28 +206,45 @@ def print_summary(self): input/output tensors. """ print("=================================================================") - print(" Summary of the network: %s (%s)" % (self.graph.name, - self.graph.backend)) + print(" Summary of the network: %s (%s)" % (self._name, self._backend)) print("=================================================================") print( "Host memory access policy: %s." % - types_pb2.HostMemoryAccessPolicy.Name(self.graph.mem_policy)) + types_pb2.HostMemoryAccessPolicy.Name(self._mem_policy)) print("-----------------------------------------------------------------") - for node in self.graph.nodes: + for node in self._nodes: print("Name: %s (%s)" % (node.name, types_pb2.OpType.Name(node.op))) print("Parents:", end = '') - for i in node.parents: + for i in node.get_parents(): + print(i, end = ' ') + print("\nChildren:", end = '') + for i in node.get_children(): print(i, end = ' ') print("\nInput tensors:") - for t in node.input_tensors: + for t in node.inputs: print( " ", t.name, types_pb2.DataType.Name(t.data_type), t.shape.dims, types_pb2.DataLayout.Name(t.shape.layout), "alignment(%d)" % t.shape.alignment) print("Output tensors:") - for t in node.output_tensors: + for t in node.outputs: print( " ", t.name, types_pb2.DataType.Name(t.data_type), t.shape.dims, types_pb2.DataLayout.Name(t.shape.layout), "alignment(%d)" % t.shape.alignment) print("-----------------------------------------------------------------") + +def get_node_proto(graph_proto, node_name): + """Get a `NodeProto` from `GraphProto` by node name. + + Args: + graph_proto: A `GraphProto`. + node_name: Name of the node. + + Returns: + A `NodeProto` or None. + """ + for node_proto in graph_proto.nodes: + if node_proto.name == node_name: + return node_proto + return None diff --git a/smaug/python/node.py b/smaug/python/node.py new file mode 100644 index 00000000..0e0f04e3 --- /dev/null +++ b/smaug/python/node.py @@ -0,0 +1,122 @@ +import sys +import numpy as np + +from smaug.core import node_pb2 +from smaug.core import types_pb2 +from smaug.python import global_vars +from smaug.python import datatypes + +class Node: + def __init__(self, name, op, params=None, inputs=None, outputs=None): + """Create a node. + + A `Node` instance contains information about its corresponding operation, + including the operator type, parameters and input/output tensors. A `Graph` + is made up of `Node`s. When serialized, a `NodeProto` is created. + + Args: + name: Name of the node. + op: `OpType` representing the operation type of the node. + params: `Params` used by the operator (optional). + inputs: A list of `Tensor` (optional). + outputs: A list of `Tensor` (optional). + + Returns: + A `Node` instance. + """ + self._name = name + self._op = op + self._params = params + self._inputs = [] if inputs is None else inputs + self._outputs = [] if outputs is None else outputs + + @property + def name(self): + return self._name + + @property + def op(self): + return self._op + + @property + def inputs(self): + return self._inputs + + @property + def outputs(self): + return self._outputs + + def add_input(self, tensor): + """Add an input tensor to the node. + + Args: + tensor: A `Tensor`. + """ + self._inputs.append(tensor) + + def add_output(self, tensor): + """Add an output tensor to the node. + + Args: + tensor: A `Tensor`. + """ + self._outputs.append(tensor) + + def update_input(self, tensor, index): + """Update the `index`th input with `tensor`. + + Args: + tensor: A `Tensor` representing the new input. + index: The input index. + """ + self._inputs[index] = tensor + + def get_parents(self): + """Get the parents of the node. + + Returns: + A list of strings representing names of the parent nodes. + """ + parents = [] + for tensor in self._inputs: + if tensor.source is not None: + parents.append(tensor.source.name) + return parents + + def get_children(self): + """Get the children of the node. + + Returns: + A list of strings representing names of the children nodes. + """ + children = [] + for tensor in self._outputs: + for target in tensor.targets: + children.append(target.name) + return children + + def to_proto(self, tensor_data_array): + """Serialize `Node` into `NodeProto`. + + Args: + tensor_data_array: `TensorDataArray` that tensor data gets serialized + into. + + Returns: + A `NodeProto`. + """ + node_proto = node_pb2.NodeProto() + node_proto.name = self._name + node_proto.op = self._op + if self._params is not None: + node_proto.params.CopyFrom(self._params) + for tensor in self._inputs: + if tensor.source is not None: + node_proto.parents.append(tensor.source.name) + node_proto.src_tensors_indices.append(tensor.source_index) + tensor_proto = node_proto.input_tensors.add() + tensor.to_tensor_proto(tensor_proto, tensor_data_array) + for tensor in self._outputs: + tensor_proto = node_proto.output_tensors.add() + tensor.to_tensor_proto(tensor_proto, tensor_data_array) + return node_proto diff --git a/smaug/python/ops/activation_ops_test.py b/smaug/python/ops/activation_ops_test.py index 4975bf85..526c1412 100755 --- a/smaug/python/ops/activation_ops_test.py +++ b/smaug/python/ops/activation_ops_test.py @@ -5,7 +5,7 @@ import unittest import numpy as np -from smaug.python.graph import Graph +from smaug.python.graph import Graph, get_node_proto from smaug.python.tensor import Tensor from smaug.python.ops import data_op from smaug.python.ops import activation_ops @@ -19,7 +19,8 @@ def __init__(self, *args, **kwargs): self.tensor_shape = [2, 32, 32, 32] def do_basic_test(self, test_graph, node_name, op_type, tensor_shape=None): - node = test_graph.get_node(node_name) + graph_proto, _ = test_graph.to_proto() + node = get_node_proto(graph_proto, node_name) if tensor_shape == None: tensor_shape = self.tensor_shape self.assertEqual(node.op, op_type) @@ -129,44 +130,45 @@ def test_fusing_activation_functions(self): activation=types_pb2.ReLU, name="bn_relu") act = nn_ops.mat_mul( act, weight_tensor, activation=types_pb2.ReLU, name="fc_relu") + graph_proto, _ = test_graph.to_proto() # None - node = test_graph.get_node("conv_none") + node = get_node_proto(graph_proto, "conv_none") self.assertEqual(node.params.act_params.activation, types_pb2.UnknownOp) # ReLU - node = test_graph.get_node("conv_relu") + node = get_node_proto(graph_proto, "conv_relu") self.assertEqual(node.params.act_params.activation, types_pb2.ReLU) # LReLU (default slope = 0.2) - node = test_graph.get_node("conv_lrelu") + node = get_node_proto(graph_proto, "conv_lrelu") self.assertEqual(node.params.act_params.activation, types_pb2.LReLU) self.assertAlmostEqual(node.params.act_params.lrelu_params.slope, 0.2) # ELU (default alpha = 0.1) - node = test_graph.get_node("conv_elu") + node = get_node_proto(graph_proto, "conv_elu") self.assertEqual(node.params.act_params.activation, types_pb2.ELU) self.assertAlmostEqual(node.params.act_params.elu_params.alpha, 0.1) # SELU (default alpha = 1.6733, lambda = 1.0507) - node = test_graph.get_node("conv_selu") + node = get_node_proto(graph_proto, "conv_selu") self.assertEqual(node.params.act_params.activation, types_pb2.SELU) self.assertAlmostEqual(node.params.act_params.elu_params.alpha, 1.6733, 5) self.assertAlmostEqual(node.params.act_params.elu_params.lambda_param, 1.0507, 5) # Tanh - node = test_graph.get_node("conv_tanh") + node = get_node_proto(graph_proto, "conv_tanh") self.assertEqual(node.params.act_params.activation, types_pb2.Tanh) # HardTanh (default min = -1, max = 1) - node = test_graph.get_node("conv_hard_tanh") + node = get_node_proto(graph_proto, "conv_hard_tanh") self.assertEqual(node.params.act_params.activation, types_pb2.HardTanh) self.assertAlmostEqual(node.params.act_params.hard_tanh_params.min, -1) self.assertAlmostEqual(node.params.act_params.hard_tanh_params.max, 1) # Sigmoid - node = test_graph.get_node("conv_sigmoid") + node = get_node_proto(graph_proto, "conv_sigmoid") self.assertEqual(node.params.act_params.activation, types_pb2.Sigmoid) # Softmax - node = test_graph.get_node("conv_softmax") + node = get_node_proto(graph_proto, "conv_softmax") self.assertEqual(node.params.act_params.activation, types_pb2.Softmax) # Fusion with inner products and batch norms. - node = test_graph.get_node("bn_relu") + node = get_node_proto(graph_proto, "bn_relu") self.assertEqual(node.params.act_params.activation, types_pb2.ReLU) - node = test_graph.get_node("fc_relu") + node = get_node_proto(graph_proto, "fc_relu") self.assertEqual(node.params.act_params.activation, types_pb2.ReLU) if __name__ == "__main__": diff --git a/smaug/python/ops/common.py b/smaug/python/ops/common.py index 5a2c3d41..1ad469d3 100644 --- a/smaug/python/ops/common.py +++ b/smaug/python/ops/common.py @@ -23,7 +23,7 @@ def check_and_add_layout_transform(name, op, input_tensors): from smaug.python.ops.array_ops import reorder if not global_vars.get_graph().layout_trans_enabled: return input_tensors - backend = global_vars.get_graph().graph.backend + backend = global_vars.get_graph().backend for i in range(len(input_tensors)): expected_layoutset = global_vars.backend_layouts[backend][ op].input_layoutsets[i] diff --git a/smaug/python/ops/control_flow_ops.py b/smaug/python/ops/control_flow_ops.py index 65de4676..a7d5bf0b 100644 --- a/smaug/python/ops/control_flow_ops.py +++ b/smaug/python/ops/control_flow_ops.py @@ -75,28 +75,20 @@ def _insert_switch_nodes(predication, branch_result, graph): # This keeps track of all the tensors that come from nodes in the graph. internal_tensors = set() for node in nodes: - internal_tensors.update( - set([tensor.name for tensor in node.output_tensors])) + internal_tensors.update(set([tensor.name for tensor in node.outputs])) for node in nodes: - for i, tensor_proto in enumerate(node.input_tensors): + for i, tensor in enumerate(node.inputs): # If any input tensor of the graph appear in the graph workspace, then # this tensor is an external to the graph and we create a switch node # for it. # Don't create switch node for an existing one. if node.op == types_pb2.Switch: continue - if tensor_proto.name not in internal_tensors: - source_node = graph.get_node(node.parents[i], True) - tensor = tensor_utils.from_tensor_proto(tensor_proto) - if source_node is not None: - tensor.source = (source_node, node.src_tensors_indices[i]) + if tensor.name not in internal_tensors: switch_result = switch( tensor, predication)[switch_op_output_ports[branch_result]] - # Update the node with the switch node as its new parent. - switch_result.to_tensor_proto(node.input_tensors[i]) - switch_node = switch_result.source[0] - node.parents[i] = switch_node.name - node.src_tensors_indices[i] = switch_op_output_ports[branch_result] + # Update the node's input with the switch node result. + node.update_input(switch_result, i) cur_graph = global_vars.get_graph() backend = cur_graph.backend @@ -110,7 +102,6 @@ def _insert_switch_nodes(predication, branch_result, graph): if not isinstance(res_t, (list, tuple)): res_t = [res_t] _insert_switch_nodes(predication, "true", subgraph_t) - cur_graph.merge(subgraph_t) # Build the subgraph for the false branch. with Graph(name="%s_false_branch" % name, backend=backend, @@ -119,7 +110,6 @@ def _insert_switch_nodes(predication, branch_result, graph): if not isinstance(res_f, (list, tuple)): res_f = [res_f] _insert_switch_nodes(predication, "false", subgraph_f) - cur_graph.merge(subgraph_f) # Add the merge nodes for the outputs. merges = [merge([t, f]) for (t, f) in zip(res_t, res_f)] diff --git a/smaug/python/ops/ops_test.py b/smaug/python/ops/ops_test.py index 02e6d2fb..bc3545c3 100755 --- a/smaug/python/ops/ops_test.py +++ b/smaug/python/ops/ops_test.py @@ -5,7 +5,7 @@ import unittest import numpy as np -from smaug.python.graph import Graph +from smaug.python.graph import Graph, get_node_proto from smaug.python.tensor import Tensor from smaug.python.ops import data_op from smaug.python.ops import activation_ops @@ -45,6 +45,9 @@ def assertEqualDims(self, dims, layout, expected_dims, expected_layout): else: assert False, "Other layouts not expected here!" + def get_node(self, name): + return get_node_proto(self.test_graph, name) + def build_test_sequential_graph(self, backend): """Create a sequential model.""" np_dtype = test_backend_dtypes[backend] @@ -100,7 +103,7 @@ def build_test_sequential_graph(self, backend): out = array_ops.stack(out, 4, 1, "stack") out0, out1, out2, out3 = array_ops.unstack(out, 1, "unstack") - self.test_graph = graph + self.test_graph, _ = graph.to_proto() self.backend = backend self.alignment = global_vars.backend_alignment[backend] @@ -160,16 +163,16 @@ def build_test_residual_graph(self, backend): math_ops.add(out0, out1, "add1"), math_ops.add(out2, out3, "add2"), "mul1") - self.test_graph = graph + self.test_graph, _ = graph.to_proto() self.backend = backend self.alignment = global_vars.backend_alignment[ - self.test_graph.graph.backend] + self.test_graph.backend] class SequentialGraphTest(OperatorTest): """Common tests for the sequential graph.""" def test_input_op(self): - node = self.test_graph.get_node("input") + node = self.get_node("input") self.assertEqual(node.op, types_pb2.Data) self.assertEqual(len(node.input_tensors), 1) self.assertEqual(len(node.output_tensors), 1) @@ -186,7 +189,7 @@ def test_convolution_op(self): expected_output_layout = global_vars.backend_layouts[self.backend][ types_pb2.Convolution3d].output_layoutset.layouts # The first convolution operator "conv0" - node = self.test_graph.get_node("conv0") + node = self.get_node("conv0") self.assertEqual(node.op, types_pb2.Convolution3d) self.assertEqual(len(node.input_tensors), 2) self.assertEqual(len(node.output_tensors), 1) @@ -211,7 +214,7 @@ def test_convolution_op(self): self.assertEqual(node.output_tensors[0].shape.alignment, self.alignment) # The second convolution operator "conv1" - node = self.test_graph.get_node("conv1") + node = self.get_node("conv1") self.assertEqual(node.op, types_pb2.Convolution3d) self.assertEqual(len(node.input_tensors), 2) self.assertEqual(len(node.output_tensors), 1) @@ -237,7 +240,7 @@ def test_convolution_op(self): def test_relu_op(self): # The first relu operator "conv0_relu" - node = self.test_graph.get_node("conv0_relu") + node = self.get_node("conv0_relu") self.assertEqual(node.op, types_pb2.ReLU) self.assertEqual(len(node.input_tensors), 1) self.assertEqual(len(node.output_tensors), 1) @@ -252,7 +255,7 @@ def test_relu_op(self): self.assertEqual(node.output_tensors[0].shape.alignment, self.alignment) # The second relu operator "conv1_relu" - node = self.test_graph.get_node("conv1_relu") + node = self.get_node("conv1_relu") self.assertEqual(node.op, types_pb2.ReLU) self.assertEqual(len(node.input_tensors), 1) self.assertEqual(len(node.output_tensors), 1) @@ -267,7 +270,7 @@ def test_relu_op(self): self.assertEqual(node.output_tensors[0].shape.alignment, self.alignment) # The third relu operator "fc0_relu" - node = self.test_graph.get_node("fc0_relu") + node = self.get_node("fc0_relu") self.assertEqual(node.op, types_pb2.ReLU) self.assertEqual(len(node.input_tensors), 1) self.assertEqual(len(node.output_tensors), 1) @@ -280,7 +283,7 @@ def test_relu_op(self): self.assertEqual(node.output_tensors[0].shape.alignment, self.alignment) def test_batch_norm_op(self): - node = self.test_graph.get_node("bn") + node = self.get_node("bn") self.assertEqual(node.op, types_pb2.BatchNorm) self.assertEqual(len(node.input_tensors), 5) self.assertEqual(len(node.output_tensors), 1) @@ -303,7 +306,7 @@ def test_batch_norm_op(self): def test_max_pool_op(self): expected_output_layout = global_vars.backend_layouts[self.backend][ types_pb2.MaxPooling].output_layoutset.layouts - node = self.test_graph.get_node("pool") + node = self.get_node("pool") self.assertEqual(node.op, types_pb2.MaxPooling) self.assertEqual(len(node.input_tensors), 1) self.assertEqual(len(node.output_tensors), 1) @@ -321,7 +324,7 @@ def test_max_pool_op(self): self.assertEqual(node.output_tensors[0].shape.alignment, self.alignment) def test_flatten_op(self): - node = self.test_graph.get_node("flatten") + node = self.get_node("flatten") self.assertEqual(node.op, types_pb2.Reorder) self.assertEqual(len(node.input_tensors), 1) self.assertEqual(len(node.output_tensors), 1) @@ -338,7 +341,7 @@ def test_mat_mul_op(self): expected_output_layout = global_vars.backend_layouts[self.backend][ types_pb2.InnerProduct].output_layoutset.layouts # The first mat_mul operator "fc0" - node = self.test_graph.get_node("fc0") + node = self.get_node("fc0") self.assertEqual(node.op, types_pb2.InnerProduct) self.assertEqual(len(node.input_tensors), 2) self.assertEqual(len(node.output_tensors), 1) @@ -360,7 +363,7 @@ def test_mat_mul_op(self): self.assertEqual(node.output_tensors[0].shape.alignment, self.alignment) # The second mat_mul operator "fc1" - node = self.test_graph.get_node("fc1") + node = self.get_node("fc1") self.assertEqual(node.op, types_pb2.InnerProduct) self.assertEqual(len(node.input_tensors), 2) self.assertEqual(len(node.output_tensors), 1) @@ -382,7 +385,7 @@ def test_mat_mul_op(self): self.assertEqual(node.output_tensors[0].shape.alignment, self.alignment) def test_expand_dims_op(self): - node = self.test_graph.get_node("expand_dims") + node = self.get_node("expand_dims") self.assertEqual(node.op, types_pb2.Reshape) self.assertEqual(len(node.input_tensors), 1) self.assertEqual(len(node.output_tensors), 1) @@ -394,7 +397,7 @@ def test_expand_dims_op(self): self.assertEqual(node.output_tensors[0].shape.alignment, self.alignment) def test_squeeze_op(self): - node = self.test_graph.get_node("squeeze") + node = self.get_node("squeeze") self.assertEqual(node.op, types_pb2.Reshape) self.assertEqual(len(node.input_tensors), 1) self.assertEqual(len(node.output_tensors), 1) @@ -406,7 +409,7 @@ def test_squeeze_op(self): self.assertEqual(node.output_tensors[0].shape.alignment, self.alignment) def test_reshape_op(self): - node = self.test_graph.get_node("reshape") + node = self.get_node("reshape") self.assertEqual(node.op, types_pb2.Reshape) self.assertEqual(len(node.input_tensors), 1) self.assertEqual(len(node.output_tensors), 1) @@ -418,7 +421,7 @@ def test_reshape_op(self): self.assertEqual(node.output_tensors[0].shape.alignment, self.alignment) def test_repeat_op(self): - node = self.test_graph.get_node("repeat") + node = self.get_node("repeat") self.assertEqual(node.op, types_pb2.Repeat) self.assertEqual(len(node.input_tensors), 1) self.assertEqual(len(node.output_tensors), 1) @@ -432,7 +435,7 @@ def test_repeat_op(self): def test_stack_op(self): # stack op is implemented using expand_dims and repeat. Here we only test # the output. - node = self.test_graph.get_node("stack:repeat") + node = self.get_node("stack:repeat") self.assertEqual(node.output_tensors[0].data_type, self.expected_dtype) self.assertEqual(node.output_tensors[0].shape.dims, [8, 4, 10]) self.assertEqual(node.output_tensors[0].shape.layout, types_pb2.NTC) @@ -442,7 +445,7 @@ def test_unstack_op(self): # unstack op is implemented using split and squeeze. Here we only test # the output. for i in range(4): - node = self.test_graph.get_node("unstack:squeeze" + + node = self.get_node("unstack:squeeze" + ("_%d" % i if i > 0 else "")) self.assertEqual(node.output_tensors[0].data_type, self.expected_dtype) self.assertEqual(node.output_tensors[0].shape.dims, [8, 10]) @@ -453,7 +456,7 @@ class ResidualGraphTest(OperatorTest): """Common tests for the residual graph.""" def test_input_op(self): - node = self.test_graph.get_node("input") + node = self.get_node("input") self.assertEqual(node.op, types_pb2.Data) self.assertEqual(len(node.input_tensors), 1) self.assertEqual(len(node.output_tensors), 1) @@ -470,7 +473,7 @@ def test_convolution_op(self): expected_output_layout = global_vars.backend_layouts[self.backend][ types_pb2.Convolution3d].output_layoutset.layouts # The first convolution operator "conv0" - node = self.test_graph.get_node("conv0") + node = self.get_node("conv0") self.assertEqual(node.op, types_pb2.Convolution3d) self.assertEqual(len(node.input_tensors), 2) self.assertEqual(len(node.output_tensors), 1) @@ -495,7 +498,7 @@ def test_convolution_op(self): self.assertEqual(node.output_tensors[0].shape.alignment, self.alignment) # The second convolution operator "conv1" - node = self.test_graph.get_node("conv1") + node = self.get_node("conv1") self.assertEqual(node.op, types_pb2.Convolution3d) self.assertEqual(len(node.input_tensors), 2) self.assertEqual(len(node.output_tensors), 1) @@ -520,7 +523,7 @@ def test_convolution_op(self): self.assertEqual(node.output_tensors[0].shape.alignment, self.alignment) # The third convolution operator "conv2" - node = self.test_graph.get_node("conv2") + node = self.get_node("conv2") self.assertEqual(node.op, types_pb2.Convolution3d) self.assertEqual(len(node.input_tensors), 2) self.assertEqual(len(node.output_tensors), 1) @@ -545,7 +548,7 @@ def test_convolution_op(self): self.assertEqual(node.output_tensors[0].shape.alignment, self.alignment) def test_relu_op(self): - node = self.test_graph.get_node("relu") + node = self.get_node("relu") self.assertEqual(node.op, types_pb2.ReLU) self.assertEqual(len(node.input_tensors), 1) self.assertEqual(len(node.output_tensors), 1) @@ -560,7 +563,7 @@ def test_relu_op(self): self.assertEqual(node.output_tensors[0].shape.alignment, self.alignment) def test_batch_norm_op(self): - node = self.test_graph.get_node("bn") + node = self.get_node("bn") self.assertEqual(node.op, types_pb2.BatchNorm) self.assertEqual(len(node.input_tensors), 5) self.assertEqual(len(node.output_tensors), 1) @@ -582,7 +585,7 @@ def test_batch_norm_op(self): def test_add_op(self): # The first add operator (add) - node = self.test_graph.get_node("add") + node = self.get_node("add") self.assertEqual(node.op, types_pb2.EltwiseAdd) self.assertEqual(len(node.input_tensors), 2) self.assertEqual(len(node.output_tensors), 1) @@ -597,7 +600,7 @@ def test_add_op(self): self.assertEqual(node.output_tensors[0].shape.alignment, self.alignment) # The second add operator (add1) - node = self.test_graph.get_node("add1") + node = self.get_node("add1") self.assertEqual(node.op, types_pb2.EltwiseAdd) self.assertEqual(len(node.input_tensors), 2) self.assertEqual(len(node.output_tensors), 1) @@ -612,7 +615,7 @@ def test_add_op(self): self.assertEqual(node.output_tensors[0].shape.alignment, self.alignment) # The third add operator (add2) - node = self.test_graph.get_node("add2") + node = self.get_node("add2") self.assertEqual(node.op, types_pb2.EltwiseAdd) self.assertEqual(len(node.input_tensors), 2) self.assertEqual(len(node.output_tensors), 1) @@ -628,7 +631,7 @@ def test_add_op(self): def test_mul_op(self): # The first mul operator (mul) - node = self.test_graph.get_node("mul") + node = self.get_node("mul") self.assertEqual(node.op, types_pb2.EltwiseMul) self.assertEqual(len(node.input_tensors), 2) self.assertEqual(len(node.output_tensors), 1) @@ -643,7 +646,7 @@ def test_mul_op(self): self.assertEqual(node.output_tensors[0].shape.alignment, self.alignment) # The second add operator (mul1) - node = self.test_graph.get_node("mul1") + node = self.get_node("mul1") self.assertEqual(node.op, types_pb2.EltwiseMul) self.assertEqual(len(node.input_tensors), 2) self.assertEqual(len(node.output_tensors), 1) @@ -658,7 +661,7 @@ def test_mul_op(self): self.assertEqual(node.output_tensors[0].shape.alignment, self.alignment) def test_concat_op(self): - node = self.test_graph.get_node("concat") + node = self.get_node("concat") self.assertEqual(node.op, types_pb2.Concat) self.assertEqual(len(node.input_tensors), 2) self.assertEqual(len(node.output_tensors), 1) @@ -673,7 +676,7 @@ def test_concat_op(self): self.assertEqual(node.output_tensors[0].shape.alignment, self.alignment) def test_split_op(self): - node = self.test_graph.get_node("split") + node = self.get_node("split") self.assertEqual(node.op, types_pb2.Split) self.assertEqual(len(node.input_tensors), 1) self.assertEqual(len(node.output_tensors), 4) @@ -702,64 +705,64 @@ def test_parent_children(self): between layers, so we delete this test from the above common tests. """ # input (Data). - node = self.test_graph.get_node("input") + node = self.get_node("input") self.assertEqual(len(node.parents), 0) # Reorder input from NCHW to NHWC. - node = self.test_graph.get_node("reorder") + node = self.get_node("reorder") self.assertEqual(node.parents[0], "input") # conv0 (Convolution). - node = self.test_graph.get_node("conv0") + node = self.get_node("conv0") self.assertEqual(node.parents[0], "reorder") # conv0_relu (ReLU). - node = self.test_graph.get_node("conv0_relu") + node = self.get_node("conv0_relu") self.assertEqual(node.parents[0], "conv0") # bn (BN). - node = self.test_graph.get_node("bn") + node = self.get_node("bn") self.assertEqual(node.parents[0], "conv0_relu") # conv1 (Convolution). - node = self.test_graph.get_node("conv1") + node = self.get_node("conv1") self.assertEqual(node.parents[0], "bn") # conv1_relu (ReLU). - node = self.test_graph.get_node("conv1_relu") + node = self.get_node("conv1_relu") self.assertEqual(node.parents[0], "conv1") # pool (MaxPooling). - node = self.test_graph.get_node("pool") + node = self.get_node("pool") self.assertEqual(node.parents[0], "conv1_relu") # flatten (Flatten). - node = self.test_graph.get_node("flatten") + node = self.get_node("flatten") self.assertEqual(node.parents[0], "pool") # fc0 (FC). - node = self.test_graph.get_node("fc0") + node = self.get_node("fc0") self.assertEqual(node.parents[0], "flatten") # fc0_relu (ReLU) - node = self.test_graph.get_node("fc0_relu") + node = self.get_node("fc0_relu") self.assertEqual(node.parents[0], "fc0") # fc1 (FC). - node = self.test_graph.get_node("fc1") + node = self.get_node("fc1") self.assertEqual(node.parents[0], "fc0_relu") # expand_dims (Reshape). - node = self.test_graph.get_node("expand_dims") + node = self.get_node("expand_dims") self.assertEqual(node.parents[0], "fc1") # squeeze (Reshape). - node = self.test_graph.get_node("squeeze") + node = self.get_node("squeeze") self.assertEqual(node.parents[0], "expand_dims") # reshape (Reshape). - node = self.test_graph.get_node("reshape") + node = self.get_node("reshape") self.assertEqual(node.parents[0], "squeeze") # repeat (Repeat). - node = self.test_graph.get_node("repeat") + node = self.get_node("repeat") self.assertEqual(node.parents[0], "reshape") # stack (Reshape + Repeat). - node = self.test_graph.get_node("stack:expand_dims") + node = self.get_node("stack:expand_dims") self.assertEqual(node.parents[0], "repeat") - node = self.test_graph.get_node("stack:repeat") + node = self.get_node("stack:repeat") self.assertEqual(node.parents[0], "stack:expand_dims") # unstack (Split + Squeeze). - node = self.test_graph.get_node("unstack:split") + node = self.get_node("unstack:split") self.assertEqual(node.parents[0], "stack:repeat") for i in range(4): child_name = "unstack:squeeze" + ("_%d" % i if i > 0 else "") - child_node = self.test_graph.get_node(child_name) + child_node = self.get_node(child_name) self.assertEqual(child_node.parents[0], "unstack:split") self.assertEqual(child_node.src_tensors_indices, [i]) @@ -778,65 +781,65 @@ def test_parent_children(self): """Test the parent/child relationship in the graph.""" # input (Data). - node = self.test_graph.get_node("input") + node = self.get_node("input") self.assertEqual(len(node.parents), 0) # conv0 (Convolution). - node = self.test_graph.get_node("conv0") + node = self.get_node("conv0") self.assertEqual(node.parents[0], "input") # conv0_relu (ReLU). - node = self.test_graph.get_node("conv0_relu") + node = self.get_node("conv0_relu") self.assertEqual(node.parents[0], "conv0") # bn (BN) - node = self.test_graph.get_node("bn") + node = self.get_node("bn") self.assertEqual(node.parents[0], "conv0_relu") # conv1 (Convolution). - node = self.test_graph.get_node("conv1") + node = self.get_node("conv1") self.assertEqual(node.parents[0], "bn") # conv1_relu (ReLU) - node = self.test_graph.get_node("conv1_relu") + node = self.get_node("conv1_relu") self.assertEqual(node.parents[0], "conv1") # pool (MaxPooling) - node = self.test_graph.get_node("pool") + node = self.get_node("pool") self.assertEqual(node.parents[0], "conv1_relu") # flatten (Flatten) - node = self.test_graph.get_node("flatten") + node = self.get_node("flatten") self.assertEqual(node.parents[0], "pool") # Transpose fc0 weights - node = self.test_graph.get_node("reorder") + node = self.get_node("reorder") # fc0 (FC) - node = self.test_graph.get_node("fc0") + node = self.get_node("fc0") self.assertEqual(node.parents, ["flatten", "reorder"]) # fc0_relu (ReLU) - node = self.test_graph.get_node("fc0_relu") + node = self.get_node("fc0_relu") self.assertEqual(node.parents[0], "fc0") # Transpose fc1/weights - node = self.test_graph.get_node("reorder_1") + node = self.get_node("reorder_1") # fc1 (FC) - node = self.test_graph.get_node("fc1") + node = self.get_node("fc1") self.assertEqual(node.parents, ["fc0_relu", "reorder_1"]) # expand_dims (Reshape). - node = self.test_graph.get_node("expand_dims") + node = self.get_node("expand_dims") self.assertEqual(node.parents[0], "fc1") # squeeze (Reshape). - node = self.test_graph.get_node("squeeze") + node = self.get_node("squeeze") self.assertEqual(node.parents[0], "expand_dims") # reshape (Reshape) - node = self.test_graph.get_node("reshape") + node = self.get_node("reshape") self.assertEqual(node.parents[0], "squeeze") # repeat (Repeat) - node = self.test_graph.get_node("repeat") + node = self.get_node("repeat") self.assertEqual(node.parents[0], "reshape") # stack (Reshape + Repeat). - node = self.test_graph.get_node("stack:expand_dims") + node = self.get_node("stack:expand_dims") self.assertEqual(node.parents[0], "repeat") - node = self.test_graph.get_node("stack:repeat") + node = self.get_node("stack:repeat") self.assertEqual(node.parents[0], "stack:expand_dims") # unstack (Split + Squeeze). - node = self.test_graph.get_node("unstack:split") + node = self.get_node("unstack:split") self.assertEqual(node.parents[0], "stack:repeat") for i in range(4): child_name = "unstack:squeeze" + ("_%d" % i if i > 0 else "") - child_node = self.test_graph.get_node(child_name) + child_node = self.get_node(child_name) self.assertEqual(child_node.parents[0], "unstack:split") self.assertEqual(child_node.src_tensors_indices, [i]) @@ -851,49 +854,49 @@ def test_parent_children(self): """Test the parent/child relationship in the graph.""" # input (Data). - node = self.test_graph.get_node("input") + node = self.get_node("input") self.assertEqual(len(node.parents), 0) # Reorder input from NCHW to NHWC. - node = self.test_graph.get_node("reorder") + node = self.get_node("reorder") self.assertEqual(node.parents[0], "input") # conv0 (Convolution). - node = self.test_graph.get_node("conv0") + node = self.get_node("conv0") self.assertEqual(node.parents[0], "reorder") # conv1 (Convolution). - node = self.test_graph.get_node("conv1") + node = self.get_node("conv1") self.assertEqual(node.parents[0], "reorder") # bn (BN). - node = self.test_graph.get_node("bn") + node = self.get_node("bn") self.assertEqual(node.parents[0], "conv1") # relu (ReLU). - node = self.test_graph.get_node("relu") + node = self.get_node("relu") self.assertEqual(node.parents[0], "bn") # conv2 (Convolution). - node = self.test_graph.get_node("conv2") + node = self.get_node("conv2") self.assertEqual(node.parents[0], "relu") # add (EltwiseAdd). - node = self.test_graph.get_node("add") + node = self.get_node("add") self.assertEqual(node.parents[0], "conv0") self.assertEqual(node.parents[1], "conv2") # mul (EltwiseMul). - node = self.test_graph.get_node("mul") + node = self.get_node("mul") self.assertEqual(node.parents, ["conv0", "add"]) # concat (Concat). - node = self.test_graph.get_node("concat") + node = self.get_node("concat") self.assertEqual(node.parents, ["conv0", "mul"]) # split (Split). - node = self.test_graph.get_node("split") + node = self.get_node("split") self.assertEqual(node.parents[0], "concat") # add1 (EltwiseAdd). - node = self.test_graph.get_node("add1") + node = self.get_node("add1") self.assertEqual(node.parents, ["split", "split"]) self.assertEqual(node.src_tensors_indices, [0, 1]) # add2 (EltwiseAdd). - node = self.test_graph.get_node("add2") + node = self.get_node("add2") self.assertEqual(node.parents, ["split", "split"]) self.assertEqual(node.src_tensors_indices, [2, 3]) # mul (EltwiseMul). - node = self.test_graph.get_node("mul1") + node = self.get_node("mul1") self.assertEqual(node.parents, ["add1", "add2"]) class RefResidualGraphTest(unittest.TestCase, ResidualGraphTest): @@ -907,45 +910,45 @@ def test_parent_children(self): """Test the parent/child relationship in the graph.""" # input (Data). - node = self.test_graph.get_node("input") + node = self.get_node("input") self.assertEqual(len(node.parents), 0) # conv0 (Convolution). - node = self.test_graph.get_node("conv0") + node = self.get_node("conv0") self.assertEqual(node.parents[0], "input") # conv1 (Convolution). - node = self.test_graph.get_node("conv1") + node = self.get_node("conv1") self.assertEqual(node.parents[0], "input") # bn (BN). - node = self.test_graph.get_node("bn") + node = self.get_node("bn") self.assertEqual(node.parents[0], "conv1") # relu (ReLU). - node = self.test_graph.get_node("relu") + node = self.get_node("relu") self.assertEqual(node.parents[0], "bn") # conv2 (Convolution). - node = self.test_graph.get_node("conv2") + node = self.get_node("conv2") self.assertEqual(node.parents[0], "relu") # add (EltwiseAdd). - node = self.test_graph.get_node("add") + node = self.get_node("add") self.assertEqual(node.parents, ["conv0", "conv2"]) # mul (EltwiseMul). - node = self.test_graph.get_node("mul") + node = self.get_node("mul") self.assertEqual(node.parents, ["conv0", "add"]) # concat (Concat). - node = self.test_graph.get_node("concat") + node = self.get_node("concat") self.assertEqual(node.parents, ["conv0", "mul"]) # split (Split). - node = self.test_graph.get_node("split") + node = self.get_node("split") self.assertEqual(node.parents[0], "concat") # add1 (EltwiseAdd). - node = self.test_graph.get_node("add1") + node = self.get_node("add1") self.assertEqual(node.parents, ["split", "split"]) self.assertEqual(node.src_tensors_indices, [0, 1]) # add2 (EltwiseAdd). - node = self.test_graph.get_node("add2") + node = self.get_node("add2") self.assertEqual(node.parents, ["split", "split"]) self.assertEqual(node.src_tensors_indices, [2, 3]) # mul (EltwiseMul). - node = self.test_graph.get_node("mul1") + node = self.get_node("mul1") self.assertEqual(node.parents, ["add1", "add2"]) if __name__ == "__main__": diff --git a/smaug/python/smaug_test.py b/smaug/python/smaug_test.py index a8ceb5b4..77d8bb18 100644 --- a/smaug/python/smaug_test.py +++ b/smaug/python/smaug_test.py @@ -15,6 +15,7 @@ class SmaugTest(unittest.TestCase): def setUp(self): + self._cwd = os.getcwd() self.run_dir = tempfile.mkdtemp() self.error_filename = os.path.join(self.run_dir, "stderr") self.graph_name = "test_graph" @@ -25,6 +26,7 @@ def setUp(self): def tearDown(self): """ Delete temporary files and outputs. """ shutil.rmtree(self.run_dir) + os.chdir(self._cwd) def launchSubprocess(self, cmd): with open(self.error_filename, "w") as f: diff --git a/smaug/python/subgraph_test.py b/smaug/python/subgraph_test.py index ab277b8b..4a982df7 100755 --- a/smaug/python/subgraph_test.py +++ b/smaug/python/subgraph_test.py @@ -31,7 +31,7 @@ def assertNodesConnected(self, graph, child_parent_map): """Test the connection among nodes.""" for child, parents in child_parent_map.items(): child_node = graph.get_node(child) - self.assertEqual(child_node.parents, parents) + self.assertEqual(child_node.get_parents(), parents) def test_subgraph_merge(self): with Graph(parent_graph_name, backend) as parent_graph: diff --git a/smaug/python/tensor.py b/smaug/python/tensor.py index 3c024f81..5bb43096 100644 --- a/smaug/python/tensor.py +++ b/smaug/python/tensor.py @@ -9,7 +9,7 @@ class Tensor: def __init__( self, dims=None, name=None, data_layout=types_pb2.NCHW, data_type=None, data_format=types_pb2.Uncompressed, tensor_data=None, source=None, - targets=None, alignment=None): + source_index=None, targets=None, alignment=None): """Create a tensor. Args: @@ -19,7 +19,8 @@ def __init__( data_type: Data type of the tensor. data_format: Data format of the tensor. tensor_data: A NumPy array that represents the tensor data. - source: A tuple (source_node, tensor_index) that represents this tensor's + source: A `Node` that represents this tensor's source node. + source_index: An int that represents this tensor's output index in its source node. targets: A list of nodes that use this tensor as inputs. alignment: Data alignment used in the tensor data. @@ -27,57 +28,97 @@ def __init__( Returns: A `Tensor` object. """ - self.shape = tensor_pb2.TensorShapeProto() - self.tensor_data = tensor_data + self._shape = tensor_pb2.TensorShapeProto() + self._tensor_data = tensor_data # If tensor_data is provided, deduce dims and data_type directly from it # (the kwargs are ignored if they are provided). - if self.tensor_data is not None: - self.deduce_attrs_from_data() + if self._tensor_data is not None: + self._deduce_attrs_from_data() else: - self.shape.dims.extend(dims) - self.data_type = data_type - - self.shape.layout = data_layout - self.name = name - self.data_format = data_format - self.source = source - self.targets = [] + self._shape.dims.extend(dims) + self._data_type = data_type + + self._shape.layout = data_layout + self._name = name + self._data_format = data_format + self._source = source + self._source_index = source_index + if self._source is not None and self._source_index is None: + raise ValueError( + "Please provide this tensor's output index in the source node!") + self._targets = [] if alignment != None: - self.shape.alignment = alignment + self._shape.alignment = alignment elif global_vars.get_graph() == None: - self.shape.alignment = 0 + self._shape.alignment = 0 else: - self.shape.alignment = global_vars.get_graph().alignment + self._shape.alignment = global_vars.get_graph().alignment # Do data padding if this Tensor contains data. - if self.tensor_data is not None: - pad_width = [(0, 0) for i in range(len(self.shape.dims) - 1)] - pad_width.append((0, self.calc_padding(self.shape.dims[-1]))) - self.tensor_data = np.pad(self.tensor_data, pad_width, 'constant') + if self._tensor_data is not None: + pad_width = [(0, 0) for i in range(len(self._shape.dims) - 1)] + pad_width.append((0, self.calc_padding(self._shape.dims[-1]))) + self._tensor_data = np.pad(self._tensor_data, pad_width, 'constant') + + @property + def name(self): + return self._name + + @name.setter + def name(self, name): + self._name = name + + @property + def shape(self): + return self._shape + + @property + def data_type(self): + return self._data_type + + @property + def data_format(self): + return self._data_format + + @property + def tensor_data(self): + return self._tensor_data + + @property + def source(self): + return self._source + + @property + def source_index(self): + return self._source_index + + @property + def targets(self): + return self._targets def dims(self, index): """This returns the size of the dimension.""" - assert index < len(self.shape.dims), "The dimension index is out of bound!" - return self.shape.dims[index] + assert index < len(self._shape.dims), "The dimension index is out of bound!" + return self._shape.dims[index] - def deduce_attrs_from_data(self): + def _deduce_attrs_from_data(self): """Deduce tensor attributes from the supplied tensor data. The deducible attributes include tensor shape dimensions and data type. """ # Deduce dims from tensor data. - self.shape.dims.extend(list(self.tensor_data.shape)) + self._shape.dims.extend(list(self._tensor_data.shape)) # Deduce data type from tensor data try: - self.data_type = datatypes.np_to_smaug_type[self.tensor_data.dtype.type] + self._data_type = datatypes.np_to_smaug_type[self._tensor_data.dtype.type] except KeyError: - assert False, "We don't support numpy dtype: %s" % self.tensor_data.dtype + assert False, "We don't support numpy dtype: %s" % self._tensor_data.dtype def calc_padding(self, value): """This returns the size we need to pad on the last dimension.""" - if self.shape.alignment == 0 or value % self.shape.alignment == 0: + if self._shape.alignment == 0 or value % self._shape.alignment == 0: return 0 - return (self.shape.alignment - (value % self.shape.alignment)) + return (self._shape.alignment - (value % self._shape.alignment)) def to_tensor_proto(self, tensor_proto, tensor_data_array=None): """Serialize the tensor into a tensor proto. @@ -86,39 +127,39 @@ def to_tensor_proto(self, tensor_proto, tensor_data_array=None): tensor_proto: The tensor proto this tensor gets serialized into. tensor_data_array: The tensor data array this tensor gets serialized into. """ - tensor_proto.name = self.name - tensor_proto.shape.CopyFrom(self.shape) - tensor_proto.data_type = self.data_type - tensor_proto.data_format = self.data_format - if self.tensor_data is not None and tensor_data_array is not None: + tensor_proto.name = self._name + tensor_proto.shape.CopyFrom(self._shape) + tensor_proto.data_type = self._data_type + tensor_proto.data_format = self._data_format + if self._tensor_data is not None and tensor_data_array is not None: # Since Protobuf doesn't support float16 data type, we pack two float16 # elements into one int32. - if self.data_type == types_pb2.Float16: + if self._data_type == types_pb2.Float16: # Numpy.view comes in handy here. Note that it won't work if # tensor_data's last dimension is of odd size. To solve that, we # flatten the tensor data, and if the flattened list is still of # odd size, we pad a zero at the end of the list. When we later # deserialize the tensor data, we know the correct shape of the # tensor, and the padded zero will be discarded. - self.tensor_data = self.tensor_data.flatten() - if self.tensor_data.size % 2 != 0: - self.tensor_data = np.append(self.tensor_data, np.float16(0)) - self.tensor_data = self.tensor_data.view(np.int32) + self._tensor_data = self._tensor_data.flatten() + if self._tensor_data.size % 2 != 0: + self._tensor_data = np.append(self._tensor_data, np.float16(0)) + self._tensor_data = self._tensor_data.view(np.int32) # Serialize the data into the proto. tensor_data_proto = tensor_data_array.data_array.add() tensor_data_proto.name = tensor_proto.name - data_list = [x for x in np.nditer(self.tensor_data)] - if self.data_type == types_pb2.Float16: + data_list = [x for x in np.nditer(self._tensor_data)] + if self._data_type == types_pb2.Float16: tensor_data_proto.half_data.extend(data_list) - elif self.data_type == types_pb2.Float32: + elif self._data_type == types_pb2.Float32: tensor_data_proto.float_data.extend(data_list) - elif self.data_type == types_pb2.Float64: + elif self._data_type == types_pb2.Float64: tensor_data_proto.double_data.extend(data_list) - elif self.data_type == types_pb2.Int32: + elif self._data_type == types_pb2.Int32: tensor_data_proto.int_data.extend(data_list) - elif self.data_type == types_pb2.Int64: + elif self._data_type == types_pb2.Int64: tensor_data_proto.int64_data.extend(data_list) - elif self.data_type == types_pb2.Bool: + elif self._data_type == types_pb2.Bool: tensor_data_proto.bool_data.extend(data_list) diff --git a/smaug/python/tensor_test.py b/smaug/python/tensor_test.py index 637d83bd..95a52f11 100755 --- a/smaug/python/tensor_test.py +++ b/smaug/python/tensor_test.py @@ -6,7 +6,7 @@ import numpy as np from smaug.python.tensor_utils import get_tensor_data -from smaug.python.graph import Graph +from smaug.python.graph import Graph, get_node_proto from smaug.python.tensor import Tensor from smaug.python.ops.data_op import input_data from smaug.core import types_pb2 @@ -30,14 +30,15 @@ def test_attr_reference(self): with Graph("test_graph", "Reference") as test_graph: input_tensor = Tensor(data_layout=types_pb2.NHWC, tensor_data=tensor_data) act = input_data(input_tensor, "input") - self.assertEqual(test_graph.graph.backend, "Reference") - node = test_graph.get_node("input") + graph_proto, tensor_data_array = test_graph.to_proto() + self.assertEqual(graph_proto.backend, "Reference") + node = get_node_proto(graph_proto, "input") self.assertEqual(node.input_tensors[0].data_type, types_pb2.Float32) self.assertEqual(node.input_tensors[0].shape.dims, [2, 2, 4, 4]) self.assertEqual(node.input_tensors[0].shape.layout, types_pb2.NHWC) self.assertEqual(node.input_tensors[0].shape.alignment, 0) - tensor_data_proto = get_tensor_data(test_graph.tensor_data_array, - node.input_tensors[0].name) + tensor_data_proto = get_tensor_data( + tensor_data_array, node.input_tensors[0].name) self.assertEqual(tensor_data_proto.float_data, list(tensor_data.flatten())) self.assertEqual(len(tensor_data_proto.half_data), 0) self.assertEqual(len(tensor_data_proto.double_data), 0) @@ -50,14 +51,15 @@ def test_attr_smv_no_padding(self): with Graph("test_graph", "SMV") as test_graph: input_tensor = Tensor(data_layout=types_pb2.NCHW, tensor_data=tensor_data) act = input_data(input_tensor, "input") - self.assertEqual(test_graph.graph.backend, "SMV") - node = test_graph.get_node("input") + graph_proto, tensor_data_array = test_graph.to_proto() + self.assertEqual(graph_proto.backend, "SMV") + node = get_node_proto(graph_proto, "input") self.assertEqual(node.input_tensors[0].data_type, types_pb2.Float16) self.assertEqual(node.input_tensors[0].shape.dims, [2, 2, 4, 8]) self.assertEqual(node.input_tensors[0].shape.layout, types_pb2.NCHW) self.assertEqual(node.input_tensors[0].shape.alignment, 8) - tensor_data_proto = get_tensor_data(test_graph.tensor_data_array, - node.input_tensors[0].name) + tensor_data_proto = get_tensor_data( + tensor_data_array, node.input_tensors[0].name) self.assertEqualFP16(tensor_data_proto.half_data, tensor_data.flatten()) self.assertEqual(len(tensor_data_proto.float_data), 0) self.assertEqual(len(tensor_data_proto.double_data), 0) @@ -71,14 +73,15 @@ def test_attr_smv_padding(self): with Graph("test_graph", "SMV") as test_graph: input_tensor = Tensor(data_layout=types_pb2.NCHW, tensor_data=tensor_data) act = input_data(input_tensor, "input") - self.assertEqual(test_graph.graph.backend, "SMV") - node = test_graph.get_node("input") + graph_proto, tensor_data_array = test_graph.to_proto() + self.assertEqual(graph_proto.backend, "SMV") + node = get_node_proto(graph_proto, "input") self.assertEqual(node.input_tensors[0].data_type, types_pb2.Float16) self.assertEqual(node.input_tensors[0].shape.dims, [2, 4]) self.assertEqual(node.input_tensors[0].shape.layout, types_pb2.NCHW) self.assertEqual(node.input_tensors[0].shape.alignment, 8) - tensor_data_proto = get_tensor_data(test_graph.tensor_data_array, - node.input_tensors[0].name) + tensor_data_proto = get_tensor_data( + tensor_data_array, node.input_tensors[0].name) self.assertEqualFP16( tensor_data_proto.half_data, np.array( @@ -96,10 +99,11 @@ def test_fp16_even(self): with Graph("test_graph", "Reference") as test_graph: input_tensor = Tensor(tensor_data=tensor_data) act = input_data(input_tensor, "input") - node = test_graph.get_node("input") + graph_proto, tensor_data_array = test_graph.to_proto() + node = get_node_proto(graph_proto, "input") self.assertEqual(node.input_tensors[0].data_type, types_pb2.Float16) - tensor_data_proto = get_tensor_data(test_graph.tensor_data_array, - node.input_tensors[0].name) + tensor_data_proto = get_tensor_data( + tensor_data_array, node.input_tensors[0].name) self.assertEqualFP16(tensor_data_proto.half_data, tensor_data.flatten()) def test_fp16_odd(self): @@ -108,10 +112,11 @@ def test_fp16_odd(self): with Graph("test_graph", "Reference") as test_graph: input_tensor = Tensor(tensor_data=tensor_data) act = input_data(input_tensor, "input") - node = test_graph.get_node("input") + graph_proto, tensor_data_array = test_graph.to_proto() + node = get_node_proto(graph_proto, "input") self.assertEqual(node.input_tensors[0].data_type, types_pb2.Float16) - tensor_data_proto = get_tensor_data(test_graph.tensor_data_array, - node.input_tensors[0].name) + tensor_data_proto = get_tensor_data( + tensor_data_array, node.input_tensors[0].name) self.assertEqualFP16(tensor_data_proto.half_data, tensor_data.flatten()) def test_fp16_odd_odd(self): @@ -123,10 +128,11 @@ def test_fp16_odd_odd(self): with Graph("test_graph", "Reference") as test_graph: input_tensor = Tensor(tensor_data=tensor_data) act = input_data(input_tensor, "input") - node = test_graph.get_node("input") + graph_proto, tensor_data_array = test_graph.to_proto() + node = get_node_proto(graph_proto, "input") self.assertEqual(node.input_tensors[0].data_type, types_pb2.Float16) - tensor_data_proto = get_tensor_data(test_graph.tensor_data_array, - node.input_tensors[0].name) + tensor_data_proto = get_tensor_data( + tensor_data_array, node.input_tensors[0].name) self.assertEqualFP16(tensor_data_proto.half_data, np.append(tensor_data.flatten(), np.float16(0))) diff --git a/smaug/python/tensor_utils.py b/smaug/python/tensor_utils.py index e473b558..4ab41177 100644 --- a/smaug/python/tensor_utils.py +++ b/smaug/python/tensor_utils.py @@ -20,53 +20,11 @@ def get_padded_shape(shape): shape.dims[-1] += alignment - remainder; return shape -def from_tensor_proto(tensor_proto, tensor_data_array=None): - """Restore a Tensor from a TensorProto. - - Args: - tensor_proto: A TensorProto. - tensor_data_array: a TensorDataArray that stores tensor data. - - Returns: - A Tensor deserialized from `tensor_proto`. - """ - name = tensor_proto.name - data_type = tensor_proto.data_type - tensor_data = None - if tensor_data_array is not None: - tensor_data_proto = get_tensor_data(tensor_data_array, name) - if tensor_data_proto is not None: - padded_shape = get_padded_shape(shape) - if data_type == types_pb2.Float16: - tensor_data = tensor_data_proto.half_data - if padded_shape.size % 2 != 0: - del tensor_data[-1] - elif data_type == types_pb2.Float32: - tensor_data = tensor_data_proto.float_data - elif data_type == types_pb2.Float64: - tensor_data = tensor_data_proto.double_data - elif data_type == types_pb2.Int32: - tensor_data = tensor_data_proto.int_data - elif data_type == types_pb2.Int64: - tensor_data = tensor_data_proto.int64_data - elif data_type == types_pb2.Bool: - tensor_data = tensor_data_proto.bool_data - # The data retrieved from the proto is one-dimensional, so make it back to - # shaped data. - tensor_data.reshape(padded_shape.dims) - - tensor = Tensor( - dims=tensor_proto.shape.dims, name=name, - data_layout=tensor_proto.shape.layout, data_type=data_type, - data_format=tensor_proto.data_format, tensor_data=tensor_data) - return tensor - def get_tensor_data_op(tensor): """Return the output of a data op if this tensor already has one created.""" for node in tensor.targets: if node.op == types_pb2.Data: - data_op_output = from_tensor_proto(node.output_tensors[0]) - data_op_output.source = (node, 0) + data_op_output = node.outputs[0] return data_op_output return None @@ -82,9 +40,7 @@ def get_tensor_reorder_op(tensor, layout): returned. """ for node in tensor.targets: - if node.op == types_pb2.Reorder and node.output_tensors[ - 0].shape.layout == layout: - reorder_op_output = from_tensor_proto(node.output_tensors[0]) - reorder_op_output.source = (node, 0) + if node.op == types_pb2.Reorder and node.outputs[0].shape.layout == layout: + reorder_op_output = node.outputs[0] return reorder_op_output return None