From 5b063cc11bb27fab722d20c21ae79bfde72ee2be Mon Sep 17 00:00:00 2001 From: Engel Nyst Date: Sat, 22 Feb 2025 20:18:34 +0100 Subject: [PATCH] add tests --- tests/unit/test_message_utils.py | 120 ++++++++++++++++++++++++++++++- 1 file changed, 118 insertions(+), 2 deletions(-) diff --git a/tests/unit/test_message_utils.py b/tests/unit/test_message_utils.py index d3114519c82b..34c52d861a23 100644 --- a/tests/unit/test_message_utils.py +++ b/tests/unit/test_message_utils.py @@ -3,13 +3,18 @@ import pytest from openhands.core.message import ImageContent, TextContent -from openhands.core.message_utils import get_action_message, get_observation_message +from openhands.core.message_utils import ( + get_action_message, + get_observation_message, + get_single_tokens_usage_for_event, + get_tokens_usage_for_event_id, +) from openhands.events.action import ( AgentFinishAction, CmdRunAction, MessageAction, ) -from openhands.events.event import EventSource, FileEditSource, FileReadSource +from openhands.events.event import Event, EventSource, FileEditSource, FileReadSource from openhands.events.observation.browse import BrowserOutputObservation from openhands.events.observation.commands import ( CmdOutputMetadata, @@ -21,6 +26,7 @@ from openhands.events.observation.files import FileEditObservation, FileReadObservation from openhands.events.observation.reject import UserRejectObservation from openhands.events.tool import ToolCallMetadata +from openhands.llm.metrics import Metrics, TokensUsage def test_cmd_output_observation_message(): @@ -269,3 +275,113 @@ def test_agent_finish_action_with_tool_metadata(): assert len(result.content) == 1 assert isinstance(result.content[0], TextContent) assert 'Initial thought\nTask completed' in result.content[0].text + + +def test_get_single_tokens_usage_for_event(): + """Test that we get the single matching usage record (if any) based on the event's model_response.id.""" + metrics = Metrics(model_name='test-model') + usage_record = TokensUsage( + model='test-model', + prompt_tokens=10, + completion_tokens=5, + cache_read_tokens=2, + cache_write_tokens=1, + response_id='test-response-id', + ) + metrics.add_tokens_usage( + prompt_tokens=usage_record.prompt_tokens, + completion_tokens=usage_record.completion_tokens, + cache_read_tokens=usage_record.cache_read_tokens, + cache_write_tokens=usage_record.cache_write_tokens, + response_id=usage_record.response_id, + ) + + # Create an event referencing that response_id + event = Event() + mock_tool_call_metadata = ToolCallMetadata( + tool_call_id='test-tool-call', + function_name='fake_function', + model_response={'id': 'test-response-id'}, + total_calls_in_response=1, + ) + event._tool_call_metadata = ( + mock_tool_call_metadata # normally you'd do event.tool_call_metadata = ... + ) + + # We should find that usage record + found = get_single_tokens_usage_for_event(event, metrics) + assert found is not None + assert found.prompt_tokens == 10 + assert found.response_id == 'test-response-id' + + # If we change the event's response ID, we won't find anything + mock_tool_call_metadata.model_response['id'] = 'some-other-id' + found2 = get_single_tokens_usage_for_event(event, metrics) + assert found2 is None + + # If the event has no tool_call_metadata, also returns None + event._tool_call_metadata = None + found3 = get_single_tokens_usage_for_event(event, metrics) + assert found3 is None + + +def test_get_tokens_usage_for_event_id(): + """ + Test that we search backward from the event with the given id, + finding the first usage record that matches a response_id in that or previous events. + """ + metrics = Metrics(model_name='test-model') + usage_1 = TokensUsage( + model='test-model', + prompt_tokens=12, + completion_tokens=3, + cache_read_tokens=2, + cache_write_tokens=5, + response_id='resp-1', + ) + usage_2 = TokensUsage( + model='test-model', + prompt_tokens=7, + completion_tokens=2, + cache_read_tokens=1, + cache_write_tokens=3, + response_id='resp-2', + ) + metrics._tokens_usages.append(usage_1) + metrics._tokens_usages.append(usage_2) + + # Build a list of events + events = [] + for i in range(5): + e = Event() + e._id = i + # We'll attach usage_1 to event 1, usage_2 to event 3 + if i == 1: + e._tool_call_metadata = ToolCallMetadata( + tool_call_id='tid1', + function_name='fn1', + model_response={'id': 'resp-1'}, + total_calls_in_response=1, + ) + elif i == 3: + e._tool_call_metadata = ToolCallMetadata( + tool_call_id='tid2', + function_name='fn2', + model_response={'id': 'resp-2'}, + total_calls_in_response=1, + ) + events.append(e) + + # If we ask for event_id=3, we find usage_2 immediately + found_3 = get_tokens_usage_for_event_id(events, 3, metrics) + assert found_3 is not None + assert found_3.response_id == 'resp-2' + + # If we ask for event_id=2, no usage in event2, so we check event1 -> usage_1 found + found_2 = get_tokens_usage_for_event_id(events, 2, metrics) + assert found_2 is not None + assert found_2.response_id == 'resp-1' + + # If we ask for event_id=0, no usage in event0 or earlier, so return None + found_0 = get_tokens_usage_for_event_id(events, 0, metrics) + assert found_0 is None