Skip to content

Commit

Permalink
refactor: move session initialization to separate router
Browse files Browse the repository at this point in the history
- Create new session.py router for POST /api/conversation endpoint
- Move session initialization code to new router
- Fix middleware application to use protected_router instead of main router
- Add comment to clarify middleware application
  • Loading branch information
openhands-agent committed Dec 9, 2024
1 parent dcafd16 commit c26fb9f
Show file tree
Hide file tree
Showing 3 changed files with 61 additions and 51 deletions.
10 changes: 8 additions & 2 deletions openhands/server/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,15 @@
RateLimitMiddleware,
)
from openhands.server.routes.auth import app as auth_api_router
from openhands.server.routes.conversation import app as conversation_api_router
from openhands.server.routes.conversation import (
app as conversation_api_router,
protected_router as conversation_protected_router,
)
from openhands.server.routes.feedback import app as feedback_api_router
from openhands.server.routes.files import app as files_api_router
from openhands.server.routes.public import app as public_api_router
from openhands.server.routes.security import app as security_api_router
from openhands.server.routes.session import app as session_api_router
from openhands.server.shared import config
from openhands.utils.import_utils import get_impl

Expand Down Expand Up @@ -49,15 +53,17 @@ async def health():
app.include_router(conversation_api_router)
app.include_router(security_api_router)
app.include_router(feedback_api_router)
app.include_router(session_api_router)

AttachConversationMiddlewareImpl = get_impl(
AttachConversationMiddleware, config.attach_session_middleware_class
)
# Apply the middleware to protected routers
app.middleware('http')(
AttachConversationMiddlewareImpl(app, target_router=files_api_router)
)
app.middleware('http')(
AttachConversationMiddlewareImpl(app, target_router=conversation_api_router)
AttachConversationMiddlewareImpl(app, target_router=conversation_protected_router)
)
app.middleware('http')(
AttachConversationMiddlewareImpl(app, target_router=security_api_router)
Expand Down
50 changes: 1 addition & 49 deletions openhands/server/routes/conversation.py
Original file line number Diff line number Diff line change
@@ -1,60 +1,13 @@
from fastapi import APIRouter, HTTPException, Request, status
from fastapi.responses import JSONResponse
from pydantic import BaseModel

from openhands.core.logger import openhands_logger as logger
from openhands.runtime.base import Runtime
from openhands.server.listen_socket import init_connection
from openhands.server.middleware import AttachConversationMiddleware
from openhands.server.session.session_init_data import SessionInitData

app = APIRouter(prefix='/api')
protected_router = APIRouter(prefix='/api/conversation/{convo_id}')


class InitSessionRequest(BaseModel):
token: str | None = None
github_token: str | None = None
latest_event_id: int = -1
args: dict | None = None
selected_repository: str | None = None


@app.post('/conversation')
async def init_session(request: Request, data: InitSessionRequest):
"""Initialize a new session or join an existing one.
This endpoint replaces the WebSocket INIT event with a REST API call.
After successful initialization, the client should connect to the WebSocket
using the returned token.
"""
kwargs = {k.lower(): v for k, v in (data.args or {}).items()}
session_init_data = SessionInitData(**kwargs)
session_init_data.github_token = data.github_token
session_init_data.selected_repository = data.selected_repository

# Generate a temporary connection ID for initialization
connection_id = f"temp_{data.token or ''}"

try:
token = await init_connection(
connection_id=connection_id,
token=data.token,
gh_token=data.github_token,
session_init_data=session_init_data,
latest_event_id=data.latest_event_id,
return_token_only=True
)
return JSONResponse(content={"token": token, "status": "ok"})
except RuntimeError as e:
if str(e) == str(status.WS_1008_POLICY_VIOLATION):
return JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
content={"error": "Authentication failed"}
)
raise


@protected_router.get('/config')
async def get_remote_runtime_config(request: Request):
"""Retrieve the runtime configuration.
Expand Down Expand Up @@ -153,6 +106,5 @@ async def search_events(
'has_more': has_more,
}

# Include the protected router and apply the middleware
# Include the protected router
app.include_router(protected_router)
app.middleware('http')(AttachConversationMiddleware(app, protected_router))
52 changes: 52 additions & 0 deletions openhands/server/routes/session.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
from fastapi import APIRouter, Request, status
from fastapi.responses import JSONResponse
from pydantic import BaseModel

from openhands.core.logger import openhands_logger as logger
from openhands.server.listen_socket import init_connection
from openhands.server.session.session_init_data import SessionInitData

app = APIRouter(prefix='/api')


class InitSessionRequest(BaseModel):
token: str | None = None
github_token: str | None = None
latest_event_id: int = -1
args: dict | None = None
selected_repository: str | None = None


@app.post('/conversation')
async def init_session(request: Request, data: InitSessionRequest):
"""Initialize a new session or join an existing one.
This endpoint replaces the WebSocket INIT event with a REST API call.
After successful initialization, the client should connect to the WebSocket
using the returned token.
"""
kwargs = {k.lower(): v for k, v in (data.args or {}).items()}
session_init_data = SessionInitData(**kwargs)
session_init_data.github_token = data.github_token
session_init_data.selected_repository = data.selected_repository

# Generate a temporary connection ID for initialization
connection_id = f"temp_{data.token or ''}"

try:
token = await init_connection(
connection_id=connection_id,
token=data.token,
gh_token=data.github_token,
session_init_data=session_init_data,
latest_event_id=data.latest_event_id,
return_token_only=True
)
return JSONResponse(content={"token": token, "status": "ok"})
except RuntimeError as e:
if str(e) == str(status.WS_1008_POLICY_VIOLATION):
return JSONResponse(
status_code=status.HTTP_401_UNAUTHORIZED,
content={"error": "Authentication failed"}
)
raise

0 comments on commit c26fb9f

Please sign in to comment.