diff --git a/README.md b/README.md index e8d2581d2..6de427951 100644 --- a/README.md +++ b/README.md @@ -35,8 +35,8 @@ For more complex Nodes, the INode interface can be implemented directly. ```python class ConvertTime(INode): - def __init__(self, time=None, timezone=0, city=None): - super(ConvertTime, self).__init__() + def __init__(self, time=None, timezone=0, city=None, **kwargs): + super(ConvertTime, self).__init__(**kwargs) InputPlug('time', self) InputPlug('timezone', self, timezone) InputPlug('city', self, city) @@ -61,21 +61,27 @@ def WorldClock(time1, time2, time3): print('----------------------------------') ``` -Now we can create the Graph that represents the world clock. First create all the necessary Nodes. +Now we can create the Graph that represents the world clock: ```python -current_time = CurrentTime() -van = ConvertTime(city='Vancouver', timezone=-8) -ldn = ConvertTime(city='London', timezone=0) -muc = ConvertTime(city='Munich', timezone=1) -world_clock = WorldClock() +graph = Graph(name="WorldClockGraph") ``` -The nodes are now grouped into a Graph, representing the world clock +Now we create all the necessary Nodes: + +```python +current_time = CurrentTime(graph=graph) +van = ConvertTime(city='Vancouver', timezone=-8, graph=graph) +ldn = ConvertTime(city='London', timezone=0, graph=graph) +muc = ConvertTime(city='Munich', timezone=1, graph=graph) +world_clock = WorldClock(graph=graph) +``` + +By specifying the "graph" attribute on the Nodes get added to the Graph automatically. + The Nodes can now be wired together. The bitshift operator is used as a shorthand to connect the plugs. ```python -graph = Graph(name="WorldClockGraph", nodes=[current_time, van, ldn, muc, world_clock]) current_time.outputs['time'] >> van.inputs['time'] current_time.outputs['time'] >> ldn.inputs['time'] diff --git a/flowpipe/node.py b/flowpipe/node.py index bc90c2a4a..0fc774574 100644 --- a/flowpipe/node.py +++ b/flowpipe/node.py @@ -23,13 +23,12 @@ __all__ = ['INode'] - class INode(object): """Holds input and output Plugs and a method for computing.""" __metaclass__ = ABCMeta - def __init__(self, name=None, identifier=None, metadata=None): + def __init__(self, name=None, identifier=None, metadata=None, graph=None): """Initialize the input and output dictionaries and the name. Args: @@ -44,6 +43,8 @@ def __init__(self, name=None, identifier=None, metadata=None): self.omit = False self.file_location = inspect.getfile(self.__class__) self.class_name = self.__class__.__name__ + if graph is not None: + graph.add_node(self) def __unicode__(self): """Show all input and output Plugs.""" @@ -269,14 +270,14 @@ class FunctionNode(INode): """Wrap a function into a Node.""" def __init__(self, func=None, outputs=None, name=None, - identifier=None, metadata=None, **kwargs): + identifier=None, metadata=None, graph=None, **kwargs): """The data on the function is used to drive the Node. The function itself becomes the compute method. The function input args become the InputPlugs. Other function attributes, name, __doc__ also transfer to the Node. """ super(FunctionNode, self).__init__( - name or getattr(func, '__name__', None), identifier, metadata) + name or getattr(func, '__name__', None), identifier, metadata, graph) self._initialize(func, outputs or [], metadata) for plug, value in kwargs.items(): self.inputs[plug].value = value diff --git a/tests/test_graph.py b/tests/test_graph.py index 1dde8d2dc..48eec4d3c 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -9,8 +9,8 @@ class NodeForTesting(INode): - def __init__(self, name=None): - super(NodeForTesting, self).__init__(name) + def __init__(self, name=None, **kwargs): + super(NodeForTesting, self).__init__(name, **kwargs) OutputPlug('out', self) InputPlug('in1', self, 0) InputPlug('in2', self, 0) @@ -284,3 +284,15 @@ def test_nodes_are_only_added_once(): graph.add_node(node) assert len(graph.nodes) == 1 + + +def test_nodes_can_add_to_graph_on_init(): + graph = Graph() + node = NodeForTesting(graph=graph) + assert graph["NodeForTesting"] == node + + @Node() + def function(): + pass + node = function(graph=graph) + assert graph["function"] == node