diff --git a/list.txt b/list.txt deleted file mode 100644 index 265e23731df8..000000000000 --- a/list.txt +++ /dev/null @@ -1 +0,0 @@ -tofarr diff --git a/openhands/server/listen.py b/openhands/server/listen.py index b3988de819fc..ecce27e1d9d7 100644 --- a/openhands/server/listen.py +++ b/openhands/server/listen.py @@ -1,3 +1,4 @@ +from contextlib import asynccontextmanager import os import re import tempfile @@ -73,10 +74,25 @@ config = load_app_config() file_store = get_file_store(config.file_store, config.file_store_path) -session_manager = SessionManager(config, file_store) +client_manager = None +redis_host = os.environ.get('REDIS_HOST') +if redis_host: + client_manager = socketio.AsyncRedisManager( + f'redis://{redis_host}', + redis_options={'password': os.environ.get('REDIS_PASSWORD')}, + ) +sio = socketio.AsyncServer( + async_mode='asgi', cors_allowed_origins='*', client_manager=client_manager +) +session_manager = SessionManager(sio, config, file_store) + +@asynccontextmanager +async def _lifespan(app: FastAPI): + async with session_manager: + yield -app = FastAPI() +app = FastAPI(lifespan=_lifespan) app.add_middleware( LocalhostCORSMiddleware, allow_credentials=True, @@ -840,16 +856,6 @@ async def get_response(self, path: str, scope): app.mount('/', SPAStaticFiles(directory='./frontend/build', html=True), name='dist') -client_manager = None -redis_host = os.environ.get('REDIS_HOST') -if redis_host: - client_manager = socketio.AsyncRedisManager( - f'redis://{redis_host}', - redis_options={'password': os.environ.get('REDIS_PASSWORD')}, - ) -sio = socketio.AsyncServer( - async_mode='asgi', cors_allowed_origins='*', client_manager=client_manager -) app = socketio.ASGIApp(sio, other_asgi_app=app) @@ -923,9 +929,7 @@ async def oh_action(connection_id: str, data: dict): return logger.info(f'sio:oh_action:{connection_id}') - session = session_manager.get_local_session(connection_id) - await session.dispatch(data) - + await session_manager.send_to_event_stream(connection_id, data) async def init_connection(connection_id: str, data: dict): gh_token = data.pop('github_token', None) @@ -949,13 +953,11 @@ async def init_connection(connection_id: str, data: dict): latest_event_id = int(data.pop('latest_event_id', -1)) # The session in question should exist, but may not actually be running locally... - session = await session_manager.init_or_join_local_session( - sio, sid, connection_id, data - ) + event_stream = await session_manager.init_or_join_session(sid, connection_id, data) # Send events async_stream = AsyncEventStreamWrapper( - session.agent_session.event_stream, latest_event_id + 1 + event_stream, latest_event_id + 1 ) async for event in async_stream: if isinstance( @@ -973,4 +975,4 @@ async def init_connection(connection_id: str, data: dict): @sio.event async def disconnect(connection_id: str): logger.info(f'sio:disconnect:{connection_id}') - await session_manager.disconnect_from_local_session(connection_id) + await session_manager.disconnect_from_session(connection_id) diff --git a/openhands/server/session/agent_session.py b/openhands/server/session/agent_session.py index 5b64187867bd..f0fdf247084a 100644 --- a/openhands/server/session/agent_session.py +++ b/openhands/server/session/agent_session.py @@ -14,6 +14,7 @@ from openhands.runtime.base import Runtime, RuntimeUnavailableError from openhands.security import SecurityAnalyzer, options from openhands.storage.files import FileStore +from openhands.utils.async_utils import call_async_from_sync class AgentSession: @@ -129,13 +130,8 @@ def close(self): """Closes the Agent session""" if self._closed: return - self._closed = True - - def inner_close(): - asyncio.run(self._close()) - - asyncio.get_event_loop().run_in_executor(None, inner_close) + call_async_from_sync(self._close) async def _close(self): if self.controller is not None: diff --git a/openhands/server/session/manager.py b/openhands/server/session/manager.py index 24db1d5bdcda..a0b2d7008c05 100644 --- a/openhands/server/session/manager.py +++ b/openhands/server/session/manager.py @@ -7,23 +7,57 @@ from openhands.core.config import AppConfig from openhands.core.logger import openhands_logger as logger from openhands.core.schema.agent import AgentState -from openhands.events.event import EventSource from openhands.events.observation.agent import AgentStateChangedObservation from openhands.events.serialization.event import event_to_dict -from openhands.events.stream import session_exists +from openhands.events.stream import EventStream, session_exists from openhands.runtime.base import RuntimeUnavailableError from openhands.server.session.conversation import Conversation -from openhands.server.session.session import Session +from openhands.server.session.session import ROOM_KEY, Session from openhands.storage.files import FileStore from openhands.utils.shutdown_listener import should_continue +_CONNECTION_KEY = "oh_session:{sid}" + @dataclass class SessionManager: + sio: socketio.AsyncServer config: AppConfig file_store: FileStore local_sessions_by_sid: dict[str, Session] = field(default_factory=dict) - local_sessions_by_connection_id: dict[str, Session] = field(default_factory=dict) + local_connection_id_to_session_id: dict[str, str] = field(default_factory=dict) + _redis_listen: bool = False + + async def __aenter__(self): + redis_client = self._get_redis_client() + if redis_client: + self._redis_listen_task = asyncio.create_task(self._redis_subscribe()) + return self + + async def __aexit__(self, exc_type, exc_value, traceback): + self._redis_listen_task.cancel() + + def _get_redis_client(self): + redis_client = getattr(self.sio.manager, "redis") + return redis_client + + async def _redis_subscribe(self): + """ + We use a redis backchannel to send actions between server nodes + """ + redis_client = self._get_redis_client() + pubsub = redis_client.pubsub() + await pubsub.subscribe("oh_event") + while should_continue(): + try: + message = await pubsub.get_message(ignore_subscribe_messages=True, timeout=None) + if message: + sid = message["sid"] + session = self.local_sessions_by_sid.get(sid) + if session: + session.dispatch(message["data"]) + except asyncio.CancelledError: + return async def attach_to_conversation(self, sid: str) -> Conversation | None: start_time = time.time() @@ -44,45 +78,98 @@ async def attach_to_conversation(self, sid: str) -> Conversation | None: async def detach_from_conversation(self, conversation: Conversation): await conversation.disconnect() - async def init_or_join_local_session(self, sio: socketio.AsyncServer, sid: str, connection_id: str, data: dict): - """ If there is no local session running, initialize one """ + async def init_or_join_session(self, sid: str, connection_id: str, data: dict): + await self.sio.enter_room(connection_id, ROOM_KEY.format(sid=sid)) + self.local_connection_id_to_session_id[connection_id] = sid + + # If we have a local session running, use that session = self.local_sessions_by_sid.get(sid) - if not session: - # I think we need to rehydrate here, but it does not seem to be working - session = Session( - sid=sid, file_store=self.file_store, config=self.config, sio=sio - ) - session.connect(connection_id) - self.local_sessions_by_sid[sid] = session - self.local_sessions_by_connection_id[connection_id] = session - await session.initialize_agent(data) - else: - session.connect(connection_id) - self.local_sessions_by_connection_id[connection_id] = session - session.agent_session.event_stream.add_event(AgentStateChangedObservation('', AgentState.INIT), EventSource.ENVIRONMENT) - return session + if session: + self.sio.emit(event_to_dict(AgentStateChangedObservation('', AgentState.INIT)), to=connection_id) + return session.agent_session.event_stream + + # If there is a remote session running, mark a connection to that + redis_client = self._get_redis_client() + if redis_client: + num_connections = await redis_client.rpush(_CONNECTION_KEY.format(sid=sid), connection_id) + # More than one remote connection implies session is already running remotely... + if num_connections != 1: + await self.sio.emit(event_to_dict(AgentStateChangedObservation('', AgentState.INIT)), to=connection_id) + event_stream = EventStream(sid, self.file_store) + return event_stream + + # Start a new local session + session = Session( + sid=sid, file_store=self.file_store, config=self.config, sio=self.sio + ) + self.local_sessions_by_sid[sid] = session + await session.initialize_agent(data) + return session.agent_session.event_stream - def get_local_session(self, connection_id: str) -> Session: - return self.local_sessions_by_connection_id[connection_id] - async def disconnect_from_local_session(self, connection_id: str): - session = self.local_sessions_by_connection_id.pop(connection_id, None) - if not session: + async def send_to_event_stream(self, connection_id: str, data: dict): + # If there is a local session running, send to that + sid = self.local_connection_id_to_session_id[connection_id] + session = self.local_sessions_by_sid.get(sid) + if session: + await session.dispatch(data) + return + + # If there is a remote session running, send to that + redis_client = self._get_redis_client() + if redis_client: + await redis_client.publish("oh_event", { + "sid": sid, + "data": data + }) + return + + raise RuntimeError(f'no_connected_session:{sid}') + + async def disconnect_from_session(self, connection_id: str): + sid = self.local_connection_id_to_session_id.pop(connection_id, None) + if not sid: # This can occur if the init action was never run. logger.warning(f'disconnect_from_uninitialized_session:{connection_id}') return - if session.disconnect(connection_id): + + # Disconnect from redis if present + redis_client = self._get_redis_client() + if redis_client: + await redis_client.lrem(_CONNECTION_KEY.format(sid=sid), 0, connection_id) + + session = self.local_sessions_by_sid.get(sid) + if session: if should_continue(): - asyncio.create_task(self._check_and_close_session(session)) + asyncio.create_task(self._check_and_close_session_later(session)) else: await self._check_and_close_session(session) - async def _check_and_close_session(self, session: Session): + async def _check_and_close_session_later(self, session: Session): # Once there have been no connections to a session for a reasonable period, we close it try: await asyncio.sleep(self.config.sandbox.close_delay) finally: # If the sleep was cancelled, we still want to close these - if not session.connection_ids: + await self._check_and_close_session(session) + + async def _check_and_close_session(self, session: Session): + # Get local connections + has_connections_for_session = next(( + True for v in self.local_connection_id_to_session_id.values() + if v == session.sid + ), False) + + # If no local connections, get connections through redis + if not has_connections_for_session: + redis_client = self._get_redis_client() + if redis_client: + key = _CONNECTION_KEY.format(sid=session.sid) + has_connections_for_session = bool(await redis_client.get(key)) + if not has_connections_for_session: + await redis_client.delete(key) + + # If no connections, close session + if not has_connections_for_session: session.close() - self.local_sessions_by_sid.pop(session.sid, None) \ No newline at end of file + self.local_sessions_by_sid.pop(session.sid, None) diff --git a/openhands/server/session/session.py b/openhands/server/session/session.py index 74cd4fbd1103..58089ec56b2c 100644 --- a/openhands/server/session/session.py +++ b/openhands/server/session/session.py @@ -23,13 +23,14 @@ from openhands.llm.llm import LLM from openhands.server.session.agent_session import AgentSession from openhands.storage.files import FileStore -from openhands.utils.async_utils import wait_all +from openhands.utils.async_utils import call_coro_in_bg_thread + +ROOM_KEY = "room:{sid}" class Session: sid: str sio: socketio.AsyncServer | None - connection_ids: set[str] last_active_ts: int = 0 is_alive: bool = True agent_session: AgentSession @@ -49,16 +50,8 @@ def __init__( EventStreamSubscriber.SERVER, self.on_event, self.sid ) self.config = config - self.connection_ids = set() self.loop = asyncio.get_event_loop() - def connect(self, connection_id: str): - self.connection_ids.add(connection_id) - - def disconnect(self, connection_id: str) -> bool: - self.connection_ids.remove(connection_id) - return not self.connection_ids - def close(self): self.is_alive = False self.agent_session.close() @@ -163,19 +156,21 @@ async def dispatch(self, data: dict): self.agent_session.event_stream.add_event(event, EventSource.USER) async def send(self, data: dict[str, object]) -> bool: - task = self.loop.create_task(self._send(data)) - await task - return task.result() + if asyncio.get_running_loop() != self.loop: + # Complete hack. Server whines about different event loops. This seems to shut it up, + # but means we don't get the result of the operation. I think this is okay, because + # we don't seem to care either way + self.loop.create_task(self._send(data)) + return True + return await self._send(data) async def _send(self, data: dict[str, object]) -> bool: try: if not self.is_alive: return False if self.sio: - await wait_all( - self.sio.emit("oh_event", data, to=connection_id) - for connection_id in self.connection_ids - ) + #await self.loop.create_task(self.sio.emit("oh_event", data, to=ROOM_KEY.format(sid=self.sid))) + await self.sio.emit("oh_event", data, to=ROOM_KEY.format(sid=self.sid)) await asyncio.sleep(0.001) # This flushes the data to the client self.last_active_ts = int(time.time()) return True