Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: Simplify prompt caching for new Anthropic API #6860

Merged
merged 2 commits into from
Feb 20, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 8 additions & 11 deletions openhands/core/message_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,17 +351,14 @@ def get_observation_message(


def apply_prompt_caching(messages: list[Message]) -> None:
"""Applies caching breakpoints to the messages."""
"""Applies caching breakpoints to the messages.

For new Anthropic API, we only need to mark the last user or tool message as cacheable.
"""
# NOTE: this is only needed for anthropic
# following logic here:
# https://github.com/anthropics/anthropic-quickstarts/blob/8f734fd08c425c6ec91ddd613af04ff87d70c5a0/computer-use-demo/computer_use_demo/loop.py#L241-L262
breakpoints_remaining = 3 # remaining 1 for system/tool
for message in reversed(messages):
if message.role in ('user', 'tool'):
if breakpoints_remaining > 0:
message.content[
-1
].cache_prompt = True # Last item inside the message content
breakpoints_remaining -= 1
else:
break
message.content[
-1
].cache_prompt = True # Last item inside the message content
break
15 changes: 7 additions & 8 deletions tests/unit/test_prompt_caching.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,12 +84,12 @@ def test_get_messages(codeact_agent: CodeActAgent):
assert messages[0].content[0].cache_prompt # system message
assert messages[1].role == 'user'
assert messages[1].content[0].text.endswith('Initial user message')
# we add cache breakpoint to the last 3 user messages
assert messages[1].content[0].cache_prompt
# we add cache breakpoint to only the last user message
assert not messages[1].content[0].cache_prompt

assert messages[3].role == 'user'
assert messages[3].content[0].text == ('Hello, agent!')
assert messages[3].content[0].cache_prompt
assert not messages[3].content[0].cache_prompt
assert messages[4].role == 'assistant'
assert messages[4].content[0].text == 'Hello, user!'
assert not messages[4].content[0].cache_prompt
Expand Down Expand Up @@ -121,10 +121,9 @@ def test_get_messages_prompt_caching(codeact_agent: CodeActAgent):
if msg.role in ('user', 'system') and msg.content[0].cache_prompt
]
assert (
len(cached_user_messages) == 4
) # Including the initial system+user + 2 last user message
len(cached_user_messages) == 2
) # Including the initial system+user + last user message

# Verify that these are indeed the last two user messages (from start)
# Verify that these are indeed the last user message (from start)
assert cached_user_messages[0].content[0].text.startswith('You are OpenHands agent')
assert cached_user_messages[2].content[0].text.startswith('User message 1')
assert cached_user_messages[3].content[0].text.startswith('User message 1')
assert cached_user_messages[1].content[0].text.startswith('User message 14')
Loading