diff --git a/config.template.toml b/config.template.toml index 744dfc7953a4..8e9ed88ba818 100644 --- a/config.template.toml +++ b/config.template.toml @@ -238,6 +238,9 @@ codeact_enable_jupyter = true # length limit enable_history_truncation = true +# Whether to enable plan routing to reasoning models +#enable_plan_routing = false + [agent.RepoExplorerAgent] # Example: use a cheaper model for RepoExplorerAgent to reduce cost, especially # useful when an agent doesn't demand high quality but uses a lot of tokens @@ -288,6 +291,23 @@ llm_config = 'gpt3' # The security analyzer to use (For Headless / CLI only - In Web this is overridden by Session Init) #security_analyzer = "" +################################ Model Routing ############################### +# Configuration for model routing features +############################################################################## +[model_routing] + +# The reasoning model to use for plan generation +reasoning_llm_config_name = 'reasoning_model' +judge_llm_config_name = 'judge_model' + +[llm.judge_model] +model = "gpt-4o" +api_key = "" + +[llm.reasoning_model] +model = "o1" +api_key = "" + #################################### Eval #################################### # Configuration for the evaluation, please refer to the specific evaluation # plugin for the available options diff --git a/evaluation/benchmarks/swe_bench/run_infer.py b/evaluation/benchmarks/swe_bench/run_infer.py index 89fe618a6c34..1de7a1fd1661 100644 --- a/evaluation/benchmarks/swe_bench/run_infer.py +++ b/evaluation/benchmarks/swe_bench/run_infer.py @@ -1,4 +1,5 @@ import asyncio +import copy import json import os import tempfile @@ -33,6 +34,7 @@ AppConfig, get_llm_config_arg, get_parser, + load_from_toml, ) from openhands.core.logger import openhands_logger as logger from openhands.core.main import create_runtime, run_controller @@ -148,14 +150,19 @@ def get_config( metadata.llm_config, metadata.eval_output_dir, instance['instance_id'] ) ) + config_copy = copy.deepcopy(config) + load_from_toml(config_copy) agent_config = AgentConfig( codeact_enable_jupyter=False, codeact_enable_browsing=RUN_WITH_BROWSING, codeact_enable_llm_editor=False, condenser=metadata.condenser_config, enable_prompt_extensions=False, + enable_plan_routing=config_copy.get_agent_config().enable_plan_routing, ) config.set_agent_config(agent_config) + config.routing_llms = config_copy.routing_llms + config.model_routing = config_copy.model_routing return config diff --git a/openhands/agenthub/codeact_agent/codeact_agent.py b/openhands/agenthub/codeact_agent/codeact_agent.py index b636e40cb9f6..ddc8464e0060 100644 --- a/openhands/agenthub/codeact_agent/codeact_agent.py +++ b/openhands/agenthub/codeact_agent/codeact_agent.py @@ -6,7 +6,7 @@ import openhands.agenthub.codeact_agent.function_calling as codeact_function_calling from openhands.controller.agent import Agent from openhands.controller.state.state import State -from openhands.core.config import AgentConfig +from openhands.core.config import AgentConfig, ModelRoutingConfig from openhands.core.logger import openhands_logger as logger from openhands.core.message import Message, TextContent from openhands.core.message_utils import ( @@ -19,12 +19,14 @@ ) from openhands.llm.llm import LLM from openhands.memory.condenser import Condenser +from openhands.router import BaseRouter, LLMBasedPlanRouter from openhands.runtime.plugins import ( AgentSkillsRequirement, JupyterRequirement, PluginRequirement, ) from openhands.utils.prompt import PromptManager +from openhands.utils.trajectory import format_trajectory class CodeActAgent(Agent): @@ -60,11 +62,14 @@ def __init__( self, llm: LLM, config: AgentConfig, + model_routing_config: ModelRoutingConfig | None = None, + routing_llms: dict[str, LLM] | None = None, ) -> None: """Initializes a new instance of the CodeActAgent class. Parameters: - llm (LLM): The llm to be used by this agent + - routing_llms (dict[str, LLM]): The llms to be selected for routing """ super().__init__(llm, config) self.pending_actions: deque[Action] = deque() @@ -93,6 +98,18 @@ def __init__( self.condenser = Condenser.from_config(self.config.condenser) logger.debug(f'Using condenser: {self.condenser}') + self.router: BaseRouter | None = None + + if config.enable_plan_routing: + assert model_routing_config is not None and routing_llms is not None + self.router = LLMBasedPlanRouter( + llm=self.llm, + routing_llms=routing_llms or dict(), + model_routing_config=model_routing_config, + ) + + self.active_llm: LLM | None = None # The LLM chosen by the router + def reset(self) -> None: """Resets the CodeAct Agent.""" super().reset() @@ -121,13 +138,30 @@ def step(self, state: State) -> Action: if latest_user_message and latest_user_message.content.strip() == '/exit': return AgentFinishAction() + params: dict = {} + + # check if model routing is needed + if self.router: + messages = self._get_messages(state) + formatted_trajectory = format_trajectory(messages) + self.active_llm = self.router.should_route_to(formatted_trajectory) + + if self.active_llm != self.llm: + logger.warning(f'🧭 Routing to custom model: {self.active_llm}') + else: + self.active_llm = self.llm + + params['tools'] = self.tools + if not self.active_llm.is_function_calling_active(): + params['mock_function_calling'] = True + # prepare what we want to send to the LLM + # NOTE: We need to call this here when self.active_llm is correctly set messages = self._get_messages(state) - params: dict = { - 'messages': self.llm.format_messages_for_llm(messages), - } - params['tools'] = self.tools - response = self.llm.completion(**params) + params['messages'] = self.active_llm.format_messages_for_llm(messages) + + response = self.active_llm.completion(**params) + actions = codeact_function_calling.response_to_actions(response) for action in actions: self.pending_actions.append(action) @@ -168,6 +202,8 @@ def _get_messages(self, state: State) -> list[Message]: if not self.prompt_manager: raise Exception('Prompt Manager not instantiated.') + active_llm_ = self.active_llm or self.llm + messages: list[Message] = self._initial_messages() # Condense the events from the state. @@ -175,14 +211,14 @@ def _get_messages(self, state: State) -> list[Message]: messages += events_to_messages( events, - max_message_chars=self.llm.config.max_message_chars, - vision_is_active=self.llm.vision_is_active(), + max_message_chars=active_llm_.config.max_message_chars, + vision_is_active=active_llm_.vision_is_active(), enable_som_visual_browsing=self.config.enable_som_visual_browsing, ) messages = self._enhance_messages(messages) - if self.llm.is_caching_prompt_active(): + if active_llm_.is_caching_prompt_active(): apply_prompt_caching(messages) return messages @@ -191,13 +227,15 @@ def _initial_messages(self) -> list[Message]: """Creates the initial messages (including the system prompt) for the LLM conversation.""" assert self.prompt_manager, 'Prompt Manager not instantiated.' + active_llm_ = self.active_llm or self.llm + return [ Message( role='system', content=[ TextContent( text=self.prompt_manager.get_system_message(), - cache_prompt=self.llm.is_caching_prompt_active(), + cache_prompt=active_llm_.is_caching_prompt_active(), ) ], ) diff --git a/openhands/agenthub/dummy_agent/agent.py b/openhands/agenthub/dummy_agent/agent.py index b420a3d5d8ae..b663b2618ac2 100644 --- a/openhands/agenthub/dummy_agent/agent.py +++ b/openhands/agenthub/dummy_agent/agent.py @@ -45,7 +45,7 @@ class DummyAgent(Agent): without making any LLM calls. """ - def __init__(self, llm: LLM, config: AgentConfig): + def __init__(self, llm: LLM, config: AgentConfig, **kwargs): super().__init__(llm, config) self.steps: list[ActionObs] = [ { diff --git a/openhands/controller/agent.py b/openhands/controller/agent.py index 43a55d935249..8577b179b3d0 100644 --- a/openhands/controller/agent.py +++ b/openhands/controller/agent.py @@ -32,6 +32,7 @@ def __init__( self, llm: LLM, config: 'AgentConfig', + **kwargs, ): self.llm = llm self.config = config diff --git a/openhands/core/config/__init__.py b/openhands/core/config/__init__.py index d653f3e70ac4..4a02864921d4 100644 --- a/openhands/core/config/__init__.py +++ b/openhands/core/config/__init__.py @@ -6,6 +6,7 @@ get_field_info, ) from openhands.core.config.llm_config import LLMConfig +from openhands.core.config.model_routing_config import ModelRoutingConfig from openhands.core.config.sandbox_config import SandboxConfig from openhands.core.config.security_config import SecurityConfig from openhands.core.config.utils import ( @@ -28,6 +29,7 @@ 'LLMConfig', 'SandboxConfig', 'SecurityConfig', + 'ModelRoutingConfig', 'load_app_config', 'load_from_env', 'load_from_toml', diff --git a/openhands/core/config/agent_config.py b/openhands/core/config/agent_config.py index 01a76575c5cd..222931fb96f9 100644 --- a/openhands/core/config/agent_config.py +++ b/openhands/core/config/agent_config.py @@ -33,3 +33,4 @@ class AgentConfig(BaseModel): disabled_microagents: list[str] | None = Field(default=None) condenser: CondenserConfig = Field(default_factory=NoOpCondenserConfig) enable_history_truncation: bool = Field(default=True) + enable_plan_routing: bool = Field(default=False) diff --git a/openhands/core/config/app_config.py b/openhands/core/config/app_config.py index 5965f06480c8..a5c548b8b196 100644 --- a/openhands/core/config/app_config.py +++ b/openhands/core/config/app_config.py @@ -10,6 +10,7 @@ model_defaults_to_dict, ) from openhands.core.config.llm_config import LLMConfig +from openhands.core.config.model_routing_config import ModelRoutingConfig from openhands.core.config.sandbox_config import SandboxConfig from openhands.core.config.security_config import SecurityConfig @@ -20,6 +21,7 @@ class AppConfig(BaseModel): Attributes: llms: Dictionary mapping LLM names to their configurations. The default configuration is stored under the 'llm' key. + routing_llms: Dictionary mapping LLM for routing' names to their configurations. agents: Dictionary mapping agent names to their configurations. The default configuration is stored under the 'agent' key. default_agent: Name of the default agent to use. @@ -48,10 +50,12 @@ class AppConfig(BaseModel): """ llms: dict[str, LLMConfig] = Field(default_factory=dict) + routing_llms: dict[str, LLMConfig] = Field(default_factory=dict) agents: dict = Field(default_factory=dict) default_agent: str = Field(default=OH_DEFAULT_AGENT) sandbox: SandboxConfig = Field(default_factory=SandboxConfig) security: SecurityConfig = Field(default_factory=SecurityConfig) + model_routing: ModelRoutingConfig = Field(default_factory=ModelRoutingConfig) runtime: str = Field(default='docker') file_store: str = Field(default='local') file_store_path: str = Field(default='/tmp/openhands_file_store') @@ -95,7 +99,10 @@ def get_llm_config(self, name='llm') -> LLMConfig: return self.llms['llm'] def set_llm_config(self, value: LLMConfig, name='llm') -> None: - self.llms[name] = value + if value.for_routing: + self.routing_llms[name] = value + else: + self.llms[name] = value def get_agent_config(self, name='agent') -> AgentConfig: """'agent' is the name for default config (for backward compatibility prior to 0.8).""" diff --git a/openhands/core/config/llm_config.py b/openhands/core/config/llm_config.py index 5497d7125823..014c8c6fde13 100644 --- a/openhands/core/config/llm_config.py +++ b/openhands/core/config/llm_config.py @@ -87,6 +87,7 @@ class LLMConfig(BaseModel): custom_tokenizer: str | None = Field(default=None) native_tool_calling: bool | None = Field(default=None) reasoning_effort: str | None = Field(default='high') + for_routing: bool = Field(default=False) model_config = {'extra': 'forbid'} diff --git a/openhands/core/config/model_routing_config.py b/openhands/core/config/model_routing_config.py new file mode 100644 index 000000000000..349389f3b88a --- /dev/null +++ b/openhands/core/config/model_routing_config.py @@ -0,0 +1,6 @@ +from pydantic import BaseModel, Field + + +class ModelRoutingConfig(BaseModel): + reasoning_llm_config_name: str = Field(default='reasoning_model') + judge_llm_config_name: str = Field(default='judge_model') diff --git a/openhands/core/config/utils.py b/openhands/core/config/utils.py index f057eb6ad2fe..eb7473c66c98 100644 --- a/openhands/core/config/utils.py +++ b/openhands/core/config/utils.py @@ -15,11 +15,9 @@ from openhands.core import logger from openhands.core.config.agent_config import AgentConfig from openhands.core.config.app_config import AppConfig -from openhands.core.config.config_utils import ( - OH_DEFAULT_AGENT, - OH_MAX_ITERATIONS, -) +from openhands.core.config.config_utils import OH_DEFAULT_AGENT, OH_MAX_ITERATIONS from openhands.core.config.llm_config import LLMConfig +from openhands.core.config.model_routing_config import ModelRoutingConfig from openhands.core.config.sandbox_config import SandboxConfig from openhands.core.config.security_config import SecurityConfig from openhands.storage import get_file_store @@ -172,7 +170,6 @@ def load_from_toml(cfg: AppConfig, toml_file: str = 'config.toml') -> None: logger.openhands_logger.debug( 'Attempt to load default LLM config from config toml' ) - # Extract generic LLM fields, which are not nested LLM configs generic_llm_fields = {} for k, v in value.items(): @@ -203,13 +200,18 @@ def load_from_toml(cfg: AppConfig, toml_file: str = 'config.toml') -> None: custom_llm_config = LLMConfig(**merged_llm_dict) cfg.set_llm_config(custom_llm_config, nested_key) - elif key is not None and key.lower() == 'security': logger.openhands_logger.debug( 'Attempt to load security config from config toml' ) security_config = SecurityConfig(**value) cfg.security = security_config + elif key is not None and key.lower() == 'model_routing': + logger.openhands_logger.debug( + 'Attempt to load model routing config from config toml' + ) + model_routing_config = ModelRoutingConfig(**value) + cfg.model_routing = model_routing_config elif not key.startswith('sandbox') and key.lower() != 'core': logger.openhands_logger.warning( f'Unknown key in {toml_file}: "{key}"' diff --git a/openhands/core/setup.py b/openhands/core/setup.py index 82bdaf0c204b..1adb36f95983 100644 --- a/openhands/core/setup.py +++ b/openhands/core/setup.py @@ -6,9 +6,7 @@ from openhands.controller import AgentController from openhands.controller.agent import Agent from openhands.controller.state.state import State -from openhands.core.config import ( - AppConfig, -) +from openhands.core.config import AppConfig from openhands.core.logger import openhands_logger as logger from openhands.events import EventStream from openhands.events.event import Event @@ -62,9 +60,18 @@ def create_agent(runtime: Runtime, config: AppConfig) -> Agent: agent_cls: Type[Agent] = Agent.get_cls(config.default_agent) agent_config = config.get_agent_config(config.default_agent) llm_config = config.get_llm_config_from_agent(config.default_agent) + routing_llms_config = config.routing_llms + model_routing_config = config.model_routing + routing_llms = {} + for config_name, routing_llm_config in routing_llms_config.items(): + routing_llms[config_name] = LLM( + config=routing_llm_config, + ) agent = agent_cls( llm=LLM(config=llm_config), config=agent_config, + model_routing_config=model_routing_config, + routing_llms=routing_llms, ) if agent.prompt_manager: microagents = runtime.get_microagents_from_selected_repo(None) diff --git a/openhands/llm/llm.py b/openhands/llm/llm.py index b40f11ca8396..e9aad1fe29b0 100644 --- a/openhands/llm/llm.py +++ b/openhands/llm/llm.py @@ -59,6 +59,7 @@ 'gpt-4o-mini', 'gpt-4o', 'o1-2024-12-17', + 'o1', 'o3-mini-2025-01-31', 'o3-mini', ] diff --git a/openhands/router/__init__.py b/openhands/router/__init__.py new file mode 100644 index 000000000000..32058b2b386f --- /dev/null +++ b/openhands/router/__init__.py @@ -0,0 +1,4 @@ +from openhands.router.base import BaseRouter +from openhands.router.plan.llm_based import LLMBasedPlanRouter + +__all__ = ['BaseRouter', 'LLMBasedPlanRouter'] diff --git a/openhands/router/base.py b/openhands/router/base.py new file mode 100644 index 000000000000..111cb23f6814 --- /dev/null +++ b/openhands/router/base.py @@ -0,0 +1,20 @@ +from abc import ABC, abstractmethod + +from openhands.core.config.model_routing_config import ModelRoutingConfig +from openhands.llm.llm import LLM + + +class BaseRouter(ABC): + def __init__( + self, + llm: LLM, + routing_llms: dict[str, LLM], + model_routing_config: ModelRoutingConfig, + ): + self.llm = llm + self.routing_llms = routing_llms + self.model_routing_config = model_routing_config + + @abstractmethod + def should_route_to(self, prompt: str) -> LLM: + pass diff --git a/openhands/router/plan/__init__.py b/openhands/router/plan/__init__.py new file mode 100644 index 000000000000..323c4dddf224 --- /dev/null +++ b/openhands/router/plan/__init__.py @@ -0,0 +1,3 @@ +from openhands.router.plan.llm_based import LLMBasedPlanRouter + +__all__ = ['LLMBasedPlanRouter'] diff --git a/openhands/router/plan/llm_based.py b/openhands/router/plan/llm_based.py new file mode 100644 index 000000000000..aaddbc1f09f0 --- /dev/null +++ b/openhands/router/plan/llm_based.py @@ -0,0 +1,80 @@ +from openhands.core.config import ModelRoutingConfig +from openhands.llm.llm import LLM +from openhands.router.base import BaseRouter +from openhands.router.plan.prompts import ( + TRAJECTORY_JUDGE_REASONING_SYSTEM_PROMPT, + TRAJECTORY_JUDGE_REASONING_USER_PROMPT, +) + + +class LLMBasedPlanRouter(BaseRouter): + """ + Router that routes the prompt that is judged by a LLM as complex and requires a step-by-step plan. + """ + + NUM_TURNS_GAP = 1 + + def __init__( + self, + llm: LLM, + routing_llms: dict[str, LLM], + model_routing_config: ModelRoutingConfig, + ): + super().__init__(llm, routing_llms, model_routing_config) + + self._validate_model_routing_config(model_routing_config, routing_llms) + + self.judge_llm = routing_llms[model_routing_config.judge_llm_config_name] + self.reasoning_llm = routing_llms[ + model_routing_config.reasoning_llm_config_name + ] + self.routed_turns: list[int] = [] + self.cur_turn_num = 0 + + def should_route_to(self, prompt: str) -> LLM: + self.cur_turn_num += 1 + + if self.cur_turn_num - max(self.routed_turns, default=0) < self.NUM_TURNS_GAP: + return self.llm + + messages = [ + { + 'role': 'system', + 'content': TRAJECTORY_JUDGE_REASONING_SYSTEM_PROMPT, + }, + { + 'role': 'user', + 'content': TRAJECTORY_JUDGE_REASONING_USER_PROMPT.format( + interaction_log=prompt + ), + }, + ] + + response = self.judge_llm.completion( + messages=messages, + ) + should_route = int(response['choices'][0]['message']['content'].strip()) == 1 + + if should_route: + self.routed_turns.append(self.cur_turn_num) + return self.reasoning_llm + return self.llm + + def _validate_model_routing_config( + self, model_routing_config: ModelRoutingConfig, routing_llms: dict[str, LLM] + ): + if ( + not model_routing_config.judge_llm_config_name + or not model_routing_config.reasoning_llm_config_name + ): + raise ValueError( + 'Judge LLM and Reasoning LLM config names must be provided' + ) + if model_routing_config.judge_llm_config_name not in routing_llms: + raise ValueError( + f'Judge LLM config {model_routing_config.judge_llm_config_name} not found' + ) + if model_routing_config.reasoning_llm_config_name not in routing_llms: + raise ValueError( + f'Reasoning LLM config {model_routing_config.reasoning_llm_config_name} not found' + ) diff --git a/openhands/router/plan/prompts.py b/openhands/router/plan/prompts.py new file mode 100644 index 000000000000..90ecc336e8f5 --- /dev/null +++ b/openhands/router/plan/prompts.py @@ -0,0 +1,64 @@ +############################################ +######## PLAN GENERATION PROMPTS ######## +############################################ + +USER_MESSAGE_PLANNING_ANALYZE_PROMPT = """Analyze this prompt to see if it requires a detailed plan generation. + +Some example scenarios that require generating a step-by-step plan: + +1. Structured Rule-Based Tasks with Well-Defined Constraints + * Example: In a synthetic task, adhering to a sequence like loosening nuts before removing wheels is critical + +2. Tasks Requiring Step-by-Step Reasoning to plan a structured chain of actions + * Example: In a synthetic task, objects must be manipulated in a sequence to achieve a configuration + +3. Scenarios with Limited Resources or Strict Constraints + * Tasks that require resource-sensitive planning, such as minimizing actions or handling tools efficiently + * Example: In a synthetic task, we need to efficiently coordinate robot actions across rooms and minimize energy consumption costs + +4. Generalization in Familiar Symbolic Representations + * Tasks where the rules remain consistent, and the specific instances change. + * Example: When we need to adapt strategies to new but structured instances of tasks. + +5. Requests Requiring Self-Evaluation + * Self-evaluation mechanism enables the identification and correction of errors mid-process. + * Example: When we need to reevaluate actions and adjust plans or actions based on constraints. + +In context of software engineering, below are some scenarios where plan generation is required: + +1. Dependency and Workflow Management + * Automating and optimizing CI/CD pipelines, build processes, and package dependency resolution. + * Example: Resolving complex dependency graphs or sequencing multi-step deployments. +2. Code Refactoring and Debugging + * Planning systematic changes for refactoring large codebases and isolating root causes during debugging. + * Example: Refactoring monolithic code into modular components while preserving functionality. +3. Infrastructure and Resource Planning + * Designing and optimizing Infrastructure as Code (IaC) changes and dynamic resource allocation. + * Example: Planning cloud resource provisioning while adhering to dependency constraints. +4. High-level Requirements to Low-level Implementation Mapping + * Translating high-level requirements into detailed implementation steps and ensuring consistency. + +=== BEGIN USER MESSAGE === +{message} +=== END USER MESSAGE === + +Only respond with 0 for no plan generation required or 1 for plan generation required. +""" + +############################################ +######## REASONING JUDGE PROMPTS ######## +############################################ + +TRAJECTORY_JUDGE_REASONING_SYSTEM_PROMPT = """You are an expert judge evaluating AI assistant interactions. Your task is to determine if: +- the AI assistant is struggling with some issues when performing the task and needs help from a human expert to guide it +- the next step is complex and needs to be carefully reasoned to solve e.g. identifying a hard-to-find bug in a codebase + +Respond only with 0 if the AI assistant is not struggling or the task is not complex. Otherwise, respond with 1.""" + +TRAJECTORY_JUDGE_REASONING_USER_PROMPT = """Please evaluate the following interaction (or part of the recent interaction) between an AI assistant and a user: + +=== INTERACTION LOG === +{interaction_log} +=== END INTERACTION === + +Based on the above interaction, do we need to provide additional guidance to the AI assistant or is the task complex and requires careful reasoning to solve? Respond with 0 if no guidance is needed or the task is not complex. Otherwise, respond with 1.""" diff --git a/openhands/server/session/session.py b/openhands/server/session/session.py index d7807fc94740..9b5c128d4379 100644 --- a/openhands/server/session/session.py +++ b/openhands/server/session/session.py @@ -110,6 +110,11 @@ async def initialize_agent( # TODO: override other LLM config & agent config groups (#2075) llm = self._create_llm(agent_cls) + routing_llms = {} + for config_name, routing_llm_config in self.config.routing_llms.items(): + routing_llms[config_name] = LLM( + config=routing_llm_config, + ) agent_config = self.config.get_agent_config(agent_cls) if settings.enable_default_condenser: @@ -119,7 +124,12 @@ async def initialize_agent( logger.info(f'Enabling default condenser: {default_condenser_config}') agent_config.condenser = default_condenser_config - agent = Agent.get_cls(agent_cls)(llm, agent_config) + agent = Agent.get_cls(agent_cls)( + llm=llm, + config=agent_config, + model_routing_config=self.config.model_routing, + routing_llms=routing_llms, + ) github_token = None selected_repository = None diff --git a/openhands/utils/trajectory.py b/openhands/utils/trajectory.py new file mode 100644 index 000000000000..bba09729f120 --- /dev/null +++ b/openhands/utils/trajectory.py @@ -0,0 +1,101 @@ +""" +Utility functions for processing and formatting trajectories. +Original code from: https://github.com/SWE-Gym/SWE-Gym/blob/main/scripts/openhands-verifier/aggregate_stats_pass_at_n.ipynb +""" + +import json + +from litellm import ChatCompletionMessageToolCall + +from openhands.core.message import ImageContent, Message, TextContent + + +def convert_content(content: list[TextContent | ImageContent]) -> str: + """Converts a list of message content to a single string.""" + return '\n'.join(item.text for item in content if item.type == 'text') + + +def convert_tool_call_to_string(tool_call: ChatCompletionMessageToolCall) -> str: + """Converts tool call arguments to a string representation.""" + try: + args = json.loads(tool_call.function.arguments) + except json.JSONDecodeError as e: + raise ValueError( + f"Failed to parse arguments as JSON. Arguments: {tool_call['function']['arguments']}" + ) from e + + tool_call_str = f'\n' + for param_name, param_value in args.items(): + is_multiline_value = isinstance(param_value, str) and '\n' in param_value + param_value = '\n' + param_value + '\n' if is_multiline_value else param_value + tool_call_str += f'{param_value}\n' + tool_call_str += '' + return tool_call_str + + +def merge_user_messages(traj: list[Message]) -> list[Message]: + """Merges consecutive user messages into a single message.""" + merged_traj = [] + current_messages = [] + + for message in traj: + if message.role == 'user': + current_messages.append(message) + else: + if current_messages: + merged_content = '\n'.join( + convert_content(msg.content) for msg in current_messages + ) + merged_traj.append( + Message(role='user', content=[TextContent(text=merged_content)]) + ) + current_messages = [] + merged_traj.append(message) + + if current_messages: + merged_content = '\n'.join( + convert_content(msg.content) for msg in current_messages + ) + merged_traj.append( + Message(role='user', content=[TextContent(text=merged_content)]) + ) + + return merged_traj + + +def format_trajectory(traj: list[Message]) -> str: + """Formats the message trajectory into a human-readable string.""" + output = '' + system_message = None + + if traj: + # Handle system message if present + if traj[0].role == 'system': + system_message = traj[0] + traj = traj[1:] + content = convert_content(system_message.content) + output += "*** System Message that describes the assistant's behavior ***\n" + output += f'{content}\n' + + # Merge consecutive user messages + merged_traj = merge_user_messages(traj) + + # Process the merged trajectory + for i, message in enumerate(merged_traj): + role = message.role + content = convert_content(message.content) + turn_id = i // 2 + 1 + output += '-' * 100 + '\n' + output += f'*** Turn {turn_id} - {role.upper() if role != "tool" else "TOOL EXECUTION RESULT"} ***\n' + + if role == 'user' or role == 'tool' or role == 'assistant': + output += f'{content}\n' + if role == 'assistant' and message.tool_calls: + for toolcall_id, tool_call in enumerate(message.tool_calls): + output += f'### Tool Call {toolcall_id}\n' + output += f'{convert_tool_call_to_string(tool_call)}\n' + else: + raise ValueError(f'Unexpected role: {role}') + + output += '-' * 100 + '\n' + return output diff --git a/tests/unit/test_trajectory_formatter.py b/tests/unit/test_trajectory_formatter.py new file mode 100644 index 000000000000..9dcbaa6cd08c --- /dev/null +++ b/tests/unit/test_trajectory_formatter.py @@ -0,0 +1,117 @@ +import pytest +from litellm import ChatCompletionMessageToolCall + +from openhands.core.message import Message, TextContent +from openhands.utils.trajectory import format_trajectory + + +# Helper function to create a mock ChatCompletionMessageToolCall +def create_mock_tool_call(name: str, arguments: str): + return ChatCompletionMessageToolCall( + function={'name': name, 'arguments': arguments} + ) + + +def test_empty_trajectory(): + traj = [] + assert ( + format_trajectory(traj) + == """---------------------------------------------------------------------------------------------------- +""" + ) + + +def test_system_message_only(): + traj = [ + Message( + role='system', content=[TextContent(text='System behavior description.')] + ) + ] + expected_output = """*** System Message that describes the assistant's behavior *** +System behavior description. +---------------------------------------------------------------------------------------------------- +""" + assert format_trajectory(traj) == expected_output + + +def test_user_messages_only(): + traj = [ + Message( + role='user', + content=[TextContent(text='Hello.'), TextContent(text='How are you?')], + ) + ] + expected_output = """---------------------------------------------------------------------------------------------------- +*** Turn 1 - USER *** +Hello. +How are you? +---------------------------------------------------------------------------------------------------- +""" + assert format_trajectory(traj) == expected_output + + +def test_mixed_messages(): + traj = [ + Message( + role='system', content=[TextContent(text='System behavior description.')] + ), + Message(role='user', content=[TextContent(text='Hello.')]), + Message(role='assistant', content=[TextContent(text='Hi there!')]), + Message(role='user', content=[TextContent(text='你好')]), + Message(role='assistant', content=[TextContent(text='你好')]), + ] + expected_output = """*** System Message that describes the assistant's behavior *** +System behavior description. +---------------------------------------------------------------------------------------------------- +*** Turn 1 - USER *** +Hello. +---------------------------------------------------------------------------------------------------- +*** Turn 1 - ASSISTANT *** +Hi there! +---------------------------------------------------------------------------------------------------- +*** Turn 2 - USER *** +你好 +---------------------------------------------------------------------------------------------------- +*** Turn 2 - ASSISTANT *** +你好 +---------------------------------------------------------------------------------------------------- +""" + assert format_trajectory(traj) == expected_output + + +def test_tool_call_handling(): + tool_call = create_mock_tool_call( + name='fn', arguments='{"param1": "value1", "param2": "value2"}' + ) + traj = [ + Message( + role='assistant', + content=[TextContent(text='Running the tool.')], + tool_calls=[tool_call], + ) + ] + expected_output = """---------------------------------------------------------------------------------------------------- +*** Turn 1 - ASSISTANT *** +Running the tool. +### Tool Call 0 + +value1 +value2 + +---------------------------------------------------------------------------------------------------- +""" + print(format_trajectory(traj)) + assert format_trajectory(traj) == expected_output + + +def test_invalid_tool_call(): + tool_call = create_mock_tool_call(name='fn', arguments='invalid json') + traj = [ + Message( + role='assistant', + content=[TextContent(text='Running the tool.')], + tool_calls=[tool_call], + ) + ] + with pytest.raises(ValueError, match='Failed to parse arguments as JSON'): + format_trajectory(traj)