Skip to content

Commit

Permalink
ENH(DAG): NEW SOLVER
Browse files Browse the repository at this point in the history
+ Pruning behaves correctly also when outputs given;
  this happens by breaking incoming provide-links
  to any given intermedediate inputs.
+ Unsatisfied detection now includes those without outputs
  due to broken links (above).
+ Remove some uneeded "glue" from unsatisfied-detection code,
  leftover from previous compile() refactoring.
+ Renamed satisfiable --> satisfied.
+ Improved unknown output requested raise-message.
+ x2 TCs, in #24 and 1st in #25 now PASS.
- 2x TCs in #25 still FAIL, and need "Pinning" of given-inputs
  (the operation MUST and MUST NOT run in these cases).
  • Loading branch information
ankostis committed Oct 3, 2019
1 parent 2ce2a43 commit 7e851b1
Show file tree
Hide file tree
Showing 2 changed files with 78 additions and 60 deletions.
135 changes: 76 additions & 59 deletions graphkit/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,18 @@ def __repr__(self):
return 'DeleteInstruction("%s")' % self


class PinInstruction(str):
"""
An instruction in the *execution plan* not to store the newly compute value
into network's values-cache but to pin it instead to some given value.
It is used ensure that given intermediate values are not overwritten when
their providing functions could not be avoided, because their other outputs
are needed elesewhere.
"""
def __repr__(self):
return 'PinInstruction("%s")' % self


class Network(object):
"""
This is the main network implementation. The class contains all of the
Expand Down Expand Up @@ -187,45 +199,56 @@ def _build_execution_plan(self, dag):

return plan

def _collect_unsatisfiable_operations(self, necessary_nodes, inputs):
def _collect_unsatisfied_operations(self, dag, inputs):
"""
Traverse ordered graph and mark satisfied needs on each operation,
Traverse topologically sorted dag to collect un-satisfied operations.
collecting those missing at least one.
Since the graph is ordered, as soon as we're on an operation,
all its needs have been accounted, so we can get its satisfaction.
Unsatisfied operations are those suffering from ANY of the following:
- They are missing at least one compulsory need-input.
Since the dag is ordered, as soon as we're on an operation,
all its needs have been accounted, so we can get its satisfaction.
- Their provided outputs are not linked to any data in the dag.
An operation might not have any output link when :meth:`_solve_dag()`
has broken them, due to given intermediate inputs.
:param necessary_nodes:
the subset of the graph to consider but WITHOUT the initial data
(because that is what :meth:`compile()` can gives us...)
:param dag:
the graph to consider
:param inputs:
an iterable of the names of the input values
return:
a list of unsatisfiable operations
a list of unsatisfied operations to prune
"""
G = self.graph # shortcut
ok_data = set(inputs) # to collect producible data
op_satisfaction = defaultdict(set) # to collect operation satisfiable needs
unsatisfiables = [] # to collect operations with partial needs
# We also need inputs to mark op_satisfaction.
nodes = chain(necessary_nodes, inputs) # note that `inputs` are plain strings
for node in nx.topological_sort(G.subgraph(nodes)):
# To collect data that will be produced.
ok_data = set(inputs)
# To colect the map of operations --> satisfied-needs.
op_satisfaction = defaultdict(set)
# To collect the operations to drop.
unsatisfied = []
for node in nx.topological_sort(dag):
if isinstance(node, Operation):
real_needs = set(n for n in node.needs if not isinstance(n, optional))
if real_needs.issubset(op_satisfaction[node]):
# mark all future data-provides as ok
ok_data.update(G.adj[node])
if not dag.adj[node]:
# Prune operations ending up without any provided-outputs.
unsatisfied.append(node)
else:
unsatisfiables.append(node)
real_needs = set(n for n in node.needs if not isinstance(n, optional))
if real_needs.issubset(op_satisfaction[node]):
# We have a satisfied operation; mark its output-data
# as ok.
ok_data.update(dag.adj[node])
else:
# Prune operations with partial inputs.
unsatisfied.append(node)
elif isinstance(node, (DataPlaceholderNode, str)): # `str` are givens
if node in ok_data:
# mark satisfied-needs on all future operations
for future_op in G.adj[node]:
for future_op in dag.adj[node]:
op_satisfaction[future_op].add(node)
else:
raise AssertionError("Unrecognized network graph node %r" % node)

return unsatisfiables
return unsatisfied


def _solve_dag(self, outputs, inputs):
Expand All @@ -246,50 +269,44 @@ def _solve_dag(self, outputs, inputs):
:return:
the subgraph comprising the solution
"""
graph = self.graph
if not outputs:
dag = self.graph

# Ignore input names that aren't in the graph.
graph_inputs = iset(dag.nodes) & inputs # preserve order

# If caller requested all outputs, the necessary nodes are all
# nodes that are reachable from one of the inputs. Ignore input
# names that aren't in the graph.
necessary_nodes = set() # unordered, not iterated
for input_name in iter(inputs):
if graph.has_node(input_name):
necessary_nodes |= nx.descendants(graph, input_name)
# Scream if some requested outputs aren't in the graph.
unknown_outputs = iset(outputs) - dag.nodes
if unknown_outputs:
raise ValueError(
"Unknown output node(s) requested: %s"
% ", ".join(unknown_outputs))

else:
dag = dag.copy() # preserve net's graph

# Break the incoming edges to all given inputs.
#
# Nodes producing any given intermediate inputs are unecessary
# (unless they are also used elsewhere).
# To discover which ones to prune, we break their incoming edges
# and they will drop out while collecting ancestors from the outputs.
for given in graph_inputs:
dag.remove_edges_from(list(dag.in_edges(given)))

if outputs:
# If caller requested specific outputs, we can prune any
# unrelated nodes further up the dag.
ending_in_outputs = set()
for input_name in outputs:
ending_in_outputs.update(nx.ancestors(dag, input_name))
dag = dag.subgraph(ending_in_outputs | set(outputs))

# If the caller requested a subset of outputs, find any nodes that
# are made unecessary because we were provided with an input that's
# deeper into the network graph. Ignore input names that aren't
# in the graph.
unnecessary_nodes = set() # unordered, not iterated
for input_name in iter(inputs):
if graph.has_node(input_name):
unnecessary_nodes |= nx.ancestors(graph, input_name)

# Find the nodes we need to be able to compute the requested
# outputs. Raise an exception if a requested output doesn't
# exist in the graph.
necessary_nodes = set() # unordered, not iterated
for output_name in outputs:
if not graph.has_node(output_name):
raise ValueError("graphkit graph does not have an output "
"node named %s" % output_name)
necessary_nodes |= nx.ancestors(graph, output_name)

# Get rid of the unnecessary nodes from the set of necessary ones.
necessary_nodes -= unnecessary_nodes

# Drop (un-satifiable) operations with partial inputs.
# See yahoo/graphkit#18
#
unsatisfiables = self._collect_unsatisfiable_operations(necessary_nodes, inputs)
necessary_nodes -= set(unsatisfiables)

shrinked_dag = graph.subgraph(necessary_nodes)
unsatisfied = self._collect_unsatisfied_operations(dag, inputs)
shrinked_dag = dag.subgraph(dag.nodes - unsatisfied)

return shrinked_dag

Expand Down
3 changes: 2 additions & 1 deletion test/test_graphkit.py
Original file line number Diff line number Diff line change
Expand Up @@ -265,12 +265,13 @@ def test_pruning_with_given_intermediate_and_asked_out():
operation(name="good_op", needs=["a", "given-2"], provides=["asked"])(add),
)

exp = {"given-1": 5, "b": 2, "given-2": 7, "a": 5, "asked": 12}
exp = {"given-1": 5, "b": 2, "given-2": 2, "a": 5, "asked": 7}
# v1.2.4 is ok
assert netop({"given-1": 5, "b": 2, "given-2": 2}) == exp
# FAILS
# - on v1.2.4 with KeyError: 'a',
# - on #18 (unsatisfied) with no result.
# FIXED on #18+#26 (new dag solver).
assert netop({"given-1": 5, "b": 2, "given-2": 2}, ["asked"]) == filtdict(exp, "asked")


Expand Down

0 comments on commit 7e851b1

Please sign in to comment.