Skip to content

Commit

Permalink
WIP/FIX(prune): PIN intermediate inputs if operation before must run
Browse files Browse the repository at this point in the history
+ Insert "PinInstructions" in the execution-plan to avoid overwritting.
+ Add `_overwrite_collector` in `compose()` to collect re-calculated values.
+ FIX the last TC in #25.
  • Loading branch information
ankostis committed Oct 3, 2019
1 parent 0830b7c commit 008d501
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 19 deletions.
42 changes: 35 additions & 7 deletions graphkit/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
# Copyright 2016, Yahoo Inc.
# Licensed under the terms of the Apache License, Version 2.0. See the LICENSE file associated with the project for terms.
try:
from collections import abc
except ImportError:
import collections as abstract


class Data(object):
"""
Expand Down Expand Up @@ -151,9 +156,12 @@ def __init__(self, **kwargs):

# set execution mode to single-threaded sequential by default
self._execution_method = "sequential"
self._overwrites_collector = None

def _compute(self, named_inputs, outputs=None):
return self.net.compute(outputs, named_inputs, method=self._execution_method)
return self.net.compute(
outputs, named_inputs, method=self._execution_method,
overwrites_collector=self._overwrites_collector)

def __call__(self, *args, **kwargs):
return self._compute(*args, **kwargs)
Expand All @@ -162,15 +170,35 @@ def set_execution_method(self, method):
"""
Determine how the network will be executed.
Args:
method: str
If "parallel", execute graph operations concurrently
using a threadpool.
:param str method:
If "parallel", execute graph operations concurrently
using a threadpool.
"""
options = ['parallel', 'sequential']
assert method in options
choices = ['parallel', 'sequential']
if method not in choices:
raise ValueError(
"Invalid computation method %r! Must be one of %s"
(method, choices))
self._execution_method = method

def set_overwrites_collector(self, collector):
"""
Asks to put all *overwrites* into the `collector` after computing
An "overwrites" is intermediate value calculated but NOT stored
into the results, becaues it has been given also as an intemediate
input value, and the operation that would overwrite it MUST run for
its other results.
:param collector:
a mutable dict to be fillwed with named values
"""
if not isinstance(collector, abc.MutableMapping):
raise ValueError(
"Overwrites collector was not a MutableMapping, but: %r"
% collector)
self._overwrites_collector = collector

def plot(self, filename=None, show=False):
self.net.plot(filename=filename, show=show)

Expand Down
62 changes: 50 additions & 12 deletions graphkit/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,26 @@ def __repr__(self):
return 'DeleteInstruction("%s")' % self


class PinInstruction(str):
"""
An instruction in the *execution plan* not to store the newly compute value
into network's values-cache but to pin it instead to some given value.
It is used ensure that given intermediate values are not overwritten when
their providing functions could not be avoided, because their other outputs
are needed elesewhere.
"""
def __repr__(self):
return 'PinInstruction("%s")' % self


class Network(object):
"""
This is the main network implementation. The class contains all of the
code necessary to weave together operations into a directed-acyclic-graph (DAG)
and pass data through.
The computation, ie the execution of the *operations* for given *inputs*
and asked *outputs* is based on 3 data-structures:
and asked *outputs* is based on 4 data-structures:
- The ``networkx`` :attr:`graph` DAG, containing interchanging layers of
:class:`Operation` and :class:`DataPlaceholderNode` nodes.
Expand All @@ -68,6 +80,12 @@ class Network(object):
- the :var:`cache` local-var, initialized on each run of both
``_compute_xxx`` methods (for parallel or sequential executions), to
hold all given input & generated (aka intermediate) data values.
- the :var:`overwrites` local-var, initialized on each run of both
``_compute_xxx`` methods (for parallel or sequential executions), to
hold values calculated but overwritten (aka "pinned") by intermediate
input-values.
"""

def __init__(self, **kwargs):
Expand Down Expand Up @@ -122,7 +140,7 @@ def add_op(self, operation):

def list_layers(self, debug=False):
# Make a generic plan.
plan = self._build_execution_plan(self.graph)
plan = self._build_execution_plan(self.graph, ())
return [n for n in plan if debug or isinstance(n, Operation)]


Expand All @@ -134,7 +152,7 @@ def show_layers(self, debug=False, ret=False):
else:
print(s)

def _build_execution_plan(self, dag):
def _build_execution_plan(self, dag, inputs):
"""
Create the list of operation-nodes & *instructions* evaluating all
Expand Down Expand Up @@ -162,7 +180,10 @@ def _build_execution_plan(self, dag):
for i, node in enumerate(ordered_nodes):

if isinstance(node, DataPlaceholderNode):
continue
if node in inputs and self.graph.pred[node]:
# Command pinning only when there is another operation
# generating this data as output.
plan.append(PinInstruction(node))

elif isinstance(node, Operation):

Expand Down Expand Up @@ -297,7 +318,7 @@ def _solve_dag(self, outputs, inputs):
unsatisfied = self._collect_unsatisfied_operations(broken_dag, inputs)
shrinked_dag = dag.subgraph(broken_dag.nodes - unsatisfied)

plan = self._build_execution_plan(shrinked_dag)
plan = self._build_execution_plan(shrinked_dag, inputs)

return plan

Expand Down Expand Up @@ -331,7 +352,8 @@ def compile(self, outputs=(), inputs=()):



def compute(self, outputs, named_inputs, method=None):
def compute(
self, outputs, named_inputs, method=None, overwrites_collector=None):
"""
Run the graph. Any inputs to the network must be passed in by name.
Expand All @@ -350,6 +372,10 @@ def compute(self, outputs, named_inputs, method=None):
Set when invoking a composed graph or by
:meth:`~NetworkOperation.set_execution_method()`.
:param overwrites_collector:
(optional) a mutable dict to be fillwed with named values.
If missing, values are simply discarded.
:returns: a dictionary of output data objects, keyed by name.
"""

Expand All @@ -364,23 +390,33 @@ def compute(self, outputs, named_inputs, method=None):

# choose a method of execution
if method == "parallel":
self._compute_thread_pool_barrier_method(cache)
self._compute_thread_pool_barrier_method(
cache, overwrites_collector, named_inputs)
else:
self._compute_sequential_method(cache, outputs)
self._compute_sequential_method(
cache, overwrites_collector, named_inputs, outputs)

if not outputs:
# Return the whole cache as output, including input and
# intermediate data nodes.
return cache
result = cache

else:
# Filter outputs to just return what's needed.
# Note: list comprehensions exist in python 2.7+
return dict(i for i in cache.items() if i[0] in outputs)
result = dict(i for i in cache.items() if i[0] in outputs)

return result


def _pin_data_in_cache(self, value_name, cache, inputs, overwrites):
if overwrites is not None:
overwrites[value_name] = cache[value_name]
cache[value_name] = inputs[value_name]


def _compute_thread_pool_barrier_method(
self, cache, thread_pool_size=10
self, cache, overwrites, inputs, thread_pool_size=10
):
"""
This method runs the graph using a parallel pool of thread executors.
Expand Down Expand Up @@ -436,7 +472,7 @@ def _compute_thread_pool_barrier_method(
has_executed.add(op)


def _compute_sequential_method(self, cache, outputs):
def _compute_sequential_method(self, cache, overwrites, inputs, outputs):
"""
This method runs the graph one operation at a time in a single thread
"""
Expand Down Expand Up @@ -477,6 +513,8 @@ def _compute_sequential_method(self, cache, outputs):
print("removing data '%s' from cache." % step)
cache.pop(step)

elif isinstance(step, PinInstruction):
self._pin_data_in_cache(step, cache, inputs, overwrites)
else:
raise AssertionError("Unrecognized instruction.%r" % step)

Expand Down
9 changes: 9 additions & 0 deletions test/test_graphkit.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,10 +228,17 @@ def test_pruning_multiouts_not_override_intermediates1():
# FAILs
# - on v1.2.4 with (overriden, asked) = (5, 15) instead of (1, 11)
# - on #18(unsatisfied) + #23(ordered-sets) like v1.2.4.
# FIXED on #26
assert netop({"a": 5, "overriden": 1}) == exp
# FAILs
# - on v1.2.4 with KeyError: 'e',
# - on #18(unsatisfied) + #23(ordered-sets) with empty result.
# FIXED on #26
assert netop({"a": 5, "overriden": 1, "c": 2}, ["asked"]) == filtdict(exp, "asked")

## Test multihtreading
netop.set_execution_method("parallel")
assert netop({"a": 5, "overriden": 1}) == exp
assert netop({"a": 5, "overriden": 1, "c": 2}, ["asked"]) == filtdict(exp, "asked")


Expand All @@ -249,11 +256,13 @@ def test_pruning_multiouts_not_override_intermediates2():
# FAILs
# - on v1.2.4 with (overriden, asked) = (5, 70) instead of (1, 13)
# - on #18(unsatisfied) + #23(ordered-sets) like v1.2.4.
# FIXED on #26
assert netop({"a": 5, "overriden": 1, "c": 2}) == exp
# FAILs
# - on v1.2.4 with KeyError: 'e',
# - on #18(unsatisfied) + #23(ordered-sets) with empty result.
assert netop({"a": 5, "overriden": 1, "c": 2}, ["asked"]) == filtdict(exp, "asked")
# FIXED on #26


def test_pruning_with_given_intermediate_and_asked_out():
Expand Down

0 comments on commit 008d501

Please sign in to comment.