diff --git a/evaluation/benchmarks/the_agent_company/run_infer.py b/evaluation/benchmarks/the_agent_company/run_infer.py index 8f8a1b599e6f..376df6c47cfb 100644 --- a/evaluation/benchmarks/the_agent_company/run_infer.py +++ b/evaluation/benchmarks/the_agent_company/run_infer.py @@ -80,7 +80,7 @@ def load_dependencies(runtime: Runtime) -> List[str]: def init_task_env(runtime: Runtime, hostname: str, env_llm_config: LLMConfig): command = ( f'SERVER_HOSTNAME={hostname} ' - f'LITELLM_API_KEY={env_llm_config.api_key} ' + f'LITELLM_API_KEY={env_llm_config.api_key.get_secret_value() if env_llm_config.api_key else None} ' f'LITELLM_BASE_URL={env_llm_config.base_url} ' f'LITELLM_MODEL={env_llm_config.model} ' 'bash /utils/init.sh' @@ -165,7 +165,7 @@ def run_evaluator( runtime: Runtime, env_llm_config: LLMConfig, trajectory_path: str, result_path: str ): command = ( - f'LITELLM_API_KEY={env_llm_config.api_key} ' + f'LITELLM_API_KEY={env_llm_config.api_key.get_secret_value() if env_llm_config.api_key else None} ' f'LITELLM_BASE_URL={env_llm_config.base_url} ' f'LITELLM_MODEL={env_llm_config.model} ' f"DECRYPTION_KEY='theagentcompany is all you need' " # Hardcoded Key diff --git a/evaluation/utils/shared.py b/evaluation/utils/shared.py index 23c5b319d3c2..86de1a01d2cd 100644 --- a/evaluation/utils/shared.py +++ b/evaluation/utils/shared.py @@ -52,30 +52,6 @@ class EvalMetadata(BaseModel): details: dict[str, Any] | None = None condenser_config: CondenserConfig | None = None - def model_dump(self, *args, **kwargs): - dumped_dict = super().model_dump(*args, **kwargs) - # avoid leaking sensitive information - dumped_dict['llm_config'] = self.llm_config.to_safe_dict() - if hasattr(self.condenser_config, 'llm_config'): - dumped_dict['condenser_config']['llm_config'] = ( - self.condenser_config.llm_config.to_safe_dict() - ) - - return dumped_dict - - def model_dump_json(self, *args, **kwargs): - dumped = super().model_dump_json(*args, **kwargs) - dumped_dict = json.loads(dumped) - # avoid leaking sensitive information - dumped_dict['llm_config'] = self.llm_config.to_safe_dict() - if hasattr(self.condenser_config, 'llm_config'): - dumped_dict['condenser_config']['llm_config'] = ( - self.condenser_config.llm_config.to_safe_dict() - ) - - logger.debug(f'Dumped metadata: {dumped_dict}') - return json.dumps(dumped_dict) - class EvalOutput(BaseModel): # NOTE: User-specified @@ -98,23 +74,6 @@ class EvalOutput(BaseModel): # Optionally save the input test instance instance: dict[str, Any] | None = None - def model_dump(self, *args, **kwargs): - dumped_dict = super().model_dump(*args, **kwargs) - # Remove None values - dumped_dict = {k: v for k, v in dumped_dict.items() if v is not None} - # Apply custom serialization for metadata (to avoid leaking sensitive information) - if self.metadata is not None: - dumped_dict['metadata'] = self.metadata.model_dump() - return dumped_dict - - def model_dump_json(self, *args, **kwargs): - dumped = super().model_dump_json(*args, **kwargs) - dumped_dict = json.loads(dumped) - # Apply custom serialization for metadata (to avoid leaking sensitive information) - if 'metadata' in dumped_dict: - dumped_dict['metadata'] = json.loads(self.metadata.model_dump_json()) - return json.dumps(dumped_dict) - class EvalException(Exception): pass @@ -314,7 +273,7 @@ def update_progress( logger.info( f'Finished evaluation for instance {result.instance_id}: {str(result.test_result)[:300]}...\n' ) - output_fp.write(json.dumps(result.model_dump()) + '\n') + output_fp.write(result.model_dump_json() + '\n') output_fp.flush() diff --git a/openhands/core/config/README.md b/openhands/core/config/README.md index 5e3abae5b13a..c612a0824403 100644 --- a/openhands/core/config/README.md +++ b/openhands/core/config/README.md @@ -37,21 +37,17 @@ export SANDBOX_TIMEOUT='300' ## Type Handling -The `load_from_env` function attempts to cast environment variable values to the types specified in the dataclasses. It handles: +The `load_from_env` function attempts to cast environment variable values to the types specified in the models. It handles: - Basic types (str, int, bool) - Optional types (e.g., `str | None`) -- Nested dataclasses +- Nested models If type casting fails, an error is logged, and the default value is retained. ## Default Values -If an environment variable is not set, the default value specified in the dataclass is used. - -## Nested Configurations - -The `AppConfig` class contains nested configurations like `LLMConfig` and `AgentConfig`. The `load_from_env` function handles these by recursively processing nested dataclasses with updated prefixes. +If an environment variable is not set, the default value specified in the model is used. ## Security Considerations diff --git a/openhands/core/config/agent_config.py b/openhands/core/config/agent_config.py index 375fd9b12e8a..43e430b879da 100644 --- a/openhands/core/config/agent_config.py +++ b/openhands/core/config/agent_config.py @@ -1,11 +1,9 @@ -from dataclasses import dataclass, field, fields +from pydantic import BaseModel, Field from openhands.core.config.condenser_config import CondenserConfig, NoOpCondenserConfig -from openhands.core.config.config_utils import get_field_info -@dataclass -class AgentConfig: +class AgentConfig(BaseModel): """Configuration for the agent. Attributes: @@ -22,20 +20,13 @@ class AgentConfig: condenser: Configuration for the memory condenser. Default is NoOpCondenserConfig. """ - codeact_enable_browsing: bool = True - codeact_enable_llm_editor: bool = False - codeact_enable_jupyter: bool = True - micro_agent_name: str | None = None - memory_enabled: bool = False - memory_max_threads: int = 3 - llm_config: str | None = None - enable_prompt_extensions: bool = True - disabled_microagents: list[str] | None = None - condenser: CondenserConfig = field(default_factory=NoOpCondenserConfig) # type: ignore - - def defaults_to_dict(self) -> dict: - """Serialize fields to a dict for the frontend, including type hints, defaults, and whether it's optional.""" - result = {} - for f in fields(self): - result[f.name] = get_field_info(f) - return result + codeact_enable_browsing: bool = Field(default=True) + codeact_enable_llm_editor: bool = Field(default=False) + codeact_enable_jupyter: bool = Field(default=True) + micro_agent_name: str | None = Field(default=None) + memory_enabled: bool = Field(default=False) + memory_max_threads: int = Field(default=3) + llm_config: str | None = Field(default=None) + enable_prompt_extensions: bool = Field(default=False) + disabled_microagents: list[str] | None = Field(default=None) + condenser: CondenserConfig = Field(default_factory=NoOpCondenserConfig) diff --git a/openhands/core/config/app_config.py b/openhands/core/config/app_config.py index a59f4fb1b865..468d37572fe2 100644 --- a/openhands/core/config/app_config.py +++ b/openhands/core/config/app_config.py @@ -1,20 +1,20 @@ -from dataclasses import dataclass, field, fields, is_dataclass from typing import ClassVar +from pydantic import BaseModel, Field, SecretStr + from openhands.core import logger from openhands.core.config.agent_config import AgentConfig from openhands.core.config.config_utils import ( OH_DEFAULT_AGENT, OH_MAX_ITERATIONS, - get_field_info, + model_defaults_to_dict, ) from openhands.core.config.llm_config import LLMConfig from openhands.core.config.sandbox_config import SandboxConfig from openhands.core.config.security_config import SecurityConfig -@dataclass -class AppConfig: +class AppConfig(BaseModel): """Configuration for the app. Attributes: @@ -46,37 +46,39 @@ class AppConfig: input is read line by line. When enabled, input continues until /exit command. """ - llms: dict[str, LLMConfig] = field(default_factory=dict) - agents: dict = field(default_factory=dict) - default_agent: str = OH_DEFAULT_AGENT - sandbox: SandboxConfig = field(default_factory=SandboxConfig) - security: SecurityConfig = field(default_factory=SecurityConfig) - runtime: str = 'docker' - file_store: str = 'local' - file_store_path: str = '/tmp/openhands_file_store' - save_trajectory_path: str | None = None - workspace_base: str | None = None - workspace_mount_path: str | None = None - workspace_mount_path_in_sandbox: str = '/workspace' - workspace_mount_rewrite: str | None = None - cache_dir: str = '/tmp/cache' - run_as_openhands: bool = True - max_iterations: int = OH_MAX_ITERATIONS - max_budget_per_task: float | None = None - e2b_api_key: str = '' - modal_api_token_id: str = '' - modal_api_token_secret: str = '' - disable_color: bool = False - jwt_secret: str = '' - debug: bool = False - file_uploads_max_file_size_mb: int = 0 - file_uploads_restrict_file_types: bool = False - file_uploads_allowed_extensions: list[str] = field(default_factory=lambda: ['.*']) - runloop_api_key: str | None = None - cli_multiline_input: bool = False + llms: dict[str, LLMConfig] = Field(default_factory=dict) + agents: dict = Field(default_factory=dict) + default_agent: str = Field(default=OH_DEFAULT_AGENT) + sandbox: SandboxConfig = Field(default_factory=SandboxConfig) + security: SecurityConfig = Field(default_factory=SecurityConfig) + runtime: str = Field(default='docker') + file_store: str = Field(default='local') + file_store_path: str = Field(default='/tmp/openhands_file_store') + save_trajectory_path: str | None = Field(default=None) + workspace_base: str | None = Field(default=None) + workspace_mount_path: str | None = Field(default=None) + workspace_mount_path_in_sandbox: str = Field(default='/workspace') + workspace_mount_rewrite: str | None = Field(default=None) + cache_dir: str = Field(default='/tmp/cache') + run_as_openhands: bool = Field(default=True) + max_iterations: int = Field(default=OH_MAX_ITERATIONS) + max_budget_per_task: float | None = Field(default=None) + e2b_api_key: SecretStr | None = Field(default=None) + modal_api_token_id: SecretStr | None = Field(default=None) + modal_api_token_secret: SecretStr | None = Field(default=None) + disable_color: bool = Field(default=False) + jwt_secret: SecretStr | None = Field(default=None) + debug: bool = Field(default=False) + file_uploads_max_file_size_mb: int = Field(default=0) + file_uploads_restrict_file_types: bool = Field(default=False) + file_uploads_allowed_extensions: list[str] = Field(default_factory=lambda: ['.*']) + runloop_api_key: SecretStr | None = Field(default=None) + cli_multiline_input: bool = Field(default=False) defaults_dict: ClassVar[dict] = {} + model_config = {'extra': 'forbid'} + def get_llm_config(self, name='llm') -> LLMConfig: """'llm' is the name for default config (for backward compatibility prior to 0.8).""" if name in self.llms: @@ -115,42 +117,7 @@ def get_llm_config_from_agent(self, name='agent') -> LLMConfig: def get_agent_configs(self) -> dict[str, AgentConfig]: return self.agents - def __post_init__(self): + def model_post_init(self, __context): """Post-initialization hook, called when the instance is created with only default values.""" - AppConfig.defaults_dict = self.defaults_to_dict() - - def defaults_to_dict(self) -> dict: - """Serialize fields to a dict for the frontend, including type hints, defaults, and whether it's optional.""" - result = {} - for f in fields(self): - field_value = getattr(self, f.name) - - # dataclasses compute their defaults themselves - if is_dataclass(type(field_value)): - result[f.name] = field_value.defaults_to_dict() - else: - result[f.name] = get_field_info(f) - return result - - def __str__(self): - attr_str = [] - for f in fields(self): - attr_name = f.name - attr_value = getattr(self, f.name) - - if attr_name in [ - 'e2b_api_key', - 'github_token', - 'jwt_secret', - 'modal_api_token_id', - 'modal_api_token_secret', - 'runloop_api_key', - ]: - attr_value = '******' if attr_value else None - - attr_str.append(f'{attr_name}={repr(attr_value)}') - - return f"AppConfig({', '.join(attr_str)}" - - def __repr__(self): - return self.__str__() + super().model_post_init(__context) + AppConfig.defaults_dict = model_defaults_to_dict(self) diff --git a/openhands/core/config/config_utils.py b/openhands/core/config/config_utils.py index 38c3c1d03df5..44893e119b5a 100644 --- a/openhands/core/config/config_utils.py +++ b/openhands/core/config/config_utils.py @@ -1,19 +1,22 @@ from types import UnionType -from typing import get_args, get_origin +from typing import Any, get_args, get_origin + +from pydantic import BaseModel +from pydantic.fields import FieldInfo OH_DEFAULT_AGENT = 'CodeActAgent' OH_MAX_ITERATIONS = 500 -def get_field_info(f): +def get_field_info(field: FieldInfo) -> dict[str, Any]: """Extract information about a dataclass field: type, optional, and default. Args: - f: The field to extract information from. + field: The field to extract information from. Returns: A dict with the field's type, whether it's optional, and its default value. """ - field_type = f.type + field_type = field.annotation optional = False # for types like str | None, find the non-None type and set optional to True @@ -33,7 +36,21 @@ def get_field_info(f): ) # default is always present - default = f.default + default = field.default # return a schema with the useful info for frontend return {'type': type_name.lower(), 'optional': optional, 'default': default} + + +def model_defaults_to_dict(model: BaseModel) -> dict[str, Any]: + """Serialize field information in a dict for the frontend, including type hints, defaults, and whether it's optional.""" + result = {} + for name, field in model.model_fields.items(): + field_value = getattr(model, name) + + if isinstance(field_value, BaseModel): + result[name] = model_defaults_to_dict(field_value) + else: + result[name] = get_field_info(field) + + return result diff --git a/openhands/core/config/llm_config.py b/openhands/core/config/llm_config.py index 2687d0206940..c28390d47bda 100644 --- a/openhands/core/config/llm_config.py +++ b/openhands/core/config/llm_config.py @@ -1,14 +1,14 @@ +from __future__ import annotations + import os -from dataclasses import dataclass, fields -from openhands.core.config.config_utils import get_field_info -from openhands.core.logger import LOG_DIR +from typing import Any +from pydantic import BaseModel, Field, SecretStr -LLM_SENSITIVE_FIELDS = ['api_key', 'aws_access_key_id', 'aws_secret_access_key'] +from openhands.core.logger import LOG_DIR -@dataclass -class LLMConfig: +class LLMConfig(BaseModel): """Configuration for the LLM model. Attributes: @@ -47,99 +47,56 @@ class LLMConfig: native_tool_calling: Whether to use native tool calling if supported by the model. Can be True, False, or not set. """ - model: str = 'claude-3-5-sonnet-20241022' - api_key: str | None = None - base_url: str | None = None - api_version: str | None = None - embedding_model: str = 'local' - embedding_base_url: str | None = None - embedding_deployment_name: str | None = None - aws_access_key_id: str | None = None - aws_secret_access_key: str | None = None - aws_region_name: str | None = None - openrouter_site_url: str = 'https://docs.all-hands.dev/' - openrouter_app_name: str = 'OpenHands' - num_retries: int = 8 - retry_multiplier: float = 2 - retry_min_wait: int = 15 - retry_max_wait: int = 120 - timeout: int | None = None - max_message_chars: int = 30_000 # maximum number of characters in an observation's content when sent to the llm - temperature: float = 0.0 - top_p: float = 1.0 - custom_llm_provider: str | None = None - max_input_tokens: int | None = None - max_output_tokens: int | None = None - input_cost_per_token: float | None = None - output_cost_per_token: float | None = None - ollama_base_url: str | None = None + model: str = Field(default='claude-3-5-sonnet-20241022') + api_key: SecretStr | None = Field(default=None) + base_url: str | None = Field(default=None) + api_version: str | None = Field(default=None) + embedding_model: str = Field(default='local') + embedding_base_url: str | None = Field(default=None) + embedding_deployment_name: str | None = Field(default=None) + aws_access_key_id: SecretStr | None = Field(default=None) + aws_secret_access_key: SecretStr | None = Field(default=None) + aws_region_name: str | None = Field(default=None) + openrouter_site_url: str = Field(default='https://docs.all-hands.dev/') + openrouter_app_name: str = Field(default='OpenHands') + num_retries: int = Field(default=8) + retry_multiplier: float = Field(default=2) + retry_min_wait: int = Field(default=15) + retry_max_wait: int = Field(default=120) + timeout: int | None = Field(default=None) + max_message_chars: int = Field( + default=30_000 + ) # maximum number of characters in an observation's content when sent to the llm + temperature: float = Field(default=0.0) + top_p: float = Field(default=1.0) + custom_llm_provider: str | None = Field(default=None) + max_input_tokens: int | None = Field(default=None) + max_output_tokens: int | None = Field(default=None) + input_cost_per_token: float | None = Field(default=None) + output_cost_per_token: float | None = Field(default=None) + ollama_base_url: str | None = Field(default=None) # This setting can be sent in each call to litellm - drop_params: bool = True + drop_params: bool = Field(default=True) # Note: this setting is actually global, unlike drop_params - modify_params: bool = True - disable_vision: bool | None = None - reasoning_effort: str | None = None - caching_prompt: bool = True - log_completions: bool = False - log_completions_folder: str = os.path.join(LOG_DIR, 'completions') - custom_tokenizer: str | None = None - native_tool_calling: bool | None = None - - def defaults_to_dict(self) -> dict: - """Serialize fields to a dict for the frontend, including type hints, defaults, and whether it's optional.""" - result = {} - for f in fields(self): - result[f.name] = get_field_info(f) - return result + modify_params: bool = Field(default=True) + disable_vision: bool | None = Field(default=None) + caching_prompt: bool = Field(default=True) + log_completions: bool = Field(default=False) + log_completions_folder: str = Field(default=os.path.join(LOG_DIR, 'completions')) + custom_tokenizer: str | None = Field(default=None) + native_tool_calling: bool | None = Field(default=None) + + model_config = {'extra': 'forbid'} + + def model_post_init(self, __context: Any): + """Post-initialization hook to assign OpenRouter-related variables to environment variables. - def __post_init__(self): - """ - Post-initialization hook to assign OpenRouter-related variables to environment variables. This ensures that these values are accessible to litellm at runtime. """ + super().model_post_init(__context) # Assign OpenRouter-specific variables to environment variables if self.openrouter_site_url: os.environ['OR_SITE_URL'] = self.openrouter_site_url if self.openrouter_app_name: os.environ['OR_APP_NAME'] = self.openrouter_app_name - - def __str__(self): - attr_str = [] - for f in fields(self): - attr_name = f.name - attr_value = getattr(self, f.name) - - if attr_name in LLM_SENSITIVE_FIELDS: - attr_value = '******' if attr_value else None - - attr_str.append(f'{attr_name}={repr(attr_value)}') - - return f"LLMConfig({', '.join(attr_str)})" - - def __repr__(self): - return self.__str__() - - def to_safe_dict(self): - """Return a dict with the sensitive fields replaced with ******.""" - ret = self.__dict__.copy() - for k, v in ret.items(): - if k in LLM_SENSITIVE_FIELDS: - ret[k] = '******' if v else None - elif isinstance(v, LLMConfig): - ret[k] = v.to_safe_dict() - return ret - - @classmethod - def from_dict(cls, llm_config_dict: dict) -> 'LLMConfig': - """Create an LLMConfig object from a dictionary. - - This function is used to create an LLMConfig object from a dictionary. - """ - # Keep None values to preserve defaults, filter out other dicts - args = { - k: v - for k, v in llm_config_dict.items() - if not isinstance(v, dict) or v is None - } - return cls(**args) diff --git a/openhands/core/config/sandbox_config.py b/openhands/core/config/sandbox_config.py index 0ea40f29faab..44bd7ee0afda 100644 --- a/openhands/core/config/sandbox_config.py +++ b/openhands/core/config/sandbox_config.py @@ -1,11 +1,9 @@ import os -from dataclasses import dataclass, field, fields -from openhands.core.config.config_utils import get_field_info +from pydantic import BaseModel, Field -@dataclass -class SandboxConfig: +class SandboxConfig(BaseModel): """Configuration for the sandbox. Attributes: @@ -39,48 +37,32 @@ class SandboxConfig: This should be a JSON string that will be parsed into a dictionary. """ - remote_runtime_api_url: str = 'http://localhost:8000' - local_runtime_url: str = 'http://localhost' - keep_runtime_alive: bool = False - rm_all_containers: bool = False - api_key: str | None = None - base_container_image: str = 'nikolaik/python-nodejs:python3.12-nodejs22' # default to nikolaik/python-nodejs:python3.12-nodejs22 for eventstream runtime - runtime_container_image: str | None = None - user_id: int = os.getuid() if hasattr(os, 'getuid') else 1000 - timeout: int = 120 - remote_runtime_init_timeout: int = 180 - enable_auto_lint: bool = ( - False # once enabled, OpenHands would lint files after editing + remote_runtime_api_url: str = Field(default='http://localhost:8000') + local_runtime_url: str = Field(default='http://localhost') + keep_runtime_alive: bool = Field(default=False) + rm_all_containers: bool = Field(default=False) + api_key: str | None = Field(default=None) + base_container_image: str = Field( + default='nikolaik/python-nodejs:python3.12-nodejs22' ) - use_host_network: bool = False - runtime_extra_build_args: list[str] | None = None - initialize_plugins: bool = True - force_rebuild_runtime: bool = False - runtime_extra_deps: str | None = None - runtime_startup_env_vars: dict[str, str] = field(default_factory=dict) - browsergym_eval_env: str | None = None - platform: str | None = None - close_delay: int = 15 - remote_runtime_resource_factor: int = 1 - enable_gpu: bool = False - docker_runtime_kwargs: str | None = None - - def defaults_to_dict(self) -> dict: - """Serialize fields to a dict for the frontend, including type hints, defaults, and whether it's optional.""" - dict = {} - for f in fields(self): - dict[f.name] = get_field_info(f) - return dict - - def __str__(self): - attr_str = [] - for f in fields(self): - attr_name = f.name - attr_value = getattr(self, f.name) - - attr_str.append(f'{attr_name}={repr(attr_value)}') - - return f"SandboxConfig({', '.join(attr_str)})" + runtime_container_image: str | None = Field(default=None) + user_id: int = Field(default=os.getuid() if hasattr(os, 'getuid') else 1000) + timeout: int = Field(default=120) + remote_runtime_init_timeout: int = Field(default=180) + enable_auto_lint: bool = Field( + default=False # once enabled, OpenHands would lint files after editing + ) + use_host_network: bool = Field(default=False) + runtime_extra_build_args: list[str] | None = Field(default=None) + initialize_plugins: bool = Field(default=True) + force_rebuild_runtime: bool = Field(default=False) + runtime_extra_deps: str | None = Field(default=None) + runtime_startup_env_vars: dict[str, str] = Field(default_factory=dict) + browsergym_eval_env: str | None = Field(default=None) + platform: str | None = Field(default=None) + close_delay: int = Field(default=900) + remote_runtime_resource_factor: int = Field(default=1) + enable_gpu: bool = Field(default=False) + docker_runtime_kwargs: str | None = Field(default=None) - def __repr__(self): - return self.__str__() + model_config = {'extra': 'forbid'} diff --git a/openhands/core/config/security_config.py b/openhands/core/config/security_config.py index 60645f305736..a4805e3ab85f 100644 --- a/openhands/core/config/security_config.py +++ b/openhands/core/config/security_config.py @@ -1,10 +1,7 @@ -from dataclasses import dataclass, fields +from pydantic import BaseModel, Field -from openhands.core.config.config_utils import get_field_info - -@dataclass -class SecurityConfig: +class SecurityConfig(BaseModel): """Configuration for security related functionalities. Attributes: @@ -12,29 +9,5 @@ class SecurityConfig: security_analyzer: The security analyzer to use. """ - confirmation_mode: bool = False - security_analyzer: str | None = None - - def defaults_to_dict(self) -> dict: - """Serialize fields to a dict for the frontend, including type hints, defaults, and whether it's optional.""" - dict = {} - for f in fields(self): - dict[f.name] = get_field_info(f) - return dict - - def __str__(self): - attr_str = [] - for f in fields(self): - attr_name = f.name - attr_value = getattr(self, f.name) - - attr_str.append(f'{attr_name}={repr(attr_value)}') - - return f"SecurityConfig({', '.join(attr_str)})" - - @classmethod - def from_dict(cls, security_config_dict: dict) -> 'SecurityConfig': - return cls(**security_config_dict) - - def __repr__(self): - return self.__str__() + confirmation_mode: bool = Field(default=False) + security_analyzer: str | None = Field(default=None) diff --git a/openhands/core/config/utils.py b/openhands/core/config/utils.py index ccdd9494a84d..f543ab9682ac 100644 --- a/openhands/core/config/utils.py +++ b/openhands/core/config/utils.py @@ -3,13 +3,13 @@ import pathlib import platform import sys -from dataclasses import is_dataclass from types import UnionType from typing import Any, MutableMapping, get_args, get_origin from uuid import uuid4 import toml from dotenv import load_dotenv +from pydantic import BaseModel, ValidationError from openhands.core import logger from openhands.core.config.agent_config import AgentConfig @@ -43,17 +43,19 @@ def get_optional_type(union_type: UnionType) -> Any: return next((t for t in types if t is not type(None)), None) # helper function to set attributes based on env vars - def set_attr_from_env(sub_config: Any, prefix=''): - """Set attributes of a config dataclass based on environment variables.""" - for field_name, field_type in sub_config.__annotations__.items(): + def set_attr_from_env(sub_config: BaseModel, prefix=''): + """Set attributes of a config model based on environment variables.""" + for field_name, field_info in sub_config.model_fields.items(): + field_value = getattr(sub_config, field_name) + field_type = field_info.annotation + # compute the expected env var name from the prefix and field name # e.g. LLM_BASE_URL env_var_name = (prefix + field_name).upper() - if is_dataclass(field_type): - # nested dataclass - nested_sub_config = getattr(sub_config, field_name) - set_attr_from_env(nested_sub_config, prefix=field_name + '_') + if isinstance(field_value, BaseModel): + set_attr_from_env(field_value, prefix=field_name + '_') + elif env_var_name in env_or_toml_dict: # convert the env var to the correct type and set it value = env_or_toml_dict[env_var_name] @@ -125,22 +127,40 @@ def load_from_toml(cfg: AppConfig, toml_file: str = 'config.toml'): if isinstance(value, dict): try: if key is not None and key.lower() == 'agent': + # Every entry here is either a field for the default `agent` config group, or itself a group + # The best way to tell the difference is to try to parse it as an AgentConfig object + agent_group_ids: set[str] = set() + for nested_key, nested_value in value.items(): + if isinstance(nested_value, dict): + try: + agent_config = AgentConfig(**nested_value) + except ValidationError: + continue + agent_group_ids.add(nested_key) + cfg.set_agent_config(agent_config, nested_key) + logger.openhands_logger.debug( 'Attempt to load default agent config from config toml' ) - non_dict_fields = { - k: v for k, v in value.items() if not isinstance(v, dict) + value_without_groups = { + k: v for k, v in value.items() if k not in agent_group_ids } - agent_config = AgentConfig(**non_dict_fields) + agent_config = AgentConfig(**value_without_groups) cfg.set_agent_config(agent_config, 'agent') + + elif key is not None and key.lower() == 'llm': + # Every entry here is either a field for the default `llm` config group, or itself a group + # The best way to tell the difference is to try to parse it as an LLMConfig object + llm_group_ids: set[str] = set() for nested_key, nested_value in value.items(): if isinstance(nested_value, dict): - logger.openhands_logger.debug( - f'Attempt to load group {nested_key} from config toml as agent config' - ) - agent_config = AgentConfig(**nested_value) - cfg.set_agent_config(agent_config, nested_key) - elif key is not None and key.lower() == 'llm': + try: + llm_config = LLMConfig(**nested_value) + except ValidationError: + continue + llm_group_ids.add(nested_key) + cfg.set_llm_config(llm_config, nested_key) + logger.openhands_logger.debug( 'Attempt to load default LLM config from config toml' ) @@ -150,7 +170,7 @@ def load_from_toml(cfg: AppConfig, toml_file: str = 'config.toml'): for k, v in value.items(): if not isinstance(v, dict): generic_llm_fields[k] = v - generic_llm_config = LLMConfig.from_dict(generic_llm_fields) + generic_llm_config = LLMConfig(**generic_llm_fields) cfg.set_llm_config(generic_llm_config, 'llm') # Process custom named LLM configs @@ -170,22 +190,23 @@ def load_from_toml(cfg: AppConfig, toml_file: str = 'config.toml'): for k, v in nested_value.items(): if not isinstance(v, dict): custom_fields[k] = v - merged_llm_dict = generic_llm_config.__dict__.copy() + merged_llm_dict = generic_llm_fields.copy() merged_llm_dict.update(custom_fields) - - custom_llm_config = LLMConfig.from_dict(merged_llm_dict) + + custom_llm_config = LLMConfig(**merged_llm_dict) cfg.set_llm_config(custom_llm_config, nested_key) + elif key is not None and key.lower() == 'security': logger.openhands_logger.debug( 'Attempt to load security config from config toml' ) - security_config = SecurityConfig.from_dict(value) + security_config = SecurityConfig(**value) cfg.security = security_config elif not key.startswith('sandbox') and key.lower() != 'core': logger.openhands_logger.warning( f'Unknown key in {toml_file}: "{key}"' ) - except (TypeError, KeyError) as e: + except (TypeError, KeyError, ValidationError) as e: logger.openhands_logger.warning( f'Cannot parse [{key}] config from toml, values have not been applied.\nError: {e}', ) @@ -221,7 +242,7 @@ def load_from_toml(cfg: AppConfig, toml_file: str = 'config.toml'): logger.openhands_logger.warning( f'Unknown config key "{key}" in [core] section' ) - except (TypeError, KeyError) as e: + except (TypeError, KeyError, ValidationError) as e: logger.openhands_logger.warning( f'Cannot parse [sandbox] config from toml, values have not been applied.\nError: {e}', ) @@ -324,7 +345,7 @@ def get_llm_config_arg( # update the llm config with the specified section if 'llm' in toml_config and llm_config_arg in toml_config['llm']: - return LLMConfig.from_dict(toml_config['llm'][llm_config_arg]) + return LLMConfig(**toml_config['llm'][llm_config_arg]) logger.openhands_logger.debug(f'Loading from toml failed for {llm_config_arg}') return None diff --git a/openhands/llm/async_llm.py b/openhands/llm/async_llm.py index f553ae173fd6..805240c2b19f 100644 --- a/openhands/llm/async_llm.py +++ b/openhands/llm/async_llm.py @@ -23,7 +23,9 @@ def __init__(self, *args, **kwargs): self._async_completion = partial( self._call_acompletion, model=self.config.model, - api_key=self.config.api_key, + api_key=self.config.api_key.get_secret_value() + if self.config.api_key + else None, base_url=self.config.base_url, api_version=self.config.api_version, custom_llm_provider=self.config.custom_llm_provider, diff --git a/openhands/llm/llm.py b/openhands/llm/llm.py index 88cda96c5f00..8f9ac12b7063 100644 --- a/openhands/llm/llm.py +++ b/openhands/llm/llm.py @@ -141,7 +141,9 @@ def __init__( self._completion = partial( litellm_completion, model=self.config.model, - api_key=self.config.api_key, + api_key=self.config.api_key.get_secret_value() + if self.config.api_key + else None, base_url=self.config.base_url, api_version=self.config.api_version, custom_llm_provider=self.config.custom_llm_provider, @@ -331,7 +333,9 @@ def init_model_info(self): # GET {base_url}/v1/model/info with litellm_model_id as path param response = requests.get( f'{self.config.base_url}/v1/model/info', - headers={'Authorization': f'Bearer {self.config.api_key}'}, + headers={ + 'Authorization': f'Bearer {self.config.api_key.get_secret_value() if self.config.api_key else None}' + }, ) resp_json = response.json() if 'data' not in resp_json: diff --git a/openhands/llm/streaming_llm.py b/openhands/llm/streaming_llm.py index 10925b9564cf..2a0e5b2d9dbb 100644 --- a/openhands/llm/streaming_llm.py +++ b/openhands/llm/streaming_llm.py @@ -17,7 +17,9 @@ def __init__(self, *args, **kwargs): self._async_streaming_completion = partial( self._call_acompletion, model=self.config.model, - api_key=self.config.api_key, + api_key=self.config.api_key.get_secret_value() + if self.config.api_key + else None, base_url=self.config.base_url, api_version=self.config.api_version, custom_llm_provider=self.config.custom_llm_provider, diff --git a/openhands/runtime/impl/modal/modal_runtime.py b/openhands/runtime/impl/modal/modal_runtime.py index 6c2be0739615..61e72205a7f8 100644 --- a/openhands/runtime/impl/modal/modal_runtime.py +++ b/openhands/runtime/impl/modal/modal_runtime.py @@ -59,7 +59,8 @@ def __init__( self.sandbox = None self.modal_client = modal.Client.from_credentials( - config.modal_api_token_id, config.modal_api_token_secret + config.modal_api_token_id.get_secret_value(), + config.modal_api_token_secret.get_secret_value(), ) self.app = modal.App.lookup( 'openhands', create_if_missing=True, client=self.modal_client diff --git a/openhands/runtime/impl/runloop/runloop_runtime.py b/openhands/runtime/impl/runloop/runloop_runtime.py index 51628f54056d..add4619aea81 100644 --- a/openhands/runtime/impl/runloop/runloop_runtime.py +++ b/openhands/runtime/impl/runloop/runloop_runtime.py @@ -40,7 +40,7 @@ def __init__( self.devbox: DevboxView | None = None self.config = config self.runloop_api_client = Runloop( - bearer_token=config.runloop_api_key, + bearer_token=config.runloop_api_key.get_secret_value(), ) self.container_name = CONTAINER_NAME_PREFIX + sid super().__init__( diff --git a/openhands/server/routes/public.py b/openhands/server/routes/public.py index a057bebb3692..ff3ab378f9e2 100644 --- a/openhands/server/routes/public.py +++ b/openhands/server/routes/public.py @@ -51,8 +51,8 @@ async def get_litellm_models() -> list[str]: ): bedrock_model_list = bedrock.list_foundation_models( llm_config.aws_region_name, - llm_config.aws_access_key_id, - llm_config.aws_secret_access_key, + llm_config.aws_access_key_id.get_secret_value(), + llm_config.aws_secret_access_key.get_secret_value(), ) model_list = litellm_model_list_without_bedrock + bedrock_model_list for llm_config in config.llms.values(): diff --git a/openhands/server/routes/settings.py b/openhands/server/routes/settings.py index ca45c142ff1c..9abad14f8547 100644 --- a/openhands/server/routes/settings.py +++ b/openhands/server/routes/settings.py @@ -30,9 +30,6 @@ async def load_settings(request: Request) -> Settings | None: status_code=status.HTTP_404_NOT_FOUND, content={'error': 'Settings not found'}, ) - - # For security reasons we don't ever send the api key to the client - settings.llm_api_key = 'SET' if settings.llm_api_key else None return settings except Exception as e: logger.warning(f'Invalid token: {e}') diff --git a/openhands/server/session/conversation_init_data.py b/openhands/server/session/conversation_init_data.py index 218d72addd14..82979e91fd96 100644 --- a/openhands/server/session/conversation_init_data.py +++ b/openhands/server/session/conversation_init_data.py @@ -1,13 +1,12 @@ -from dataclasses import dataclass +from pydantic import Field from openhands.server.settings import Settings -@dataclass class ConversationInitData(Settings): """ Session initialization data for the web environment - a deep copy of the global config is made and then overridden with this data. """ - github_token: str | None = None - selected_repository: str | None = None + github_token: str | None = Field(default=None) + selected_repository: str | None = Field(default=None) diff --git a/openhands/server/session/session.py b/openhands/server/session/session.py index b24b29702086..c6a4dd31f04c 100644 --- a/openhands/server/session/session.py +++ b/openhands/server/session/session.py @@ -91,6 +91,9 @@ async def initialize_agent(self, settings: Settings, initial_user_msg: str | Non ) max_iterations = settings.max_iterations or self.config.max_iterations + # This is a shallow copy of the default LLM config, so changes here will + # persist if we retrieve the default LLM config again when constructing + # the agent default_llm_config = self.config.get_llm_config() default_llm_config.model = settings.llm_model or '' default_llm_config.api_key = settings.llm_api_key diff --git a/openhands/server/settings.py b/openhands/server/settings.py index 57c879e49d45..bdb755fd5c20 100644 --- a/openhands/server/settings.py +++ b/openhands/server/settings.py @@ -1,8 +1,8 @@ -from dataclasses import dataclass +from pydantic import BaseModel, SecretStr, SerializationInfo, field_serializer +from pydantic.json import pydantic_encoder -@dataclass -class Settings: +class Settings(BaseModel): """ Persisted settings for OpenHands sessions """ @@ -13,6 +13,20 @@ class Settings: security_analyzer: str | None = None confirmation_mode: bool | None = None llm_model: str | None = None - llm_api_key: str | None = None + llm_api_key: SecretStr | None = None llm_base_url: str | None = None remote_runtime_resource_factor: int | None = None + + @field_serializer('llm_api_key') + def llm_api_key_serializer(self, llm_api_key: SecretStr, info: SerializationInfo): + """Custom serializer for the LLM API key. + + To serialize the API key instead of `"********"`, set `expose_secrets` to True in the serialization context. For example:: + + settings.model_dump_json(context={'expose_secrets': True}) + """ + context = info.context + if context and context.get('expose_secrets', False): + return llm_api_key.get_secret_value() + + return pydantic_encoder(llm_api_key) diff --git a/openhands/storage/settings/file_settings_store.py b/openhands/storage/settings/file_settings_store.py index e0ea22d48c28..d3cc08677078 100644 --- a/openhands/storage/settings/file_settings_store.py +++ b/openhands/storage/settings/file_settings_store.py @@ -26,7 +26,7 @@ async def load(self) -> Settings | None: return None async def store(self, settings: Settings): - json_str = json.dumps(settings.__dict__) + json_str = settings.model_dump_json(context={'expose_secrets': True}) await call_sync_from_async(self.file_store.write, self.path, json_str) @classmethod diff --git a/openhands/utils/embeddings.py b/openhands/utils/embeddings.py index 7e251f0e5022..6791787d3204 100644 --- a/openhands/utils/embeddings.py +++ b/openhands/utils/embeddings.py @@ -90,7 +90,9 @@ def get_embedding_model(strategy: str, llm_config: LLMConfig) -> 'BaseEmbedding' return OpenAIEmbedding( model='text-embedding-ada-002', - api_key=llm_config.api_key, + api_key=llm_config.api_key.get_secret_value() + if llm_config.api_key + else None, ) elif strategy == 'azureopenai': from llama_index.embeddings.azure_openai import AzureOpenAIEmbedding diff --git a/tests/unit/test_acompletion.py b/tests/unit/test_acompletion.py index b6753759be3d..cca18bbb5b29 100644 --- a/tests/unit/test_acompletion.py +++ b/tests/unit/test_acompletion.py @@ -109,9 +109,6 @@ async def mock_on_cancel_requested(): print(f'Cancel requested: {is_set}') return is_set - config = load_app_config() - config.on_cancel_requested_fn = mock_on_cancel_requested - async def mock_acompletion(*args, **kwargs): print('Starting mock_acompletion') for i in range(20): # Increased iterations for longer running task @@ -153,13 +150,6 @@ async def cancel_after_delay(): async def test_async_streaming_completion_with_user_cancellation(cancel_after_chunks): cancel_requested = False - async def mock_on_cancel_requested(): - nonlocal cancel_requested - return cancel_requested - - config = load_app_config() - config.on_cancel_requested_fn = mock_on_cancel_requested - test_messages = [ 'This is ', 'a test ', diff --git a/tests/unit/test_codeact_agent.py b/tests/unit/test_codeact_agent.py index 84f0b8fc1993..26fa4428826e 100644 --- a/tests/unit/test_codeact_agent.py +++ b/tests/unit/test_codeact_agent.py @@ -60,7 +60,6 @@ def mock_state() -> State: def test_cmd_output_observation_message(agent: CodeActAgent): - agent.config.function_calling = False obs = CmdOutputObservation( command='echo hello', content='Command output', @@ -82,7 +81,6 @@ def test_cmd_output_observation_message(agent: CodeActAgent): def test_ipython_run_cell_observation_message(agent: CodeActAgent): - agent.config.function_calling = False obs = IPythonRunCellObservation( code='plt.plot()', content='IPython output\n![image](data:image/png;base64,ABC123)', @@ -105,7 +103,6 @@ def test_ipython_run_cell_observation_message(agent: CodeActAgent): def test_agent_delegate_observation_message(agent: CodeActAgent): - agent.config.function_calling = False obs = AgentDelegateObservation( content='Content', outputs={'content': 'Delegated agent output'} ) @@ -122,7 +119,6 @@ def test_agent_delegate_observation_message(agent: CodeActAgent): def test_error_observation_message(agent: CodeActAgent): - agent.config.function_calling = False obs = ErrorObservation('Error message') results = agent.get_observation_message(obs, tool_call_id_to_message={}) @@ -145,7 +141,6 @@ def test_unknown_observation_message(agent: CodeActAgent): def test_file_edit_observation_message(agent: CodeActAgent): - agent.config.function_calling = False obs = FileEditObservation( path='/test/file.txt', prev_exist=True, @@ -167,7 +162,6 @@ def test_file_edit_observation_message(agent: CodeActAgent): def test_file_read_observation_message(agent: CodeActAgent): - agent.config.function_calling = False obs = FileReadObservation( path='/test/file.txt', content='File content', @@ -186,7 +180,6 @@ def test_file_read_observation_message(agent: CodeActAgent): def test_browser_output_observation_message(agent: CodeActAgent): - agent.config.function_calling = False obs = BrowserOutputObservation( url='http://example.com', trigger_by_action='browse', @@ -207,7 +200,6 @@ def test_browser_output_observation_message(agent: CodeActAgent): def test_user_reject_observation_message(agent: CodeActAgent): - agent.config.function_calling = False obs = UserRejectObservation('Action rejected') results = agent.get_observation_message(obs, tool_call_id_to_message={}) @@ -223,7 +215,6 @@ def test_user_reject_observation_message(agent: CodeActAgent): def test_function_calling_observation_message(agent: CodeActAgent): - agent.config.function_calling = True mock_response = { 'id': 'mock_id', 'total_calls_in_response': 1, diff --git a/tests/unit/test_condenser.py b/tests/unit/test_condenser.py index 4aa5afcf7543..91878c86baa1 100644 --- a/tests/unit/test_condenser.py +++ b/tests/unit/test_condenser.py @@ -226,7 +226,7 @@ def test_llm_condenser_from_config(): assert isinstance(condenser, LLMSummarizingCondenser) assert condenser.llm.config.model == 'gpt-4o' - assert condenser.llm.config.api_key == 'test_key' + assert condenser.llm.config.api_key.get_secret_value() == 'test_key' def test_llm_condenser(mock_llm, mock_state): @@ -381,7 +381,7 @@ def test_llm_attention_condenser_from_config(): assert isinstance(condenser, LLMAttentionCondenser) assert condenser.llm.config.model == 'gpt-4o' - assert condenser.llm.config.api_key == 'test_key' + assert condenser.llm.config.api_key.get_secret_value() == 'test_key' assert condenser.max_size == 50 assert condenser.keep_first == 10 diff --git a/tests/unit/test_config.py b/tests/unit/test_config.py index 44a76145cf6e..5edfd64cda90 100644 --- a/tests/unit/test_config.py +++ b/tests/unit/test_config.py @@ -63,7 +63,7 @@ def test_compat_env_to_config(monkeypatch, setup_env): assert config.workspace_base == '/repos/openhands/workspace' assert isinstance(config.get_llm_config(), LLMConfig) - assert config.get_llm_config().api_key == 'sk-proj-rgMV0...' + assert config.get_llm_config().api_key.get_secret_value() == 'sk-proj-rgMV0...' assert config.get_llm_config().model == 'gpt-4o' assert isinstance(config.get_agent_config(), AgentConfig) assert isinstance(config.get_agent_config().memory_max_threads, int) @@ -83,7 +83,7 @@ def test_load_from_old_style_env(monkeypatch, default_config): load_from_env(default_config, os.environ) - assert default_config.get_llm_config().api_key == 'test-api-key' + assert default_config.get_llm_config().api_key.get_secret_value() == 'test-api-key' assert default_config.get_agent_config().memory_enabled is True assert default_config.default_agent == 'BrowsingAgent' assert default_config.workspace_base == '/opt/files/workspace' @@ -126,7 +126,7 @@ def test_load_from_new_style_toml(default_config, temp_toml_file): # default llm & agent configs assert default_config.default_agent == 'TestAgent' assert default_config.get_llm_config().model == 'test-model' - assert default_config.get_llm_config().api_key == 'toml-api-key' + assert default_config.get_llm_config().api_key.get_secret_value() == 'toml-api-key' assert default_config.get_agent_config().memory_enabled is True # undefined agent config inherits default ones @@ -291,7 +291,7 @@ def test_env_overrides_compat_toml(monkeypatch, default_config, temp_toml_file): assert default_config.get_llm_config().model == 'test-model' assert default_config.get_llm_config('llm').model == 'test-model' assert default_config.get_llm_config_from_agent().model == 'test-model' - assert default_config.get_llm_config().api_key == 'env-api-key' + assert default_config.get_llm_config().api_key.get_secret_value() == 'env-api-key' # after we set workspace_base to 'UNDEFINED' in the environment, # workspace_base should be set to that @@ -336,7 +336,7 @@ def test_env_overrides_sandbox_toml(monkeypatch, default_config, temp_toml_file) assert default_config.workspace_mount_path is None # before load_from_env, values are set to the values from the toml file - assert default_config.get_llm_config().api_key == 'toml-api-key' + assert default_config.get_llm_config().api_key.get_secret_value() == 'toml-api-key' assert default_config.sandbox.timeout == 500 assert default_config.sandbox.user_id == 1001 @@ -345,7 +345,7 @@ def test_env_overrides_sandbox_toml(monkeypatch, default_config, temp_toml_file) # values from env override values from toml assert os.environ.get('LLM_MODEL') is None assert default_config.get_llm_config().model == 'test-model' - assert default_config.get_llm_config().api_key == 'env-api-key' + assert default_config.get_llm_config().api_key.get_secret_value() == 'env-api-key' assert default_config.sandbox.timeout == 1000 assert default_config.sandbox.user_id == 1002 @@ -412,7 +412,7 @@ def test_security_config_from_dict(): # Test with all fields config_dict = {'confirmation_mode': True, 'security_analyzer': 'some_analyzer'} - security_config = SecurityConfig.from_dict(config_dict) + security_config = SecurityConfig(**config_dict) # Verify all fields are correctly set assert security_config.confirmation_mode is True @@ -560,10 +560,7 @@ def test_load_from_toml_partial_invalid(default_config, temp_toml_file, caplog): assert 'Cannot parse [llm] config from toml' in log_content assert 'values have not been applied' in log_content # Error: LLMConfig.__init__() got an unexpected keyword argume - assert ( - 'Error: LLMConfig.__init__() got an unexpected keyword argume' - in log_content - ) + assert 'Error: 1 validation error for LLMConfig' in log_content assert 'invalid_field' in log_content # invalid [sandbox] config @@ -635,12 +632,14 @@ def test_api_keys_repr_str(): aws_access_key_id='my_access_key', aws_secret_access_key='my_secret_key', ) - assert "api_key='******'" in repr(llm_config) - assert "aws_access_key_id='******'" in repr(llm_config) - assert "aws_secret_access_key='******'" in repr(llm_config) - assert "api_key='******'" in str(llm_config) - assert "aws_access_key_id='******'" in str(llm_config) - assert "aws_secret_access_key='******'" in str(llm_config) + + # Check that no secret keys are emitted in representations of the config object + assert 'my_api_key' not in repr(llm_config) + assert 'my_api_key' not in str(llm_config) + assert 'my_access_key' not in repr(llm_config) + assert 'my_access_key' not in str(llm_config) + assert 'my_secret_key' not in repr(llm_config) + assert 'my_secret_key' not in str(llm_config) # Check that no other attrs in LLMConfig have 'key' or 'token' in their name # This will fail when new attrs are added, and attract attention @@ -652,7 +651,7 @@ def test_api_keys_repr_str(): 'output_cost_per_token', 'custom_tokenizer', ] - for attr_name in dir(LLMConfig): + for attr_name in LLMConfig.model_fields.keys(): if ( not attr_name.startswith('__') and attr_name not in known_key_token_attrs_llm @@ -667,7 +666,7 @@ def test_api_keys_repr_str(): # Test AgentConfig # No attrs in AgentConfig have 'key' or 'token' in their name agent_config = AgentConfig(memory_enabled=True, memory_max_threads=4) - for attr_name in dir(AgentConfig): + for attr_name in AgentConfig.model_fields.keys(): if not attr_name.startswith('__'): assert ( 'key' not in attr_name.lower() @@ -686,16 +685,16 @@ def test_api_keys_repr_str(): modal_api_token_secret='my_modal_api_token_secret', runloop_api_key='my_runloop_api_key', ) - assert "e2b_api_key='******'" in repr(app_config) - assert "e2b_api_key='******'" in str(app_config) - assert "jwt_secret='******'" in repr(app_config) - assert "jwt_secret='******'" in str(app_config) - assert "modal_api_token_id='******'" in repr(app_config) - assert "modal_api_token_id='******'" in str(app_config) - assert "modal_api_token_secret='******'" in repr(app_config) - assert "modal_api_token_secret='******'" in str(app_config) - assert "runloop_api_key='******'" in repr(app_config) - assert "runloop_api_key='******'" in str(app_config) + assert 'my_e2b_api_key' not in repr(app_config) + assert 'my_e2b_api_key' not in str(app_config) + assert 'my_jwt_secret' not in repr(app_config) + assert 'my_jwt_secret' not in str(app_config) + assert 'my_modal_api_token_id' not in repr(app_config) + assert 'my_modal_api_token_id' not in str(app_config) + assert 'my_modal_api_token_secret' not in repr(app_config) + assert 'my_modal_api_token_secret' not in str(app_config) + assert 'my_runloop_api_key' not in repr(app_config) + assert 'my_runloop_api_key' not in str(app_config) # Check that no other attrs in AppConfig have 'key' or 'token' in their name # This will fail when new attrs are added, and attract attention @@ -705,7 +704,7 @@ def test_api_keys_repr_str(): 'modal_api_token_secret', 'runloop_api_key', ] - for attr_name in dir(AppConfig): + for attr_name in AppConfig.model_fields.keys(): if ( not attr_name.startswith('__') and attr_name not in known_key_token_attrs_app diff --git a/tests/unit/test_file_settings_store.py b/tests/unit/test_file_settings_store.py index c20153c9fdce..347754035aa1 100644 --- a/tests/unit/test_file_settings_store.py +++ b/tests/unit/test_file_settings_store.py @@ -1,4 +1,3 @@ -import json from unittest.mock import MagicMock, patch import pytest @@ -43,7 +42,7 @@ async def test_store_and_load_data(file_settings_store): await file_settings_store.store(init_data) # Verify store called with correct JSON - expected_json = json.dumps(init_data.__dict__) + expected_json = init_data.model_dump_json(context={'expose_secrets': True}) file_settings_store.file_store.write.assert_called_once_with( 'settings.json', expected_json ) @@ -60,7 +59,12 @@ async def test_store_and_load_data(file_settings_store): assert loaded_data.security_analyzer == init_data.security_analyzer assert loaded_data.confirmation_mode == init_data.confirmation_mode assert loaded_data.llm_model == init_data.llm_model - assert loaded_data.llm_api_key == init_data.llm_api_key + assert loaded_data.llm_api_key + assert init_data.llm_api_key + assert ( + loaded_data.llm_api_key.get_secret_value() + == init_data.llm_api_key.get_secret_value() + ) assert loaded_data.llm_base_url == init_data.llm_base_url diff --git a/tests/unit/test_llm.py b/tests/unit/test_llm.py index edf82d8aa41b..227b0006b020 100644 --- a/tests/unit/test_llm.py +++ b/tests/unit/test_llm.py @@ -40,7 +40,7 @@ def default_config(): def test_llm_init_with_default_config(default_config): llm = LLM(default_config) assert llm.config.model == 'gpt-4o' - assert llm.config.api_key == 'test_key' + assert llm.config.api_key.get_secret_value() == 'test_key' assert isinstance(llm.metrics, Metrics) assert llm.metrics.model_name == 'gpt-4o' @@ -77,7 +77,7 @@ def test_llm_init_with_custom_config(): ) llm = LLM(custom_config) assert llm.config.model == 'custom-model' - assert llm.config.api_key == 'custom_key' + assert llm.config.api_key.get_secret_value() == 'custom_key' assert llm.config.max_input_tokens == 5000 assert llm.config.max_output_tokens == 1500 assert llm.config.temperature == 0.8 diff --git a/tests/unit/test_llm_config.py b/tests/unit/test_llm_config.py index 2fc22d6f2232..342112a44316 100644 --- a/tests/unit/test_llm_config.py +++ b/tests/unit/test_llm_config.py @@ -59,28 +59,28 @@ def test_load_from_toml_llm_with_fallback( # Verify generic LLM configuration generic_llm = default_config.get_llm_config('llm') assert generic_llm.model == 'base-model' - assert generic_llm.api_key == 'base-api-key' + assert generic_llm.api_key.get_secret_value() == 'base-api-key' assert generic_llm.embedding_model == 'base-embedding' assert generic_llm.num_retries == 3 # Verify custom1 LLM falls back 'num_retries' from base custom1 = default_config.get_llm_config('custom1') assert custom1.model == 'custom-model-1' - assert custom1.api_key == 'custom-api-key-1' + assert custom1.api_key.get_secret_value() == 'custom-api-key-1' assert custom1.embedding_model == 'base-embedding' assert custom1.num_retries == 3 # from [llm] # Verify custom2 LLM overrides 'num_retries' custom2 = default_config.get_llm_config('custom2') assert custom2.model == 'custom-model-2' - assert custom2.api_key == 'custom-api-key-2' + assert custom2.api_key.get_secret_value() == 'custom-api-key-2' assert custom2.embedding_model == 'base-embedding' assert custom2.num_retries == 5 # overridden value # Verify custom3 LLM inherits all attributes except 'model' and 'api_key' custom3 = default_config.get_llm_config('custom3') assert custom3.model == 'custom-model-3' - assert custom3.api_key == 'custom-api-key-3' + assert custom3.api_key.get_secret_value() == 'custom-api-key-3' assert custom3.embedding_model == 'base-embedding' assert custom3.num_retries == 3 # from [llm] @@ -113,14 +113,14 @@ def test_load_from_toml_llm_custom_overrides_all( # Verify generic LLM configuration remains unchanged generic_llm = default_config.get_llm_config('llm') assert generic_llm.model == 'base-model' - assert generic_llm.api_key == 'base-api-key' + assert generic_llm.api_key.get_secret_value() == 'base-api-key' assert generic_llm.embedding_model == 'base-embedding' assert generic_llm.num_retries == 3 # Verify custom_full LLM overrides all attributes custom_full = default_config.get_llm_config('custom_full') assert custom_full.model == 'full-custom-model' - assert custom_full.api_key == 'full-custom-api-key' + assert custom_full.api_key.get_secret_value() == 'full-custom-api-key' assert custom_full.embedding_model == 'full-custom-embedding' assert custom_full.num_retries == 10 # overridden value @@ -136,14 +136,14 @@ def test_load_from_toml_llm_custom_partial_override( # Verify custom1 LLM overrides 'model' and 'api_key' but inherits 'num_retries' custom1 = default_config.get_llm_config('custom1') assert custom1.model == 'custom-model-1' - assert custom1.api_key == 'custom-api-key-1' + assert custom1.api_key.get_secret_value() == 'custom-api-key-1' assert custom1.embedding_model == 'base-embedding' assert custom1.num_retries == 3 # from [llm] # Verify custom2 LLM overrides 'model', 'api_key', and 'num_retries' custom2 = default_config.get_llm_config('custom2') assert custom2.model == 'custom-model-2' - assert custom2.api_key == 'custom-api-key-2' + assert custom2.api_key.get_secret_value() == 'custom-api-key-2' assert custom2.embedding_model == 'base-embedding' assert custom2.num_retries == 5 # Overridden value @@ -159,7 +159,7 @@ def test_load_from_toml_llm_custom_no_override( # Verify custom3 LLM inherits 'embedding_model' and 'num_retries' from generic custom3 = default_config.get_llm_config('custom3') assert custom3.model == 'custom-model-3' - assert custom3.api_key == 'custom-api-key-3' + assert custom3.api_key.get_secret_value() == 'custom-api-key-3' assert custom3.embedding_model == 'base-embedding' assert custom3.num_retries == 3 # from [llm] @@ -186,7 +186,7 @@ def test_load_from_toml_llm_missing_generic( # Verify custom_only LLM uses its own attributes and defaults for others custom_only = default_config.get_llm_config('custom_only') assert custom_only.model == 'custom-only-model' - assert custom_only.api_key == 'custom-only-api-key' + assert custom_only.api_key.get_secret_value() == 'custom-only-api-key' assert custom_only.embedding_model == 'local' # default value assert custom_only.num_retries == 8 # default value @@ -217,12 +217,12 @@ def test_load_from_toml_llm_invalid_config( # Verify generic LLM is loaded correctly generic_llm = default_config.get_llm_config('llm') assert generic_llm.model == 'base-model' - assert generic_llm.api_key == 'base-api-key' + assert generic_llm.api_key.get_secret_value() == 'base-api-key' assert generic_llm.num_retries == 3 # Verify invalid_custom LLM does not override generic attributes custom_invalid = default_config.get_llm_config('invalid_custom') assert custom_invalid.model == 'base-model' - assert custom_invalid.api_key == 'base-api-key' + assert custom_invalid.api_key.get_secret_value() == 'base-api-key' assert custom_invalid.num_retries == 3 # default value assert custom_invalid.embedding_model == 'local' # default value diff --git a/tests/unit/test_llm_draft_config.py b/tests/unit/test_llm_draft_config.py index a9803782b2de..9a9b7dd6878d 100644 --- a/tests/unit/test_llm_draft_config.py +++ b/tests/unit/test_llm_draft_config.py @@ -86,7 +86,7 @@ def test_draft_editor_as_named_llm(config_toml_with_draft_editor): draft_llm = config.get_llm_config('draft_editor') assert draft_llm is not None assert draft_llm.model == 'draft-model' - assert draft_llm.api_key == 'draft-api-key' + assert draft_llm.api_key.get_secret_value() == 'draft-api-key' def test_draft_editor_fallback(config_toml_with_draft_editor): diff --git a/tests/unit/test_settings_api.py b/tests/unit/test_settings_api.py index a8e52a239010..ed7c02eb1403 100644 --- a/tests/unit/test_settings_api.py +++ b/tests/unit/test_settings_api.py @@ -2,6 +2,7 @@ import pytest from fastapi.testclient import TestClient +from pydantic import SecretStr from openhands.core.config.sandbox_config import SandboxConfig from openhands.server.app import app @@ -50,7 +51,7 @@ async def test_settings_api_runtime_factor(test_client, mock_settings_store): 'security_analyzer': 'default', 'confirmation_mode': True, 'llm_model': 'test-model', - 'llm_api_key': None, + 'llm_api_key': 'test-key', 'llm_base_url': 'https://test.com', 'remote_runtime_resource_factor': 2, } @@ -83,3 +84,36 @@ async def test_settings_api_runtime_factor(test_client, mock_settings_store): mock_settings_store.store.assert_called() stored_settings = mock_settings_store.store.call_args[0][0] assert stored_settings.remote_runtime_resource_factor == 2 + + assert isinstance(stored_settings.llm_api_key, SecretStr) + assert stored_settings.llm_api_key.get_secret_value() == 'test-key' + + +@pytest.mark.asyncio +async def test_settings_llm_api_key(test_client, mock_settings_store): + # Mock the settings store to return None initially (no existing settings) + mock_settings_store.load.return_value = None + + # Test data with remote_runtime_resource_factor + settings_data = {'llm_api_key': 'test-key'} + + # The test_client fixture already handles authentication + + # Make the POST request to store settings + response = test_client.post('/api/settings', json=settings_data) + assert response.status_code == 200 + + # Verify the settings were stored with the correct secret API key + stored_settings = mock_settings_store.store.call_args[0][0] + assert isinstance(stored_settings.llm_api_key, SecretStr) + assert stored_settings.llm_api_key.get_secret_value() == 'test-key' + + # Mock settings store to return our settings for the GET request + mock_settings_store.load.return_value = Settings(**settings_data) + + # Make a GET request to retrieve settings + response = test_client.get('/api/settings') + assert response.status_code == 200 + + # We should never expose the API key in the response + assert 'test-key' not in response.json()