From eeb23425090286523ab4fcf5d6a86a98c4b4bb4d Mon Sep 17 00:00:00 2001 From: Engel Nyst Date: Tue, 5 Nov 2024 03:36:14 +0100 Subject: [PATCH] Refactor history/event stream (#3808) --- .../usage/how-to/evaluation-harness.md | 4 +- .../usage/how-to/evaluation-harness.md | 4 +- .../usage/how-to/evaluation-harness.md | 4 +- evaluation/EDA/run_infer.py | 7 +- evaluation/agent_bench/run_infer.py | 5 +- evaluation/aider_bench/run_infer.py | 3 +- evaluation/biocoder/run_infer.py | 3 +- evaluation/bird/run_infer.py | 5 +- evaluation/browsing_delegation/run_infer.py | 3 +- evaluation/discoverybench/run_infer.py | 7 +- evaluation/gaia/run_infer.py | 5 +- evaluation/gorilla/run_infer.py | 5 +- evaluation/gpqa/run_infer.py | 5 +- evaluation/humanevalfix/run_infer.py | 3 +- evaluation/integration_tests/run_infer.py | 2 +- evaluation/logic_reasoning/run_infer.py | 5 +- evaluation/miniwob/run_infer.py | 5 +- evaluation/mint/run_infer.py | 9 +- evaluation/ml_bench/run_infer.py | 3 +- evaluation/scienceagentbench/run_infer.py | 3 +- evaluation/swe_bench/run_infer.py | 3 +- evaluation/toolqa/run_infer.py | 5 +- evaluation/utils/shared.py | 29 +- evaluation/webarena/run_infer.py | 5 +- .../agenthub/browsing_agent/browsing_agent.py | 4 +- .../agenthub/codeact_agent/codeact_agent.py | 6 +- .../codeact_swe_agent/codeact_swe_agent.py | 6 +- openhands/agenthub/delegator_agent/agent.py | 8 +- openhands/agenthub/dummy_agent/agent.py | 2 +- openhands/agenthub/micro/agent.py | 8 +- openhands/agenthub/planner_agent/prompt.py | 4 +- openhands/controller/agent_controller.py | 173 ++++++-- openhands/controller/state/state.py | 45 ++- openhands/controller/stuck.py | 2 +- openhands/core/config/app_config.py | 2 - openhands/core/main.py | 32 +- openhands/events/action/message.py | 2 +- openhands/events/stream.py | 21 +- openhands/memory/__init__.py | 3 +- openhands/memory/history.py | 224 ----------- openhands/runtime/utils/edit.py | 8 +- tests/runtime/test_stress_remote_runtime.py | 2 +- tests/unit/test_codeact_agent.py | 2 +- tests/unit/test_is_stuck.py | 376 +++++++++--------- tests/unit/test_micro_agents.py | 14 +- tests/unit/test_prompt_caching.py | 92 +++-- 46 files changed, 559 insertions(+), 609 deletions(-) delete mode 100644 openhands/memory/history.py diff --git a/docs/i18n/fr/docusaurus-plugin-content-docs/current/usage/how-to/evaluation-harness.md b/docs/i18n/fr/docusaurus-plugin-content-docs/current/usage/how-to/evaluation-harness.md index 785112ed6ea0..3f191053998f 100644 --- a/docs/i18n/fr/docusaurus-plugin-content-docs/current/usage/how-to/evaluation-harness.md +++ b/docs/i18n/fr/docusaurus-plugin-content-docs/current/usage/how-to/evaluation-harness.md @@ -161,7 +161,7 @@ Pour créer un workflow d'évaluation pour votre benchmark, suivez ces étapes : instruction=instruction, test_result=evaluation_result, metadata=metadata, - history=state.history.compatibility_for_eval_history_pairs(), + history=compatibility_for_eval_history_pairs(state.history), metrics=state.metrics.get() if state.metrics else None, error=state.last_error if state and state.last_error else None, ) @@ -260,7 +260,7 @@ def codeact_user_response(state: State | None) -> str: # vérifier si l'agent a essayé de parler à l'utilisateur 3 fois, si oui, faire savoir à l'agent qu'il peut abandonner user_msgs = [ event - for event in state.history.get_events() + for event in state.history if isinstance(event, MessageAction) and event.source == 'user' ] if len(user_msgs) >= 2: diff --git a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/usage/how-to/evaluation-harness.md b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/usage/how-to/evaluation-harness.md index a50bb18502e2..eb99a30ea3fd 100644 --- a/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/usage/how-to/evaluation-harness.md +++ b/docs/i18n/zh-Hans/docusaurus-plugin-content-docs/current/usage/how-to/evaluation-harness.md @@ -158,7 +158,7 @@ OpenHands 的主要入口点在 `openhands/core/main.py` 中。以下是它工 instruction=instruction, test_result=evaluation_result, metadata=metadata, - history=state.history.compatibility_for_eval_history_pairs(), + history=compatibility_for_eval_history_pairs(state.history), metrics=state.metrics.get() if state.metrics else None, error=state.last_error if state and state.last_error else None, ) @@ -257,7 +257,7 @@ def codeact_user_response(state: State | None) -> str: # 检查代理是否已尝试与用户对话 3 次,如果是,让代理知道它可以放弃 user_msgs = [ event - for event in state.history.get_events() + for event in state.history if isinstance(event, MessageAction) and event.source == 'user' ] if len(user_msgs) >= 2: diff --git a/docs/modules/usage/how-to/evaluation-harness.md b/docs/modules/usage/how-to/evaluation-harness.md index 622f7e5607ba..e4d1e5d15bc7 100644 --- a/docs/modules/usage/how-to/evaluation-harness.md +++ b/docs/modules/usage/how-to/evaluation-harness.md @@ -158,7 +158,7 @@ To create an evaluation workflow for your benchmark, follow these steps: instruction=instruction, test_result=evaluation_result, metadata=metadata, - history=state.history.compatibility_for_eval_history_pairs(), + history=compatibility_for_eval_history_pairs(state.history), metrics=state.metrics.get() if state.metrics else None, error=state.last_error if state and state.last_error else None, ) @@ -257,7 +257,7 @@ def codeact_user_response(state: State | None) -> str: # check if the agent has tried to talk to the user 3 times, if so, let the agent know it can give up user_msgs = [ event - for event in state.history.get_events() + for event in state.history if isinstance(event, MessageAction) and event.source == 'user' ] if len(user_msgs) >= 2: diff --git a/evaluation/EDA/run_infer.py b/evaluation/EDA/run_infer.py index 2c896939a751..fb5df3b44f01 100644 --- a/evaluation/EDA/run_infer.py +++ b/evaluation/EDA/run_infer.py @@ -8,6 +8,7 @@ from evaluation.utils.shared import ( EvalMetadata, EvalOutput, + compatibility_for_eval_history_pairs, make_metadata, prepare_dataset, reset_logger_for_multiprocessing, @@ -34,7 +35,7 @@ def codeact_user_response_eda(state: State) -> str: # retrieve the latest model message from history if state.history: - model_guess = state.history.get_last_agent_message() + model_guess = state.get_last_agent_message() assert game is not None, 'Game is not initialized.' msg = game.generate_user_response(model_guess) @@ -139,7 +140,7 @@ def process_instance( if state is None: raise ValueError('State should not be None.') - final_message = state.history.get_last_agent_message() + final_message = state.get_last_agent_message() logger.info(f'Final message: {final_message} | Ground truth: {instance["text"]}') test_result = game.reward() @@ -148,7 +149,7 @@ def process_instance( # history is now available as a stream of events, rather than list of pairs of (Action, Observation) # for compatibility with the existing output format, we can remake the pairs here # remove when it becomes unnecessary - histories = state.history.compatibility_for_eval_history_pairs() + histories = compatibility_for_eval_history_pairs(state.history) # Save the output output = EvalOutput( diff --git a/evaluation/agent_bench/run_infer.py b/evaluation/agent_bench/run_infer.py index d6fcc62e0798..acdf60fe4850 100644 --- a/evaluation/agent_bench/run_infer.py +++ b/evaluation/agent_bench/run_infer.py @@ -16,6 +16,7 @@ from evaluation.utils.shared import ( EvalMetadata, EvalOutput, + compatibility_for_eval_history_pairs, make_metadata, prepare_dataset, reset_logger_for_multiprocessing, @@ -242,7 +243,7 @@ def process_instance( raw_ans = '' # retrieve the last agent message or thought - for event in state.history.get_events(reverse=True): + for event in reversed(state.history): if event.source == 'agent': if isinstance(event, AgentFinishAction): raw_ans = event.thought @@ -271,7 +272,7 @@ def process_instance( # history is now available as a stream of events, rather than list of pairs of (Action, Observation) # for compatibility with the existing output format, we can remake the pairs here # remove when it becomes unnecessary - histories = state.history.compatibility_for_eval_history_pairs() + histories = compatibility_for_eval_history_pairs(state.history) metrics = state.metrics.get() if state.metrics else None diff --git a/evaluation/aider_bench/run_infer.py b/evaluation/aider_bench/run_infer.py index fa1bb9534a83..cddc4bfe7db9 100644 --- a/evaluation/aider_bench/run_infer.py +++ b/evaluation/aider_bench/run_infer.py @@ -15,6 +15,7 @@ from evaluation.utils.shared import ( EvalMetadata, EvalOutput, + compatibility_for_eval_history_pairs, make_metadata, prepare_dataset, reset_logger_for_multiprocessing, @@ -250,7 +251,7 @@ def process_instance( # history is now available as a stream of events, rather than list of pairs of (Action, Observation) # for compatibility with the existing output format, we can remake the pairs here # remove when it becomes unnecessary - histories = state.history.compatibility_for_eval_history_pairs() + histories = compatibility_for_eval_history_pairs(state.history) metrics = state.metrics.get() if state.metrics else None # Save the output diff --git a/evaluation/biocoder/run_infer.py b/evaluation/biocoder/run_infer.py index 4535ccba4e4e..5ab4b3b88313 100644 --- a/evaluation/biocoder/run_infer.py +++ b/evaluation/biocoder/run_infer.py @@ -13,6 +13,7 @@ EvalMetadata, EvalOutput, codeact_user_response, + compatibility_for_eval_history_pairs, make_metadata, prepare_dataset, reset_logger_for_multiprocessing, @@ -299,7 +300,7 @@ def process_instance( # history is now available as a stream of events, rather than list of pairs of (Action, Observation) # for compatibility with the existing output format, we can remake the pairs here # remove when it becomes unnecessary - histories = state.history.compatibility_for_eval_history_pairs() + histories = compatibility_for_eval_history_pairs(state.history) test_result['generated'] = test_result['metadata']['1_copy_change_code'] diff --git a/evaluation/bird/run_infer.py b/evaluation/bird/run_infer.py index adb498cd2eb1..248dbb66181c 100644 --- a/evaluation/bird/run_infer.py +++ b/evaluation/bird/run_infer.py @@ -16,6 +16,7 @@ from evaluation.utils.shared import ( EvalMetadata, EvalOutput, + compatibility_for_eval_history_pairs, make_metadata, prepare_dataset, reset_logger_for_multiprocessing, @@ -46,7 +47,7 @@ def codeact_user_response(state: State) -> str: # check if the agent has tried to talk to the user 3 times, if so, let the agent know it can give up user_msgs = [ event - for event in state.history.get_events() + for event in state.history if isinstance(event, MessageAction) and event.source == 'user' ] if len(user_msgs) > 2: @@ -431,7 +432,7 @@ def execute_sql(db_path, sql): # history is now available as a stream of events, rather than list of pairs of (Action, Observation) # for compatibility with the existing output format, we can remake the pairs here # remove when it becomes unnecessary - histories = state.history.compatibility_for_eval_history_pairs() + histories = compatibility_for_eval_history_pairs(state.history) # Save the output output = EvalOutput( diff --git a/evaluation/browsing_delegation/run_infer.py b/evaluation/browsing_delegation/run_infer.py index c9fe2ebd18bc..5c1ab8c062e3 100644 --- a/evaluation/browsing_delegation/run_infer.py +++ b/evaluation/browsing_delegation/run_infer.py @@ -9,6 +9,7 @@ from evaluation.utils.shared import ( EvalMetadata, EvalOutput, + compatibility_for_eval_history_pairs, make_metadata, prepare_dataset, reset_logger_for_multiprocessing, @@ -89,7 +90,7 @@ def process_instance( # history is now available as a stream of events, rather than list of pairs of (Action, Observation) # for compatibility with the existing output format, we can remake the pairs here # remove when it becomes unnecessary - histories = state.history.compatibility_for_eval_history_pairs() + histories = compatibility_for_eval_history_pairs(state.history) # find the last delegate action last_delegate_action = None diff --git a/evaluation/discoverybench/run_infer.py b/evaluation/discoverybench/run_infer.py index 77d72d04775a..72148a64e759 100644 --- a/evaluation/discoverybench/run_infer.py +++ b/evaluation/discoverybench/run_infer.py @@ -15,6 +15,7 @@ EvalMetadata, EvalOutput, codeact_user_response, + compatibility_for_eval_history_pairs, make_metadata, prepare_dataset, reset_logger_for_multiprocessing, @@ -173,14 +174,14 @@ def initialize_runtime(runtime: Runtime, data_files: list[str]): def get_last_agent_finish_action(state: State) -> AgentFinishAction: - for event in state.history.get_events(reverse=True): + for event in reversed(state.history): if isinstance(event, AgentFinishAction): return event return None def get_last_message_action(state: State) -> MessageAction: - for event in state.history.get_events(reverse=True): + for event in reversed(state.history): if isinstance(event, MessageAction): return event return None @@ -307,7 +308,7 @@ def process_instance( # history is now available as a stream of events, rather than list of pairs of (Action, Observation) # for compatibility with the existing output format, we can remake the pairs here # remove when it becomes unnecessary - histories = state.history.compatibility_for_eval_history_pairs() + histories = compatibility_for_eval_history_pairs(state.history) # DiscoveryBench Evaluation eval_rec = run_eval_gold_vs_gen_NL_hypo_workflow( diff --git a/evaluation/gaia/run_infer.py b/evaluation/gaia/run_infer.py index c02cd0aee737..1fa0c00e6d6a 100644 --- a/evaluation/gaia/run_infer.py +++ b/evaluation/gaia/run_infer.py @@ -12,6 +12,7 @@ EvalMetadata, EvalOutput, codeact_user_response, + compatibility_for_eval_history_pairs, make_metadata, prepare_dataset, reset_logger_for_multiprocessing, @@ -166,7 +167,7 @@ def process_instance( model_answer_raw = '' # get the last message or thought from the agent - for event in state.history.get_events(reverse=True): + for event in reversed(state.history): if event.source == 'agent': if isinstance(event, AgentFinishAction): model_answer_raw = event.thought @@ -203,7 +204,7 @@ def process_instance( # history is now available as a stream of events, rather than list of pairs of (Action, Observation) # for compatibility with the existing output format, we can remake the pairs here # remove when it becomes unnecessary - histories = state.history.compatibility_for_eval_history_pairs() + histories = compatibility_for_eval_history_pairs(state.history) # Save the output output = EvalOutput( diff --git a/evaluation/gorilla/run_infer.py b/evaluation/gorilla/run_infer.py index 873cb7f89694..e437f2b6075a 100644 --- a/evaluation/gorilla/run_infer.py +++ b/evaluation/gorilla/run_infer.py @@ -10,6 +10,7 @@ EvalMetadata, EvalOutput, codeact_user_response, + compatibility_for_eval_history_pairs, make_metadata, prepare_dataset, reset_logger_for_multiprocessing, @@ -101,7 +102,7 @@ def process_instance( raise ValueError('State should not be None.') # retrieve the last message from the agent - model_answer_raw = state.history.get_last_agent_message() + model_answer_raw = state.get_last_agent_message() # attempt to parse model_answer ast_eval_fn = instance['ast_eval'] @@ -114,7 +115,7 @@ def process_instance( # history is now available as a stream of events, rather than list of pairs of (Action, Observation) # for compatibility with the existing output format, we can remake the pairs here # remove when it becomes unnecessary - histories = state.history.compatibility_for_eval_history_pairs() + histories = compatibility_for_eval_history_pairs(state.history) output = EvalOutput( instance_id=instance_id, diff --git a/evaluation/gpqa/run_infer.py b/evaluation/gpqa/run_infer.py index 8fd4034c9d5e..58db2e404fc8 100644 --- a/evaluation/gpqa/run_infer.py +++ b/evaluation/gpqa/run_infer.py @@ -28,6 +28,7 @@ from evaluation.utils.shared import ( EvalMetadata, EvalOutput, + compatibility_for_eval_history_pairs, make_metadata, prepare_dataset, reset_logger_for_multiprocessing, @@ -244,7 +245,7 @@ def process_instance( 'C': False, 'D': False, } - for event in state.history.get_events(reverse=True): + for event in reversed(state.history): if ( isinstance(event, AgentFinishAction) and event.source != 'user' @@ -300,7 +301,7 @@ def process_instance( instance_id=str(instance.instance_id), instruction=instruction, metadata=metadata, - history=state.history.compatibility_for_eval_history_pairs(), + history=compatibility_for_eval_history_pairs(state.history), metrics=metrics, error=state.last_error if state and state.last_error else None, test_result={ diff --git a/evaluation/humanevalfix/run_infer.py b/evaluation/humanevalfix/run_infer.py index 25fee65561fc..2aa184758b33 100644 --- a/evaluation/humanevalfix/run_infer.py +++ b/evaluation/humanevalfix/run_infer.py @@ -21,6 +21,7 @@ EvalMetadata, EvalOutput, codeact_user_response, + compatibility_for_eval_history_pairs, make_metadata, prepare_dataset, reset_logger_for_multiprocessing, @@ -255,7 +256,7 @@ def process_instance( # history is now available as a stream of events, rather than list of pairs of (Action, Observation) # for compatibility with the existing output format, we can remake the pairs here # remove when it becomes unnecessary - histories = state.history.compatibility_for_eval_history_pairs() + histories = compatibility_for_eval_history_pairs(state.history) # Save the output output = EvalOutput( diff --git a/evaluation/integration_tests/run_infer.py b/evaluation/integration_tests/run_infer.py index 9232f9e7b1bf..5e3205fefe2e 100644 --- a/evaluation/integration_tests/run_infer.py +++ b/evaluation/integration_tests/run_infer.py @@ -129,7 +129,7 @@ def process_instance( # # result evaluation # # ============================================= - histories = [event_to_dict(event) for event in state.history.get_events()] + histories = [event_to_dict(event) for event in state.history] test_result: TestResult = test_class.verify_result(runtime, histories) metrics = state.metrics.get() if state.metrics else None diff --git a/evaluation/logic_reasoning/run_infer.py b/evaluation/logic_reasoning/run_infer.py index 5b7d35f21130..116b438b3ee9 100644 --- a/evaluation/logic_reasoning/run_infer.py +++ b/evaluation/logic_reasoning/run_infer.py @@ -8,6 +8,7 @@ EvalMetadata, EvalOutput, codeact_user_response, + compatibility_for_eval_history_pairs, make_metadata, prepare_dataset, reset_logger_for_multiprocessing, @@ -225,7 +226,7 @@ def process_instance( raise ValueError('State should not be None.') final_message = '' - for event in state.history.get_events(reverse=True): + for event in reversed(state.history): if isinstance(event, AgentFinishAction): final_message = event.thought break @@ -247,7 +248,7 @@ def process_instance( # history is now available as a stream of events, rather than list of pairs of (Action, Observation) # for compatibility with the existing output format, we can remake the pairs here # remove when it becomes unnecessary - histories = state.history.compatibility_for_eval_history_pairs() + histories = compatibility_for_eval_history_pairs(state.history) # Save the output output = EvalOutput( diff --git a/evaluation/miniwob/run_infer.py b/evaluation/miniwob/run_infer.py index 65eb0eaa4422..715bdaa470ae 100644 --- a/evaluation/miniwob/run_infer.py +++ b/evaluation/miniwob/run_infer.py @@ -11,6 +11,7 @@ EvalMetadata, EvalOutput, codeact_user_response, + compatibility_for_eval_history_pairs, make_metadata, prepare_dataset, reset_logger_for_multiprocessing, @@ -182,7 +183,7 @@ def process_instance( # Instruction is the first message from the USER instruction = '' - for event in state.history.get_events(): + for event in state.history: if isinstance(event, MessageAction): instruction = event.content break @@ -194,7 +195,7 @@ def process_instance( # history is now available as a stream of events, rather than list of pairs of (Action, Observation) # for compatibility with the existing output format, we can remake the pairs here # remove when it becomes unnecessary - histories = state.history.compatibility_for_eval_history_pairs() + histories = compatibility_for_eval_history_pairs(state.history) # Save the output output = EvalOutput( diff --git a/evaluation/mint/run_infer.py b/evaluation/mint/run_infer.py index 8017b194d8d8..2165c3c03fe4 100644 --- a/evaluation/mint/run_infer.py +++ b/evaluation/mint/run_infer.py @@ -13,6 +13,7 @@ from evaluation.utils.shared import ( EvalMetadata, EvalOutput, + compatibility_for_eval_history_pairs, make_metadata, prepare_dataset, reset_logger_for_multiprocessing, @@ -28,6 +29,7 @@ from openhands.core.logger import openhands_logger as logger from openhands.core.main import create_runtime, run_controller from openhands.events.action import ( + Action, CmdRunAction, MessageAction, ) @@ -45,7 +47,10 @@ def codeact_user_response_mint(state: State, task: Task, task_config: dict[str, task=task, task_config=task_config, ) - last_action = state.history.get_last_action() + last_action = next( + (event for event in reversed(state.history) if isinstance(event, Action)), + None, + ) result_state: TaskState = env.step(last_action.message or '') state.extra_data['task_state'] = result_state @@ -202,7 +207,7 @@ def process_instance( # history is now available as a stream of events, rather than list of pairs of (Action, Observation) # for compatibility with the existing output format, we can remake the pairs here # remove when it becomes unnecessary - histories = state.history.compatibility_for_eval_history_pairs() + histories = compatibility_for_eval_history_pairs(state.history) # Save the output output = EvalOutput( diff --git a/evaluation/ml_bench/run_infer.py b/evaluation/ml_bench/run_infer.py index deec068f3392..2bb667e3c947 100644 --- a/evaluation/ml_bench/run_infer.py +++ b/evaluation/ml_bench/run_infer.py @@ -24,6 +24,7 @@ EvalMetadata, EvalOutput, codeact_user_response, + compatibility_for_eval_history_pairs, make_metadata, prepare_dataset, reset_logger_for_multiprocessing, @@ -256,7 +257,7 @@ def process_instance(instance: Any, metadata: EvalMetadata, reset_logger: bool = # history is now available as a stream of events, rather than list of pairs of (Action, Observation) # for compatibility with the existing output format, we can remake the pairs here # remove when it becomes unnecessary - histories = state.history.compatibility_for_eval_history_pairs() + histories = compatibility_for_eval_history_pairs(state.history) # Save the output output = EvalOutput( diff --git a/evaluation/scienceagentbench/run_infer.py b/evaluation/scienceagentbench/run_infer.py index 8fbe61718742..93a82855452e 100644 --- a/evaluation/scienceagentbench/run_infer.py +++ b/evaluation/scienceagentbench/run_infer.py @@ -10,6 +10,7 @@ EvalMetadata, EvalOutput, codeact_user_response, + compatibility_for_eval_history_pairs, make_metadata, prepare_dataset, reset_logger_for_multiprocessing, @@ -232,7 +233,7 @@ def process_instance( # history is now available as a stream of events, rather than list of pairs of (Action, Observation) # for compatibility with the existing output format, we can remake the pairs here # remove when it becomes unnecessary - histories = state.history.compatibility_for_eval_history_pairs() + histories = compatibility_for_eval_history_pairs(state.history) # Save the output output = EvalOutput( diff --git a/evaluation/swe_bench/run_infer.py b/evaluation/swe_bench/run_infer.py index d5440664a61e..b06fba19a102 100644 --- a/evaluation/swe_bench/run_infer.py +++ b/evaluation/swe_bench/run_infer.py @@ -443,7 +443,8 @@ def process_instance( if state is None: raise ValueError('State should not be None.') - histories = [event_to_dict(event) for event in state.history.get_events()] + # NOTE: this is NO LONGER the event stream, but an agent history that includes delegate agent's events + histories = [event_to_dict(event) for event in state.history] metrics = state.metrics.get() if state.metrics else None # Save the output diff --git a/evaluation/toolqa/run_infer.py b/evaluation/toolqa/run_infer.py index 5c2c53422785..25633ce6ce23 100644 --- a/evaluation/toolqa/run_infer.py +++ b/evaluation/toolqa/run_infer.py @@ -9,6 +9,7 @@ EvalMetadata, EvalOutput, codeact_user_response, + compatibility_for_eval_history_pairs, make_metadata, prepare_dataset, reset_logger_for_multiprocessing, @@ -126,7 +127,7 @@ def process_instance(instance: Any, metadata: EvalMetadata, reset_logger: bool = raise ValueError('State should not be None.') # retrieve the last message from the agent - model_answer_raw = state.history.get_last_agent_message() + model_answer_raw = state.get_last_agent_message() # attempt to parse model_answer correct = eval_answer(str(model_answer_raw), str(answer)) @@ -137,7 +138,7 @@ def process_instance(instance: Any, metadata: EvalMetadata, reset_logger: bool = # history is now available as a stream of events, rather than list of pairs of (Action, Observation) # for compatibility with the existing output format, we can remake the pairs here # remove when it becomes unnecessary - histories = state.history.compatibility_for_eval_history_pairs() + histories = compatibility_for_eval_history_pairs(state.history) # Save the output output = EvalOutput( diff --git a/evaluation/utils/shared.py b/evaluation/utils/shared.py index 09bc61421e88..d5a6d6d89de8 100644 --- a/evaluation/utils/shared.py +++ b/evaluation/utils/shared.py @@ -18,6 +18,9 @@ from openhands.core.logger import openhands_logger as logger from openhands.events.action import Action from openhands.events.action.message import MessageAction +from openhands.events.event import Event +from openhands.events.serialization.event import event_to_dict +from openhands.events.utils import get_pairs_from_events class EvalMetadata(BaseModel): @@ -112,7 +115,14 @@ def codeact_user_response( if state.history: # check if the last action has an answer, if so, early exit if try_parse is not None: - last_action = state.history.get_last_action() + last_action = next( + ( + event + for event in reversed(state.history) + if isinstance(event, Action) + ), + None, + ) ans = try_parse(last_action) if ans is not None: return '/exit' @@ -120,7 +130,7 @@ def codeact_user_response( # check if the agent has tried to talk to the user 3 times, if so, let the agent know it can give up user_msgs = [ event - for event in state.history.get_events() + for event in state.history if isinstance(event, MessageAction) and event.source == 'user' ] if len(user_msgs) >= 2: @@ -428,3 +438,18 @@ def update_llm_config_for_completions_logging( f'{llm_config.log_completions_folder}' ) return llm_config + + +# history is now available as a filtered stream of events, rather than list of pairs of (Action, Observation) +# we rebuild the pairs here +# for compatibility with the existing output format in evaluations +# remove this when it's no longer necessary +def compatibility_for_eval_history_pairs( + history: list[Event], +) -> list[tuple[dict, dict]]: + history_pairs = [] + + for action, observation in get_pairs_from_events(history): + history_pairs.append((event_to_dict(action), event_to_dict(observation))) + + return history_pairs diff --git a/evaluation/webarena/run_infer.py b/evaluation/webarena/run_infer.py index cfc2bdae493a..531f134fd988 100644 --- a/evaluation/webarena/run_infer.py +++ b/evaluation/webarena/run_infer.py @@ -10,6 +10,7 @@ from evaluation.utils.shared import ( EvalMetadata, EvalOutput, + compatibility_for_eval_history_pairs, make_metadata, prepare_dataset, reset_logger_for_multiprocessing, @@ -166,7 +167,7 @@ def process_instance( # Instruction is the first message from the USER instruction = '' - for event in state.history.get_events(): + for event in state.history: if isinstance(event, MessageAction): instruction = event.content break @@ -178,7 +179,7 @@ def process_instance( # history is now available as a stream of events, rather than list of pairs of (Action, Observation) # for compatibility with the existing output format, we can remake the pairs here # remove when it becomes unnecessary - histories = state.history.compatibility_for_eval_history_pairs() + histories = compatibility_for_eval_history_pairs(state.history) # Save the output output = EvalOutput( diff --git a/openhands/agenthub/browsing_agent/browsing_agent.py b/openhands/agenthub/browsing_agent/browsing_agent.py index 0460506d04f3..822677bab526 100644 --- a/openhands/agenthub/browsing_agent/browsing_agent.py +++ b/openhands/agenthub/browsing_agent/browsing_agent.py @@ -150,13 +150,13 @@ def step(self, state: State) -> Action: last_obs = None last_action = None - if EVAL_MODE and len(state.history.get_events_as_list()) == 1: + if EVAL_MODE and len(state.history) == 1: # for webarena and miniwob++ eval, we need to retrieve the initial observation already in browser env # initialize and retrieve the first observation by issuing an noop OP # For non-benchmark browsing, the browser env starts with a blank page, and the agent is expected to first navigate to desired websites return BrowseInteractiveAction(browser_actions='noop()') - for event in state.history.get_events(): + for event in state.history: if isinstance(event, BrowseInteractiveAction): prev_actions.append(event.browser_actions) last_action = event diff --git a/openhands/agenthub/codeact_agent/codeact_agent.py b/openhands/agenthub/codeact_agent/codeact_agent.py index a599959e30d0..314a2e0a089b 100644 --- a/openhands/agenthub/codeact_agent/codeact_agent.py +++ b/openhands/agenthub/codeact_agent/codeact_agent.py @@ -337,8 +337,8 @@ def step(self, state: State) -> Action: return self.pending_actions.popleft() # if we're done, go back - latest_user_message = state.history.get_last_user_message() - if latest_user_message and latest_user_message.strip() == '/exit': + last_user_message = state.get_last_user_message() + if last_user_message and last_user_message.strip() == '/exit': return AgentFinishAction() # prepare what we want to send to the LLM @@ -419,7 +419,7 @@ def _get_messages(self, state: State) -> list[Message]: pending_tool_call_action_messages: dict[str, Message] = {} tool_call_id_to_message: dict[str, Message] = {} - events = list(state.history.get_events()) + events = list(state.history) for event in events: # create a regular message from an event if isinstance(event, Action): diff --git a/openhands/agenthub/codeact_swe_agent/codeact_swe_agent.py b/openhands/agenthub/codeact_swe_agent/codeact_swe_agent.py index 6fc679aec449..7c5b039e8c47 100644 --- a/openhands/agenthub/codeact_swe_agent/codeact_swe_agent.py +++ b/openhands/agenthub/codeact_swe_agent/codeact_swe_agent.py @@ -154,8 +154,8 @@ def step(self, state: State) -> Action: - AgentFinishAction() - end the interaction """ # if we're done, go back - latest_user_message = state.history.get_last_user_message() - if latest_user_message and latest_user_message.strip() == '/exit': + last_user_message = state.get_last_user_message() + if last_user_message and last_user_message.strip() == '/exit': return AgentFinishAction() # prepare what we want to send to the LLM @@ -176,7 +176,7 @@ def _get_messages(self, state: State) -> list[Message]: Message(role='user', content=[TextContent(text=self.in_context_example)]), ] - for event in state.history.get_events(): + for event in state.history: # create a regular message from an event if isinstance(event, Action): message = self.get_action_message(event) diff --git a/openhands/agenthub/delegator_agent/agent.py b/openhands/agenthub/delegator_agent/agent.py index 29e0030423c7..7cb987c8c3f7 100644 --- a/openhands/agenthub/delegator_agent/agent.py +++ b/openhands/agenthub/delegator_agent/agent.py @@ -2,7 +2,7 @@ from openhands.controller.state.state import State from openhands.core.config import AgentConfig from openhands.events.action import Action, AgentDelegateAction, AgentFinishAction -from openhands.events.observation import AgentDelegateObservation +from openhands.events.observation import AgentDelegateObservation, Observation from openhands.llm.llm import LLM @@ -41,7 +41,11 @@ def step(self, state: State) -> Action: ) # last observation in history should be from the delegate - last_observation = state.history.get_last_observation() + last_observation = None + for event in reversed(state.history): + if isinstance(event, Observation): + last_observation = event + break if not isinstance(last_observation, AgentDelegateObservation): raise Exception('Last observation is not an AgentDelegateObservation') diff --git a/openhands/agenthub/dummy_agent/agent.py b/openhands/agenthub/dummy_agent/agent.py index dbe4c60cfafa..272e6c935f2e 100644 --- a/openhands/agenthub/dummy_agent/agent.py +++ b/openhands/agenthub/dummy_agent/agent.py @@ -164,7 +164,7 @@ def step(self, state: State) -> Action: if 'observations' in prev_step and prev_step['observations']: expected_observations = prev_step['observations'] - hist_events = state.history.get_last_events(len(expected_observations)) + hist_events = state.history[-len(expected_observations) :] if len(hist_events) < len(expected_observations): print( diff --git a/openhands/agenthub/micro/agent.py b/openhands/agenthub/micro/agent.py index 83225a3245cd..a9b0825afd9d 100644 --- a/openhands/agenthub/micro/agent.py +++ b/openhands/agenthub/micro/agent.py @@ -8,10 +8,10 @@ from openhands.core.message import ImageContent, Message, TextContent from openhands.core.utils import json from openhands.events.action import Action +from openhands.events.event import Event from openhands.events.serialization.action import action_from_dict from openhands.events.serialization.event import event_to_memory from openhands.llm.llm import LLM -from openhands.memory.history import ShortTermHistory def parse_response(orig_response: str) -> Action: @@ -32,16 +32,14 @@ class MicroAgent(Agent): prompt = '' agent_definition: dict = {} - def history_to_json( - self, history: ShortTermHistory, max_events: int = 20, **kwargs - ): + def history_to_json(self, history: list[Event], max_events: int = 20, **kwargs): """ Serialize and simplify history to str format """ processed_history = [] event_count = 0 - for event in history.get_events(reverse=True): + for event in reversed(history): if event_count >= max_events: break processed_history.append( diff --git a/openhands/agenthub/planner_agent/prompt.py b/openhands/agenthub/planner_agent/prompt.py index 017c25bbef05..7b73f4353131 100644 --- a/openhands/agenthub/planner_agent/prompt.py +++ b/openhands/agenthub/planner_agent/prompt.py @@ -117,7 +117,7 @@ def get_hint(latest_action_id: str) -> str: def get_prompt_and_images( state: State, max_message_chars: int -) -> tuple[str, list[str]]: +) -> tuple[str, list[str] | None]: """Gets the prompt for the planner agent. Formatted with the most recent action-observation pairs, current task, and hint based on last action @@ -136,7 +136,7 @@ def get_prompt_and_images( latest_action: Action = NullAction() # retrieve the latest HISTORY_SIZE events - for event_count, event in enumerate(state.history.get_events(reverse=True)): + for event_count, event in enumerate(reversed(state.history)): if event_count >= HISTORY_SIZE: break if latest_action == NullAction() and isinstance(event, Action): diff --git a/openhands/controller/agent_controller.py b/openhands/controller/agent_controller.py index 3eb7a7d066ab..8899e5e15bf6 100644 --- a/openhands/controller/agent_controller.py +++ b/openhands/controller/agent_controller.py @@ -1,7 +1,7 @@ import asyncio import copy import traceback -from typing import Callable, Type +from typing import Callable, ClassVar, Type import litellm @@ -36,6 +36,7 @@ AgentDelegateObservation, AgentStateChangedObservation, ErrorObservation, + NullObservation, Observation, ) from openhands.events.serialization.event import truncate_content @@ -61,6 +62,12 @@ class AgentController: parent: 'AgentController | None' = None delegate: 'AgentController | None' = None _pending_action: Action | None = None + filter_out: ClassVar[tuple[type[Event], ...]] = ( + NullAction, + NullObservation, + ChangeAgentStateAction, + AgentStateChangedObservation, + ) def __init__( self, @@ -121,8 +128,34 @@ def __init__( self.status_callback = status_callback async def close(self): - """Closes the agent controller, canceling any ongoing tasks and unsubscribing from the event stream.""" + """Closes the agent controller, canceling any ongoing tasks and unsubscribing from the event stream. + + Note that it's fairly important that this closes properly, otherwise the state is incomplete.""" await self.set_agent_state_to(AgentState.STOPPED) + + # we made history, now is the time to rewrite it! + # the final state.history will be used by external scripts like evals, tests, etc. + # history will need to be complete WITH delegates events + # like the regular agent history, it does not include: + # - 'hidden' events, events with hidden=True + # - backend events (the default 'filtered out' types, types in self.filter_out) + start_id = self.state.start_id if self.state.start_id >= 0 else 0 + end_id = ( + self.state.end_id + if self.state.end_id >= 0 + else self.event_stream.get_latest_event_id() + ) + self.state.history = list( + self.event_stream.get_events( + start_id=start_id, + end_id=end_id, + reverse=False, + filter_out_type=self.filter_out, + filter_hidden=True, + ) + ) + + # unsubscribe from the event stream self.event_stream.unsubscribe(EventStreamSubscriber.AGENT_CONTROLLER) def log(self, level: str, message: str, extra: dict | None = None): @@ -178,6 +211,11 @@ async def on_event(self, event: Event): """ if hasattr(event, 'hidden') and event.hidden: return + + # if the event is not filtered out, add it to the history + if not any(isinstance(event, filter_type) for filter_type in self.filter_out): + self.state.history.append(event) + if isinstance(event, Action): await self._handle_action(event) elif isinstance(event, Observation): @@ -233,9 +271,6 @@ async def _handle_observation(self, observation: Observation): if self.state.agent_state == AgentState.USER_REJECTED: await self.set_agent_state_to(AgentState.AWAITING_USER_INPUT) return - - if isinstance(observation, AgentDelegateObservation): - self.state.history.on_event(observation) elif isinstance(observation, ErrorObservation): if self.state.agent_state == AgentState.ERROR: self.state.metrics.merge(self.state.local_metrics) @@ -362,6 +397,8 @@ async def start_delegate(self, action: AgentDelegateAction): delegate_level=self.state.delegate_level + 1, # global metrics should be shared between parent and child metrics=self.state.metrics, + # start on top of the stream + start_id=self.event_stream.get_latest_event_id() + 1, ) self.log( 'debug', @@ -463,9 +500,7 @@ async def _step(self) -> None: async def _delegate_step(self): """Executes a single step of the delegate agent.""" - self.log('debug', 'Delegate not none, awaiting...') await self.delegate._step() # type: ignore[union-attr] - self.log('debug', 'Delegate step done') assert self.delegate is not None delegate_state = self.delegate.get_agent_state() self.log('debug', f'Delegate state: {delegate_state}') @@ -473,7 +508,7 @@ async def _delegate_step(self): # update iteration that shall be shared across agents self.state.iteration = self.delegate.state.iteration - # emit AgentDelegateObservation when the delegate terminates due to error + # emit AgentDelegateObservation to mark delegate termination due to error delegate_outputs = ( self.delegate.state.outputs if self.delegate.state else {} ) @@ -488,10 +523,6 @@ async def _delegate_step(self): self.delegate = None self.delegateAction = None - self.event_stream.add_event( - ErrorObservation('Delegate agent encountered an error'), - EventSource.AGENT, - ) elif delegate_state in (AgentState.FINISHED, AgentState.REJECTED): self.log('debug', 'Delegate agent has finished execution') # retrieve delegate result @@ -574,8 +605,10 @@ def set_initial_state( max_iterations: The maximum number of iterations allowed for the task. confirmation_mode: Whether to enable confirmation mode. """ - # state from the previous session, state from a parent agent, or a new state - # note that this is called twice when restoring a previous session, first with state=None + # state can come from: + # - the previous session, in which case it has history + # - from a parent agent, in which case it has no history + # - None / a new state if state is None: self.state = State( inputs={}, @@ -585,27 +618,109 @@ def set_initial_state( else: self.state = state - # when restored from a previous session, the State object will have history, start_id, and end_id - # connect it to the event stream - self.state.history.set_event_stream(self.event_stream) + if self.state.start_id <= -1: + self.state.start_id = 0 - # if start_id was not set in State, we're starting fresh, at the top of the stream - start_id = self.state.start_id - if start_id == -1: - start_id = self.event_stream.get_latest_event_id() + 1 - else: self.log( - 'debug', f'AgentController {self.id} restoring from event {start_id}' + 'debug', + f'AgentController {self.id} initializing history from event {self.state.start_id}', + ) + + self._init_history() + + def _init_history(self): + """Initializes the agent's history from the event stream. + + The history is a list of events that: + - Excludes events of types listed in self.filter_out + - Excludes events with hidden=True attribute + - For delegate events (between AgentDelegateAction and AgentDelegateObservation): + - Excludes all events between the action and observation + - Includes the delegate action and observation themselves + """ + + # define range of events to fetch + # delegates start with a start_id and initially won't find any events + # otherwise we're restoring a previous session + start_id = self.state.start_id if self.state.start_id >= 0 else 0 + end_id = ( + self.state.end_id + if self.state.end_id >= 0 + else self.event_stream.get_latest_event_id() + ) + + # sanity check + if start_id > end_id + 1: + self.log( + 'debug', + f'start_id {start_id} is greater than end_id + 1 ({end_id + 1}). History will be empty.', ) + self.state.history = [] + return + + # Get all events, filtering out backend events and hidden events + events = list( + self.event_stream.get_events( + start_id=start_id, + end_id=end_id, + reverse=False, + filter_out_type=self.filter_out, + filter_hidden=True, + ) + ) + + # Find all delegate action/observation pairs + delegate_ranges: list[tuple[int, int]] = [] + delegate_action_ids: list[int] = [] # stack of unmatched delegate action IDs + + for event in events: + if isinstance(event, AgentDelegateAction): + delegate_action_ids.append(event.id) + # Note: we can get agent=event.agent and task=event.inputs.get('task','') + # if we need to track these in the future + + elif isinstance(event, AgentDelegateObservation): + # Match with most recent unmatched delegate action + if not delegate_action_ids: + self.log( + 'error', + f'Found AgentDelegateObservation without matching action at id={event.id}', + ) + continue + + action_id = delegate_action_ids.pop() + delegate_ranges.append((action_id, event.id)) + + # Filter out events between delegate action/observation pairs + if delegate_ranges: + filtered_events: list[Event] = [] + current_idx = 0 + + for start_id, end_id in sorted(delegate_ranges): + # Add events before delegate range + filtered_events.extend( + event for event in events[current_idx:] if event.id < start_id + ) + + # Add delegate action and observation + filtered_events.extend( + event for event in events if event.id in (start_id, end_id) + ) + + # Update index to after delegate range + current_idx = next( + (i for i, e in enumerate(events) if e.id > end_id), len(events) + ) + + # Add any remaining events after last delegate range + filtered_events.extend(events[current_idx:]) + + self.state.history = filtered_events + else: + self.state.history = events # make sure history is in sync self.state.start_id = start_id - self.state.history.start_id = start_id - - # if there was an end_id saved in State, set it in history - # currently not used, later useful for delegates - if self.state.end_id > -1: - self.state.history.end_id = self.state.end_id def _is_stuck(self): """Checks if the agent or its delegate is stuck in a loop. diff --git a/openhands/controller/state/state.py b/openhands/controller/state/state.py index 8e6c911c5ef6..96c0ab7e8322 100644 --- a/openhands/controller/state/state.py +++ b/openhands/controller/state/state.py @@ -11,9 +11,8 @@ MessageAction, ) from openhands.events.action.agent import AgentFinishAction -from openhands.events.observation import ErrorObservation +from openhands.events.event import Event, EventSource from openhands.llm.metrics import Metrics -from openhands.memory.history import ShortTermHistory from openhands.storage.files import FileStore @@ -78,7 +77,7 @@ class State: # max number of iterations for the current task max_iterations: int = 100 confirmation_mode: bool = False - history: ShortTermHistory = field(default_factory=ShortTermHistory) + history: list[Event] = field(default_factory=list) inputs: dict = field(default_factory=dict) outputs: dict = field(default_factory=dict) agent_state: AgentState = AgentState.LOADING @@ -94,6 +93,7 @@ class State: start_id: int = -1 end_id: int = -1 almost_stuck: int = 0 + delegates: dict[tuple[int, int], tuple[str, str]] = field(default_factory=dict) # NOTE: This will never be used by the controller, but it can be used by different # evaluation tasks to store extra data needed to track the progress/state of the task. extra_data: dict[str, Any] = field(default_factory=dict) @@ -116,7 +116,7 @@ def restore_from_session(sid: str, file_store: FileStore) -> 'State': pickled = base64.b64decode(encoded) state = pickle.loads(pickled) except Exception as e: - logger.warning(f'Failed to restore state from session: {e}') + logger.warning(f'Could not restore state from session: {e}') raise e # update state @@ -130,39 +130,40 @@ def restore_from_session(sid: str, file_store: FileStore) -> 'State': return state def __getstate__(self): + # don't pickle history, it will be restored from the event stream state = self.__dict__.copy() - - # save the relevant data from recent history - # so that we can restore it when the state is restored - if 'history' in state: - state['start_id'] = state['history'].start_id - state['end_id'] = state['history'].end_id - - # don't save history object itself - state.pop('history', None) + state['history'] = [] return state def __setstate__(self, state): self.__dict__.update(state) - # recreate the history object + # make sure we always have the attribute history if not hasattr(self, 'history'): - self.history = ShortTermHistory() - - self.history.start_id = self.start_id - self.history.end_id = self.end_id + self.history = [] - - def get_current_user_intent(self): + def get_current_user_intent(self) -> tuple[str | None, list[str] | None]: """Returns the latest user message and image(if provided) that appears after a FinishAction, or the first (the task) if nothing was finished yet.""" last_user_message = None last_user_message_image_urls: list[str] | None = [] - for event in self.history.get_events(reverse=True): + for event in reversed(self.history): if isinstance(event, MessageAction) and event.source == 'user': last_user_message = event.content last_user_message_image_urls = event.images_urls elif isinstance(event, AgentFinishAction): if last_user_message is not None: - return last_user_message + return last_user_message, None return last_user_message, last_user_message_image_urls + + def get_last_agent_message(self) -> str | None: + for event in reversed(self.history): + if isinstance(event, MessageAction) and event.source == EventSource.AGENT: + return event.content + return None + + def get_last_user_message(self) -> str | None: + for event in reversed(self.history): + if isinstance(event, MessageAction) and event.source == EventSource.USER: + return event.content + return None diff --git a/openhands/controller/stuck.py b/openhands/controller/stuck.py index 230d5f2e81ac..0eb0f4c893ca 100644 --- a/openhands/controller/stuck.py +++ b/openhands/controller/stuck.py @@ -28,7 +28,7 @@ def is_stuck(self): # filter out MessageAction with source='user' from history filtered_history = [ event - for event in self.state.history.get_events() + for event in self.state.history if not ( (isinstance(event, MessageAction) and event.source == EventSource.USER) or diff --git a/openhands/core/config/app_config.py b/openhands/core/config/app_config.py index a60c5070286f..6511f634983a 100644 --- a/openhands/core/config/app_config.py +++ b/openhands/core/config/app_config.py @@ -38,7 +38,6 @@ class AppConfig: e2b_api_key: The E2B API key. disable_color: Whether to disable color. For terminals that don't support color. debug: Whether to enable debugging. - enable_cli_session: Whether to enable saving and restoring the session when run from CLI. file_uploads_max_file_size_mb: Maximum file size for uploads in megabytes. 0 means no limit. file_uploads_restrict_file_types: Whether to restrict file types for file uploads. Defaults to False. file_uploads_allowed_extensions: List of allowed file extensions for uploads. ['.*'] means all extensions are allowed. @@ -67,7 +66,6 @@ class AppConfig: disable_color: bool = False jwt_secret: str = uuid.uuid4().hex debug: bool = False - enable_cli_session: bool = False file_uploads_max_file_size_mb: int = 0 file_uploads_restrict_file_types: bool = False file_uploads_allowed_extensions: list[str] = field(default_factory=lambda: ['.*']) diff --git a/openhands/core/main.py b/openhands/core/main.py index c338f35e6bce..c077f9b9cc53 100644 --- a/openhands/core/main.py +++ b/openhands/core/main.py @@ -125,16 +125,18 @@ async def run_controller( runtime = create_runtime(config, sid=sid) event_stream = runtime.event_stream - # restore cli session if enabled + + # restore cli session if available initial_state = None - if config.enable_cli_session: - try: - logger.debug(f'Restoring agent state from cli session {event_stream.sid}') - initial_state = State.restore_from_session( - event_stream.sid, event_stream.file_store - ) - except Exception as e: - logger.debug(f'Error restoring state: {e}') + try: + logger.debug( + f'Trying to restore agent state from cli session {event_stream.sid} if available' + ) + initial_state = State.restore_from_session( + event_stream.sid, event_stream.file_store + ) + except Exception as e: + logger.debug(f'Cannot restore agent state: {e}') # init controller with this initial state controller = AgentController( @@ -157,7 +159,7 @@ async def run_controller( ) # start event is a MessageAction with the task, either resumed or new - if config.enable_cli_session and initial_state is not None: + if initial_state is not None: # we're resuming the previous session event_stream.add_event( MessageAction( @@ -168,7 +170,7 @@ async def run_controller( ), EventSource.USER, ) - elif initial_state is None: + else: # init with the provided actions event_stream.add_event(initial_user_action, EventSource.USER) @@ -202,8 +204,9 @@ async def on_event(event: Event): logger.error(f'Exception in main loop: {e}') # save session when we're about to close - if config.enable_cli_session: + if config.file_store is not None and config.file_store != 'memory': end_state = controller.get_state() + # NOTE: the saved state does not include delegates events end_state.save_to_session(event_stream.sid, event_stream.file_store) state = controller.get_state() @@ -212,10 +215,7 @@ async def on_event(event: Event): if config.trajectories_path is not None: file_path = os.path.join(config.trajectories_path, sid + '.json') os.makedirs(os.path.dirname(file_path), exist_ok=True) - histories = [ - event_to_trajectory(event) - for event in state.history.get_events(include_delegates=True) - ] + histories = [event_to_trajectory(event) for event in state.history] with open(file_path, 'w') as f: json.dump(histories, f) diff --git a/openhands/events/action/message.py b/openhands/events/action/message.py index 55fb21f359d3..0e3bb26a1cc2 100644 --- a/openhands/events/action/message.py +++ b/openhands/events/action/message.py @@ -7,7 +7,7 @@ @dataclass class MessageAction(Action): content: str - images_urls: list | None = None + images_urls: list[str] | None = None wait_for_response: bool = False action: str = ActionType.MESSAGE security_risk: ActionSecurityRisk | None = None diff --git a/openhands/events/stream.py b/openhands/events/stream.py index 680aba511e12..3bc4d9875a27 100644 --- a/openhands/events/stream.py +++ b/openhands/events/stream.py @@ -83,12 +83,27 @@ def _get_id_from_filename(filename: str) -> int: def get_events( self, - start_id=0, - end_id=None, - reverse=False, + start_id: int = 0, + end_id: int | None = None, + reverse: bool = False, filter_out_type: tuple[type[Event], ...] | None = None, filter_hidden=False, ) -> Iterable[Event]: + """ + Retrieve events from the event stream, optionally filtering out events of a given type + and events marked as hidden. + + Args: + start_id: The ID of the first event to retrieve. Defaults to 0. + end_id: The ID of the last event to retrieve. Defaults to the last event in the stream. + reverse: Whether to retrieve events in reverse order. Defaults to False. + filter_out_type: A tuple of event types to filter out. Typically used to filter out backend events from the agent. + filter_hidden: If True, filters out events with the 'hidden' attribute set to True. + + Yields: + Events from the stream that match the criteria. + """ + def should_filter(event: Event): if filter_hidden and hasattr(event, 'hidden') and event.hidden: return True diff --git a/openhands/memory/__init__.py b/openhands/memory/__init__.py index 0ce208cef581..12c499c768be 100644 --- a/openhands/memory/__init__.py +++ b/openhands/memory/__init__.py @@ -1,5 +1,4 @@ from openhands.memory.condenser import MemoryCondenser -from openhands.memory.history import ShortTermHistory from openhands.memory.memory import LongTermMemory -__all__ = ['LongTermMemory', 'ShortTermHistory', 'MemoryCondenser'] +__all__ = ['LongTermMemory', 'MemoryCondenser'] diff --git a/openhands/memory/history.py b/openhands/memory/history.py deleted file mode 100644 index 1e4cfb8b5f05..000000000000 --- a/openhands/memory/history.py +++ /dev/null @@ -1,224 +0,0 @@ -from typing import ClassVar, Iterable - -from openhands.core.logger import openhands_logger as logger -from openhands.events.action.action import Action -from openhands.events.action.agent import ( - AgentDelegateAction, - ChangeAgentStateAction, -) -from openhands.events.action.empty import NullAction -from openhands.events.action.message import MessageAction -from openhands.events.event import Event, EventSource -from openhands.events.observation.agent import AgentStateChangedObservation -from openhands.events.observation.delegate import AgentDelegateObservation -from openhands.events.observation.empty import NullObservation -from openhands.events.observation.observation import Observation -from openhands.events.serialization.event import event_to_dict -from openhands.events.stream import EventStream -from openhands.events.utils import get_pairs_from_events - - -class ShortTermHistory(list[Event]): - """A list of events that represents the short-term memory of the agent. - - This class provides methods to retrieve and filter the events in the history of the running agent from the event stream. - """ - - start_id: int - end_id: int - _event_stream: EventStream - delegates: dict[tuple[int, int], tuple[str, str]] - filter_out: ClassVar[tuple[type[Event], ...]] = ( - NullAction, - NullObservation, - ChangeAgentStateAction, - AgentStateChangedObservation, - ) - - def __init__(self): - super().__init__() - self.start_id = -1 - self.end_id = -1 - self.delegates = {} - - def set_event_stream(self, event_stream: EventStream): - self._event_stream = event_stream - - def get_events_as_list(self, include_delegates: bool = False) -> list[Event]: - """Return the history as a list of Event objects.""" - return list(self.get_events(include_delegates=include_delegates)) - - def get_events( - self, - reverse: bool = False, - include_delegates: bool = False, - include_hidden=False, - ) -> Iterable[Event]: - """Return the events as a stream of Event objects.""" - # TODO handle AgentRejectAction, if it's not part of a chunk ending with an AgentDelegateObservation - # or even if it is, because currently we don't add it to the summary - - # iterate from start_id to end_id, or reverse - start_id = self.start_id if self.start_id != -1 else 0 - end_id = ( - self.end_id - if self.end_id != -1 - else self._event_stream.get_latest_event_id() - ) - - for event in self._event_stream.get_events( - start_id=start_id, - end_id=end_id, - reverse=reverse, - filter_out_type=self.filter_out, - ): - if not include_hidden and hasattr(event, 'hidden') and event.hidden: - continue - # TODO add summaries - # and filter out events that were included in a summary - - # filter out the events from a delegate of the current agent - if not include_delegates and not any( - # except for the delegate action and observation themselves, currently - # AgentDelegateAction has id = delegate_start - # AgentDelegateObservation has id = delegate_end - delegate_start < event.id < delegate_end - for delegate_start, delegate_end in self.delegates.keys() - ): - yield event - elif include_delegates: - yield event - - def get_last_action(self, end_id: int = -1) -> Action | None: - """Return the last action from the event stream, filtered to exclude unwanted events.""" - # from end_id in reverse, find the first action - end_id = self._event_stream.get_latest_event_id() if end_id == -1 else end_id - - last_action = next( - ( - event - for event in self._event_stream.get_events( - end_id=end_id, reverse=True, filter_out_type=self.filter_out - ) - if isinstance(event, Action) - ), - None, - ) - - return last_action - - def get_last_observation(self, end_id: int = -1) -> Observation | None: - """Return the last observation from the event stream, filtered to exclude unwanted events.""" - # from end_id in reverse, find the first observation - end_id = self._event_stream.get_latest_event_id() if end_id == -1 else end_id - - last_observation = next( - ( - event - for event in self._event_stream.get_events( - end_id=end_id, reverse=True, filter_out_type=self.filter_out - ) - if isinstance(event, Observation) - ), - None, - ) - - return last_observation - - def get_last_user_message(self) -> str: - """Return the content of the last user message from the event stream.""" - last_user_message = next( - ( - event.content - for event in self._event_stream.get_events(reverse=True) - if isinstance(event, MessageAction) and event.source == EventSource.USER - ), - None, - ) - - return last_user_message if last_user_message is not None else '' - - def get_last_agent_message(self) -> str: - """Return the content of the last agent message from the event stream.""" - last_agent_message = next( - ( - event.content - for event in self._event_stream.get_events(reverse=True) - if isinstance(event, MessageAction) - and event.source == EventSource.AGENT - ), - None, - ) - - return last_agent_message if last_agent_message is not None else '' - - def get_last_events(self, n: int) -> list[Event]: - """Return the last n events from the event stream.""" - # dummy agent is using this - # it should work, but it's not great to store temporary lists now just for a test - end_id = self._event_stream.get_latest_event_id() - start_id = max(0, end_id - n + 1) - - return list( - event - for event in self._event_stream.get_events( - start_id=start_id, - end_id=end_id, - filter_out_type=self.filter_out, - ) - ) - - def has_delegation(self) -> bool: - for event in self._event_stream.get_events(): - if isinstance(event, AgentDelegateObservation): - return True - return False - - def on_event(self, event: Event): - if not isinstance(event, AgentDelegateObservation): - return - - logger.debug('AgentDelegateObservation received') - - # figure out what this delegate's actions were - # from the last AgentDelegateAction to this AgentDelegateObservation - # and save their ids as start and end ids - # in order to use later to exclude them from parent stream - # or summarize them - delegate_end = event.id - delegate_start = -1 - delegate_agent: str = '' - delegate_task: str = '' - for prev_event in self._event_stream.get_events( - end_id=event.id - 1, reverse=True - ): - if isinstance(prev_event, AgentDelegateAction): - delegate_start = prev_event.id - delegate_agent = prev_event.agent - delegate_task = prev_event.inputs.get('task', '') - break - - if delegate_start == -1: - logger.error( - f'No AgentDelegateAction found for AgentDelegateObservation with id={delegate_end}' - ) - return - - self.delegates[(delegate_start, delegate_end)] = (delegate_agent, delegate_task) - logger.debug( - f'Delegate {delegate_agent} with task {delegate_task} ran from id={delegate_start} to id={delegate_end}' - ) - - # TODO remove me when unnecessary - # history is now available as a filtered stream of events, rather than list of pairs of (Action, Observation) - # we rebuild the pairs here - # for compatibility with the existing output format in evaluations - def compatibility_for_eval_history_pairs(self) -> list[tuple[dict, dict]]: - history_pairs = [] - - for action, observation in get_pairs_from_events( - self.get_events_as_list(include_delegates=True) - ): - history_pairs.append((event_to_dict(action), event_to_dict(observation))) - - return history_pairs diff --git a/openhands/runtime/utils/edit.py b/openhands/runtime/utils/edit.py index 6595760e061f..cd3ffd0b71ce 100644 --- a/openhands/runtime/utils/edit.py +++ b/openhands/runtime/utils/edit.py @@ -213,7 +213,9 @@ def edit(self, action: FileEditAction) -> Observation: if isinstance(obs, ErrorObservation): return obs if not isinstance(obs, FileWriteObservation): - raise ValueError(f'Expected FileWriteObservation, got {type(obs)}: {str(obs)}') + raise ValueError( + f'Expected FileWriteObservation, got {type(obs)}: {str(obs)}' + ) return FileEditObservation( content=get_diff('', action.content, action.path), path=action.path, @@ -222,7 +224,9 @@ def edit(self, action: FileEditAction) -> Observation: new_content=action.content, ) if not isinstance(obs, FileReadObservation): - raise ValueError(f'Expected FileReadObservation, got {type(obs)}: {str(obs)}') + raise ValueError( + f'Expected FileReadObservation, got {type(obs)}: {str(obs)}' + ) original_file_content = obs.content old_file_lines = original_file_content.split('\n') diff --git a/tests/runtime/test_stress_remote_runtime.py b/tests/runtime/test_stress_remote_runtime.py index 4d96ee132a2d..3a5d6d280726 100644 --- a/tests/runtime/test_stress_remote_runtime.py +++ b/tests/runtime/test_stress_remote_runtime.py @@ -181,7 +181,7 @@ def next_command(*args, **kwargs): test_result = {} if state is None: raise ValueError('State should not be None.') - histories = [event_to_dict(event) for event in state.history.get_events()] + histories = [event_to_dict(event) for event in state.history] metrics = state.metrics.get() if state.metrics else None # Save the output diff --git a/tests/unit/test_codeact_agent.py b/tests/unit/test_codeact_agent.py index 9e3dda6c2cdd..126ce788c77c 100644 --- a/tests/unit/test_codeact_agent.py +++ b/tests/unit/test_codeact_agent.py @@ -104,5 +104,5 @@ def test_error_observation_message(agent: CodeActAgent): def test_unknown_observation_message(agent: CodeActAgent): obs = Mock() - with pytest.raises(ValueError, match='Unknown observation type:'): + with pytest.raises(ValueError, match='Unknown observation type'): agent.get_observation_message(obs, tool_call_id_to_message={}) diff --git a/tests/unit/test_is_stuck.py b/tests/unit/test_is_stuck.py index af3ef6b83cbe..197d6d8462b7 100644 --- a/tests/unit/test_is_stuck.py +++ b/tests/unit/test_is_stuck.py @@ -17,8 +17,6 @@ from openhands.events.observation.empty import NullObservation from openhands.events.observation.error import ErrorObservation from openhands.events.stream import EventSource, EventStream -from openhands.events.utils import get_pairs_from_events -from openhands.memory.history import ShortTermHistory from openhands.storage import get_file_store @@ -55,22 +53,21 @@ def event_stream(temp_dir): class TestStuckDetector: @pytest.fixture - def stuck_detector(self, event_stream): + def stuck_detector(self): state = State(inputs={}, max_iterations=50) - state.history.set_event_stream(event_stream) - + state.history = [] # Initialize history as an empty list return StuckDetector(state) def _impl_syntax_error_events( self, - event_stream: EventStream, + state: State, error_message: str, random_line: bool, incidents: int = 4, ): for i in range(incidents): ipython_action = IPythonRunCellAction(code=code_snippet) - event_stream.add_event(ipython_action, EventSource.AGENT) + state.history.append(ipython_action) extra_number = (i + 1) * 10 if random_line else '42' extra_line = '\n' * (i + 1) if random_line else '' ipython_observation = IPythonRunCellObservation( @@ -79,15 +76,15 @@ def _impl_syntax_error_events( f'{error_message}{extra_line}' + jupyter_line_1 + jupyter_line_2, code=code_snippet, ) - ipython_observation._cause = ipython_action._id - event_stream.add_event(ipython_observation, EventSource.ENVIRONMENT) + # ipython_observation._cause = ipython_action._id + state.history.append(ipython_observation) def _impl_unterminated_string_error_events( - self, event_stream: EventStream, random_line: bool, incidents: int = 4 + self, state: State, random_line: bool, incidents: int = 4 ): for i in range(incidents): ipython_action = IPythonRunCellAction(code=code_snippet) - event_stream.add_event(ipython_action, EventSource.AGENT) + state.history.append(ipython_action) line_number = (i + 1) * 10 if random_line else '1' ipython_observation = IPythonRunCellObservation( content=f'print(" Cell In[1], line {line_number}\nhello\n ^\nSyntaxError: unterminated string literal (detected at line {line_number})' @@ -95,34 +92,30 @@ def _impl_unterminated_string_error_events( + jupyter_line_2, code=code_snippet, ) - ipython_observation._cause = ipython_action._id - event_stream.add_event(ipython_observation, EventSource.ENVIRONMENT) + # ipython_observation._cause = ipython_action._ + state.history.append(ipython_observation) - def test_history_too_short( - self, stuck_detector: StuckDetector, event_stream: EventStream - ): + def test_history_too_short(self, stuck_detector: StuckDetector): + state = stuck_detector.state message_action = MessageAction(content='Hello', wait_for_response=False) message_action._source = EventSource.USER observation = NullObservation(content='') - observation._cause = message_action.id - event_stream.add_event(message_action, EventSource.USER) - event_stream.add_event(observation, EventSource.ENVIRONMENT) + # observation._cause = message_action.id + state.history.append(message_action) + state.history.append(observation) cmd_action = CmdRunAction(command='ls') - event_stream.add_event(cmd_action, EventSource.AGENT) + state.history.append(cmd_action) cmd_observation = CmdOutputObservation( command_id=1, command='ls', content='file1.txt\nfile2.txt' ) - cmd_observation._cause = cmd_action._id - event_stream.add_event(cmd_observation, EventSource.ENVIRONMENT) - - # stuck_detector.state.history.set_event_stream(event_stream) + # cmd_observation._cause = cmd_action._id + state.history.append(cmd_observation) assert stuck_detector.is_stuck() is False - def test_is_stuck_repeating_action_observation( - self, stuck_detector: StuckDetector, event_stream: EventStream - ): + def test_is_stuck_repeating_action_observation(self, stuck_detector: StuckDetector): + state = stuck_detector.state message_action = MessageAction(content='Done', wait_for_response=False) message_action._source = EventSource.USER @@ -130,135 +123,125 @@ def test_is_stuck_repeating_action_observation( hello_observation = NullObservation('') # 2 events - event_stream.add_event(hello_action, EventSource.USER) - event_stream.add_event(hello_observation, EventSource.ENVIRONMENT) + state.history.append(hello_action) + state.history.append(hello_observation) cmd_action_1 = CmdRunAction(command='ls') - event_stream.add_event(cmd_action_1, EventSource.AGENT) - cmd_observation_1 = CmdOutputObservation( - content='', command='ls', command_id=cmd_action_1._id - ) + cmd_action_1._id = 1 + state.history.append(cmd_action_1) + cmd_observation_1 = CmdOutputObservation(content='', command='ls', command_id=1) cmd_observation_1._cause = cmd_action_1._id - event_stream.add_event(cmd_observation_1, EventSource.ENVIRONMENT) + state.history.append(cmd_observation_1) # 4 events cmd_action_2 = CmdRunAction(command='ls') - event_stream.add_event(cmd_action_2, EventSource.AGENT) - cmd_observation_2 = CmdOutputObservation( - content='', command='ls', command_id=cmd_action_2._id - ) + cmd_action_2._id = 2 + state.history.append(cmd_action_2) + cmd_observation_2 = CmdOutputObservation(content='', command='ls', command_id=2) cmd_observation_2._cause = cmd_action_2._id - event_stream.add_event(cmd_observation_2, EventSource.ENVIRONMENT) + state.history.append(cmd_observation_2) # 6 events # random user message just because we can message_null_observation = NullObservation(content='') - event_stream.add_event(message_action, EventSource.USER) - event_stream.add_event(message_null_observation, EventSource.ENVIRONMENT) + state.history.append(message_action) + state.history.append(message_null_observation) # 8 events assert stuck_detector.is_stuck() is False assert stuck_detector.state.almost_stuck == 2 cmd_action_3 = CmdRunAction(command='ls') - event_stream.add_event(cmd_action_3, EventSource.AGENT) - cmd_observation_3 = CmdOutputObservation( - content='', command='ls', command_id=cmd_action_3._id - ) + cmd_action_3._id = 3 + state.history.append(cmd_action_3) + cmd_observation_3 = CmdOutputObservation(content='', command='ls', command_id=3) cmd_observation_3._cause = cmd_action_3._id - event_stream.add_event(cmd_observation_3, EventSource.ENVIRONMENT) + state.history.append(cmd_observation_3) # 10 events - assert len(collect_events(event_stream)) == 10 - assert len(list(stuck_detector.state.history.get_events())) == 8 + assert len(state.history) == 10 assert ( - len( - get_pairs_from_events( - stuck_detector.state.history.get_events_as_list( - include_delegates=True - ) - ) - ) - == 5 - ) + len(state.history) == 10 + ) # Adjusted since history is a list and the controller is not running + + # FIXME are we still testing this without this test? + # assert ( + # len( + # get_pairs_from_events(state.history) + # ) + # == 5 + # ) assert stuck_detector.is_stuck() is False assert stuck_detector.state.almost_stuck == 1 cmd_action_4 = CmdRunAction(command='ls') - event_stream.add_event(cmd_action_4, EventSource.AGENT) - cmd_observation_4 = CmdOutputObservation( - content='', command='ls', command_id=cmd_action_4._id - ) + cmd_action_4._id = 4 + state.history.append(cmd_action_4) + cmd_observation_4 = CmdOutputObservation(content='', command='ls', command_id=4) cmd_observation_4._cause = cmd_action_4._id - event_stream.add_event(cmd_observation_4, EventSource.ENVIRONMENT) + state.history.append(cmd_observation_4) # 12 events - assert len(collect_events(event_stream)) == 12 - assert len(list(stuck_detector.state.history.get_events())) == 10 - assert ( - len( - get_pairs_from_events( - stuck_detector.state.history.get_events_as_list( - include_delegates=True - ) - ) - ) - == 6 - ) + assert len(state.history) == 12 + # assert ( + # len( + # get_pairs_from_events(state.history) + # ) + # == 6 + # ) with patch('logging.Logger.warning') as mock_warning: assert stuck_detector.is_stuck() is True assert stuck_detector.state.almost_stuck == 0 mock_warning.assert_called_once_with('Action, Observation loop detected') - def test_is_stuck_repeating_action_error( - self, stuck_detector: StuckDetector, event_stream: EventStream - ): + def test_is_stuck_repeating_action_error(self, stuck_detector: StuckDetector): + state = stuck_detector.state # (action, error_observation), not necessarily the same error message_action = MessageAction(content='Done', wait_for_response=False) message_action._source = EventSource.USER hello_action = MessageAction(content='Hello', wait_for_response=False) hello_observation = NullObservation(content='') - event_stream.add_event(hello_action, EventSource.USER) - hello_observation._cause = hello_action._id - event_stream.add_event(hello_observation, EventSource.ENVIRONMENT) + state.history.append(hello_action) + # hello_observation._cause = hello_action._id + state.history.append(hello_observation) # 2 events cmd_action_1 = CmdRunAction(command='invalid_command') - event_stream.add_event(cmd_action_1, EventSource.AGENT) + state.history.append(cmd_action_1) error_observation_1 = ErrorObservation(content='Command not found') - error_observation_1._cause = cmd_action_1._id - event_stream.add_event(error_observation_1, EventSource.ENVIRONMENT) + # error_observation_1._cause = cmd_action_1._id + state.history.append(error_observation_1) # 4 events cmd_action_2 = CmdRunAction(command='invalid_command') - event_stream.add_event(cmd_action_2, EventSource.AGENT) + state.history.append(cmd_action_2) error_observation_2 = ErrorObservation( content='Command still not found or another error' ) - error_observation_2._cause = cmd_action_2._id - event_stream.add_event(error_observation_2, EventSource.ENVIRONMENT) + # error_observation_2._cause = cmd_action_2._id + state.history.append(error_observation_2) # 6 events message_null_observation = NullObservation(content='') - event_stream.add_event(message_action, EventSource.USER) - event_stream.add_event(message_null_observation, EventSource.ENVIRONMENT) + state.history.append(message_action) + state.history.append(message_null_observation) # 8 events cmd_action_3 = CmdRunAction(command='invalid_command') - event_stream.add_event(cmd_action_3, EventSource.AGENT) + state.history.append(cmd_action_3) error_observation_3 = ErrorObservation(content='Different error') - error_observation_3._cause = cmd_action_3._id - event_stream.add_event(error_observation_3, EventSource.ENVIRONMENT) + # error_observation_3._cause = cmd_action_3._id + state.history.append(error_observation_3) # 10 events cmd_action_4 = CmdRunAction(command='invalid_command') - event_stream.add_event(cmd_action_4, EventSource.AGENT) + state.history.append(cmd_action_4) error_observation_4 = ErrorObservation(content='Command not found') - error_observation_4._cause = cmd_action_4._id - event_stream.add_event(error_observation_4, EventSource.ENVIRONMENT) + # error_observation_4._cause = cmd_action_4._id + state.history.append(error_observation_4) # 12 events with patch('logging.Logger.warning') as mock_warning: @@ -267,11 +250,10 @@ def test_is_stuck_repeating_action_error( 'Action, ErrorObservation loop detected' ) - def test_is_stuck_invalid_syntax_error( - self, stuck_detector: StuckDetector, event_stream: EventStream - ): + def test_is_stuck_invalid_syntax_error(self, stuck_detector: StuckDetector): + state = stuck_detector.state self._impl_syntax_error_events( - event_stream, + state, error_message='SyntaxError: invalid syntax. Perhaps you forgot a comma?', random_line=False, ) @@ -280,10 +262,11 @@ def test_is_stuck_invalid_syntax_error( assert stuck_detector.is_stuck() is True def test_is_not_stuck_invalid_syntax_error_random_lines( - self, stuck_detector: StuckDetector, event_stream: EventStream + self, stuck_detector: StuckDetector ): + state = stuck_detector.state self._impl_syntax_error_events( - event_stream, + state, error_message='SyntaxError: invalid syntax. Perhaps you forgot a comma?', random_line=True, ) @@ -292,10 +275,11 @@ def test_is_not_stuck_invalid_syntax_error_random_lines( assert stuck_detector.is_stuck() is False def test_is_not_stuck_invalid_syntax_error_only_three_incidents( - self, stuck_detector: StuckDetector, event_stream: EventStream + self, stuck_detector: StuckDetector ): + state = stuck_detector.state self._impl_syntax_error_events( - event_stream, + state, error_message='SyntaxError: invalid syntax. Perhaps you forgot a comma?', random_line=True, incidents=3, @@ -304,11 +288,10 @@ def test_is_not_stuck_invalid_syntax_error_only_three_incidents( with patch('logging.Logger.warning'): assert stuck_detector.is_stuck() is False - def test_is_stuck_incomplete_input_error( - self, stuck_detector: StuckDetector, event_stream: EventStream - ): + def test_is_stuck_incomplete_input_error(self, stuck_detector: StuckDetector): + state = stuck_detector.state self._impl_syntax_error_events( - event_stream, + state, error_message='SyntaxError: incomplete input', random_line=False, ) @@ -316,11 +299,10 @@ def test_is_stuck_incomplete_input_error( with patch('logging.Logger.warning'): assert stuck_detector.is_stuck() is True - def test_is_not_stuck_incomplete_input_error( - self, stuck_detector: StuckDetector, event_stream: EventStream - ): + def test_is_not_stuck_incomplete_input_error(self, stuck_detector: StuckDetector): + state = stuck_detector.state self._impl_syntax_error_events( - event_stream, + state, error_message='SyntaxError: incomplete input', random_line=True, ) @@ -329,239 +311,241 @@ def test_is_not_stuck_incomplete_input_error( assert stuck_detector.is_stuck() is False def test_is_not_stuck_ipython_unterminated_string_error_random_lines( - self, stuck_detector: StuckDetector, event_stream: EventStream + self, stuck_detector: StuckDetector ): - self._impl_unterminated_string_error_events(event_stream, random_line=True) + state = stuck_detector.state + self._impl_unterminated_string_error_events(state, random_line=True) with patch('logging.Logger.warning'): assert stuck_detector.is_stuck() is False def test_is_not_stuck_ipython_unterminated_string_error_only_three_incidents( - self, stuck_detector: StuckDetector, event_stream: EventStream + self, stuck_detector: StuckDetector ): + state = stuck_detector.state self._impl_unterminated_string_error_events( - event_stream, random_line=False, incidents=3 + state, random_line=False, incidents=3 ) with patch('logging.Logger.warning'): assert stuck_detector.is_stuck() is False def test_is_stuck_ipython_unterminated_string_error( - self, stuck_detector: StuckDetector, event_stream: EventStream + self, stuck_detector: StuckDetector ): - self._impl_unterminated_string_error_events(event_stream, random_line=False) + state = stuck_detector.state + self._impl_unterminated_string_error_events(state, random_line=False) with patch('logging.Logger.warning'): assert stuck_detector.is_stuck() is True def test_is_not_stuck_ipython_syntax_error_not_at_end( - self, stuck_detector: StuckDetector, event_stream: EventStream + self, stuck_detector: StuckDetector ): + state = stuck_detector.state # this test is to make sure we don't get false positives # since the "at line x" is changing in between! ipython_action_1 = IPythonRunCellAction(code='print("hello') - event_stream.add_event(ipython_action_1, EventSource.AGENT) + state.history.append(ipython_action_1) ipython_observation_1 = IPythonRunCellObservation( content='print("hello\n ^\nSyntaxError: unterminated string literal (detected at line 1)\nThis is some additional output', code='print("hello', ) - ipython_observation_1._cause = ipython_action_1._id - event_stream.add_event(ipython_observation_1, EventSource.ENVIRONMENT) + # ipython_observation_1._cause = ipython_action_1._id + state.history.append(ipython_observation_1) ipython_action_2 = IPythonRunCellAction(code='print("hello') - event_stream.add_event(ipython_action_2, EventSource.AGENT) + state.history.append(ipython_action_2) ipython_observation_2 = IPythonRunCellObservation( content='print("hello\n ^\nSyntaxError: unterminated string literal (detected at line 1)\nToo much output here on and on', code='print("hello', ) - ipython_observation_2._cause = ipython_action_2._id - event_stream.add_event(ipython_observation_2, EventSource.ENVIRONMENT) + # ipython_observation_2._cause = ipython_action_2._id + state.history.append(ipython_observation_2) ipython_action_3 = IPythonRunCellAction(code='print("hello') - event_stream.add_event(ipython_action_3, EventSource.AGENT) + state.history.append(ipython_action_3) ipython_observation_3 = IPythonRunCellObservation( content='print("hello\n ^\nSyntaxError: unterminated string literal (detected at line 3)\nEnough', code='print("hello', ) - ipython_observation_3._cause = ipython_action_3._id - event_stream.add_event(ipython_observation_3, EventSource.ENVIRONMENT) + # ipython_observation_3._cause = ipython_action_3._id + state.history.append(ipython_observation_3) ipython_action_4 = IPythonRunCellAction(code='print("hello') - event_stream.add_event(ipython_action_4, EventSource.AGENT) + state.history.append(ipython_action_4) ipython_observation_4 = IPythonRunCellObservation( content='print("hello\n ^\nSyntaxError: unterminated string literal (detected at line 2)\nLast line of output', code='print("hello', ) - ipython_observation_4._cause = ipython_action_4._id - event_stream.add_event(ipython_observation_4, EventSource.ENVIRONMENT) + # ipython_observation_4._cause = ipython_action_4._id + state.history.append(ipython_observation_4) with patch('logging.Logger.warning') as mock_warning: assert stuck_detector.is_stuck() is False mock_warning.assert_not_called() def test_is_stuck_repeating_action_observation_pattern( - self, stuck_detector: StuckDetector, event_stream: EventStream + self, stuck_detector: StuckDetector ): + state = stuck_detector.state message_action = MessageAction(content='Come on', wait_for_response=False) message_action._source = EventSource.USER - event_stream.add_event(message_action, EventSource.USER) + state.history.append(message_action) message_observation = NullObservation(content='') - event_stream.add_event(message_observation, EventSource.ENVIRONMENT) + state.history.append(message_observation) cmd_action_1 = CmdRunAction(command='ls') - event_stream.add_event(cmd_action_1, EventSource.AGENT) + state.history.append(cmd_action_1) cmd_observation_1 = CmdOutputObservation( command_id=1, command='ls', content='file1.txt\nfile2.txt' ) - cmd_observation_1._cause = cmd_action_1._id - event_stream.add_event(cmd_observation_1, EventSource.ENVIRONMENT) + # cmd_observation_1._cause = cmd_action_1._id + state.history.append(cmd_observation_1) read_action_1 = FileReadAction(path='file1.txt') - event_stream.add_event(read_action_1, EventSource.AGENT) + state.history.append(read_action_1) read_observation_1 = FileReadObservation( content='File content', path='file1.txt' ) - read_observation_1._cause = read_action_1._id - event_stream.add_event(read_observation_1, EventSource.ENVIRONMENT) + # read_observation_1._cause = read_action_1._id + state.history.append(read_observation_1) cmd_action_2 = CmdRunAction(command='ls') - event_stream.add_event(cmd_action_2, EventSource.AGENT) + state.history.append(cmd_action_2) cmd_observation_2 = CmdOutputObservation( command_id=2, command='ls', content='file1.txt\nfile2.txt' ) - cmd_observation_2._cause = cmd_action_2._id - event_stream.add_event(cmd_observation_2, EventSource.ENVIRONMENT) + # cmd_observation_2._cause = cmd_action_2._id + state.history.append(cmd_observation_2) read_action_2 = FileReadAction(path='file1.txt') - event_stream.add_event(read_action_2, EventSource.AGENT) + state.history.append(read_action_2) read_observation_2 = FileReadObservation( content='File content', path='file1.txt' ) - read_observation_2._cause = read_action_2._id - event_stream.add_event(read_observation_2, EventSource.ENVIRONMENT) + # read_observation_2._cause = read_action_2._id + state.history.append(read_observation_2) message_action = MessageAction(content='Come on', wait_for_response=False) - event_stream.add_event(message_action, EventSource.USER) + message_action._source = EventSource.USER + state.history.append(message_action) message_null_observation = NullObservation(content='') - event_stream.add_event(message_null_observation, EventSource.ENVIRONMENT) + state.history.append(message_null_observation) cmd_action_3 = CmdRunAction(command='ls') - event_stream.add_event(cmd_action_3, EventSource.AGENT) + state.history.append(cmd_action_3) cmd_observation_3 = CmdOutputObservation( command_id=3, command='ls', content='file1.txt\nfile2.txt' ) - cmd_observation_3._cause = cmd_action_3._id - event_stream.add_event(cmd_observation_3, EventSource.ENVIRONMENT) + # cmd_observation_3._cause = cmd_action_3._id + state.history.append(cmd_observation_3) read_action_3 = FileReadAction(path='file1.txt') - event_stream.add_event(read_action_3, EventSource.AGENT) + state.history.append(read_action_3) read_observation_3 = FileReadObservation( content='File content', path='file1.txt' ) - read_observation_3._cause = read_action_3._id - event_stream.add_event(read_observation_3, EventSource.ENVIRONMENT) + # read_observation_3._cause = read_action_3._id + state.history.append(read_observation_3) with patch('logging.Logger.warning') as mock_warning: assert stuck_detector.is_stuck() is True mock_warning.assert_called_once_with('Action, Observation pattern detected') - def test_is_stuck_not_stuck( - self, stuck_detector: StuckDetector, event_stream: EventStream - ): + def test_is_stuck_not_stuck(self, stuck_detector: StuckDetector): + state = stuck_detector.state message_action = MessageAction(content='Done', wait_for_response=False) message_action._source = EventSource.USER hello_action = MessageAction(content='Hello', wait_for_response=False) - event_stream.add_event(hello_action, EventSource.USER) + state.history.append(hello_action) hello_observation = NullObservation(content='') - hello_observation._cause = hello_action._id - event_stream.add_event(hello_observation, EventSource.ENVIRONMENT) + # hello_observation._cause = hello_action._id + state.history.append(hello_observation) cmd_action_1 = CmdRunAction(command='ls') - event_stream.add_event(cmd_action_1, EventSource.AGENT) + state.history.append(cmd_action_1) cmd_observation_1 = CmdOutputObservation( command_id=cmd_action_1.id, command='ls', content='file1.txt\nfile2.txt' ) - cmd_observation_1._cause = cmd_action_1._id - event_stream.add_event(cmd_observation_1, EventSource.ENVIRONMENT) + # cmd_observation_1._cause = cmd_action_1._id + state.history.append(cmd_observation_1) read_action_1 = FileReadAction(path='file1.txt') - event_stream.add_event(read_action_1, EventSource.AGENT) + state.history.append(read_action_1) read_observation_1 = FileReadObservation( content='File content', path='file1.txt' ) - read_observation_1._cause = read_action_1._id - event_stream.add_event(read_observation_1, EventSource.ENVIRONMENT) + # read_observation_1._cause = read_action_1._id + state.history.append(read_observation_1) cmd_action_2 = CmdRunAction(command='pwd') - event_stream.add_event(cmd_action_2, EventSource.AGENT) + state.history.append(cmd_action_2) cmd_observation_2 = CmdOutputObservation( command_id=2, command='pwd', content='/home/user' ) - cmd_observation_2._cause = cmd_action_2._id - event_stream.add_event(cmd_observation_2, EventSource.ENVIRONMENT) + # cmd_observation_2._cause = cmd_action_2._id + state.history.append(cmd_observation_2) read_action_2 = FileReadAction(path='file2.txt') - event_stream.add_event(read_action_2, EventSource.AGENT) + state.history.append(read_action_2) read_observation_2 = FileReadObservation( content='Another file content', path='file2.txt' ) - read_observation_2._cause = read_action_2._id - event_stream.add_event(read_observation_2, EventSource.ENVIRONMENT) + # read_observation_2._cause = read_action_2._id + state.history.append(read_observation_2) message_null_observation = NullObservation(content='') - event_stream.add_event(message_action, EventSource.USER) - event_stream.add_event(message_null_observation, EventSource.ENVIRONMENT) + state.history.append(message_action) + state.history.append(message_null_observation) cmd_action_3 = CmdRunAction(command='pwd') - event_stream.add_event(cmd_action_3, EventSource.AGENT) + state.history.append(cmd_action_3) cmd_observation_3 = CmdOutputObservation( command_id=cmd_action_3.id, command='pwd', content='/home/user' ) - cmd_observation_3._cause = cmd_action_3._id - event_stream.add_event(cmd_observation_3, EventSource.ENVIRONMENT) + # cmd_observation_3._cause = cmd_action_3._id + state.history.append(cmd_observation_3) read_action_3 = FileReadAction(path='file2.txt') - event_stream.add_event(read_action_3, EventSource.AGENT) + state.history.append(read_action_3) read_observation_3 = FileReadObservation( content='Another file content', path='file2.txt' ) - read_observation_3._cause = read_action_3._id - event_stream.add_event(read_observation_3, EventSource.ENVIRONMENT) + # read_observation_3._cause = read_action_3._id + state.history.append(read_observation_3) assert stuck_detector.is_stuck() is False - def test_is_stuck_monologue(self, stuck_detector, event_stream): - # Add events to the event stream + def test_is_stuck_monologue(self, stuck_detector): + state = stuck_detector.state + # Add events to the history list directly message_action_1 = MessageAction(content='Hi there!') - event_stream.add_event(message_action_1, EventSource.USER) message_action_1._source = EventSource.USER - + state.history.append(message_action_1) message_action_2 = MessageAction(content='Hi there!') - event_stream.add_event(message_action_2, EventSource.AGENT) message_action_2._source = EventSource.AGENT - + state.history.append(message_action_2) message_action_3 = MessageAction(content='How are you?') - event_stream.add_event(message_action_3, EventSource.USER) message_action_3._source = EventSource.USER + state.history.append(message_action_3) cmd_kill_action = CmdRunAction( command='echo 42', thought="I'm not stuck, he's stuck" ) - event_stream.add_event(cmd_kill_action, EventSource.AGENT) + state.history.append(cmd_kill_action) message_action_4 = MessageAction(content="I'm doing well, thanks for asking.") - event_stream.add_event(message_action_4, EventSource.AGENT) message_action_4._source = EventSource.AGENT - + state.history.append(message_action_4) message_action_5 = MessageAction(content="I'm doing well, thanks for asking.") - event_stream.add_event(message_action_5, EventSource.AGENT) message_action_5._source = EventSource.AGENT - + state.history.append(message_action_5) message_action_6 = MessageAction(content="I'm doing well, thanks for asking.") - event_stream.add_event(message_action_6, EventSource.AGENT) message_action_6._source = EventSource.AGENT + state.history.append(message_action_6) assert stuck_detector.is_stuck() @@ -572,16 +556,15 @@ def test_is_stuck_monologue(self, stuck_detector, event_stream): command='storybook', exit_code=0, ) - cmd_output_observation._cause = cmd_kill_action._id - event_stream.add_event(cmd_output_observation, EventSource.ENVIRONMENT) + # cmd_output_observation._cause = cmd_kill_action._id + state.history.append(cmd_output_observation) message_action_7 = MessageAction(content="I'm doing well, thanks for asking.") - event_stream.add_event(message_action_7, EventSource.AGENT) message_action_7._source = EventSource.AGENT - + state.history.append(message_action_7) message_action_8 = MessageAction(content="I'm doing well, thanks for asking.") - event_stream.add_event(message_action_8, EventSource.AGENT) message_action_8._source = EventSource.AGENT + state.history.append(message_action_8) with patch('logging.Logger.warning'): assert not stuck_detector.is_stuck() @@ -596,7 +579,6 @@ def controller(self): ) controller.delegate = None controller.state = Mock() - controller.state.history = ShortTermHistory() return controller def test_is_stuck_delegate_stuck(self, controller: AgentController): diff --git a/tests/unit/test_micro_agents.py b/tests/unit/test_micro_agents.py index 70553d851125..8cff14fdd4f2 100644 --- a/tests/unit/test_micro_agents.py +++ b/tests/unit/test_micro_agents.py @@ -10,10 +10,8 @@ from openhands.controller.agent import Agent from openhands.controller.state.state import State from openhands.core.config import AgentConfig -from openhands.events import EventSource from openhands.events.action import MessageAction from openhands.events.stream import EventStream -from openhands.memory.history import ShortTermHistory from openhands.storage import get_file_store @@ -74,10 +72,10 @@ def test_coder_agent_with_summary(event_stream: EventStream, agent_configs: dict ) assert coder_agent is not None + # give it some history task = 'This is a dummy task' - history = ShortTermHistory() - history.set_event_stream(event_stream) - event_stream.add_event(MessageAction(content=task), EventSource.USER) + history = list() + history.append(MessageAction(content=task)) summary = 'This is a dummy summary about this repo' state = State(history=history, inputs={'summary': summary}) @@ -119,10 +117,10 @@ def test_coder_agent_without_summary(event_stream: EventStream, agent_configs: d ) assert coder_agent is not None + # give it some history task = 'This is a dummy task' - history = ShortTermHistory() - history.set_event_stream(event_stream) - event_stream.add_event(MessageAction(content=task), EventSource.USER) + history = list() + history.append(MessageAction(content=task)) # set state without codebase summary state = State(history=history) diff --git a/tests/unit/test_prompt_caching.py b/tests/unit/test_prompt_caching.py index 1bd182857355..caa08b0e55fd 100644 --- a/tests/unit/test_prompt_caching.py +++ b/tests/unit/test_prompt_caching.py @@ -1,14 +1,12 @@ -from unittest.mock import MagicMock, Mock, patch +from unittest.mock import Mock, patch import pytest from openhands.agenthub.codeact_agent.codeact_agent import CodeActAgent from openhands.core.config import AgentConfig, LLMConfig -from openhands.events import EventSource, EventStream from openhands.events.action import CmdRunAction, MessageAction from openhands.events.observation import CmdOutputObservation from openhands.llm.llm import LLM -from openhands.storage import get_file_store @pytest.fixture @@ -19,12 +17,6 @@ def mock_llm(): return llm -@pytest.fixture -def mock_event_stream(tmp_path): - file_store = get_file_store('local', str(tmp_path)) - return EventStream('test_session', file_store) - - @pytest.fixture(params=[False, True]) def codeact_agent(mock_llm, request): config = AgentConfig() @@ -57,17 +49,28 @@ def model_dump(self): return MockModelResponse(content) -def test_get_messages_with_reminder(codeact_agent, mock_event_stream): - # Add some events to the stream - mock_event_stream.add_event(MessageAction('Initial user message'), EventSource.USER) - mock_event_stream.add_event(MessageAction('Sure!'), EventSource.AGENT) - mock_event_stream.add_event(MessageAction('Hello, agent!'), EventSource.USER) - mock_event_stream.add_event(MessageAction('Hello, user!'), EventSource.AGENT) - mock_event_stream.add_event(MessageAction('Laaaaaaaast!'), EventSource.USER) +def test_get_messages_with_reminder(codeact_agent: CodeActAgent): + # Add some events to history + history = list() + message_action_1 = MessageAction('Initial user message') + message_action_1._source = 'user' + history.append(message_action_1) + message_action_2 = MessageAction('Sure!') + message_action_2._source = 'assistant' + history.append(message_action_2) + message_action_3 = MessageAction('Hello, agent!') + message_action_3._source = 'user' + history.append(message_action_3) + message_action_4 = MessageAction('Hello, user!') + message_action_4._source = 'assistant' + history.append(message_action_4) + message_action_5 = MessageAction('Laaaaaaaast!') + message_action_5._source = 'user' + history.append(message_action_5) codeact_agent.reset() messages = codeact_agent._get_messages( - Mock(history=mock_event_stream, max_iterations=5, iteration=0) + Mock(history=history, max_iterations=5, iteration=0) ) assert ( @@ -102,19 +105,20 @@ def test_get_messages_with_reminder(codeact_agent, mock_event_stream): ) -def test_get_messages_prompt_caching(codeact_agent, mock_event_stream): +def test_get_messages_prompt_caching(codeact_agent: CodeActAgent): + history = list() # Add multiple user and agent messages for i in range(15): - mock_event_stream.add_event( - MessageAction(f'User message {i}'), EventSource.USER - ) - mock_event_stream.add_event( - MessageAction(f'Agent message {i}'), EventSource.AGENT - ) + message_action_user = MessageAction(f'User message {i}') + message_action_user._source = 'user' + history.append(message_action_user) + message_action_agent = MessageAction(f'Agent message {i}') + message_action_agent._source = 'assistant' + history.append(message_action_agent) codeact_agent.reset() messages = codeact_agent._get_messages( - Mock(history=mock_event_stream, max_iterations=10, iteration=5) + Mock(history=history, max_iterations=10, iteration=5) ) # Check that only the last two user messages have cache_prompt=True @@ -136,18 +140,23 @@ def test_get_messages_prompt_caching(codeact_agent, mock_event_stream): assert cached_user_messages[3].content[0].text.startswith('User message 1') -def test_get_messages_with_cmd_action(codeact_agent, mock_event_stream): +def test_get_messages_with_cmd_action(codeact_agent: CodeActAgent): if codeact_agent.config.function_calling: pytest.skip('Skipping this test for function calling') + history = list() + # Add a mix of actions and observations message_action_1 = MessageAction( "Let's list the contents of the current directory." ) - mock_event_stream.add_event(message_action_1, EventSource.USER) + message_action_1._source = 'user' + history.append(message_action_1) cmd_action_1 = CmdRunAction('ls -l', thought='List files in current directory') - mock_event_stream.add_event(cmd_action_1, EventSource.AGENT) + cmd_action_1._source = 'agent' + cmd_action_1._id = 'cmd_1' + history.append(cmd_action_1) cmd_observation_1 = CmdOutputObservation( content='total 0\n-rw-r--r-- 1 user group 0 Jan 1 00:00 file1.txt\n-rw-r--r-- 1 user group 0 Jan 1 00:00 file2.txt', @@ -155,13 +164,17 @@ def test_get_messages_with_cmd_action(codeact_agent, mock_event_stream): command='ls -l', exit_code=0, ) - mock_event_stream.add_event(cmd_observation_1, EventSource.ENVIRONMENT) + cmd_observation_1._source = 'user' + history.append(cmd_observation_1) message_action_2 = MessageAction("Now, let's create a new directory.") - mock_event_stream.add_event(message_action_2, EventSource.AGENT) + message_action_2._source = 'agent' + history.append(message_action_2) cmd_action_2 = CmdRunAction('mkdir new_directory', thought='Create a new directory') - mock_event_stream.add_event(cmd_action_2, EventSource.AGENT) + cmd_action_2._source = 'agent' + cmd_action_2._id = 'cmd_2' + history.append(cmd_action_2) cmd_observation_2 = CmdOutputObservation( content='', @@ -169,11 +182,12 @@ def test_get_messages_with_cmd_action(codeact_agent, mock_event_stream): command='mkdir new_directory', exit_code=0, ) - mock_event_stream.add_event(cmd_observation_2, EventSource.ENVIRONMENT) + cmd_observation_2._source = 'user' + history.append(cmd_observation_2) codeact_agent.reset() messages = codeact_agent._get_messages( - Mock(history=mock_event_stream, max_iterations=5, iteration=0) + Mock(history=history, max_iterations=5, iteration=0) ) # Assert the presence of key elements in the messages @@ -218,19 +232,17 @@ def test_get_messages_with_cmd_action(codeact_agent, mock_event_stream): assert 'ENVIRONMENT REMINDER: You have 5 turns' in messages[5].content[1].text -def test_prompt_caching_headers(codeact_agent, mock_event_stream): +def test_prompt_caching_headers(codeact_agent: CodeActAgent): + history = list() if codeact_agent.config.function_calling: pytest.skip('Skipping this test for function calling') # Setup - mock_event_stream.add_event(MessageAction('Hello, agent!'), EventSource.USER) - mock_event_stream.add_event(MessageAction('Hello, user!'), EventSource.AGENT) - - mock_short_term_history = MagicMock() - mock_short_term_history.get_last_user_message.return_value = 'Hello, agent!' + history.append(MessageAction('Hello, agent!')) + history.append(MessageAction('Hello, user!')) mock_state = Mock() - mock_state.history = mock_short_term_history + mock_state.history = history mock_state.max_iterations = 5 mock_state.iteration = 0