Skip to content

Commit

Permalink
Fix bug in graph parallelization.
Browse files Browse the repository at this point in the history
  • Loading branch information
syamajala committed Mar 22, 2019
1 parent 6ede017 commit fc04e39
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 57 deletions.
3 changes: 1 addition & 2 deletions networkfox/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,8 @@
__author__ = 'hnguyen'
__version__ = '1.2.4'

from .functional import operation, compose
from .functional import operation, compose, If, Else

# For backwards compatibility
from .base import Operation, Var
from .network import Network
from .control import If, Else
43 changes: 0 additions & 43 deletions networkfox/control.py

This file was deleted.

54 changes: 53 additions & 1 deletion networkfox/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from itertools import chain

from .base import Operation, NetworkOperation, Var
from .base import Operation, NetworkOperation, Var, Control
from .network import Network
from .modifiers import optional

Expand Down Expand Up @@ -222,8 +222,60 @@ def order_preserving_uniquifier(seq, seen=None):

# compile network
net = Network()

control_nodes = list(filter(lambda op: isinstance(op, Control), operations))
for idx, control_node in enumerate(control_nodes):
if isinstance(control_node, If):
try:
if isinstance(control_nodes[idx+1], Else):
control_node.Else = control_nodes[idx+1]
except IndexError:
continue
elif isinstance(control_node, Else):
if not isinstance(control_nodes[idx-1], If):
raise Exception("Else not preceded by If")
control_node.If = control_nodes[idx-1]

for op in operations:
net.add_op(op)
net.compile()

return NetworkOperation(name=self.name, needs=needs, provides=provides, params={}, net=net)


class If(Control):

def __init__(self, condition_needs, condition, **kwargs):
super(If, self).__init__(**kwargs)
self.order = 1
self.condition_needs = condition_needs
self.condition = condition
self.computed_condition = False
self.Else = None

def __call__(self, *args):
self.graph = compose(name=self.name)(*args)
return self

def _compute_condition(self, named_inputs):
inputs = [named_inputs[d] for d in self.condition_needs]
self.computed_condition = self.condition(*inputs)
return self.computed_condition

def _compute(self, named_inputs, color=None):
return self.graph(named_inputs, color=color)


class Else(Control):

def __init__(self, **kwargs):
super(Else, self).__init__(**kwargs)
self.order = 2
self.If = None

def __call__(self, *args):
self.graph = compose(name=self.name)(*args)
return self

def _compute(self, named_inputs, color=None):
return self.graph(named_inputs, color=color)
17 changes: 6 additions & 11 deletions networkfox/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -323,9 +323,7 @@ def _compute_parallel(self, necessary_nodes, cache, outputs, named_inputs, pool,
# the upnext list contains a list of operations for scheduling
# in the current round of scheduling
upnext = []

for node in necessary_nodes:

# only delete if all successors for the data node have been executed
if isinstance(node, DeleteInstruction):
if outputs and node not in outputs:
Expand All @@ -338,15 +336,13 @@ def _compute_parallel(self, necessary_nodes, cache, outputs, named_inputs, pool,
elif isinstance(node, Control):
if hasattr(node, 'condition'):
if all(map(lambda need: need in cache, node.condition_needs)):
if_true = node._compute_condition(cache)
if if_true:
if node._compute_condition(cache):
if ready_to_schedule_operation(node, cache, has_executed):
upnext.append((node, cache, color))
else:
# assume short circuiting if statement
if ready_to_schedule_operation(node, cache, has_executed):
upnext.append((node, cache, color))
elif not if_true:
upnext.append((node, cache, color))
elif not node.If.computed_condition:
if ready_to_schedule_operation(node, cache, has_executed):
upnext.append((node, cache, color))

Expand Down Expand Up @@ -388,22 +384,20 @@ def _compute_serial(self, necessary_nodes, cache, outputs, named_inputs, color=N
"""

self.times = {}
if_true = False

for node in necessary_nodes:

if isinstance(node, Control):
if hasattr(node, 'condition'):
if all(map(lambda need: need in cache, node.condition_needs)):
if_true = node._compute_condition(cache)
if if_true:
if node._compute_condition(cache):
layer_outputs = node._compute(cache, color)
cache.update(layer_outputs)
else:
# assume short circuiting if statement
layer_outputs = node._compute(cache, color)
cache.update(layer_outputs)
elif not if_true:
elif not node.If.computed_condition:
layer_outputs = node._compute(cache, color)
cache.update(layer_outputs)

Expand Down Expand Up @@ -569,6 +563,7 @@ def ready_to_schedule_operation(op, cache, has_executed):

return True


def ready_to_delete_data_node(name, has_executed, graph):
"""
Determines if a DataPlaceholderNode is ready to be deleted from the
Expand Down

0 comments on commit fc04e39

Please sign in to comment.