Skip to content

Commit

Permalink
Feat: Introduce class for SessionInitData rather than using a dict (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
tofarr authored Dec 5, 2024
1 parent 1146b62 commit de81020
Show file tree
Hide file tree
Showing 6 changed files with 68 additions and 45 deletions.
10 changes: 5 additions & 5 deletions config.template.toml
Original file line number Diff line number Diff line change
Expand Up @@ -95,10 +95,10 @@ workspace_base = "./workspace"
# AWS secret access key
#aws_secret_access_key = ""

# API key to use
# API key to use (For Headless / CLI only - In Web this is overridden by Session Init)
api_key = "your-api-key"

# API base URL
# API base URL (For Headless / CLI only - In Web this is overridden by Session Init)
#base_url = ""

# API version
Expand Down Expand Up @@ -131,7 +131,7 @@ embedding_model = "local"
# Maximum number of output tokens
#max_output_tokens = 0

# Model to use
# Model to use. (For Headless / CLI only - In Web this is overridden by Session Init)
model = "gpt-4o"

# Number of retries to attempt when an operation fails with the LLM.
Expand Down Expand Up @@ -237,10 +237,10 @@ llm_config = 'gpt3'
##############################################################################
[security]

# Enable confirmation mode
# Enable confirmation mode (For Headless / CLI only - In Web this is overridden by Session Init)
#confirmation_mode = false

# The security analyzer to use
# The security analyzer to use (For Headless / CLI only - In Web this is overridden by Session Init)
#security_analyzer = ""

#################################### Eval ####################################
Expand Down
9 changes: 5 additions & 4 deletions openhands/server/session/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from openhands.runtime.base import RuntimeUnavailableError
from openhands.server.session.conversation import Conversation
from openhands.server.session.session import ROOM_KEY, Session
from openhands.server.session.session_init_data import SessionInitData
from openhands.storage.files import FileStore
from openhands.utils.shutdown_listener import should_continue

Expand Down Expand Up @@ -141,7 +142,7 @@ async def attach_to_conversation(self, sid: str) -> Conversation | None:
async def detach_from_conversation(self, conversation: Conversation):
await conversation.disconnect()

async def init_or_join_session(self, sid: str, connection_id: str, data: dict):
async def init_or_join_session(self, sid: str, connection_id: str, session_init_data: SessionInitData):
await self.sio.enter_room(connection_id, ROOM_KEY.format(sid=sid))
self.local_connection_id_to_session_id[connection_id] = sid

Expand All @@ -156,7 +157,7 @@ async def init_or_join_session(self, sid: str, connection_id: str, data: dict):
if redis_client and await self._is_session_running_in_cluster(sid):
return EventStream(sid, self.file_store)

return await self.start_local_session(sid, data)
return await self.start_local_session(sid, session_init_data)

async def _is_session_running_in_cluster(self, sid: str) -> bool:
"""As the rest of the cluster if a session is running. Wait a for a short timeout for a reply"""
Expand Down Expand Up @@ -210,14 +211,14 @@ async def _has_remote_connections(self, sid: str) -> bool:
finally:
self._has_remote_connections_flags.pop(sid)

async def start_local_session(self, sid: str, data: dict):
async def start_local_session(self, sid: str, session_init_data: SessionInitData):
# Start a new local session
logger.info(f'start_new_local_session:{sid}')
session = Session(
sid=sid, file_store=self.file_store, config=self.config, sio=self.sio
)
self.local_sessions_by_sid[sid] = session
await session.initialize_agent(data)
await session.initialize_agent(session_init_data)
return session.agent_session.event_stream

async def send_to_event_stream(self, connection_id: str, data: dict):
Expand Down
37 changes: 14 additions & 23 deletions openhands/server/session/session.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
from copy import deepcopy
import time

import socketio
Expand All @@ -21,6 +22,7 @@
from openhands.events.stream import EventStreamSubscriber
from openhands.llm.llm import LLM
from openhands.server.session.agent_session import AgentSession
from openhands.server.session.session_init_data import SessionInitData
from openhands.storage.files import FileStore

ROOM_KEY = 'room:{sid}'
Expand All @@ -34,7 +36,6 @@ class Session:
agent_session: AgentSession
loop: asyncio.AbstractEventLoop
config: AppConfig
settings: dict | None

def __init__(
self,
Expand All @@ -52,41 +53,31 @@ def __init__(
self.agent_session.event_stream.subscribe(
EventStreamSubscriber.SERVER, self.on_event, self.sid
)
self.config = config
# Copying this means that when we update variables they are not applied to the shared global configuration!
self.config = deepcopy(config)
self.loop = asyncio.get_event_loop()
self.settings = None

def close(self):
self.is_alive = False
self.agent_session.close()

async def initialize_agent(self, data: dict):
self.settings = data
async def initialize_agent(self, session_init_data: SessionInitData):
self.agent_session.event_stream.add_event(
AgentStateChangedObservation('', AgentState.LOADING),
EventSource.ENVIRONMENT,
)
# Extract the agent-relevant arguments from the request
args = {key: value for key, value in data.get('args', {}).items()}
agent_cls = args.get(ConfigType.AGENT, self.config.default_agent)
self.config.security.confirmation_mode = args.get(
ConfigType.CONFIRMATION_MODE, self.config.security.confirmation_mode
)
self.config.security.security_analyzer = data.get('args', {}).get(
ConfigType.SECURITY_ANALYZER, self.config.security.security_analyzer
)
max_iterations = args.get(ConfigType.MAX_ITERATIONS, self.config.max_iterations)
agent_cls = session_init_data.agent or self.config.default_agent
self.config.security.confirmation_mode = self.config.security.confirmation_mode if session_init_data.confirmation_mode is None else session_init_data.confirmation_mode
self.config.security.security_analyzer = session_init_data.security_analyzer or self.config.security.security_analyzer
max_iterations = session_init_data.max_iterations or self.config.max_iterations
# override default LLM config


default_llm_config = self.config.get_llm_config()
default_llm_config.model = args.get(
ConfigType.LLM_MODEL, default_llm_config.model
)
default_llm_config.api_key = args.get(
ConfigType.LLM_API_KEY, default_llm_config.api_key
)
default_llm_config.base_url = args.get(
ConfigType.LLM_BASE_URL, default_llm_config.base_url
)
default_llm_config.model = session_init_data.llm_model or default_llm_config.model
default_llm_config.api_key = session_init_data.llm_api_key or default_llm_config.api_key
default_llm_config.base_url = session_init_data.llm_base_url or default_llm_config.base_url

# TODO: override other LLM config & agent config groups (#2075)

Expand Down
18 changes: 18 additions & 0 deletions openhands/server/session/session_init_data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@


from dataclasses import dataclass


@dataclass
class SessionInitData:
"""
Session initialization data for the web environment - a deep copy of the global config is made and then overridden with this data.
"""
language: str | None = None
agent: str | None = None
max_iterations: int | None = None
security_analyzer: str | None = None
confirmation_mode: bool | None = None
llm_model: str | None = None
llm_api_key: str | None = None
llm_base_url: str | None = None
26 changes: 19 additions & 7 deletions openhands/server/socket.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from openhands.events.stream import AsyncEventStreamWrapper
from openhands.server.auth import get_sid_from_token, sign_token
from openhands.server.github_utils import authenticate_github_user
from openhands.server.session.session_init_data import SessionInitData
from openhands.server.shared import config, session_manager, sio


Expand All @@ -26,19 +27,30 @@ async def oh_action(connection_id: str, data: dict):
# If it's an init, we do it here.
action = data.get('action', '')
if action == ActionType.INIT:
await init_connection(connection_id, data)
token = data.pop('token', None)
github_token = data.pop('github_token', None)
latest_event_id = int(data.pop('latest_event_id', -1))
kwargs = {k.lower(): v for k, v in (data.get('args') or {}).items()}
session_init_data = SessionInitData(**kwargs)
await init_connection(
connection_id, token, github_token, session_init_data, latest_event_id
)
return

logger.info(f'sio:oh_action:{connection_id}')
await session_manager.send_to_event_stream(connection_id, data)


async def init_connection(connection_id: str, data: dict):
gh_token = data.pop('github_token', None)
async def init_connection(
connection_id: str,
token: str | None,
gh_token: str | None,
session_init_data: SessionInitData,
latest_event_id: int,
):
if not await authenticate_github_user(gh_token):
raise RuntimeError(status.WS_1008_POLICY_VIOLATION)

token = data.pop('token', None)
if token:
sid = get_sid_from_token(token, config.jwt_secret)
if sid == '':
Expand All @@ -52,10 +64,10 @@ async def init_connection(connection_id: str, data: dict):
token = sign_token({'sid': sid}, config.jwt_secret)
await sio.emit('oh_event', {'token': token, 'status': 'ok'}, to=connection_id)

latest_event_id = int(data.pop('latest_event_id', -1))

# The session in question should exist, but may not actually be running locally...
event_stream = await session_manager.init_or_join_session(sid, connection_id, data)
event_stream = await session_manager.init_or_join_session(
sid, connection_id, session_init_data
)

# Send events
agent_state_changed = None
Expand Down
13 changes: 7 additions & 6 deletions tests/unit/test_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from openhands.core.config.app_config import AppConfig
from openhands.server.session.manager import SessionManager
from openhands.server.session.session_init_data import SessionInitData
from openhands.storage.memory import InMemoryFileStore


Expand Down Expand Up @@ -100,7 +101,7 @@ async def test_init_new_local_session():
sio, AppConfig(), InMemoryFileStore()
) as session_manager:
await session_manager.init_or_join_session(
'new-session-id', 'new-session-id', {'type': 'mock-settings'}
'new-session-id', 'new-session-id', SessionInitData()
)
assert session_instance.initialize_agent.call_count == 1
assert sio.enter_room.await_count == 1
Expand Down Expand Up @@ -132,11 +133,11 @@ async def test_join_local_session():
) as session_manager:
# First call initializes
await session_manager.init_or_join_session(
'new-session-id', 'new-session-id', {'type': 'mock-settings'}
'new-session-id', 'new-session-id', SessionInitData()
)
# Second call joins
await session_manager.init_or_join_session(
'new-session-id', 'extra-connection-id', {'type': 'mock-settings'}
'new-session-id', 'extra-connection-id', SessionInitData()
)
assert session_instance.initialize_agent.call_count == 1
assert sio.enter_room.await_count == 2
Expand Down Expand Up @@ -168,7 +169,7 @@ async def test_join_cluster_session():
) as session_manager:
# First call initializes
await session_manager.init_or_join_session(
'new-session-id', 'new-session-id', {'type': 'mock-settings'}
'new-session-id', 'new-session-id', SessionInitData()
)
assert session_instance.initialize_agent.call_count == 0
assert sio.enter_room.await_count == 1
Expand Down Expand Up @@ -199,7 +200,7 @@ async def test_add_to_local_event_stream():
sio, AppConfig(), InMemoryFileStore()
) as session_manager:
await session_manager.init_or_join_session(
'new-session-id', 'connection-id', {'type': 'mock-settings'}
'new-session-id', 'connection-id', SessionInitData()
)
await session_manager.send_to_event_stream(
'connection-id', {'event_type': 'some_event'}
Expand Down Expand Up @@ -232,7 +233,7 @@ async def test_add_to_cluster_event_stream():
sio, AppConfig(), InMemoryFileStore()
) as session_manager:
await session_manager.init_or_join_session(
'new-session-id', 'connection-id', {'type': 'mock-settings'}
'new-session-id', 'connection-id', SessionInitData()
)
await session_manager.send_to_event_stream(
'connection-id', {'event_type': 'some_event'}
Expand Down

0 comments on commit de81020

Please sign in to comment.