diff --git a/graphkit/base.py b/graphkit/base.py index 1c04e8d5..1424ef20 100644 --- a/graphkit/base.py +++ b/graphkit/base.py @@ -1,5 +1,10 @@ # Copyright 2016, Yahoo Inc. # Licensed under the terms of the Apache License, Version 2.0. See the LICENSE file associated with the project for terms. +try: + from collections import abc +except ImportError: + import collections as abstract + class Data(object): """ @@ -151,9 +156,12 @@ def __init__(self, **kwargs): # set execution mode to single-threaded sequential by default self._execution_method = "sequential" + self._overwrites_collector = None def _compute(self, named_inputs, outputs=None): - return self.net.compute(outputs, named_inputs, method=self._execution_method) + return self.net.compute( + outputs, named_inputs, method=self._execution_method, + overwrites_collector=self._overwrites_collector) def __call__(self, *args, **kwargs): return self._compute(*args, **kwargs) @@ -162,15 +170,35 @@ def set_execution_method(self, method): """ Determine how the network will be executed. - Args: - method: str - If "parallel", execute graph operations concurrently - using a threadpool. + :param str method: + If "parallel", execute graph operations concurrently + using a threadpool. """ - options = ['parallel', 'sequential'] - assert method in options + choices = ['parallel', 'sequential'] + if method not in choices: + raise ValueError( + "Invalid computation method %r! Must be one of %s" + (method, choices)) self._execution_method = method + def set_overwrites_collector(self, collector): + """ + Asks to put all *overwrites* into the `collector` after computing + + An "overwrites" is intermediate value calculated but NOT stored + into the results, becaues it has been given also as an intemediate + input value, and the operation that would overwrite it MUST run for + its other results. + + :param collector: + a mutable dict to be fillwed with named values + """ + if not isinstance(collector, abc.MutableMapping): + raise ValueError( + "Overwrites collector was not a MutableMapping, but: %r" + % collector) + self._overwrites_collector = collector + def plot(self, filename=None, show=False): self.net.plot(filename=filename, show=show) diff --git a/graphkit/network.py b/graphkit/network.py index 4ef4b4c4..0dbba46c 100644 --- a/graphkit/network.py +++ b/graphkit/network.py @@ -34,6 +34,18 @@ def __repr__(self): return 'DeleteInstruction("%s")' % self +class PinInstruction(str): + """ + An instruction in the *execution plan* not to store the newly compute value + into network's values-cache but to pin it instead to some given value. + It is used ensure that given intermediate values are not overwritten when + their providing functions could not be avoided, because their other outputs + are needed elesewhere. + """ + def __repr__(self): + return 'PinInstruction("%s")' % self + + class Network(object): """ This is the main network implementation. The class contains all of the @@ -41,7 +53,7 @@ class Network(object): and pass data through. The computation, ie the execution of the *operations* for given *inputs* - and asked *outputs* is based on 3 data-structures: + and asked *outputs* is based on 4 data-structures: - The ``networkx`` :attr:`graph` DAG, containing interchanging layers of :class:`Operation` and :class:`DataPlaceholderNode` nodes. @@ -68,6 +80,12 @@ class Network(object): - the :var:`cache` local-var, initialized on each run of both ``_compute_xxx`` methods (for parallel or sequential executions), to hold all given input & generated (aka intermediate) data values. + + - the :var:`overwrites` local-var, initialized on each run of both + ``_compute_xxx`` methods (for parallel or sequential executions), to + hold values calculated but overwritten (aka "pinned") by intermediate + input-values. + """ def __init__(self, **kwargs): @@ -122,7 +140,7 @@ def add_op(self, operation): def list_layers(self, debug=False): # Make a generic plan. - plan = self._build_execution_plan(self.graph) + plan = self._build_execution_plan(self.graph, ()) return [n for n in plan if debug or isinstance(n, Operation)] @@ -134,7 +152,7 @@ def show_layers(self, debug=False, ret=False): else: print(s) - def _build_execution_plan(self, dag): + def _build_execution_plan(self, dag, inputs): """ Create the list of operation-nodes & *instructions* evaluating all @@ -162,7 +180,10 @@ def _build_execution_plan(self, dag): for i, node in enumerate(ordered_nodes): if isinstance(node, DataPlaceholderNode): - continue + if node in inputs and self.graph.pred[node]: + # Command pinning only when there is another operation + # generating this data as output. + plan.append(PinInstruction(node)) elif isinstance(node, Operation): @@ -297,7 +318,7 @@ def _solve_dag(self, outputs, inputs): unsatisfied = self._collect_unsatisfied_operations(broken_dag, inputs) shrinked_dag = dag.subgraph(broken_dag.nodes - unsatisfied) - plan = self._build_execution_plan(shrinked_dag) + plan = self._build_execution_plan(shrinked_dag, inputs) return plan @@ -331,7 +352,8 @@ def compile(self, outputs=(), inputs=()): - def compute(self, outputs, named_inputs, method=None): + def compute( + self, outputs, named_inputs, method=None, overwrites_collector=None): """ Run the graph. Any inputs to the network must be passed in by name. @@ -350,6 +372,10 @@ def compute(self, outputs, named_inputs, method=None): Set when invoking a composed graph or by :meth:`~NetworkOperation.set_execution_method()`. + :param overwrites_collector: + (optional) a mutable dict to be fillwed with named values. + If missing, values are simply discarded. + :returns: a dictionary of output data objects, keyed by name. """ @@ -364,23 +390,33 @@ def compute(self, outputs, named_inputs, method=None): # choose a method of execution if method == "parallel": - self._compute_thread_pool_barrier_method(cache) + self._compute_thread_pool_barrier_method( + cache, overwrites_collector, named_inputs) else: - self._compute_sequential_method(cache, outputs) + self._compute_sequential_method( + cache, overwrites_collector, named_inputs, outputs) if not outputs: # Return the whole cache as output, including input and # intermediate data nodes. - return cache + result = cache else: # Filter outputs to just return what's needed. # Note: list comprehensions exist in python 2.7+ - return dict(i for i in cache.items() if i[0] in outputs) + result = dict(i for i in cache.items() if i[0] in outputs) + + return result + + + def _pin_data_in_cache(self, value_name, cache, inputs, overwrites): + if overwrites is not None: + overwrites[value_name] = cache[value_name] + cache[value_name] = inputs[value_name] def _compute_thread_pool_barrier_method( - self, cache, thread_pool_size=10 + self, cache, overwrites, inputs, thread_pool_size=10 ): """ This method runs the graph using a parallel pool of thread executors. @@ -436,7 +472,7 @@ def _compute_thread_pool_barrier_method( has_executed.add(op) - def _compute_sequential_method(self, cache, outputs): + def _compute_sequential_method(self, cache, overwrites, inputs, outputs): """ This method runs the graph one operation at a time in a single thread """ @@ -477,6 +513,8 @@ def _compute_sequential_method(self, cache, outputs): print("removing data '%s' from cache." % step) cache.pop(step) + elif isinstance(step, PinInstruction): + self._pin_data_in_cache(step, cache, inputs, overwrites) else: raise AssertionError("Unrecognized instruction.%r" % step) diff --git a/test/test_graphkit.py b/test/test_graphkit.py index 47f536b3..918e6c61 100644 --- a/test/test_graphkit.py +++ b/test/test_graphkit.py @@ -228,10 +228,17 @@ def test_pruning_multiouts_not_override_intermediates1(): # FAILs # - on v1.2.4 with (overriden, asked) = (5, 15) instead of (1, 11) # - on #18(unsatisfied) + #23(ordered-sets) like v1.2.4. + # FIXED on #26 assert netop({"a": 5, "overriden": 1}) == exp # FAILs # - on v1.2.4 with KeyError: 'e', # - on #18(unsatisfied) + #23(ordered-sets) with empty result. + # FIXED on #26 + assert netop({"a": 5, "overriden": 1, "c": 2}, ["asked"]) == filtdict(exp, "asked") + + ## Test multihtreading + netop.set_execution_method("parallel") + assert netop({"a": 5, "overriden": 1}) == exp assert netop({"a": 5, "overriden": 1, "c": 2}, ["asked"]) == filtdict(exp, "asked") @@ -249,11 +256,13 @@ def test_pruning_multiouts_not_override_intermediates2(): # FAILs # - on v1.2.4 with (overriden, asked) = (5, 70) instead of (1, 13) # - on #18(unsatisfied) + #23(ordered-sets) like v1.2.4. + # FIXED on #26 assert netop({"a": 5, "overriden": 1, "c": 2}) == exp # FAILs # - on v1.2.4 with KeyError: 'e', # - on #18(unsatisfied) + #23(ordered-sets) with empty result. assert netop({"a": 5, "overriden": 1, "c": 2}, ["asked"]) == filtdict(exp, "asked") + # FIXED on #26 def test_pruning_with_given_intermediate_and_asked_out():