diff --git a/examples/langchain_tool_usage.py b/examples/langchain_tool_usage.py index cf55d12084..3ce4eb3996 100644 --- a/examples/langchain_tool_usage.py +++ b/examples/langchain_tool_usage.py @@ -73,7 +73,7 @@ def main(): print(f"Created agent: {agent_state.name} with ID {str(agent_state.id)}") # Send a message to the agent - send_message_response = client.user_message(agent_id=agent_state.id, message="How do you pronounce Albert Einstein's name?") + send_message_response = client.user_message(agent_id=agent_state.id, message="Tell me a fun fact about Albert Einstein!") for message in send_message_response.messages: response_json = json.dumps(message.model_dump(), indent=4) print(f"{response_json}\n") diff --git a/letta/agent.py b/letta/agent.py index 8f1760f87b..4284e86f4f 100644 --- a/letta/agent.py +++ b/letta/agent.py @@ -38,6 +38,7 @@ from letta.schemas.openai.chat_completion_response import ChatCompletionResponse from letta.schemas.openai.chat_completion_response import Message as ChatCompletionMessage from letta.schemas.openai.chat_completion_response import UsageStatistics +from letta.schemas.sandbox_config import SandboxRunResult from letta.schemas.tool import Tool from letta.schemas.tool_rule import TerminalToolRule from letta.schemas.usage import LettaUsageStatistics @@ -198,7 +199,9 @@ def update_memory_if_changed(self, new_memory: Memory) -> bool: return True return False - def execute_tool_and_persist_state(self, function_name: str, function_args: dict, target_letta_tool: Tool): + def execute_tool_and_persist_state( + self, function_name: str, function_args: dict, target_letta_tool: Tool + ) -> tuple[str, Optional[SandboxRunResult]]: """ Execute tool modifications and persist the state of the agent. Note: only some agent state modifications will be persisted, such as data in the AgentState ORM and block data @@ -242,6 +245,7 @@ def execute_tool_and_persist_state(self, function_name: str, function_args: dict assert orig_memory_str == self.agent_state.memory.compile(), "Memory should not be modified in a sandbox tool" if updated_agent_state is not None: self.update_memory_if_changed(updated_agent_state.memory) + return function_response, sandbox_run_result except Exception as e: # Need to catch error here, or else trunction wont happen # TODO: modify to function execution error @@ -249,7 +253,44 @@ def execute_tool_and_persist_state(self, function_name: str, function_args: dict function_name=function_name, exception_name=type(e).__name__, exception_message=str(e) ) - return function_response + return function_response, None + + def _handle_function_error_response( + self, + error_msg: str, + tool_call_id: str, + function_name: str, + function_response: str, + messages: List[Message], + include_function_failed_message: bool = False, + ) -> List[Message]: + """ + Handle error from function call response + """ + # Update tool rules + self.last_function_response = function_response + self.tool_rules_solver.update_tool_usage(function_name) + + # Extend conversation with function response + function_response = package_function_response(False, error_msg) + new_message = Message.dict_to_message( + agent_id=self.agent_state.id, + user_id=self.agent_state.created_by_id, + model=self.model, + openai_message_dict={ + "role": "tool", + "name": function_name, + "content": function_response, + "tool_call_id": tool_call_id, + }, + ) + messages.append(new_message) + self.interface.function_message(f"Error: {error_msg}", msg_obj=new_message) + if include_function_failed_message: + self.interface.function_message(f"Ran {function_name}({function_args})", msg_obj=new_message) + + # Return updated messages + return messages def _get_ai_reply( self, @@ -261,6 +302,7 @@ def _get_ai_reply( backoff_factor: float = 0.5, # delay multiplier for exponential backoff max_delay: float = 10.0, # max delay between retries step_count: Optional[int] = None, + last_function_failed: bool = False, ) -> ChatCompletionResponse: """Get response from LLM API with robust retry mechanism.""" @@ -273,6 +315,12 @@ def _get_ai_reply( else [func for func in agent_state_tool_jsons if func["name"] in allowed_tool_names] ) + # Don't allow a tool to be called if it failed last time + if last_function_failed and self.tool_rules_solver.last_tool_name: + allowed_functions = [f for f in allowed_functions if f["name"] != self.tool_rules_solver.last_tool_name] + if not allowed_functions: + return None + # For the first message, force the initial tool if one is specified force_tool_call = None if ( @@ -285,6 +333,7 @@ def _get_ai_reply( # Force a tool call if exactly one tool is specified elif step_count is not None and step_count > 0 and len(allowed_tool_names) == 1: force_tool_call = allowed_tool_names[0] + for attempt in range(1, empty_response_retry_limit + 1): try: response = create( @@ -409,21 +458,8 @@ def _handle_ai_response( if not target_letta_tool: error_msg = f"No function named {function_name}" - function_response = package_function_response(False, error_msg) - messages.append( - Message.dict_to_message( - agent_id=self.agent_state.id, - user_id=self.agent_state.created_by_id, - model=self.model, - openai_message_dict={ - "role": "tool", - "name": function_name, - "content": function_response, - "tool_call_id": tool_call_id, - }, - ) - ) # extend conversation with function response - self.interface.function_message(f"Error: {error_msg}", msg_obj=messages[-1]) + function_response = "None" # more like "never ran?" + messages = self._handle_function_error_response(error_msg, tool_call_id, function_name, function_response, messages) return messages, False, True # force a heartbeat to allow agent to handle error # Failure case 2: function name is OK, but function args are bad JSON @@ -432,21 +468,8 @@ def _handle_ai_response( function_args = parse_json(raw_function_args) except Exception: error_msg = f"Error parsing JSON for function '{function_name}' arguments: {function_call.arguments}" - function_response = package_function_response(False, error_msg) - messages.append( - Message.dict_to_message( - agent_id=self.agent_state.id, - user_id=self.agent_state.created_by_id, - model=self.model, - openai_message_dict={ - "role": "tool", - "name": function_name, - "content": function_response, - "tool_call_id": tool_call_id, - }, - ) - ) # extend conversation with function response - self.interface.function_message(f"Error: {error_msg}", msg_obj=messages[-1]) + function_response = "None" # more like "never ran?" + messages = self._handle_function_error_response(error_msg, tool_call_id, function_name, function_response, messages) return messages, False, True # force a heartbeat to allow agent to handle error # Check if inner thoughts is in the function call arguments (possible apparently if you are using Azure) @@ -479,7 +502,12 @@ def _handle_ai_response( self.interface.function_message(f"Running {function_name}({function_args})", msg_obj=messages[-1]) try: # handle tool execution (sandbox) and state updates - function_response = self.execute_tool_and_persist_state(function_name, function_args, target_letta_tool) + function_response, sandbox_run_result = self.execute_tool_and_persist_state(function_name, function_args, target_letta_tool) + + if sandbox_run_result and sandbox_run_result.status == "error": + error_msg = f"Error calling function {function_name} with args {function_args}: {sandbox_run_result.stderr}" + messages = self._handle_function_error_response(error_msg, tool_call_id, function_name, function_response, messages) + return messages, False, True # force a heartbeat to allow agent to handle error # handle trunction if function_name in ["conversation_search", "conversation_search_date", "archival_memory_search"]: @@ -505,45 +533,17 @@ def _handle_ai_response( error_msg = get_friendly_error_msg(function_name=function_name, exception_name=type(e).__name__, exception_message=str(e)) error_msg_user = f"{error_msg}\n{traceback.format_exc()}" self.logger.error(error_msg_user) - function_response = package_function_response(False, error_msg) - self.last_function_response = function_response - # TODO: truncate error message somehow - messages.append( - Message.dict_to_message( - agent_id=self.agent_state.id, - user_id=self.agent_state.created_by_id, - model=self.model, - openai_message_dict={ - "role": "tool", - "name": function_name, - "content": function_response, - "tool_call_id": tool_call_id, - }, - ) - ) # extend conversation with function response - self.interface.function_message(f"Ran {function_name}({function_args})", msg_obj=messages[-1]) - self.interface.function_message(f"Error: {error_msg}", msg_obj=messages[-1]) + messages = self._handle_function_error_response( + error_msg, tool_call_id, function_name, function_response, messages, include_function_failed_message=True + ) return messages, False, True # force a heartbeat to allow agent to handle error # Step 4: check if function response is an error if function_response_string.startswith(ERROR_MESSAGE_PREFIX): - function_response = package_function_response(False, function_response_string) - # TODO: truncate error message somehow - messages.append( - Message.dict_to_message( - agent_id=self.agent_state.id, - user_id=self.agent_state.created_by_id, - model=self.model, - openai_message_dict={ - "role": "tool", - "name": function_name, - "content": function_response, - "tool_call_id": tool_call_id, - }, - ) - ) # extend conversation with function response - self.interface.function_message(f"Ran {function_name}({function_args})", msg_obj=messages[-1]) - self.interface.function_message(f"Error: {function_response_string}", msg_obj=messages[-1]) + error_msg = function_response_string + messages = self._handle_function_error_response( + error_msg, tool_call_id, function_name, function_response, messages, include_function_failed_message=True + ) return messages, False, True # force a heartbeat to allow agent to handle error # If no failures happened along the way: ... @@ -607,9 +607,11 @@ def step( counter = 0 total_usage = UsageStatistics() step_count = 0 + function_failed = False while True: kwargs["first_message"] = False kwargs["step_count"] = step_count + kwargs["last_function_failed"] = function_failed step_response = self.inner_step( messages=next_input_message, **kwargs, @@ -689,6 +691,7 @@ def inner_step( step_count: Optional[int] = None, metadata: Optional[dict] = None, summarize_attempt_count: int = 0, + last_function_failed: bool = False, ) -> AgentStepResponse: """Runs a single step in the agent loop (generates at most one LLM call)""" @@ -723,7 +726,17 @@ def inner_step( first_message=first_message, stream=stream, step_count=step_count, + last_function_failed=last_function_failed, ) + if not response: + # EDGE CASE: Function call failed AND there's no tools left for agent to call -> return early + return AgentStepResponse( + messages=input_message_sequence, + heartbeat_request=False, + function_failed=False, # NOTE: this is different from other function fails. We force to return early + in_context_memory_warning=False, + usage=UsageStatistics(), + ) # Step 3: check if LLM wanted to call a function # (if yes) Step 4: call the function diff --git a/letta/client/client.py b/letta/client/client.py index a2b62fa7cf..413e6b64c6 100644 --- a/letta/client/client.py +++ b/letta/client/client.py @@ -2950,18 +2950,11 @@ def load_langchain_tool(self, langchain_tool: "LangChainBaseTool", additional_im langchain_tool=langchain_tool, additional_imports_module_attr_map=additional_imports_module_attr_map, ) - return self.server.tool_manager.create_or_update_tool(pydantic_tool=Tool(**tool_create.model_dump()), actor=self.user) - - def load_crewai_tool(self, crewai_tool: "CrewAIBaseTool", additional_imports_module_attr_map: dict[str, str] = None) -> Tool: - tool_create = ToolCreate.from_crewai( - crewai_tool=crewai_tool, - additional_imports_module_attr_map=additional_imports_module_attr_map, - ) - return self.server.tool_manager.create_or_update_tool(pydantic_tool=Tool(**tool_create.model_dump()), actor=self.user) + return self.server.tool_manager.create_or_update_langchain_tool(tool_create=tool_create, actor=self.user) def load_composio_tool(self, action: "ActionType") -> Tool: tool_create = ToolCreate.from_composio(action_name=action.name) - return self.server.tool_manager.create_or_update_composio_tool(pydantic_tool=Tool(**tool_create.model_dump()), actor=self.user) + return self.server.tool_manager.create_or_update_composio_tool(tool_create=tool_create, actor=self.user) def create_tool( self, diff --git a/letta/functions/helpers.py b/letta/functions/helpers.py index bd2643c493..8c232cd587 100644 --- a/letta/functions/helpers.py +++ b/letta/functions/helpers.py @@ -230,9 +230,7 @@ def generate_imported_tool_instantiation_call_str(obj: Any) -> Optional[str]: def is_base_model(obj: Any): - from langchain_core.pydantic_v1 import BaseModel as LangChainBaseModel - - return isinstance(obj, BaseModel) or isinstance(obj, LangChainBaseModel) + return isinstance(obj, BaseModel) def generate_import_code(module_attr_map: Optional[dict]): diff --git a/letta/local_llm/utils.py b/letta/local_llm/utils.py index f5d5417401..21be45c5c3 100644 --- a/letta/local_llm/utils.py +++ b/letta/local_llm/utils.py @@ -11,8 +11,11 @@ import letta.local_llm.llm_chat_completion_wrappers.dolphin as dolphin import letta.local_llm.llm_chat_completion_wrappers.llama3 as llama3 import letta.local_llm.llm_chat_completion_wrappers.zephyr as zephyr +from letta.log import get_logger from letta.schemas.openai.chat_completion_request import Tool, ToolCall +logger = get_logger(__name__) + def post_json_auth_request(uri, json_payload, auth_type, auth_key): """Send a POST request with a JSON payload and optional authentication""" @@ -126,8 +129,11 @@ def num_tokens_from_functions(functions: List[dict], model: str = "gpt-4"): function_tokens += 2 if isinstance(v["items"], dict) and "type" in v["items"]: function_tokens += len(encoding.encode(v["items"]["type"])) + elif field == "default": + function_tokens += 2 + function_tokens += len(encoding.encode(str(v["default"]))) else: - warnings.warn(f"num_tokens_from_functions: Unsupported field {field} in function {function}") + logger.warning(f"num_tokens_from_functions: Unsupported field {field} in function {function}") function_tokens += 11 num_tokens += function_tokens diff --git a/letta/orm/agent.py b/letta/orm/agent.py index 39db57d235..a4d08f719c 100644 --- a/letta/orm/agent.py +++ b/letta/orm/agent.py @@ -4,7 +4,6 @@ from sqlalchemy import JSON, Index, String from sqlalchemy.orm import Mapped, mapped_column, relationship -from letta.constants import MULTI_AGENT_TOOLS from letta.orm.block import Block from letta.orm.custom_columns import EmbeddingConfigColumn, LLMConfigColumn, ToolRulesColumn from letta.orm.message import Message @@ -121,12 +120,7 @@ def to_pydantic(self) -> PydanticAgentState: # add default rule for having send_message be a terminal tool tool_rules = self.tool_rules if not tool_rules: - tool_rules = [ - TerminalToolRule(tool_name="send_message"), - ] - - for tool_name in MULTI_AGENT_TOOLS: - tool_rules.append(TerminalToolRule(tool_name=tool_name)) + tool_rules = [TerminalToolRule(tool_name="send_message"), TerminalToolRule(tool_name="send_message_to_agent_async")] state = { "id": self.id, diff --git a/letta/orm/enums.py b/letta/orm/enums.py index e87d28d233..d3aac7ab69 100644 --- a/letta/orm/enums.py +++ b/letta/orm/enums.py @@ -7,6 +7,7 @@ class ToolType(str, Enum): LETTA_MEMORY_CORE = "letta_memory_core" LETTA_MULTI_AGENT_CORE = "letta_multi_agent_core" EXTERNAL_COMPOSIO = "external_composio" + EXTERNAL_LANGCHAIN = "external_langchain" class JobType(str, Enum): diff --git a/letta/schemas/agent.py b/letta/schemas/agent.py index ec9880d425..9269742deb 100644 --- a/letta/schemas/agent.py +++ b/letta/schemas/agent.py @@ -143,6 +143,9 @@ class CreateAgent(BaseModel, validate_assignment=True): # None, description="The environment variables for tool execution specific to this agent." ) memory_variables: Optional[Dict[str, str]] = Field(None, description="The variables that should be set for the agent.") + project_id: Optional[str] = Field(None, description="The id of the project the agent belongs to.") + template_id: Optional[str] = Field(None, description="The id of the template the agent belongs to.") + base_template_id: Optional[str] = Field(None, description="The base template id of the agent.") @field_validator("name") @classmethod @@ -210,6 +213,9 @@ class UpdateAgent(BaseModel): tool_exec_environment_variables: Optional[Dict[str, str]] = Field( None, description="The environment variables for tool execution specific to this agent." ) + project_id: Optional[str] = Field(None, description="The id of the project the agent belongs to.") + template_id: Optional[str] = Field(None, description="The id of the template the agent belongs to.") + base_template_id: Optional[str] = Field(None, description="The base template id of the agent.") class Config: extra = "ignore" # Ignores extra fields diff --git a/letta/schemas/providers.py b/letta/schemas/providers.py index 8d38ad4c38..1fad894368 100644 --- a/letta/schemas/providers.py +++ b/letta/schemas/providers.py @@ -279,7 +279,7 @@ def list_embedding_models(self) -> List[EmbeddingConfig]: from letta.llm_api.openai import openai_get_model_list # For LMStudio, we want to hit 'GET /api/v0/models' instead of 'GET /v1/models' - MODEL_ENDPOINT_URL = f"{self.base_url}/api/v0" + MODEL_ENDPOINT_URL = f"{self.base_url.strip('/v1')}/api/v0" response = openai_get_model_list(MODEL_ENDPOINT_URL) """ diff --git a/letta/schemas/tool.py b/letta/schemas/tool.py index ab4736e605..f17498c1f7 100644 --- a/letta/schemas/tool.py +++ b/letta/schemas/tool.py @@ -79,7 +79,7 @@ def refresh_source_code_and_json_schema(self): self.json_schema = get_json_schema_from_module(module_name=LETTA_MULTI_AGENT_TOOL_MODULE_NAME, function_name=self.name) elif self.tool_type == ToolType.EXTERNAL_COMPOSIO: # If it is a composio tool, we generate both the source code and json schema on the fly here - # TODO: This is brittle, need to think long term about how to improve this + # TODO: Deriving the composio action name is brittle, need to think long term about how to improve this try: composio_action = generate_composio_action_from_func_name(self.name) tool_create = ToolCreate.from_composio(composio_action) diff --git a/letta/server/rest_api/routers/v1/agents.py b/letta/server/rest_api/routers/v1/agents.py index 458e8fe450..9a5661a69c 100644 --- a/letta/server/rest_api/routers/v1/agents.py +++ b/letta/server/rest_api/routers/v1/agents.py @@ -47,6 +47,9 @@ def list_agents( after: Optional[str] = Query(None, description="Cursor for pagination"), limit: Optional[int] = Query(None, description="Limit for pagination"), query_text: Optional[str] = Query(None, description="Search agents by name"), + project_id: Optional[str] = Query(None, description="Search agents by project id"), + template_id: Optional[str] = Query(None, description="Search agents by template id"), + base_template_id: Optional[str] = Query(None, description="Search agents by base template id"), ): """ List all agents associated with a given user. @@ -58,16 +61,25 @@ def list_agents( kwargs = { key: value for key, value in { - "tags": tags, - "match_all_tags": match_all_tags, "name": name, - "query_text": query_text, + "project_id": project_id, + "template_id": template_id, + "base_template_id": base_template_id, }.items() if value is not None } # Call list_agents with the dynamic kwargs - agents = server.agent_manager.list_agents(actor=actor, before=before, after=after, limit=limit, **kwargs) + agents = server.agent_manager.list_agents( + actor=actor, + before=before, + after=after, + limit=limit, + query_text=query_text, + tags=tags, + match_all_tags=match_all_tags, + **kwargs, + ) return agents diff --git a/letta/server/rest_api/routers/v1/tools.py b/letta/server/rest_api/routers/v1/tools.py index 65a403c836..912503fe64 100644 --- a/letta/server/rest_api/routers/v1/tools.py +++ b/letta/server/rest_api/routers/v1/tools.py @@ -66,7 +66,7 @@ def list_tools( try: actor = server.user_manager.get_user_or_default(user_id=user_id) if name is not None: - tool = server.tool_manager.get_tool_by_name(name=name, actor=actor) + tool = server.tool_manager.get_tool_by_name(tool_name=name, actor=actor) return [tool] if tool else [] return server.tool_manager.list_tools(actor=actor, after=after, limit=limit) except Exception as e: @@ -231,7 +231,7 @@ def add_composio_tool( try: tool_create = ToolCreate.from_composio(action_name=composio_action_name) - return server.tool_manager.create_or_update_composio_tool(pydantic_tool=Tool(**tool_create.model_dump()), actor=actor) + return server.tool_manager.create_or_update_composio_tool(tool_create=tool_create, actor=actor) except EnumStringNotFound as e: raise HTTPException( status_code=400, # Bad Request diff --git a/letta/server/server.py b/letta/server/server.py index b8578984bd..4a02b74ee5 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -189,7 +189,7 @@ def db_error_handler(): if settings.letta_pg_uri_no_default: - print("Creating postgres engine", settings.letta_pg_uri) + print("Creating postgres engine") config.recall_storage_type = "postgres" config.recall_storage_uri = settings.letta_pg_uri_no_default config.archival_storage_type = "postgres" @@ -1000,8 +1000,8 @@ def load_data( return passage_count, document_count def list_data_source_passages(self, user_id: str, source_id: str) -> List[Passage]: - warnings.warn("list_data_source_passages is not yet implemented, returning empty list.", category=UserWarning) - return [] + # TODO: move this query into PassageManager + return self.agent_manager.list_passages(actor=self.user_manager.get_user_or_default(user_id=user_id), source_id=source_id) def list_all_sources(self, actor: User) -> List[Source]: """List all sources (w/ extra metadata) belonging to a user""" diff --git a/letta/server/startup.sh b/letta/server/startup.sh index 2e9d7c301d..d4523cce9a 100755 --- a/letta/server/startup.sh +++ b/letta/server/startup.sh @@ -14,7 +14,7 @@ wait_for_postgres() { # Check if we're configured for external Postgres if [ -n "$LETTA_PG_URI" ]; then - echo "External Postgres configuration detected, using $LETTA_PG_URI" + echo "External Postgres configuration detected, using env var LETTA_PG_URI" else echo "No external Postgres configuration detected, starting internal PostgreSQL..." # Start PostgreSQL using the base image's entrypoint script diff --git a/letta/services/agent_manager.py b/letta/services/agent_manager.py index f4aa472657..3faf50fc6f 100644 --- a/letta/services/agent_manager.py +++ b/letta/services/agent_manager.py @@ -120,6 +120,9 @@ def create_agent( metadata=agent_create.metadata, tool_rules=agent_create.tool_rules, actor=actor, + project_id=agent_create.project_id, + template_id=agent_create.template_id, + base_template_id=agent_create.base_template_id, ) # If there are provided environment variables, add them in @@ -179,6 +182,9 @@ def _create_agent( description: Optional[str] = None, metadata: Optional[Dict] = None, tool_rules: Optional[List[PydanticToolRule]] = None, + project_id: Optional[str] = None, + template_id: Optional[str] = None, + base_template_id: Optional[str] = None, ) -> PydanticAgentState: """Create a new agent.""" with self.session_maker() as session: @@ -193,6 +199,9 @@ def _create_agent( "description": description, "metadata_": metadata, "tool_rules": tool_rules, + "project_id": project_id, + "template_id": template_id, + "base_template_id": base_template_id, } # Create the new agent using SqlalchemyBase.create @@ -242,7 +251,19 @@ def _update_agent(self, agent_id: str, agent_update: UpdateAgent, actor: Pydanti agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor) # Update scalar fields directly - scalar_fields = {"name", "system", "llm_config", "embedding_config", "message_ids", "tool_rules", "description", "metadata"} + scalar_fields = { + "name", + "system", + "llm_config", + "embedding_config", + "message_ids", + "tool_rules", + "description", + "metadata", + "project_id", + "template_id", + "base_template_id", + } for field in scalar_fields: value = getattr(agent_update, field, None) if value is not None: diff --git a/letta/services/passage_manager.py b/letta/services/passage_manager.py index 5a2f0f5d6f..7bcf1bc341 100644 --- a/letta/services/passage_manager.py +++ b/letta/services/passage_manager.py @@ -1,6 +1,8 @@ from datetime import datetime from typing import List, Optional +from openai import OpenAI + from letta.embeddings import embedding_model, parse_and_chunk_text from letta.orm.errors import NoResultFound from letta.orm.passage import AgentPassage, SourcePassage @@ -86,14 +88,31 @@ def insert_passage( """Insert passage(s) into archival memory""" embedding_chunk_size = agent_state.embedding_config.embedding_chunk_size - embed_model = embedding_model(agent_state.embedding_config) + + # TODO eventually migrate off of llama-index for embeddings? + # Already causing pain for OpenAI proxy endpoints like LM Studio... + if agent_state.embedding_config.embedding_endpoint_type != "openai": + embed_model = embedding_model(agent_state.embedding_config) passages = [] try: # breakup string into passages for text in parse_and_chunk_text(text, embedding_chunk_size): - embedding = embed_model.get_text_embedding(text) + + if agent_state.embedding_config.embedding_endpoint_type != "openai": + embedding = embed_model.get_text_embedding(text) + else: + # TODO should have the settings passed in via the server call + from letta.settings import model_settings + + # Simple OpenAI client code + client = OpenAI( + api_key=model_settings.openai_api_key, base_url=agent_state.embedding_config.embedding_endpoint, max_retries=0 + ) + response = client.embeddings.create(input=text, model=agent_state.embedding_config.embedding_model) + embedding = response.data[0].embedding + if isinstance(embedding, dict): try: embedding = embedding["data"][0]["embedding"] diff --git a/letta/services/step_manager.py b/letta/services/step_manager.py index 5e6fbce5f3..49dbf316ba 100644 --- a/letta/services/step_manager.py +++ b/letta/services/step_manager.py @@ -1,4 +1,4 @@ -import datetime +from datetime import datetime from typing import List, Literal, Optional from sqlalchemy import select diff --git a/letta/services/tool_manager.py b/letta/services/tool_manager.py index 0aa0820666..211b3e33f3 100644 --- a/letta/services/tool_manager.py +++ b/letta/services/tool_manager.py @@ -11,7 +11,7 @@ from letta.orm.errors import NoResultFound from letta.orm.tool import Tool as ToolModel from letta.schemas.tool import Tool as PydanticTool -from letta.schemas.tool import ToolUpdate +from letta.schemas.tool import ToolCreate, ToolUpdate from letta.schemas.user import User as PydanticUser from letta.utils import enforce_types, printd @@ -57,9 +57,16 @@ def create_or_update_tool(self, pydantic_tool: PydanticTool, actor: PydanticUser return tool @enforce_types - def create_or_update_composio_tool(self, pydantic_tool: PydanticTool, actor: PydanticUser) -> PydanticTool: - pydantic_tool.tool_type = ToolType.EXTERNAL_COMPOSIO - return self.create_or_update_tool(pydantic_tool, actor) + def create_or_update_composio_tool(self, tool_create: ToolCreate, actor: PydanticUser) -> PydanticTool: + return self.create_or_update_tool( + PydanticTool(tool_type=ToolType.EXTERNAL_COMPOSIO, name=tool_create.json_schema["name"], **tool_create.model_dump()), actor + ) + + @enforce_types + def create_or_update_langchain_tool(self, tool_create: ToolCreate, actor: PydanticUser) -> PydanticTool: + return self.create_or_update_tool( + PydanticTool(tool_type=ToolType.EXTERNAL_LANGCHAIN, name=tool_create.json_schema["name"], **tool_create.model_dump()), actor + ) @enforce_types def create_tool(self, pydantic_tool: PydanticTool, actor: PydanticUser) -> PydanticTool: diff --git a/tests/integration_test_agent_tool_graph.py b/tests/integration_test_agent_tool_graph.py index 3a24e29bd0..025f751b48 100644 --- a/tests/integration_test_agent_tool_graph.py +++ b/tests/integration_test_agent_tool_graph.py @@ -659,3 +659,64 @@ def test_simple_tool_rule(mock_e2b_api_key_none): assert tool_calls[flip_coin_call_index + 1].tool_call.name == secret_word, "Fourth secret word should be called after flip_coin" cleanup(client, agent_uuid=agent_state.id) + + +def test_init_tool_rule_always_fails_one_tool(): + """ + Test an init tool rule that always fails when called. The agent has only one tool available. + + Once that tool fails and the agent removes that tool, the agent should have 0 tools available. + + This means that the agent should return from `step` early. + """ + client = create_client() + cleanup(client=client, agent_uuid=agent_uuid) + + # Create tools + bad_tool = client.create_or_update_tool(auto_error) + + # Create tool rule: InitToolRule + tool_rule = InitToolRule( + tool_name=bad_tool.name, + ) + + # Set up agent with the tool rule + claude_config = "tests/configs/llm_model_configs/claude-3-sonnet-20240229.json" + agent_state = setup_agent(client, claude_config, agent_uuid, tool_rules=[tool_rule], tool_ids=[bad_tool.id], include_base_tools=False) + + # Start conversation + response = client.user_message(agent_id=agent_state.id, message="blah blah blah") + + # Verify the tool calls + tool_calls = [msg for msg in response.messages if isinstance(msg, ToolCallMessage)] + assert len(tool_calls) >= 1 # Should have at least flip_coin and fourth_secret_word calls + assert_invoked_function_call(response.messages, bad_tool.name) + + +def test_init_tool_rule_always_fails_multiple_tools(): + """ + Test an init tool rule that always fails when called. The agent has only 1+ tools available. + Once that tool fails and the agent removes that tool, the agent should have other tools available. + """ + client = create_client() + cleanup(client=client, agent_uuid=agent_uuid) + + # Create tools + bad_tool = client.create_or_update_tool(auto_error) + + # Create tool rule: InitToolRule + tool_rule = InitToolRule( + tool_name=bad_tool.name, + ) + + # Set up agent with the tool rule + claude_config = "tests/configs/llm_model_configs/claude-3-sonnet-20240229.json" + agent_state = setup_agent(client, claude_config, agent_uuid, tool_rules=[tool_rule], tool_ids=[bad_tool.id], include_base_tools=True) + + # Start conversation + response = client.user_message(agent_id=agent_state.id, message="blah blah blah") + + # Verify the tool calls + tool_calls = [msg for msg in response.messages if isinstance(msg, ToolCallMessage)] + assert len(tool_calls) >= 1 # Should have at least flip_coin and fourth_secret_word calls + assert_invoked_function_call(response.messages, bad_tool.name) diff --git a/tests/integration_test_tool_execution_sandbox.py b/tests/integration_test_tool_execution_sandbox.py index 8418ac4798..ea3e6473da 100644 --- a/tests/integration_test_tool_execution_sandbox.py +++ b/tests/integration_test_tool_execution_sandbox.py @@ -190,7 +190,7 @@ def create_list(): def composio_github_star_tool(test_user): tool_manager = ToolManager() tool_create = ToolCreate.from_composio(action_name="GITHUB_STAR_A_REPOSITORY_FOR_THE_AUTHENTICATED_USER") - tool = tool_manager.create_or_update_composio_tool(pydantic_tool=Tool(**tool_create.model_dump()), actor=test_user) + tool = tool_manager.create_or_update_composio_tool(tool_create=tool_create, actor=test_user) yield tool @@ -198,7 +198,7 @@ def composio_github_star_tool(test_user): def composio_gmail_get_profile_tool(test_user): tool_manager = ToolManager() tool_create = ToolCreate.from_composio(action_name="GMAIL_GET_PROFILE") - tool = tool_manager.create_or_update_composio_tool(pydantic_tool=Tool(**tool_create.model_dump()), actor=test_user) + tool = tool_manager.create_or_update_composio_tool(tool_create=tool_create, actor=test_user) yield tool diff --git a/tests/test_client.py b/tests/test_client.py index 721a293fb7..c9cfae4a02 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -480,11 +480,14 @@ def always_error(): assert response_message.status == "error" if isinstance(client, RESTClient): - assert response_message.tool_return == "Error executing function always_error: ZeroDivisionError: division by zero" + assert ( + response_message.tool_return.startswith("Error calling function always_error") + and "ZeroDivisionError" in response_message.tool_return + ) else: response_json = json.loads(response_message.tool_return) assert response_json["status"] == "Failed" - assert response_json["message"] == "Error executing function always_error: ZeroDivisionError: division by zero" + assert "Error calling function always_error" in response_json["message"] and "ZeroDivisionError" in response_json["message"] client.delete_agent(agent_id=agent.id) diff --git a/tests/test_managers.py b/tests/test_managers.py index 0b6d629b95..a4d8adce9d 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -198,7 +198,7 @@ def print_tool(message: str): @pytest.fixture def composio_github_star_tool(server, default_user): tool_create = ToolCreate.from_composio(action_name="GITHUB_STAR_A_REPOSITORY_FOR_THE_AUTHENTICATED_USER") - tool = server.tool_manager.create_or_update_composio_tool(pydantic_tool=PydanticTool(**tool_create.model_dump()), actor=default_user) + tool = server.tool_manager.create_or_update_composio_tool(tool_create=tool_create, actor=default_user) yield tool diff --git a/tests/test_sdk_client.py b/tests/test_sdk_client.py index 5d56904156..f01f431ecb 100644 --- a/tests/test_sdk_client.py +++ b/tests/test_sdk_client.py @@ -440,7 +440,12 @@ def always_error(): assert response_message, "ToolReturnMessage message not found in response" assert response_message.status == "error" - assert response_message.tool_return == "Error executing function always_error: ZeroDivisionError: division by zero" + + # TODO try and get this format back, need to fix e2b return parsing + # assert response_message.tool_return == "Error executing function always_error: ZeroDivisionError: division by zero" + + assert response_message.tool_return.startswith("Error calling function always_error") + assert "ZeroDivisionError" in response_message.tool_return @pytest.mark.asyncio diff --git a/tests/test_tool_sandbox/restaurant_management_system/adjust_menu_prices.py b/tests/test_tool_sandbox/restaurant_management_system/adjust_menu_prices.py index ffe734b33a..f2a5bd1174 100644 --- a/tests/test_tool_sandbox/restaurant_management_system/adjust_menu_prices.py +++ b/tests/test_tool_sandbox/restaurant_management_system/adjust_menu_prices.py @@ -8,6 +8,7 @@ def adjust_menu_prices(percentage: float) -> str: str: A formatted string summarizing the price adjustments. """ import cowsay + from tqdm import tqdm from core.menu import Menu, MenuItem # Import a class from the codebase from core.utils import format_currency # Use a utility function to test imports from tqdm import tqdm diff --git a/tests/test_tool_schema_parsing.py b/tests/test_tool_schema_parsing.py index 627302ed59..4a3fb3cc62 100644 --- a/tests/test_tool_schema_parsing.py +++ b/tests/test_tool_schema_parsing.py @@ -205,3 +205,31 @@ def test_composio_tool_schema_generation(openai_model: str, structured_output: b print(f"Failed to call OpenAI using schema {schema} generated from {action_name}\n\n") raise + + +@pytest.mark.parametrize("openai_model", ["gpt-4o-mini"]) +@pytest.mark.parametrize("structured_output", [True]) +def test_langchain_tool_schema_generation(openai_model: str, structured_output: bool): + """Test that we can generate the schemas for some Langchain tools.""" + from langchain_community.tools import WikipediaQueryRun + from langchain_community.utilities import WikipediaAPIWrapper + + api_wrapper = WikipediaAPIWrapper(top_k_results=1, doc_content_chars_max=500) + langchain_tool = WikipediaQueryRun(api_wrapper=api_wrapper) + + tool_create = ToolCreate.from_langchain( + langchain_tool=langchain_tool, + additional_imports_module_attr_map={"langchain_community.utilities": "WikipediaAPIWrapper"}, + ) + + assert tool_create.json_schema + schema = tool_create.json_schema + print(f"The schema for {langchain_tool.name}: {json.dumps(schema, indent=4)}\n\n") + + try: + _openai_payload(openai_model, schema, structured_output) + print(f"Successfully called OpenAI using schema {schema} generated from {langchain_tool.name}\n\n") + except: + print(f"Failed to call OpenAI using schema {schema} generated from {langchain_tool.name}\n\n") + + raise