Skip to content

Commit

Permalink
Feat: Ability to filter events by multiple types (All-Hands-AI#6484)
Browse files Browse the repository at this point in the history
  • Loading branch information
malhotra5 authored and idagelic committed Feb 12, 2025
1 parent 115c000 commit 271e7a3
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 27 deletions.
32 changes: 16 additions & 16 deletions openhands/events/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ def _should_filter_event(
self,
event,
query: str | None = None,
event_type: type[Event] | None = None,
event_types: tuple[type[Event], ...] | None = None,
source: str | None = None,
start_date: str | None = None,
end_date: str | None = None,
Expand All @@ -328,16 +328,16 @@ def _should_filter_event(
Args:
event: The event to check
query (str, optional): Text to search for in event content
event_type (type[Event], optional): Filter by event type class (e.g., FileReadAction)
source (str, optional): Filter by event source
start_date (str, optional): Filter events after this date (ISO format)
end_date (str, optional): Filter events before this date (ISO format)
query: Text to search for in event content
event_type: Filter by event type classes (e.g., (FileReadAction, ) ).
source: Filter by event source
start_date: Filter events after this date (ISO format)
end_date: Filter events before this date (ISO format)
Returns:
bool: True if the event should be filtered out, False if it matches all criteria
"""
if event_type and not isinstance(event, event_type):
if event_types and not isinstance(event, event_types):
return True

if source and not event.source.value == source:
Expand All @@ -361,7 +361,7 @@ def _should_filter_event(
def get_matching_events(
self,
query: str | None = None,
event_type: type[Event] | None = None,
event_types: tuple[type[Event], ...] | None = None,
source: str | None = None,
start_date: str | None = None,
end_date: str | None = None,
Expand All @@ -371,13 +371,13 @@ def get_matching_events(
"""Get matching events from the event stream based on filters.
Args:
query (str, optional): Text to search for in event content
event_type (type[Event], optional): Filter by event type class (e.g., FileReadAction)
source (str, optional): Filter by event source
start_date (str, optional): Filter events after this date (ISO format)
end_date (str, optional): Filter events before this date (ISO format)
start_id (int): Starting ID in the event stream. Defaults to 0
limit (int): Maximum number of events to return. Must be between 1 and 100. Defaults to 100
query: Text to search for in event content
event_types: Filter by event type classes (e.g., (FileReadAction, ) ).
source: Filter by event source
start_date: Filter events after this date (ISO format)
end_date: Filter events before this date (ISO format)
start_id: Starting ID in the event stream. Defaults to 0
limit: Maximum number of events to return. Must be between 1 and 100. Defaults to 100
Returns:
list: List of matching events (as dicts)
Expand All @@ -392,7 +392,7 @@ def get_matching_events(

for event in self.get_events(start_id=start_id):
if self._should_filter_event(
event, query, event_type, source, start_date, end_date
event, query, event_types, source, start_date, end_date
):
continue

Expand Down
18 changes: 9 additions & 9 deletions openhands/server/routes/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,14 +119,14 @@ async def search_events(
):
"""Search through the event stream with filtering and pagination.
Args:
request (Request): The incoming request object
query (str, optional): Text to search for in event content
start_id (int): Starting ID in the event stream. Defaults to 0
limit (int): Maximum number of events to return. Must be between 1 and 100. Defaults to 20
event_type (str, optional): Filter by event type (e.g., "FileReadAction")
source (str, optional): Filter by event source
start_date (str, optional): Filter events after this date (ISO format)
end_date (str, optional): Filter events before this date (ISO format)
request: The incoming request object
query: Text to search for in event content
start_id: Starting ID in the event stream. Defaults to 0
limit: Maximum number of events to return. Must be between 1 and 100. Defaults to 20
event_type: Filter by event type (e.g., "FileReadAction")
source: Filter by event source
start_date: Filter events after this date (ISO format)
end_date: Filter events before this date (ISO format)
Returns:
dict: Dictionary containing:
- events: List of matching events
Expand All @@ -145,7 +145,7 @@ async def search_events(
cast_event_type = str_to_event_type(event_type)
matching_events = event_stream.get_matching_events(
query=query,
event_type=cast_event_type,
event_types=(cast_event_type),
source=source,
start_date=start_date,
end_date=end_date,
Expand Down
10 changes: 8 additions & 2 deletions tests/unit/test_event_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from openhands.events.action import (
NullAction,
)
from openhands.events.action.message import MessageAction
from openhands.events.observation import NullObservation
from openhands.storage import get_file_store

Expand Down Expand Up @@ -72,17 +73,22 @@ def test_get_matching_events_type_filter(temp_dir: str):
event_stream.add_event(NullAction(), EventSource.AGENT)
event_stream.add_event(NullObservation('test'), EventSource.AGENT)
event_stream.add_event(NullAction(), EventSource.AGENT)
event_stream.add_event(MessageAction(content='test'), EventSource.AGENT)

# Filter by NullAction
events = event_stream.get_matching_events(event_type=NullAction)
events = event_stream.get_matching_events(event_types=(NullAction,))
assert len(events) == 2
assert all(e['action'] == 'null' for e in events)

# Filter by NullObservation
events = event_stream.get_matching_events(event_type=NullObservation)
events = event_stream.get_matching_events(event_types=(NullObservation,))
assert len(events) == 1
assert events[0]['observation'] == 'null'

# Filter by NullAction and MessageAction
events = event_stream.get_matching_events(event_types=(NullAction, MessageAction))
assert len(events) == 3


def test_get_matching_events_query_search(temp_dir: str):
file_store = get_file_store('local', temp_dir)
Expand Down

0 comments on commit 271e7a3

Please sign in to comment.