From d17df96fb2722d50a6138b2fb9f921f78b40435a Mon Sep 17 00:00:00 2001 From: openhands Date: Sun, 23 Feb 2025 15:40:28 +0000 Subject: [PATCH] fix: Add proper type hints to routes directory --- openhands/server/routes/feedback.py | 2 +- openhands/server/routes/github.py | 10 +- openhands/server/routes/public.py | 168 ++++++++++++++++++++++---- openhands/server/routes/security.py | 48 +++++++- openhands/server/routes/settings.py | 88 ++++++++++++-- openhands/server/routes/trajectory.py | 2 +- 6 files changed, 275 insertions(+), 43 deletions(-) diff --git a/openhands/server/routes/feedback.py b/openhands/server/routes/feedback.py index a18131581576..4974ea992dd1 100644 --- a/openhands/server/routes/feedback.py +++ b/openhands/server/routes/feedback.py @@ -11,7 +11,7 @@ @app.post('/submit-feedback') -async def submit_feedback(request: Request, conversation_id: str): +async def submit_feedback(request: Request, conversation_id: str) -> JSONResponse: """Submit user feedback. This function stores the provided feedback data. diff --git a/openhands/server/routes/github.py b/openhands/server/routes/github.py index d4bd1c8d5a05..34a2999f9a21 100644 --- a/openhands/server/routes/github.py +++ b/openhands/server/routes/github.py @@ -1,3 +1,5 @@ +from typing import Union + from fastapi import APIRouter, Depends, status from fastapi.responses import JSONResponse from pydantic import SecretStr @@ -22,7 +24,7 @@ async def get_github_repositories( installation_id: int | None = None, github_user_id: str | None = Depends(get_user_id), github_user_token: SecretStr | None = Depends(get_github_token), -): +) -> Union[list[GitHubRepository], JSONResponse]: client = GithubServiceImpl(user_id=github_user_id, token=github_user_token) try: repos: list[GitHubRepository] = await client.get_repositories( @@ -47,7 +49,7 @@ async def get_github_repositories( async def get_github_user( github_user_id: str | None = Depends(get_user_id), github_user_token: SecretStr | None = Depends(get_github_token), -): +) -> Union[GitHubUser, JSONResponse]: client = GithubServiceImpl(user_id=github_user_id, token=github_user_token) try: user: GitHubUser = await client.get_user() @@ -70,7 +72,7 @@ async def get_github_user( async def get_github_installation_ids( github_user_id: str | None = Depends(get_user_id), github_user_token: SecretStr | None = Depends(get_github_token), -): +) -> Union[list[int], JSONResponse]: client = GithubServiceImpl(user_id=github_user_id, token=github_user_token) try: installations_ids: list[int] = await client.get_installation_ids() @@ -97,7 +99,7 @@ async def search_github_repositories( order: str = 'desc', github_user_id: str | None = Depends(get_user_id), github_user_token: SecretStr | None = Depends(get_github_token), -): +) -> Union[list[GitHubRepository], JSONResponse]: client = GithubServiceImpl(user_id=github_user_id, token=github_user_token) try: repos: list[GitHubRepository] = await client.search_repositories( diff --git a/openhands/server/routes/public.py b/openhands/server/routes/public.py index 59e5c4e4efe6..e139c3da2511 100644 --- a/openhands/server/routes/public.py +++ b/openhands/server/routes/public.py @@ -1,6 +1,10 @@ import warnings +from typing import Annotated, Any, cast import requests +from fastapi import APIRouter, Depends, Response +from fastapi.responses import JSONResponse +from fastapi.routing import APIRoute from openhands.security.options import SecurityAnalyzers @@ -8,10 +12,6 @@ warnings.simplefilter('ignore') import litellm -from fastapi import ( - APIRouter, -) - from openhands.controller.agent import Agent from openhands.core.config import LLMConfig from openhands.core.logger import openhands_logger as logger @@ -21,20 +21,14 @@ app = APIRouter(prefix='/api/options') -@app.get('/models') -async def get_litellm_models() -> list[str]: +async def get_litellm_models_list() -> list[str]: """Get all models supported by LiteLLM. This function combines models from litellm and Bedrock, removing any error-prone Bedrock models. - To get the models: - ```sh - curl http://localhost:3000/api/litellm-models - ``` - Returns: - list: A sorted list of unique model names. + list[str]: A sorted list of unique model names. """ litellm_model_list = litellm.model_list + list(litellm.model_cost.keys()) litellm_model_list_without_bedrock = bedrock.remove_error_modelId( @@ -74,8 +68,65 @@ async def get_litellm_models() -> list[str]: return list(sorted(set(model_list))) -@app.get('/agents') -async def get_agents(): +def get_models_route() -> APIRoute: + """Get the route for getting models. + + Returns: + APIRoute: The route for getting models. + """ + return cast( + APIRoute, + app.get('/models', response_model=list[str]), + ) + + +async def get_litellm_models( + models: Annotated[list[str], Depends(get_litellm_models_list)], +) -> list[str]: + """Get all models supported by LiteLLM. + + To get the models: + ```sh + curl http://localhost:3000/api/litellm-models + ``` + + Args: + models (list[str]): The list of models from get_litellm_models_list. + + Returns: + list[str]: A sorted list of unique model names. + """ + return models + + +models_route = get_models_route() +models_route.endpoint = get_litellm_models + + +async def get_agents_list() -> list[str]: + """Get all agents supported by LiteLLM. + + Returns: + list[str]: A sorted list of agent names. + """ + return sorted(Agent.list_agents()) + + +def get_agents_route() -> APIRoute: + """Get the route for getting agents. + + Returns: + APIRoute: The route for getting agents. + """ + return cast( + APIRoute, + app.get('/agents', response_model=list[str]), + ) + + +async def get_agents( + agents: Annotated[list[str], Depends(get_agents_list)], +) -> list[str]: """Get all agents supported by LiteLLM. To get the agents: @@ -83,15 +134,43 @@ async def get_agents(): curl http://localhost:3000/api/agents ``` + Args: + agents (list[str]): The list of agents from get_agents_list. + Returns: - list: A sorted list of agent names. + list[str]: A sorted list of agent names. """ - agents = sorted(Agent.list_agents()) return agents -@app.get('/security-analyzers') -async def get_security_analyzers(): +agents_route = get_agents_route() +agents_route.endpoint = get_agents + + +async def get_security_analyzers_list() -> list[str]: + """Get all supported security analyzers. + + Returns: + list[str]: A sorted list of security analyzer names. + """ + return sorted(SecurityAnalyzers.keys()) + + +def get_analyzers_route() -> APIRoute: + """Get the route for getting security analyzers. + + Returns: + APIRoute: The route for getting security analyzers. + """ + return cast( + APIRoute, + app.get('/security-analyzers', response_model=list[str]), + ) + + +async def get_security_analyzers( + analyzers: Annotated[list[str], Depends(get_security_analyzers_list)], +) -> list[str]: """Get all supported security analyzers. To get the security analyzers: @@ -99,16 +178,57 @@ async def get_security_analyzers(): curl http://localhost:3000/api/security-analyzers ``` + Args: + analyzers (list[str]): The list of analyzers from get_security_analyzers_list. + Returns: - list: A sorted list of security analyzer names. + list[str]: A sorted list of security analyzer names. """ - return sorted(SecurityAnalyzers.keys()) + return analyzers + + +analyzers_route = get_analyzers_route() +analyzers_route.endpoint = get_security_analyzers + + +async def get_server_config() -> Response: + """Get current config. + + Returns: + Response: The current server configuration. + """ + config_data = server_config.get_config() + return JSONResponse( + status_code=200, + content=config_data, + ) -@app.get('/config') -async def get_config(): +def get_config_route() -> APIRoute: + """Get the route for getting config. + + Returns: + APIRoute: The route for getting config. """ - Get current config + return cast( + APIRoute, + app.get('/config', response_model=dict[str, Any]), + ) + + +async def get_config( + response: Annotated[Response, Depends(get_server_config)], +) -> Response: + """Get current config. + + Args: + response (Response): The response from get_server_config. + + Returns: + Response: The current server configuration. """ + return response + - return server_config.get_config() +config_route = get_config_route() +config_route.endpoint = get_config diff --git a/openhands/server/routes/security.py b/openhands/server/routes/security.py index 719cfa472807..ad69af3868fb 100644 --- a/openhands/server/routes/security.py +++ b/openhands/server/routes/security.py @@ -1,24 +1,26 @@ +from typing import Annotated, cast + from fastapi import ( APIRouter, + Depends, HTTPException, Request, + Response, status, ) +from fastapi.routing import APIRoute app = APIRouter(prefix='/api/conversations/{conversation_id}') -@app.route('/security/{path:path}', methods=['GET', 'POST', 'PUT', 'DELETE']) -async def security_api(request: Request): - """Catch-all route for security analyzer API requests. - - Each request is handled directly to the security analyzer. +async def get_response(request: Request) -> Response: + """Get response from security analyzer. Args: request (Request): The incoming FastAPI request object. Returns: - Any: The response from the security analyzer. + Response: The response from the security analyzer. Raises: HTTPException: If the security analyzer is not initialized. @@ -32,3 +34,37 @@ async def security_api(request: Request): return await request.state.conversation.security_analyzer.handle_api_request( request ) + + +def get_route() -> APIRoute: + """Get the route for the security API. + + Returns: + APIRoute: The route for the security API. + """ + return cast( + APIRoute, + app.api_route( + '/security/{path:path}', + methods=['GET', 'POST', 'PUT', 'DELETE'], + response_model=None, + ), + ) + + +async def security_api( + response: Annotated[Response, Depends(get_response)], +) -> Response: + """Catch-all route for security analyzer API requests. + + Args: + response (Response): The response from the security analyzer. + + Returns: + Response: The response from the security analyzer. + """ + return response + + +route = get_route() +route.endpoint = security_api diff --git a/openhands/server/routes/settings.py b/openhands/server/routes/settings.py index 66ed76a23e33..526838b5ee2e 100644 --- a/openhands/server/routes/settings.py +++ b/openhands/server/routes/settings.py @@ -1,5 +1,8 @@ -from fastapi import APIRouter, Request, status +from typing import Annotated, cast + +from fastapi import APIRouter, Depends, Request, Response, status from fastapi.responses import JSONResponse +from fastapi.routing import APIRoute from pydantic import SecretStr from openhands.core.logger import openhands_logger as logger @@ -11,8 +14,15 @@ app = APIRouter(prefix='/api') -@app.get('/settings') -async def load_settings(request: Request) -> GETSettingsModel | None: +async def get_settings(request: Request) -> Response: + """Load user settings. + + Args: + request (Request): The incoming FastAPI request object. + + Returns: + Response: The user settings or error response. + """ try: user_id = get_user_id(request) settings_store = await SettingsStoreImpl.get_instance(config, user_id) @@ -31,7 +41,10 @@ async def load_settings(request: Request) -> GETSettingsModel | None: settings_with_token_data.llm_api_key = settings.llm_api_key del settings_with_token_data.github_token - return settings_with_token_data + return JSONResponse( + status_code=status.HTTP_200_OK, + content=settings_with_token_data.model_dump(), + ) except Exception as e: logger.warning(f'Invalid token: {e}') return JSONResponse( @@ -40,13 +53,62 @@ async def load_settings(request: Request) -> GETSettingsModel | None: ) -@app.post('/settings') +def get_get_route() -> APIRoute: + """Get the route for loading settings. + + Returns: + APIRoute: The route for loading settings. + """ + return cast( + APIRoute, + app.get('/settings', response_model=GETSettingsModel), + ) + + +async def load_settings( + response: Annotated[Response, Depends(get_settings)], +) -> Response: + """Load user settings. + + Args: + response (Response): The response from get_settings. + + Returns: + Response: The user settings or error response. + """ + return response + + +get_route = get_get_route() +get_route.endpoint = load_settings + + +def get_post_route() -> APIRoute: + """Get the route for storing settings. + + Returns: + APIRoute: The route for storing settings. + """ + return cast( + APIRoute, + app.post('/settings', response_model=dict[str, str]), + ) + + async def store_settings( request: Request, settings: POSTSettingsModel, -) -> JSONResponse: - # Check if token is valid +) -> Response: + """Store user settings. + Args: + request (Request): The incoming FastAPI request object. + settings (POSTSettingsModel): The settings to store. + + Returns: + Response: Success or error response. + """ + # Check if token is valid if settings.github_token: try: # We check if the token is valid by getting the user @@ -110,7 +172,19 @@ async def store_settings( ) +post_route = get_post_route() +post_route.endpoint = store_settings + + def convert_to_settings(settings_with_token_data: POSTSettingsModel) -> Settings: + """Convert POSTSettingsModel to Settings. + + Args: + settings_with_token_data (POSTSettingsModel): The settings to convert. + + Returns: + Settings: The converted settings. + """ settings_data = settings_with_token_data.model_dump() # Filter out additional fields from `SettingsWithTokenData` diff --git a/openhands/server/routes/trajectory.py b/openhands/server/routes/trajectory.py index b9732ede0c28..3fbea8eb739f 100644 --- a/openhands/server/routes/trajectory.py +++ b/openhands/server/routes/trajectory.py @@ -9,7 +9,7 @@ @app.get('/trajectory') -async def get_trajectory(request: Request): +async def get_trajectory(request: Request) -> JSONResponse: """Get trajectory. This function retrieves the current trajectory and returns it.