Skip to content

Commit

Permalink
[Bug fix]: Standardize SecretStr use (#6660)
Browse files Browse the repository at this point in the history
Co-authored-by: Engel Nyst <[email protected]>
Co-authored-by: openhands <[email protected]>
  • Loading branch information
3 people authored Feb 10, 2025
1 parent 707cb07 commit 4a5891c
Show file tree
Hide file tree
Showing 14 changed files with 166 additions and 71 deletions.
3 changes: 3 additions & 0 deletions dev_config/python/ruff.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,6 @@ inline-quotes = "single"

[format]
quote-style = "single"

[lint.flake8-bugbear]
extend-immutable-calls = ["Depends", "fastapi.Depends", "fastapi.params.Depends"]
5 changes: 3 additions & 2 deletions openhands/runtime/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from typing import Callable
from zipfile import ZipFile

from pydantic import SecretStr
from requests.exceptions import ConnectionError

from openhands.core.config import AppConfig, SandboxConfig
Expand Down Expand Up @@ -234,12 +235,12 @@ async def _handle_action(self, event: Action) -> None:
source = event.source if event.source else EventSource.AGENT
self.event_stream.add_event(observation, source) # type: ignore[arg-type]

def clone_repo(self, github_token: str, selected_repository: str) -> str:
def clone_repo(self, github_token: SecretStr, selected_repository: str) -> str:
if not github_token or not selected_repository:
raise ValueError(
'github_token and selected_repository must be provided to clone a repository'
)
url = f'https://{github_token}@github.com/{selected_repository}.git'
url = f'https://{github_token.get_secret_value()}@github.com/{selected_repository}.git'
dir_name = selected_repository.split('/')[1]
# add random branch name to avoid conflicts
random_str = ''.join(
Expand Down
3 changes: 2 additions & 1 deletion openhands/server/auth.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from fastapi import Request
from pydantic import SecretStr


def get_github_token(request: Request) -> str | None:
def get_github_token(request: Request) -> SecretStr | None:
return getattr(request.state, 'github_token', None)


Expand Down
2 changes: 1 addition & 1 deletion openhands/server/config/server_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ class ServerConfig(ServerConfigInterface):
)
conversation_manager_class: str = 'openhands.server.conversation_manager.standalone_conversation_manager.StandaloneConversationManager'

github_service_class: str = 'openhands.server.services.github_service.GitHubService'
github_service_class: str = 'openhands.services.github.github_service.GitHubService'

def verify_config(self):
if self.config_cls:
Expand Down
2 changes: 1 addition & 1 deletion openhands/server/middleware.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ async def __call__(self, request: Request, call_next: Callable):
# TODO: To avoid checks like this we should re-add the abilty to have completely different middleware in SAAS as in OSS
if getattr(request.state, 'github_token', None) is None:
if settings and settings.github_token:
request.state.github_token = settings.github_token.get_secret_value()
request.state.github_token = settings.github_token
else:
request.state.github_token = None

Expand Down
24 changes: 16 additions & 8 deletions openhands/server/routes/github.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
from fastapi import APIRouter, Depends
from fastapi.responses import JSONResponse
from pydantic import SecretStr

from openhands.server.auth import get_user_id
from openhands.server.data_models.gh_types import GitHubRepository, GitHubUser
from openhands.server.services.github_service import GitHubService
from openhands.server.auth import get_github_token, get_user_id
from openhands.server.shared import server_config
from openhands.server.types import GhAuthenticationError, GHUnknownException
from openhands.services.github.github_service import (
GhAuthenticationError,
GHUnknownException,
GitHubService,
)
from openhands.services.github.github_types import GitHubRepository, GitHubUser
from openhands.utils.import_utils import get_impl

app = APIRouter(prefix='/api/github')
Expand All @@ -20,8 +24,9 @@ async def get_github_repositories(
sort: str = 'pushed',
installation_id: int | None = None,
github_user_id: str | None = Depends(get_user_id),
github_user_token: SecretStr | None = Depends(get_github_token),
):
client = GithubServiceImpl(github_user_id)
client = GithubServiceImpl(user_id=github_user_id, token=github_user_token)
try:
repos: list[GitHubRepository] = await client.get_repositories(
page, per_page, sort, installation_id
Expand All @@ -44,8 +49,9 @@ async def get_github_repositories(
@app.get('/user')
async def get_github_user(
github_user_id: str | None = Depends(get_user_id),
github_user_token: SecretStr | None = Depends(get_github_token),
):
client = GithubServiceImpl(github_user_id)
client = GithubServiceImpl(user_id=github_user_id, token=github_user_token)
try:
user: GitHubUser = await client.get_user()
return user
Expand All @@ -66,8 +72,9 @@ async def get_github_user(
@app.get('/installations')
async def get_github_installation_ids(
github_user_id: str | None = Depends(get_user_id),
github_user_token: SecretStr | None = Depends(get_github_token),
):
client = GithubServiceImpl(github_user_id)
client = GithubServiceImpl(user_id=github_user_id, token=github_user_token)
try:
installations_ids: list[int] = await client.get_installation_ids()
return installations_ids
Expand All @@ -92,8 +99,9 @@ async def search_github_repositories(
sort: str = 'stars',
order: str = 'desc',
github_user_id: str | None = Depends(get_user_id),
github_user_token: SecretStr | None = Depends(get_github_token),
):
client = GithubServiceImpl(github_user_id)
client = GithubServiceImpl(user_id=github_user_id, token=github_user_token)
try:
repos: list[GitHubRepository] = await client.search_repositories(
query, per_page, sort, order
Expand Down
12 changes: 7 additions & 5 deletions openhands/server/routes/manage_conversations.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@

from fastapi import APIRouter, Body, Request
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from pydantic import BaseModel, SecretStr

from openhands.core.logger import openhands_logger as logger
from openhands.events.action.message import MessageAction
from openhands.events.stream import EventStreamSubscriber
from openhands.runtime import get_runtime_cls
from openhands.server.auth import get_user_id
from openhands.server.auth import get_github_token, get_user_id
from openhands.server.routes.github import GithubServiceImpl
from openhands.server.session.conversation_init_data import ConversationInitData
from openhands.server.shared import (
Expand Down Expand Up @@ -44,7 +44,7 @@ class InitSessionRequest(BaseModel):

async def _create_new_conversation(
user_id: str | None,
token: str | None,
token: SecretStr | None,
selected_repository: str | None,
initial_user_msg: str | None,
image_urls: list[str] | None,
Expand Down Expand Up @@ -72,7 +72,7 @@ async def _create_new_conversation(
logger.warn('Settings not present, not starting conversation')
raise MissingSettingsError('Settings not found')

session_init_args['github_token'] = token or ''
session_init_args['github_token'] = token or SecretStr('')
session_init_args['selected_repository'] = selected_repository
conversation_init_data = ConversationInitData(**session_init_args)
logger.info('Loading conversation store')
Expand Down Expand Up @@ -131,7 +131,9 @@ async def new_conversation(request: Request, data: InitSessionRequest):
"""
logger.info('Initializing new conversation')
user_id = get_user_id(request)
github_token = GithubServiceImpl.get_gh_token(request)
github_service = GithubServiceImpl(user_id=user_id, token=get_github_token(request))
github_token = await github_service.get_latest_token()

selected_repository = data.selected_repository
initial_user_msg = data.initial_user_msg
image_urls = data.image_urls or []
Expand Down
11 changes: 6 additions & 5 deletions openhands/server/routes/settings.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from fastapi import APIRouter, Request, status
from fastapi.responses import JSONResponse
from pydantic import SecretStr

from openhands.core.logger import openhands_logger as logger
from openhands.server.auth import get_user_id
from openhands.server.services.github_service import GitHubService
from openhands.server.auth import get_github_token, get_user_id
from openhands.server.settings import GETSettingsModel, POSTSettingsModel, Settings
from openhands.server.shared import SettingsStoreImpl, config
from openhands.services.github.github_service import GitHubService

app = APIRouter(prefix='/api')

Expand All @@ -22,7 +23,7 @@ async def load_settings(request: Request) -> GETSettingsModel | None:
content={'error': 'Settings not found'},
)

token_is_set = bool(user_id) or bool(request.state.github_token)
token_is_set = bool(user_id) or bool(get_github_token(request))
settings_with_token_data = GETSettingsModel(
**settings.model_dump(),
github_token_is_set=token_is_set,
Expand Down Expand Up @@ -50,8 +51,8 @@ async def store_settings(
try:
# We check if the token is valid by getting the user
# If the token is invalid, this will raise an exception
github = GitHubService(None)
await github.validate_user(settings.github_token)
github = GitHubService(user_id=None, token=SecretStr(settings.github_token))
await github.get_user()

except Exception as e:
logger.warning(f'Invalid GitHub token: {e}')
Expand Down
10 changes: 6 additions & 4 deletions openhands/server/session/agent_session.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import time
from typing import Callable, Optional

from pydantic import SecretStr

from openhands.controller import AgentController
from openhands.controller.agent import Agent
from openhands.controller.state.state import State
Expand Down Expand Up @@ -69,7 +71,7 @@ async def start(
max_budget_per_task: float | None = None,
agent_to_llm_config: dict[str, LLMConfig] | None = None,
agent_configs: dict[str, AgentConfig] | None = None,
github_token: str | None = None,
github_token: SecretStr | None = None,
selected_repository: str | None = None,
initial_message: MessageAction | None = None,
):
Expand Down Expand Up @@ -113,7 +115,7 @@ async def start(
if github_token:
self.event_stream.set_secrets(
{
'github_token': github_token,
'github_token': github_token.get_secret_value(),
}
)
if initial_message:
Expand Down Expand Up @@ -177,7 +179,7 @@ async def _create_runtime(
runtime_name: str,
config: AppConfig,
agent: Agent,
github_token: str | None = None,
github_token: SecretStr | None = None,
selected_repository: str | None = None,
):
"""Creates a runtime instance
Expand All @@ -195,7 +197,7 @@ async def _create_runtime(
runtime_cls = get_runtime_cls(runtime_name)
env_vars = (
{
'GITHUB_TOKEN': github_token,
'GITHUB_TOKEN': github_token.get_secret_value(),
}
if github_token
else None
Expand Down
4 changes: 2 additions & 2 deletions openhands/server/session/conversation_init_data.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from pydantic import Field
from pydantic import Field, SecretStr

from openhands.server.settings import Settings

Expand All @@ -8,5 +8,5 @@ class ConversationInitData(Settings):
Session initialization data for the web environment - a deep copy of the global config is made and then overridden with this data.
"""

github_token: str | None = Field(default=None)
github_token: SecretStr | None = Field(default=None)
selected_repository: str | None = Field(default=None)
12 changes: 0 additions & 12 deletions openhands/server/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,15 +42,3 @@ class LLMAuthenticationError(ValueError):
"""Raised when there is an issue with LLM authentication."""

pass


class GhAuthenticationError(ValueError):
"""Raised when there is an issue with LLM authentication."""

pass


class GHUnknownException(ValueError):
"""Raised when there is an issue with LLM authentication."""

pass
Original file line number Diff line number Diff line change
@@ -1,41 +1,45 @@
from typing import Any

import httpx
from fastapi import Request
from pydantic import SecretStr

from openhands.server.auth import get_github_token
from openhands.server.data_models.gh_types import GitHubRepository, GitHubUser
from openhands.server.shared import SettingsStoreImpl, config, server_config
from openhands.server.types import AppMode, GhAuthenticationError, GHUnknownException
from openhands.services.github.github_types import (
GhAuthenticationError,
GHUnknownException,
GitHubRepository,
GitHubUser,
)


class GitHubService:
BASE_URL = 'https://api.github.com'
token: str = ''
token: SecretStr = SecretStr('')
refresh = False

def __init__(self, user_id: str | None):
def __init__(self, user_id: str | None = None, token: SecretStr | None = None):
self.user_id = user_id

async def _get_github_headers(self):
if token:
self.token = token

async def _get_github_headers(self) -> dict:
"""
Retrieve the GH Token from settings store to construct the headers
"""

settings_store = await SettingsStoreImpl.get_instance(config, self.user_id)
settings = await settings_store.load()
if settings and settings.github_token:
self.token = settings.github_token.get_secret_value()
if self.user_id and not self.token:
self.token = await self.get_latest_token()

return {
'Authorization': f'Bearer {self.token}',
'Authorization': f'Bearer {self.token.get_secret_value()}',
'Accept': 'application/vnd.github.v3+json',
}

def _has_token_expired(self, status_code: int):
def _has_token_expired(self, status_code: int) -> bool:
return status_code == 401

async def _get_latest_token(self):
pass
async def get_latest_token(self) -> SecretStr:
return self.token

async def _fetch_data(
self, url: str, params: dict | None = None
Expand All @@ -44,10 +48,8 @@ async def _fetch_data(
async with httpx.AsyncClient() as client:
github_headers = await self._get_github_headers()
response = await client.get(url, headers=github_headers, params=params)
if server_config.app_mode == AppMode.SAAS and self._has_token_expired(
response.status_code
):
await self._get_latest_token()
if self.refresh and self._has_token_expired(response.status_code):
await self.get_latest_token()
github_headers = await self._get_github_headers()
response = await client.get(
url, headers=github_headers, params=params
Expand All @@ -60,8 +62,10 @@ async def _fetch_data(

return response.json(), headers

except httpx.HTTPStatusError:
raise GhAuthenticationError('Invalid Github token')
except httpx.HTTPStatusError as e:
if e.response.status_code == 401:
raise GhAuthenticationError('Invalid Github token')
raise GHUnknownException('Unknown error')

except httpx.HTTPError:
raise GHUnknownException('Unknown error')
Expand All @@ -79,10 +83,6 @@ async def get_user(self) -> GitHubUser:
email=response.get('email'),
)

async def validate_user(self, token) -> GitHubUser:
self.token = token
return await self.get_user()

async def get_repositories(
self, page: int, per_page: int, sort: str, installation_id: int | None
) -> list[GitHubRepository]:
Expand Down Expand Up @@ -133,7 +133,3 @@ async def search_repositories(
]

return repos

@classmethod
def get_gh_token(cls, request: Request) -> str | None:
return get_github_token(request)
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,15 @@ class GitHubRepository(BaseModel):
full_name: str
stargazers_count: int | None = None
link_header: str | None = None


class GhAuthenticationError(ValueError):
"""Raised when there is an issue with GitHub authentication."""

pass


class GHUnknownException(ValueError):
"""Raised when there is an issue with GitHub communcation."""

pass
Loading

0 comments on commit 4a5891c

Please sign in to comment.