diff --git a/dev_config/python/ruff.toml b/dev_config/python/ruff.toml index af56e7e9d342..2ffc222fd980 100644 --- a/dev_config/python/ruff.toml +++ b/dev_config/python/ruff.toml @@ -24,3 +24,6 @@ inline-quotes = "single" [format] quote-style = "single" + +[lint.flake8-bugbear] +extend-immutable-calls = ["Depends", "fastapi.Depends", "fastapi.params.Depends"] diff --git a/openhands/runtime/base.py b/openhands/runtime/base.py index 6b10ac07c9bf..f511a8b92fd8 100644 --- a/openhands/runtime/base.py +++ b/openhands/runtime/base.py @@ -12,6 +12,7 @@ from typing import Callable from zipfile import ZipFile +from pydantic import SecretStr from requests.exceptions import ConnectionError from openhands.core.config import AppConfig, SandboxConfig @@ -234,12 +235,12 @@ async def _handle_action(self, event: Action) -> None: source = event.source if event.source else EventSource.AGENT self.event_stream.add_event(observation, source) # type: ignore[arg-type] - def clone_repo(self, github_token: str, selected_repository: str) -> str: + def clone_repo(self, github_token: SecretStr, selected_repository: str) -> str: if not github_token or not selected_repository: raise ValueError( 'github_token and selected_repository must be provided to clone a repository' ) - url = f'https://{github_token}@github.com/{selected_repository}.git' + url = f'https://{github_token.get_secret_value()}@github.com/{selected_repository}.git' dir_name = selected_repository.split('/')[1] # add random branch name to avoid conflicts random_str = ''.join( diff --git a/openhands/server/auth.py b/openhands/server/auth.py index d54577a66524..fa28dafbf45e 100644 --- a/openhands/server/auth.py +++ b/openhands/server/auth.py @@ -1,7 +1,8 @@ from fastapi import Request +from pydantic import SecretStr -def get_github_token(request: Request) -> str | None: +def get_github_token(request: Request) -> SecretStr | None: return getattr(request.state, 'github_token', None) diff --git a/openhands/server/config/server_config.py b/openhands/server/config/server_config.py index 456567ef5783..32b2deab4ed0 100644 --- a/openhands/server/config/server_config.py +++ b/openhands/server/config/server_config.py @@ -18,7 +18,7 @@ class ServerConfig(ServerConfigInterface): ) conversation_manager_class: str = 'openhands.server.conversation_manager.standalone_conversation_manager.StandaloneConversationManager' - github_service_class: str = 'openhands.server.services.github_service.GitHubService' + github_service_class: str = 'openhands.services.github.github_service.GitHubService' def verify_config(self): if self.config_cls: diff --git a/openhands/server/middleware.py b/openhands/server/middleware.py index cf72579197ab..734d52004bc5 100644 --- a/openhands/server/middleware.py +++ b/openhands/server/middleware.py @@ -196,7 +196,7 @@ async def __call__(self, request: Request, call_next: Callable): # TODO: To avoid checks like this we should re-add the abilty to have completely different middleware in SAAS as in OSS if getattr(request.state, 'github_token', None) is None: if settings and settings.github_token: - request.state.github_token = settings.github_token.get_secret_value() + request.state.github_token = settings.github_token else: request.state.github_token = None diff --git a/openhands/server/routes/github.py b/openhands/server/routes/github.py index 889b7a30da90..c50c1b2ca72b 100644 --- a/openhands/server/routes/github.py +++ b/openhands/server/routes/github.py @@ -1,11 +1,15 @@ from fastapi import APIRouter, Depends from fastapi.responses import JSONResponse +from pydantic import SecretStr -from openhands.server.auth import get_user_id -from openhands.server.data_models.gh_types import GitHubRepository, GitHubUser -from openhands.server.services.github_service import GitHubService +from openhands.server.auth import get_github_token, get_user_id from openhands.server.shared import server_config -from openhands.server.types import GhAuthenticationError, GHUnknownException +from openhands.services.github.github_service import ( + GhAuthenticationError, + GHUnknownException, + GitHubService, +) +from openhands.services.github.github_types import GitHubRepository, GitHubUser from openhands.utils.import_utils import get_impl app = APIRouter(prefix='/api/github') @@ -20,8 +24,9 @@ async def get_github_repositories( sort: str = 'pushed', installation_id: int | None = None, github_user_id: str | None = Depends(get_user_id), + github_user_token: SecretStr | None = Depends(get_github_token), ): - client = GithubServiceImpl(github_user_id) + client = GithubServiceImpl(user_id=github_user_id, token=github_user_token) try: repos: list[GitHubRepository] = await client.get_repositories( page, per_page, sort, installation_id @@ -44,8 +49,9 @@ async def get_github_repositories( @app.get('/user') async def get_github_user( github_user_id: str | None = Depends(get_user_id), + github_user_token: SecretStr | None = Depends(get_github_token), ): - client = GithubServiceImpl(github_user_id) + client = GithubServiceImpl(user_id=github_user_id, token=github_user_token) try: user: GitHubUser = await client.get_user() return user @@ -66,8 +72,9 @@ async def get_github_user( @app.get('/installations') async def get_github_installation_ids( github_user_id: str | None = Depends(get_user_id), + github_user_token: SecretStr | None = Depends(get_github_token), ): - client = GithubServiceImpl(github_user_id) + client = GithubServiceImpl(user_id=github_user_id, token=github_user_token) try: installations_ids: list[int] = await client.get_installation_ids() return installations_ids @@ -92,8 +99,9 @@ async def search_github_repositories( sort: str = 'stars', order: str = 'desc', github_user_id: str | None = Depends(get_user_id), + github_user_token: SecretStr | None = Depends(get_github_token), ): - client = GithubServiceImpl(github_user_id) + client = GithubServiceImpl(user_id=github_user_id, token=github_user_token) try: repos: list[GitHubRepository] = await client.search_repositories( query, per_page, sort, order diff --git a/openhands/server/routes/manage_conversations.py b/openhands/server/routes/manage_conversations.py index 2a453e30f1d3..41c79a2d25fa 100644 --- a/openhands/server/routes/manage_conversations.py +++ b/openhands/server/routes/manage_conversations.py @@ -4,13 +4,13 @@ from fastapi import APIRouter, Body, Request from fastapi.responses import JSONResponse -from pydantic import BaseModel +from pydantic import BaseModel, SecretStr from openhands.core.logger import openhands_logger as logger from openhands.events.action.message import MessageAction from openhands.events.stream import EventStreamSubscriber from openhands.runtime import get_runtime_cls -from openhands.server.auth import get_user_id +from openhands.server.auth import get_github_token, get_user_id from openhands.server.routes.github import GithubServiceImpl from openhands.server.session.conversation_init_data import ConversationInitData from openhands.server.shared import ( @@ -44,7 +44,7 @@ class InitSessionRequest(BaseModel): async def _create_new_conversation( user_id: str | None, - token: str | None, + token: SecretStr | None, selected_repository: str | None, initial_user_msg: str | None, image_urls: list[str] | None, @@ -72,7 +72,7 @@ async def _create_new_conversation( logger.warn('Settings not present, not starting conversation') raise MissingSettingsError('Settings not found') - session_init_args['github_token'] = token or '' + session_init_args['github_token'] = token or SecretStr('') session_init_args['selected_repository'] = selected_repository conversation_init_data = ConversationInitData(**session_init_args) logger.info('Loading conversation store') @@ -131,7 +131,9 @@ async def new_conversation(request: Request, data: InitSessionRequest): """ logger.info('Initializing new conversation') user_id = get_user_id(request) - github_token = GithubServiceImpl.get_gh_token(request) + github_service = GithubServiceImpl(user_id=user_id, token=get_github_token(request)) + github_token = await github_service.get_latest_token() + selected_repository = data.selected_repository initial_user_msg = data.initial_user_msg image_urls = data.image_urls or [] diff --git a/openhands/server/routes/settings.py b/openhands/server/routes/settings.py index d14c113a6549..8fb7b5f06c4d 100644 --- a/openhands/server/routes/settings.py +++ b/openhands/server/routes/settings.py @@ -1,11 +1,12 @@ from fastapi import APIRouter, Request, status from fastapi.responses import JSONResponse +from pydantic import SecretStr from openhands.core.logger import openhands_logger as logger -from openhands.server.auth import get_user_id -from openhands.server.services.github_service import GitHubService +from openhands.server.auth import get_github_token, get_user_id from openhands.server.settings import GETSettingsModel, POSTSettingsModel, Settings from openhands.server.shared import SettingsStoreImpl, config +from openhands.services.github.github_service import GitHubService app = APIRouter(prefix='/api') @@ -22,7 +23,7 @@ async def load_settings(request: Request) -> GETSettingsModel | None: content={'error': 'Settings not found'}, ) - token_is_set = bool(user_id) or bool(request.state.github_token) + token_is_set = bool(user_id) or bool(get_github_token(request)) settings_with_token_data = GETSettingsModel( **settings.model_dump(), github_token_is_set=token_is_set, @@ -50,8 +51,8 @@ async def store_settings( try: # We check if the token is valid by getting the user # If the token is invalid, this will raise an exception - github = GitHubService(None) - await github.validate_user(settings.github_token) + github = GitHubService(user_id=None, token=SecretStr(settings.github_token)) + await github.get_user() except Exception as e: logger.warning(f'Invalid GitHub token: {e}') diff --git a/openhands/server/session/agent_session.py b/openhands/server/session/agent_session.py index 59afdf141c10..37202e2bcbac 100644 --- a/openhands/server/session/agent_session.py +++ b/openhands/server/session/agent_session.py @@ -2,6 +2,8 @@ import time from typing import Callable, Optional +from pydantic import SecretStr + from openhands.controller import AgentController from openhands.controller.agent import Agent from openhands.controller.state.state import State @@ -69,7 +71,7 @@ async def start( max_budget_per_task: float | None = None, agent_to_llm_config: dict[str, LLMConfig] | None = None, agent_configs: dict[str, AgentConfig] | None = None, - github_token: str | None = None, + github_token: SecretStr | None = None, selected_repository: str | None = None, initial_message: MessageAction | None = None, ): @@ -113,7 +115,7 @@ async def start( if github_token: self.event_stream.set_secrets( { - 'github_token': github_token, + 'github_token': github_token.get_secret_value(), } ) if initial_message: @@ -177,7 +179,7 @@ async def _create_runtime( runtime_name: str, config: AppConfig, agent: Agent, - github_token: str | None = None, + github_token: SecretStr | None = None, selected_repository: str | None = None, ): """Creates a runtime instance @@ -195,7 +197,7 @@ async def _create_runtime( runtime_cls = get_runtime_cls(runtime_name) env_vars = ( { - 'GITHUB_TOKEN': github_token, + 'GITHUB_TOKEN': github_token.get_secret_value(), } if github_token else None diff --git a/openhands/server/session/conversation_init_data.py b/openhands/server/session/conversation_init_data.py index 82979e91fd96..8773f48b326a 100644 --- a/openhands/server/session/conversation_init_data.py +++ b/openhands/server/session/conversation_init_data.py @@ -1,4 +1,4 @@ -from pydantic import Field +from pydantic import Field, SecretStr from openhands.server.settings import Settings @@ -8,5 +8,5 @@ class ConversationInitData(Settings): Session initialization data for the web environment - a deep copy of the global config is made and then overridden with this data. """ - github_token: str | None = Field(default=None) + github_token: SecretStr | None = Field(default=None) selected_repository: str | None = Field(default=None) diff --git a/openhands/server/types.py b/openhands/server/types.py index da115de8cda5..4c8c1dc96a1c 100644 --- a/openhands/server/types.py +++ b/openhands/server/types.py @@ -42,15 +42,3 @@ class LLMAuthenticationError(ValueError): """Raised when there is an issue with LLM authentication.""" pass - - -class GhAuthenticationError(ValueError): - """Raised when there is an issue with LLM authentication.""" - - pass - - -class GHUnknownException(ValueError): - """Raised when there is an issue with LLM authentication.""" - - pass diff --git a/openhands/server/services/github_service.py b/openhands/services/github/github_service.py similarity index 72% rename from openhands/server/services/github_service.py rename to openhands/services/github/github_service.py index 9ade12f8a852..92d14c5653ed 100644 --- a/openhands/server/services/github_service.py +++ b/openhands/services/github/github_service.py @@ -1,41 +1,45 @@ from typing import Any import httpx -from fastapi import Request +from pydantic import SecretStr -from openhands.server.auth import get_github_token -from openhands.server.data_models.gh_types import GitHubRepository, GitHubUser -from openhands.server.shared import SettingsStoreImpl, config, server_config -from openhands.server.types import AppMode, GhAuthenticationError, GHUnknownException +from openhands.services.github.github_types import ( + GhAuthenticationError, + GHUnknownException, + GitHubRepository, + GitHubUser, +) class GitHubService: BASE_URL = 'https://api.github.com' - token: str = '' + token: SecretStr = SecretStr('') + refresh = False - def __init__(self, user_id: str | None): + def __init__(self, user_id: str | None = None, token: SecretStr | None = None): self.user_id = user_id - async def _get_github_headers(self): + if token: + self.token = token + + async def _get_github_headers(self) -> dict: """ Retrieve the GH Token from settings store to construct the headers """ - settings_store = await SettingsStoreImpl.get_instance(config, self.user_id) - settings = await settings_store.load() - if settings and settings.github_token: - self.token = settings.github_token.get_secret_value() + if self.user_id and not self.token: + self.token = await self.get_latest_token() return { - 'Authorization': f'Bearer {self.token}', + 'Authorization': f'Bearer {self.token.get_secret_value()}', 'Accept': 'application/vnd.github.v3+json', } - def _has_token_expired(self, status_code: int): + def _has_token_expired(self, status_code: int) -> bool: return status_code == 401 - async def _get_latest_token(self): - pass + async def get_latest_token(self) -> SecretStr: + return self.token async def _fetch_data( self, url: str, params: dict | None = None @@ -44,10 +48,8 @@ async def _fetch_data( async with httpx.AsyncClient() as client: github_headers = await self._get_github_headers() response = await client.get(url, headers=github_headers, params=params) - if server_config.app_mode == AppMode.SAAS and self._has_token_expired( - response.status_code - ): - await self._get_latest_token() + if self.refresh and self._has_token_expired(response.status_code): + await self.get_latest_token() github_headers = await self._get_github_headers() response = await client.get( url, headers=github_headers, params=params @@ -60,8 +62,10 @@ async def _fetch_data( return response.json(), headers - except httpx.HTTPStatusError: - raise GhAuthenticationError('Invalid Github token') + except httpx.HTTPStatusError as e: + if e.response.status_code == 401: + raise GhAuthenticationError('Invalid Github token') + raise GHUnknownException('Unknown error') except httpx.HTTPError: raise GHUnknownException('Unknown error') @@ -79,10 +83,6 @@ async def get_user(self) -> GitHubUser: email=response.get('email'), ) - async def validate_user(self, token) -> GitHubUser: - self.token = token - return await self.get_user() - async def get_repositories( self, page: int, per_page: int, sort: str, installation_id: int | None ) -> list[GitHubRepository]: @@ -133,7 +133,3 @@ async def search_repositories( ] return repos - - @classmethod - def get_gh_token(cls, request: Request) -> str | None: - return get_github_token(request) diff --git a/openhands/server/data_models/gh_types.py b/openhands/services/github/github_types.py similarity index 58% rename from openhands/server/data_models/gh_types.py rename to openhands/services/github/github_types.py index e6b67392bcca..d1958c9bbd19 100644 --- a/openhands/server/data_models/gh_types.py +++ b/openhands/services/github/github_types.py @@ -15,3 +15,15 @@ class GitHubRepository(BaseModel): full_name: str stargazers_count: int | None = None link_header: str | None = None + + +class GhAuthenticationError(ValueError): + """Raised when there is an issue with GitHub authentication.""" + + pass + + +class GHUnknownException(ValueError): + """Raised when there is an issue with GitHub communcation.""" + + pass diff --git a/tests/unit/test_github_service.py b/tests/unit/test_github_service.py new file mode 100644 index 000000000000..e24faf6cb12a --- /dev/null +++ b/tests/unit/test_github_service.py @@ -0,0 +1,81 @@ +from unittest.mock import AsyncMock, Mock, patch + +import httpx +import pytest +from pydantic import SecretStr + +from openhands.services.github.github_service import GitHubService +from openhands.services.github.github_types import GhAuthenticationError + + +@pytest.mark.asyncio +async def test_github_service_token_handling(): + # Test initialization with SecretStr token + token = SecretStr('test-token') + service = GitHubService(user_id=None, token=token) + assert service.token == token + assert service.token.get_secret_value() == 'test-token' + + # Test headers contain the token correctly + headers = await service._get_github_headers() + assert headers['Authorization'] == 'Bearer test-token' + assert headers['Accept'] == 'application/vnd.github.v3+json' + + # Test initialization without token + service = GitHubService(user_id='test-user') + assert service.token == SecretStr('') + + +@pytest.mark.asyncio +async def test_github_service_token_refresh(): + # Test that token refresh is only attempted when refresh=True + token = SecretStr('test-token') + service = GitHubService(user_id=None, token=token) + assert not service.refresh + + # Test token expiry detection + assert service._has_token_expired(401) + assert not service._has_token_expired(200) + assert not service._has_token_expired(404) + + # Test get_latest_token returns a copy of the current token + latest_token = await service.get_latest_token() + assert isinstance(latest_token, SecretStr) + assert latest_token.get_secret_value() == 'test-token' # Compare with known value + + +@pytest.mark.asyncio +async def test_github_service_fetch_data(): + # Mock httpx.AsyncClient for testing API calls + mock_response = AsyncMock() + mock_response.status_code = 200 + mock_response.json.return_value = {'login': 'test-user'} + mock_response.raise_for_status = Mock() + + mock_client = AsyncMock() + mock_client.get.return_value = mock_response + mock_client.__aenter__.return_value = mock_client + mock_client.__aexit__.return_value = None + + with patch('httpx.AsyncClient', return_value=mock_client): + service = GitHubService(user_id=None, token=SecretStr('test-token')) + _ = await service._fetch_data('https://api.github.com/user') + + # Verify the request was made with correct headers + mock_client.get.assert_called_once() + call_args = mock_client.get.call_args + headers = call_args[1]['headers'] + assert headers['Authorization'] == 'Bearer test-token' + + # Test error handling with 401 status code + mock_response.status_code = 401 + mock_response.raise_for_status.side_effect = httpx.HTTPStatusError( + message='401 Unauthorized', request=Mock(), response=mock_response + ) + + # Reset the mock to test error handling + mock_client.get.reset_mock() + mock_client.get.return_value = mock_response + + with pytest.raises(GhAuthenticationError): + _ = await service._fetch_data('https://api.github.com/user')