Skip to content

Commit

Permalink
Redis integration
Browse files Browse the repository at this point in the history
  • Loading branch information
tofarr committed Nov 20, 2024
1 parent 84dd0f2 commit e71deb9
Show file tree
Hide file tree
Showing 5 changed files with 153 additions and 74 deletions.
1 change: 0 additions & 1 deletion list.txt

This file was deleted.

42 changes: 22 additions & 20 deletions openhands/server/listen.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from contextlib import asynccontextmanager
import os
import re
import tempfile
Expand Down Expand Up @@ -73,10 +74,25 @@

config = load_app_config()
file_store = get_file_store(config.file_store, config.file_store_path)
session_manager = SessionManager(config, file_store)
client_manager = None
redis_host = os.environ.get('REDIS_HOST')
if redis_host:
client_manager = socketio.AsyncRedisManager(
f'redis://{redis_host}',
redis_options={'password': os.environ.get('REDIS_PASSWORD')},
)
sio = socketio.AsyncServer(
async_mode='asgi', cors_allowed_origins='*', client_manager=client_manager
)
session_manager = SessionManager(sio, config, file_store)


@asynccontextmanager
async def _lifespan(app: FastAPI):
async with session_manager:
yield

app = FastAPI()
app = FastAPI(lifespan=_lifespan)
app.add_middleware(
LocalhostCORSMiddleware,
allow_credentials=True,
Expand Down Expand Up @@ -840,16 +856,6 @@ async def get_response(self, path: str, scope):

app.mount('/', SPAStaticFiles(directory='./frontend/build', html=True), name='dist')

client_manager = None
redis_host = os.environ.get('REDIS_HOST')
if redis_host:
client_manager = socketio.AsyncRedisManager(
f'redis://{redis_host}',
redis_options={'password': os.environ.get('REDIS_PASSWORD')},
)
sio = socketio.AsyncServer(
async_mode='asgi', cors_allowed_origins='*', client_manager=client_manager
)
app = socketio.ASGIApp(sio, other_asgi_app=app)


Expand Down Expand Up @@ -923,9 +929,7 @@ async def oh_action(connection_id: str, data: dict):
return

logger.info(f'sio:oh_action:{connection_id}')
session = session_manager.get_local_session(connection_id)
await session.dispatch(data)

await session_manager.send_to_event_stream(connection_id, data)

async def init_connection(connection_id: str, data: dict):
gh_token = data.pop('github_token', None)
Expand All @@ -949,13 +953,11 @@ async def init_connection(connection_id: str, data: dict):
latest_event_id = int(data.pop('latest_event_id', -1))

# The session in question should exist, but may not actually be running locally...
session = await session_manager.init_or_join_local_session(
sio, sid, connection_id, data
)
event_stream = await session_manager.init_or_join_session(sid, connection_id, data)

# Send events
async_stream = AsyncEventStreamWrapper(
session.agent_session.event_stream, latest_event_id + 1
event_stream, latest_event_id + 1
)
async for event in async_stream:
if isinstance(
Expand All @@ -973,4 +975,4 @@ async def init_connection(connection_id: str, data: dict):
@sio.event
async def disconnect(connection_id: str):
logger.info(f'sio:disconnect:{connection_id}')
await session_manager.disconnect_from_local_session(connection_id)
await session_manager.disconnect_from_session(connection_id)
8 changes: 2 additions & 6 deletions openhands/server/session/agent_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from openhands.runtime.base import Runtime, RuntimeUnavailableError
from openhands.security import SecurityAnalyzer, options
from openhands.storage.files import FileStore
from openhands.utils.async_utils import call_async_from_sync


class AgentSession:
Expand Down Expand Up @@ -129,13 +130,8 @@ def close(self):
"""Closes the Agent session"""
if self._closed:
return

self._closed = True

def inner_close():
asyncio.run(self._close())

asyncio.get_event_loop().run_in_executor(None, inner_close)
call_async_from_sync(self._close)

async def _close(self):
if self.controller is not None:
Expand Down
147 changes: 117 additions & 30 deletions openhands/server/session/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,23 +7,57 @@
from openhands.core.config import AppConfig
from openhands.core.logger import openhands_logger as logger
from openhands.core.schema.agent import AgentState
from openhands.events.event import EventSource
from openhands.events.observation.agent import AgentStateChangedObservation
from openhands.events.serialization.event import event_to_dict
from openhands.events.stream import session_exists
from openhands.events.stream import EventStream, session_exists
from openhands.runtime.base import RuntimeUnavailableError
from openhands.server.session.conversation import Conversation
from openhands.server.session.session import Session
from openhands.server.session.session import ROOM_KEY, Session
from openhands.storage.files import FileStore
from openhands.utils.shutdown_listener import should_continue

_CONNECTION_KEY = "oh_session:{sid}"


@dataclass
class SessionManager:
sio: socketio.AsyncServer
config: AppConfig
file_store: FileStore
local_sessions_by_sid: dict[str, Session] = field(default_factory=dict)
local_sessions_by_connection_id: dict[str, Session] = field(default_factory=dict)
local_connection_id_to_session_id: dict[str, str] = field(default_factory=dict)
_redis_listen: bool = False

async def __aenter__(self):
redis_client = self._get_redis_client()
if redis_client:
self._redis_listen_task = asyncio.create_task(self._redis_subscribe())
return self

async def __aexit__(self, exc_type, exc_value, traceback):
self._redis_listen_task.cancel()

def _get_redis_client(self):
redis_client = getattr(self.sio.manager, "redis")
return redis_client

async def _redis_subscribe(self):
"""
We use a redis backchannel to send actions between server nodes
"""
redis_client = self._get_redis_client()
pubsub = redis_client.pubsub()
await pubsub.subscribe("oh_event")
while should_continue():
try:
message = await pubsub.get_message(ignore_subscribe_messages=True, timeout=None)
if message:
sid = message["sid"]
session = self.local_sessions_by_sid.get(sid)
if session:
session.dispatch(message["data"])
except asyncio.CancelledError:
return

async def attach_to_conversation(self, sid: str) -> Conversation | None:
start_time = time.time()
Expand All @@ -44,45 +78,98 @@ async def attach_to_conversation(self, sid: str) -> Conversation | None:
async def detach_from_conversation(self, conversation: Conversation):
await conversation.disconnect()

async def init_or_join_local_session(self, sio: socketio.AsyncServer, sid: str, connection_id: str, data: dict):
""" If there is no local session running, initialize one """
async def init_or_join_session(self, sid: str, connection_id: str, data: dict):
await self.sio.enter_room(connection_id, ROOM_KEY.format(sid=sid))
self.local_connection_id_to_session_id[connection_id] = sid

# If we have a local session running, use that
session = self.local_sessions_by_sid.get(sid)
if not session:
# I think we need to rehydrate here, but it does not seem to be working
session = Session(
sid=sid, file_store=self.file_store, config=self.config, sio=sio
)
session.connect(connection_id)
self.local_sessions_by_sid[sid] = session
self.local_sessions_by_connection_id[connection_id] = session
await session.initialize_agent(data)
else:
session.connect(connection_id)
self.local_sessions_by_connection_id[connection_id] = session
session.agent_session.event_stream.add_event(AgentStateChangedObservation('', AgentState.INIT), EventSource.ENVIRONMENT)
return session
if session:
self.sio.emit(event_to_dict(AgentStateChangedObservation('', AgentState.INIT)), to=connection_id)
return session.agent_session.event_stream

# If there is a remote session running, mark a connection to that
redis_client = self._get_redis_client()
if redis_client:
num_connections = await redis_client.rpush(_CONNECTION_KEY.format(sid=sid), connection_id)
# More than one remote connection implies session is already running remotely...
if num_connections != 1:
await self.sio.emit(event_to_dict(AgentStateChangedObservation('', AgentState.INIT)), to=connection_id)
event_stream = EventStream(sid, self.file_store)
return event_stream

# Start a new local session
session = Session(
sid=sid, file_store=self.file_store, config=self.config, sio=self.sio
)
self.local_sessions_by_sid[sid] = session
await session.initialize_agent(data)
return session.agent_session.event_stream

def get_local_session(self, connection_id: str) -> Session:
return self.local_sessions_by_connection_id[connection_id]

async def disconnect_from_local_session(self, connection_id: str):
session = self.local_sessions_by_connection_id.pop(connection_id, None)
if not session:
async def send_to_event_stream(self, connection_id: str, data: dict):
# If there is a local session running, send to that
sid = self.local_connection_id_to_session_id[connection_id]
session = self.local_sessions_by_sid.get(sid)
if session:
await session.dispatch(data)
return

# If there is a remote session running, send to that
redis_client = self._get_redis_client()
if redis_client:
await redis_client.publish("oh_event", {
"sid": sid,
"data": data
})
return

raise RuntimeError(f'no_connected_session:{sid}')

async def disconnect_from_session(self, connection_id: str):
sid = self.local_connection_id_to_session_id.pop(connection_id, None)
if not sid:
# This can occur if the init action was never run.
logger.warning(f'disconnect_from_uninitialized_session:{connection_id}')
return
if session.disconnect(connection_id):

# Disconnect from redis if present
redis_client = self._get_redis_client()
if redis_client:
await redis_client.lrem(_CONNECTION_KEY.format(sid=sid), 0, connection_id)

session = self.local_sessions_by_sid.get(sid)
if session:
if should_continue():
asyncio.create_task(self._check_and_close_session(session))
asyncio.create_task(self._check_and_close_session_later(session))
else:
await self._check_and_close_session(session)

async def _check_and_close_session(self, session: Session):
async def _check_and_close_session_later(self, session: Session):
# Once there have been no connections to a session for a reasonable period, we close it
try:
await asyncio.sleep(self.config.sandbox.close_delay)
finally:
# If the sleep was cancelled, we still want to close these
if not session.connection_ids:
await self._check_and_close_session(session)

async def _check_and_close_session(self, session: Session):
# Get local connections
has_connections_for_session = next((
True for v in self.local_connection_id_to_session_id.values()
if v == session.sid
), False)

# If no local connections, get connections through redis
if not has_connections_for_session:
redis_client = self._get_redis_client()
if redis_client:
key = _CONNECTION_KEY.format(sid=session.sid)
has_connections_for_session = bool(await redis_client.get(key))
if not has_connections_for_session:
await redis_client.delete(key)

# If no connections, close session
if not has_connections_for_session:
session.close()
self.local_sessions_by_sid.pop(session.sid, None)
self.local_sessions_by_sid.pop(session.sid, None)
29 changes: 12 additions & 17 deletions openhands/server/session/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,14 @@
from openhands.llm.llm import LLM
from openhands.server.session.agent_session import AgentSession
from openhands.storage.files import FileStore
from openhands.utils.async_utils import wait_all
from openhands.utils.async_utils import call_coro_in_bg_thread

ROOM_KEY = "room:{sid}"


class Session:
sid: str
sio: socketio.AsyncServer | None
connection_ids: set[str]
last_active_ts: int = 0
is_alive: bool = True
agent_session: AgentSession
Expand All @@ -49,16 +50,8 @@ def __init__(
EventStreamSubscriber.SERVER, self.on_event, self.sid
)
self.config = config
self.connection_ids = set()
self.loop = asyncio.get_event_loop()

def connect(self, connection_id: str):
self.connection_ids.add(connection_id)

def disconnect(self, connection_id: str) -> bool:
self.connection_ids.remove(connection_id)
return not self.connection_ids

def close(self):
self.is_alive = False
self.agent_session.close()
Expand Down Expand Up @@ -163,19 +156,21 @@ async def dispatch(self, data: dict):
self.agent_session.event_stream.add_event(event, EventSource.USER)

async def send(self, data: dict[str, object]) -> bool:
task = self.loop.create_task(self._send(data))
await task
return task.result()
if asyncio.get_running_loop() != self.loop:
# Complete hack. Server whines about different event loops. This seems to shut it up,
# but means we don't get the result of the operation. I think this is okay, because
# we don't seem to care either way
self.loop.create_task(self._send(data))
return True
return await self._send(data)

async def _send(self, data: dict[str, object]) -> bool:
try:
if not self.is_alive:
return False
if self.sio:
await wait_all(
self.sio.emit("oh_event", data, to=connection_id)
for connection_id in self.connection_ids
)
#await self.loop.create_task(self.sio.emit("oh_event", data, to=ROOM_KEY.format(sid=self.sid)))
await self.sio.emit("oh_event", data, to=ROOM_KEY.format(sid=self.sid))
await asyncio.sleep(0.001) # This flushes the data to the client
self.last_active_ts = int(time.time())
return True
Expand Down

0 comments on commit e71deb9

Please sign in to comment.