forked from All-Hands-AI/OpenHands
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add filter_hidden parameter to trajectory API
- Add hidden property to Event class - Add hidden parameter to EventStream.add_event method - Add tests for filtering hidden events in trajectory API
- Loading branch information
1 parent
d3fa9ab
commit 125c051
Showing
3 changed files
with
100 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,88 @@ | ||
import pytest | ||
from fastapi import Request | ||
from fastapi.responses import JSONResponse | ||
|
||
from openhands.events import EventSource, EventStream | ||
from openhands.events.action import NullAction | ||
from openhands.events.observation import NullObservation | ||
from openhands.storage import get_file_store | ||
from openhands.server.routes.trajectory import get_trajectory | ||
|
||
pytestmark = pytest.mark.asyncio | ||
|
||
|
||
@pytest.fixture | ||
def temp_dir(tmp_path_factory: pytest.TempPathFactory) -> str: | ||
return str(tmp_path_factory.mktemp("test_trajectory")) | ||
|
||
|
||
@pytest.fixture | ||
def event_stream(temp_dir: str): | ||
file_store = get_file_store("local", temp_dir) | ||
stream = EventStream("test_conversation", file_store) | ||
# Add a mix of hidden and non-hidden events | ||
stream.add_event(NullAction(), EventSource.AGENT, hidden=True) | ||
stream.add_event(NullObservation("visible1"), EventSource.AGENT, hidden=False) | ||
stream.add_event(NullAction(), EventSource.AGENT, hidden=True) | ||
stream.add_event(NullObservation("visible2"), EventSource.AGENT, hidden=False) | ||
return stream | ||
|
||
|
||
@pytest.fixture | ||
def mock_request(event_stream): | ||
class MockRequest: | ||
def __init__(self): | ||
self.state = type( | ||
"State", | ||
(), | ||
{ | ||
"conversation": type( | ||
"Conversation", (), {"event_stream": event_stream} | ||
)() | ||
}, | ||
) | ||
|
||
return MockRequest() | ||
|
||
|
||
async def test_get_trajectory_filter_hidden(mock_request): | ||
# Test with filter_hidden=True (default) | ||
response = await get_trajectory(mock_request) | ||
assert isinstance(response, JSONResponse) | ||
assert response.status_code == 200 | ||
|
||
content = response.body.decode() | ||
assert "visible1" in content | ||
assert "visible2" in content | ||
# Hidden events should not be in the response | ||
assert ( | ||
len(response.body.decode().split("NullAction")) == 1 | ||
) # Only in the class name | ||
|
||
|
||
async def test_get_trajectory_show_hidden(mock_request): | ||
# Test with filter_hidden=False | ||
response = await get_trajectory(mock_request, filter_hidden=False) | ||
assert isinstance(response, JSONResponse) | ||
assert response.status_code == 200 | ||
|
||
content = response.body.decode() | ||
assert "visible1" in content | ||
assert "visible2" in content | ||
# Hidden events should be in the response | ||
# Count the number of "action":"null" occurrences which represent NullAction events | ||
assert ( | ||
len(content.split('"action":"null"')) > 2 | ||
) # More occurrences due to hidden events | ||
|
||
|
||
async def test_get_trajectory_error_handling(): | ||
# Test error handling with a broken request | ||
class BrokenRequest: | ||
def __init__(self): | ||
self.state = None | ||
|
||
response = await get_trajectory(BrokenRequest()) | ||
assert isinstance(response, JSONResponse) | ||
assert response.status_code == 500 | ||
assert "error" in response.body.decode() |