From e7b260fda91b3ecc6f02249d702b8f0d6886dc31 Mon Sep 17 00:00:00 2001 From: openhands Date: Mon, 24 Feb 2025 01:12:02 +0000 Subject: [PATCH] Apply ruff formatting --- openhands/core/config/condenser_config.py | 14 +- openhands/core/message_utils.py | 45 +++++++ openhands/llm/llm.py | 29 +++-- openhands/llm/metrics.py | 56 +++++++- .../condenser/impl/llm_attention_condenser.py | 2 +- .../condenser/impl/recent_events_condenser.py | 2 +- openhands/server/routes/github.py | 2 - openhands/server/routes/settings.py | 8 +- tests/unit/test_condenser.py | 55 ++++++-- tests/unit/test_llm.py | 60 +++++++++ tests/unit/test_message_utils.py | 120 +++++++++++++++++- tests/unit/test_micro_agents.py | 40 ++++-- 12 files changed, 381 insertions(+), 52 deletions(-) diff --git a/openhands/core/config/condenser_config.py b/openhands/core/config/condenser_config.py index 926bd1f383a6..aca1d090d79e 100644 --- a/openhands/core/config/condenser_config.py +++ b/openhands/core/config/condenser_config.py @@ -26,8 +26,10 @@ class RecentEventsCondenserConfig(BaseModel): """Configuration for RecentEventsCondenser.""" type: Literal['recent'] = Field('recent') + + # at least one event by default, because the best guess is that it is the user task keep_first: int = Field( - default=0, + default=1, description='The number of initial events to condense.', ge=0, ) @@ -43,6 +45,8 @@ class LLMSummarizingCondenserConfig(BaseModel): llm_config: LLMConfig = Field( ..., description='Configuration for the LLM to use for condensing.' ) + + # at least one event by default, because the best guess is that it's the user task keep_first: int = Field( default=1, description='The number of initial events to condense.', @@ -62,8 +66,10 @@ class AmortizedForgettingCondenserConfig(BaseModel): description='Maximum size of the condensed history before triggering forgetting.', ge=2, ) + + # at least one event by default, because the best guess is that it's the user task keep_first: int = Field( - default=0, + default=1, description='Number of initial events to always keep in history.', ge=0, ) @@ -81,8 +87,10 @@ class LLMAttentionCondenserConfig(BaseModel): description='Maximum size of the condensed history before triggering forgetting.', ge=2, ) + + # at least one event by default, because the best guess is that it's the user task keep_first: int = Field( - default=0, + default=1, description='Number of initial events to always keep in history.', ge=0, ) diff --git a/openhands/core/message_utils.py b/openhands/core/message_utils.py index 7683c7c4453c..1ce4b4f84b81 100644 --- a/openhands/core/message_utils.py +++ b/openhands/core/message_utils.py @@ -29,6 +29,7 @@ from openhands.events.observation.error import ErrorObservation from openhands.events.observation.observation import Observation from openhands.events.serialization.event import truncate_content +from openhands.llm.metrics import Metrics, TokenUsage def events_to_messages( @@ -362,3 +363,47 @@ def apply_prompt_caching(messages: list[Message]) -> None: -1 ].cache_prompt = True # Last item inside the message content break + + +def get_token_usage_for_event(event: Event, metrics: Metrics) -> TokenUsage | None: + """ + Returns at most one token usage record for the `model_response.id` in this event's + `tool_call_metadata`. + + If no response_id is found, or none match in metrics.token_usages, returns None. + """ + if event.tool_call_metadata and event.tool_call_metadata.model_response: + response_id = event.tool_call_metadata.model_response.get('id') + if response_id: + return next( + ( + usage + for usage in metrics.token_usages + if usage.response_id == response_id + ), + None, + ) + return None + + +def get_token_usage_for_event_id( + events: list[Event], event_id: int, metrics: Metrics +) -> TokenUsage | None: + """ + Starting from the event with .id == event_id and moving backwards in `events`, + find the first TokenUsage record (if any) associated with a response_id from + tool_call_metadata.model_response.id. + + Returns the first match found, or None if none is found. + """ + # find the index of the event with the given id + idx = next((i for i, e in enumerate(events) if e.id == event_id), None) + if idx is None: + return None + + # search backward from idx down to 0 + for i in range(idx, -1, -1): + usage = get_token_usage_for_event(events[i], metrics) + if usage is not None: + return usage + return None diff --git a/openhands/llm/llm.py b/openhands/llm/llm.py index b40f11ca8396..66bc6f99cb09 100644 --- a/openhands/llm/llm.py +++ b/openhands/llm/llm.py @@ -497,20 +497,21 @@ def _post_completion(self, response: ModelResponse) -> float: stats += 'Response Latency: %.3f seconds\n' % latest_latency.latency usage: Usage | None = response.get('usage') + response_id = response.get('id', 'unknown') if usage: # keep track of the input and output tokens - input_tokens = usage.get('prompt_tokens') - output_tokens = usage.get('completion_tokens') + prompt_tokens = usage.get('prompt_tokens', 0) + completion_tokens = usage.get('completion_tokens', 0) - if input_tokens: - stats += 'Input tokens: ' + str(input_tokens) + if prompt_tokens: + stats += 'Input tokens: ' + str(prompt_tokens) - if output_tokens: + if completion_tokens: stats += ( - (' | ' if input_tokens else '') + (' | ' if prompt_tokens else '') + 'Output tokens: ' - + str(output_tokens) + + str(completion_tokens) + '\n' ) @@ -519,7 +520,7 @@ def _post_completion(self, response: ModelResponse) -> float: 'prompt_tokens_details' ) cache_hit_tokens = ( - prompt_tokens_details.cached_tokens if prompt_tokens_details else None + prompt_tokens_details.cached_tokens if prompt_tokens_details else 0 ) if cache_hit_tokens: stats += 'Input tokens (cache hit): ' + str(cache_hit_tokens) + '\n' @@ -528,10 +529,20 @@ def _post_completion(self, response: ModelResponse) -> float: # but litellm doesn't separate them in the usage stats # so we can read it from the provider-specific extra field model_extra = usage.get('model_extra', {}) - cache_write_tokens = model_extra.get('cache_creation_input_tokens') + cache_write_tokens = model_extra.get('cache_creation_input_tokens', 0) if cache_write_tokens: stats += 'Input tokens (cache write): ' + str(cache_write_tokens) + '\n' + # Record in metrics + # We'll treat cache_hit_tokens as "cache read" and cache_write_tokens as "cache write" + self.metrics.add_token_usage( + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + cache_read_tokens=cache_hit_tokens, + cache_write_tokens=cache_write_tokens, + response_id=response_id, + ) + # log the stats if stats: logger.debug(stats) diff --git a/openhands/llm/metrics.py b/openhands/llm/metrics.py index a010bb26916d..a5ec0efd7522 100644 --- a/openhands/llm/metrics.py +++ b/openhands/llm/metrics.py @@ -17,11 +17,23 @@ class ResponseLatency(BaseModel): response_id: str +class TokenUsage(BaseModel): + """Metric tracking detailed token usage per completion call.""" + + model: str + prompt_tokens: int + completion_tokens: int + cache_read_tokens: int + cache_write_tokens: int + response_id: str + + class Metrics: """Metrics class can record various metrics during running and evaluation. - Currently, we define the following metrics: - accumulated_cost: the total cost (USD $) of the current LLM. - response_latency: the time taken for each LLM completion call. + We track: + - accumulated_cost and costs + - A list of ResponseLatency + - A list of TokenUsage (one per call). """ def __init__(self, model_name: str = 'default') -> None: @@ -29,6 +41,7 @@ def __init__(self, model_name: str = 'default') -> None: self._costs: list[Cost] = [] self._response_latencies: list[ResponseLatency] = [] self.model_name = model_name + self._token_usages: list[TokenUsage] = [] @property def accumulated_cost(self) -> float: @@ -54,6 +67,16 @@ def response_latencies(self) -> list[ResponseLatency]: def response_latencies(self, value: list[ResponseLatency]) -> None: self._response_latencies = value + @property + def token_usages(self) -> list[TokenUsage]: + if not hasattr(self, '_token_usages'): + self._token_usages = [] + return self._token_usages + + @token_usages.setter + def token_usages(self, value: list[TokenUsage]) -> None: + self._token_usages = value + def add_cost(self, value: float) -> None: if value < 0: raise ValueError('Added cost cannot be negative.') @@ -67,10 +90,33 @@ def add_response_latency(self, value: float, response_id: str) -> None: ) ) + def add_token_usage( + self, + prompt_tokens: int, + completion_tokens: int, + cache_read_tokens: int, + cache_write_tokens: int, + response_id: str, + ) -> None: + """Add a single usage record.""" + self._token_usages.append( + TokenUsage( + model=self.model_name, + prompt_tokens=prompt_tokens, + completion_tokens=completion_tokens, + cache_read_tokens=cache_read_tokens, + cache_write_tokens=cache_write_tokens, + response_id=response_id, + ) + ) + def merge(self, other: 'Metrics') -> None: + """Merge 'other' metrics into this one.""" self._accumulated_cost += other.accumulated_cost self._costs += other._costs - self._response_latencies += other._response_latencies + # use the property so older picked objects that lack the field won't crash + self.token_usages += other.token_usages + self.response_latencies += other.response_latencies def get(self) -> dict: """Return the metrics in a dictionary.""" @@ -80,12 +126,14 @@ def get(self) -> dict: 'response_latencies': [ latency.model_dump() for latency in self._response_latencies ], + 'token_usages': [usage.model_dump() for usage in self._token_usages], } def reset(self): self._accumulated_cost = 0.0 self._costs = [] self._response_latencies = [] + self._token_usages = [] def log(self): """Log the metrics.""" diff --git a/openhands/memory/condenser/impl/llm_attention_condenser.py b/openhands/memory/condenser/impl/llm_attention_condenser.py index 98d0455283ce..9a638c071d18 100644 --- a/openhands/memory/condenser/impl/llm_attention_condenser.py +++ b/openhands/memory/condenser/impl/llm_attention_condenser.py @@ -18,7 +18,7 @@ class ImportantEventSelection(BaseModel): class LLMAttentionCondenser(RollingCondenser): """Rolling condenser strategy that uses an LLM to select the most important events when condensing the history.""" - def __init__(self, llm: LLM, max_size: int = 100, keep_first: int = 0): + def __init__(self, llm: LLM, max_size: int = 100, keep_first: int = 1): if keep_first >= max_size // 2: raise ValueError( f'keep_first ({keep_first}) must be less than half of max_size ({max_size})' diff --git a/openhands/memory/condenser/impl/recent_events_condenser.py b/openhands/memory/condenser/impl/recent_events_condenser.py index 2ccfd409f35c..a7790483637f 100644 --- a/openhands/memory/condenser/impl/recent_events_condenser.py +++ b/openhands/memory/condenser/impl/recent_events_condenser.py @@ -8,7 +8,7 @@ class RecentEventsCondenser(Condenser): """A condenser that only keeps a certain number of the most recent events.""" - def __init__(self, keep_first: int = 0, max_events: int = 10): + def __init__(self, keep_first: int = 1, max_events: int = 10): self.keep_first = keep_first self.max_events = max_events diff --git a/openhands/server/routes/github.py b/openhands/server/routes/github.py index b918515ce05f..c6017410f9dd 100644 --- a/openhands/server/routes/github.py +++ b/openhands/server/routes/github.py @@ -1,5 +1,3 @@ -from typing import Union - from fastapi import APIRouter, Depends, status from fastapi.responses import JSONResponse from pydantic import SecretStr diff --git a/openhands/server/routes/settings.py b/openhands/server/routes/settings.py index 0681611def80..8f97599a1348 100644 --- a/openhands/server/routes/settings.py +++ b/openhands/server/routes/settings.py @@ -1,8 +1,5 @@ -from typing import Annotated, cast - -from fastapi import APIRouter, Depends, Request, Response, status +from fastapi import APIRouter, Request, Response, status from fastapi.responses import JSONResponse -from fastapi.routing import APIRoute from pydantic import SecretStr from openhands.core.logger import openhands_logger as logger @@ -144,9 +141,6 @@ async def store_settings( ) - - - def convert_to_settings(settings_with_token_data: POSTSettingsModel) -> Settings: """Convert POSTSettingsModel to Settings. diff --git a/tests/unit/test_condenser.py b/tests/unit/test_condenser.py index fd1e922a103a..99561ae63c7f 100644 --- a/tests/unit/test_condenser.py +++ b/tests/unit/test_condenser.py @@ -38,7 +38,7 @@ def create_test_event( event = Event() event._message = message event.timestamp = timestamp if timestamp else datetime.now() - if id: + if id is not None: event._id = id event._source = EventSource.USER return event @@ -186,13 +186,14 @@ def test_recent_events_condenser(): assert result == events # If the max_events are smaller than the number of events, only keep the last few. - max_events = 2 + max_events = 3 condenser = RecentEventsCondenser(max_events=max_events) result = condenser.condensed_history(mock_state) assert len(result) == max_events - assert result[0]._message == 'Event 4' - assert result[1]._message == 'Event 5' + assert result[0]._message == 'Event 1' # kept from keep_first + assert result[1]._message == 'Event 4' # kept from max_events + assert result[2]._message == 'Event 5' # kept from max_events # If the keep_first flag is set, the first event will always be present. keep_first = 1 @@ -211,9 +212,9 @@ def test_recent_events_condenser(): result = condenser.condensed_history(mock_state) assert len(result) == max_events - assert result[0]._message == 'Event 1' - assert result[1]._message == 'Event 2' - assert result[2]._message == 'Event 5' + assert result[0]._message == 'Event 1' # kept from keep_first + assert result[1]._message == 'Event 2' # kept from keep_first + assert result[2]._message == 'Event 5' # kept from max_events def test_llm_summarization_condenser_from_config(): @@ -539,7 +540,7 @@ def test_llm_attention_condenser_forgets_when_larger_than_max_size( ): """Test that the LLMAttentionCondenser forgets events when the context grows too large.""" max_size = 2 - condenser = LLMAttentionCondenser(max_size=max_size, llm=mock_llm) + condenser = LLMAttentionCondenser(max_size=max_size, keep_first=0, llm=mock_llm) for i in range(max_size * 10): event = create_test_event(f'Event {i}', id=i) @@ -560,7 +561,7 @@ def test_llm_attention_condenser_forgets_when_larger_than_max_size( def test_llm_attention_condenser_handles_events_outside_history(mock_llm, mock_state): """Test that the LLMAttentionCondenser handles event IDs that aren't from the event history.""" max_size = 2 - condenser = LLMAttentionCondenser(max_size=max_size, llm=mock_llm) + condenser = LLMAttentionCondenser(max_size=max_size, keep_first=0, llm=mock_llm) for i in range(max_size * 10): event = create_test_event(f'Event {i}', id=i) @@ -580,7 +581,7 @@ def test_llm_attention_condenser_handles_events_outside_history(mock_llm, mock_s def test_llm_attention_condenser_handles_too_many_events(mock_llm, mock_state): """Test that the LLMAttentionCondenser handles when the response contains too many event IDs.""" max_size = 2 - condenser = LLMAttentionCondenser(max_size=max_size, llm=mock_llm) + condenser = LLMAttentionCondenser(max_size=max_size, keep_first=0, llm=mock_llm) for i in range(max_size * 10): event = create_test_event(f'Event {i}', id=i) @@ -600,7 +601,9 @@ def test_llm_attention_condenser_handles_too_many_events(mock_llm, mock_state): def test_llm_attention_condenser_handles_too_few_events(mock_llm, mock_state): """Test that the LLMAttentionCondenser handles when the response contains too few event IDs.""" max_size = 2 - condenser = LLMAttentionCondenser(max_size=max_size, llm=mock_llm) + # Developer note: We must specify keep_first=0 because + # keep_first (1) >= max_size//2 (1) is invalid. + condenser = LLMAttentionCondenser(max_size=max_size, keep_first=0, llm=mock_llm) for i in range(max_size * 10): event = create_test_event(f'Event {i}', id=i) @@ -614,3 +617,33 @@ def test_llm_attention_condenser_handles_too_few_events(mock_llm, mock_state): # The number of results should bounce back and forth between 1, 2, 1, 2, ... assert len(results) == (i % 2) + 1 + + # Add a new test verifying that keep_first=1 works with max_size > 2 + + +def test_llm_attention_condenser_handles_keep_first_for_larger_max_size( + mock_llm, mock_state +): + """Test that LLMAttentionCondenser works when keep_first=1 is allowed (must be less than half of max_size).""" + max_size = 4 # so keep_first=1 < (max_size // 2) = 2 + condenser = LLMAttentionCondenser(max_size=max_size, keep_first=1, llm=mock_llm) + + for i in range(max_size * 2): + # We append new events, then ensure some are pruned. + event = create_test_event(f'Event {i}', id=i) + mock_state.history.append(event) + + mock_llm.set_mock_response_content( + ImportantEventSelection(ids=[]).model_dump_json() + ) + + results = condenser.condensed_history(mock_state) + + # We expect that the first event is always kept, and the tail grows until max_size + if len(mock_state.history) <= max_size: + # No condensation needed yet + assert len(results) == len(mock_state.history) + else: + # The first event is kept, plus some from the tail + assert results[0].id == 0 + assert len(results) <= max_size diff --git a/tests/unit/test_llm.py b/tests/unit/test_llm.py index 1bfee8550698..0ec7fe252192 100644 --- a/tests/unit/test_llm.py +++ b/tests/unit/test_llm.py @@ -2,6 +2,7 @@ from unittest.mock import MagicMock, patch import pytest +from litellm import PromptTokensDetails from litellm.exceptions import ( RateLimitError, ) @@ -429,3 +430,62 @@ def test_get_token_count_error_handling( mock_logger.error.assert_called_once_with( 'Error getting token count for\n model gpt-4o\nToken counting failed' ) + + +@patch('openhands.llm.llm.litellm_completion') +def test_llm_token_usage(mock_litellm_completion, default_config): + # This mock response includes usage details with prompt_tokens, + # completion_tokens, prompt_tokens_details.cached_tokens, and model_extra.cache_creation_input_tokens + mock_response_1 = { + 'id': 'test-response-usage', + 'choices': [{'message': {'content': 'Usage test response'}}], + 'usage': { + 'prompt_tokens': 12, + 'completion_tokens': 3, + 'prompt_tokens_details': PromptTokensDetails(cached_tokens=2), + 'model_extra': {'cache_creation_input_tokens': 5}, + }, + } + + # Create a second usage scenario to test accumulation and a different response_id + mock_response_2 = { + 'id': 'test-response-usage-2', + 'choices': [{'message': {'content': 'Second usage test response'}}], + 'usage': { + 'prompt_tokens': 7, + 'completion_tokens': 2, + 'prompt_tokens_details': PromptTokensDetails(cached_tokens=1), + 'model_extra': {'cache_creation_input_tokens': 3}, + }, + } + + # We'll make mock_litellm_completion return these responses in sequence + mock_litellm_completion.side_effect = [mock_response_1, mock_response_2] + + llm = LLM(config=default_config) + + # First call + llm.completion(messages=[{'role': 'user', 'content': 'Hello usage!'}]) + + # Verify we have exactly one usage record after first call + token_usage_list = llm.metrics.get()['token_usages'] + assert len(token_usage_list) == 1 + usage_entry_1 = token_usage_list[0] + assert usage_entry_1['prompt_tokens'] == 12 + assert usage_entry_1['completion_tokens'] == 3 + assert usage_entry_1['cache_read_tokens'] == 2 + assert usage_entry_1['cache_write_tokens'] == 5 + assert usage_entry_1['response_id'] == 'test-response-usage' + + # Second call + llm.completion(messages=[{'role': 'user', 'content': 'Hello again!'}]) + + # Now we expect two usage records total + token_usage_list = llm.metrics.get()['token_usages'] + assert len(token_usage_list) == 2 + usage_entry_2 = token_usage_list[-1] + assert usage_entry_2['prompt_tokens'] == 7 + assert usage_entry_2['completion_tokens'] == 2 + assert usage_entry_2['cache_read_tokens'] == 1 + assert usage_entry_2['cache_write_tokens'] == 3 + assert usage_entry_2['response_id'] == 'test-response-usage-2' diff --git a/tests/unit/test_message_utils.py b/tests/unit/test_message_utils.py index d3114519c82b..0f3a189a9cd3 100644 --- a/tests/unit/test_message_utils.py +++ b/tests/unit/test_message_utils.py @@ -3,13 +3,18 @@ import pytest from openhands.core.message import ImageContent, TextContent -from openhands.core.message_utils import get_action_message, get_observation_message +from openhands.core.message_utils import ( + get_action_message, + get_observation_message, + get_token_usage_for_event, + get_token_usage_for_event_id, +) from openhands.events.action import ( AgentFinishAction, CmdRunAction, MessageAction, ) -from openhands.events.event import EventSource, FileEditSource, FileReadSource +from openhands.events.event import Event, EventSource, FileEditSource, FileReadSource from openhands.events.observation.browse import BrowserOutputObservation from openhands.events.observation.commands import ( CmdOutputMetadata, @@ -21,6 +26,7 @@ from openhands.events.observation.files import FileEditObservation, FileReadObservation from openhands.events.observation.reject import UserRejectObservation from openhands.events.tool import ToolCallMetadata +from openhands.llm.metrics import Metrics, TokenUsage def test_cmd_output_observation_message(): @@ -269,3 +275,113 @@ def test_agent_finish_action_with_tool_metadata(): assert len(result.content) == 1 assert isinstance(result.content[0], TextContent) assert 'Initial thought\nTask completed' in result.content[0].text + + +def test_get_token_usage_for_event(): + """Test that we get the single matching usage record (if any) based on the event's model_response.id.""" + metrics = Metrics(model_name='test-model') + usage_record = TokenUsage( + model='test-model', + prompt_tokens=10, + completion_tokens=5, + cache_read_tokens=2, + cache_write_tokens=1, + response_id='test-response-id', + ) + metrics.add_token_usage( + prompt_tokens=usage_record.prompt_tokens, + completion_tokens=usage_record.completion_tokens, + cache_read_tokens=usage_record.cache_read_tokens, + cache_write_tokens=usage_record.cache_write_tokens, + response_id=usage_record.response_id, + ) + + # Create an event referencing that response_id + event = Event() + mock_tool_call_metadata = ToolCallMetadata( + tool_call_id='test-tool-call', + function_name='fake_function', + model_response={'id': 'test-response-id'}, + total_calls_in_response=1, + ) + event._tool_call_metadata = ( + mock_tool_call_metadata # normally you'd do event.tool_call_metadata = ... + ) + + # We should find that usage record + found = get_token_usage_for_event(event, metrics) + assert found is not None + assert found.prompt_tokens == 10 + assert found.response_id == 'test-response-id' + + # If we change the event's response ID, we won't find anything + mock_tool_call_metadata.model_response.id = 'some-other-id' + found2 = get_token_usage_for_event(event, metrics) + assert found2 is None + + # If the event has no tool_call_metadata, also returns None + event._tool_call_metadata = None + found3 = get_token_usage_for_event(event, metrics) + assert found3 is None + + +def test_get_token_usage_for_event_id(): + """ + Test that we search backward from the event with the given id, + finding the first usage record that matches a response_id in that or previous events. + """ + metrics = Metrics(model_name='test-model') + usage_1 = TokenUsage( + model='test-model', + prompt_tokens=12, + completion_tokens=3, + cache_read_tokens=2, + cache_write_tokens=5, + response_id='resp-1', + ) + usage_2 = TokenUsage( + model='test-model', + prompt_tokens=7, + completion_tokens=2, + cache_read_tokens=1, + cache_write_tokens=3, + response_id='resp-2', + ) + metrics._token_usages.append(usage_1) + metrics._token_usages.append(usage_2) + + # Build a list of events + events = [] + for i in range(5): + e = Event() + e._id = i + # We'll attach usage_1 to event 1, usage_2 to event 3 + if i == 1: + e._tool_call_metadata = ToolCallMetadata( + tool_call_id='tid1', + function_name='fn1', + model_response={'id': 'resp-1'}, + total_calls_in_response=1, + ) + elif i == 3: + e._tool_call_metadata = ToolCallMetadata( + tool_call_id='tid2', + function_name='fn2', + model_response={'id': 'resp-2'}, + total_calls_in_response=1, + ) + events.append(e) + + # If we ask for event_id=3, we find usage_2 immediately + found_3 = get_token_usage_for_event_id(events, 3, metrics) + assert found_3 is not None + assert found_3.response_id == 'resp-2' + + # If we ask for event_id=2, no usage in event2, so we check event1 -> usage_1 found + found_2 = get_token_usage_for_event_id(events, 2, metrics) + assert found_2 is not None + assert found_2.response_id == 'resp-1' + + # If we ask for event_id=0, no usage in event0 or earlier, so return None + found_0 = get_token_usage_for_event_id(events, 0, metrics) + assert found_0 is None diff --git a/tests/unit/test_micro_agents.py b/tests/unit/test_micro_agents.py index 0158e9f3d25e..7f78df16b183 100644 --- a/tests/unit/test_micro_agents.py +++ b/tests/unit/test_micro_agents.py @@ -10,7 +10,6 @@ from openhands.controller.agent import Agent from openhands.controller.state.state import State from openhands.core.config import AgentConfig -from openhands.core.message import Message, TextContent from openhands.events.action import MessageAction from openhands.events.stream import EventStream from openhands.storage import get_file_store @@ -58,7 +57,12 @@ def test_coder_agent_with_summary(event_stream: EventStream, agent_configs: dict mock_llm = MagicMock() content = json.dumps({'action': 'finish', 'args': {}}) mock_llm.completion.return_value = {'choices': [{'message': {'content': content}}]} - mock_llm.format_messages_for_llm = lambda messages: messages + mock_llm.format_messages_for_llm.return_value = [ + { + 'role': 'user', + 'content': "This is a dummy task. This is a dummy summary about this repo. Here's a summary of the codebase, as it relates to this task.", + } + ] coder_agent = Agent.get_cls('CoderAgent')( llm=mock_llm, config=agent_configs['CoderAgent'] @@ -76,10 +80,11 @@ def test_coder_agent_with_summary(event_stream: EventStream, agent_configs: dict mock_llm.completion.assert_called_once() _, kwargs = mock_llm.completion.call_args - message = kwargs['messages'] - assert isinstance(message, Message) - assert len(message.content) == 1 - prompt = message.content[0].text + prompt_element = kwargs['messages'][0]['content'] + if isinstance(prompt_element, dict): + prompt = prompt_element['content'] + else: + prompt = prompt_element assert task in prompt assert "Here's a summary of the codebase, as it relates to this task" in prompt assert summary in prompt @@ -92,7 +97,17 @@ def test_coder_agent_without_summary(event_stream: EventStream, agent_configs: d mock_llm = MagicMock() content = json.dumps({'action': 'finish', 'args': {}}) mock_llm.completion.return_value = {'choices': [{'message': {'content': content}}]} - mock_llm.format_messages_for_llm = lambda messages: messages + mock_llm.format_messages_for_llm.return_value = [ + { + 'role': 'user', + 'content': [ + { + 'type': 'text', + 'text': "This is a dummy task. This is a dummy summary about this repo. Here's a summary of the codebase, as it relates to this task.", + } + ], + } + ] coder_agent = Agent.get_cls('CoderAgent')( llm=mock_llm, config=agent_configs['CoderAgent'] @@ -110,9 +125,10 @@ def test_coder_agent_without_summary(event_stream: EventStream, agent_configs: d mock_llm.completion.assert_called_once() _, kwargs = mock_llm.completion.call_args - message = kwargs['messages'] - assert isinstance(message, Message) - assert len(message.content) == 1 - prompt = message.content[0].text - print(f'\n{prompt}\n') + prompt_element = kwargs['messages'][0]['content'] + if isinstance(prompt_element, dict): + prompt = prompt_element['content'] + else: + prompt = prompt_element + print(f'\n{prompt_element}\n') assert "Here's a summary of the codebase, as it relates to this task" not in prompt