Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Gemini prompt caching #5005

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 20 additions & 13 deletions openhands/agenthub/codeact_agent/codeact_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,9 @@ def get_action_message(
)
]
elif isinstance(action, CmdRunAction) and action.source == 'user':
content = [TextContent(text=f'User executed the command:\n{action.command}')]
content = [
TextContent(text=f'User executed the command:\n{action.command}')
]
return [
Message(
role='user',
Expand Down Expand Up @@ -452,18 +454,23 @@ def _get_messages(self, state: State) -> list[Message]:
messages.append(message)

if self.llm.is_caching_prompt_active():
# NOTE: this is only needed for anthropic
# following logic here:
# https://github.com/anthropics/anthropic-quickstarts/blob/8f734fd08c425c6ec91ddd613af04ff87d70c5a0/computer-use-demo/computer_use_demo/loop.py#L241-L262
breakpoints_remaining = 3 # remaining 1 for system/tool
for message in reversed(messages):
if message.role == 'user' or message.role == 'tool':
if breakpoints_remaining > 0:
message.content[
-1
].cache_prompt = True # Last item inside the message content
breakpoints_remaining -= 1
else:
# For models that only support one checkpoint, just cache the last user/tool message
if self.llm.is_single_checkpoint_model():
for message in reversed(messages):
if message.role == 'user' or message.role == 'tool':
message.content[-1].cache_prompt = True
break
else:
# NOTE: this is only needed for anthropic
# following logic here:
# https://github.com/anthropics/anthropic-quickstarts/blob/8f734fd08c425c6ec91ddd613af04ff87d70c5a0/computer-use-demo/computer_use_demo/loop.py#L241-L262
breakpoints_remaining = 3 # remaining 1 for system/tool
for message in reversed(messages):
if message.role == 'user' or message.role == 'tool':
if breakpoints_remaining > 0:
message.content[-1].cache_prompt = True
breakpoints_remaining -= 1
else:
break

return messages
37 changes: 16 additions & 21 deletions openhands/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,6 @@
ServiceUnavailableError,
)

# cache prompt supporting models
# remove this when we gemini and deepseek are supported
CACHE_PROMPT_SUPPORTED_MODELS = [
'claude-3-5-sonnet-20241022',
'claude-3-5-sonnet-20240620',
'claude-3-5-haiku-20241022',
'claude-3-haiku-20240307',
'claude-3-opus-20240229',
]

# function calling supporting models
FUNCTION_CALLING_SUPPORTED_MODELS = [
Expand All @@ -70,6 +61,10 @@
'gpt-4o',
]

SINGLE_CHECKPOINT_MODELS = [
'gemini', # Gemini only supports one checkpoint for prompt caching
]


class LLM(RetryMixin, DebugMixin):
"""The LLM class represents a Language Model instance.
Expand Down Expand Up @@ -122,13 +117,6 @@ def __init__(
drop_params=self.config.drop_params,
)

if self.vision_is_active():
logger.debug('LLM: model has vision enabled')
if self.is_caching_prompt_active():
logger.debug('LLM: caching prompt enabled')
if self.is_function_calling_active():
logger.debug('LLM: model supports function calling')

self._completion = partial(
litellm_completion,
model=self.config.model,
Expand Down Expand Up @@ -403,11 +391,8 @@ def is_caching_prompt_active(self) -> bool:
"""
return (
self.config.caching_prompt is True
and (
self.config.model in CACHE_PROMPT_SUPPORTED_MODELS
or self.config.model.split('/')[-1] in CACHE_PROMPT_SUPPORTED_MODELS
)
# We don't need to look-up model_info, because only Anthropic models needs the explicit caching breakpoint
and self.model_info is not None
and self.model_info.get('supports_prompt_caching', False)
)

def is_function_calling_active(self) -> bool:
Expand Down Expand Up @@ -570,3 +555,13 @@ def format_messages_for_llm(self, messages: Message | list[Message]) -> list[dic

# let pydantic handle the serialization
return [message.model_dump() for message in messages]

def is_single_checkpoint_model(self) -> bool:
"""Check if model only supports a single prompt cache checkpoint.

Returns:
bool: True if model only supports one checkpoint (e.g. Gemini)
"""
return any(
model in self.config.model.lower() for model in SINGLE_CHECKPOINT_MODELS
)
Loading