diff --git a/openhands/core/message_utils.py b/openhands/core/message_utils.py index 25be12873177..7683c7c4453c 100644 --- a/openhands/core/message_utils.py +++ b/openhands/core/message_utils.py @@ -351,17 +351,14 @@ def get_observation_message( def apply_prompt_caching(messages: list[Message]) -> None: - """Applies caching breakpoints to the messages.""" + """Applies caching breakpoints to the messages. + + For new Anthropic API, we only need to mark the last user or tool message as cacheable. + """ # NOTE: this is only needed for anthropic - # following logic here: - # https://github.com/anthropics/anthropic-quickstarts/blob/8f734fd08c425c6ec91ddd613af04ff87d70c5a0/computer-use-demo/computer_use_demo/loop.py#L241-L262 - breakpoints_remaining = 3 # remaining 1 for system/tool for message in reversed(messages): if message.role in ('user', 'tool'): - if breakpoints_remaining > 0: - message.content[ - -1 - ].cache_prompt = True # Last item inside the message content - breakpoints_remaining -= 1 - else: - break + message.content[ + -1 + ].cache_prompt = True # Last item inside the message content + break diff --git a/tests/unit/test_prompt_caching.py b/tests/unit/test_prompt_caching.py index 3258fa486a9f..2e149c1a37fd 100644 --- a/tests/unit/test_prompt_caching.py +++ b/tests/unit/test_prompt_caching.py @@ -84,12 +84,12 @@ def test_get_messages(codeact_agent: CodeActAgent): assert messages[0].content[0].cache_prompt # system message assert messages[1].role == 'user' assert messages[1].content[0].text.endswith('Initial user message') - # we add cache breakpoint to the last 3 user messages - assert messages[1].content[0].cache_prompt + # we add cache breakpoint to only the last user message + assert not messages[1].content[0].cache_prompt assert messages[3].role == 'user' assert messages[3].content[0].text == ('Hello, agent!') - assert messages[3].content[0].cache_prompt + assert not messages[3].content[0].cache_prompt assert messages[4].role == 'assistant' assert messages[4].content[0].text == 'Hello, user!' assert not messages[4].content[0].cache_prompt @@ -121,10 +121,9 @@ def test_get_messages_prompt_caching(codeact_agent: CodeActAgent): if msg.role in ('user', 'system') and msg.content[0].cache_prompt ] assert ( - len(cached_user_messages) == 4 - ) # Including the initial system+user + 2 last user message + len(cached_user_messages) == 2 + ) # Including the initial system+user + last user message - # Verify that these are indeed the last two user messages (from start) + # Verify that these are indeed the last user message (from start) assert cached_user_messages[0].content[0].text.startswith('You are OpenHands agent') - assert cached_user_messages[2].content[0].text.startswith('User message 1') - assert cached_user_messages[3].content[0].text.startswith('User message 1') + assert cached_user_messages[1].content[0].text.startswith('User message 14')