diff --git a/smaug/python/graph.py b/smaug/python/graph.py
index f1aa449a..59367c7d 100644
--- a/smaug/python/graph.py
+++ b/smaug/python/graph.py
@@ -7,23 +7,23 @@
 from smaug.core import types_pb2
 from smaug.core import tensor_pb2
 from smaug.python import global_vars
+from smaug.python.node import Node
 from smaug.python.tensor import Tensor
 
 class Graph:
   def __init__(
       self, name="DefaultGraph", backend="Reference",
       mem_policy=types_pb2.AllDma):
-    assert (backend in global_vars.backend_alignment)
-    self.graph = graph_pb2.GraphProto()
+    if backend not in global_vars.backend_alignment:
+      raise ValueError("An unknown backend %s is used!" % backend)
+    self._name = name
+    self._backend = backend
+    self._mem_policy = mem_policy
+    self._nodes = []
     self._node_names = {}
-    self.graph.name = name
-    self.graph.backend = backend
-    self.graph.mem_policy = mem_policy
-    self.alignment = global_vars.backend_alignment[backend]
+    self._alignment = global_vars.backend_alignment[self._backend]
     # Layout transformation is enabled by default.
     self._layout_trans_enabled = True
-    # This proto stores all the parameters in the network.
-    self.tensor_data_array = tensor_pb2.TensorDataArray()
     self._parent_graph = None
 
   def __enter__(self):
@@ -39,11 +39,15 @@ def __exit__(self, *args):
 
   @property
   def backend(self):
-    return self.graph.backend
+    return self._backend
 
   @property
   def mem_policy(self):
-    return self.graph.mem_policy
+    return self._mem_policy
+
+  @property
+  def alignment(self):
+    return self._alignment
 
   @property
   def layout_trans_enabled(self):
@@ -51,8 +55,12 @@ def layout_trans_enabled(self):
 
   def merge(self, other):
     """Merge another graph into this."""
+    for node in other.get_nodes():
+      if self.get_node(node.name) is not None:
+        raise ValueError(
+            "The graph to be merged contains a node with the same name as one "
+            "in the current graph. Possibly merging a graph more than once?")
     self.get_nodes().extend(other.get_nodes())
-    self.tensor_data_array.data_array.extend(other.tensor_data_array.data_array)
 
   def add_node(
       self, name, op, input_tensors, output_tensors_dims,
@@ -75,24 +83,16 @@ def add_node(
     Returns:
       The output tensor of the added node.
     """
-    node = self.graph.nodes.add()
-    node.name = self.create_unique_name(name)
-    node.op = op
-
-    # Add the parameters to the node.
-    if params != None:
-      node.params.CopyFrom(params)
+    name = self.create_unique_name(name)
+    node = Node(name, op, params)
+    self._nodes.append(node)
 
-    # Update the node's parents field, and add every input tensor to the node.
+    # Add every input tensor to the node.
     for i,tensor in enumerate(input_tensors):
       if tensor.name == None:
         tensor.name = node.name + "/input%d" % i
-      if tensor.source is not None:
-        node.parents.append(tensor.source[0].name)
-        node.src_tensors_indices.append(tensor.source[1])
+      node.add_input(tensor)
       tensor.targets.append(node)
-      input_tensor_proto = node.input_tensors.add()
-      tensor.to_tensor_proto(input_tensor_proto, self.tensor_data_array)
 
     # Create the output tensor (with the node as its source), and add it to the
     # node.
@@ -101,33 +101,33 @@ def add_node(
       output_tensor = Tensor(
           dims=d, name="%s/output%d" % (node.name, i),
           data_layout=output_tensor_layout, data_type=output_tensor_dtype,
-          data_format=output_tensor_dformat, source=(node, i),
-          alignment=self.alignment)
-      output_tensor_proto = node.output_tensors.add()
-      output_tensor.to_tensor_proto(output_tensor_proto, self.tensor_data_array)
+          data_format=output_tensor_dformat, source=node, source_index=i,
+          alignment=self._alignment)
+      node.add_output(output_tensor)
       output_tensors.append(output_tensor)
 
     return output_tensors
 
   def get_node(self, node_name, recursive=False):
-    """Return a node in the graph proto by its name.
+    """Return a node in the graph by its name.
 
     Args:
       node_name: Node name.
       recursive: If true, recursively search the node in the parent graphs.
 
     Returns:
-      A NodeProto if we find the node.
+      A `Node` if we find the node or None is returned.
     """
-    for i in range(len(self.graph.nodes)):
-      if self.graph.nodes[i].name == node_name:
-        return self.graph.nodes[i]
+    for node in self._nodes:
+      if node.name == node_name:
+        return node
     if recursive and self._parent_graph is not None:
       return self._parent_graph.get_node(node_name, True)
+    return None
 
   def get_nodes(self):
-    """Return nodes in the graph proto."""
-    return self.graph.nodes
+    """Return nodes in the graph."""
+    return self._nodes
 
   def get_root_graph(self):
     """Return the root graph."""
@@ -167,6 +167,21 @@ def enable_layout_transform(self):
     """Enable automatic layout transformation."""
     self._layout_trans_enabled = True
 
+  def to_proto(self):
+    """Serialize the graph.
+
+    Returns:
+      A tuple of (`GraphProto`, `TensorDataArray`).
+    """
+    graph_proto = graph_pb2.GraphProto()
+    graph_proto.name = self._name
+    graph_proto.backend = self._backend
+    graph_proto.mem_policy = self._mem_policy
+    tensor_data_array = tensor_pb2.TensorDataArray()
+    for node in self._nodes:
+      graph_proto.nodes.append(node.to_proto(tensor_data_array))
+    return graph_proto, tensor_data_array
+
   def write_graph(self, name=None):
     """Serialize the graph to a protobuf file.
 
@@ -174,12 +189,14 @@ def write_graph(self, name=None):
       name: Name of the output protobuf file. If not specified, use the graph's
             name instead.
     """
-    if name == None:
-      topo_name = self.graph.name + "_topo.pbtxt"
-      params_name = self.graph.name + "_params.pb"
+    graph_proto, tensor_data_array = self.to_proto()
+    if name is None:
+      name = self._name
+    topo_name = name + "_topo.pbtxt"
+    params_name = name + "_params.pb"
     with open(topo_name, "w") as f_topo, open(params_name, "wb") as f_params:
-      f_topo.write(text_format.MessageToString(self.graph))
-      f_params.write(self.tensor_data_array.SerializeToString())
+      f_topo.write(text_format.MessageToString(graph_proto))
+      f_params.write(tensor_data_array.SerializeToString())
 
   def print_summary(self):
     """Print the summary of the graph.
@@ -189,28 +206,45 @@ def print_summary(self):
     input/output tensors.
     """
     print("=================================================================")
-    print("      Summary of the network: %s (%s)" % (self.graph.name,
-                                                     self.graph.backend))
+    print("      Summary of the network: %s (%s)" % (self._name, self._backend))
     print("=================================================================")
     print(
         "Host memory access policy: %s." %
-        types_pb2.HostMemoryAccessPolicy.Name(self.graph.mem_policy))
+        types_pb2.HostMemoryAccessPolicy.Name(self._mem_policy))
     print("-----------------------------------------------------------------")
-    for node in self.graph.nodes:
+    for node in self._nodes:
       print("Name: %s (%s)" % (node.name, types_pb2.OpType.Name(node.op)))
       print("Parents:", end = '')
-      for i in node.parents:
+      for i in node.get_parents():
+        print(i, end = ' ')
+      print("\nChildren:", end = '')
+      for i in node.get_children():
         print(i, end = ' ')
       print("\nInput tensors:")
-      for t in node.input_tensors:
+      for t in node.inputs:
         print(
             " ", t.name, types_pb2.DataType.Name(t.data_type), t.shape.dims,
             types_pb2.DataLayout.Name(t.shape.layout),
             "alignment(%d)" % t.shape.alignment)
       print("Output tensors:")
-      for t in node.output_tensors:
+      for t in node.outputs:
         print(
             " ", t.name, types_pb2.DataType.Name(t.data_type), t.shape.dims,
             types_pb2.DataLayout.Name(t.shape.layout),
             "alignment(%d)" % t.shape.alignment)
       print("-----------------------------------------------------------------")
+
+def get_node_proto(graph_proto, node_name):
+  """Get a `NodeProto` from `GraphProto` by node name.
+
+  Args:
+    graph_proto: A `GraphProto`.
+    node_name: Name of the node.
+
+  Returns:
+    A `NodeProto` or None.
+  """
+  for node_proto in graph_proto.nodes:
+    if node_proto.name == node_name:
+      return node_proto
+  return None
diff --git a/smaug/python/node.py b/smaug/python/node.py
new file mode 100644
index 00000000..0e0f04e3
--- /dev/null
+++ b/smaug/python/node.py
@@ -0,0 +1,122 @@
+import sys
+import numpy as np
+
+from smaug.core import node_pb2
+from smaug.core import types_pb2
+from smaug.python import global_vars
+from smaug.python import datatypes
+
+class Node:
+  def __init__(self, name, op, params=None, inputs=None, outputs=None):
+    """Create a node.
+
+    A `Node` instance contains information about its corresponding operation,
+    including the operator type, parameters and input/output tensors. A `Graph`
+    is made up of `Node`s. When serialized, a `NodeProto` is created.
+
+    Args:
+      name: Name of the node.
+      op: `OpType` representing the operation type of the node.
+      params: `Params` used by the operator (optional).
+      inputs: A list of `Tensor` (optional).
+      outputs: A list of `Tensor` (optional).
+
+    Returns:
+      A `Node` instance.
+    """
+    self._name = name
+    self._op = op
+    self._params = params
+    self._inputs = [] if inputs is None else inputs
+    self._outputs = [] if outputs is None else outputs
+
+  @property
+  def name(self):
+    return self._name
+
+  @property
+  def op(self):
+    return self._op
+
+  @property
+  def inputs(self):
+    return self._inputs
+
+  @property
+  def outputs(self):
+    return self._outputs
+
+  def add_input(self, tensor):
+    """Add an input tensor to the node.
+
+    Args:
+      tensor: A `Tensor`.
+    """
+    self._inputs.append(tensor)
+
+  def add_output(self, tensor):
+    """Add an output tensor to the node.
+
+    Args:
+      tensor: A `Tensor`.
+    """
+    self._outputs.append(tensor)
+
+  def update_input(self, tensor, index):
+    """Update the `index`th input with `tensor`.
+
+    Args:
+      tensor: A `Tensor` representing the new input.
+      index: The input index.
+    """
+    self._inputs[index] = tensor
+
+  def get_parents(self):
+    """Get the parents of the node.
+
+    Returns:
+      A list of strings representing names of the parent nodes.
+    """
+    parents = []
+    for tensor in self._inputs:
+      if tensor.source is not None:
+        parents.append(tensor.source.name)
+    return parents
+
+  def get_children(self):
+    """Get the children of the node.
+
+    Returns:
+      A list of strings representing names of the children nodes.
+    """
+    children = []
+    for tensor in self._outputs:
+      for target in tensor.targets:
+        children.append(target.name)
+    return children
+
+  def to_proto(self, tensor_data_array):
+    """Serialize `Node` into `NodeProto`.
+
+    Args:
+      tensor_data_array: `TensorDataArray` that tensor data gets serialized
+        into.
+
+    Returns:
+      A `NodeProto`.
+    """
+    node_proto = node_pb2.NodeProto()
+    node_proto.name = self._name
+    node_proto.op = self._op
+    if self._params is not None:
+      node_proto.params.CopyFrom(self._params)
+    for tensor in self._inputs:
+      if tensor.source is not None:
+        node_proto.parents.append(tensor.source.name)
+        node_proto.src_tensors_indices.append(tensor.source_index)
+      tensor_proto = node_proto.input_tensors.add()
+      tensor.to_tensor_proto(tensor_proto, tensor_data_array)
+    for tensor in self._outputs:
+      tensor_proto = node_proto.output_tensors.add()
+      tensor.to_tensor_proto(tensor_proto, tensor_data_array)
+    return node_proto
diff --git a/smaug/python/ops/activation_ops_test.py b/smaug/python/ops/activation_ops_test.py
index 4975bf85..526c1412 100755
--- a/smaug/python/ops/activation_ops_test.py
+++ b/smaug/python/ops/activation_ops_test.py
@@ -5,7 +5,7 @@
 import unittest
 import numpy as np
 
-from smaug.python.graph import Graph
+from smaug.python.graph import Graph, get_node_proto
 from smaug.python.tensor import Tensor
 from smaug.python.ops import data_op
 from smaug.python.ops import activation_ops
@@ -19,7 +19,8 @@ def __init__(self, *args, **kwargs):
     self.tensor_shape = [2, 32, 32, 32]
 
   def do_basic_test(self, test_graph, node_name, op_type, tensor_shape=None):
-    node = test_graph.get_node(node_name)
+    graph_proto, _ = test_graph.to_proto()
+    node = get_node_proto(graph_proto, node_name)
     if tensor_shape == None:
       tensor_shape = self.tensor_shape
     self.assertEqual(node.op, op_type)
@@ -129,44 +130,45 @@ def test_fusing_activation_functions(self):
           activation=types_pb2.ReLU, name="bn_relu")
       act = nn_ops.mat_mul(
           act, weight_tensor, activation=types_pb2.ReLU, name="fc_relu")
+    graph_proto, _ = test_graph.to_proto()
     # None
-    node = test_graph.get_node("conv_none")
+    node = get_node_proto(graph_proto, "conv_none")
     self.assertEqual(node.params.act_params.activation, types_pb2.UnknownOp)
     # ReLU
-    node = test_graph.get_node("conv_relu")
+    node = get_node_proto(graph_proto, "conv_relu")
     self.assertEqual(node.params.act_params.activation, types_pb2.ReLU)
     # LReLU (default slope = 0.2)
-    node = test_graph.get_node("conv_lrelu")
+    node = get_node_proto(graph_proto, "conv_lrelu")
     self.assertEqual(node.params.act_params.activation, types_pb2.LReLU)
     self.assertAlmostEqual(node.params.act_params.lrelu_params.slope, 0.2)
     # ELU (default alpha = 0.1)
-    node = test_graph.get_node("conv_elu")
+    node = get_node_proto(graph_proto, "conv_elu")
     self.assertEqual(node.params.act_params.activation, types_pb2.ELU)
     self.assertAlmostEqual(node.params.act_params.elu_params.alpha, 0.1)
     # SELU (default alpha = 1.6733, lambda = 1.0507)
-    node = test_graph.get_node("conv_selu")
+    node = get_node_proto(graph_proto, "conv_selu")
     self.assertEqual(node.params.act_params.activation, types_pb2.SELU)
     self.assertAlmostEqual(node.params.act_params.elu_params.alpha, 1.6733, 5)
     self.assertAlmostEqual(node.params.act_params.elu_params.lambda_param,
                            1.0507, 5)
     # Tanh
-    node = test_graph.get_node("conv_tanh")
+    node = get_node_proto(graph_proto, "conv_tanh")
     self.assertEqual(node.params.act_params.activation, types_pb2.Tanh)
     # HardTanh (default min = -1, max = 1)
-    node = test_graph.get_node("conv_hard_tanh")
+    node = get_node_proto(graph_proto, "conv_hard_tanh")
     self.assertEqual(node.params.act_params.activation, types_pb2.HardTanh)
     self.assertAlmostEqual(node.params.act_params.hard_tanh_params.min, -1)
     self.assertAlmostEqual(node.params.act_params.hard_tanh_params.max, 1)
     # Sigmoid
-    node = test_graph.get_node("conv_sigmoid")
+    node = get_node_proto(graph_proto, "conv_sigmoid")
     self.assertEqual(node.params.act_params.activation, types_pb2.Sigmoid)
     # Softmax
-    node = test_graph.get_node("conv_softmax")
+    node = get_node_proto(graph_proto, "conv_softmax")
     self.assertEqual(node.params.act_params.activation, types_pb2.Softmax)
     # Fusion with inner products and batch norms.
-    node = test_graph.get_node("bn_relu")
+    node = get_node_proto(graph_proto, "bn_relu")
     self.assertEqual(node.params.act_params.activation, types_pb2.ReLU)
-    node = test_graph.get_node("fc_relu")
+    node = get_node_proto(graph_proto, "fc_relu")
     self.assertEqual(node.params.act_params.activation, types_pb2.ReLU)
 
 if __name__ == "__main__":
diff --git a/smaug/python/ops/common.py b/smaug/python/ops/common.py
index 5a2c3d41..1ad469d3 100644
--- a/smaug/python/ops/common.py
+++ b/smaug/python/ops/common.py
@@ -23,7 +23,7 @@ def check_and_add_layout_transform(name, op, input_tensors):
   from smaug.python.ops.array_ops import reorder
   if not global_vars.get_graph().layout_trans_enabled:
     return input_tensors
-  backend = global_vars.get_graph().graph.backend
+  backend = global_vars.get_graph().backend
   for i in range(len(input_tensors)):
     expected_layoutset = global_vars.backend_layouts[backend][
         op].input_layoutsets[i]
diff --git a/smaug/python/ops/control_flow_ops.py b/smaug/python/ops/control_flow_ops.py
index 65de4676..a7d5bf0b 100644
--- a/smaug/python/ops/control_flow_ops.py
+++ b/smaug/python/ops/control_flow_ops.py
@@ -75,28 +75,20 @@ def _insert_switch_nodes(predication, branch_result, graph):
     # This keeps track of all the tensors that come from nodes in the graph.
     internal_tensors = set()
     for node in nodes:
-      internal_tensors.update(
-          set([tensor.name for tensor in node.output_tensors]))
+      internal_tensors.update(set([tensor.name for tensor in node.outputs]))
     for node in nodes:
-      for i, tensor_proto in enumerate(node.input_tensors):
+      for i, tensor in enumerate(node.inputs):
         # If any input tensor of the graph appear in the graph workspace, then
         # this tensor is an external to the graph and we create a switch node
         # for it.
         # Don't create switch node for an existing one.
         if node.op == types_pb2.Switch:
           continue
-        if tensor_proto.name not in internal_tensors:
-          source_node = graph.get_node(node.parents[i], True)
-          tensor = tensor_utils.from_tensor_proto(tensor_proto)
-          if source_node is not None:
-            tensor.source = (source_node, node.src_tensors_indices[i])
+        if tensor.name not in internal_tensors:
           switch_result = switch(
               tensor, predication)[switch_op_output_ports[branch_result]]
-          # Update the node with the switch node as its new parent.
-          switch_result.to_tensor_proto(node.input_tensors[i])
-          switch_node = switch_result.source[0]
-          node.parents[i] = switch_node.name
-          node.src_tensors_indices[i] = switch_op_output_ports[branch_result]
+          # Update the node's input with the switch node result.
+          node.update_input(switch_result, i)
 
   cur_graph = global_vars.get_graph()
   backend = cur_graph.backend
@@ -110,7 +102,6 @@ def _insert_switch_nodes(predication, branch_result, graph):
     if not isinstance(res_t, (list, tuple)):
       res_t = [res_t]
     _insert_switch_nodes(predication, "true", subgraph_t)
-  cur_graph.merge(subgraph_t)
 
   # Build the subgraph for the false branch.
   with Graph(name="%s_false_branch" % name, backend=backend,
@@ -119,7 +110,6 @@ def _insert_switch_nodes(predication, branch_result, graph):
     if not isinstance(res_f, (list, tuple)):
       res_f = [res_f]
     _insert_switch_nodes(predication, "false", subgraph_f)
-  cur_graph.merge(subgraph_f)
 
   # Add the merge nodes for the outputs.
   merges = [merge([t, f]) for (t, f) in zip(res_t, res_f)]
diff --git a/smaug/python/ops/ops_test.py b/smaug/python/ops/ops_test.py
index 02e6d2fb..bc3545c3 100755
--- a/smaug/python/ops/ops_test.py
+++ b/smaug/python/ops/ops_test.py
@@ -5,7 +5,7 @@
 import unittest
 import numpy as np
 
-from smaug.python.graph import Graph
+from smaug.python.graph import Graph, get_node_proto
 from smaug.python.tensor import Tensor
 from smaug.python.ops import data_op
 from smaug.python.ops import activation_ops
@@ -45,6 +45,9 @@ def assertEqualDims(self, dims, layout, expected_dims, expected_layout):
     else:
       assert False, "Other layouts not expected here!"
 
+  def get_node(self, name):
+    return get_node_proto(self.test_graph, name)
+
   def build_test_sequential_graph(self, backend):
     """Create a sequential model."""
     np_dtype = test_backend_dtypes[backend]
@@ -100,7 +103,7 @@ def build_test_sequential_graph(self, backend):
       out = array_ops.stack(out, 4, 1, "stack")
       out0, out1, out2, out3 = array_ops.unstack(out, 1, "unstack")
 
-    self.test_graph = graph
+    self.test_graph, _ = graph.to_proto()
     self.backend = backend
     self.alignment = global_vars.backend_alignment[backend]
 
@@ -160,16 +163,16 @@ def build_test_residual_graph(self, backend):
           math_ops.add(out0, out1, "add1"), math_ops.add(out2, out3, "add2"),
           "mul1")
 
-    self.test_graph = graph
+    self.test_graph, _ = graph.to_proto()
     self.backend = backend
     self.alignment = global_vars.backend_alignment[
-        self.test_graph.graph.backend]
+        self.test_graph.backend]
 
 class SequentialGraphTest(OperatorTest):
   """Common tests for the sequential graph."""
 
   def test_input_op(self):
-    node = self.test_graph.get_node("input")
+    node = self.get_node("input")
     self.assertEqual(node.op, types_pb2.Data)
     self.assertEqual(len(node.input_tensors), 1)
     self.assertEqual(len(node.output_tensors), 1)
@@ -186,7 +189,7 @@ def test_convolution_op(self):
     expected_output_layout = global_vars.backend_layouts[self.backend][
         types_pb2.Convolution3d].output_layoutset.layouts
     # The first convolution operator "conv0"
-    node = self.test_graph.get_node("conv0")
+    node = self.get_node("conv0")
     self.assertEqual(node.op, types_pb2.Convolution3d)
     self.assertEqual(len(node.input_tensors), 2)
     self.assertEqual(len(node.output_tensors), 1)
@@ -211,7 +214,7 @@ def test_convolution_op(self):
     self.assertEqual(node.output_tensors[0].shape.alignment, self.alignment)
 
     # The second convolution operator "conv1"
-    node = self.test_graph.get_node("conv1")
+    node = self.get_node("conv1")
     self.assertEqual(node.op, types_pb2.Convolution3d)
     self.assertEqual(len(node.input_tensors), 2)
     self.assertEqual(len(node.output_tensors), 1)
@@ -237,7 +240,7 @@ def test_convolution_op(self):
 
   def test_relu_op(self):
     # The first relu operator "conv0_relu"
-    node = self.test_graph.get_node("conv0_relu")
+    node = self.get_node("conv0_relu")
     self.assertEqual(node.op, types_pb2.ReLU)
     self.assertEqual(len(node.input_tensors), 1)
     self.assertEqual(len(node.output_tensors), 1)
@@ -252,7 +255,7 @@ def test_relu_op(self):
     self.assertEqual(node.output_tensors[0].shape.alignment, self.alignment)
 
     # The second relu operator "conv1_relu"
-    node = self.test_graph.get_node("conv1_relu")
+    node = self.get_node("conv1_relu")
     self.assertEqual(node.op, types_pb2.ReLU)
     self.assertEqual(len(node.input_tensors), 1)
     self.assertEqual(len(node.output_tensors), 1)
@@ -267,7 +270,7 @@ def test_relu_op(self):
     self.assertEqual(node.output_tensors[0].shape.alignment, self.alignment)
 
     # The third relu operator "fc0_relu"
-    node = self.test_graph.get_node("fc0_relu")
+    node = self.get_node("fc0_relu")
     self.assertEqual(node.op, types_pb2.ReLU)
     self.assertEqual(len(node.input_tensors), 1)
     self.assertEqual(len(node.output_tensors), 1)
@@ -280,7 +283,7 @@ def test_relu_op(self):
     self.assertEqual(node.output_tensors[0].shape.alignment, self.alignment)
 
   def test_batch_norm_op(self):
-    node = self.test_graph.get_node("bn")
+    node = self.get_node("bn")
     self.assertEqual(node.op, types_pb2.BatchNorm)
     self.assertEqual(len(node.input_tensors), 5)
     self.assertEqual(len(node.output_tensors), 1)
@@ -303,7 +306,7 @@ def test_batch_norm_op(self):
   def test_max_pool_op(self):
     expected_output_layout = global_vars.backend_layouts[self.backend][
         types_pb2.MaxPooling].output_layoutset.layouts
-    node = self.test_graph.get_node("pool")
+    node = self.get_node("pool")
     self.assertEqual(node.op, types_pb2.MaxPooling)
     self.assertEqual(len(node.input_tensors), 1)
     self.assertEqual(len(node.output_tensors), 1)
@@ -321,7 +324,7 @@ def test_max_pool_op(self):
     self.assertEqual(node.output_tensors[0].shape.alignment, self.alignment)
 
   def test_flatten_op(self):
-    node = self.test_graph.get_node("flatten")
+    node = self.get_node("flatten")
     self.assertEqual(node.op, types_pb2.Reorder)
     self.assertEqual(len(node.input_tensors), 1)
     self.assertEqual(len(node.output_tensors), 1)
@@ -338,7 +341,7 @@ def test_mat_mul_op(self):
     expected_output_layout = global_vars.backend_layouts[self.backend][
         types_pb2.InnerProduct].output_layoutset.layouts
     # The first mat_mul operator "fc0"
-    node = self.test_graph.get_node("fc0")
+    node = self.get_node("fc0")
     self.assertEqual(node.op, types_pb2.InnerProduct)
     self.assertEqual(len(node.input_tensors), 2)
     self.assertEqual(len(node.output_tensors), 1)
@@ -360,7 +363,7 @@ def test_mat_mul_op(self):
     self.assertEqual(node.output_tensors[0].shape.alignment, self.alignment)
 
     # The second mat_mul operator "fc1"
-    node = self.test_graph.get_node("fc1")
+    node = self.get_node("fc1")
     self.assertEqual(node.op, types_pb2.InnerProduct)
     self.assertEqual(len(node.input_tensors), 2)
     self.assertEqual(len(node.output_tensors), 1)
@@ -382,7 +385,7 @@ def test_mat_mul_op(self):
     self.assertEqual(node.output_tensors[0].shape.alignment, self.alignment)
 
   def test_expand_dims_op(self):
-    node = self.test_graph.get_node("expand_dims")
+    node = self.get_node("expand_dims")
     self.assertEqual(node.op, types_pb2.Reshape)
     self.assertEqual(len(node.input_tensors), 1)
     self.assertEqual(len(node.output_tensors), 1)
@@ -394,7 +397,7 @@ def test_expand_dims_op(self):
     self.assertEqual(node.output_tensors[0].shape.alignment, self.alignment)
 
   def test_squeeze_op(self):
-    node = self.test_graph.get_node("squeeze")
+    node = self.get_node("squeeze")
     self.assertEqual(node.op, types_pb2.Reshape)
     self.assertEqual(len(node.input_tensors), 1)
     self.assertEqual(len(node.output_tensors), 1)
@@ -406,7 +409,7 @@ def test_squeeze_op(self):
     self.assertEqual(node.output_tensors[0].shape.alignment, self.alignment)
 
   def test_reshape_op(self):
-    node = self.test_graph.get_node("reshape")
+    node = self.get_node("reshape")
     self.assertEqual(node.op, types_pb2.Reshape)
     self.assertEqual(len(node.input_tensors), 1)
     self.assertEqual(len(node.output_tensors), 1)
@@ -418,7 +421,7 @@ def test_reshape_op(self):
     self.assertEqual(node.output_tensors[0].shape.alignment, self.alignment)
 
   def test_repeat_op(self):
-    node = self.test_graph.get_node("repeat")
+    node = self.get_node("repeat")
     self.assertEqual(node.op, types_pb2.Repeat)
     self.assertEqual(len(node.input_tensors), 1)
     self.assertEqual(len(node.output_tensors), 1)
@@ -432,7 +435,7 @@ def test_repeat_op(self):
   def test_stack_op(self):
     # stack op is implemented using expand_dims and repeat. Here we only test
     # the output.
-    node = self.test_graph.get_node("stack:repeat")
+    node = self.get_node("stack:repeat")
     self.assertEqual(node.output_tensors[0].data_type, self.expected_dtype)
     self.assertEqual(node.output_tensors[0].shape.dims, [8, 4, 10])
     self.assertEqual(node.output_tensors[0].shape.layout, types_pb2.NTC)
@@ -442,7 +445,7 @@ def test_unstack_op(self):
     # unstack op is implemented using split and squeeze. Here we only test
     # the output.
     for i in range(4):
-      node = self.test_graph.get_node("unstack:squeeze" +
+      node = self.get_node("unstack:squeeze" +
                                       ("_%d" % i if i > 0 else ""))
       self.assertEqual(node.output_tensors[0].data_type, self.expected_dtype)
       self.assertEqual(node.output_tensors[0].shape.dims, [8, 10])
@@ -453,7 +456,7 @@ class ResidualGraphTest(OperatorTest):
   """Common tests for the residual graph."""
 
   def test_input_op(self):
-    node = self.test_graph.get_node("input")
+    node = self.get_node("input")
     self.assertEqual(node.op, types_pb2.Data)
     self.assertEqual(len(node.input_tensors), 1)
     self.assertEqual(len(node.output_tensors), 1)
@@ -470,7 +473,7 @@ def test_convolution_op(self):
     expected_output_layout = global_vars.backend_layouts[self.backend][
         types_pb2.Convolution3d].output_layoutset.layouts
     # The first convolution operator "conv0"
-    node = self.test_graph.get_node("conv0")
+    node = self.get_node("conv0")
     self.assertEqual(node.op, types_pb2.Convolution3d)
     self.assertEqual(len(node.input_tensors), 2)
     self.assertEqual(len(node.output_tensors), 1)
@@ -495,7 +498,7 @@ def test_convolution_op(self):
     self.assertEqual(node.output_tensors[0].shape.alignment, self.alignment)
 
     # The second convolution operator "conv1"
-    node = self.test_graph.get_node("conv1")
+    node = self.get_node("conv1")
     self.assertEqual(node.op, types_pb2.Convolution3d)
     self.assertEqual(len(node.input_tensors), 2)
     self.assertEqual(len(node.output_tensors), 1)
@@ -520,7 +523,7 @@ def test_convolution_op(self):
     self.assertEqual(node.output_tensors[0].shape.alignment, self.alignment)
 
     # The third convolution operator "conv2"
-    node = self.test_graph.get_node("conv2")
+    node = self.get_node("conv2")
     self.assertEqual(node.op, types_pb2.Convolution3d)
     self.assertEqual(len(node.input_tensors), 2)
     self.assertEqual(len(node.output_tensors), 1)
@@ -545,7 +548,7 @@ def test_convolution_op(self):
     self.assertEqual(node.output_tensors[0].shape.alignment, self.alignment)
 
   def test_relu_op(self):
-    node = self.test_graph.get_node("relu")
+    node = self.get_node("relu")
     self.assertEqual(node.op, types_pb2.ReLU)
     self.assertEqual(len(node.input_tensors), 1)
     self.assertEqual(len(node.output_tensors), 1)
@@ -560,7 +563,7 @@ def test_relu_op(self):
     self.assertEqual(node.output_tensors[0].shape.alignment, self.alignment)
 
   def test_batch_norm_op(self):
-    node = self.test_graph.get_node("bn")
+    node = self.get_node("bn")
     self.assertEqual(node.op, types_pb2.BatchNorm)
     self.assertEqual(len(node.input_tensors), 5)
     self.assertEqual(len(node.output_tensors), 1)
@@ -582,7 +585,7 @@ def test_batch_norm_op(self):
 
   def test_add_op(self):
     # The first add operator (add)
-    node = self.test_graph.get_node("add")
+    node = self.get_node("add")
     self.assertEqual(node.op, types_pb2.EltwiseAdd)
     self.assertEqual(len(node.input_tensors), 2)
     self.assertEqual(len(node.output_tensors), 1)
@@ -597,7 +600,7 @@ def test_add_op(self):
     self.assertEqual(node.output_tensors[0].shape.alignment, self.alignment)
 
     # The second add operator (add1)
-    node = self.test_graph.get_node("add1")
+    node = self.get_node("add1")
     self.assertEqual(node.op, types_pb2.EltwiseAdd)
     self.assertEqual(len(node.input_tensors), 2)
     self.assertEqual(len(node.output_tensors), 1)
@@ -612,7 +615,7 @@ def test_add_op(self):
     self.assertEqual(node.output_tensors[0].shape.alignment, self.alignment)
 
     # The third add operator (add2)
-    node = self.test_graph.get_node("add2")
+    node = self.get_node("add2")
     self.assertEqual(node.op, types_pb2.EltwiseAdd)
     self.assertEqual(len(node.input_tensors), 2)
     self.assertEqual(len(node.output_tensors), 1)
@@ -628,7 +631,7 @@ def test_add_op(self):
 
   def test_mul_op(self):
     # The first mul operator (mul)
-    node = self.test_graph.get_node("mul")
+    node = self.get_node("mul")
     self.assertEqual(node.op, types_pb2.EltwiseMul)
     self.assertEqual(len(node.input_tensors), 2)
     self.assertEqual(len(node.output_tensors), 1)
@@ -643,7 +646,7 @@ def test_mul_op(self):
     self.assertEqual(node.output_tensors[0].shape.alignment, self.alignment)
 
     # The second add operator (mul1)
-    node = self.test_graph.get_node("mul1")
+    node = self.get_node("mul1")
     self.assertEqual(node.op, types_pb2.EltwiseMul)
     self.assertEqual(len(node.input_tensors), 2)
     self.assertEqual(len(node.output_tensors), 1)
@@ -658,7 +661,7 @@ def test_mul_op(self):
     self.assertEqual(node.output_tensors[0].shape.alignment, self.alignment)
 
   def test_concat_op(self):
-    node = self.test_graph.get_node("concat")
+    node = self.get_node("concat")
     self.assertEqual(node.op, types_pb2.Concat)
     self.assertEqual(len(node.input_tensors), 2)
     self.assertEqual(len(node.output_tensors), 1)
@@ -673,7 +676,7 @@ def test_concat_op(self):
     self.assertEqual(node.output_tensors[0].shape.alignment, self.alignment)
 
   def test_split_op(self):
-    node = self.test_graph.get_node("split")
+    node = self.get_node("split")
     self.assertEqual(node.op, types_pb2.Split)
     self.assertEqual(len(node.input_tensors), 1)
     self.assertEqual(len(node.output_tensors), 4)
@@ -702,64 +705,64 @@ def test_parent_children(self):
     between layers, so we delete this test from the above common tests.
     """
     # input (Data).
-    node = self.test_graph.get_node("input")
+    node = self.get_node("input")
     self.assertEqual(len(node.parents), 0)
     # Reorder input from NCHW to NHWC.
-    node = self.test_graph.get_node("reorder")
+    node = self.get_node("reorder")
     self.assertEqual(node.parents[0], "input")
     # conv0 (Convolution).
-    node = self.test_graph.get_node("conv0")
+    node = self.get_node("conv0")
     self.assertEqual(node.parents[0], "reorder")
     # conv0_relu (ReLU).
-    node = self.test_graph.get_node("conv0_relu")
+    node = self.get_node("conv0_relu")
     self.assertEqual(node.parents[0], "conv0")
     # bn (BN).
-    node = self.test_graph.get_node("bn")
+    node = self.get_node("bn")
     self.assertEqual(node.parents[0], "conv0_relu")
     # conv1 (Convolution).
-    node = self.test_graph.get_node("conv1")
+    node = self.get_node("conv1")
     self.assertEqual(node.parents[0], "bn")
     # conv1_relu (ReLU).
-    node = self.test_graph.get_node("conv1_relu")
+    node = self.get_node("conv1_relu")
     self.assertEqual(node.parents[0], "conv1")
     # pool (MaxPooling).
-    node = self.test_graph.get_node("pool")
+    node = self.get_node("pool")
     self.assertEqual(node.parents[0], "conv1_relu")
     # flatten (Flatten).
-    node = self.test_graph.get_node("flatten")
+    node = self.get_node("flatten")
     self.assertEqual(node.parents[0], "pool")
     # fc0 (FC).
-    node = self.test_graph.get_node("fc0")
+    node = self.get_node("fc0")
     self.assertEqual(node.parents[0], "flatten")
     # fc0_relu (ReLU)
-    node = self.test_graph.get_node("fc0_relu")
+    node = self.get_node("fc0_relu")
     self.assertEqual(node.parents[0], "fc0")
     # fc1 (FC).
-    node = self.test_graph.get_node("fc1")
+    node = self.get_node("fc1")
     self.assertEqual(node.parents[0], "fc0_relu")
     # expand_dims (Reshape).
-    node = self.test_graph.get_node("expand_dims")
+    node = self.get_node("expand_dims")
     self.assertEqual(node.parents[0], "fc1")
     # squeeze (Reshape).
-    node = self.test_graph.get_node("squeeze")
+    node = self.get_node("squeeze")
     self.assertEqual(node.parents[0], "expand_dims")
     # reshape (Reshape).
-    node = self.test_graph.get_node("reshape")
+    node = self.get_node("reshape")
     self.assertEqual(node.parents[0], "squeeze")
     # repeat (Repeat).
-    node = self.test_graph.get_node("repeat")
+    node = self.get_node("repeat")
     self.assertEqual(node.parents[0], "reshape")
     # stack (Reshape + Repeat).
-    node = self.test_graph.get_node("stack:expand_dims")
+    node = self.get_node("stack:expand_dims")
     self.assertEqual(node.parents[0], "repeat")
-    node = self.test_graph.get_node("stack:repeat")
+    node = self.get_node("stack:repeat")
     self.assertEqual(node.parents[0], "stack:expand_dims")
     # unstack (Split + Squeeze).
-    node = self.test_graph.get_node("unstack:split")
+    node = self.get_node("unstack:split")
     self.assertEqual(node.parents[0], "stack:repeat")
     for i in range(4):
       child_name = "unstack:squeeze" + ("_%d" % i if i > 0 else "")
-      child_node = self.test_graph.get_node(child_name)
+      child_node = self.get_node(child_name)
       self.assertEqual(child_node.parents[0], "unstack:split")
       self.assertEqual(child_node.src_tensors_indices, [i])
 
@@ -778,65 +781,65 @@ def test_parent_children(self):
     """Test the parent/child relationship in the graph."""
 
     # input (Data).
-    node = self.test_graph.get_node("input")
+    node = self.get_node("input")
     self.assertEqual(len(node.parents), 0)
     # conv0 (Convolution).
-    node = self.test_graph.get_node("conv0")
+    node = self.get_node("conv0")
     self.assertEqual(node.parents[0], "input")
     # conv0_relu (ReLU).
-    node = self.test_graph.get_node("conv0_relu")
+    node = self.get_node("conv0_relu")
     self.assertEqual(node.parents[0], "conv0")
     # bn (BN)
-    node = self.test_graph.get_node("bn")
+    node = self.get_node("bn")
     self.assertEqual(node.parents[0], "conv0_relu")
     # conv1 (Convolution).
-    node = self.test_graph.get_node("conv1")
+    node = self.get_node("conv1")
     self.assertEqual(node.parents[0], "bn")
     # conv1_relu (ReLU)
-    node = self.test_graph.get_node("conv1_relu")
+    node = self.get_node("conv1_relu")
     self.assertEqual(node.parents[0], "conv1")
     # pool (MaxPooling)
-    node = self.test_graph.get_node("pool")
+    node = self.get_node("pool")
     self.assertEqual(node.parents[0], "conv1_relu")
     # flatten (Flatten)
-    node = self.test_graph.get_node("flatten")
+    node = self.get_node("flatten")
     self.assertEqual(node.parents[0], "pool")
     # Transpose fc0 weights
-    node = self.test_graph.get_node("reorder")
+    node = self.get_node("reorder")
     # fc0 (FC)
-    node = self.test_graph.get_node("fc0")
+    node = self.get_node("fc0")
     self.assertEqual(node.parents, ["flatten", "reorder"])
     # fc0_relu (ReLU)
-    node = self.test_graph.get_node("fc0_relu")
+    node = self.get_node("fc0_relu")
     self.assertEqual(node.parents[0], "fc0")
     # Transpose fc1/weights
-    node = self.test_graph.get_node("reorder_1")
+    node = self.get_node("reorder_1")
     # fc1 (FC)
-    node = self.test_graph.get_node("fc1")
+    node = self.get_node("fc1")
     self.assertEqual(node.parents, ["fc0_relu", "reorder_1"])
     # expand_dims (Reshape).
-    node = self.test_graph.get_node("expand_dims")
+    node = self.get_node("expand_dims")
     self.assertEqual(node.parents[0], "fc1")
     # squeeze (Reshape).
-    node = self.test_graph.get_node("squeeze")
+    node = self.get_node("squeeze")
     self.assertEqual(node.parents[0], "expand_dims")
     # reshape (Reshape)
-    node = self.test_graph.get_node("reshape")
+    node = self.get_node("reshape")
     self.assertEqual(node.parents[0], "squeeze")
     # repeat (Repeat)
-    node = self.test_graph.get_node("repeat")
+    node = self.get_node("repeat")
     self.assertEqual(node.parents[0], "reshape")
     # stack (Reshape + Repeat).
-    node = self.test_graph.get_node("stack:expand_dims")
+    node = self.get_node("stack:expand_dims")
     self.assertEqual(node.parents[0], "repeat")
-    node = self.test_graph.get_node("stack:repeat")
+    node = self.get_node("stack:repeat")
     self.assertEqual(node.parents[0], "stack:expand_dims")
     # unstack (Split + Squeeze).
-    node = self.test_graph.get_node("unstack:split")
+    node = self.get_node("unstack:split")
     self.assertEqual(node.parents[0], "stack:repeat")
     for i in range(4):
       child_name = "unstack:squeeze" + ("_%d" % i if i > 0 else "")
-      child_node = self.test_graph.get_node(child_name)
+      child_node = self.get_node(child_name)
       self.assertEqual(child_node.parents[0], "unstack:split")
       self.assertEqual(child_node.src_tensors_indices, [i])
 
@@ -851,49 +854,49 @@ def test_parent_children(self):
     """Test the parent/child relationship in the graph."""
 
     # input (Data).
-    node = self.test_graph.get_node("input")
+    node = self.get_node("input")
     self.assertEqual(len(node.parents), 0)
     # Reorder input from NCHW to NHWC.
-    node = self.test_graph.get_node("reorder")
+    node = self.get_node("reorder")
     self.assertEqual(node.parents[0], "input")
     # conv0 (Convolution).
-    node = self.test_graph.get_node("conv0")
+    node = self.get_node("conv0")
     self.assertEqual(node.parents[0], "reorder")
     # conv1 (Convolution).
-    node = self.test_graph.get_node("conv1")
+    node = self.get_node("conv1")
     self.assertEqual(node.parents[0], "reorder")
     # bn (BN).
-    node = self.test_graph.get_node("bn")
+    node = self.get_node("bn")
     self.assertEqual(node.parents[0], "conv1")
     # relu (ReLU).
-    node = self.test_graph.get_node("relu")
+    node = self.get_node("relu")
     self.assertEqual(node.parents[0], "bn")
     # conv2 (Convolution).
-    node = self.test_graph.get_node("conv2")
+    node = self.get_node("conv2")
     self.assertEqual(node.parents[0], "relu")
     # add (EltwiseAdd).
-    node = self.test_graph.get_node("add")
+    node = self.get_node("add")
     self.assertEqual(node.parents[0], "conv0")
     self.assertEqual(node.parents[1], "conv2")
     # mul (EltwiseMul).
-    node = self.test_graph.get_node("mul")
+    node = self.get_node("mul")
     self.assertEqual(node.parents, ["conv0", "add"])
     # concat (Concat).
-    node = self.test_graph.get_node("concat")
+    node = self.get_node("concat")
     self.assertEqual(node.parents, ["conv0", "mul"])
     # split (Split).
-    node = self.test_graph.get_node("split")
+    node = self.get_node("split")
     self.assertEqual(node.parents[0], "concat")
     # add1 (EltwiseAdd).
-    node = self.test_graph.get_node("add1")
+    node = self.get_node("add1")
     self.assertEqual(node.parents, ["split", "split"])
     self.assertEqual(node.src_tensors_indices, [0, 1])
     # add2 (EltwiseAdd).
-    node = self.test_graph.get_node("add2")
+    node = self.get_node("add2")
     self.assertEqual(node.parents, ["split", "split"])
     self.assertEqual(node.src_tensors_indices, [2, 3])
     # mul (EltwiseMul).
-    node = self.test_graph.get_node("mul1")
+    node = self.get_node("mul1")
     self.assertEqual(node.parents, ["add1", "add2"])
 
 class RefResidualGraphTest(unittest.TestCase, ResidualGraphTest):
@@ -907,45 +910,45 @@ def test_parent_children(self):
     """Test the parent/child relationship in the graph."""
 
     # input (Data).
-    node = self.test_graph.get_node("input")
+    node = self.get_node("input")
     self.assertEqual(len(node.parents), 0)
     # conv0 (Convolution).
-    node = self.test_graph.get_node("conv0")
+    node = self.get_node("conv0")
     self.assertEqual(node.parents[0], "input")
     # conv1 (Convolution).
-    node = self.test_graph.get_node("conv1")
+    node = self.get_node("conv1")
     self.assertEqual(node.parents[0], "input")
     # bn (BN).
-    node = self.test_graph.get_node("bn")
+    node = self.get_node("bn")
     self.assertEqual(node.parents[0], "conv1")
     # relu (ReLU).
-    node = self.test_graph.get_node("relu")
+    node = self.get_node("relu")
     self.assertEqual(node.parents[0], "bn")
     # conv2 (Convolution).
-    node = self.test_graph.get_node("conv2")
+    node = self.get_node("conv2")
     self.assertEqual(node.parents[0], "relu")
     # add (EltwiseAdd).
-    node = self.test_graph.get_node("add")
+    node = self.get_node("add")
     self.assertEqual(node.parents, ["conv0", "conv2"])
     # mul (EltwiseMul).
-    node = self.test_graph.get_node("mul")
+    node = self.get_node("mul")
     self.assertEqual(node.parents, ["conv0", "add"])
     # concat (Concat).
-    node = self.test_graph.get_node("concat")
+    node = self.get_node("concat")
     self.assertEqual(node.parents, ["conv0", "mul"])
     # split (Split).
-    node = self.test_graph.get_node("split")
+    node = self.get_node("split")
     self.assertEqual(node.parents[0], "concat")
     # add1 (EltwiseAdd).
-    node = self.test_graph.get_node("add1")
+    node = self.get_node("add1")
     self.assertEqual(node.parents, ["split", "split"])
     self.assertEqual(node.src_tensors_indices, [0, 1])
     # add2 (EltwiseAdd).
-    node = self.test_graph.get_node("add2")
+    node = self.get_node("add2")
     self.assertEqual(node.parents, ["split", "split"])
     self.assertEqual(node.src_tensors_indices, [2, 3])
     # mul (EltwiseMul).
-    node = self.test_graph.get_node("mul1")
+    node = self.get_node("mul1")
     self.assertEqual(node.parents, ["add1", "add2"])
 
 if __name__ == "__main__":
diff --git a/smaug/python/smaug_test.py b/smaug/python/smaug_test.py
index a8ceb5b4..77d8bb18 100644
--- a/smaug/python/smaug_test.py
+++ b/smaug/python/smaug_test.py
@@ -15,6 +15,7 @@
 
 class SmaugTest(unittest.TestCase):
   def setUp(self):
+    self._cwd = os.getcwd()
     self.run_dir = tempfile.mkdtemp()
     self.error_filename = os.path.join(self.run_dir, "stderr")
     self.graph_name = "test_graph"
@@ -25,6 +26,7 @@ def setUp(self):
   def tearDown(self):
     """ Delete temporary files and outputs. """
     shutil.rmtree(self.run_dir)
+    os.chdir(self._cwd)
 
   def launchSubprocess(self, cmd):
     with open(self.error_filename, "w") as f:
diff --git a/smaug/python/subgraph_test.py b/smaug/python/subgraph_test.py
index ab277b8b..4a982df7 100755
--- a/smaug/python/subgraph_test.py
+++ b/smaug/python/subgraph_test.py
@@ -31,7 +31,7 @@ def assertNodesConnected(self, graph, child_parent_map):
     """Test the connection among nodes."""
     for child, parents in child_parent_map.items():
       child_node = graph.get_node(child)
-      self.assertEqual(child_node.parents, parents)
+      self.assertEqual(child_node.get_parents(), parents)
 
   def test_subgraph_merge(self):
     with Graph(parent_graph_name, backend) as parent_graph:
diff --git a/smaug/python/tensor.py b/smaug/python/tensor.py
index 3c024f81..5bb43096 100644
--- a/smaug/python/tensor.py
+++ b/smaug/python/tensor.py
@@ -9,7 +9,7 @@ class Tensor:
   def __init__(
       self, dims=None, name=None, data_layout=types_pb2.NCHW, data_type=None,
       data_format=types_pb2.Uncompressed, tensor_data=None, source=None,
-      targets=None, alignment=None):
+      source_index=None, targets=None, alignment=None):
     """Create a tensor.
 
     Args:
@@ -19,7 +19,8 @@ def __init__(
       data_type: Data type of the tensor.
       data_format: Data format of the tensor.
       tensor_data: A NumPy array that represents the tensor data.
-      source: A tuple (source_node, tensor_index) that represents this tensor's
+      source: A `Node` that represents this tensor's source node.
+      source_index: An int that represents this tensor's output index in its
         source node.
       targets: A list of nodes that use this tensor as inputs.
       alignment: Data alignment used in the tensor data.
@@ -27,57 +28,97 @@ def __init__(
     Returns:
       A `Tensor` object.
     """
-    self.shape = tensor_pb2.TensorShapeProto()
-    self.tensor_data = tensor_data
+    self._shape = tensor_pb2.TensorShapeProto()
+    self._tensor_data = tensor_data
     # If tensor_data is provided, deduce dims and data_type directly from it
     # (the kwargs are ignored if they are provided).
-    if self.tensor_data is not None:
-      self.deduce_attrs_from_data()
+    if self._tensor_data is not None:
+      self._deduce_attrs_from_data()
     else:
-      self.shape.dims.extend(dims)
-      self.data_type = data_type
-
-    self.shape.layout = data_layout
-    self.name = name
-    self.data_format = data_format
-    self.source = source
-    self.targets = []
+      self._shape.dims.extend(dims)
+      self._data_type = data_type
+
+    self._shape.layout = data_layout
+    self._name = name
+    self._data_format = data_format
+    self._source = source
+    self._source_index = source_index
+    if self._source is not None and self._source_index is None:
+      raise ValueError(
+          "Please provide this tensor's output index in the source node!")
+    self._targets = []
     if alignment != None:
-      self.shape.alignment = alignment
+      self._shape.alignment = alignment
     elif global_vars.get_graph() == None:
-      self.shape.alignment = 0
+      self._shape.alignment = 0
     else:
-      self.shape.alignment = global_vars.get_graph().alignment
+      self._shape.alignment = global_vars.get_graph().alignment
 
     # Do data padding if this Tensor contains data.
-    if self.tensor_data is not None:
-      pad_width = [(0, 0) for i in range(len(self.shape.dims) - 1)]
-      pad_width.append((0, self.calc_padding(self.shape.dims[-1])))
-      self.tensor_data = np.pad(self.tensor_data, pad_width, 'constant')
+    if self._tensor_data is not None:
+      pad_width = [(0, 0) for i in range(len(self._shape.dims) - 1)]
+      pad_width.append((0, self.calc_padding(self._shape.dims[-1])))
+      self._tensor_data = np.pad(self._tensor_data, pad_width, 'constant')
+
+  @property
+  def name(self):
+    return self._name
+
+  @name.setter
+  def name(self, name):
+    self._name = name
+
+  @property
+  def shape(self):
+    return self._shape
+
+  @property
+  def data_type(self):
+    return self._data_type
+
+  @property
+  def data_format(self):
+    return self._data_format
+
+  @property
+  def tensor_data(self):
+    return self._tensor_data
+
+  @property
+  def source(self):
+    return self._source
+
+  @property
+  def source_index(self):
+    return self._source_index
+
+  @property
+  def targets(self):
+    return self._targets
 
   def dims(self, index):
     """This returns the size of the dimension."""
-    assert index < len(self.shape.dims), "The dimension index is out of bound!"
-    return self.shape.dims[index]
+    assert index < len(self._shape.dims), "The dimension index is out of bound!"
+    return self._shape.dims[index]
 
-  def deduce_attrs_from_data(self):
+  def _deduce_attrs_from_data(self):
     """Deduce tensor attributes from the supplied tensor data.
 
     The deducible attributes include tensor shape dimensions and data type.
     """
     # Deduce dims from tensor data.
-    self.shape.dims.extend(list(self.tensor_data.shape))
+    self._shape.dims.extend(list(self._tensor_data.shape))
     # Deduce data type from tensor data
     try:
-      self.data_type = datatypes.np_to_smaug_type[self.tensor_data.dtype.type]
+      self._data_type = datatypes.np_to_smaug_type[self._tensor_data.dtype.type]
     except KeyError:
-      assert False, "We don't support numpy dtype: %s" % self.tensor_data.dtype
+      assert False, "We don't support numpy dtype: %s" % self._tensor_data.dtype
 
   def calc_padding(self, value):
     """This returns the size we need to pad on the last dimension."""
-    if self.shape.alignment == 0 or value % self.shape.alignment == 0:
+    if self._shape.alignment == 0 or value % self._shape.alignment == 0:
       return 0
-    return (self.shape.alignment - (value % self.shape.alignment))
+    return (self._shape.alignment - (value % self._shape.alignment))
 
   def to_tensor_proto(self, tensor_proto, tensor_data_array=None):
     """Serialize the tensor into a tensor proto.
@@ -86,39 +127,39 @@ def to_tensor_proto(self, tensor_proto, tensor_data_array=None):
       tensor_proto: The tensor proto this tensor gets serialized into.
       tensor_data_array: The tensor data array this tensor gets serialized into.
     """
-    tensor_proto.name = self.name
-    tensor_proto.shape.CopyFrom(self.shape)
-    tensor_proto.data_type = self.data_type
-    tensor_proto.data_format = self.data_format
-    if self.tensor_data is not None and tensor_data_array is not None:
+    tensor_proto.name = self._name
+    tensor_proto.shape.CopyFrom(self._shape)
+    tensor_proto.data_type = self._data_type
+    tensor_proto.data_format = self._data_format
+    if self._tensor_data is not None and tensor_data_array is not None:
 
       # Since Protobuf doesn't support float16 data type, we pack two float16
       # elements into one int32.
-      if self.data_type == types_pb2.Float16:
+      if self._data_type == types_pb2.Float16:
         # Numpy.view comes in handy here. Note that it won't work if
         # tensor_data's last dimension is of odd size. To solve that, we
         # flatten the tensor data, and if the flattened list is still of
         # odd size, we pad a zero at the end of the list. When we later
         # deserialize the tensor data, we know the correct shape of the
         # tensor, and the padded zero will be discarded.
-        self.tensor_data = self.tensor_data.flatten()
-        if self.tensor_data.size % 2 != 0:
-          self.tensor_data = np.append(self.tensor_data, np.float16(0))
-        self.tensor_data = self.tensor_data.view(np.int32)
+        self._tensor_data = self._tensor_data.flatten()
+        if self._tensor_data.size % 2 != 0:
+          self._tensor_data = np.append(self._tensor_data, np.float16(0))
+        self._tensor_data = self._tensor_data.view(np.int32)
 
       # Serialize the data into the proto.
       tensor_data_proto = tensor_data_array.data_array.add()
       tensor_data_proto.name = tensor_proto.name
-      data_list = [x for x in np.nditer(self.tensor_data)]
-      if self.data_type == types_pb2.Float16:
+      data_list = [x for x in np.nditer(self._tensor_data)]
+      if self._data_type == types_pb2.Float16:
         tensor_data_proto.half_data.extend(data_list)
-      elif self.data_type == types_pb2.Float32:
+      elif self._data_type == types_pb2.Float32:
         tensor_data_proto.float_data.extend(data_list)
-      elif self.data_type == types_pb2.Float64:
+      elif self._data_type == types_pb2.Float64:
         tensor_data_proto.double_data.extend(data_list)
-      elif self.data_type == types_pb2.Int32:
+      elif self._data_type == types_pb2.Int32:
         tensor_data_proto.int_data.extend(data_list)
-      elif self.data_type == types_pb2.Int64:
+      elif self._data_type == types_pb2.Int64:
         tensor_data_proto.int64_data.extend(data_list)
-      elif self.data_type == types_pb2.Bool:
+      elif self._data_type == types_pb2.Bool:
         tensor_data_proto.bool_data.extend(data_list)
diff --git a/smaug/python/tensor_test.py b/smaug/python/tensor_test.py
index 637d83bd..95a52f11 100755
--- a/smaug/python/tensor_test.py
+++ b/smaug/python/tensor_test.py
@@ -6,7 +6,7 @@
 import numpy as np
 
 from smaug.python.tensor_utils import get_tensor_data
-from smaug.python.graph import Graph
+from smaug.python.graph import Graph, get_node_proto
 from smaug.python.tensor import Tensor
 from smaug.python.ops.data_op import input_data
 from smaug.core import types_pb2
@@ -30,14 +30,15 @@ def test_attr_reference(self):
     with Graph("test_graph", "Reference") as test_graph:
       input_tensor = Tensor(data_layout=types_pb2.NHWC, tensor_data=tensor_data)
       act = input_data(input_tensor, "input")
-    self.assertEqual(test_graph.graph.backend, "Reference")
-    node = test_graph.get_node("input")
+    graph_proto, tensor_data_array = test_graph.to_proto()
+    self.assertEqual(graph_proto.backend, "Reference")
+    node = get_node_proto(graph_proto, "input")
     self.assertEqual(node.input_tensors[0].data_type, types_pb2.Float32)
     self.assertEqual(node.input_tensors[0].shape.dims, [2, 2, 4, 4])
     self.assertEqual(node.input_tensors[0].shape.layout, types_pb2.NHWC)
     self.assertEqual(node.input_tensors[0].shape.alignment, 0)
-    tensor_data_proto = get_tensor_data(test_graph.tensor_data_array,
-                                        node.input_tensors[0].name)
+    tensor_data_proto = get_tensor_data(
+        tensor_data_array, node.input_tensors[0].name)
     self.assertEqual(tensor_data_proto.float_data, list(tensor_data.flatten()))
     self.assertEqual(len(tensor_data_proto.half_data), 0)
     self.assertEqual(len(tensor_data_proto.double_data), 0)
@@ -50,14 +51,15 @@ def test_attr_smv_no_padding(self):
     with Graph("test_graph", "SMV") as test_graph:
       input_tensor = Tensor(data_layout=types_pb2.NCHW, tensor_data=tensor_data)
       act = input_data(input_tensor, "input")
-    self.assertEqual(test_graph.graph.backend, "SMV")
-    node = test_graph.get_node("input")
+    graph_proto, tensor_data_array = test_graph.to_proto()
+    self.assertEqual(graph_proto.backend, "SMV")
+    node = get_node_proto(graph_proto, "input")
     self.assertEqual(node.input_tensors[0].data_type, types_pb2.Float16)
     self.assertEqual(node.input_tensors[0].shape.dims, [2, 2, 4, 8])
     self.assertEqual(node.input_tensors[0].shape.layout, types_pb2.NCHW)
     self.assertEqual(node.input_tensors[0].shape.alignment, 8)
-    tensor_data_proto = get_tensor_data(test_graph.tensor_data_array,
-                                        node.input_tensors[0].name)
+    tensor_data_proto = get_tensor_data(
+        tensor_data_array, node.input_tensors[0].name)
     self.assertEqualFP16(tensor_data_proto.half_data, tensor_data.flatten())
     self.assertEqual(len(tensor_data_proto.float_data), 0)
     self.assertEqual(len(tensor_data_proto.double_data), 0)
@@ -71,14 +73,15 @@ def test_attr_smv_padding(self):
     with Graph("test_graph", "SMV") as test_graph:
       input_tensor = Tensor(data_layout=types_pb2.NCHW, tensor_data=tensor_data)
       act = input_data(input_tensor, "input")
-    self.assertEqual(test_graph.graph.backend, "SMV")
-    node = test_graph.get_node("input")
+    graph_proto, tensor_data_array = test_graph.to_proto()
+    self.assertEqual(graph_proto.backend, "SMV")
+    node = get_node_proto(graph_proto, "input")
     self.assertEqual(node.input_tensors[0].data_type, types_pb2.Float16)
     self.assertEqual(node.input_tensors[0].shape.dims, [2, 4])
     self.assertEqual(node.input_tensors[0].shape.layout, types_pb2.NCHW)
     self.assertEqual(node.input_tensors[0].shape.alignment, 8)
-    tensor_data_proto = get_tensor_data(test_graph.tensor_data_array,
-                                        node.input_tensors[0].name)
+    tensor_data_proto = get_tensor_data(
+        tensor_data_array, node.input_tensors[0].name)
     self.assertEqualFP16(
         tensor_data_proto.half_data,
         np.array(
@@ -96,10 +99,11 @@ def test_fp16_even(self):
     with Graph("test_graph", "Reference") as test_graph:
       input_tensor = Tensor(tensor_data=tensor_data)
       act = input_data(input_tensor, "input")
-    node = test_graph.get_node("input")
+    graph_proto, tensor_data_array = test_graph.to_proto()
+    node = get_node_proto(graph_proto, "input")
     self.assertEqual(node.input_tensors[0].data_type, types_pb2.Float16)
-    tensor_data_proto = get_tensor_data(test_graph.tensor_data_array,
-                                        node.input_tensors[0].name)
+    tensor_data_proto = get_tensor_data(
+        tensor_data_array, node.input_tensors[0].name)
     self.assertEqualFP16(tensor_data_proto.half_data, tensor_data.flatten())
 
   def test_fp16_odd(self):
@@ -108,10 +112,11 @@ def test_fp16_odd(self):
     with Graph("test_graph", "Reference") as test_graph:
       input_tensor = Tensor(tensor_data=tensor_data)
       act = input_data(input_tensor, "input")
-    node = test_graph.get_node("input")
+    graph_proto, tensor_data_array = test_graph.to_proto()
+    node = get_node_proto(graph_proto, "input")
     self.assertEqual(node.input_tensors[0].data_type, types_pb2.Float16)
-    tensor_data_proto = get_tensor_data(test_graph.tensor_data_array,
-                                        node.input_tensors[0].name)
+    tensor_data_proto = get_tensor_data(
+        tensor_data_array, node.input_tensors[0].name)
     self.assertEqualFP16(tensor_data_proto.half_data, tensor_data.flatten())
 
   def test_fp16_odd_odd(self):
@@ -123,10 +128,11 @@ def test_fp16_odd_odd(self):
     with Graph("test_graph", "Reference") as test_graph:
       input_tensor = Tensor(tensor_data=tensor_data)
       act = input_data(input_tensor, "input")
-    node = test_graph.get_node("input")
+    graph_proto, tensor_data_array = test_graph.to_proto()
+    node = get_node_proto(graph_proto, "input")
     self.assertEqual(node.input_tensors[0].data_type, types_pb2.Float16)
-    tensor_data_proto = get_tensor_data(test_graph.tensor_data_array,
-                                        node.input_tensors[0].name)
+    tensor_data_proto = get_tensor_data(
+        tensor_data_array, node.input_tensors[0].name)
     self.assertEqualFP16(tensor_data_proto.half_data,
                          np.append(tensor_data.flatten(), np.float16(0)))
 
diff --git a/smaug/python/tensor_utils.py b/smaug/python/tensor_utils.py
index e473b558..4ab41177 100644
--- a/smaug/python/tensor_utils.py
+++ b/smaug/python/tensor_utils.py
@@ -20,53 +20,11 @@ def get_padded_shape(shape):
   shape.dims[-1] += alignment - remainder;
   return shape
 
-def from_tensor_proto(tensor_proto, tensor_data_array=None):
-  """Restore a Tensor from a TensorProto.
-
-  Args:
-    tensor_proto: A TensorProto.
-    tensor_data_array: a TensorDataArray that stores tensor data.
-
-  Returns:
-    A Tensor deserialized from `tensor_proto`.
-  """
-  name = tensor_proto.name
-  data_type = tensor_proto.data_type
-  tensor_data = None
-  if tensor_data_array is not None:
-    tensor_data_proto = get_tensor_data(tensor_data_array, name)
-    if tensor_data_proto is not None:
-      padded_shape = get_padded_shape(shape)
-      if data_type == types_pb2.Float16:
-        tensor_data = tensor_data_proto.half_data
-        if padded_shape.size % 2 != 0:
-          del tensor_data[-1]
-      elif data_type == types_pb2.Float32:
-        tensor_data = tensor_data_proto.float_data
-      elif data_type == types_pb2.Float64:
-        tensor_data = tensor_data_proto.double_data
-      elif data_type == types_pb2.Int32:
-        tensor_data = tensor_data_proto.int_data
-      elif data_type == types_pb2.Int64:
-        tensor_data = tensor_data_proto.int64_data
-      elif data_type == types_pb2.Bool:
-        tensor_data = tensor_data_proto.bool_data
-      # The data retrieved from the proto is one-dimensional, so make it back to
-      # shaped data.
-      tensor_data.reshape(padded_shape.dims)
-
-  tensor = Tensor(
-      dims=tensor_proto.shape.dims, name=name,
-      data_layout=tensor_proto.shape.layout, data_type=data_type,
-      data_format=tensor_proto.data_format, tensor_data=tensor_data)
-  return tensor
-
 def get_tensor_data_op(tensor):
   """Return the output of a data op if this tensor already has one created."""
   for node in tensor.targets:
     if node.op == types_pb2.Data:
-      data_op_output = from_tensor_proto(node.output_tensors[0])
-      data_op_output.source = (node, 0)
+      data_op_output = node.outputs[0]
       return data_op_output
   return None
 
@@ -82,9 +40,7 @@ def get_tensor_reorder_op(tensor, layout):
     returned.
   """
   for node in tensor.targets:
-    if node.op == types_pb2.Reorder and node.output_tensors[
-        0].shape.layout == layout:
-      reorder_op_output = from_tensor_proto(node.output_tensors[0])
-      reorder_op_output.source = (node, 0)
+    if node.op == types_pb2.Reorder and node.outputs[0].shape.layout == layout:
+      reorder_op_output = node.outputs[0]
       return reorder_op_output
   return None