Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Sandboxing for tool execution #2040

Open
wants to merge 30 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
b41a659
Enable basic sandboxed tool run functionality
Oct 24, 2024
2233cce
Clean up and add helpers
Oct 25, 2024
ea8f140
Modularize and handle more argument types
Oct 26, 2024
1936a53
Add unit tests and minor fixes
Oct 27, 2024
a57d2d5
Basic api key handling
Oct 27, 2024
5a2fb5c
Add auto install support and unit test for composio
Oct 28, 2024
8b66e50
Merge main
mattzh72 Nov 13, 2024
ceab4b9
Install e2b and add to poetry
mattzh72 Nov 13, 2024
a06566e
Finish sandbox manager
mattzh72 Nov 14, 2024
7f8f6dd
Finish multi sandbox environment
mattzh72 Nov 14, 2024
687c3e8
Create migration scripts
mattzh72 Nov 14, 2024
fccb899
Rename local sandbox type
mattzh72 Nov 14, 2024
6dfab00
Fix tests
mattzh72 Nov 14, 2024
2b1b229
Write unit test suite for sandboxing
mattzh72 Nov 15, 2024
8db71de
Ignore venvs
mattzh72 Nov 15, 2024
2aaf85a
Ignore venvs
mattzh72 Nov 15, 2024
a0c05fb
remove test venv
mattzh72 Nov 15, 2024
0d8f921
Add venv creation dynamically to tests
mattzh72 Nov 15, 2024
4f62c94
Add sandbox_config_id to env vars and update alembic migration
mattzh72 Nov 15, 2024
845e550
Add tool execution sandbox tests to CI
mattzh72 Nov 15, 2024
3e93874
add routes and some supplementary safety logic like state hashing
mattzh72 Nov 15, 2024
64ca45e
Fix tool sandbox tests
mattzh72 Nov 15, 2024
a5db371
download cloud tool sandbox for CI tests
mattzh72 Nov 15, 2024
ac96871
download cloud tool sandbox for more CI tests
mattzh72 Nov 15, 2024
3cf97b7
Finish unit testing as well
mattzh72 Nov 15, 2024
0d59f84
Add extra integration e2b tests
mattzh72 Nov 15, 2024
5e120ac
Try again
mattzh72 Nov 15, 2024
01ac30a
Remove test tools
mattzh72 Nov 15, 2024
0bd12e8
Disable tool graph tests
mattzh72 Nov 15, 2024
91cb3c4
Merge main
mattzh72 Nov 16, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/docker-integration-tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ jobs:
run: |
pipx install poetry==1.8.2
poetry install -E dev -E postgres
poetry run pytest -s tests/test_client.py
poetry run pytest -s tests/test_client_legacy.py
- name: Print docker logs if tests fail
if: failure()
Expand Down
12 changes: 7 additions & 5 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ env:
ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }}
GEMINI_API_KEY: ${{ secrets.GEMINI_API_KEY }}
GROQ_API_KEY: ${{ secrets.GROQ_API_KEY }}
E2B_API_KEY: ${{ secrets.E2B_API_KEY }}

on:
push:
Expand All @@ -21,14 +22,15 @@ jobs:
fail-fast: false
matrix:
test_suite:
- "test_local_client.py"
- "test_client.py"
- "test_local_client.py"
- "test_client_legacy.py"
- "test_server.py"
- "test_managers.py"
- "test_tools.py"
- "test_o1_agent.py"
- "test_tool_rule_solver.py"
- "test_agent_tool_graph.py"
- "test_tool_execution_sandbox.py"
services:
qdrant:
image: qdrant/qdrant
Expand Down Expand Up @@ -57,7 +59,7 @@ jobs:
with:
python-version: "3.12"
poetry-version: "1.8.2"
install-args: "-E dev -E postgres -E milvus -E external-tools -E tests"
install-args: "-E dev -E postgres -E milvus -E external-tools -E tests -E cloud-tool-sandbox"
- name: Migrate database
env:
LETTA_PG_PORT: 5432
Expand Down Expand Up @@ -110,7 +112,7 @@ jobs:
with:
python-version: "3.12"
poetry-version: "1.8.2"
install-args: "-E dev -E postgres -E milvus -E external-tools -E tests"
install-args: "-E dev -E postgres -E milvus -E external-tools -E tests -E cloud-tool-sandbox"
- name: Migrate database
env:
LETTA_PG_PORT: 5432
Expand All @@ -131,4 +133,4 @@ jobs:
LETTA_SERVER_PASS: test_server_token
PYTHONPATH: ${{ github.workspace }}:${{ env.PYTHONPATH }}
run: |
poetry run pytest -s -vv -k "not integration_test_summarizer.py and not test_agent_tool_graph.py and not test_tool_rule_solver.py and not test_local_client.py and not test_o1_agent.py and not test_cli.py and not test_tools.py and not test_concurrent_connections.py and not test_quickstart and not test_model_letta_perfomance and not test_storage and not test_server and not test_openai_client and not test_providers and not test_client.py" tests
poetry run pytest -s -vv -k "not test_client.py and not test_tool_execution_sandbox.py and not integration_test_summarizer.py and not test_agent_tool_graph.py and not test_tool_rule_solver.py and not test_local_client.py and not test_o1_agent.py and not test_cli.py and not test_concurrent_connections.py and not test_quickstart and not test_model_letta_performance and not test_storage and not test_server and not test_openai_client and not test_providers and not test_client_legacy.py" tests
1 change: 0 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -1019,7 +1019,6 @@ letta/.pytest_cache/
memgpy/pytest.ini
**/**/pytest_cache


# local sandbox venvs
letta/services/tool_sandbox_env/*
tests/test_tool_sandbox/*
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
"""Create sandbox config and sandbox env var tables

Revision ID: f81ceea2c08d
Revises: c85a3d07c028
Create Date: 2024-11-14 17:51:27.263561

"""

from typing import Sequence, Union

import sqlalchemy as sa

from alembic import op

# revision identifiers, used by Alembic.
revision: str = "f81ceea2c08d"
down_revision: Union[str, None] = "c85a3d07c028"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.create_table(
"sandbox_configs",
sa.Column("id", sa.String(), nullable=False),
sa.Column("type", sa.Enum("E2B", "LOCAL", name="sandboxtype"), nullable=False),
sa.Column("config", sa.JSON(), nullable=False),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True),
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True),
sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("FALSE"), nullable=False),
sa.Column("_created_by_id", sa.String(), nullable=True),
sa.Column("_last_updated_by_id", sa.String(), nullable=True),
sa.Column("organization_id", sa.String(), nullable=False),
sa.ForeignKeyConstraint(
["organization_id"],
["organizations.id"],
),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("type", "organization_id", name="uix_type_organization"),
)
op.create_table(
"sandbox_environment_variables",
sa.Column("id", sa.String(), nullable=False),
sa.Column("key", sa.String(), nullable=False),
sa.Column("value", sa.String(), nullable=False),
sa.Column("description", sa.String(), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True),
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True),
sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("FALSE"), nullable=False),
sa.Column("_created_by_id", sa.String(), nullable=True),
sa.Column("_last_updated_by_id", sa.String(), nullable=True),
sa.Column("organization_id", sa.String(), nullable=False),
sa.Column("sandbox_config_id", sa.String(), nullable=False),
sa.ForeignKeyConstraint(
["organization_id"],
["organizations.id"],
),
sa.ForeignKeyConstraint(
["sandbox_config_id"],
["sandbox_configs.id"],
),
sa.PrimaryKeyConstraint("id"),
sa.UniqueConstraint("key", "sandbox_config_id", name="uix_key_sandbox_config"),
)
# ### end Alembic commands ###


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table("sandbox_environment_variables")
op.drop_table("sandbox_configs")
# ### end Alembic commands ###
2 changes: 1 addition & 1 deletion examples/docs/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@


# define a function with a docstring
def roll_d20(self) -> str:
def roll_d20() -> str:
"""
Simulate the roll of a 20-sided die (d20).

Expand Down
10 changes: 5 additions & 5 deletions examples/tool_rule_usage.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,14 +31,14 @@
"""Contrived tools for this test case"""


def first_secret_word(self: "Agent"):
def first_secret_word():
"""
Call this to retrieve the first secret word, which you will need for the second_secret_word function.
"""
return "v0iq020i0g"


def second_secret_word(self: "Agent", prev_secret_word: str):
def second_secret_word(prev_secret_word: str):
"""
Call this to retrieve the second secret word, which you will need for the third_secret_word function. If you get the word wrong, this function will error.

Expand All @@ -51,7 +51,7 @@ def second_secret_word(self: "Agent", prev_secret_word: str):
return "4rwp2b4gxq"


def third_secret_word(self: "Agent", prev_secret_word: str):
def third_secret_word(prev_secret_word: str):
"""
Call this to retrieve the third secret word, which you will need for the fourth_secret_word function. If you get the word wrong, this function will error.

Expand All @@ -64,7 +64,7 @@ def third_secret_word(self: "Agent", prev_secret_word: str):
return "hj2hwibbqm"


def fourth_secret_word(self: "Agent", prev_secret_word: str):
def fourth_secret_word(prev_secret_word: str):
"""
Call this to retrieve the last secret word, which you will need to output in a send_message later. If you get the word wrong, this function will error.

Expand All @@ -77,7 +77,7 @@ def fourth_secret_word(self: "Agent", prev_secret_word: str):
return "banana"


def auto_error(self: "Agent"):
def auto_error():
"""
If you call this function, it will throw an error automatically.
"""
Expand Down
13 changes: 11 additions & 2 deletions letta/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

from letta.agent_store.storage import StorageConnector
from letta.constants import (
BASE_CORE_MEMORY_TOOLS,
BASE_TOOLS,
CLI_WARNING_PREFIX,
FIRST_MESSAGE_ATTEMPTS,
FUNC_FAILED_HEARTBEAT_MESSAGE,
Expand Down Expand Up @@ -47,6 +49,7 @@
from letta.schemas.tool_rule import TerminalToolRule
from letta.schemas.usage import LettaUsageStatistics
from letta.services.source_manager import SourceManager
from letta.services.tool_execution_sandbox import ToolExecutionSandbox
from letta.services.user_manager import UserManager
from letta.streaming_interface import StreamingRefreshCLIInterface
from letta.system import (
Expand Down Expand Up @@ -720,9 +723,15 @@ def _handle_ai_response(
if isinstance(function_args[name], dict):
function_args[name] = spec[name](**function_args[name])

function_args["self"] = self # need to attach self to arg since it's dynamically linked
# TODO: This needs to be rethought, how do we allow functions that modify agent state/db?
# TODO: There should probably be two types of tools: stateless/stateful

if function_name in BASE_TOOLS or function_name in BASE_CORE_MEMORY_TOOLS:
function_args["self"] = self # need to attach self to arg since it's dynamically linked
function_response = function_to_call(**function_args)
else:
function_response = ToolExecutionSandbox(function_name, function_args, self.agent_state.user_id).run()

function_response = function_to_call(**function_args)
if function_name in ["conversation_search", "conversation_search_date", "archival_memory_search"]:
# with certain functions we rely on the paging mechanism to handle overflow
truncate = False
Expand Down
7 changes: 4 additions & 3 deletions letta/agent_store/db.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,9 +380,10 @@ def __init__(self, table_type: str, config: LettaConfig, user_id, agent_id=None)
else:
raise ValueError(f"Table type {table_type} not implemented")

for c in self.db_model.__table__.columns:
if c.name == "embedding":
assert isinstance(c.type, Vector), f"Embedding column must be of type Vector, got {c.type}"
if settings.pg_uri:
for c in self.db_model.__table__.columns:
if c.name == "embedding":
assert isinstance(c.type, Vector), f"Embedding column must be of type Vector, got {c.type}"

from letta.server.server import db_context

Expand Down
Loading
Loading