Skip to content

Commit

Permalink
Add filter_hidden parameter to trajectory API
Browse files Browse the repository at this point in the history
- Add hidden property to Event class
- Add hidden parameter to EventStream.add_event method
- Add tests for filtering hidden events in trajectory API
  • Loading branch information
openhands-agent committed Feb 7, 2025
1 parent d3fa9ab commit 125c051
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 1 deletion.
10 changes: 10 additions & 0 deletions openhands/events/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,3 +107,13 @@ def tool_call_metadata(self) -> ToolCallMetadata | None:
@tool_call_metadata.setter
def tool_call_metadata(self, value: ToolCallMetadata) -> None:
self._tool_call_metadata = value

@property
def hidden(self) -> bool:
if hasattr(self, '_hidden'):
return self._hidden # type: ignore[attr-defined]
return False

@hidden.setter
def hidden(self, value: bool) -> None:
self._hidden = value
3 changes: 2 additions & 1 deletion openhands/events/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ def unsubscribe(self, subscriber_id: EventStreamSubscriber, callback_id: str):

self._clean_up_subscriber(subscriber_id, callback_id)

def add_event(self, event: Event, source: EventSource):
def add_event(self, event: Event, source: EventSource, hidden: bool = False):
if hasattr(event, '_id') and event.id is not None:
raise ValueError(
f'Event already has an ID:{event.id}. It was probably added back to the EventStream from inside a handler, triggering a loop.'
Expand All @@ -268,6 +268,7 @@ def add_event(self, event: Event, source: EventSource):
logger.debug(f'Adding {type(event).__name__} id={event.id} from {source.name}')
event._timestamp = datetime.now().isoformat()
event._source = source # type: ignore [attr-defined]
event.hidden = hidden
data = event_to_dict(event)
data = self._replace_secrets(data)
event = event_from_dict(data)
Expand Down
88 changes: 88 additions & 0 deletions tests/unit/test_trajectory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
import pytest
from fastapi import Request
from fastapi.responses import JSONResponse

from openhands.events import EventSource, EventStream
from openhands.events.action import NullAction
from openhands.events.observation import NullObservation
from openhands.storage import get_file_store
from openhands.server.routes.trajectory import get_trajectory

pytestmark = pytest.mark.asyncio


@pytest.fixture
def temp_dir(tmp_path_factory: pytest.TempPathFactory) -> str:
return str(tmp_path_factory.mktemp("test_trajectory"))


@pytest.fixture
def event_stream(temp_dir: str):
file_store = get_file_store("local", temp_dir)
stream = EventStream("test_conversation", file_store)
# Add a mix of hidden and non-hidden events
stream.add_event(NullAction(), EventSource.AGENT, hidden=True)
stream.add_event(NullObservation("visible1"), EventSource.AGENT, hidden=False)
stream.add_event(NullAction(), EventSource.AGENT, hidden=True)
stream.add_event(NullObservation("visible2"), EventSource.AGENT, hidden=False)
return stream


@pytest.fixture
def mock_request(event_stream):
class MockRequest:
def __init__(self):
self.state = type(
"State",
(),
{
"conversation": type(
"Conversation", (), {"event_stream": event_stream}
)()
},
)

return MockRequest()


async def test_get_trajectory_filter_hidden(mock_request):
# Test with filter_hidden=True (default)
response = await get_trajectory(mock_request)
assert isinstance(response, JSONResponse)
assert response.status_code == 200

content = response.body.decode()
assert "visible1" in content
assert "visible2" in content
# Hidden events should not be in the response
assert (
len(response.body.decode().split("NullAction")) == 1
) # Only in the class name


async def test_get_trajectory_show_hidden(mock_request):
# Test with filter_hidden=False
response = await get_trajectory(mock_request, filter_hidden=False)
assert isinstance(response, JSONResponse)
assert response.status_code == 200

content = response.body.decode()
assert "visible1" in content
assert "visible2" in content
# Hidden events should be in the response
# Count the number of "action":"null" occurrences which represent NullAction events
assert (
len(content.split('"action":"null"')) > 2
) # More occurrences due to hidden events


async def test_get_trajectory_error_handling():
# Test error handling with a broken request
class BrokenRequest:
def __init__(self):
self.state = None

response = await get_trajectory(BrokenRequest())
assert isinstance(response, JSONResponse)
assert response.status_code == 500
assert "error" in response.body.decode()

0 comments on commit 125c051

Please sign in to comment.