From 271e7a396b3c72f6627638ae5b6b5941a14b204c Mon Sep 17 00:00:00 2001 From: Rohit Malhotra Date: Mon, 27 Jan 2025 17:09:16 -0500 Subject: [PATCH] Feat: Ability to filter events by multiple types (#6484) --- openhands/events/stream.py | 32 ++++++++++++------------- openhands/server/routes/conversation.py | 18 +++++++------- tests/unit/test_event_stream.py | 10 ++++++-- 3 files changed, 33 insertions(+), 27 deletions(-) diff --git a/openhands/events/stream.py b/openhands/events/stream.py index 50b24ed84810..c1df06335d01 100644 --- a/openhands/events/stream.py +++ b/openhands/events/stream.py @@ -319,7 +319,7 @@ def _should_filter_event( self, event, query: str | None = None, - event_type: type[Event] | None = None, + event_types: tuple[type[Event], ...] | None = None, source: str | None = None, start_date: str | None = None, end_date: str | None = None, @@ -328,16 +328,16 @@ def _should_filter_event( Args: event: The event to check - query (str, optional): Text to search for in event content - event_type (type[Event], optional): Filter by event type class (e.g., FileReadAction) - source (str, optional): Filter by event source - start_date (str, optional): Filter events after this date (ISO format) - end_date (str, optional): Filter events before this date (ISO format) + query: Text to search for in event content + event_type: Filter by event type classes (e.g., (FileReadAction, ) ). + source: Filter by event source + start_date: Filter events after this date (ISO format) + end_date: Filter events before this date (ISO format) Returns: bool: True if the event should be filtered out, False if it matches all criteria """ - if event_type and not isinstance(event, event_type): + if event_types and not isinstance(event, event_types): return True if source and not event.source.value == source: @@ -361,7 +361,7 @@ def _should_filter_event( def get_matching_events( self, query: str | None = None, - event_type: type[Event] | None = None, + event_types: tuple[type[Event], ...] | None = None, source: str | None = None, start_date: str | None = None, end_date: str | None = None, @@ -371,13 +371,13 @@ def get_matching_events( """Get matching events from the event stream based on filters. Args: - query (str, optional): Text to search for in event content - event_type (type[Event], optional): Filter by event type class (e.g., FileReadAction) - source (str, optional): Filter by event source - start_date (str, optional): Filter events after this date (ISO format) - end_date (str, optional): Filter events before this date (ISO format) - start_id (int): Starting ID in the event stream. Defaults to 0 - limit (int): Maximum number of events to return. Must be between 1 and 100. Defaults to 100 + query: Text to search for in event content + event_types: Filter by event type classes (e.g., (FileReadAction, ) ). + source: Filter by event source + start_date: Filter events after this date (ISO format) + end_date: Filter events before this date (ISO format) + start_id: Starting ID in the event stream. Defaults to 0 + limit: Maximum number of events to return. Must be between 1 and 100. Defaults to 100 Returns: list: List of matching events (as dicts) @@ -392,7 +392,7 @@ def get_matching_events( for event in self.get_events(start_id=start_id): if self._should_filter_event( - event, query, event_type, source, start_date, end_date + event, query, event_types, source, start_date, end_date ): continue diff --git a/openhands/server/routes/conversation.py b/openhands/server/routes/conversation.py index d5fab4515a9d..c0cef6c4ae01 100644 --- a/openhands/server/routes/conversation.py +++ b/openhands/server/routes/conversation.py @@ -119,14 +119,14 @@ async def search_events( ): """Search through the event stream with filtering and pagination. Args: - request (Request): The incoming request object - query (str, optional): Text to search for in event content - start_id (int): Starting ID in the event stream. Defaults to 0 - limit (int): Maximum number of events to return. Must be between 1 and 100. Defaults to 20 - event_type (str, optional): Filter by event type (e.g., "FileReadAction") - source (str, optional): Filter by event source - start_date (str, optional): Filter events after this date (ISO format) - end_date (str, optional): Filter events before this date (ISO format) + request: The incoming request object + query: Text to search for in event content + start_id: Starting ID in the event stream. Defaults to 0 + limit: Maximum number of events to return. Must be between 1 and 100. Defaults to 20 + event_type: Filter by event type (e.g., "FileReadAction") + source: Filter by event source + start_date: Filter events after this date (ISO format) + end_date: Filter events before this date (ISO format) Returns: dict: Dictionary containing: - events: List of matching events @@ -145,7 +145,7 @@ async def search_events( cast_event_type = str_to_event_type(event_type) matching_events = event_stream.get_matching_events( query=query, - event_type=cast_event_type, + event_types=(cast_event_type), source=source, start_date=start_date, end_date=end_date, diff --git a/tests/unit/test_event_stream.py b/tests/unit/test_event_stream.py index d9ce963bf638..e666340a54e1 100644 --- a/tests/unit/test_event_stream.py +++ b/tests/unit/test_event_stream.py @@ -7,6 +7,7 @@ from openhands.events.action import ( NullAction, ) +from openhands.events.action.message import MessageAction from openhands.events.observation import NullObservation from openhands.storage import get_file_store @@ -72,17 +73,22 @@ def test_get_matching_events_type_filter(temp_dir: str): event_stream.add_event(NullAction(), EventSource.AGENT) event_stream.add_event(NullObservation('test'), EventSource.AGENT) event_stream.add_event(NullAction(), EventSource.AGENT) + event_stream.add_event(MessageAction(content='test'), EventSource.AGENT) # Filter by NullAction - events = event_stream.get_matching_events(event_type=NullAction) + events = event_stream.get_matching_events(event_types=(NullAction,)) assert len(events) == 2 assert all(e['action'] == 'null' for e in events) # Filter by NullObservation - events = event_stream.get_matching_events(event_type=NullObservation) + events = event_stream.get_matching_events(event_types=(NullObservation,)) assert len(events) == 1 assert events[0]['observation'] == 'null' + # Filter by NullAction and MessageAction + events = event_stream.get_matching_events(event_types=(NullAction, MessageAction)) + assert len(events) == 3 + def test_get_matching_events_query_search(temp_dir: str): file_store = get_file_store('local', temp_dir)