diff --git a/.gitignore b/.gitignore index fe04838..ee5ea53 100644 --- a/.gitignore +++ b/.gitignore @@ -168,3 +168,6 @@ logs .gradio/ .ruff_cache/ + +run.sh +run_experiments.sh \ No newline at end of file diff --git a/aide/__init__.py b/aide/__init__.py index 3b1a82b..555f62f 100644 --- a/aide/__init__.py +++ b/aide/__init__.py @@ -12,6 +12,7 @@ _load_cfg, prep_cfg, ) +from .parallel_agent import ParallelAgent @dataclass @@ -43,11 +44,18 @@ def __init__(self, data_dir: str, goal: str, eval: str | None = None): prep_agent_workspace(self.cfg) self.journal = Journal() - self.agent = Agent( - task_desc=self.task_desc, - cfg=self.cfg, - journal=self.journal, - ) + if self.cfg.agent.parallel.enabled: + self.agent = ParallelAgent( + task_desc=self.task_desc, + cfg=self.cfg, + journal=self.journal, + ) + else: + self.agent = Agent( + task_desc=self.task_desc, + cfg=self.cfg, + journal=self.journal, + ) self.interpreter = Interpreter( self.cfg.workspace_dir, **OmegaConf.to_container(self.cfg.exec) # type: ignore ) @@ -60,3 +68,10 @@ def run(self, steps: int) -> Solution: best_node = self.journal.get_best_node(only_good=False) return Solution(code=best_node.code, valid_metric=best_node.metric.value) + + def cleanup(self): + """Cleanup resources""" + if hasattr(self, 'interpreter'): + self.interpreter.cleanup_session() + if isinstance(self.agent, ParallelAgent): + self.agent.cleanup() diff --git a/aide/interpreter.py b/aide/interpreter.py index 69171f5..21dfaea 100644 --- a/aide/interpreter.py +++ b/aide/interpreter.py @@ -49,11 +49,7 @@ def exception_summary(e, working_dir, exec_file_name, format_tb_ipython): tb_lines = traceback.format_exception(e) # skip parts of stack trace in weflow code tb_str = "".join( - [ - line - for line in tb_lines - if "aide/" not in line and "importlib" not in line - ] + [l for l in tb_lines if "aide/" not in l and "importlib" not in l] ) # tb_str = "".join([l for l in tb_lines]) @@ -203,7 +199,7 @@ def run(self, code: str, reset_session=True) -> ExecutionResult: """ - logger.debug(f"REPL is executing code (reset_session={reset_session})") + logger.info(f"REPL is executing code (reset_session={reset_session})") if reset_session: if self.process is not None: @@ -224,9 +220,18 @@ def run(self, code: str, reset_session=True) -> ExecutionResult: except queue.Empty: msg = "REPL child process failed to start execution" logger.critical(msg) + queue_dump = "" while not self.result_outq.empty(): - logger.error(f"REPL output queue dump: {self.result_outq.get()}") - raise RuntimeError(msg) from None + queue_dump = self.result_outq.get() + logger.error(f"REPL output queue dump: {queue_dump[:1000]}") + self.cleanup_session() + return ExecutionResult( + term_out=[msg, queue_dump], + exec_time=0, + exc_type="RuntimeError", + exc_info={}, + exc_stack=[], + ) assert state[0] == "state:ready", state start_time = time.time() @@ -246,11 +251,18 @@ def run(self, code: str, reset_session=True) -> ExecutionResult: if not child_in_overtime and not self.process.is_alive(): msg = "REPL child process died unexpectedly" logger.critical(msg) + queue_dump = "" while not self.result_outq.empty(): - logger.error( - f"REPL output queue dump: {self.result_outq.get()}" - ) - raise RuntimeError(msg) from None + queue_dump = self.result_outq.get() + logger.error(f"REPL output queue dump: {queue_dump[:1000]}") + self.cleanup_session() + return ExecutionResult( + term_out=[msg, queue_dump], + exec_time=0, + exc_type="RuntimeError", + exc_info={}, + exc_stack=[], + ) # child is alive and still executing -> check if we should sigint.. if self.timeout is None: diff --git a/aide/parallel_agent.py b/aide/parallel_agent.py new file mode 100644 index 0000000..862ba1a --- /dev/null +++ b/aide/parallel_agent.py @@ -0,0 +1,195 @@ +import ray +from typing import List, Optional, Any +from .agent import Agent +from .journal import Node, Journal +from .interpreter import ExecutionResult, Interpreter +from .utils.config import Config +from omegaconf import OmegaConf +from .utils import data_preview as dp +import logging +from pathlib import Path + +@ray.remote +class ParallelWorker(Agent): + """Worker class that inherits from Agent to handle code generation and execution""" + def __init__(self, task_desc: str, cfg: Config, journal: Journal, data_preview: str): + super().__init__(task_desc, cfg, journal) + # Initialize interpreter for this worker + self.interpreter = Interpreter( + cfg.workspace_dir, **OmegaConf.to_container(cfg.exec) # type: ignore + ) + # Initialize data preview + self.data_preview = data_preview + # Setup logger for this worker + actor_id = ray.get_runtime_context().get_actor_id() + self.logger = logging.getLogger(f"ParallelWorker-{actor_id}") + self.logger.setLevel(logging.INFO) + + def generate_nodes(self, parent_node: Optional[Node], num_nodes: int) -> List[Node]: + """Generate multiple nodes in parallel""" + self.logger.info(f"Generating {num_nodes} nodes from parent: {parent_node}") + nodes = [] + for _ in range(num_nodes): + if parent_node is None: + node = self._draft() + elif parent_node.is_buggy: + node = self._debug(parent_node) + else: + node = self._improve(parent_node) + nodes.append(node) + self.logger.info(f"Generated {len(nodes)} nodes") + return nodes + + def execute_and_evaluate_node(self, node: Node) -> Node: + """Execute node's code and evaluate results""" + try: + actor_id = ray.get_runtime_context().get_actor_id() + self.logger.info(f"Worker {actor_id} executing node {node.id}") + # Execute code + result = self.interpreter.run(node.code, True) + # Absorb execution results + node.absorb_exec_result(result) + # Evaluate and update node metrics + self.parse_exec_result(node, result) + self.logger.info(f"Worker {actor_id} completed node {node.id} with metric: {node.metric.value if node.metric else 'None'}") + return node + except Exception as e: + self.logger.error(f"Worker {actor_id} failed executing node {node.id}: {str(e)}") + raise + + def get_data_preview(self): + """Return the data preview""" + return self.data_preview + + def cleanup_interpreter(self): + """Cleanup the interpreter session""" + self.interpreter.cleanup_session() + + def search_and_generate(self, num_nodes: int) -> List[Node]: + """Independent search and generation by each worker""" + parent_node = self.search_policy() + self.logger.info(f"Worker selected parent node: {parent_node.id if parent_node else 'None'}") + return self.generate_nodes(parent_node, num_nodes) + +class ParallelAgent(Agent): + """Main agent class that implements parallel tree search""" + def __init__(self, task_desc: str, cfg: Config, journal: Journal): + super().__init__(task_desc, cfg, journal) + + ray.init( + ignore_reinit_error=True, + logging_level=logging.INFO, + ) + + # Initialize data preview first + if cfg.agent.data_preview: + self.data_preview = dp.generate(cfg.workspace_dir) + else: + self.data_preview = None + + self.num_workers = cfg.agent.parallel.num_workers + self.nodes_per_worker = cfg.agent.parallel.nodes_per_worker + + # Setup logger for parallel execution + self.logger = logging.getLogger("aide.parallel") + self.logger.setLevel(logging.INFO) + + self.workers = [ + ParallelWorker.remote( + task_desc=task_desc, + cfg=cfg, + journal=journal, + data_preview=self.data_preview + ) + for _ in range(self.num_workers) + ] + + def step(self, exec_callback: Any = None): + """Single step of the tree search""" + step_num = len(self.journal) + self.logger.info(f"Starting step {step_num}") + + if not self.journal.nodes: + self.update_data_preview() + self.logger.info("Updated data preview") + + # Let workers independently search and generate nodes + node_futures = [ + worker.search_and_generate.remote(self.nodes_per_worker) + for worker in self.workers + ] + + # Wait for node generation + self.logger.info(f"Step {step_num}: Waiting for node generation to complete...") + generated_nodes = ray.get(node_futures) + total_nodes = sum(len(nodes) for nodes in generated_nodes) + self.logger.info(f"Step {step_num}: Generated {total_nodes} nodes total") + + # Flatten list of nodes and maintain parent relationships + all_nodes = [] + for worker_nodes in generated_nodes: + for node in worker_nodes: + all_nodes.append(node) + + # Distribute execution work across workers (same layer parallel execution) + nodes_per_executor = max(1, len(all_nodes) // len(self.workers)) + exec_futures = [] + + self.logger.info(f"Step {step_num}: Distributing {len(all_nodes)} nodes across {len(self.workers)} workers for execution") + for i, worker in enumerate(self.workers): + start_idx = i * nodes_per_executor + end_idx = start_idx + nodes_per_executor if i < len(self.workers) - 1 else len(all_nodes) + worker_nodes = all_nodes[start_idx:end_idx] + + self.logger.info(f"Step {step_num}: Worker {i} assigned {len(worker_nodes)} nodes") + for node in worker_nodes: + exec_futures.append(worker.execute_and_evaluate_node.remote(node)) + + # Get execution results and update journal + self.logger.info(f"Step {step_num}: Waiting for {len(exec_futures)} node executions to complete...") + evaluated_nodes = ray.get(exec_futures) + self.logger.info(f"Step {step_num}: All node executions completed") + + # Batch update journal and save results + successful_nodes = 0 + buggy_nodes = 0 + best_metric = float('-inf') + + for node in evaluated_nodes: + if node.parent is None: # Check node's parent attribute instead of using parent_node + self.journal.draft_nodes.append(node) + self.journal.append(node) + + # Track statistics + if node.is_buggy: + buggy_nodes += 1 + else: + successful_nodes += 1 + if node.metric and node.metric.value > best_metric: + best_metric = node.metric.value + + self.logger.info( + f"Step {step_num} completed: " + f"{successful_nodes} successful nodes, " + f"{buggy_nodes} buggy nodes, " + f"best metric: {best_metric if best_metric != float('-inf') else 'N/A'}" + ) + + # Save results + try: + from .utils.config import save_run + save_run(self.cfg, self.journal) + self.logger.info(f"Step {step_num}: Successfully saved run data to {self.cfg.log_dir}") + except Exception as e: + self.logger.error(f"Step {step_num}: Failed to save run: {str(e)}") + + def cleanup(self): + """Cleanup Ray resources""" + for worker in self.workers: + ray.get(worker.cleanup_interpreter.remote()) + ray.shutdown() + + def update_data_preview(self): + """Update data preview from the first worker""" + if not hasattr(self, 'data_preview'): + self.data_preview = ray.get(self.workers[0].get_data_preview.remote()) diff --git a/aide/run.py b/aide/run.py index 856ab28..13d9825 100644 --- a/aide/run.py +++ b/aide/run.py @@ -5,6 +5,7 @@ from . import backend from .agent import Agent +from .parallel_agent import ParallelAgent from .interpreter import Interpreter from .journal import Journal, Node from .journal2report import journal2report @@ -63,6 +64,8 @@ def run(): with Status("Preparing agent workspace (copying and extracting files) ..."): prep_agent_workspace(cfg) + global_step = 0 + def cleanup(): if global_step == 0: shutil.rmtree(cfg.workspace_dir) @@ -70,14 +73,26 @@ def cleanup(): atexit.register(cleanup) journal = Journal() - agent = Agent( - task_desc=task_desc, - cfg=cfg, - journal=journal, - ) - interpreter = Interpreter( - cfg.workspace_dir, **OmegaConf.to_container(cfg.exec) # type: ignore - ) + + # Choose agent type based on config + if cfg.agent.parallel.enabled: + agent = ParallelAgent( + task_desc=task_desc, + cfg=cfg, + journal=journal, + ) + # No need for separate interpreter as each worker has its own + exec_callback = None + else: + agent = Agent( + task_desc=task_desc, + cfg=cfg, + journal=journal, + ) + interpreter = Interpreter( + cfg.workspace_dir, **OmegaConf.to_container(cfg.exec) # type: ignore + ) + exec_callback = interpreter.run global_step = len(journal) prog = Progress( @@ -89,12 +104,6 @@ def cleanup(): status = Status("[green]Generating code...") prog.add_task("Progress:", total=cfg.agent.steps, completed=global_step) - def exec_callback(*args, **kwargs): - status.update("[magenta]Executing code...") - res = interpreter.run(*args, **kwargs) - status.update("[green]Generating code...") - return res - def generate_live(): tree = journal_to_rich_tree(journal) prog.update(prog.task_ids[0], completed=global_step) @@ -127,12 +136,23 @@ def generate_live(): refresh_per_second=16, screen=True, ) as live: - while global_step < cfg.agent.steps: - agent.step(exec_callback=exec_callback) - save_run(cfg, journal) - global_step = len(journal) - live.update(generate_live()) - interpreter.cleanup_session() + try: + while global_step < cfg.agent.steps: + if cfg.agent.parallel.enabled: + status.update("[magenta]Generating and executing code in parallel...") + agent.step() + else: + status.update("[green]Generating code...") + agent.step(exec_callback=exec_callback) + save_run(cfg, journal) + global_step = len(journal) + live.update(generate_live()) + finally: + # Cleanup resources + if cfg.agent.parallel.enabled: + agent.cleanup() + else: + interpreter.cleanup_session() if cfg.generate_report: print("Generating final report from journal...") diff --git a/aide/utils/config.py b/aide/utils/config.py index 06a8b95..5f9e255 100644 --- a/aide/utils/config.py +++ b/aide/utils/config.py @@ -38,6 +38,12 @@ class SearchConfig: debug_prob: float num_drafts: int +@dataclass +class ParallelConfig: + enabled: bool + num_workers: int + nodes_per_worker: int + @dataclass class AgentConfig: @@ -51,6 +57,8 @@ class AgentConfig: search: SearchConfig + parallel: ParallelConfig + @dataclass class ExecConfig: diff --git a/aide/utils/config.yaml b/aide/utils/config.yaml index 1e5d307..0596ed2 100644 --- a/aide/utils/config.yaml +++ b/aide/utils/config.yaml @@ -56,3 +56,8 @@ agent: max_debug_depth: 3 debug_prob: 0.5 num_drafts: 5 + + parallel: + enabled: true + num_workers: 4 + nodes_per_worker: 2 diff --git a/requirements.txt b/requirements.txt index 152e67e..726e923 100644 --- a/requirements.txt +++ b/requirements.txt @@ -92,3 +92,4 @@ xlrd backoff streamlit==1.40.2 python-dotenv +ray[default]==2.40.0