forked from All-Hands-AI/OpenHands
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtest_acompletion.py
197 lines (164 loc) Β· 7.06 KB
/
test_acompletion.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
import asyncio
from contextlib import contextmanager
from typing import Type
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from openhands.core.config import load_app_config
from openhands.core.exceptions import UserCancelledError
from openhands.llm.async_llm import AsyncLLM
from openhands.llm.llm import LLM
from openhands.llm.streaming_llm import StreamingLLM
config = load_app_config()
@pytest.fixture
def test_llm():
return _get_llm(LLM)
def _get_llm(type_: Type[LLM]):
with _patch_http():
return type_(config=config.get_llm_config())
@pytest.fixture
def mock_response():
return [
{'choices': [{'delta': {'content': 'This is a'}}]},
{'choices': [{'delta': {'content': ' test'}}]},
{'choices': [{'delta': {'content': ' message.'}}]},
{'choices': [{'delta': {'content': ' It is'}}]},
{'choices': [{'delta': {'content': ' a bit'}}]},
{'choices': [{'delta': {'content': ' longer'}}]},
{'choices': [{'delta': {'content': ' than'}}]},
{'choices': [{'delta': {'content': ' the'}}]},
{'choices': [{'delta': {'content': ' previous'}}]},
{'choices': [{'delta': {'content': ' one,'}}]},
{'choices': [{'delta': {'content': ' but'}}]},
{'choices': [{'delta': {'content': ' hopefully'}}]},
{'choices': [{'delta': {'content': ' still'}}]},
{'choices': [{'delta': {'content': ' short'}}]},
{'choices': [{'delta': {'content': ' enough.'}}]},
]
@contextmanager
def _patch_http():
with patch('openhands.llm.llm.requests.get', MagicMock()) as mock_http:
mock_http.json.return_value = {
'data': [
{'model_name': 'some_model'},
{'model_name': 'another_model'},
]
}
yield
@pytest.mark.asyncio
async def test_acompletion_non_streaming():
with patch.object(AsyncLLM, '_call_acompletion') as mock_call_acompletion:
mock_response = {
'choices': [{'message': {'content': 'This is a test message.'}}]
}
mock_call_acompletion.return_value = mock_response
test_llm = _get_llm(AsyncLLM)
response = await test_llm.async_completion(
messages=[{'role': 'user', 'content': 'Hello!'}],
stream=False,
drop_params=True,
)
# Assertions for non-streaming completion
assert response['choices'][0]['message']['content'] != ''
@pytest.mark.asyncio
async def test_acompletion_streaming(mock_response):
with patch.object(StreamingLLM, '_call_acompletion') as mock_call_acompletion:
mock_call_acompletion.return_value.__aiter__.return_value = iter(mock_response)
test_llm = _get_llm(StreamingLLM)
async for chunk in test_llm.async_streaming_completion(
messages=[{'role': 'user', 'content': 'Hello!'}], stream=True
):
print(f"Chunk: {chunk['choices'][0]['delta']['content']}")
# Assertions for streaming completion
assert chunk['choices'][0]['delta']['content'] in [
r['choices'][0]['delta']['content'] for r in mock_response
]
@pytest.mark.asyncio
async def test_completion(test_llm):
with patch.object(LLM, 'completion') as mock_completion:
mock_completion.return_value = {
'choices': [{'message': {'content': 'This is a test message.'}}]
}
response = test_llm.completion(messages=[{'role': 'user', 'content': 'Hello!'}])
assert response['choices'][0]['message']['content'] == 'This is a test message.'
@pytest.mark.asyncio
@pytest.mark.parametrize('cancel_delay', [0.1, 0.3, 0.5, 0.7, 0.9])
async def test_async_completion_with_user_cancellation(cancel_delay):
cancel_event = asyncio.Event()
async def mock_on_cancel_requested():
is_set = cancel_event.is_set()
print(f'Cancel requested: {is_set}')
return is_set
async def mock_acompletion(*args, **kwargs):
print('Starting mock_acompletion')
for i in range(20): # Increased iterations for longer running task
print(f'mock_acompletion iteration {i}')
await asyncio.sleep(0.1)
if await mock_on_cancel_requested():
print('Cancellation detected in mock_acompletion')
raise UserCancelledError('LLM request cancelled by user')
print('Completing mock_acompletion without cancellation')
return {'choices': [{'message': {'content': 'This is a test message.'}}]}
with patch.object(
AsyncLLM, '_call_acompletion', new_callable=AsyncMock
) as mock_call_acompletion:
mock_call_acompletion.side_effect = mock_acompletion
test_llm = _get_llm(AsyncLLM)
async def cancel_after_delay():
print(f'Starting cancel_after_delay with delay {cancel_delay}')
await asyncio.sleep(cancel_delay)
print('Setting cancel event')
cancel_event.set()
with pytest.raises(UserCancelledError):
await asyncio.gather(
test_llm.async_completion(
messages=[{'role': 'user', 'content': 'Hello!'}],
stream=False,
),
cancel_after_delay(),
)
# Ensure the mock was called
mock_call_acompletion.assert_called_once()
@pytest.mark.asyncio
@pytest.mark.parametrize('cancel_after_chunks', [1, 3, 5, 7, 9])
async def test_async_streaming_completion_with_user_cancellation(cancel_after_chunks):
cancel_requested = False
test_messages = [
'This is ',
'a test ',
'message ',
'with ',
'multiple ',
'chunks ',
'to ',
'simulate ',
'a ',
'longer ',
'streaming ',
'response.',
]
async def mock_acompletion(*args, **kwargs):
for i, content in enumerate(test_messages):
yield {'choices': [{'delta': {'content': content}}]}
if i + 1 == cancel_after_chunks:
nonlocal cancel_requested
cancel_requested = True
if cancel_requested:
raise UserCancelledError('LLM request cancelled by user')
await asyncio.sleep(0.05) # Simulate some delay between chunks
with patch.object(
AsyncLLM, '_call_acompletion', new_callable=AsyncMock
) as mock_call_acompletion:
mock_call_acompletion.return_value = mock_acompletion()
test_llm = _get_llm(StreamingLLM)
received_chunks = []
with pytest.raises(UserCancelledError):
async for chunk in test_llm.async_streaming_completion(
messages=[{'role': 'user', 'content': 'Hello!'}], stream=True
):
received_chunks.append(chunk['choices'][0]['delta']['content'])
print(f"Chunk: {chunk['choices'][0]['delta']['content']}")
# Assert that we received the expected number of chunks before cancellation
assert len(received_chunks) == cancel_after_chunks
assert received_chunks == test_messages[:cancel_after_chunks]
# Ensure the mock was called
mock_call_acompletion.assert_called_once()