Skip to content

Commit

Permalink
Reorder disconnect
Browse files Browse the repository at this point in the history
  • Loading branch information
tofarr committed Nov 22, 2024
1 parent 38462d8 commit dd5fc6f
Showing 1 changed file with 6 additions and 11 deletions.
17 changes: 6 additions & 11 deletions openhands/server/session/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,30 +145,24 @@ async def disconnect_from_session(self, connection_id: str):
# This can occur if the init action was never run.
logger.warning(f'disconnect_from_uninitialized_session:{connection_id}')
return

# Disconnect from redis if present
redis_client = self._get_redis_client()
if redis_client:
logger.info(f'disconnect_connection_from_session:{connection_id}:{sid}')
await redis_client.lrem(_CONNECTION_KEY.format(sid=sid), 0, connection_id)

session = self.local_sessions_by_sid.get(sid)
if session:
logger.info(f'close_session:{connection_id}:{sid}')
if should_continue():
asyncio.create_task(self._close_orphaned_session_later(session))
asyncio.create_task(self._cleanup_session_later(session, connection_id))
else:
await self._close_orphaned_session(session, True)
await self._cleanup_session(session, connection_id, True)

async def _close_orphaned_session_later(self, session: Session):
async def _cleanup_session_later(self, session: Session, connection_id: str):
# Once there have been no connections to a session for a reasonable period, we close it
try:
await asyncio.sleep(self.config.sandbox.close_delay)
finally:
# If the sleep was cancelled, we still want to close these
await self._close_orphaned_session(session, False)
await self._cleanup_session(session, connection_id, False)

async def _close_orphaned_session(self, session: Session, force: bool):
async def _cleanup_session(self, session: Session, connection_id: str, force: bool):
# Get local connections
has_local_connections = next((
True for v in self.local_connection_id_to_session_id.values()
Expand All @@ -180,6 +174,7 @@ async def _close_orphaned_session(self, session: Session, force: bool):
redis_client = self._get_redis_client()
if redis_client:
key = _CONNECTION_KEY.format(sid=session.sid)
await redis_client.lrem(key, 0, connection_id)
redis_connections = await redis_client.lrange(key, 0, -1)
redis_connections = [
c.decode() for c in redis_connections
Expand Down

0 comments on commit dd5fc6f

Please sign in to comment.