Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
malhotra5 committed Jan 28, 2025
1 parent ea9411a commit 3d0fded
Showing 1 changed file with 17 additions and 8 deletions.
25 changes: 17 additions & 8 deletions tests/unit/test_event_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import pytest
from pytest import TempPathFactory

from openhands.core.schema.observation import ObservationType
from openhands.events import EventSource, EventStream
from openhands.events.action import (
NullAction,
Expand Down Expand Up @@ -78,20 +79,23 @@ def test_get_matching_events_type_filter(temp_dir: str):
# Filter by NullAction
events = event_stream.get_matching_events(event_types=(NullAction,))
assert len(events) == 2
assert all(e['action'] == 'null' for e in events)
assert all(isinstance(e, NullAction) for e in events)

# Filter by NullObservation
events = event_stream.get_matching_events(event_types=(NullObservation,))
assert len(events) == 1
assert events[0]['observation'] == 'null'
assert (
isinstance(events[0], NullObservation)
and events[0].observation == ObservationType.NULL
)

# Filter by NullAction and MessageAction
events = event_stream.get_matching_events(event_types=(NullAction, MessageAction))
assert len(events) == 3

# Filter in reverse
events = event_stream.get_matching_events(reverse=True, limit=1)
assert events[0]['message'] == 'test'
assert isinstance(events[0], MessageAction) and events[0].content == 'test'


def test_get_matching_events_query_search(temp_dir: str):
Expand Down Expand Up @@ -126,12 +130,17 @@ def test_get_matching_events_source_filter(temp_dir: str):
# Filter by AGENT source
events = event_stream.get_matching_events(source='agent')
assert len(events) == 2
assert all(e['source'] == 'agent' for e in events)
assert all(
isinstance(e, NullObservation) and e.source == EventSource.AGENT for e in events
)

# Filter by ENVIRONMENT source
events = event_stream.get_matching_events(source='environment')
assert len(events) == 1
assert events[0]['source'] == 'environment'
assert (
isinstance(events[0], NullObservation)
and events[0].source == EventSource.ENVIRONMENT
)


def test_get_matching_events_pagination(temp_dir: str):
Expand All @@ -149,13 +158,13 @@ def test_get_matching_events_pagination(temp_dir: str):
# Test start_id
events = event_stream.get_matching_events(start_id=2)
assert len(events) == 3
assert events[0]['content'] == 'test2'
assert isinstance(events[0], NullObservation) and events[0].content == 'test2'

# Test combination of start_id and limit
events = event_stream.get_matching_events(start_id=1, limit=2)
assert len(events) == 2
assert events[0]['content'] == 'test1'
assert events[1]['content'] == 'test2'
assert isinstance(events[0], NullObservation) and events[0].content == 'test1'
assert isinstance(events[1], NullObservation) and events[1].content == 'test2'


def test_get_matching_events_limit_validation(temp_dir: str):
Expand Down

0 comments on commit 3d0fded

Please sign in to comment.