Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix non-function calls messages #5026

Merged
merged 6 commits into from
Nov 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 31 additions & 17 deletions openhands/core/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ class Message(BaseModel):
cache_enabled: bool = False
vision_enabled: bool = False
# function calling
function_calling_enabled: bool = False
# - tool calls (from LLM)
tool_calls: list[ChatCompletionMessageToolCall] | None = None
# - tool execution result (to LLM)
Expand All @@ -72,22 +73,22 @@ def serialize_model(self) -> dict:
# - into a single string: for providers that don't support list of content items (e.g. no vision, no tool calls)
# - into a list of content items: the new APIs of providers with vision/prompt caching/tool calls
# NOTE: remove this when litellm or providers support the new API
if (
self.cache_enabled
or self.vision_enabled
or self.tool_call_id is not None
or self.tool_calls is not None
):
if self.cache_enabled or self.vision_enabled or self.function_calling_enabled:
return self._list_serializer()
# some providers, like HF and Groq/llama, don't support a list here, but a single string
return self._string_serializer()

def _string_serializer(self):
def _string_serializer(self) -> dict:
# convert content to a single string
content = '\n'.join(
item.text for item in self.content if isinstance(item, TextContent)
)
return {'content': content, 'role': self.role}
message_dict: dict = {'content': content, 'role': self.role}

# add tool call keys if we have a tool call or response
return self._add_tool_call_keys(message_dict)

def _list_serializer(self):
def _list_serializer(self) -> dict:
content: list[dict] = []
role_tool_with_prompt_caching = False
for item in self.content:
Expand All @@ -102,24 +103,37 @@ def _list_serializer(self):
elif isinstance(item, ImageContent) and self.vision_enabled:
content.extend(d)

ret: dict = {'content': content, 'role': self.role}
message_dict: dict = {'content': content, 'role': self.role}

# pop content if it's empty
if not content or (
len(content) == 1
and content[0]['type'] == 'text'
and content[0]['text'] == ''
):
ret.pop('content')
message_dict.pop('content')

if role_tool_with_prompt_caching:
ret['cache_control'] = {'type': 'ephemeral'}
message_dict['cache_control'] = {'type': 'ephemeral'}

# add tool call keys if we have a tool call or response
return self._add_tool_call_keys(message_dict)

def _add_tool_call_keys(self, message_dict: dict) -> dict:
"""Add tool call keys if we have a tool call or response.

NOTE: this is necessary for both native and non-native tool calling"""

# an assistant message calling a tool
if self.tool_calls is not None:
message_dict['tool_calls'] = self.tool_calls

# an observation message with tool response
if self.tool_call_id is not None:
assert (
self.name is not None
), 'name is required when tool_call_id is not None'
ret['tool_call_id'] = self.tool_call_id
ret['name'] = self.name
if self.tool_calls:
ret['tool_calls'] = self.tool_calls
return ret
message_dict['tool_call_id'] = self.tool_call_id
message_dict['name'] = self.name

return message_dict
12 changes: 7 additions & 5 deletions openhands/llm/fn_call_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,9 +319,8 @@ def convert_fncall_messages_to_non_fncall_messages(
converted_messages = []
first_user_message_encountered = False
for message in messages:
role, content = message['role'], message['content']
if content is None:
content = ''
role = message['role']
content = message.get('content', '')

# 1. SYSTEM MESSAGES
# append system prompt suffix to content
Expand All @@ -338,6 +337,7 @@ def convert_fncall_messages_to_non_fncall_messages(
f'Unexpected content type {type(content)}. Expected str or list. Content: {content}'
)
converted_messages.append({'role': 'system', 'content': content})

# 2. USER MESSAGES (no change)
elif role == 'user':
# Add in-context learning example for the first user message
Expand Down Expand Up @@ -446,10 +446,12 @@ def convert_fncall_messages_to_non_fncall_messages(
f'Unexpected content type {type(content)}. Expected str or list. Content: {content}'
)
converted_messages.append({'role': 'assistant', 'content': content})

# 4. TOOL MESSAGES (tool outputs)
elif role == 'tool':
# Convert tool result as assistant message
prefix = f'EXECUTION RESULT of [{message["name"]}]:\n'
# Convert tool result as user message
tool_name = message.get('name', 'function')
prefix = f'EXECUTION RESULT of [{tool_name}]:\n'
# and omit "tool_call_id" AND "name"
if isinstance(content, str):
content = prefix + content
Expand Down
21 changes: 11 additions & 10 deletions openhands/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,9 @@ def __init__(
drop_params=self.config.drop_params,
)

with warnings.catch_warnings():
warnings.simplefilter('ignore')
self.init_model_info()
if self.vision_is_active():
logger.debug('LLM: model has vision enabled')
if self.is_caching_prompt_active():
Expand All @@ -143,16 +146,6 @@ def __init__(
drop_params=self.config.drop_params,
)

with warnings.catch_warnings():
warnings.simplefilter('ignore')
self.init_model_info()
if self.vision_is_active():
logger.debug('LLM: model has vision enabled')
if self.is_caching_prompt_active():
logger.debug('LLM: caching prompt enabled')
if self.is_function_calling_active():
logger.debug('LLM: model supports function calling')

self._completion_unwrapped = self._completion

@self.retry_decorator(
Expand Down Expand Up @@ -343,6 +336,13 @@ def init_model_info(self):
pass
logger.debug(f'Model info: {self.model_info}')

if self.config.model.startswith('huggingface'):
# HF doesn't support the OpenAI default value for top_p (1)
logger.debug(
f'Setting top_p to 0.9 for Hugging Face model: {self.config.model}'
)
self.config.top_p = 0.9 if self.config.top_p == 1 else self.config.top_p

# Set the max tokens in an LM-specific way if not set
if self.config.max_input_tokens is None:
if (
Expand Down Expand Up @@ -567,6 +567,7 @@ def format_messages_for_llm(self, messages: Message | list[Message]) -> list[dic
for message in messages:
message.cache_enabled = self.is_caching_prompt_active()
message.vision_enabled = self.vision_is_active()
message.function_calling_enabled = self.is_function_calling_active()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will only be True for the selected set of model -- but we still send messages WITH tool_ids to the LLM

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahh so we need the part of the serializer that adds tool_call_id. OK, sure, I'll revert that. Sigh, this is messy. The attempt to separate things when function calling is native is because we need to send a single thing, not a list.

(I'll follow up on litellm too, they actually do this kind of conversion, so we don't have to play whack-a-mole anymore)

Copy link
Collaborator Author

@enyst enyst Nov 21, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@xingyaoww I actually restored this check, but it does send tool ids now. It just sends them as string, if non-native function calling, and as list if native. (because this applies now)

I started to work on refactoring the Message class. I think it needs to include tool calls/ids conversion part (at least as a single Message), do full serialization, and become closer to litellm's Message. I started to work on it on this branch, but in the meantime, I found out that people rely on this PR, and use it via docker because it solves what's broken with non-native function calling.

So I'll make another branch to work on the refactoring, and this PR is up for review again. If it's not too terrible, I'd say let's just fix the problems first, and refactor later.

(I'll clean up some leftover comment, too, but I prefer to do it if it's approved, because I think it would again break people's use via docker on the last commit.)


# let pydantic handle the serialization
return [message.model_dump() for message in messages]
Loading