Skip to content

Commit

Permalink
Merge branch 'enyst/history-error' of github.com:All-Hands-AI/OpenHan…
Browse files Browse the repository at this point in the history
…ds into enyst/history-error
  • Loading branch information
enyst committed Dec 31, 2024
2 parents 184b1ea + bff70ef commit f5689b4
Show file tree
Hide file tree
Showing 12 changed files with 385 additions and 169 deletions.
13 changes: 1 addition & 12 deletions openhands/agenthub/codeact_agent/codeact_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,18 +482,7 @@ def _get_messages(self, state: State) -> list[Message]:
if message:
if message.role == 'user':
self.prompt_manager.enhance_message(message)
# handle error if the message is the SAME role as the previous message
# litellm.exceptions.BadRequestError: litellm.BadRequestError: OpenAIException - Error code: 400 - {'detail': 'Only supports u/a/u/a/u...'}
# there shouldn't be two consecutive messages from the same role
# NOTE: we shouldn't combine tool messages because each of them has a different tool_call_id
if (
messages
and messages[-1].role == message.role
and message.role != 'tool'
):
messages[-1].content.extend(message.content)
else:
messages.append(message)
messages.append(message)

if self.llm.is_caching_prompt_active():
# NOTE: this is only needed for anthropic
Expand Down
91 changes: 61 additions & 30 deletions openhands/controller/agent_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@
)
from openhands.events.serialization.event import truncate_content
from openhands.llm.llm import LLM
from openhands.utils.shutdown_listener import should_continue

# note: RESUME is only available on web GUI
TRAFFIC_CONTROL_REMINDER = (
Expand All @@ -64,7 +63,6 @@ class AgentController:
confirmation_mode: bool
agent_to_llm_config: dict[str, LLMConfig]
agent_configs: dict[str, AgentConfig]
agent_task: asyncio.Future | None = None
parent: 'AgentController | None' = None
delegate: 'AgentController | None' = None
_pending_action: Action | None = None
Expand Down Expand Up @@ -109,7 +107,6 @@ def __init__(
headless_mode: Whether the agent is run in headless mode.
status_callback: Optional callback function to handle status updates.
"""
self._step_lock = asyncio.Lock()
self.id = sid
self.agent = agent
self.headless_mode = headless_mode
Expand Down Expand Up @@ -199,32 +196,44 @@ async def _react_to_exception(
err_id = 'STATUS$ERROR_LLM_AUTHENTICATION'
self.status_callback('error', err_id, type(e).__name__ + ': ' + str(e))

async def start_step_loop(self):
"""The main loop for the agent's step-by-step execution."""
self.log('info', 'Starting step loop...')
while True:
if not self._is_awaiting_observation() and not should_continue():
break
if self._closed:
break
try:
await self._step()
except asyncio.CancelledError:
self.log('debug', 'AgentController task was cancelled')
break
except Exception as e:
traceback.print_exc()
self.log('error', f'Error while running the agent: {e}')
await self._react_to_exception(e)
def step(self):
asyncio.create_task(self._step_with_exception_handling())

async def _step_with_exception_handling(self):
try:
await self._step()
except Exception as e:
traceback.print_exc()
self.log('error', f'Error while running the agent: {e}')
reported = RuntimeError(
'There was an unexpected error while running the agent.'
)
if isinstance(e, litellm.LLMError):
reported = e
await self._react_to_exception(reported)

await asyncio.sleep(0.1)
def should_step(self, event: Event) -> bool:
if isinstance(event, Action):
if isinstance(event, MessageAction) and event.source == EventSource.USER:
return True
return False
if isinstance(event, Observation):
if isinstance(event, NullObservation) or isinstance(
event, AgentStateChangedObservation
):
return False
return True
return False

async def on_event(self, event: Event) -> None:
def on_event(self, event: Event) -> None:
"""Callback from the event stream. Notifies the controller of incoming events.
Args:
event (Event): The incoming event to process.
"""
asyncio.get_event_loop().run_until_complete(self._on_event(event))

async def _on_event(self, event: Event) -> None:
if hasattr(event, 'hidden') and event.hidden:
return

Expand All @@ -237,6 +246,9 @@ async def on_event(self, event: Event) -> None:
elif isinstance(event, Observation):
await self._handle_observation(event)

if self.should_step(event):
self.step()

async def _handle_action(self, action: Action) -> None:
"""Handles actions from the event stream.
Expand Down Expand Up @@ -335,6 +347,28 @@ async def _handle_message_action(self, action: MessageAction) -> None:
def _reset(self) -> None:
"""Resets the agent controller"""

# make sure there is an Observation with the tool call metadata to be recognized by the agent
# otherwise the pending action is found in history, but it's incomplete without an obs with tool result
if self._pending_action and hasattr(self._pending_action, 'tool_call_metadata'):
# find out if there already is an observation with the same tool call metadata
found_observation = False
for event in self.state.history:
if (
isinstance(event, Observation)
and event.tool_call_metadata
== self._pending_action.tool_call_metadata
):
found_observation = True
break

# make a new ErrorObservation with the tool call metadata
if not found_observation:
obs = ErrorObservation(content='The action has not been executed.')
obs.tool_call_metadata = self._pending_action.tool_call_metadata
obs._cause = self._pending_action.id # type: ignore[attr-defined]
self.event_stream.add_event(obs, EventSource.AGENT)

# reset the pending action, this will be called when the agent is STOPPED or ERROR
self._pending_action = None
self.agent.reset()

Expand Down Expand Up @@ -465,19 +499,16 @@ async def start_delegate(self, action: AgentDelegateAction) -> None:
async def _step(self) -> None:
"""Executes a single step of the parent or delegate agent. Detects stuck agents and limits on the number of iterations and the task budget."""
if self.get_agent_state() != AgentState.RUNNING:
await asyncio.sleep(1)
return

if self._pending_action:
await asyncio.sleep(1)
return

if self.delegate is not None:
assert self.delegate != self
if self.delegate.get_agent_state() == AgentState.PAUSED:
# no need to check too often
await asyncio.sleep(1)
else:
# TODO this conditional will always be false, because the parent controllers are unsubscribed
# remove if it's still useless when delegation is reworked
if self.delegate.get_agent_state() != AgentState.PAUSED:
await self._delegate_step()
return

Expand All @@ -487,7 +518,6 @@ async def _step(self) -> None:
extra={'msg_type': 'STEP'},
)

# check if agent hit the resources limit
stop_step = False
if self.state.iteration >= self.state.max_iterations:
stop_step = await self._handle_traffic_control(
Expand All @@ -500,6 +530,7 @@ async def _step(self) -> None:
'budget', current_cost, self.max_budget_per_task
)
if stop_step:
logger.warning('Stopping agent due to traffic control')
return

if self._is_stuck():
Expand Down Expand Up @@ -954,7 +985,7 @@ def __repr__(self):
return (
f'AgentController(id={self.id}, agent={self.agent!r}, '
f'event_stream={self.event_stream!r}, '
f'state={self.state!r}, agent_task={self.agent_task!r}, '
f'state={self.state!r}, '
f'delegate={self.delegate!r}, _pending_action={self._pending_action!r})'
)

Expand Down
8 changes: 0 additions & 8 deletions openhands/core/loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ async def run_agent_until_done(
the agent until it reaches a terminal state.
Note that runtime must be connected before being passed in here.
"""
controller.agent_task = asyncio.create_task(controller.start_step_loop())

def status_callback(msg_type, msg_id, msg):
if msg_type == 'error':
Expand All @@ -41,10 +40,3 @@ def status_callback(msg_type, msg_id, msg):

while controller.state.agent_state not in end_states:
await asyncio.sleep(1)

if not controller.agent_task.done():
controller.agent_task.cancel()
try:
await controller.agent_task
except asyncio.CancelledError:
pass
58 changes: 39 additions & 19 deletions openhands/events/stream.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import asyncio
import threading
from dataclasses import dataclass, field
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
from enum import Enum
from queue import Queue
from typing import Callable, Iterable

from openhands.core.logger import openhands_logger as logger
Expand Down Expand Up @@ -52,15 +53,26 @@ async def __aiter__(self):
yield await loop.run_in_executor(None, lambda e=event: e) # type: ignore


@dataclass
class EventStream:
sid: str
file_store: FileStore
# For each subscriber ID, there is a map of callback functions - useful
# when there are multiple listeners
_subscribers: dict[str, dict[str, Callable]] = field(default_factory=dict)
_subscribers: dict[str, dict[str, Callable]]
_cur_id: int = 0
_lock: threading.Lock = field(default_factory=threading.Lock)
_lock: threading.Lock

def __init__(self, sid: str, file_store: FileStore, num_workers: int = 1):
self.sid = sid
self.file_store = file_store
self._queue: Queue[Event] = Queue()
self._thread_pools: dict[str, dict[str, ThreadPoolExecutor]] = {}
self._queue_thread = threading.Thread(target=self._run_queue_loop)
self._queue_thread.daemon = True
self._queue_thread.start()
self._subscribers = {}
self._lock = threading.Lock()
self._cur_id = 0

def __post_init__(self) -> None:
try:
Expand All @@ -76,6 +88,10 @@ def __post_init__(self) -> None:
if id >= self._cur_id:
self._cur_id = id + 1

def _init_thread_loop(self):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)

def _get_filename_for_id(self, id: int) -> str:
return get_conversation_event_filename(self.sid, id)

Expand Down Expand Up @@ -157,15 +173,18 @@ def get_latest_event_id(self) -> int:
def subscribe(
self, subscriber_id: EventStreamSubscriber, callback: Callable, callback_id: str
):
pool = ThreadPoolExecutor(max_workers=1, initializer=self._init_thread_loop)
if subscriber_id not in self._subscribers:
self._subscribers[subscriber_id] = {}
self._thread_pools[subscriber_id] = {}

if callback_id in self._subscribers[subscriber_id]:
raise ValueError(
f'Callback ID on subscriber {subscriber_id} already exists: {callback_id}'
)

self._subscribers[subscriber_id][callback_id] = callback
self._thread_pools[subscriber_id][callback_id] = pool

def unsubscribe(self, subscriber_id: EventStreamSubscriber, callback_id: str):
if subscriber_id not in self._subscribers:
Expand All @@ -179,13 +198,6 @@ def unsubscribe(self, subscriber_id: EventStreamSubscriber, callback_id: str):
del self._subscribers[subscriber_id][callback_id]

def add_event(self, event: Event, source: EventSource):
try:
asyncio.get_running_loop().create_task(self._async_add_event(event, source))
except RuntimeError:
# No event loop running...
asyncio.run(self._async_add_event(event, source))

async def _async_add_event(self, event: Event, source: EventSource):
if hasattr(event, '_id') and event.id is not None:
raise ValueError(
'Event already has an ID. It was probably added back to the EventStream from inside a handler, trigging a loop.'
Expand All @@ -199,14 +211,22 @@ async def _async_add_event(self, event: Event, source: EventSource):
data = event_to_dict(event)
if event.id is not None:
self.file_store.write(self._get_filename_for_id(event.id), json.dumps(data))
tasks = []
for key in sorted(self._subscribers.keys()):
callbacks = self._subscribers[key]
for callback_id in callbacks:
callback = callbacks[callback_id]
tasks.append(asyncio.create_task(callback(event)))
if tasks:
await asyncio.wait(tasks)
self._queue.put(event)

def _run_queue_loop(self):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(self._process_queue())

async def _process_queue(self):
while should_continue():
event = self._queue.get()
for key in sorted(self._subscribers.keys()):
callbacks = self._subscribers[key]
for callback_id in callbacks:
callback = callbacks[callback_id]
pool = self._thread_pools[key][callback_id]
pool.submit(callback, event)

def _callback(self, callback: Callable, event: Event):
asyncio.run(callback(event))
Expand Down
6 changes: 4 additions & 2 deletions openhands/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
warnings.simplefilter('ignore')
import litellm

from litellm import ChatCompletionMessageToolCall, ModelInfo, PromptTokensDetails
from litellm import Message as LiteLLMMessage
from litellm import ModelInfo, PromptTokensDetails
from litellm import completion as litellm_completion
from litellm import completion_cost as litellm_completion_cost
from litellm.exceptions import (
Expand Down Expand Up @@ -246,7 +246,9 @@ def wrapper(*args, **kwargs):
resp.choices[0].message = fn_call_response_message

message_back: str = resp['choices'][0]['message']['content'] or ''
tool_calls = resp['choices'][0]['message'].get('tool_calls', [])
tool_calls: list[ChatCompletionMessageToolCall] = resp['choices'][0][
'message'
].get('tool_calls', [])
if tool_calls:
for tool_call in tool_calls:
fn_name = tool_call.function.name
Expand Down
Loading

0 comments on commit f5689b4

Please sign in to comment.