Skip to content

Commit

Permalink
fix: prevents exception when the pipeline contains multiple nested lo…
Browse files Browse the repository at this point in the history
…ops, due to the cycle detection removing the same edge multiple times (ref deepset-ai#8657)
  • Loading branch information
etirelli committed Dec 29, 2024
1 parent 3ea128c commit 9fb6509
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 21 deletions.
47 changes: 26 additions & 21 deletions haystack/core/pipeline/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1209,27 +1209,32 @@ def _break_supported_cycles_in_graph(self) -> Tuple[networkx.MultiDiGraph, Dict[
# sender_comp will be the last element of cycle and receiver_comp will be the first.
# So if cycle is [1, 2, 3, 4] we would call zip([1, 2, 3, 4], [2, 3, 4, 1]).
for sender_comp, receiver_comp in zip(cycle, cycle[1:] + cycle[:1]):
# We get the key and iterate those as we want to edit the graph data while
# iterating the edges and that would raise.
# Even though the connection key set in Pipeline.connect() uses only the
# sockets name we don't have clashes since it's only used to differentiate
# multiple edges between two nodes.
edge_keys = list(temp_graph.get_edge_data(sender_comp, receiver_comp).keys())
for edge_key in edge_keys:
edge_data = temp_graph.get_edge_data(sender_comp, receiver_comp)[edge_key]
receiver_socket = edge_data["to_socket"]
if not receiver_socket.is_variadic and receiver_socket.is_mandatory:
continue

# We found a breakable edge
sender_socket = edge_data["from_socket"]
edges_removed[sender_comp].append(sender_socket.name)
temp_graph.remove_edge(sender_comp, receiver_comp, edge_key)

graph_has_cycles = not networkx.is_directed_acyclic_graph(temp_graph)
if not graph_has_cycles:
# We removed all the cycles, we can stop
break
# for graphs with multiple nested cycles, we need to check if the edge hasn't
# been previously removed before we try to remove it again
edge_data = temp_graph.get_edge_data(sender_comp, receiver_comp)
if edge_data is not None:
# We get the key and iterate those as we want to edit the graph data while
# iterating the edges and that would raise.
# Even though the connection key set in Pipeline.connect() uses only the
# sockets name we don't have clashes since it's only used to differentiate
# multiple edges between two nodes.
edge_keys = list(edge_data.keys())

for edge_key in edge_keys:
edge_data = temp_graph.get_edge_data(sender_comp, receiver_comp)[edge_key]
receiver_socket = edge_data["to_socket"]
if not receiver_socket.is_variadic and receiver_socket.is_mandatory:
continue

# We found a breakable edge
sender_socket = edge_data["from_socket"]
edges_removed[sender_comp].append(sender_socket.name)
temp_graph.remove_edge(sender_comp, receiver_comp, edge_key)

graph_has_cycles = not networkx.is_directed_acyclic_graph(temp_graph)
if not graph_has_cycles:
# We removed all the cycles, we can stop
break

if not graph_has_cycles:
# We removed all the cycles, nice
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
---
fixes:
- |
Prevents the pipeline from raising an exception when there are multiple nested cycles in the graph.
25 changes: 25 additions & 0 deletions test/core/pipeline/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -1581,3 +1581,28 @@ def test__find_receivers_from(self):
),
)
]

def test__break_supported_cycles_in_graph(self):
# the following pipeline has a nested cycle, which is supported by Haystack
# but was causing an exception to be raised in the _break_supported_cycles_in_graph method
comp1 = component_class("Comp1", input_types={"value": int}, output_types={"value": int})()
comp2 = component_class("Comp2", input_types={"value": Variadic[int]}, output_types={"value": int})()
comp3 = component_class("Comp3", input_types={"value": Variadic[int]}, output_types={"value": int})()
comp4 = component_class("Comp4", input_types={"value": Optional[int]}, output_types={"value": int})()
comp5 = component_class("Comp5", input_types={"value": Variadic[int]}, output_types={"value": int})()
pipe = Pipeline()
pipe.add_component("comp1", comp1)
pipe.add_component("comp2", comp2)
pipe.add_component("comp3", comp3)
pipe.add_component("comp4", comp4)
pipe.add_component("comp5", comp5)
pipe.connect("comp1.value", "comp2.value")
pipe.connect("comp2.value", "comp3.value")
pipe.connect("comp3.value", "comp4.value")
pipe.connect("comp3.value", "comp5.value")
pipe.connect("comp4.value", "comp5.value")
pipe.connect("comp4.value", "comp3.value")
pipe.connect("comp5.value", "comp2.value")

# the following call should not raise an exception
pipe._break_supported_cycles_in_graph()

0 comments on commit 9fb6509

Please sign in to comment.