From fe009ea78f49755cec80527fecb08e49483304ae Mon Sep 17 00:00:00 2001 From: Engel Nyst Date: Thu, 14 Nov 2024 20:22:36 +0100 Subject: [PATCH 1/2] fix gemini prompt caching --- .../agenthub/codeact_agent/codeact_agent.py | 33 +++++++++++-------- openhands/llm/llm.py | 14 ++++++++ 2 files changed, 34 insertions(+), 13 deletions(-) diff --git a/openhands/agenthub/codeact_agent/codeact_agent.py b/openhands/agenthub/codeact_agent/codeact_agent.py index 39b9e69247be..9da720d31630 100644 --- a/openhands/agenthub/codeact_agent/codeact_agent.py +++ b/openhands/agenthub/codeact_agent/codeact_agent.py @@ -187,7 +187,9 @@ def get_action_message( ) ] elif isinstance(action, CmdRunAction) and action.source == 'user': - content = [TextContent(text=f'User executed the command:\n{action.command}')] + content = [ + TextContent(text=f'User executed the command:\n{action.command}') + ] return [ Message( role='user', @@ -452,18 +454,23 @@ def _get_messages(self, state: State) -> list[Message]: messages.append(message) if self.llm.is_caching_prompt_active(): - # NOTE: this is only needed for anthropic - # following logic here: - # https://github.com/anthropics/anthropic-quickstarts/blob/8f734fd08c425c6ec91ddd613af04ff87d70c5a0/computer-use-demo/computer_use_demo/loop.py#L241-L262 - breakpoints_remaining = 3 # remaining 1 for system/tool - for message in reversed(messages): - if message.role == 'user' or message.role == 'tool': - if breakpoints_remaining > 0: - message.content[ - -1 - ].cache_prompt = True # Last item inside the message content - breakpoints_remaining -= 1 - else: + # For models that only support one checkpoint, just cache the last user/tool message + if self.llm.is_single_checkpoint_model(): + for message in reversed(messages): + if message.role == 'user' or message.role == 'tool': + message.content[-1].cache_prompt = True break + else: + # NOTE: this is only needed for anthropic + # following logic here: + # https://github.com/anthropics/anthropic-quickstarts/blob/8f734fd08c425c6ec91ddd613af04ff87d70c5a0/computer-use-demo/computer_use_demo/loop.py#L241-L262 + breakpoints_remaining = 3 # remaining 1 for system/tool + for message in reversed(messages): + if message.role == 'user' or message.role == 'tool': + if breakpoints_remaining > 0: + message.content[-1].cache_prompt = True + breakpoints_remaining -= 1 + else: + break return messages diff --git a/openhands/llm/llm.py b/openhands/llm/llm.py index 0590945995c1..61c9233d6936 100644 --- a/openhands/llm/llm.py +++ b/openhands/llm/llm.py @@ -70,6 +70,10 @@ 'gpt-4o', ] +SINGLE_CHECKPOINT_MODELS = [ + 'gemini', # Gemini only supports one checkpoint for prompt caching +] + class LLM(RetryMixin, DebugMixin): """The LLM class represents a Language Model instance. @@ -570,3 +574,13 @@ def format_messages_for_llm(self, messages: Message | list[Message]) -> list[dic # let pydantic handle the serialization return [message.model_dump() for message in messages] + + def is_single_checkpoint_model(self) -> bool: + """Check if model only supports a single prompt cache checkpoint. + + Returns: + bool: True if model only supports one checkpoint (e.g. Gemini) + """ + return any( + model in self.config.model.lower() for model in SINGLE_CHECKPOINT_MODELS + ) From b24dfd3e24381a49d7c099e3068d834a858a6116 Mon Sep 17 00:00:00 2001 From: Engel Nyst Date: Thu, 14 Nov 2024 20:47:57 +0100 Subject: [PATCH 2/2] with gemini support we don't need a list --- openhands/llm/llm.py | 23 ++--------------------- 1 file changed, 2 insertions(+), 21 deletions(-) diff --git a/openhands/llm/llm.py b/openhands/llm/llm.py index 61c9233d6936..fe55c2e9d3f6 100644 --- a/openhands/llm/llm.py +++ b/openhands/llm/llm.py @@ -50,15 +50,6 @@ ServiceUnavailableError, ) -# cache prompt supporting models -# remove this when we gemini and deepseek are supported -CACHE_PROMPT_SUPPORTED_MODELS = [ - 'claude-3-5-sonnet-20241022', - 'claude-3-5-sonnet-20240620', - 'claude-3-5-haiku-20241022', - 'claude-3-haiku-20240307', - 'claude-3-opus-20240229', -] # function calling supporting models FUNCTION_CALLING_SUPPORTED_MODELS = [ @@ -126,13 +117,6 @@ def __init__( drop_params=self.config.drop_params, ) - if self.vision_is_active(): - logger.debug('LLM: model has vision enabled') - if self.is_caching_prompt_active(): - logger.debug('LLM: caching prompt enabled') - if self.is_function_calling_active(): - logger.debug('LLM: model supports function calling') - self._completion = partial( litellm_completion, model=self.config.model, @@ -407,11 +391,8 @@ def is_caching_prompt_active(self) -> bool: """ return ( self.config.caching_prompt is True - and ( - self.config.model in CACHE_PROMPT_SUPPORTED_MODELS - or self.config.model.split('/')[-1] in CACHE_PROMPT_SUPPORTED_MODELS - ) - # We don't need to look-up model_info, because only Anthropic models needs the explicit caching breakpoint + and self.model_info is not None + and self.model_info.get('supports_prompt_caching', False) ) def is_function_calling_active(self) -> bool: