Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor: Don't serialize matching events when searching event stream #6509

Merged
merged 2 commits into from
Jan 28, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions openhands/events/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,7 +384,7 @@ def get_matching_events(
start_id: int = 0,
limit: int = 100,
reverse: bool = False,
) -> list:
) -> list[type[Event]]:
"""Get matching events from the event stream based on filters.

Args:
Expand Down Expand Up @@ -414,7 +414,7 @@ def get_matching_events(
):
continue

matching_events.append(event_to_dict(event))
matching_events.append(event)

# Stop if we have enough events
if len(matching_events) >= limit:
Expand Down
4 changes: 3 additions & 1 deletion openhands/server/routes/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from openhands.core.logger import openhands_logger as logger
from openhands.events.event import Event
from openhands.events.serialization.event import event_from_dict
from openhands.events.serialization.event import event_from_dict, event_to_dict
from openhands.runtime.base import Runtime

app = APIRouter(prefix='/api/conversations/{conversation_id}')
Expand Down Expand Up @@ -156,6 +156,8 @@ async def search_events(
has_more = len(matching_events) > limit
if has_more:
matching_events = matching_events[:limit] # Remove the extra event

matching_events = [event_to_dict(event) for event in matching_events]
return {
'events': matching_events,
'has_more': has_more,
Expand Down
25 changes: 17 additions & 8 deletions tests/unit/test_event_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pytest
from pytest import TempPathFactory

from openhands.core.schema.observation import ObservationType
from openhands.events import EventSource, EventStream
from openhands.events.action import (
NullAction,
Expand Down Expand Up @@ -78,20 +79,23 @@ def test_get_matching_events_type_filter(temp_dir: str):
# Filter by NullAction
events = event_stream.get_matching_events(event_types=(NullAction,))
assert len(events) == 2
assert all(e['action'] == 'null' for e in events)
assert all(isinstance(e, NullAction) for e in events)

# Filter by NullObservation
events = event_stream.get_matching_events(event_types=(NullObservation,))
assert len(events) == 1
assert events[0]['observation'] == 'null'
assert (
isinstance(events[0], NullObservation)
and events[0].observation == ObservationType.NULL
)

# Filter by NullAction and MessageAction
events = event_stream.get_matching_events(event_types=(NullAction, MessageAction))
assert len(events) == 3

# Filter in reverse
events = event_stream.get_matching_events(reverse=True, limit=1)
assert events[0]['message'] == 'test'
assert isinstance(events[0], MessageAction) and events[0].content == 'test'


def test_get_matching_events_query_search(temp_dir: str):
Expand Down Expand Up @@ -126,12 +130,17 @@ def test_get_matching_events_source_filter(temp_dir: str):
# Filter by AGENT source
events = event_stream.get_matching_events(source='agent')
assert len(events) == 2
assert all(e['source'] == 'agent' for e in events)
assert all(
isinstance(e, NullObservation) and e.source == EventSource.AGENT for e in events
)

# Filter by ENVIRONMENT source
events = event_stream.get_matching_events(source='environment')
assert len(events) == 1
assert events[0]['source'] == 'environment'
assert (
isinstance(events[0], NullObservation)
and events[0].source == EventSource.ENVIRONMENT
)


def test_get_matching_events_pagination(temp_dir: str):
Expand All @@ -149,13 +158,13 @@ def test_get_matching_events_pagination(temp_dir: str):
# Test start_id
events = event_stream.get_matching_events(start_id=2)
assert len(events) == 3
assert events[0]['content'] == 'test2'
assert isinstance(events[0], NullObservation) and events[0].content == 'test2'

# Test combination of start_id and limit
events = event_stream.get_matching_events(start_id=1, limit=2)
assert len(events) == 2
assert events[0]['content'] == 'test1'
assert events[1]['content'] == 'test2'
assert isinstance(events[0], NullObservation) and events[0].content == 'test1'
assert isinstance(events[1], NullObservation) and events[1].content == 'test2'


def test_get_matching_events_limit_validation(temp_dir: str):
Expand Down