diff --git a/openhands/events/stream.py b/openhands/events/stream.py index 8a6cdfb35ae2..ca1025d0f163 100644 --- a/openhands/events/stream.py +++ b/openhands/events/stream.py @@ -384,7 +384,7 @@ def get_matching_events( start_id: int = 0, limit: int = 100, reverse: bool = False, - ) -> list: + ) -> list[type[Event]]: """Get matching events from the event stream based on filters. Args: @@ -414,7 +414,7 @@ def get_matching_events( ): continue - matching_events.append(event_to_dict(event)) + matching_events.append(event) # Stop if we have enough events if len(matching_events) >= limit: diff --git a/openhands/server/routes/conversation.py b/openhands/server/routes/conversation.py index c0cef6c4ae01..b91c8070b56f 100644 --- a/openhands/server/routes/conversation.py +++ b/openhands/server/routes/conversation.py @@ -3,7 +3,7 @@ from openhands.core.logger import openhands_logger as logger from openhands.events.event import Event -from openhands.events.serialization.event import event_from_dict +from openhands.events.serialization.event import event_from_dict, event_to_dict from openhands.runtime.base import Runtime app = APIRouter(prefix='/api/conversations/{conversation_id}') @@ -156,6 +156,8 @@ async def search_events( has_more = len(matching_events) > limit if has_more: matching_events = matching_events[:limit] # Remove the extra event + + matching_events = [event_to_dict(event) for event in matching_events] return { 'events': matching_events, 'has_more': has_more, diff --git a/tests/unit/test_event_stream.py b/tests/unit/test_event_stream.py index f414bc6e2994..4ac4faaa2561 100644 --- a/tests/unit/test_event_stream.py +++ b/tests/unit/test_event_stream.py @@ -3,6 +3,7 @@ import pytest from pytest import TempPathFactory +from openhands.core.schema.observation import ObservationType from openhands.events import EventSource, EventStream from openhands.events.action import ( NullAction, @@ -78,12 +79,15 @@ def test_get_matching_events_type_filter(temp_dir: str): # Filter by NullAction events = event_stream.get_matching_events(event_types=(NullAction,)) assert len(events) == 2 - assert all(e['action'] == 'null' for e in events) + assert all(isinstance(e, NullAction) for e in events) # Filter by NullObservation events = event_stream.get_matching_events(event_types=(NullObservation,)) assert len(events) == 1 - assert events[0]['observation'] == 'null' + assert ( + isinstance(events[0], NullObservation) + and events[0].observation == ObservationType.NULL + ) # Filter by NullAction and MessageAction events = event_stream.get_matching_events(event_types=(NullAction, MessageAction)) @@ -91,7 +95,7 @@ def test_get_matching_events_type_filter(temp_dir: str): # Filter in reverse events = event_stream.get_matching_events(reverse=True, limit=1) - assert events[0]['message'] == 'test' + assert isinstance(events[0], MessageAction) and events[0].content == 'test' def test_get_matching_events_query_search(temp_dir: str): @@ -126,12 +130,17 @@ def test_get_matching_events_source_filter(temp_dir: str): # Filter by AGENT source events = event_stream.get_matching_events(source='agent') assert len(events) == 2 - assert all(e['source'] == 'agent' for e in events) + assert all( + isinstance(e, NullObservation) and e.source == EventSource.AGENT for e in events + ) # Filter by ENVIRONMENT source events = event_stream.get_matching_events(source='environment') assert len(events) == 1 - assert events[0]['source'] == 'environment' + assert ( + isinstance(events[0], NullObservation) + and events[0].source == EventSource.ENVIRONMENT + ) def test_get_matching_events_pagination(temp_dir: str): @@ -149,13 +158,13 @@ def test_get_matching_events_pagination(temp_dir: str): # Test start_id events = event_stream.get_matching_events(start_id=2) assert len(events) == 3 - assert events[0]['content'] == 'test2' + assert isinstance(events[0], NullObservation) and events[0].content == 'test2' # Test combination of start_id and limit events = event_stream.get_matching_events(start_id=1, limit=2) assert len(events) == 2 - assert events[0]['content'] == 'test1' - assert events[1]['content'] == 'test2' + assert isinstance(events[0], NullObservation) and events[0].content == 'test1' + assert isinstance(events[1], NullObservation) and events[1].content == 'test2' def test_get_matching_events_limit_validation(temp_dir: str):