Skip to content

Commit

Permalink
Add tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mattzh72 committed Nov 7, 2024
1 parent a61a507 commit d6f7c86
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 3 deletions.
14 changes: 11 additions & 3 deletions letta/server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -379,6 +379,9 @@ def _load_agent(self, agent_id: str, actor: User, interface: Union[AgentInterfac
except NoResultFound:
warnings.warn(f"Tried to retrieve a tool with name {name} from the agent_state, but does not exist in tool db.")

# set agent_state tools to only the names of the available tools
agent_state.tools = [t.name for t in tool_objs]

# Make sure the memory is a memory object
assert isinstance(agent_state.memory, Memory)

Expand Down Expand Up @@ -807,12 +810,17 @@ def create_agent(
llm_config = request.llm_config
embedding_config = request.embedding_config

# get tools + make sure they exist
# get tools + only add if they exist
tool_objs = []
if request.tools:
for tool_name in request.tools:
tool_obj = self.tool_manager.get_tool_by_name(tool_name=tool_name, actor=actor)
tool_objs.append(tool_obj)
try:
tool_obj = self.tool_manager.get_tool_by_name(tool_name=tool_name, actor=actor)
tool_objs.append(tool_obj)
except NoResultFound:
warnings.warn(f"Attempted to add a nonexistent tool {tool_name} to agent {request.name}, skipping.")
# reset the request.tools to only valid tools
request.tools = [t.name for t in tool_objs]

assert request.memory is not None
memory_functions = get_memory_functions(request.memory)
Expand Down
30 changes: 30 additions & 0 deletions tests/test_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,3 +541,33 @@ def test_get_context_window_overview(server: SyncServer, user_id: str, agent_id:
+ overview.num_tokens_functions_definitions
+ overview.num_tokens_external_memory_summary
)


def test_load_agent_with_nonexistent_tool_names_does_not_error(server: SyncServer, user_id: str):
fake_tool_name = "blahblahblah"
tools = BASE_TOOLS + [fake_tool_name]
agent_state = server.create_agent(
request=CreateAgent(
name="nonexistent_tools_agent",
tools=tools,
memory=ChatMemory(
human="Sarah",
persona="I am a helpful assistant",
),
llm_config=LLMConfig.default_config("gpt-4"),
embedding_config=EmbeddingConfig.default_config(provider="openai"),
),
actor=server.get_user_or_default(user_id),
)

# Check that the tools in agent_state do NOT include the fake name
assert fake_tool_name not in agent_state.tools
assert set(BASE_TOOLS).issubset(set(agent_state.tools))

# Load the agent from the database and check that it doesn't error / tools are correct
saved_tools = server.get_tools_from_agent(agent_id=agent_state.id, user_id=user_id)
assert fake_tool_name not in agent_state.tools
assert set(BASE_TOOLS).issubset(set(agent_state.tools))

# cleanup
server.delete_agent(user_id, agent_state.id)

0 comments on commit d6f7c86

Please sign in to comment.