Skip to content

Commit

Permalink
[eval] save eventstream & llm completions for SWE-Bench run_infer (#3923
Browse files Browse the repository at this point in the history
)
  • Loading branch information
xingyaoww authored Sep 22, 2024
1 parent e0608af commit 714e46f
Show file tree
Hide file tree
Showing 5 changed files with 33 additions and 5 deletions.
7 changes: 3 additions & 4 deletions evaluation/swe_bench/run_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from openhands.core.main import create_runtime, run_controller
from openhands.events.action import CmdRunAction
from openhands.events.observation import CmdOutputObservation, ErrorObservation
from openhands.events.serialization.event import event_to_dict
from openhands.runtime.runtime import Runtime
from openhands.runtime.utils.shutdown_listener import sleep_if_should_continue

Expand Down Expand Up @@ -383,10 +384,7 @@ def process_instance(
if state is None:
raise ValueError('State should not be None.')

# history is now available as a stream of events, rather than list of pairs of (Action, Observation)
# for compatibility with the existing output format, we can remake the pairs here
# remove when it becomes unnecessary
histories = state.history.compatibility_for_eval_history_pairs()
histories = [event_to_dict(event) for event in state.history.get_events()]
metrics = state.metrics.get() if state.metrics else None

# Save the output
Expand All @@ -398,6 +396,7 @@ def process_instance(
metadata=metadata,
history=histories,
metrics=metrics,
llm_completions=state.extra_data.get('llm_completions', []),
error=state.last_error if state and state.last_error else None,
)
return output
Expand Down
7 changes: 6 additions & 1 deletion evaluation/utils/shared.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,11 @@ class EvalOutput(BaseModel):

# Interaction info
metadata: EvalMetadata | None = None
history: list[tuple[dict[str, Any], dict[str, Any]]] | None = None
# list[tuple[dict[str, Any], dict[str, Any]]] - for compatibility with the old format
history: (
list[dict[str, Any]] | list[tuple[dict[str, Any], dict[str, Any]]] | None
) = None
llm_completions: list[dict[str, Any]]
metrics: dict[str, Any] | None = None
error: str | None = None

Expand Down Expand Up @@ -278,6 +282,7 @@ def _process_instance_wrapper(
+ '-' * 10
)
# Raise an error after all retries & stop the evaluation
logger.exception(e)
raise RuntimeError(
f'Maximum error retries reached for instance {instance.instance_id}'
) from e
Expand Down
4 changes: 4 additions & 0 deletions openhands/controller/agent_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,10 @@ def update_state_before_step(self):
async def update_state_after_step(self):
# update metrics especially for cost
self.state.local_metrics = self.agent.llm.metrics
if 'llm_completions' not in self.state.extra_data:
self.state.extra_data['llm_completions'] = []
self.state.extra_data['llm_completions'].extend(self.agent.llm.llm_completions)
self.agent.llm.llm_completions.clear()

async def report_error(self, message: str, exception: Exception | None = None):
"""Reports an error to the user and sends the exception to the LLM next step, in the hope it can self-correct.
Expand Down
2 changes: 2 additions & 0 deletions openhands/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ class LLMConfig:
drop_params: Drop any unmapped (unsupported) params without causing an exception.
disable_vision: If model is vision capable, this option allows to disable image processing (useful for cost reduction).
caching_prompt: Using the prompt caching feature provided by the LLM.
log_completions: Whether to log LLM completions to the state.
"""

model: str = 'gpt-4o'
Expand Down Expand Up @@ -82,6 +83,7 @@ class LLMConfig:
drop_params: bool | None = None
disable_vision: bool | None = None
caching_prompt: bool = False
log_completions: bool = False

def defaults_to_dict(self) -> dict:
"""Serialize fields to a dict for the frontend, including type hints, defaults, and whether it's optional."""
Expand Down
18 changes: 18 additions & 0 deletions openhands/llm/llm.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import asyncio
import copy
import time
import warnings
from functools import partial
from typing import Any

from openhands.core.config import LLMConfig
from openhands.runtime.utils.shutdown_listener import should_continue
Expand Down Expand Up @@ -73,6 +75,11 @@ def __init__(
self.cost_metric_supported = True
self.config = copy.deepcopy(config)

# list of LLM completions (for logging purposes). Each completion is a dict with the following keys:
# - 'messages': list of messages
# - 'response': response from the LLM
self.llm_completions: list[dict[str, Any]] = []

# Set up config attributes with default values to prevent AttributeError
LLMConfig.set_missing_attributes(self.config)

Expand Down Expand Up @@ -257,6 +264,16 @@ def wrapper(*args, **kwargs):
logger.debug('No completion messages!')
resp = {'choices': [{'message': {'content': ''}}]}

if self.config.log_completions:
self.llm_completions.append(
{
'messages': messages,
'response': resp,
'timestamp': time.time(),
'cost': self.completion_cost(resp),
}
)

# log the response
message_back = resp['choices'][0]['message']['content']
if message_back:
Expand Down Expand Up @@ -659,6 +676,7 @@ def __repr__(self):

def reset(self):
self.metrics = Metrics()
self.llm_completions = []

def format_messages_for_llm(self, messages: Message | list[Message]) -> list[dict]:
if isinstance(messages, Message):
Expand Down

0 comments on commit 714e46f

Please sign in to comment.