Skip to content

Commit

Permalink
Merge branch 'main' into feature/condenser-refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
csmith49 authored Jan 6, 2025
2 parents 5beba03 + 343b864 commit c6acf01
Show file tree
Hide file tree
Showing 12 changed files with 58 additions and 64 deletions.
9 changes: 1 addition & 8 deletions frontend/src/context/ws-client-provider.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -62,16 +62,13 @@ const WsClientContext = React.createContext<UseWsClient>({

interface WsClientProviderProps {
conversationId: string;
ghToken: string | null;
}

export function WsClientProvider({
ghToken,
conversationId,
children,
}: React.PropsWithChildren<WsClientProviderProps>) {
const sioRef = React.useRef<Socket | null>(null);
const ghTokenRef = React.useRef<string | null>(ghToken);
const [status, setStatus] = React.useState(
WsClientProviderStatus.DISCONNECTED,
);
Expand Down Expand Up @@ -141,9 +138,6 @@ export function WsClientProvider({

sio = io(baseUrl, {
transports: ["websocket"],
auth: {
github_token: ghToken || undefined,
},
query,
});
sio.on("connect", handleConnect);
Expand All @@ -153,7 +147,6 @@ export function WsClientProvider({
sio.on("disconnect", handleDisconnect);

sioRef.current = sio;
ghTokenRef.current = ghToken;

return () => {
sio.off("connect", handleConnect);
Expand All @@ -162,7 +155,7 @@ export function WsClientProvider({
sio.off("connect_failed", handleError);
sio.off("disconnect", handleDisconnect);
};
}, [ghToken, conversationId]);
}, [conversationId]);

React.useEffect(
() => () => {
Expand Down
2 changes: 1 addition & 1 deletion frontend/src/routes/_oh.app/route.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ function AppContent() {
}

return (
<WsClientProvider ghToken={gitHubToken} conversationId={conversationId}>
<WsClientProvider conversationId={conversationId}>
<EventHandler>
<div data-testid="app-route" className="flex flex-col h-full gap-3">
<div className="flex h-full overflow-auto">{renderMain()}</div>
Expand Down
5 changes: 5 additions & 0 deletions openhands/server/auth.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,14 @@
import jwt
from fastapi import Request
from jwt.exceptions import InvalidTokenError

from openhands.core.logger import openhands_logger as logger


def get_user_id(request: Request) -> int:
return getattr(request.state, 'github_user_id', 0)


def get_sid_from_token(token: str, jwt_secret: str) -> str:
"""Retrieves the session id from a JWT token.
Expand Down
25 changes: 12 additions & 13 deletions openhands/server/listen_socket.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from urllib.parse import parse_qs

from github import Github
import jwt
from socketio.exceptions import ConnectionRefusedError

from openhands.core.logger import openhands_logger as logger
Expand All @@ -18,7 +18,6 @@
from openhands.server.session.manager import ConversationDoesNotExistError
from openhands.server.shared import config, openhands_config, session_manager, sio
from openhands.server.types import AppMode
from openhands.utils.async_utils import call_sync_from_async


@sio.event
Expand All @@ -31,20 +30,20 @@ async def connect(connection_id: str, environ, auth):
logger.error('No conversation_id in query params')
raise ConnectionRefusedError('No conversation_id in query params')

github_token = ''
user_id = -1
if openhands_config.app_mode != AppMode.OSS:
user_id = ''
if auth and 'github_token' in auth:
github_token = auth['github_token']
with Github(github_token) as g:
gh_user = await call_sync_from_async(g.get_user)
user_id = gh_user.id
cookies_str = environ.get('HTTP_COOKIE', '')
cookies = dict(cookie.split('=', 1) for cookie in cookies_str.split('; '))
signed_token = cookies.get('github_auth', '')
if not signed_token:
logger.error('No github_auth cookie')
raise ConnectionRefusedError('No github_auth cookie')
decoded = jwt.decode(signed_token, config.jwt_secret, algorithms=['HS256'])
user_id = decoded['github_user_id']

logger.info(f'User {user_id} is connecting to conversation {conversation_id}')

conversation_store = await ConversationStoreImpl.get_instance(
config, github_token
)
conversation_store = await ConversationStoreImpl.get_instance(config, user_id)
metadata = await conversation_store.get_metadata(conversation_id)
if metadata.github_user_id != user_id:
logger.error(
Expand All @@ -54,7 +53,7 @@ async def connect(connection_id: str, environ, auth):
f'User {user_id} is not allowed to join conversation {conversation_id}'
)

settings_store = await SettingsStoreImpl.get_instance(config, github_token)
settings_store = await SettingsStoreImpl.get_instance(config, user_id)
settings = await settings_store.load()

if not settings:
Expand Down
52 changes: 24 additions & 28 deletions openhands/server/routes/manage_conversations.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,11 @@

from fastapi import APIRouter, Body, Request
from fastapi.responses import JSONResponse
from github import Github
from pydantic import BaseModel

from openhands.core.logger import openhands_logger as logger
from openhands.events.stream import EventStreamSubscriber
from openhands.server.auth import get_user_id
from openhands.server.routes.settings import ConversationStoreImpl, SettingsStoreImpl
from openhands.server.session.conversation_init_data import ConversationInitData
from openhands.server.shared import config, session_manager
Expand All @@ -21,7 +21,6 @@
from openhands.utils.async_utils import (
GENERAL_TIMEOUT,
call_async_from_sync,
call_sync_from_async,
wait_all,
)

Expand All @@ -43,22 +42,24 @@ async def new_conversation(request: Request, data: InitSessionRequest):
using the returned conversation ID
"""
logger.info('Initializing new conversation')
github_token = data.github_token or ''

logger.info('Loading settings')
settings_store = await SettingsStoreImpl.get_instance(config, github_token)
settings_store = await SettingsStoreImpl.get_instance(config, get_user_id(request))
settings = await settings_store.load()
logger.info('Settings loaded')

session_init_args: dict = {}
if settings:
session_init_args = {**settings.__dict__, **session_init_args}

github_token = getattr(request.state, 'github_token', '')
session_init_args['github_token'] = github_token
session_init_args['selected_repository'] = data.selected_repository
conversation_init_data = ConversationInitData(**session_init_args)
logger.info('Loading conversation store')
conversation_store = await ConversationStoreImpl.get_instance(config, github_token)
conversation_store = await ConversationStoreImpl.get_instance(
config, get_user_id(request)
)
logger.info('Conversation store loaded')

conversation_id = uuid.uuid4().hex
Expand All @@ -67,18 +68,11 @@ async def new_conversation(request: Request, data: InitSessionRequest):
conversation_id = uuid.uuid4().hex
logger.info(f'New conversation ID: {conversation_id}')

user_id = ''
if data.github_token:
logger.info('Fetching Github user ID')
with Github(data.github_token) as g:
gh_user = await call_sync_from_async(g.get_user)
user_id = gh_user.id

logger.info(f'Saving metadata for conversation {conversation_id}')
await conversation_store.save_metadata(
ConversationMetadata(
conversation_id=conversation_id,
github_user_id=user_id,
github_user_id=get_user_id(request),
selected_repository=data.selected_repository,
)
)
Expand All @@ -90,9 +84,7 @@ async def new_conversation(request: Request, data: InitSessionRequest):
try:
event_stream.subscribe(
EventStreamSubscriber.SERVER,
_create_conversation_update_callback(
data.github_token or '', conversation_id
),
_create_conversation_update_callback(get_user_id(request), conversation_id),
UPDATED_AT_CALLBACK_ID,
)
except ValueError:
Expand All @@ -107,8 +99,9 @@ async def search_conversations(
page_id: str | None = None,
limit: int = 20,
) -> ConversationInfoResultSet:
github_token = getattr(request.state, 'github_token', '') or ''
conversation_store = await ConversationStoreImpl.get_instance(config, github_token)
conversation_store = await ConversationStoreImpl.get_instance(
config, get_user_id(request)
)
conversation_metadata_result_set = await conversation_store.search(page_id, limit)
conversation_ids = set(
conversation.conversation_id
Expand All @@ -134,8 +127,9 @@ async def search_conversations(
async def get_conversation(
conversation_id: str, request: Request
) -> ConversationInfo | None:
github_token = getattr(request.state, 'github_token', '') or ''
conversation_store = await ConversationStoreImpl.get_instance(config, github_token)
conversation_store = await ConversationStoreImpl.get_instance(
config, get_user_id(request)
)
try:
metadata = await conversation_store.get_metadata(conversation_id)
is_running = await session_manager.is_agent_loop_running(conversation_id)
Expand All @@ -149,8 +143,9 @@ async def get_conversation(
async def update_conversation(
request: Request, conversation_id: str, title: str = Body(embed=True)
) -> bool:
github_token = getattr(request.state, 'github_token', '') or ''
conversation_store = await ConversationStoreImpl.get_instance(config, github_token)
conversation_store = await ConversationStoreImpl.get_instance(
config, get_user_id(request)
)
metadata = await conversation_store.get_metadata(conversation_id)
if not metadata:
return False
Expand All @@ -164,8 +159,9 @@ async def delete_conversation(
conversation_id: str,
request: Request,
) -> bool:
github_token = getattr(request.state, 'github_token', '') or ''
conversation_store = await ConversationStoreImpl.get_instance(config, github_token)
conversation_store = await ConversationStoreImpl.get_instance(
config, get_user_id(request)
)
try:
await conversation_store.get_metadata(conversation_id)
except FileNotFoundError:
Expand Down Expand Up @@ -205,21 +201,21 @@ async def _get_conversation_info(


def _create_conversation_update_callback(
github_token: str, conversation_id: str
user_id: int, conversation_id: str
) -> Callable:
def callback(*args, **kwargs):
call_async_from_sync(
_update_timestamp_for_conversation,
GENERAL_TIMEOUT,
github_token,
user_id,
conversation_id,
)

return callback


async def _update_timestamp_for_conversation(github_token: str, conversation_id: str):
conversation_store = await ConversationStoreImpl.get_instance(config, github_token)
async def _update_timestamp_for_conversation(user_id: int, conversation_id: str):
conversation_store = await ConversationStoreImpl.get_instance(config, user_id)
conversation = await conversation_store.get_metadata(conversation_id)
conversation.last_updated_at = datetime.now()
await conversation_store.save_metadata(conversation)
13 changes: 7 additions & 6 deletions openhands/server/routes/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from fastapi.responses import JSONResponse

from openhands.core.logger import openhands_logger as logger
from openhands.server.auth import get_user_id
from openhands.server.settings import Settings
from openhands.server.shared import config, openhands_config
from openhands.storage.conversation.conversation_store import ConversationStore
Expand All @@ -19,9 +20,10 @@

@app.get('/settings')
async def load_settings(request: Request) -> Settings | None:
github_token = getattr(request.state, 'github_token', '') or ''
try:
settings_store = await SettingsStoreImpl.get_instance(config, github_token)
settings_store = await SettingsStoreImpl.get_instance(
config, get_user_id(request)
)
settings = await settings_store.load()
if not settings:
return JSONResponse(
Expand All @@ -45,11 +47,10 @@ async def store_settings(
request: Request,
settings: Settings,
) -> JSONResponse:
github_token = ''
if hasattr(request.state, 'github_token'):
github_token = request.state.github_token
try:
settings_store = await SettingsStoreImpl.get_instance(config, github_token)
settings_store = await SettingsStoreImpl.get_instance(
config, get_user_id(request)
)
existing_settings = await settings_store.load()

if existing_settings:
Expand Down
4 changes: 1 addition & 3 deletions openhands/storage/conversation/conversation_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,5 @@ async def search(

@classmethod
@abstractmethod
async def get_instance(
cls, config: AppConfig, token: str | None
) -> ConversationStore:
async def get_instance(cls, config: AppConfig, user_id: int) -> ConversationStore:
"""Get a store for the user represented by the token given"""
4 changes: 3 additions & 1 deletion openhands/storage/conversation/file_conversation_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,9 @@ def get_conversation_metadata_filename(self, conversation_id: str) -> str:
return get_conversation_metadata_filename(conversation_id)

@classmethod
async def get_instance(cls, config: AppConfig, token: str | None):
async def get_instance(
cls, config: AppConfig, user_id: int
) -> FileConversationStore:
file_store = get_file_store(config.file_store, config.file_store_path)
return FileConversationStore(file_store)

Expand Down
2 changes: 1 addition & 1 deletion openhands/storage/data_models/conversation_metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
@dataclass
class ConversationMetadata:
conversation_id: str
github_user_id: int | str
github_user_id: int
selected_repository: str | None
title: str | None = None
last_updated_at: datetime | None = None
Expand Down
2 changes: 1 addition & 1 deletion openhands/storage/settings/file_settings_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,6 @@ async def store(self, settings: Settings):
await call_sync_from_async(self.file_store.write, self.path, json_str)

@classmethod
async def get_instance(cls, config: AppConfig, token: str | None):
async def get_instance(cls, config: AppConfig, user_id: int) -> FileSettingsStore:
file_store = get_file_store(config.file_store, config.file_store_path)
return FileSettingsStore(file_store)
2 changes: 1 addition & 1 deletion openhands/storage/settings/settings_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,5 +21,5 @@ async def store(self, settings: Settings):

@classmethod
@abstractmethod
async def get_instance(cls, config: AppConfig, token: str | None) -> SettingsStore:
async def get_instance(cls, config: AppConfig, user_id: int) -> SettingsStore:
"""Get a store for the user represented by the token given"""
2 changes: 1 addition & 1 deletion tests/unit/test_conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def _patch_store():
'title': 'Some Conversation',
'selected_repository': 'foobar',
'conversation_id': 'some_conversation_id',
'github_user_id': 'github_user',
'github_user_id': 12345,
'created_at': '2025-01-01T00:00:00',
'last_updated_at': '2025-01-01T00:01:00',
}
Expand Down

0 comments on commit c6acf01

Please sign in to comment.