Skip to content

Commit

Permalink
refactor: make Agent.step() multi-step (#1884)
Browse files Browse the repository at this point in the history
  • Loading branch information
cpacker authored Oct 15, 2024
1 parent 94d2a18 commit 4fd82ee
Show file tree
Hide file tree
Showing 7 changed files with 148 additions and 125 deletions.
118 changes: 103 additions & 15 deletions letta/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,17 @@
from letta.constants import (
CLI_WARNING_PREFIX,
FIRST_MESSAGE_ATTEMPTS,
FUNC_FAILED_HEARTBEAT_MESSAGE,
IN_CONTEXT_MEMORY_KEYWORD,
LLM_MAX_TOKENS,
MESSAGE_SUMMARY_TRUNC_KEEP_N_LAST,
MESSAGE_SUMMARY_TRUNC_TOKEN_FRAC,
MESSAGE_SUMMARY_WARNING_FRAC,
REQ_HEARTBEAT_MESSAGE,
)
from letta.errors import LLMError
from letta.interface import AgentInterface
from letta.llm_api.helpers import is_context_overflow_error
from letta.llm_api.llm_api_tools import create
from letta.memory import ArchivalMemory, RecallMemory, summarize_messages
from letta.metadata import MetadataStore
Expand All @@ -32,11 +36,15 @@
from letta.schemas.openai.chat_completion_response import (
Message as ChatCompletionMessage,
)
from letta.schemas.openai.chat_completion_response import UsageStatistics
from letta.schemas.passage import Passage
from letta.schemas.tool import Tool
from letta.schemas.usage import LettaUsageStatistics
from letta.system import (
get_heartbeat,
get_initial_boot_messages,
get_login_event,
get_token_limit_warning,
package_function_response,
package_summarize_message,
package_user_message,
Expand All @@ -56,9 +64,6 @@
verify_first_message_correctness,
)

from .errors import LLMError
from .llm_api.helpers import is_context_overflow_error


def compile_memory_metadata_block(
memory_edit_timestamp: datetime.datetime,
Expand Down Expand Up @@ -202,7 +207,7 @@ class BaseAgent(ABC):
def step(
self,
messages: Union[Message, List[Message]],
) -> AgentStepResponse:
) -> LettaUsageStatistics:
"""
Top-level event message handler for the agent.
"""
Expand Down Expand Up @@ -721,18 +726,105 @@ def _handle_ai_response(
return messages, heartbeat_request, function_failed

def step(
self,
messages: Union[Message, List[Message]],
# additional args
chaining: bool = True,
max_chaining_steps: Optional[int] = None,
ms: Optional[MetadataStore] = None,
**kwargs,
) -> LettaUsageStatistics:
"""Run Agent.step in a loop, handling chaining via heartbeat requests and function failures"""
# assert ms is not None, "MetadataStore is required"

next_input_message = messages if isinstance(messages, list) else [messages]
counter = 0
total_usage = UsageStatistics()
step_count = 0
while True:
kwargs["ms"] = ms
kwargs["first_message"] = False
step_response = self.inner_step(
messages=next_input_message,
**kwargs,
)
step_response.messages
heartbeat_request = step_response.heartbeat_request
function_failed = step_response.function_failed
token_warning = step_response.in_context_memory_warning
usage = step_response.usage

step_count += 1
total_usage += usage
counter += 1
self.interface.step_complete()

# logger.debug("Saving agent state")
# save updated state
if ms:
save_agent(self, ms)

# Chain stops
if not chaining:
printd("No chaining, stopping after one step")
break
elif max_chaining_steps is not None and counter > max_chaining_steps:
printd(f"Hit max chaining steps, stopping after {counter} steps")
break
# Chain handlers
elif token_warning:
assert self.agent_state.user_id is not None
next_input_message = Message.dict_to_message(
agent_id=self.agent_state.id,
user_id=self.agent_state.user_id,
model=self.model,
openai_message_dict={
"role": "user", # TODO: change to system?
"content": get_token_limit_warning(),
},
)
continue # always chain
elif function_failed:
assert self.agent_state.user_id is not None
next_input_message = Message.dict_to_message(
agent_id=self.agent_state.id,
user_id=self.agent_state.user_id,
model=self.model,
openai_message_dict={
"role": "user", # TODO: change to system?
"content": get_heartbeat(FUNC_FAILED_HEARTBEAT_MESSAGE),
},
)
continue # always chain
elif heartbeat_request:
assert self.agent_state.user_id is not None
next_input_message = Message.dict_to_message(
agent_id=self.agent_state.id,
user_id=self.agent_state.user_id,
model=self.model,
openai_message_dict={
"role": "user", # TODO: change to system?
"content": get_heartbeat(REQ_HEARTBEAT_MESSAGE),
},
)
continue # always chain
# Letta no-op / yield
else:
break

return LettaUsageStatistics(**total_usage.model_dump(), step_count=step_count)

def inner_step(
self,
messages: Union[Message, List[Message]],
first_message: bool = False,
first_message_retry_limit: int = FIRST_MESSAGE_ATTEMPTS,
skip_verify: bool = False,
return_dicts: bool = True,
# recreate_message_timestamp: bool = True, # if True, when input is a Message type, recreated the 'created_at' field
stream: bool = False, # TODO move to config?
inner_thoughts_in_kwargs_option: OptionState = OptionState.DEFAULT,
ms: Optional[MetadataStore] = None,
) -> AgentStepResponse:
"""Top-level event message handler for the Letta agent"""
"""Runs a single step in the agent loop (generates at most one LLM call)"""

try:

Expand Down Expand Up @@ -834,13 +926,12 @@ def step(
)

self._append_to_messages(all_new_messages)
messages_to_return = [msg.to_openai_dict() for msg in all_new_messages] if return_dicts else all_new_messages

# update state after each step
self.update_state()

return AgentStepResponse(
messages=messages_to_return,
messages=all_new_messages,
heartbeat_request=heartbeat_request,
function_failed=function_failed,
in_context_memory_warning=active_memory_warning,
Expand All @@ -856,15 +947,12 @@ def step(
self.summarize_messages_inplace()

# Try step again
return self.step(
return self.inner_step(
messages=messages,
first_message=first_message,
first_message_retry_limit=first_message_retry_limit,
skip_verify=skip_verify,
return_dicts=return_dicts,
# recreate_message_timestamp=recreate_message_timestamp,
stream=stream,
# timestamp=timestamp,
inner_thoughts_in_kwargs_option=inner_thoughts_in_kwargs_option,
ms=ms,
)
Expand Down Expand Up @@ -905,7 +993,7 @@ def step_user_message(self, user_message_str: str, **kwargs) -> AgentStepRespons
# created_at=timestamp,
)

return self.step(messages=[user_message], **kwargs)
return self.inner_step(messages=[user_message], **kwargs)

def summarize_messages_inplace(self, cutoff=None, preserve_last_N_messages=True, disallow_tool_as_first=True):
assert self.messages[0]["role"] == "system", f"self.messages[0] should be system (instead got {self.messages[0]})"
Expand Down Expand Up @@ -1326,7 +1414,7 @@ def retry_message(self) -> List[Message]:
self.pop_until_user()
user_message = self.pop_message(count=1)[0]
assert user_message.text is not None, "User message text is None"
step_response = self.step_user_message(user_message_str=user_message.text, return_dicts=False)
step_response = self.step_user_message(user_message_str=user_message.text)
messages = step_response.messages

assert messages is not None
Expand Down
23 changes: 12 additions & 11 deletions letta/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -747,8 +747,9 @@ def send_message(
# simplify messages
if not include_full_message:
messages = []
for message in response.messages:
messages += message.to_letta_message()
for m in response.messages:
assert isinstance(m, Message)
messages += m.to_letta_message()
response.messages = messages

return response
Expand Down Expand Up @@ -1677,7 +1678,7 @@ def get_agent(self, agent_id: str) -> AgentState:
self.interface.clear()
return self.server.get_agent_state(user_id=self.user_id, agent_id=agent_id)

def get_agent_id(self, agent_name: str) -> AgentState:
def get_agent_id(self, agent_name: str) -> Optional[str]:
"""
Get the ID of an agent by name (names are unique per user)
Expand Down Expand Up @@ -1767,6 +1768,7 @@ def send_message(
self,
message: str,
role: str,
name: Optional[str] = None,
agent_id: Optional[str] = None,
agent_name: Optional[str] = None,
stream_steps: bool = False,
Expand All @@ -1790,19 +1792,18 @@ def send_message(
# lookup agent by name
assert agent_name, f"Either agent_id or agent_name must be provided"
agent_id = self.get_agent_id(agent_name=agent_name)

agent_state = self.get_agent(agent_id=agent_id)
assert agent_id, f"Agent with name {agent_name} not found"

if stream_steps or stream_tokens:
# TODO: implement streaming with stream=True/False
raise NotImplementedError
self.interface.clear()
if role == "system":
usage = self.server.system_message(user_id=self.user_id, agent_id=agent_id, message=message)
elif role == "user":
usage = self.server.user_message(user_id=self.user_id, agent_id=agent_id, message=message)
else:
raise ValueError(f"Role {role} not supported")

usage = self.server.send_messages(
user_id=self.user_id,
agent_id=agent_id,
messages=[MessageCreate(role=MessageRole(role), text=message, name=name)],
)

# auto-save
if self.auto_save:
Expand Down
10 changes: 6 additions & 4 deletions letta/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,8 +361,10 @@ def run_agent_loop(
skip_next_user_input = False

def process_agent_step(user_message, no_verify):
# TODO(charles): update to use agent.step() instead of inner_step()

if user_message is None:
step_response = letta_agent.step(
step_response = letta_agent.inner_step(
messages=[],
first_message=False,
skip_verify=no_verify,
Expand Down Expand Up @@ -402,15 +404,15 @@ def process_agent_step(user_message, no_verify):
while True:
try:
if strip_ui:
new_messages, user_message, skip_next_user_input = process_agent_step(user_message, no_verify)
_, user_message, skip_next_user_input = process_agent_step(user_message, no_verify)
break
else:
if stream:
# Don't display the "Thinking..." if streaming
new_messages, user_message, skip_next_user_input = process_agent_step(user_message, no_verify)
_, user_message, skip_next_user_input = process_agent_step(user_message, no_verify)
else:
with console.status("[bold cyan]Thinking...") as status:
new_messages, user_message, skip_next_user_input = process_agent_step(user_message, no_verify)
_, user_message, skip_next_user_input = process_agent_step(user_message, no_verify)
break
except KeyboardInterrupt:
print("User interrupt occurred.")
Expand Down
5 changes: 2 additions & 3 deletions letta/schemas/agent.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import uuid
from datetime import datetime
from enum import Enum
from typing import Dict, List, Optional, Union
from typing import Dict, List, Optional

from pydantic import BaseModel, Field, field_validator

Expand Down Expand Up @@ -121,8 +121,7 @@ class UpdateAgentState(BaseAgent):


class AgentStepResponse(BaseModel):
# TODO remove support for list of dicts
messages: Union[List[Message], List[dict]] = Field(..., description="The messages generated during the agent's step.")
messages: List[Message] = Field(..., description="The messages generated during the agent's step.")
heartbeat_request: bool = Field(..., description="Whether the agent requested a heartbeat (i.e. follow-up execution).")
function_failed: bool = Field(..., description="Whether the agent step ended because a function call failed.")
in_context_memory_warning: bool = Field(
Expand Down
2 changes: 1 addition & 1 deletion letta/server/rest_api/routers/openai/assistants/threads.py
Original file line number Diff line number Diff line change
Expand Up @@ -248,7 +248,7 @@ def create_run(
agent_id = thread_id
# TODO: override preset of agent with request.assistant_id
agent = server._get_or_load_agent(agent_id=agent_id)
agent.step(user_message=None) # already has messages added
agent.inner_step(messages=[]) # already has messages added
run_id = str(uuid.uuid4())
create_time = int(get_utc_time().timestamp())
return OpenAIRun(
Expand Down
Loading

0 comments on commit 4fd82ee

Please sign in to comment.