forked from All-Hands-AI/OpenHands
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_prompt_caching.py
130 lines (110 loc) Β· 4.43 KB
/
test_prompt_caching.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
from unittest.mock import Mock
import pytest
from litellm import ModelResponse
from openhands.agenthub.codeact_agent.codeact_agent import CodeActAgent
from openhands.core.config import AgentConfig, LLMConfig
from openhands.events.action import MessageAction
from openhands.llm.llm import LLM
@pytest.fixture
def mock_llm():
llm = LLM(
LLMConfig(
model='claude-3-5-sonnet-20241022',
api_key='fake',
caching_prompt=True,
)
)
return llm
@pytest.fixture
def codeact_agent(mock_llm):
config = AgentConfig()
agent = CodeActAgent(mock_llm, config)
return agent
def response_mock(content: str, tool_call_id: str):
class MockModelResponse:
def __init__(self, content, tool_call_id):
self.choices = [
{
'message': {
'content': content,
'tool_calls': [
{
'function': {
'id': tool_call_id,
'name': 'execute_bash',
'arguments': '{}',
}
}
],
}
}
]
def model_dump(self):
return {'choices': self.choices}
return ModelResponse(**MockModelResponse(content, tool_call_id).model_dump())
def test_get_messages(codeact_agent: CodeActAgent):
# Add some events to history
history = list()
message_action_1 = MessageAction('Initial user message')
message_action_1._source = 'user'
history.append(message_action_1)
message_action_2 = MessageAction('Sure!')
message_action_2._source = 'assistant'
history.append(message_action_2)
message_action_3 = MessageAction('Hello, agent!')
message_action_3._source = 'user'
history.append(message_action_3)
message_action_4 = MessageAction('Hello, user!')
message_action_4._source = 'assistant'
history.append(message_action_4)
message_action_5 = MessageAction('Laaaaaaaast!')
message_action_5._source = 'user'
history.append(message_action_5)
codeact_agent.reset()
messages = codeact_agent._get_messages(
Mock(history=history, max_iterations=5, iteration=0, extra_data={})
)
assert (
len(messages) == 6
) # System, initial user + user message, agent message, last user message
assert messages[0].content[0].cache_prompt # system message
assert messages[1].role == 'user'
assert messages[1].content[0].text.endswith('Initial user message')
# we add cache breakpoint to the last 3 user messages
assert messages[1].content[0].cache_prompt
assert messages[3].role == 'user'
assert messages[3].content[0].text == ('Hello, agent!')
assert messages[3].content[0].cache_prompt
assert messages[4].role == 'assistant'
assert messages[4].content[0].text == 'Hello, user!'
assert not messages[4].content[0].cache_prompt
assert messages[5].role == 'user'
assert messages[5].content[0].text.startswith('Laaaaaaaast!')
assert messages[5].content[0].cache_prompt
def test_get_messages_prompt_caching(codeact_agent: CodeActAgent):
history = list()
# Add multiple user and agent messages
for i in range(15):
message_action_user = MessageAction(f'User message {i}')
message_action_user._source = 'user'
history.append(message_action_user)
message_action_agent = MessageAction(f'Agent message {i}')
message_action_agent._source = 'assistant'
history.append(message_action_agent)
codeact_agent.reset()
messages = codeact_agent._get_messages(
Mock(history=history, max_iterations=10, iteration=5, extra_data={})
)
# Check that only the last two user messages have cache_prompt=True
cached_user_messages = [
msg
for msg in messages
if msg.role in ('user', 'system') and msg.content[0].cache_prompt
]
assert (
len(cached_user_messages) == 4
) # Including the initial system+user + 2 last user message
# Verify that these are indeed the last two user messages (from start)
assert cached_user_messages[0].content[0].text.startswith('You are OpenHands agent')
assert cached_user_messages[2].content[0].text.startswith('User message 1')
assert cached_user_messages[3].content[0].text.startswith('User message 1')