Skip to content

Commit

Permalink
Fix: Mocking LLM proxy in unit tests (#5639)
Browse files Browse the repository at this point in the history
  • Loading branch information
tofarr authored Dec 16, 2024
1 parent 239619a commit d76e83b
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 9 deletions.
32 changes: 25 additions & 7 deletions tests/unit/test_acompletion.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import asyncio
from unittest.mock import AsyncMock, patch
from contextlib import contextmanager
from typing import Type
from unittest.mock import AsyncMock, MagicMock, patch

import pytest

Expand All @@ -14,8 +16,12 @@

@pytest.fixture
def test_llm():
# Create a mock config for testing
return LLM(config=config.get_llm_config())
return _get_llm(LLM)


def _get_llm(type_: Type[LLM]):
with _patch_http():
return type_(config=config.get_llm_config())


@pytest.fixture
Expand All @@ -39,14 +45,26 @@ def mock_response():
]


@contextmanager
def _patch_http():
with patch('openhands.llm.llm.requests.get', MagicMock()) as mock_http:
mock_http.json.return_value = {
'data': [
{'model_name': 'some_model'},
{'model_name': 'another_model'},
]
}
yield


@pytest.mark.asyncio
async def test_acompletion_non_streaming():
with patch.object(AsyncLLM, '_call_acompletion') as mock_call_acompletion:
mock_response = {
'choices': [{'message': {'content': 'This is a test message.'}}]
}
mock_call_acompletion.return_value = mock_response
test_llm = AsyncLLM(config=config.get_llm_config())
test_llm = _get_llm(AsyncLLM)
response = await test_llm.async_completion(
messages=[{'role': 'user', 'content': 'Hello!'}],
stream=False,
Expand All @@ -60,7 +78,7 @@ async def test_acompletion_non_streaming():
async def test_acompletion_streaming(mock_response):
with patch.object(StreamingLLM, '_call_acompletion') as mock_call_acompletion:
mock_call_acompletion.return_value.__aiter__.return_value = iter(mock_response)
test_llm = StreamingLLM(config=config.get_llm_config())
test_llm = _get_llm(StreamingLLM)
async for chunk in test_llm.async_streaming_completion(
messages=[{'role': 'user', 'content': 'Hello!'}], stream=True
):
Expand Down Expand Up @@ -109,7 +127,7 @@ async def mock_acompletion(*args, **kwargs):
AsyncLLM, '_call_acompletion', new_callable=AsyncMock
) as mock_call_acompletion:
mock_call_acompletion.side_effect = mock_acompletion
test_llm = AsyncLLM(config=config.get_llm_config())
test_llm = _get_llm(AsyncLLM)

async def cancel_after_delay():
print(f'Starting cancel_after_delay with delay {cancel_delay}')
Expand Down Expand Up @@ -171,7 +189,7 @@ async def mock_acompletion(*args, **kwargs):
AsyncLLM, '_call_acompletion', new_callable=AsyncMock
) as mock_call_acompletion:
mock_call_acompletion.return_value = mock_acompletion()
test_llm = StreamingLLM(config=config.get_llm_config())
test_llm = _get_llm(StreamingLLM)

received_chunks = []
with pytest.raises(UserCancelledError):
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/test_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ async def test_session_is_running_in_cluster():
)
)
with (
patch('openhands.server.session.manager._REDIS_POLL_TIMEOUT', 0.05),
patch('openhands.server.session.manager._REDIS_POLL_TIMEOUT', 0.1),
):
async with SessionManager(
sio, AppConfig(), InMemoryFileStore()
Expand All @@ -87,7 +87,7 @@ async def test_init_new_local_session():
is_session_running_in_cluster_mock.return_value = False
with (
patch('openhands.server.session.manager.Session', mock_session),
patch('openhands.server.session.manager._REDIS_POLL_TIMEOUT', 0.01),
patch('openhands.server.session.manager._REDIS_POLL_TIMEOUT', 0.1),
patch(
'openhands.server.session.manager.SessionManager._redis_subscribe',
AsyncMock(),
Expand Down

0 comments on commit d76e83b

Please sign in to comment.