Skip to content

Commit

Permalink
Refactor: Don't serialize matching events when searching event stream (
Browse files Browse the repository at this point in the history
  • Loading branch information
malhotra5 authored Jan 28, 2025
1 parent 3534606 commit eb760f3
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 11 deletions.
4 changes: 2 additions & 2 deletions openhands/events/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,7 @@ def get_matching_events(
start_id: int = 0,
limit: int = 100,
reverse: bool = False,
) -> list:
) -> list[type[Event]]:
"""Get matching events from the event stream based on filters.
Args:
Expand Down Expand Up @@ -414,7 +414,7 @@ def get_matching_events(
):
continue

matching_events.append(event_to_dict(event))
matching_events.append(event)

# Stop if we have enough events
if len(matching_events) >= limit:
Expand Down
4 changes: 3 additions & 1 deletion openhands/server/routes/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from openhands.core.logger import openhands_logger as logger
from openhands.events.event import Event
from openhands.events.serialization.event import event_from_dict
from openhands.events.serialization.event import event_from_dict, event_to_dict
from openhands.runtime.base import Runtime

app = APIRouter(prefix='/api/conversations/{conversation_id}')
Expand Down Expand Up @@ -156,6 +156,8 @@ async def search_events(
has_more = len(matching_events) > limit
if has_more:
matching_events = matching_events[:limit] # Remove the extra event

matching_events = [event_to_dict(event) for event in matching_events]
return {
'events': matching_events,
'has_more': has_more,
Expand Down
25 changes: 17 additions & 8 deletions tests/unit/test_event_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pytest
from pytest import TempPathFactory

from openhands.core.schema.observation import ObservationType
from openhands.events import EventSource, EventStream
from openhands.events.action import (
NullAction,
Expand Down Expand Up @@ -78,20 +79,23 @@ def test_get_matching_events_type_filter(temp_dir: str):
# Filter by NullAction
events = event_stream.get_matching_events(event_types=(NullAction,))
assert len(events) == 2
assert all(e['action'] == 'null' for e in events)
assert all(isinstance(e, NullAction) for e in events)

# Filter by NullObservation
events = event_stream.get_matching_events(event_types=(NullObservation,))
assert len(events) == 1
assert events[0]['observation'] == 'null'
assert (
isinstance(events[0], NullObservation)
and events[0].observation == ObservationType.NULL
)

# Filter by NullAction and MessageAction
events = event_stream.get_matching_events(event_types=(NullAction, MessageAction))
assert len(events) == 3

# Filter in reverse
events = event_stream.get_matching_events(reverse=True, limit=1)
assert events[0]['message'] == 'test'
assert isinstance(events[0], MessageAction) and events[0].content == 'test'


def test_get_matching_events_query_search(temp_dir: str):
Expand Down Expand Up @@ -126,12 +130,17 @@ def test_get_matching_events_source_filter(temp_dir: str):
# Filter by AGENT source
events = event_stream.get_matching_events(source='agent')
assert len(events) == 2
assert all(e['source'] == 'agent' for e in events)
assert all(
isinstance(e, NullObservation) and e.source == EventSource.AGENT for e in events
)

# Filter by ENVIRONMENT source
events = event_stream.get_matching_events(source='environment')
assert len(events) == 1
assert events[0]['source'] == 'environment'
assert (
isinstance(events[0], NullObservation)
and events[0].source == EventSource.ENVIRONMENT
)


def test_get_matching_events_pagination(temp_dir: str):
Expand All @@ -149,13 +158,13 @@ def test_get_matching_events_pagination(temp_dir: str):
# Test start_id
events = event_stream.get_matching_events(start_id=2)
assert len(events) == 3
assert events[0]['content'] == 'test2'
assert isinstance(events[0], NullObservation) and events[0].content == 'test2'

# Test combination of start_id and limit
events = event_stream.get_matching_events(start_id=1, limit=2)
assert len(events) == 2
assert events[0]['content'] == 'test1'
assert events[1]['content'] == 'test2'
assert isinstance(events[0], NullObservation) and events[0].content == 'test1'
assert isinstance(events[1], NullObservation) and events[1].content == 'test2'


def test_get_matching_events_limit_validation(temp_dir: str):
Expand Down

0 comments on commit eb760f3

Please sign in to comment.