Skip to content

Commit

Permalink
fix: various bugfixes (#2411)
Browse files Browse the repository at this point in the history
  • Loading branch information
sarahwooders authored Feb 5, 2025
2 parents cae4178 + 7de7b28 commit 600ab1b
Show file tree
Hide file tree
Showing 25 changed files with 283 additions and 115 deletions.
2 changes: 1 addition & 1 deletion examples/langchain_tool_usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ def main():
print(f"Created agent: {agent_state.name} with ID {str(agent_state.id)}")

# Send a message to the agent
send_message_response = client.user_message(agent_id=agent_state.id, message="How do you pronounce Albert Einstein's name?")
send_message_response = client.user_message(agent_id=agent_state.id, message="Tell me a fun fact about Albert Einstein!")
for message in send_message_response.messages:
response_json = json.dumps(message.model_dump(), indent=4)
print(f"{response_json}\n")
Expand Down
149 changes: 81 additions & 68 deletions letta/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse
from letta.schemas.openai.chat_completion_response import Message as ChatCompletionMessage
from letta.schemas.openai.chat_completion_response import UsageStatistics
from letta.schemas.sandbox_config import SandboxRunResult
from letta.schemas.tool import Tool
from letta.schemas.tool_rule import TerminalToolRule
from letta.schemas.usage import LettaUsageStatistics
Expand Down Expand Up @@ -198,7 +199,9 @@ def update_memory_if_changed(self, new_memory: Memory) -> bool:
return True
return False

def execute_tool_and_persist_state(self, function_name: str, function_args: dict, target_letta_tool: Tool):
def execute_tool_and_persist_state(
self, function_name: str, function_args: dict, target_letta_tool: Tool
) -> tuple[str, Optional[SandboxRunResult]]:
"""
Execute tool modifications and persist the state of the agent.
Note: only some agent state modifications will be persisted, such as data in the AgentState ORM and block data
Expand Down Expand Up @@ -242,14 +245,52 @@ def execute_tool_and_persist_state(self, function_name: str, function_args: dict
assert orig_memory_str == self.agent_state.memory.compile(), "Memory should not be modified in a sandbox tool"
if updated_agent_state is not None:
self.update_memory_if_changed(updated_agent_state.memory)
return function_response, sandbox_run_result
except Exception as e:
# Need to catch error here, or else trunction wont happen
# TODO: modify to function execution error
function_response = get_friendly_error_msg(
function_name=function_name, exception_name=type(e).__name__, exception_message=str(e)
)

return function_response
return function_response, None

def _handle_function_error_response(
self,
error_msg: str,
tool_call_id: str,
function_name: str,
function_response: str,
messages: List[Message],
include_function_failed_message: bool = False,
) -> List[Message]:
"""
Handle error from function call response
"""
# Update tool rules
self.last_function_response = function_response
self.tool_rules_solver.update_tool_usage(function_name)

# Extend conversation with function response
function_response = package_function_response(False, error_msg)
new_message = Message.dict_to_message(
agent_id=self.agent_state.id,
user_id=self.agent_state.created_by_id,
model=self.model,
openai_message_dict={
"role": "tool",
"name": function_name,
"content": function_response,
"tool_call_id": tool_call_id,
},
)
messages.append(new_message)
self.interface.function_message(f"Error: {error_msg}", msg_obj=new_message)
if include_function_failed_message:
self.interface.function_message(f"Ran {function_name}({function_args})", msg_obj=new_message)

# Return updated messages
return messages

def _get_ai_reply(
self,
Expand All @@ -261,6 +302,7 @@ def _get_ai_reply(
backoff_factor: float = 0.5, # delay multiplier for exponential backoff
max_delay: float = 10.0, # max delay between retries
step_count: Optional[int] = None,
last_function_failed: bool = False,
) -> ChatCompletionResponse:
"""Get response from LLM API with robust retry mechanism."""

Expand All @@ -273,6 +315,12 @@ def _get_ai_reply(
else [func for func in agent_state_tool_jsons if func["name"] in allowed_tool_names]
)

# Don't allow a tool to be called if it failed last time
if last_function_failed and self.tool_rules_solver.last_tool_name:
allowed_functions = [f for f in allowed_functions if f["name"] != self.tool_rules_solver.last_tool_name]
if not allowed_functions:
return None

# For the first message, force the initial tool if one is specified
force_tool_call = None
if (
Expand All @@ -285,6 +333,7 @@ def _get_ai_reply(
# Force a tool call if exactly one tool is specified
elif step_count is not None and step_count > 0 and len(allowed_tool_names) == 1:
force_tool_call = allowed_tool_names[0]

for attempt in range(1, empty_response_retry_limit + 1):
try:
response = create(
Expand Down Expand Up @@ -409,21 +458,8 @@ def _handle_ai_response(

if not target_letta_tool:
error_msg = f"No function named {function_name}"
function_response = package_function_response(False, error_msg)
messages.append(
Message.dict_to_message(
agent_id=self.agent_state.id,
user_id=self.agent_state.created_by_id,
model=self.model,
openai_message_dict={
"role": "tool",
"name": function_name,
"content": function_response,
"tool_call_id": tool_call_id,
},
)
) # extend conversation with function response
self.interface.function_message(f"Error: {error_msg}", msg_obj=messages[-1])
function_response = "None" # more like "never ran?"
messages = self._handle_function_error_response(error_msg, tool_call_id, function_name, function_response, messages)
return messages, False, True # force a heartbeat to allow agent to handle error

# Failure case 2: function name is OK, but function args are bad JSON
Expand All @@ -432,21 +468,8 @@ def _handle_ai_response(
function_args = parse_json(raw_function_args)
except Exception:
error_msg = f"Error parsing JSON for function '{function_name}' arguments: {function_call.arguments}"
function_response = package_function_response(False, error_msg)
messages.append(
Message.dict_to_message(
agent_id=self.agent_state.id,
user_id=self.agent_state.created_by_id,
model=self.model,
openai_message_dict={
"role": "tool",
"name": function_name,
"content": function_response,
"tool_call_id": tool_call_id,
},
)
) # extend conversation with function response
self.interface.function_message(f"Error: {error_msg}", msg_obj=messages[-1])
function_response = "None" # more like "never ran?"
messages = self._handle_function_error_response(error_msg, tool_call_id, function_name, function_response, messages)
return messages, False, True # force a heartbeat to allow agent to handle error

# Check if inner thoughts is in the function call arguments (possible apparently if you are using Azure)
Expand Down Expand Up @@ -479,7 +502,12 @@ def _handle_ai_response(
self.interface.function_message(f"Running {function_name}({function_args})", msg_obj=messages[-1])
try:
# handle tool execution (sandbox) and state updates
function_response = self.execute_tool_and_persist_state(function_name, function_args, target_letta_tool)
function_response, sandbox_run_result = self.execute_tool_and_persist_state(function_name, function_args, target_letta_tool)

if sandbox_run_result and sandbox_run_result.status == "error":
error_msg = f"Error calling function {function_name} with args {function_args}: {sandbox_run_result.stderr}"
messages = self._handle_function_error_response(error_msg, tool_call_id, function_name, function_response, messages)
return messages, False, True # force a heartbeat to allow agent to handle error

# handle trunction
if function_name in ["conversation_search", "conversation_search_date", "archival_memory_search"]:
Expand All @@ -505,45 +533,17 @@ def _handle_ai_response(
error_msg = get_friendly_error_msg(function_name=function_name, exception_name=type(e).__name__, exception_message=str(e))
error_msg_user = f"{error_msg}\n{traceback.format_exc()}"
self.logger.error(error_msg_user)
function_response = package_function_response(False, error_msg)
self.last_function_response = function_response
# TODO: truncate error message somehow
messages.append(
Message.dict_to_message(
agent_id=self.agent_state.id,
user_id=self.agent_state.created_by_id,
model=self.model,
openai_message_dict={
"role": "tool",
"name": function_name,
"content": function_response,
"tool_call_id": tool_call_id,
},
)
) # extend conversation with function response
self.interface.function_message(f"Ran {function_name}({function_args})", msg_obj=messages[-1])
self.interface.function_message(f"Error: {error_msg}", msg_obj=messages[-1])
messages = self._handle_function_error_response(
error_msg, tool_call_id, function_name, function_response, messages, include_function_failed_message=True
)
return messages, False, True # force a heartbeat to allow agent to handle error

# Step 4: check if function response is an error
if function_response_string.startswith(ERROR_MESSAGE_PREFIX):
function_response = package_function_response(False, function_response_string)
# TODO: truncate error message somehow
messages.append(
Message.dict_to_message(
agent_id=self.agent_state.id,
user_id=self.agent_state.created_by_id,
model=self.model,
openai_message_dict={
"role": "tool",
"name": function_name,
"content": function_response,
"tool_call_id": tool_call_id,
},
)
) # extend conversation with function response
self.interface.function_message(f"Ran {function_name}({function_args})", msg_obj=messages[-1])
self.interface.function_message(f"Error: {function_response_string}", msg_obj=messages[-1])
error_msg = function_response_string
messages = self._handle_function_error_response(
error_msg, tool_call_id, function_name, function_response, messages, include_function_failed_message=True
)
return messages, False, True # force a heartbeat to allow agent to handle error

# If no failures happened along the way: ...
Expand Down Expand Up @@ -607,9 +607,11 @@ def step(
counter = 0
total_usage = UsageStatistics()
step_count = 0
function_failed = False
while True:
kwargs["first_message"] = False
kwargs["step_count"] = step_count
kwargs["last_function_failed"] = function_failed
step_response = self.inner_step(
messages=next_input_message,
**kwargs,
Expand Down Expand Up @@ -689,6 +691,7 @@ def inner_step(
step_count: Optional[int] = None,
metadata: Optional[dict] = None,
summarize_attempt_count: int = 0,
last_function_failed: bool = False,
) -> AgentStepResponse:
"""Runs a single step in the agent loop (generates at most one LLM call)"""

Expand Down Expand Up @@ -723,7 +726,17 @@ def inner_step(
first_message=first_message,
stream=stream,
step_count=step_count,
last_function_failed=last_function_failed,
)
if not response:
# EDGE CASE: Function call failed AND there's no tools left for agent to call -> return early
return AgentStepResponse(
messages=input_message_sequence,
heartbeat_request=False,
function_failed=False, # NOTE: this is different from other function fails. We force to return early
in_context_memory_warning=False,
usage=UsageStatistics(),
)

# Step 3: check if LLM wanted to call a function
# (if yes) Step 4: call the function
Expand Down
11 changes: 2 additions & 9 deletions letta/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2950,18 +2950,11 @@ def load_langchain_tool(self, langchain_tool: "LangChainBaseTool", additional_im
langchain_tool=langchain_tool,
additional_imports_module_attr_map=additional_imports_module_attr_map,
)
return self.server.tool_manager.create_or_update_tool(pydantic_tool=Tool(**tool_create.model_dump()), actor=self.user)

def load_crewai_tool(self, crewai_tool: "CrewAIBaseTool", additional_imports_module_attr_map: dict[str, str] = None) -> Tool:
tool_create = ToolCreate.from_crewai(
crewai_tool=crewai_tool,
additional_imports_module_attr_map=additional_imports_module_attr_map,
)
return self.server.tool_manager.create_or_update_tool(pydantic_tool=Tool(**tool_create.model_dump()), actor=self.user)
return self.server.tool_manager.create_or_update_langchain_tool(tool_create=tool_create, actor=self.user)

def load_composio_tool(self, action: "ActionType") -> Tool:
tool_create = ToolCreate.from_composio(action_name=action.name)
return self.server.tool_manager.create_or_update_composio_tool(pydantic_tool=Tool(**tool_create.model_dump()), actor=self.user)
return self.server.tool_manager.create_or_update_composio_tool(tool_create=tool_create, actor=self.user)

def create_tool(
self,
Expand Down
4 changes: 1 addition & 3 deletions letta/functions/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,9 +230,7 @@ def generate_imported_tool_instantiation_call_str(obj: Any) -> Optional[str]:


def is_base_model(obj: Any):
from langchain_core.pydantic_v1 import BaseModel as LangChainBaseModel

return isinstance(obj, BaseModel) or isinstance(obj, LangChainBaseModel)
return isinstance(obj, BaseModel)


def generate_import_code(module_attr_map: Optional[dict]):
Expand Down
8 changes: 7 additions & 1 deletion letta/local_llm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,11 @@
import letta.local_llm.llm_chat_completion_wrappers.dolphin as dolphin
import letta.local_llm.llm_chat_completion_wrappers.llama3 as llama3
import letta.local_llm.llm_chat_completion_wrappers.zephyr as zephyr
from letta.log import get_logger
from letta.schemas.openai.chat_completion_request import Tool, ToolCall

logger = get_logger(__name__)


def post_json_auth_request(uri, json_payload, auth_type, auth_key):
"""Send a POST request with a JSON payload and optional authentication"""
Expand Down Expand Up @@ -126,8 +129,11 @@ def num_tokens_from_functions(functions: List[dict], model: str = "gpt-4"):
function_tokens += 2
if isinstance(v["items"], dict) and "type" in v["items"]:
function_tokens += len(encoding.encode(v["items"]["type"]))
elif field == "default":
function_tokens += 2
function_tokens += len(encoding.encode(str(v["default"])))
else:
warnings.warn(f"num_tokens_from_functions: Unsupported field {field} in function {function}")
logger.warning(f"num_tokens_from_functions: Unsupported field {field} in function {function}")
function_tokens += 11

num_tokens += function_tokens
Expand Down
8 changes: 1 addition & 7 deletions letta/orm/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from sqlalchemy import JSON, Index, String
from sqlalchemy.orm import Mapped, mapped_column, relationship

from letta.constants import MULTI_AGENT_TOOLS
from letta.orm.block import Block
from letta.orm.custom_columns import EmbeddingConfigColumn, LLMConfigColumn, ToolRulesColumn
from letta.orm.message import Message
Expand Down Expand Up @@ -121,12 +120,7 @@ def to_pydantic(self) -> PydanticAgentState:
# add default rule for having send_message be a terminal tool
tool_rules = self.tool_rules
if not tool_rules:
tool_rules = [
TerminalToolRule(tool_name="send_message"),
]

for tool_name in MULTI_AGENT_TOOLS:
tool_rules.append(TerminalToolRule(tool_name=tool_name))
tool_rules = [TerminalToolRule(tool_name="send_message"), TerminalToolRule(tool_name="send_message_to_agent_async")]

state = {
"id": self.id,
Expand Down
1 change: 1 addition & 0 deletions letta/orm/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ class ToolType(str, Enum):
LETTA_MEMORY_CORE = "letta_memory_core"
LETTA_MULTI_AGENT_CORE = "letta_multi_agent_core"
EXTERNAL_COMPOSIO = "external_composio"
EXTERNAL_LANGCHAIN = "external_langchain"


class JobType(str, Enum):
Expand Down
6 changes: 6 additions & 0 deletions letta/schemas/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,9 @@ class CreateAgent(BaseModel, validate_assignment=True): #
None, description="The environment variables for tool execution specific to this agent."
)
memory_variables: Optional[Dict[str, str]] = Field(None, description="The variables that should be set for the agent.")
project_id: Optional[str] = Field(None, description="The id of the project the agent belongs to.")
template_id: Optional[str] = Field(None, description="The id of the template the agent belongs to.")
base_template_id: Optional[str] = Field(None, description="The base template id of the agent.")

@field_validator("name")
@classmethod
Expand Down Expand Up @@ -210,6 +213,9 @@ class UpdateAgent(BaseModel):
tool_exec_environment_variables: Optional[Dict[str, str]] = Field(
None, description="The environment variables for tool execution specific to this agent."
)
project_id: Optional[str] = Field(None, description="The id of the project the agent belongs to.")
template_id: Optional[str] = Field(None, description="The id of the template the agent belongs to.")
base_template_id: Optional[str] = Field(None, description="The base template id of the agent.")

class Config:
extra = "ignore" # Ignores extra fields
Expand Down
2 changes: 1 addition & 1 deletion letta/schemas/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ def list_embedding_models(self) -> List[EmbeddingConfig]:
from letta.llm_api.openai import openai_get_model_list

# For LMStudio, we want to hit 'GET /api/v0/models' instead of 'GET /v1/models'
MODEL_ENDPOINT_URL = f"{self.base_url}/api/v0"
MODEL_ENDPOINT_URL = f"{self.base_url.strip('/v1')}/api/v0"
response = openai_get_model_list(MODEL_ENDPOINT_URL)

"""
Expand Down
Loading

0 comments on commit 600ab1b

Please sign in to comment.