Skip to content

Commit

Permalink
Refactor: Use type[Event] instead of str to filter events (#6480)
Browse files Browse the repository at this point in the history
Co-authored-by: openhands <[email protected]>
  • Loading branch information
malhotra5 and openhands-agent authored Jan 27, 2025
1 parent 4bde644 commit 6045349
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 8 deletions.
10 changes: 5 additions & 5 deletions openhands/events/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ def _should_filter_event(
self,
event,
query: str | None = None,
event_type: str | None = None,
event_type: type[Event] | None = None,
source: str | None = None,
start_date: str | None = None,
end_date: str | None = None,
Expand All @@ -329,15 +329,15 @@ def _should_filter_event(
Args:
event: The event to check
query (str, optional): Text to search for in event content
event_type (str, optional): Filter by event type (e.g., "FileReadAction")
event_type (type[Event], optional): Filter by event type class (e.g., FileReadAction)
source (str, optional): Filter by event source
start_date (str, optional): Filter events after this date (ISO format)
end_date (str, optional): Filter events before this date (ISO format)
Returns:
bool: True if the event should be filtered out, False if it matches all criteria
"""
if event_type and not event.__class__.__name__ == event_type:
if event_type and not isinstance(event, event_type):
return True

if source and not event.source.value == source:
Expand All @@ -361,7 +361,7 @@ def _should_filter_event(
def get_matching_events(
self,
query: str | None = None,
event_type: str | None = None,
event_type: type[Event] | None = None,
source: str | None = None,
start_date: str | None = None,
end_date: str | None = None,
Expand All @@ -372,7 +372,7 @@ def get_matching_events(
Args:
query (str, optional): Text to search for in event content
event_type (str, optional): Filter by event type (e.g., "FileReadAction")
event_type (type[Event], optional): Filter by event type class (e.g., FileReadAction)
source (str, optional): Filter by event source
start_date (str, optional): Filter events after this date (ISO format)
end_date (str, optional): Filter events before this date (ISO format)
Expand Down
19 changes: 18 additions & 1 deletion openhands/server/routes/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,26 @@
from fastapi.responses import JSONResponse

from openhands.core.logger import openhands_logger as logger
from openhands.events.event import Event
from openhands.events.serialization.event import event_from_dict
from openhands.runtime.base import Runtime

app = APIRouter(prefix='/api/conversations/{conversation_id}')


def str_to_event_type(event: str | None) -> Event | None:
if not event:
return None

for event_type in ['observation', 'action']:
try:
return event_from_dict({event_type: event})
except Exception:
continue

return None


@app.get('/config')
async def get_remote_runtime_config(request: Request):
"""Retrieve the runtime configuration.
Expand Down Expand Up @@ -126,9 +141,11 @@ async def search_events(
)
# Get matching events from the stream
event_stream = request.state.conversation.event_stream

cast_event_type = str_to_event_type(event_type)
matching_events = event_stream.get_matching_events(
query=query,
event_type=event_type,
event_type=cast_event_type,
source=source,
start_date=start_date,
end_date=end_date,
Expand Down
4 changes: 2 additions & 2 deletions tests/unit/test_event_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,12 +74,12 @@ def test_get_matching_events_type_filter(temp_dir: str):
event_stream.add_event(NullAction(), EventSource.AGENT)

# Filter by NullAction
events = event_stream.get_matching_events(event_type='NullAction')
events = event_stream.get_matching_events(event_type=NullAction)
assert len(events) == 2
assert all(e['action'] == 'null' for e in events)

# Filter by NullObservation
events = event_stream.get_matching_events(event_type='NullObservation')
events = event_stream.get_matching_events(event_type=NullObservation)
assert len(events) == 1
assert events[0]['observation'] == 'null'

Expand Down

0 comments on commit 6045349

Please sign in to comment.