diff --git a/graphkit/network.py b/graphkit/network.py index 06447271..1520c2ef 100644 --- a/graphkit/network.py +++ b/graphkit/network.py @@ -34,6 +34,18 @@ def __repr__(self): return 'DeleteInstruction("%s")' % self +class PinInstruction(str): + """ + An instruction in the *execution plan* not to store the newly compute value + into network's values-cache but to pin it instead to some given value. + It is used ensure that given intermediate values are not overwritten when + their providing functions could not be avoided, because their other outputs + are needed elesewhere. + """ + def __repr__(self): + return 'PinInstruction("%s")' % self + + class Network(object): """ This is the main network implementation. The class contains all of the @@ -187,45 +199,56 @@ def _build_execution_plan(self, dag): return plan - def _collect_unsatisfiable_operations(self, necessary_nodes, inputs): + def _collect_unsatisfied_operations(self, dag, inputs): """ - Traverse ordered graph and mark satisfied needs on each operation, + Traverse topologically sorted dag to collect un-satisfied operations. - collecting those missing at least one. - Since the graph is ordered, as soon as we're on an operation, - all its needs have been accounted, so we can get its satisfaction. + Unsatisfied operations are those suffering from ANY of the following: + + - They are missing at least one compulsory need-input. + Since the dag is ordered, as soon as we're on an operation, + all its needs have been accounted, so we can get its satisfaction. + + - Their provided outputs are not linked to any data in the dag. + An operation might not have any output link when :meth:`_solve_dag()` + has broken them, due to given intermediate inputs. - :param necessary_nodes: - the subset of the graph to consider but WITHOUT the initial data - (because that is what :meth:`compile()` can gives us...) + :param dag: + the graph to consider :param inputs: an iterable of the names of the input values return: - a list of unsatisfiable operations + a list of unsatisfied operations to prune """ - G = self.graph # shortcut - ok_data = set(inputs) # to collect producible data - op_satisfaction = defaultdict(set) # to collect operation satisfiable needs - unsatisfiables = [] # to collect operations with partial needs - # We also need inputs to mark op_satisfaction. - nodes = chain(necessary_nodes, inputs) # note that `inputs` are plain strings - for node in nx.topological_sort(G.subgraph(nodes)): + # To collect data that will be produced. + ok_data = set(inputs) + # To colect the map of operations --> satisfied-needs. + op_satisfaction = defaultdict(set) + # To collect the operations to drop. + unsatisfied = [] + for node in nx.topological_sort(dag): if isinstance(node, Operation): - real_needs = set(n for n in node.needs if not isinstance(n, optional)) - if real_needs.issubset(op_satisfaction[node]): - # mark all future data-provides as ok - ok_data.update(G.adj[node]) + if not dag.adj[node]: + # Prune operations ending up without any provided-outputs. + unsatisfied.append(node) else: - unsatisfiables.append(node) + real_needs = set(n for n in node.needs if not isinstance(n, optional)) + if real_needs.issubset(op_satisfaction[node]): + # We have a satisfied operation; mark its output-data + # as ok. + ok_data.update(dag.adj[node]) + else: + # Prune operations with partial inputs. + unsatisfied.append(node) elif isinstance(node, (DataPlaceholderNode, str)): # `str` are givens if node in ok_data: # mark satisfied-needs on all future operations - for future_op in G.adj[node]: + for future_op in dag.adj[node]: op_satisfaction[future_op].add(node) else: raise AssertionError("Unrecognized network graph node %r" % node) - return unsatisfiables + return unsatisfied def _solve_dag(self, outputs, inputs): @@ -246,50 +269,44 @@ def _solve_dag(self, outputs, inputs): :return: the subgraph comprising the solution - """ - graph = self.graph - if not outputs: + dag = self.graph + + # Ignore input names that aren't in the graph. + graph_inputs = iset(dag.nodes) & inputs # preserve order - # If caller requested all outputs, the necessary nodes are all - # nodes that are reachable from one of the inputs. Ignore input - # names that aren't in the graph. - necessary_nodes = set() # unordered, not iterated - for input_name in iter(inputs): - if graph.has_node(input_name): - necessary_nodes |= nx.descendants(graph, input_name) + # Scream if some requested outputs aren't in the graph. + unknown_outputs = iset(outputs) - dag.nodes + if unknown_outputs: + raise ValueError( + "Unknown output node(s) requested: %s" + % ", ".join(unknown_outputs)) - else: + dag = dag.copy() # preserve net's graph + + # Break the incoming edges to all given inputs. + # + # Nodes producing any given intermediate inputs are unecessary + # (unless they are also used elsewhere). + # To discover which ones to prune, we break their incoming edges + # and they will drop out while collecting ancestors from the outputs. + for given in graph_inputs: + dag.remove_edges_from(list(dag.in_edges(given))) + + if outputs: + # If caller requested specific outputs, we can prune any + # unrelated nodes further up the dag. + ending_in_outputs = set() + for input_name in outputs: + ending_in_outputs.update(nx.ancestors(dag, input_name)) + dag = dag.subgraph(ending_in_outputs | set(outputs)) - # If the caller requested a subset of outputs, find any nodes that - # are made unecessary because we were provided with an input that's - # deeper into the network graph. Ignore input names that aren't - # in the graph. - unnecessary_nodes = set() # unordered, not iterated - for input_name in iter(inputs): - if graph.has_node(input_name): - unnecessary_nodes |= nx.ancestors(graph, input_name) - - # Find the nodes we need to be able to compute the requested - # outputs. Raise an exception if a requested output doesn't - # exist in the graph. - necessary_nodes = set() # unordered, not iterated - for output_name in outputs: - if not graph.has_node(output_name): - raise ValueError("graphkit graph does not have an output " - "node named %s" % output_name) - necessary_nodes |= nx.ancestors(graph, output_name) - - # Get rid of the unnecessary nodes from the set of necessary ones. - necessary_nodes -= unnecessary_nodes # Drop (un-satifiable) operations with partial inputs. # See yahoo/graphkit#18 # - unsatisfiables = self._collect_unsatisfiable_operations(necessary_nodes, inputs) - necessary_nodes -= set(unsatisfiables) - - shrinked_dag = graph.subgraph(necessary_nodes) + unsatisfied = self._collect_unsatisfied_operations(dag, inputs) + shrinked_dag = dag.subgraph(dag.nodes - unsatisfied) return shrinked_dag diff --git a/test/test_graphkit.py b/test/test_graphkit.py index be4b0e86..01a7f133 100644 --- a/test/test_graphkit.py +++ b/test/test_graphkit.py @@ -265,12 +265,13 @@ def test_pruning_with_given_intermediate_and_asked_out(): operation(name="good_op", needs=["a", "given-2"], provides=["asked"])(add), ) - exp = {"given-1": 5, "b": 2, "given-2": 7, "a": 5, "asked": 12} + exp = {"given-1": 5, "b": 2, "given-2": 2, "a": 5, "asked": 7} # v1.2.4 is ok assert netop({"given-1": 5, "b": 2, "given-2": 2}) == exp # FAILS # - on v1.2.4 with KeyError: 'a', # - on #18 (unsatisfied) with no result. + # FIXED on #18+#26 (new dag solver). assert netop({"given-1": 5, "b": 2, "given-2": 2}, ["asked"]) == filtdict(exp, "asked")