Skip to content

Commit

Permalink
✨ parallel execuion with ray (WecoAI#34)
Browse files Browse the repository at this point in the history
  • Loading branch information
dexhunter committed Dec 10, 2024
1 parent 0fbe106 commit fc8a6d6
Show file tree
Hide file tree
Showing 8 changed files with 296 additions and 37 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -168,3 +168,6 @@ logs

.gradio/
.ruff_cache/

run.sh
run_experiments.sh
25 changes: 20 additions & 5 deletions aide/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
_load_cfg,
prep_cfg,
)
from .parallel_agent import ParallelAgent


@dataclass
Expand Down Expand Up @@ -43,11 +44,18 @@ def __init__(self, data_dir: str, goal: str, eval: str | None = None):
prep_agent_workspace(self.cfg)

self.journal = Journal()
self.agent = Agent(
task_desc=self.task_desc,
cfg=self.cfg,
journal=self.journal,
)
if self.cfg.agent.parallel.enabled:
self.agent = ParallelAgent(
task_desc=self.task_desc,
cfg=self.cfg,
journal=self.journal,
)
else:
self.agent = Agent(
task_desc=self.task_desc,
cfg=self.cfg,
journal=self.journal,
)
self.interpreter = Interpreter(
self.cfg.workspace_dir, **OmegaConf.to_container(self.cfg.exec) # type: ignore
)
Expand All @@ -60,3 +68,10 @@ def run(self, steps: int) -> Solution:

best_node = self.journal.get_best_node(only_good=False)
return Solution(code=best_node.code, valid_metric=best_node.metric.value)

def cleanup(self):
"""Cleanup resources"""
if hasattr(self, 'interpreter'):
self.interpreter.cleanup_session()
if isinstance(self.agent, ParallelAgent):
self.agent.cleanup()
36 changes: 24 additions & 12 deletions aide/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,11 +49,7 @@ def exception_summary(e, working_dir, exec_file_name, format_tb_ipython):
tb_lines = traceback.format_exception(e)
# skip parts of stack trace in weflow code
tb_str = "".join(
[
line
for line in tb_lines
if "aide/" not in line and "importlib" not in line
]
[l for l in tb_lines if "aide/" not in l and "importlib" not in l]

Check failure on line 52 in aide/interpreter.py

View workflow job for this annotation

GitHub Actions / Python Linting

Ruff (E741)

aide/interpreter.py:52:20: E741 Ambiguous variable name: `l`
)
# tb_str = "".join([l for l in tb_lines])

Expand Down Expand Up @@ -203,7 +199,7 @@ def run(self, code: str, reset_session=True) -> ExecutionResult:
"""

logger.debug(f"REPL is executing code (reset_session={reset_session})")
logger.info(f"REPL is executing code (reset_session={reset_session})")

if reset_session:
if self.process is not None:
Expand All @@ -224,9 +220,18 @@ def run(self, code: str, reset_session=True) -> ExecutionResult:
except queue.Empty:
msg = "REPL child process failed to start execution"
logger.critical(msg)
queue_dump = ""
while not self.result_outq.empty():
logger.error(f"REPL output queue dump: {self.result_outq.get()}")
raise RuntimeError(msg) from None
queue_dump = self.result_outq.get()
logger.error(f"REPL output queue dump: {queue_dump[:1000]}")
self.cleanup_session()
return ExecutionResult(
term_out=[msg, queue_dump],
exec_time=0,
exc_type="RuntimeError",
exc_info={},
exc_stack=[],
)
assert state[0] == "state:ready", state
start_time = time.time()

Expand All @@ -246,11 +251,18 @@ def run(self, code: str, reset_session=True) -> ExecutionResult:
if not child_in_overtime and not self.process.is_alive():
msg = "REPL child process died unexpectedly"
logger.critical(msg)
queue_dump = ""
while not self.result_outq.empty():
logger.error(
f"REPL output queue dump: {self.result_outq.get()}"
)
raise RuntimeError(msg) from None
queue_dump = self.result_outq.get()
logger.error(f"REPL output queue dump: {queue_dump[:1000]}")
self.cleanup_session()
return ExecutionResult(
term_out=[msg, queue_dump],
exec_time=0,
exc_type="RuntimeError",
exc_info={},
exc_stack=[],
)

# child is alive and still executing -> check if we should sigint..
if self.timeout is None:
Expand Down
195 changes: 195 additions & 0 deletions aide/parallel_agent.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
import ray
from typing import List, Optional, Any
from .agent import Agent
from .journal import Node, Journal
from .interpreter import ExecutionResult, Interpreter

Check failure on line 5 in aide/parallel_agent.py

View workflow job for this annotation

GitHub Actions / Python Linting

Ruff (F401)

aide/parallel_agent.py:5:26: F401 `.interpreter.ExecutionResult` imported but unused
from .utils.config import Config
from omegaconf import OmegaConf
from .utils import data_preview as dp
import logging
from pathlib import Path

Check failure on line 10 in aide/parallel_agent.py

View workflow job for this annotation

GitHub Actions / Python Linting

Ruff (F401)

aide/parallel_agent.py:10:21: F401 `pathlib.Path` imported but unused

@ray.remote
class ParallelWorker(Agent):
"""Worker class that inherits from Agent to handle code generation and execution"""
def __init__(self, task_desc: str, cfg: Config, journal: Journal, data_preview: str):
super().__init__(task_desc, cfg, journal)
# Initialize interpreter for this worker
self.interpreter = Interpreter(
cfg.workspace_dir, **OmegaConf.to_container(cfg.exec) # type: ignore
)
# Initialize data preview
self.data_preview = data_preview
# Setup logger for this worker
actor_id = ray.get_runtime_context().get_actor_id()
self.logger = logging.getLogger(f"ParallelWorker-{actor_id}")
self.logger.setLevel(logging.INFO)

def generate_nodes(self, parent_node: Optional[Node], num_nodes: int) -> List[Node]:
"""Generate multiple nodes in parallel"""
self.logger.info(f"Generating {num_nodes} nodes from parent: {parent_node}")
nodes = []
for _ in range(num_nodes):
if parent_node is None:
node = self._draft()
elif parent_node.is_buggy:
node = self._debug(parent_node)
else:
node = self._improve(parent_node)
nodes.append(node)
self.logger.info(f"Generated {len(nodes)} nodes")
return nodes

def execute_and_evaluate_node(self, node: Node) -> Node:
"""Execute node's code and evaluate results"""
try:
actor_id = ray.get_runtime_context().get_actor_id()
self.logger.info(f"Worker {actor_id} executing node {node.id}")
# Execute code
result = self.interpreter.run(node.code, True)
# Absorb execution results
node.absorb_exec_result(result)
# Evaluate and update node metrics
self.parse_exec_result(node, result)
self.logger.info(f"Worker {actor_id} completed node {node.id} with metric: {node.metric.value if node.metric else 'None'}")
return node
except Exception as e:
self.logger.error(f"Worker {actor_id} failed executing node {node.id}: {str(e)}")
raise

def get_data_preview(self):
"""Return the data preview"""
return self.data_preview

def cleanup_interpreter(self):
"""Cleanup the interpreter session"""
self.interpreter.cleanup_session()

def search_and_generate(self, num_nodes: int) -> List[Node]:
"""Independent search and generation by each worker"""
parent_node = self.search_policy()
self.logger.info(f"Worker selected parent node: {parent_node.id if parent_node else 'None'}")
return self.generate_nodes(parent_node, num_nodes)

class ParallelAgent(Agent):
"""Main agent class that implements parallel tree search"""
def __init__(self, task_desc: str, cfg: Config, journal: Journal):
super().__init__(task_desc, cfg, journal)

ray.init(
ignore_reinit_error=True,
logging_level=logging.INFO,
)

# Initialize data preview first
if cfg.agent.data_preview:
self.data_preview = dp.generate(cfg.workspace_dir)
else:
self.data_preview = None

self.num_workers = cfg.agent.parallel.num_workers
self.nodes_per_worker = cfg.agent.parallel.nodes_per_worker

# Setup logger for parallel execution
self.logger = logging.getLogger("aide.parallel")
self.logger.setLevel(logging.INFO)

self.workers = [
ParallelWorker.remote(
task_desc=task_desc,
cfg=cfg,
journal=journal,
data_preview=self.data_preview
)
for _ in range(self.num_workers)
]

def step(self, exec_callback: Any = None):
"""Single step of the tree search"""
step_num = len(self.journal)
self.logger.info(f"Starting step {step_num}")

if not self.journal.nodes:
self.update_data_preview()
self.logger.info("Updated data preview")

# Let workers independently search and generate nodes
node_futures = [
worker.search_and_generate.remote(self.nodes_per_worker)
for worker in self.workers
]

# Wait for node generation
self.logger.info(f"Step {step_num}: Waiting for node generation to complete...")
generated_nodes = ray.get(node_futures)
total_nodes = sum(len(nodes) for nodes in generated_nodes)
self.logger.info(f"Step {step_num}: Generated {total_nodes} nodes total")

# Flatten list of nodes and maintain parent relationships
all_nodes = []
for worker_nodes in generated_nodes:
for node in worker_nodes:
all_nodes.append(node)

# Distribute execution work across workers (same layer parallel execution)
nodes_per_executor = max(1, len(all_nodes) // len(self.workers))
exec_futures = []

self.logger.info(f"Step {step_num}: Distributing {len(all_nodes)} nodes across {len(self.workers)} workers for execution")
for i, worker in enumerate(self.workers):
start_idx = i * nodes_per_executor
end_idx = start_idx + nodes_per_executor if i < len(self.workers) - 1 else len(all_nodes)
worker_nodes = all_nodes[start_idx:end_idx]

self.logger.info(f"Step {step_num}: Worker {i} assigned {len(worker_nodes)} nodes")
for node in worker_nodes:
exec_futures.append(worker.execute_and_evaluate_node.remote(node))

# Get execution results and update journal
self.logger.info(f"Step {step_num}: Waiting for {len(exec_futures)} node executions to complete...")
evaluated_nodes = ray.get(exec_futures)
self.logger.info(f"Step {step_num}: All node executions completed")

# Batch update journal and save results
successful_nodes = 0
buggy_nodes = 0
best_metric = float('-inf')

for node in evaluated_nodes:
if node.parent is None: # Check node's parent attribute instead of using parent_node
self.journal.draft_nodes.append(node)
self.journal.append(node)

# Track statistics
if node.is_buggy:
buggy_nodes += 1
else:
successful_nodes += 1
if node.metric and node.metric.value > best_metric:
best_metric = node.metric.value

self.logger.info(
f"Step {step_num} completed: "
f"{successful_nodes} successful nodes, "
f"{buggy_nodes} buggy nodes, "
f"best metric: {best_metric if best_metric != float('-inf') else 'N/A'}"
)

# Save results
try:
from .utils.config import save_run
save_run(self.cfg, self.journal)
self.logger.info(f"Step {step_num}: Successfully saved run data to {self.cfg.log_dir}")
except Exception as e:
self.logger.error(f"Step {step_num}: Failed to save run: {str(e)}")

def cleanup(self):
"""Cleanup Ray resources"""
for worker in self.workers:
ray.get(worker.cleanup_interpreter.remote())
ray.shutdown()

def update_data_preview(self):
"""Update data preview from the first worker"""
if not hasattr(self, 'data_preview'):
self.data_preview = ray.get(self.workers[0].get_data_preview.remote())
Loading

0 comments on commit fc8a6d6

Please sign in to comment.