Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
enyst committed Nov 25, 2024
1 parent 3e83fd2 commit 6e72669
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 19 deletions.
15 changes: 9 additions & 6 deletions openhands/agenthub/codeact_agent/codeact_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
ImageContent,
Message,
TextContent,
ToolResponseContent,
)
from openhands.events.action import (
Action,
Expand Down Expand Up @@ -156,10 +155,7 @@ def get_action_message(
FileEditAction,
BrowseInteractiveAction,
),
) or (
isinstance(action, CmdRunAction)
and action.source == 'agent'
):
) or (isinstance(action, CmdRunAction) and action.source == 'agent'):
tool_metadata = action.tool_call_metadata
assert tool_metadata is not None
llm_response: ModelResponse = tool_metadata.model_response
Expand Down Expand Up @@ -206,14 +202,21 @@ def get_action_message(
'Tool call metadata should NOT be None when function calling is enabled. Action: '
+ str(action)
)
llm_response: ModelResponse = tool_metadata.model_response
llm_response = tool_metadata.model_response
assistant_msg = llm_response.choices[0].message
return [
Message(
role=assistant_msg.role,
content=[TextContent(text=assistant_msg.content or '')],
)
]
elif isinstance(action, AgentFinishAction) and action.source == 'user':
return [
Message(
role='user',
content=[TextContent(text=action.thought or '')],
)
]
return []

def get_observation_message(
Expand Down
35 changes: 22 additions & 13 deletions openhands/core/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from litellm import ChatCompletionMessageToolCall
from pydantic import BaseModel, Field, model_serializer

from openhands.core.exceptions import FunctionCallConversionError
from openhands.llm.fn_call_converter import (
IN_CONTEXT_LEARNING_EXAMPLE_PREFIX,
IN_CONTEXT_LEARNING_EXAMPLE_SUFFIX,
Expand Down Expand Up @@ -277,7 +278,9 @@ def _native_function_serializer(self, use_list_format: bool) -> dict:
return message_dict

@classmethod
def convert_messages_to_non_native(cls, messages: list[dict], tools: list[dict]) -> list[dict]:
def convert_messages_to_non_native(
cls, messages: list[dict], tools: list[dict]
) -> list[dict]:
"""Convert a list of messages from native to non-native format.
Used when the API doesn't support native function calling."""
converted_messages = []
Expand All @@ -301,7 +304,9 @@ def convert_messages_to_non_native(cls, messages: list[dict], tools: list[dict])
if role == 'assistant' and 'tool_calls' in message:
tool_calls = message['tool_calls']
if len(tool_calls) != 1:
raise ValueError(f'Expected exactly one tool call, got {len(tool_calls)}')
raise ValueError(
f'Expected exactly one tool call, got {len(tool_calls)}'
)

# Create Message with both text and tool call content
content_list = []
Expand All @@ -311,14 +316,14 @@ def convert_messages_to_non_native(cls, messages: list[dict], tools: list[dict])
ToolCallContent(
function_name=tool_calls[0]['function']['name'],
function_arguments=tool_calls[0]['function']['arguments'],
tool_call_id=tool_calls[0]['id']
tool_call_id=tool_calls[0]['id'],
)
)
converted_messages.append(
Message(
role='assistant',
content=content_list,
function_calling_enabled=False
function_calling_enabled=False,
).model_dump()
)
continue
Expand All @@ -329,14 +334,14 @@ def convert_messages_to_non_native(cls, messages: list[dict], tools: list[dict])
ToolResponseContent(
tool_call_id=message['tool_call_id'],
name=message.get('name', 'function'),
content=message['content']
content=message['content'],
)
]
converted_messages.append(
Message(
role='tool',
content=content_list,
function_calling_enabled=False
function_calling_enabled=False,
).model_dump()
)
continue
Expand Down Expand Up @@ -416,11 +421,13 @@ def _tool_call_to_string(cls, tool_call: dict) -> str:
def _add_tools_description(cls, message: dict, tools: list[dict]) -> dict:
"""Add tools description to a system message."""
formatted_tools = cls._tools_to_description(tools)
system_prompt_suffix = SYSTEM_PROMPT_SUFFIX_TEMPLATE.format(description=formatted_tools)

system_prompt_suffix = SYSTEM_PROMPT_SUFFIX_TEMPLATE.format(
description=formatted_tools
)

content = message.get('content', '')
content_list = []

if isinstance(content, str):
content_list.append(TextContent(text=content + system_prompt_suffix))
elif isinstance(content, list):
Expand All @@ -433,7 +440,7 @@ def _add_tools_description(cls, message: dict, tools: list[dict]) -> dict:
)
# Add single suffix - use append
content_list.append(TextContent(text=system_prompt_suffix))

return {
**message,
'content': content_list,
Expand All @@ -444,11 +451,13 @@ def _add_in_context_learning(cls, message: dict) -> dict:
"""Add in-context learning example to first user message."""
content = message.get('content', '')
content_list = []

if isinstance(content, str):
content_list.append(
TextContent(
text=IN_CONTEXT_LEARNING_EXAMPLE_PREFIX + content + IN_CONTEXT_LEARNING_EXAMPLE_SUFFIX
text=IN_CONTEXT_LEARNING_EXAMPLE_PREFIX
+ content
+ IN_CONTEXT_LEARNING_EXAMPLE_SUFFIX
)
)
elif isinstance(content, list):
Expand All @@ -463,7 +472,7 @@ def _add_in_context_learning(cls, message: dict) -> dict:
)
# Add the suffix
content_list.append(TextContent(text=IN_CONTEXT_LEARNING_EXAMPLE_SUFFIX))

return {
**message,
'content': content_list,
Expand Down

0 comments on commit 6e72669

Please sign in to comment.