Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix the problem of Workflow terminates after parallel tasks execution, merge node not triggered #12498

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
166 changes: 96 additions & 70 deletions api/core/workflow/graph_engine/graph_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import queue
import time
import uuid
from collections import defaultdict
from collections.abc import Generator, Mapping
from concurrent.futures import ThreadPoolExecutor, wait
from copy import copy, deepcopy
Expand Down Expand Up @@ -348,17 +349,13 @@ def _run(

next_node_id = edge.target_node_id
else:
final_node_id = None

if any(edge.run_condition for edge in edge_mappings):
# if nodes has run conditions, get node id which branch to take based on the run condition results
condition_edge_mappings: dict[str, list[GraphEdge]] = {}
for edge in edge_mappings:
if edge.run_condition:
run_condition_hash = edge.run_condition.hash
if run_condition_hash not in condition_edge_mappings:
condition_edge_mappings[run_condition_hash] = []

condition_edge_mappings[run_condition_hash].append(edge)

for _, sub_edge_mappings in condition_edge_mappings.items():
Expand All @@ -383,7 +380,29 @@ def _run(
continue

if len(sub_edge_mappings) == 1:
final_node_id = edge.target_node_id
next_node_id = edge.target_node_id
generator = self._run(
start_node_id=next_node_id,
in_parallel_id=in_parallel_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
handle_exceptions=handle_exceptions,
)
for item in generator:
yield item

next_edges = self.graph.edge_mapping.get(next_node_id, [])
if next_edges:
next_node_id = next_edges[0].target_node_id
generator = self._run(
start_node_id=next_node_id,
in_parallel_id=in_parallel_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
handle_exceptions=handle_exceptions,
)
for item in generator:
yield item
else:
parallel_generator = self._run_parallel_branches(
edge_mappings=sub_edge_mappings,
Expand All @@ -394,16 +413,21 @@ def _run(

for parallel_result in parallel_generator:
if isinstance(parallel_result, str):
final_node_id = parallel_result
next_node_id = parallel_result
generator = self._run(
start_node_id=next_node_id,
in_parallel_id=in_parallel_id,
parent_parallel_id=parent_parallel_id,
parent_parallel_start_node_id=parent_parallel_start_node_id,
handle_exceptions=handle_exceptions,
)
for item in generator:
yield item
else:
yield parallel_result

break

if not final_node_id:
break
break

next_node_id = final_node_id
elif (
node_instance.node_data.error_strategy == ErrorStrategy.FAIL_BRANCH
and node_instance.should_continue_on_error
Expand All @@ -418,6 +442,7 @@ def _run(
handle_exceptions=handle_exceptions,
)

final_node_id = None
for generated_item in parallel_generator:
if isinstance(generated_item, str):
final_node_id = generated_item
Expand All @@ -439,83 +464,83 @@ def _run_parallel_branches(
parallel_start_node_id: Optional[str] = None,
handle_exceptions: list[str] = [],
) -> Generator[GraphEngineEvent | str, None, None]:
# if nodes has no run conditions, parallel run all nodes
parallel_id = self.graph.node_parallel_mapping.get(edge_mappings[0].target_node_id)
if not parallel_id:
node_id = edge_mappings[0].target_node_id
node_config = self.graph.node_id_config_mapping.get(node_id)
if not node_config:
raise GraphRunFailedError(
f"Node {node_id} related parallel not found or incorrectly connected to multiple parallel branches."
)

node_title = node_config.get("data", {}).get("title")
raise GraphRunFailedError(
f"Node {node_title} related parallel not found or incorrectly connected to multiple parallel branches."
)

parallel = self.graph.parallel_mapping.get(parallel_id)
if not parallel:
raise GraphRunFailedError(f"Parallel {parallel_id} not found.")
target_nodes: defaultdict[str, list[GraphEdge]] = defaultdict(list)
for edge in edge_mappings:
target_nodes[edge.target_node_id].append(edge)

# run parallel nodes, run in new thread and use queue to get results
executed_node_ids = set()
q: queue.Queue = queue.Queue()
all_futures = []

# Create a list to store the threads
futures = []

# new thread
for edge in edge_mappings:
if (
edge.target_node_id not in self.graph.node_parallel_mapping
or self.graph.node_parallel_mapping.get(edge.target_node_id, "") != parallel_id
):
for target_node_id, edges in target_nodes.items():
if target_node_id in executed_node_ids:
continue

future = self.thread_pool.submit(
self._run_parallel_node,
**{
"flask_app": current_app._get_current_object(), # type: ignore[attr-defined]
"q": q,
"parallel_id": parallel_id,
"parallel_start_node_id": edge.target_node_id,
"parent_parallel_id": in_parallel_id,
"parent_parallel_start_node_id": parallel_start_node_id,
"handle_exceptions": handle_exceptions,
},
)
parallel_id = self.graph.node_parallel_mapping.get(target_node_id)
if not parallel_id:
raise GraphRunFailedError(f"Node {target_node_id} parallel not found")

future.add_done_callback(self.thread_pool.task_done_callback)
parallel = self.graph.parallel_mapping.get(parallel_id)
if not parallel:
raise GraphRunFailedError(f"Parallel {parallel_id} not found.")

futures.append(future)
for edge in edges:
if edge.target_node_id in executed_node_ids:
continue
if (
edge.target_node_id not in self.graph.node_parallel_mapping
or self.graph.node_parallel_mapping.get(edge.target_node_id, "") != parallel_id
):
continue

executed_node_ids.add(edge.target_node_id)

future = self.thread_pool.submit(
self._run_parallel_node,
**{
"flask_app": current_app._get_current_object(), # type: ignore[attr-defined]
"q": q,
"parallel_id": parallel_id,
"parallel_start_node_id": edge.target_node_id,
"parent_parallel_id": in_parallel_id,
"parent_parallel_start_node_id": parallel_start_node_id,
"handle_exceptions": handle_exceptions,
},
)
future.add_done_callback(self.thread_pool.task_done_callback)
all_futures.append((parallel_id, future))

succeeded_count = 0
while True:
branch_results = []

while succeeded_count < len(all_futures):
try:
event = q.get(timeout=1)
if event is None:
break

yield event
if event.parallel_id == parallel_id:
if isinstance(event, ParallelBranchRunSucceededEvent):
succeeded_count += 1
if succeeded_count == len(futures):
q.put(None)

continue
elif isinstance(event, ParallelBranchRunFailedEvent):
raise GraphRunFailedError(event.error)

if isinstance(event, ParallelBranchRunSucceededEvent):
succeeded_count += 1
branch_results.append(event)

if succeeded_count == len(all_futures):
q.put(None)

for parallel_id, _ in all_futures:
parallel = self.graph.parallel_mapping[parallel_id]
final_node_id = parallel.end_to_node_id
if final_node_id:
yield final_node_id

elif isinstance(event, ParallelBranchRunFailedEvent):
raise GraphRunFailedError(event.error)

except queue.Empty:
continue

# wait all threads
wait(futures)

# get final node id
final_node_id = parallel.end_to_node_id
if final_node_id:
yield final_node_id
wait([f for _, f in all_futures])

def _run_parallel_node(
self,
Expand Down Expand Up @@ -720,6 +745,7 @@ def _run_node(
if run_result.metadata and run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS):
# plus state total_tokens
self.graph_runtime_state.total_tokens += int(
# type: ignore[arg-type]
run_result.metadata.get(NodeRunMetadataKey.TOTAL_TOKENS) # type: ignore[arg-type]
)

Expand Down
Loading