Skip to content

Commit

Permalink
Merge branch 'main' into feat-search-conversations
Browse files Browse the repository at this point in the history
  • Loading branch information
tofarr committed Dec 31, 2024
2 parents 3707971 + 2ec2f25 commit 8f70910
Show file tree
Hide file tree
Showing 13 changed files with 421 additions and 174 deletions.
13 changes: 1 addition & 12 deletions openhands/agenthub/codeact_agent/codeact_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -482,18 +482,7 @@ def _get_messages(self, state: State) -> list[Message]:
if message:
if message.role == 'user':
self.prompt_manager.enhance_message(message)
# handle error if the message is the SAME role as the previous message
# litellm.exceptions.BadRequestError: litellm.BadRequestError: OpenAIException - Error code: 400 - {'detail': 'Only supports u/a/u/a/u...'}
# there shouldn't be two consecutive messages from the same role
# NOTE: we shouldn't combine tool messages because each of them has a different tool_call_id
if (
messages
and messages[-1].role == message.role
and message.role != 'tool'
):
messages[-1].content.extend(message.content)
else:
messages.append(message)
messages.append(message)

if self.llm.is_caching_prompt_active():
# NOTE: this is only needed for anthropic
Expand Down
102 changes: 71 additions & 31 deletions openhands/controller/agent_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@
)
from openhands.events.serialization.event import truncate_content
from openhands.llm.llm import LLM
from openhands.utils.shutdown_listener import should_continue

# note: RESUME is only available on web GUI
TRAFFIC_CONTROL_REMINDER = (
Expand All @@ -64,7 +63,6 @@ class AgentController:
confirmation_mode: bool
agent_to_llm_config: dict[str, LLMConfig]
agent_configs: dict[str, AgentConfig]
agent_task: asyncio.Future | None = None
parent: 'AgentController | None' = None
delegate: 'AgentController | None' = None
_pending_action: Action | None = None
Expand Down Expand Up @@ -109,7 +107,6 @@ def __init__(
headless_mode: Whether the agent is run in headless mode.
status_callback: Optional callback function to handle status updates.
"""
self._step_lock = asyncio.Lock()
self.id = sid
self.agent = agent
self.headless_mode = headless_mode
Expand Down Expand Up @@ -199,32 +196,44 @@ async def _react_to_exception(
err_id = 'STATUS$ERROR_LLM_AUTHENTICATION'
self.status_callback('error', err_id, type(e).__name__ + ': ' + str(e))

async def start_step_loop(self):
"""The main loop for the agent's step-by-step execution."""
self.log('info', 'Starting step loop...')
while True:
if not self._is_awaiting_observation() and not should_continue():
break
if self._closed:
break
try:
await self._step()
except asyncio.CancelledError:
self.log('debug', 'AgentController task was cancelled')
break
except Exception as e:
traceback.print_exc()
self.log('error', f'Error while running the agent: {e}')
await self._react_to_exception(e)
def step(self):
asyncio.create_task(self._step_with_exception_handling())

await asyncio.sleep(0.1)
async def _step_with_exception_handling(self):
try:
await self._step()
except Exception as e:
traceback.print_exc()
self.log('error', f'Error while running the agent: {e}')
reported = RuntimeError(
'There was an unexpected error while running the agent.'
)
if isinstance(e, litellm.LLMError):
reported = e
await self._react_to_exception(reported)

async def on_event(self, event: Event) -> None:
def should_step(self, event: Event) -> bool:
if isinstance(event, Action):
if isinstance(event, MessageAction) and event.source == EventSource.USER:
return True
return False
if isinstance(event, Observation):
if isinstance(event, NullObservation) or isinstance(
event, AgentStateChangedObservation
):
return False
return True
return False

def on_event(self, event: Event) -> None:
"""Callback from the event stream. Notifies the controller of incoming events.
Args:
event (Event): The incoming event to process.
"""
asyncio.get_event_loop().run_until_complete(self._on_event(event))

async def _on_event(self, event: Event) -> None:
if hasattr(event, 'hidden') and event.hidden:
return

Expand All @@ -237,6 +246,9 @@ async def on_event(self, event: Event) -> None:
elif isinstance(event, Observation):
await self._handle_observation(event)

if self.should_step(event):
self.step()

async def _handle_action(self, action: Action) -> None:
"""Handles actions from the event stream.
Expand Down Expand Up @@ -335,6 +347,28 @@ async def _handle_message_action(self, action: MessageAction) -> None:
def _reset(self) -> None:
"""Resets the agent controller"""

# make sure there is an Observation with the tool call metadata to be recognized by the agent
# otherwise the pending action is found in history, but it's incomplete without an obs with tool result
if self._pending_action and hasattr(self._pending_action, 'tool_call_metadata'):
# find out if there already is an observation with the same tool call metadata
found_observation = False
for event in self.state.history:
if (
isinstance(event, Observation)
and event.tool_call_metadata
== self._pending_action.tool_call_metadata
):
found_observation = True
break

# make a new ErrorObservation with the tool call metadata
if not found_observation:
obs = ErrorObservation(content='The action has not been executed.')
obs.tool_call_metadata = self._pending_action.tool_call_metadata
obs._cause = self._pending_action.id # type: ignore[attr-defined]
self.event_stream.add_event(obs, EventSource.AGENT)

# reset the pending action, this will be called when the agent is STOPPED or ERROR
self._pending_action = None
self.agent.reset()

Expand Down Expand Up @@ -465,19 +499,16 @@ async def start_delegate(self, action: AgentDelegateAction) -> None:
async def _step(self) -> None:
"""Executes a single step of the parent or delegate agent. Detects stuck agents and limits on the number of iterations and the task budget."""
if self.get_agent_state() != AgentState.RUNNING:
await asyncio.sleep(1)
return

if self._pending_action:
await asyncio.sleep(1)
return

if self.delegate is not None:
assert self.delegate != self
if self.delegate.get_agent_state() == AgentState.PAUSED:
# no need to check too often
await asyncio.sleep(1)
else:
# TODO this conditional will always be false, because the parent controllers are unsubscribed
# remove if it's still useless when delegation is reworked
if self.delegate.get_agent_state() != AgentState.PAUSED:
await self._delegate_step()
return

Expand All @@ -487,7 +518,6 @@ async def _step(self) -> None:
extra={'msg_type': 'STEP'},
)

# check if agent hit the resources limit
stop_step = False
if self.state.iteration >= self.state.max_iterations:
stop_step = await self._handle_traffic_control(
Expand All @@ -500,6 +530,7 @@ async def _step(self) -> None:
'budget', current_cost, self.max_budget_per_task
)
if stop_step:
logger.warning('Stopping agent due to traffic control')
return

if self._is_stuck():
Expand Down Expand Up @@ -699,12 +730,20 @@ def set_initial_state(
# - the previous session, in which case it has history
# - from a parent agent, in which case it has no history
# - None / a new state

# If state is None, we create a brand new state and still load the event stream so we can restore the history
if state is None:
self.state = State(
inputs={},
max_iterations=max_iterations,
confirmation_mode=confirmation_mode,
)
self.state.start_id = 0

self.log(
'debug',
f'AgentController {self.id} - created new state. start_id: {self.state.start_id}',
)
else:
self.state = state

Expand All @@ -716,7 +755,8 @@ def set_initial_state(
f'AgentController {self.id} initializing history from event {self.state.start_id}',
)

self._init_history()
# Always load from the event stream to avoid losing history
self._init_history()

def _init_history(self) -> None:
"""Initializes the agent's history from the event stream.
Expand Down Expand Up @@ -945,7 +985,7 @@ def __repr__(self):
return (
f'AgentController(id={self.id}, agent={self.agent!r}, '
f'event_stream={self.event_stream!r}, '
f'state={self.state!r}, agent_task={self.agent_task!r}, '
f'state={self.state!r}, '
f'delegate={self.delegate!r}, _pending_action={self._pending_action!r})'
)

Expand Down
8 changes: 0 additions & 8 deletions openhands/core/loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ async def run_agent_until_done(
the agent until it reaches a terminal state.
Note that runtime must be connected before being passed in here.
"""
controller.agent_task = asyncio.create_task(controller.start_step_loop())

def status_callback(msg_type, msg_id, msg):
if msg_type == 'error':
Expand All @@ -41,10 +40,3 @@ def status_callback(msg_type, msg_id, msg):

while controller.state.agent_state not in end_states:
await asyncio.sleep(1)

if not controller.agent_task.done():
controller.agent_task.cancel()
try:
await controller.agent_task
except asyncio.CancelledError:
pass
61 changes: 42 additions & 19 deletions openhands/events/stream.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import asyncio
import threading
from dataclasses import dataclass, field
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
from enum import Enum
from queue import Queue
from typing import Callable, Iterable

from openhands.core.logger import openhands_logger as logger
Expand Down Expand Up @@ -52,15 +53,29 @@ async def __aiter__(self):
yield await loop.run_in_executor(None, lambda e=event: e) # type: ignore


@dataclass
class EventStream:
sid: str
file_store: FileStore
# For each subscriber ID, there is a map of callback functions - useful
# when there are multiple listeners
_subscribers: dict[str, dict[str, Callable]] = field(default_factory=dict)
_subscribers: dict[str, dict[str, Callable]]
_cur_id: int = 0
_lock: threading.Lock = field(default_factory=threading.Lock)
_lock: threading.Lock

def __init__(self, sid: str, file_store: FileStore, num_workers: int = 1):
self.sid = sid
self.file_store = file_store
self._queue: Queue[Event] = Queue()
self._thread_pools: dict[str, dict[str, ThreadPoolExecutor]] = {}
self._queue_thread = threading.Thread(target=self._run_queue_loop)
self._queue_thread.daemon = True
self._queue_thread.start()
self._subscribers = {}
self._lock = threading.Lock()
self._cur_id = 0

# load the stream
self.__post_init__()

def __post_init__(self) -> None:
try:
Expand All @@ -76,6 +91,10 @@ def __post_init__(self) -> None:
if id >= self._cur_id:
self._cur_id = id + 1

def _init_thread_loop(self):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)

def _get_filename_for_id(self, id: int) -> str:
return get_conversation_event_filename(self.sid, id)

Expand Down Expand Up @@ -157,15 +176,18 @@ def get_latest_event_id(self) -> int:
def subscribe(
self, subscriber_id: EventStreamSubscriber, callback: Callable, callback_id: str
):
pool = ThreadPoolExecutor(max_workers=1, initializer=self._init_thread_loop)
if subscriber_id not in self._subscribers:
self._subscribers[subscriber_id] = {}
self._thread_pools[subscriber_id] = {}

if callback_id in self._subscribers[subscriber_id]:
raise ValueError(
f'Callback ID on subscriber {subscriber_id} already exists: {callback_id}'
)

self._subscribers[subscriber_id][callback_id] = callback
self._thread_pools[subscriber_id][callback_id] = pool

def unsubscribe(self, subscriber_id: EventStreamSubscriber, callback_id: str):
if subscriber_id not in self._subscribers:
Expand All @@ -179,13 +201,6 @@ def unsubscribe(self, subscriber_id: EventStreamSubscriber, callback_id: str):
del self._subscribers[subscriber_id][callback_id]

def add_event(self, event: Event, source: EventSource):
try:
asyncio.get_running_loop().create_task(self._async_add_event(event, source))
except RuntimeError:
# No event loop running...
asyncio.run(self._async_add_event(event, source))

async def _async_add_event(self, event: Event, source: EventSource):
if hasattr(event, '_id') and event.id is not None:
raise ValueError(
'Event already has an ID. It was probably added back to the EventStream from inside a handler, trigging a loop.'
Expand All @@ -199,14 +214,22 @@ async def _async_add_event(self, event: Event, source: EventSource):
data = event_to_dict(event)
if event.id is not None:
self.file_store.write(self._get_filename_for_id(event.id), json.dumps(data))
tasks = []
for key in sorted(self._subscribers.keys()):
callbacks = self._subscribers[key]
for callback_id in callbacks:
callback = callbacks[callback_id]
tasks.append(asyncio.create_task(callback(event)))
if tasks:
await asyncio.wait(tasks)
self._queue.put(event)

def _run_queue_loop(self):
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
loop.run_until_complete(self._process_queue())

async def _process_queue(self):
while should_continue():
event = self._queue.get()
for key in sorted(self._subscribers.keys()):
callbacks = self._subscribers[key]
for callback_id in callbacks:
callback = callbacks[callback_id]
pool = self._thread_pools[key][callback_id]
pool.submit(callback, event)

def _callback(self, callback: Callable, event: Event):
asyncio.run(callback(event))
Expand Down
6 changes: 4 additions & 2 deletions openhands/llm/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,8 @@
warnings.simplefilter('ignore')
import litellm

from litellm import ChatCompletionMessageToolCall, ModelInfo, PromptTokensDetails
from litellm import Message as LiteLLMMessage
from litellm import ModelInfo, PromptTokensDetails
from litellm import completion as litellm_completion
from litellm import completion_cost as litellm_completion_cost
from litellm.exceptions import (
Expand Down Expand Up @@ -246,7 +246,9 @@ def wrapper(*args, **kwargs):
resp.choices[0].message = fn_call_response_message

message_back: str = resp['choices'][0]['message']['content'] or ''
tool_calls = resp['choices'][0]['message'].get('tool_calls', [])
tool_calls: list[ChatCompletionMessageToolCall] = resp['choices'][0][
'message'
].get('tool_calls', [])
if tool_calls:
for tool_call in tool_calls:
fn_name = tool_call.function.name
Expand Down
Loading

0 comments on commit 8f70910

Please sign in to comment.