Skip to content

Commit

Permalink
feature: Condenser Interface and Defaults (#5306)
Browse files Browse the repository at this point in the history
Co-authored-by: openhands <[email protected]>
Co-authored-by: Calvin Smith <[email protected]>
Co-authored-by: Engel Nyst <[email protected]>
  • Loading branch information
4 people authored Jan 7, 2025
1 parent 561f308 commit 6e4ff56
Show file tree
Hide file tree
Showing 17 changed files with 2,683 additions and 1,805 deletions.
4 changes: 3 additions & 1 deletion evaluation/benchmarks/swe_bench/run_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
EvalOutput,
assert_and_raise,
codeact_user_response,
get_metrics,
is_fatal_evaluation_error,
make_metadata,
prepare_dataset,
Expand Down Expand Up @@ -148,6 +149,7 @@ def get_config(
codeact_enable_jupyter=False,
codeact_enable_browsing=RUN_WITH_BROWSING,
codeact_enable_llm_editor=False,
condenser=metadata.condenser_config,
)
config.set_agent_config(agent_config)
return config
Expand Down Expand Up @@ -448,7 +450,7 @@ def process_instance(

# NOTE: this is NO LONGER the event stream, but an agent history that includes delegate agent's events
histories = [event_to_dict(event) for event in state.history]
metrics = state.metrics.get() if state.metrics else None
metrics = get_metrics(state)

# Save the output
output = EvalOutput(
Expand Down
27 changes: 27 additions & 0 deletions evaluation/utils/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@

from openhands.controller.state.state import State
from openhands.core.config import LLMConfig
from openhands.core.config.condenser_config import (
CondenserConfig,
NoOpCondenserConfig,
)
from openhands.core.exceptions import (
AgentRuntimeBuildError,
AgentRuntimeDisconnectedError,
Expand All @@ -33,6 +37,7 @@
from openhands.events.event import Event
from openhands.events.serialization.event import event_to_dict
from openhands.events.utils import get_pairs_from_events
from openhands.memory.condenser import get_condensation_metadata


class EvalMetadata(BaseModel):
Expand All @@ -45,18 +50,29 @@ class EvalMetadata(BaseModel):
dataset: str | None = None
data_split: str | None = None
details: dict[str, Any] | None = None
condenser_config: CondenserConfig | None = None

def model_dump(self, *args, **kwargs):
dumped_dict = super().model_dump(*args, **kwargs)
# avoid leaking sensitive information
dumped_dict['llm_config'] = self.llm_config.to_safe_dict()
if hasattr(self.condenser_config, 'llm_config'):
dumped_dict['condenser_config']['llm_config'] = (
self.condenser_config.llm_config.to_safe_dict()
)

return dumped_dict

def model_dump_json(self, *args, **kwargs):
dumped = super().model_dump_json(*args, **kwargs)
dumped_dict = json.loads(dumped)
# avoid leaking sensitive information
dumped_dict['llm_config'] = self.llm_config.to_safe_dict()
if hasattr(self.condenser_config, 'llm_config'):
dumped_dict['condenser_config']['llm_config'] = (
self.condenser_config.llm_config.to_safe_dict()
)

logger.debug(f'Dumped metadata: {dumped_dict}')
return json.dumps(dumped_dict)

Expand Down Expand Up @@ -192,6 +208,7 @@ def make_metadata(
eval_output_dir: str,
data_split: str | None = None,
details: dict[str, Any] | None = None,
condenser_config: CondenserConfig | None = None,
) -> EvalMetadata:
model_name = llm_config.model.split('/')[-1]
model_path = model_name.replace(':', '_').replace('@', '-')
Expand Down Expand Up @@ -222,6 +239,9 @@ def make_metadata(
dataset=dataset_name,
data_split=data_split,
details=details,
condenser_config=condenser_config
if condenser_config
else NoOpCondenserConfig(),
)
metadata_json = metadata.model_dump_json()
logger.info(f'Metadata: {metadata_json}')
Expand Down Expand Up @@ -551,3 +571,10 @@ def is_fatal_evaluation_error(error: str | None) -> bool:
return True

return False


def get_metrics(state: State) -> dict[str, Any]:
"""Extract metrics from the state."""
metrics = state.metrics.get() if state.metrics else {}
metrics['condenser'] = get_condensation_metadata(state)
return metrics
13 changes: 12 additions & 1 deletion openhands/agenthub/codeact_agent/codeact_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
MessageAction,
)
from openhands.events.observation import (
AgentCondensationObservation,
AgentDelegateObservation,
BrowserOutputObservation,
CmdOutputObservation,
Expand All @@ -36,6 +37,7 @@
from openhands.events.observation.observation import Observation
from openhands.events.serialization.event import truncate_content
from openhands.llm.llm import LLM
from openhands.memory.condenser import Condenser
from openhands.runtime.plugins import (
AgentSkillsRequirement,
JupyterRequirement,
Expand Down Expand Up @@ -115,6 +117,9 @@ def __init__(
disabled_microagents=self.config.disabled_microagents,
)

self.condenser = Condenser.from_config(self.config.condenser)
logger.debug(f'Using condenser: {self.condenser}')

def get_action_message(
self,
action: Action,
Expand Down Expand Up @@ -322,6 +327,9 @@ def get_observation_message(
text = 'OBSERVATION:\n' + truncate_content(obs.content, max_message_chars)
text += '\n[Last action has been rejected by the user]'
message = Message(role='user', content=[TextContent(text=text)])
elif isinstance(obs, AgentCondensationObservation):
text = truncate_content(obs.content, max_message_chars)
message = Message(role='user', content=[TextContent(text=text)])
else:
# If an observation message is not returned, it will cause an error
# when the LLM tries to return the next message
Expand Down Expand Up @@ -442,7 +450,10 @@ def _get_messages(self, state: State) -> list[Message]:

pending_tool_call_action_messages: dict[str, Message] = {}
tool_call_id_to_message: dict[str, Message] = {}
events = list(state.history)

# Condense the events from the state.
events = self.condenser.condensed_history(state)

for event in events:
# create a regular message from an event
if isinstance(event, Action):
Expand Down
5 changes: 4 additions & 1 deletion openhands/core/config/agent_config.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from dataclasses import dataclass, fields
from dataclasses import dataclass, field, fields

from openhands.core.config.condenser_config import CondenserConfig, NoOpCondenserConfig
from openhands.core.config.config_utils import get_field_info


Expand All @@ -18,6 +19,7 @@ class AgentConfig:
llm_config: The name of the llm config to use. If specified, this will override global llm config.
use_microagents: Whether to use microagents at all. Default is True.
disabled_microagents: A list of microagents to disable. Default is None.
condenser: Configuration for the memory condenser. Default is NoOpCondenserConfig.
"""

codeact_enable_browsing: bool = True
Expand All @@ -29,6 +31,7 @@ class AgentConfig:
llm_config: str | None = None
use_microagents: bool = True
disabled_microagents: list[str] | None = None
condenser: CondenserConfig = field(default_factory=NoOpCondenserConfig) # type: ignore

def defaults_to_dict(self) -> dict:
"""Serialize fields to a dict for the frontend, including type hints, defaults, and whether it's optional."""
Expand Down
90 changes: 90 additions & 0 deletions openhands/core/config/condenser_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
from typing import Literal

from pydantic import BaseModel, Field

from openhands.core.config.llm_config import LLMConfig


class NoOpCondenserConfig(BaseModel):
"""Configuration for NoOpCondenser."""

type: Literal['noop'] = Field('noop')


class ObservationMaskingCondenserConfig(BaseModel):
"""Configuration for ObservationMaskingCondenser."""

type: Literal['observation_masking'] = Field('observation_masking')
attention_window: int = Field(
default=10,
description='The number of most-recent events where observations will not be masked.',
ge=1,
)


class RecentEventsCondenserConfig(BaseModel):
"""Configuration for RecentEventsCondenser."""

type: Literal['recent'] = Field('recent')
keep_first: int = Field(
default=0,
description='The number of initial events to condense.',
ge=0,
)
max_events: int = Field(
default=10, description='Maximum number of events to keep.', ge=1
)


class LLMSummarizingCondenserConfig(BaseModel):
"""Configuration for LLMCondenser."""

type: Literal['llm'] = Field('llm')
llm_config: LLMConfig = Field(
..., description='Configuration for the LLM to use for condensing.'
)


class AmortizedForgettingCondenserConfig(BaseModel):
"""Configuration for AmortizedForgettingCondenser."""

type: Literal['amortized'] = Field('amortized')
max_size: int = Field(
default=100,
description='Maximum size of the condensed history before triggering forgetting.',
ge=2,
)
keep_first: int = Field(
default=0,
description='Number of initial events to always keep in history.',
ge=0,
)


class LLMAttentionCondenserConfig(BaseModel):
"""Configuration for LLMAttentionCondenser."""

type: Literal['llm_attention'] = Field('llm_attention')
llm_config: LLMConfig = Field(
..., description='Configuration for the LLM to use for attention.'
)
max_size: int = Field(
default=100,
description='Maximum size of the condensed history before triggering forgetting.',
ge=2,
)
keep_first: int = Field(
default=0,
description='Number of initial events to always keep in history.',
ge=0,
)


CondenserConfig = (
NoOpCondenserConfig
| ObservationMaskingCondenserConfig
| RecentEventsCondenserConfig
| LLMSummarizingCondenserConfig
| AmortizedForgettingCondenserConfig
| LLMAttentionCondenserConfig
)
3 changes: 3 additions & 0 deletions openhands/core/schema/observation.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,5 +44,8 @@ class ObservationTypeSchema(BaseModel):

USER_REJECTED: str = Field(default='user_rejected')

CONDENSE: str = Field(default='condense')
"""Result of a condensation operation."""


ObservationType = ObservationTypeSchema()
6 changes: 5 additions & 1 deletion openhands/events/observation/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from openhands.events.observation.agent import AgentStateChangedObservation
from openhands.events.observation.agent import (
AgentCondensationObservation,
AgentStateChangedObservation,
)
from openhands.events.observation.browse import BrowserOutputObservation
from openhands.events.observation.commands import (
CmdOutputMetadata,
Expand Down Expand Up @@ -32,4 +35,5 @@
'AgentDelegateObservation',
'SuccessObservation',
'UserRejectObservation',
'AgentCondensationObservation',
]
11 changes: 11 additions & 0 deletions openhands/events/observation/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,14 @@ class AgentStateChangedObservation(Observation):
@property
def message(self) -> str:
return ''


@dataclass
class AgentCondensationObservation(Observation):
"""The output of a condensation action."""

observation: str = ObservationType.CONDENSE

@property
def message(self) -> str:
return self.content
6 changes: 5 additions & 1 deletion openhands/events/serialization/observation.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import copy

from openhands.events.observation.agent import AgentStateChangedObservation
from openhands.events.observation.agent import (
AgentCondensationObservation,
AgentStateChangedObservation,
)
from openhands.events.observation.browse import BrowserOutputObservation
from openhands.events.observation.commands import (
CmdOutputMetadata,
Expand Down Expand Up @@ -32,6 +35,7 @@
ErrorObservation,
AgentStateChangedObservation,
UserRejectObservation,
AgentCondensationObservation,
)

OBSERVATION_TYPE_TO_CLASS = {
Expand Down
4 changes: 2 additions & 2 deletions openhands/memory/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from openhands.memory.condenser import MemoryCondenser
from openhands.memory.condenser import Condenser
from openhands.memory.memory import LongTermMemory

__all__ = ['LongTermMemory', 'MemoryCondenser']
__all__ = ['LongTermMemory', 'Condenser']
Loading

0 comments on commit 6e4ff56

Please sign in to comment.