Skip to content

Commit

Permalink
fix tool calls in non-native function calling messages
Browse files Browse the repository at this point in the history
  • Loading branch information
enyst committed Nov 15, 2024
1 parent 2ea1297 commit ffca2b4
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 18 deletions.
52 changes: 35 additions & 17 deletions openhands/core/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ class Message(BaseModel):
cache_enabled: bool = False
vision_enabled: bool = False
# function calling
function_calling_enabled: bool = False
# - tool calls (from LLM)
tool_calls: list[ChatCompletionMessageToolCall] | None = None
# - tool execution result (to LLM)
Expand All @@ -72,22 +73,21 @@ def serialize_model(self) -> dict:
# - into a single string: for providers that don't support list of content items (e.g. no vision, no tool calls)
# - into a list of content items: the new APIs of providers with vision/prompt caching/tool calls
# NOTE: remove this when litellm or providers support the new API
if (
self.cache_enabled
or self.vision_enabled
or self.tool_call_id is not None
or self.tool_calls is not None
):
if self.cache_enabled or self.vision_enabled or self.function_calling_enabled:
return self._list_serializer()
return self._string_serializer()

def _string_serializer(self):
def _string_serializer(self) -> dict:
# convert content to a single string
content = '\n'.join(
item.text for item in self.content if isinstance(item, TextContent)
)
return {'content': content, 'role': self.role}
message_dict: dict = {'content': content, 'role': self.role}

# add tool call keys if we have a tool call or response
return self._add_tool_call_keys(message_dict)

def _list_serializer(self):
def _list_serializer(self) -> dict:
content: list[dict] = []
role_tool_with_prompt_caching = False
for item in self.content:
Expand All @@ -102,24 +102,42 @@ def _list_serializer(self):
elif isinstance(item, ImageContent) and self.vision_enabled:
content.extend(d)

ret: dict = {'content': content, 'role': self.role}
message_dict: dict = {'content': content, 'role': self.role}

# pop content if it's empty
if not content or (
len(content) == 1
and content[0]['type'] == 'text'
and content[0]['text'] == ''
):
ret.pop('content')
message_dict.pop('content')

# some providers, like HF and Groq/llama, don't support a list here, but a single string
# if not self.function_calling_enabled:
# content_str = '\n'.join([item['text'] for item in content])
# message_dict['content'] = [{'type': 'text', 'text': content_str}]

if role_tool_with_prompt_caching:
ret['cache_control'] = {'type': 'ephemeral'}
message_dict['cache_control'] = {'type': 'ephemeral'}

# add tool call keys if we have a tool call or response
return self._add_tool_call_keys(message_dict)

def _add_tool_call_keys(self, message_dict: dict) -> dict:
"""Add tool call keys if we have a tool call or response.
NOTE: this is necessary for both native and non-native tool calling"""

# an assistant message calling a tool
if self.tool_calls is not None:
message_dict['tool_calls'] = self.tool_calls

# an observation message with tool response
if self.tool_call_id is not None:
assert (
self.name is not None
), 'name is required when tool_call_id is not None'
ret['tool_call_id'] = self.tool_call_id
ret['name'] = self.name
if self.tool_calls:
ret['tool_calls'] = self.tool_calls
return ret
message_dict['tool_call_id'] = self.tool_call_id
message_dict['name'] = self.name

return message_dict
2 changes: 1 addition & 1 deletion openhands/llm/fn_call_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,7 +449,7 @@ def convert_fncall_messages_to_non_fncall_messages(

# 4. TOOL MESSAGES (tool outputs)
elif role == 'tool':
# Convert tool result as assistant message
# Convert tool result as user message
tool_name = message.get('name', 'function')
prefix = f'EXECUTION RESULT of [{tool_name}]:\n'
# and omit "tool_call_id" AND "name"
Expand Down
1 change: 1 addition & 0 deletions openhands/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,6 +567,7 @@ def format_messages_for_llm(self, messages: Message | list[Message]) -> list[dic
for message in messages:
message.cache_enabled = self.is_caching_prompt_active()
message.vision_enabled = self.vision_is_active()
message.function_calling_enabled = self.is_function_calling_active()

# let pydantic handle the serialization
return [message.model_dump() for message in messages]

0 comments on commit ffca2b4

Please sign in to comment.