Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add vertex support #2429

Merged
merged 8 commits into from
Feb 13, 2025
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
"""Add message_buffer_autoclear option for AgentState

Revision ID: 7980d239ea08
Revises: dfafcf8210ca
Create Date: 2025-02-12 14:02:00.918226

"""

from typing import Sequence, Union

import sqlalchemy as sa

from alembic import op

# revision identifiers, used by Alembic.
revision: str = "7980d239ea08"
down_revision: Union[str, None] = "dfafcf8210ca"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None


def upgrade() -> None:
# Add the column with a temporary nullable=True so we can backfill
op.add_column("agents", sa.Column("message_buffer_autoclear", sa.Boolean(), nullable=True))

# Backfill existing rows to set message_buffer_autoclear to False where it's NULL
op.execute("UPDATE agents SET message_buffer_autoclear = false WHERE message_buffer_autoclear IS NULL")

# Now, enforce nullable=False after backfilling
op.alter_column("agents", "message_buffer_autoclear", nullable=False)


def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("agents", "message_buffer_autoclear")
# ### end Alembic commands ###
2 changes: 1 addition & 1 deletion letta/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.6.24"
__version__ = "0.6.25"

# import clients
from letta.client.client import LocalClient, RESTClient, create_client
Expand Down
22 changes: 19 additions & 3 deletions letta/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@
get_utc_time,
json_dumps,
json_loads,
log_telemetry,
parse_json,
printd,
validate_function_response,
Expand Down Expand Up @@ -306,7 +307,7 @@ def _get_ai_reply(
last_function_failed: bool = False,
) -> ChatCompletionResponse:
"""Get response from LLM API with robust retry mechanism."""

log_telemetry(self.logger, "_get_ai_reply start")
allowed_tool_names = self.tool_rules_solver.get_allowed_tool_names(last_function_response=self.last_function_response)
agent_state_tool_jsons = [t.json_schema for t in self.agent_state.tools]

Expand Down Expand Up @@ -337,6 +338,7 @@ def _get_ai_reply(

for attempt in range(1, empty_response_retry_limit + 1):
try:
log_telemetry(self.logger, "_get_ai_reply create start")
response = create(
llm_config=self.agent_state.llm_config,
messages=message_sequence,
Expand All @@ -349,6 +351,7 @@ def _get_ai_reply(
stream=stream,
stream_interface=self.interface,
)
log_telemetry(self.logger, "_get_ai_reply create finish")

# These bottom two are retryable
if len(response.choices) == 0 or response.choices[0] is None:
Expand All @@ -360,12 +363,13 @@ def _get_ai_reply(
raise RuntimeError("Finish reason was length (maximum context length)")
else:
raise ValueError(f"Bad finish reason from API: {response.choices[0].finish_reason}")

log_telemetry(self.logger, "_handle_ai_response finish")
return response

except ValueError as ve:
if attempt >= empty_response_retry_limit:
warnings.warn(f"Retry limit reached. Final error: {ve}")
log_telemetry(self.logger, "_handle_ai_response finish ValueError")
raise Exception(f"Retries exhausted and no valid response received. Final error: {ve}")
else:
delay = min(backoff_factor * (2 ** (attempt - 1)), max_delay)
Expand All @@ -374,8 +378,10 @@ def _get_ai_reply(

except Exception as e:
# For non-retryable errors, exit immediately
log_telemetry(self.logger, "_handle_ai_response finish generic Exception")
raise e

log_telemetry(self.logger, "_handle_ai_response finish catch-all exception")
raise Exception("Retries exhausted and no valid response received.")

def _handle_ai_response(
Expand All @@ -388,7 +394,7 @@ def _handle_ai_response(
response_message_id: Optional[str] = None,
) -> Tuple[List[Message], bool, bool]:
"""Handles parsing and function execution"""

log_telemetry(self.logger, "_handle_ai_response start")
# Hacky failsafe for now to make sure we didn't implement the streaming Message ID creation incorrectly
if response_message_id is not None:
assert response_message_id.startswith("message-"), response_message_id
Expand Down Expand Up @@ -506,7 +512,13 @@ def _handle_ai_response(
self.interface.function_message(f"Running {function_name}({function_args})", msg_obj=messages[-1])
try:
# handle tool execution (sandbox) and state updates
log_telemetry(
self.logger, "_handle_ai_response execute tool start", function_name=function_name, function_args=function_args
)
function_response, sandbox_run_result = self.execute_tool_and_persist_state(function_name, function_args, target_letta_tool)
log_telemetry(
self.logger, "_handle_ai_response execute tool finish", function_name=function_name, function_args=function_args
)

if sandbox_run_result and sandbox_run_result.status == "error":
messages = self._handle_function_error_response(
Expand Down Expand Up @@ -597,6 +609,7 @@ def _handle_ai_response(
elif self.tool_rules_solver.is_terminal_tool(function_name):
heartbeat_request = False

log_telemetry(self.logger, "_handle_ai_response finish")
return messages, heartbeat_request, function_failed

def step(
Expand Down Expand Up @@ -684,6 +697,9 @@ def step(
else:
break

if self.agent_state.message_buffer_autoclear:
self.agent_manager.trim_all_in_context_messages_except_system(self.agent_state.id, actor=self.user)

return LettaUsageStatistics(**total_usage.model_dump(), step_count=step_count)

def inner_step(
Expand Down
5 changes: 5 additions & 0 deletions letta/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ def create_agent(
metadata: Optional[Dict] = {"human:": DEFAULT_HUMAN, "persona": DEFAULT_PERSONA},
description: Optional[str] = None,
tags: Optional[List[str]] = None,
message_buffer_autoclear: bool = False,
) -> AgentState:
raise NotImplementedError

Expand Down Expand Up @@ -540,6 +541,7 @@ def create_agent(
description: Optional[str] = None,
initial_message_sequence: Optional[List[Message]] = None,
tags: Optional[List[str]] = None,
message_buffer_autoclear: bool = False,
) -> AgentState:
"""Create an agent

Expand Down Expand Up @@ -600,6 +602,7 @@ def create_agent(
"initial_message_sequence": initial_message_sequence,
"tags": tags,
"include_base_tools": include_base_tools,
"message_buffer_autoclear": message_buffer_autoclear,
}

# Only add name if it's not None
Expand Down Expand Up @@ -2353,6 +2356,7 @@ def create_agent(
description: Optional[str] = None,
initial_message_sequence: Optional[List[Message]] = None,
tags: Optional[List[str]] = None,
message_buffer_autoclear: bool = False,
) -> AgentState:
"""Create an agent

Expand Down Expand Up @@ -2404,6 +2408,7 @@ def create_agent(
"embedding_config": embedding_config if embedding_config else self._default_embedding_config,
"initial_message_sequence": initial_message_sequence,
"tags": tags,
"message_buffer_autoclear": message_buffer_autoclear,
}

# Only add name if it's not None
Expand Down
21 changes: 21 additions & 0 deletions letta/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,19 @@ def get_text_embedding(self, text: str):
return response_json["embedding"]["values"]


class GoogleVertexEmbeddings:

def __init__(self, model: str, project_id: str, region: str):
from google import genai

self.client = genai.Client(vertexai=True, project=project_id, location=region, http_options={"api_version": "v1"})
self.model = model

def get_text_embedding(self, text: str):
response = self.client.generate_embeddings(content=text, model=self.model)
return response.embeddings[0].embedding


def query_embedding(embedding_model, query_text: str):
"""Generate padded embedding for querying database"""
query_vec = embedding_model.get_text_embedding(query_text)
Expand Down Expand Up @@ -267,5 +280,13 @@ def embedding_model(config: EmbeddingConfig, user_id: Optional[uuid.UUID] = None
)
return model

elif endpoint_type == "google_vertex":
model = GoogleVertexEmbeddings(
model=config.embedding_model,
api_key=model_settings.gemini_api_key,
base_url=model_settings.gemini_base_url,
)
return model

else:
raise ValueError(f"Unknown endpoint type {endpoint_type}")
29 changes: 28 additions & 1 deletion letta/functions/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from letta.schemas.user import User
from letta.server.rest_api.utils import get_letta_server
from letta.settings import settings
from letta.utils import log_telemetry


# TODO: This is kind of hacky, as this is used to search up the action later on composio's side
Expand Down Expand Up @@ -341,10 +342,16 @@ async def async_send_message_with_retries(
timeout: int,
logging_prefix: Optional[str] = None,
) -> str:

logging_prefix = logging_prefix or "[async_send_message_with_retries]"
log_telemetry(sender_agent.logger, f"async_send_message_with_retries start", target_agent_id=target_agent_id)

for attempt in range(1, max_retries + 1):
try:
log_telemetry(
sender_agent.logger,
f"async_send_message_with_retries -> asyncio wait for send_message_to_agent_no_stream start",
target_agent_id=target_agent_id,
)
response = await asyncio.wait_for(
send_message_to_agent_no_stream(
server=server,
Expand All @@ -354,15 +361,24 @@ async def async_send_message_with_retries(
),
timeout=timeout,
)
log_telemetry(
sender_agent.logger,
f"async_send_message_with_retries -> asyncio wait for send_message_to_agent_no_stream finish",
target_agent_id=target_agent_id,
)

# Then parse out the assistant message
assistant_message = parse_letta_response_for_assistant_message(target_agent_id, response)
if assistant_message:
sender_agent.logger.info(f"{logging_prefix} - {assistant_message}")
log_telemetry(
sender_agent.logger, f"async_send_message_with_retries finish with assistant message", target_agent_id=target_agent_id
)
return assistant_message
else:
msg = f"(No response from agent {target_agent_id})"
sender_agent.logger.info(f"{logging_prefix} - {msg}")
log_telemetry(sender_agent.logger, f"async_send_message_with_retries finish no response", target_agent_id=target_agent_id)
return msg

except asyncio.TimeoutError:
Expand All @@ -380,6 +396,12 @@ async def async_send_message_with_retries(
await asyncio.sleep(backoff)
else:
sender_agent.logger.error(f"{logging_prefix} - Fatal error: {error_msg}")
log_telemetry(
sender_agent.logger,
f"async_send_message_with_retries finish fatal error",
target_agent_id=target_agent_id,
error_msg=error_msg,
)
raise Exception(error_msg)


Expand Down Expand Up @@ -468,6 +490,7 @@ def runner():


async def _send_message_to_agents_matching_all_tags_async(sender_agent: "Agent", message: str, tags: List[str]) -> List[str]:
log_telemetry(sender_agent.logger, "_send_message_to_agents_matching_all_tags_async start", message=message, tags=tags)
server = get_letta_server()

augmented_message = (
Expand All @@ -477,7 +500,9 @@ async def _send_message_to_agents_matching_all_tags_async(sender_agent: "Agent",
)

# Retrieve up to 100 matching agents
log_telemetry(sender_agent.logger, "_send_message_to_agents_matching_all_tags_async listing agents start", message=message, tags=tags)
matching_agents = server.agent_manager.list_agents(actor=sender_agent.user, tags=tags, match_all_tags=True, limit=100)
log_telemetry(sender_agent.logger, "_send_message_to_agents_matching_all_tags_async listing agents finish", message=message, tags=tags)

# Create a system message
messages = [MessageCreate(role=MessageRole.system, content=augmented_message, name=sender_agent.agent_state.name)]
Expand All @@ -504,4 +529,6 @@ async def _send_single(agent_state):
final.append(str(r))
else:
final.append(r)

log_telemetry(sender_agent.logger, "_send_message_to_agents_matching_all_tags_async finish", message=message, tags=tags)
return final
Loading