Skip to content

Commit

Permalink
fix: Handle draft_editor and fallback values in LLM configs
Browse files Browse the repository at this point in the history
- Adds proper fallback mechanism from generic LLM config to custom configs
- Adds special handling for draft_editor field:
  - Falls back to generic config value if not specified
  - Can be set to None using 'null' in TOML
  - Can be overridden with custom value
  • Loading branch information
openhands-agent committed Dec 14, 2024
1 parent 3d1cd83 commit bf8feff
Show file tree
Hide file tree
Showing 3 changed files with 138 additions and 10 deletions.
19 changes: 15 additions & 4 deletions openhands/core/config/llm_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,8 +130,19 @@ def from_dict(cls, llm_config_dict: dict) -> 'LLMConfig':
This function is used to create an LLMConfig object from a dictionary,
with the exception of the 'draft_editor' key, which is a nested LLMConfig object.
"""
args = {k: v for k, v in llm_config_dict.items() if not isinstance(v, dict)}
if 'draft_editor' in llm_config_dict:
draft_editor_config = LLMConfig(**llm_config_dict['draft_editor'])
args['draft_editor'] = draft_editor_config
# Keep None values to preserve defaults, filter out other dicts
args = {
k: v
for k, v in llm_config_dict.items()
if not isinstance(v, dict) or v is None
}
if (
'draft_editor' in llm_config_dict
and llm_config_dict['draft_editor'] is not None
):
if isinstance(llm_config_dict['draft_editor'], LLMConfig):
args['draft_editor'] = llm_config_dict['draft_editor']
else:
draft_editor_config = LLMConfig(**llm_config_dict['draft_editor'])
args['draft_editor'] = draft_editor_config
return cls(**args)
37 changes: 31 additions & 6 deletions openhands/core/config/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,18 +140,28 @@ def load_from_toml(cfg: AppConfig, toml_file: str = 'config.toml'):
logger.openhands_logger.debug(
'Attempt to load default LLM config from config toml'
)
# Extract generic LLM fields
generic_llm_fields = {
k: v for k, v in value.items() if not isinstance(v, dict)
}
# Extract generic LLM fields, keeping draft_editor
generic_llm_fields = {}
for k, v in value.items():
if not isinstance(v, dict) or k == 'draft_editor':
generic_llm_fields[k] = v
logger.openhands_logger.debug(
f'Generic LLM fields: {generic_llm_fields}'
)
generic_llm_config = LLMConfig.from_dict(generic_llm_fields)
logger.openhands_logger.debug(
f'Generic LLM config dict: {generic_llm_config.__dict__}'
)
cfg.set_llm_config(generic_llm_config, 'llm')

# Process custom named LLM configs
for nested_key, nested_value in value.items():
if isinstance(nested_value, dict):
logger.openhands_logger.debug(
f'Attempt to load group {nested_key} from config toml as LLM config'
f'Processing custom LLM config "{nested_key}":'
)
logger.openhands_logger.debug(
f' Nested value: {nested_value}'
)
# Apply generic LLM config with custom LLM overrides, e.g.
# [llm]
Expand All @@ -160,8 +170,23 @@ def load_from_toml(cfg: AppConfig, toml_file: str = 'config.toml'):
# [llm.claude]
# model="claude-3-5-sonnet"
# results in num_retries APPLIED to claude-3-5-sonnet
custom_fields = {}
for k, v in nested_value.items():
if not isinstance(v, dict) or k == 'draft_editor':
custom_fields[k] = v
merged_llm_dict = generic_llm_config.__dict__.copy()
merged_llm_dict.update(nested_value)
merged_llm_dict.update(custom_fields)
# Handle draft_editor with fallback values:
# - If draft_editor is "null", use None
# - If draft_editor is in custom fields, use that value
# - If draft_editor is not specified, fall back to generic config value
if 'draft_editor' in custom_fields:
if custom_fields['draft_editor'] == 'null':
merged_llm_dict['draft_editor'] = None
else:
merged_llm_dict['draft_editor'] = (
generic_llm_config.draft_editor
)
custom_llm_config = LLMConfig.from_dict(merged_llm_dict)
cfg.set_llm_config(custom_llm_config, nested_key)
elif key is not None and key.lower() == 'security':
Expand Down
92 changes: 92 additions & 0 deletions tests/unit/test_llm_draft_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
import pathlib

import pytest

from openhands.core.config import AppConfig
from openhands.core.config.utils import load_from_toml


@pytest.fixture
def draft_llm_toml(tmp_path: pathlib.Path) -> str:
toml_content = """
[core]
workspace_base = "./workspace"
[llm]
model = "base-model"
api_key = "base-api-key"
draft_editor = { model = "draft-model", api_key = "draft-api-key" }
[llm.custom1]
model = "custom-model-1"
api_key = "custom-api-key-1"
# Should use draft_editor from [llm] as fallback
[llm.custom2]
model = "custom-model-2"
api_key = "custom-api-key-2"
draft_editor = { model = "custom-draft", api_key = "custom-draft-key" }
[llm.custom3]
model = "custom-model-3"
api_key = "custom-api-key-3"
draft_editor = "null" # Explicitly set to null in TOML
"""
toml_file = tmp_path / 'llm_config.toml'
toml_file.write_text(toml_content)
return str(toml_file)


def test_draft_editor_fallback(draft_llm_toml):
"""Test that draft_editor is correctly handled in different scenarios:
- Falls back to generic [llm] section value
- Uses custom value when specified
- Can be explicitly set to null
"""
config = AppConfig()

# Verify default draft_editor is None
default_llm = config.get_llm_config('llm')
assert default_llm.draft_editor is None

# Load config from TOML
load_from_toml(config, draft_llm_toml)

# Verify generic LLM draft_editor
generic_llm = config.get_llm_config('llm')
assert generic_llm.draft_editor is not None
assert generic_llm.draft_editor.model == 'draft-model'
assert generic_llm.draft_editor.api_key == 'draft-api-key'

# Verify custom1 uses draft_editor from generic as fallback
custom1 = config.get_llm_config('custom1')
assert custom1.model == 'custom-model-1'
assert custom1.draft_editor is not None
assert custom1.draft_editor.model == 'draft-model'
assert custom1.draft_editor.api_key == 'draft-api-key'

# Verify custom2 overrides draft_editor
custom2 = config.get_llm_config('custom2')
assert custom2.model == 'custom-model-2'
assert custom2.draft_editor is not None
assert custom2.draft_editor.model == 'custom-draft'
assert custom2.draft_editor.api_key == 'custom-draft-key'

# Verify custom3 has draft_editor explicitly set to None
custom3 = config.get_llm_config('custom3')
assert custom3.model == 'custom-model-3'
assert custom3.draft_editor is None


def test_draft_editor_defaults(draft_llm_toml):
"""Test that draft_editor uses default values from LLMConfig when not specified"""
config = AppConfig()
load_from_toml(config, draft_llm_toml)

generic_llm = config.get_llm_config('llm')
assert generic_llm.draft_editor.num_retries == 8 # Default from LLMConfig
assert generic_llm.draft_editor.embedding_model == 'local' # Default from LLMConfig

custom2 = config.get_llm_config('custom2')
assert custom2.draft_editor.num_retries == 8 # Default from LLMConfig
assert custom2.draft_editor.embedding_model == 'local' # Default from LLMConfig

0 comments on commit bf8feff

Please sign in to comment.