Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Fix Azure provider and add complex e2e testing #1842

Merged
merged 10 commits into from
Oct 8, 2024
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions configs/llm_model_configs/azure-gpt-4o-mini.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
{
"context_window": 128000,
"model": "gpt-4o-mini",
"model_endpoint_type": "azure",
"api_version": "2023-03-15-preview",
"model_wrapper": null
}
28 changes: 19 additions & 9 deletions letta/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
MESSAGE_SUMMARY_WARNING_FRAC,
)
from letta.interface import AgentInterface
from letta.llm_api.llm_api_tools import create, is_context_overflow_error
from letta.llm_api.llm_api_tools import create
from letta.memory import ArchivalMemory, RecallMemory, summarize_messages
from letta.metadata import MetadataStore
from letta.persistence_manager import LocalStateManager
Expand Down Expand Up @@ -56,6 +56,7 @@
)

from .errors import LLMError
from .llm_api.helpers import is_context_overflow_error


def compile_memory_metadata_block(
Expand Down Expand Up @@ -207,7 +208,7 @@ def step(
recreate_message_timestamp: bool = True, # if True, when input is a Message type, recreated the 'created_at' field
stream: bool = False, # TODO move to config?
timestamp: Optional[datetime.datetime] = None,
inner_thoughts_in_kwargs: OptionState = OptionState.DEFAULT,
inner_thoughts_in_kwargs_option: OptionState = OptionState.DEFAULT,
ms: Optional[MetadataStore] = None,
) -> AgentStepResponse:
"""
Expand All @@ -223,7 +224,7 @@ def update_state(self) -> AgentState:
class Agent(BaseAgent):
def __init__(
self,
interface: AgentInterface,
interface: Optional[AgentInterface],
# agents can be created from providing agent_state
agent_state: AgentState,
tools: List[Tool],
Expand Down Expand Up @@ -460,7 +461,7 @@ def _get_ai_reply(
function_call: str = "auto",
first_message: bool = False, # hint
stream: bool = False, # TODO move to config?
inner_thoughts_in_kwargs: OptionState = OptionState.DEFAULT,
inner_thoughts_in_kwargs_option: OptionState = OptionState.DEFAULT,
) -> ChatCompletionResponse:
"""Get response from LLM API"""
try:
Expand All @@ -478,7 +479,7 @@ def _get_ai_reply(
stream=stream,
stream_inferface=self.interface,
# putting inner thoughts in func args or not
inner_thoughts_in_kwargs=inner_thoughts_in_kwargs,
inner_thoughts_in_kwargs_option=inner_thoughts_in_kwargs_option,
)

if len(response.choices) == 0:
Expand Down Expand Up @@ -560,6 +561,8 @@ def _handle_ai_response(
function_call = (
response_message.function_call if response_message.function_call is not None else response_message.tool_calls[0].function
)

# Get the name of the function
function_name = function_call.name
printd(f"Request to call function {function_name} with tool_call_id: {tool_call_id}")

Expand Down Expand Up @@ -608,6 +611,13 @@ def _handle_ai_response(
self.interface.function_message(f"Error: {error_msg}", msg_obj=messages[-1])
return messages, False, True # force a heartbeat to allow agent to handle error

# Check if inner thoughts is in the function call arguments (possible apparently if you are using Azure)
if "inner_thoughts" in function_args:
response_message.content = function_args.pop("inner_thoughts")
# The content if then internal monologue, not chat
if response_message.content:
self.interface.internal_monologue(response_message.content, msg_obj=messages[-1])

# (Still parsing function args)
# Handle requests for immediate heartbeat
heartbeat_request = function_args.pop("request_heartbeat", None)
Expand Down Expand Up @@ -716,7 +726,7 @@ def step(
recreate_message_timestamp: bool = True, # if True, when input is a Message type, recreated the 'created_at' field
stream: bool = False, # TODO move to config?
timestamp: Optional[datetime.datetime] = None,
inner_thoughts_in_kwargs: OptionState = OptionState.DEFAULT,
inner_thoughts_in_kwargs_option: OptionState = OptionState.DEFAULT,
ms: Optional[MetadataStore] = None,
) -> AgentStepResponse:
"""Top-level event message handler for the Letta agent"""
Expand Down Expand Up @@ -795,7 +805,7 @@ def step(
message_sequence=input_message_sequence,
first_message=True, # passed through to the prompt formatter
stream=stream,
inner_thoughts_in_kwargs=inner_thoughts_in_kwargs,
inner_thoughts_in_kwargs_option=inner_thoughts_in_kwargs_option,
)
if verify_first_message_correctness(response, require_monologue=self.first_message_verify_mono):
break
Expand All @@ -808,7 +818,7 @@ def step(
response = self._get_ai_reply(
message_sequence=input_message_sequence,
stream=stream,
inner_thoughts_in_kwargs=inner_thoughts_in_kwargs,
inner_thoughts_in_kwargs_option=inner_thoughts_in_kwargs_option,
)

# Step 3: check if LLM wanted to call a function
Expand Down Expand Up @@ -892,7 +902,7 @@ def step(
recreate_message_timestamp=recreate_message_timestamp,
stream=stream,
timestamp=timestamp,
inner_thoughts_in_kwargs=inner_thoughts_in_kwargs,
inner_thoughts_in_kwargs_option=inner_thoughts_in_kwargs_option,
ms=ms,
)

Expand Down
6 changes: 3 additions & 3 deletions letta/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,14 @@ class LettaCredentials:

# azure config
azure_auth_type: str = "api_key"
azure_key: Optional[str] = None
azure_key: Optional[str] = os.getenv("AZURE_OPENAI_API_KEY")

# groq config
groq_key: Optional[str] = os.getenv("GROQ_API_KEY")

# base llm / model
azure_version: Optional[str] = None
azure_endpoint: Optional[str] = None
azure_version: Optional[str] = "2023-03-15-preview" # None
mattzh72 marked this conversation as resolved.
Show resolved Hide resolved
azure_endpoint: Optional[str] = "letta" # None
azure_deployment: Optional[str] = None
# embeddings
azure_embedding_version: Optional[str] = None
Expand Down
2 changes: 1 addition & 1 deletion letta/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def construct_error_message(messages: List[Union["Message", "LettaMessage"]], er
error_msg += f" (Explanation: {explanation})"

# Pretty print out message JSON
message_json = json.dumps([message.model_dump_json(indent=4) for message in messages], indent=4)
message_json = json.dumps([message.model_dump() for message in messages], indent=4)
return f"{error_msg}\n\n{message_json}"


Expand Down
34 changes: 15 additions & 19 deletions letta/llm_api/azure_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,11 @@

import requests

from letta.schemas.llm_config import LLMConfig
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse
from letta.schemas.openai.chat_completions import ChatCompletionRequest
from letta.schemas.openai.embedding_response import EmbeddingResponse
from letta.settings import ModelSettings
from letta.utils import smart_urljoin

MODEL_TO_AZURE_ENGINE = {
Expand All @@ -13,17 +16,16 @@
"gpt-3.5": "gpt-35-turbo",
"gpt-3.5-turbo": "gpt-35-turbo",
"gpt-3.5-turbo-16k": "gpt-35-turbo-16k",
"gpt-4o-mini": "gpt-4o-mini",
}


def clean_azure_endpoint(raw_endpoint_name: str) -> str:
"""Make sure the endpoint is of format 'https://YOUR_RESOURCE_NAME.openai.azure.com'"""
if raw_endpoint_name is None:
raise ValueError(raw_endpoint_name)
endpoint_address = raw_endpoint_name.strip("/").replace(".openai.azure.com", "")
endpoint_address = endpoint_address.replace("http://", "")
endpoint_address = endpoint_address.replace("https://", "")
return endpoint_address
def get_azure_endpoint(llm_config: LLMConfig, model_settings: ModelSettings):
assert llm_config.api_version, "Missing model version! This field must be provided in the LLM config for Azure."
assert llm_config.model in MODEL_TO_AZURE_ENGINE, f"{llm_config.model} not in supported models: {list(MODEL_TO_AZURE_ENGINE.keys())}"

model = MODEL_TO_AZURE_ENGINE[llm_config.model]
return f"{model_settings.azure_base_url}/openai/deployments/{model}/chat/completions?api-version={llm_config.api_version}"


def azure_openai_get_model_list(url: str, api_key: Union[str, None], api_version: str) -> dict:
Expand Down Expand Up @@ -72,19 +74,15 @@ def azure_openai_get_model_list(url: str, api_key: Union[str, None], api_version


def azure_openai_chat_completions_request(
resource_name: str, deployment_id: str, api_version: str, api_key: str, data: dict
model_settings: ModelSettings, llm_config: LLMConfig, api_key: str, chat_completion_request: ChatCompletionRequest
) -> ChatCompletionResponse:
"""https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#chat-completions"""
from letta.utils import printd

assert resource_name is not None, "Missing required field when calling Azure OpenAI"
assert deployment_id is not None, "Missing required field when calling Azure OpenAI"
assert api_version is not None, "Missing required field when calling Azure OpenAI"
assert api_key is not None, "Missing required field when calling Azure OpenAI"

resource_name = clean_azure_endpoint(resource_name)
url = f"https://{resource_name}.openai.azure.com/openai/deployments/{deployment_id}/chat/completions?api-version={api_version}"
headers = {"Content-Type": "application/json", "api-key": f"{api_key}"}
data = chat_completion_request.model_dump(exclude_none=True)

# If functions == None, strip from the payload
if "functions" in data and data["functions"] is None:
Expand All @@ -95,11 +93,10 @@ def azure_openai_chat_completions_request(
data.pop("tools")
data.pop("tool_choice", None) # extra safe, should exist always (default="auto")

printd(f"Sending request to {url}")
model_endpoint = get_azure_endpoint(llm_config, model_settings)
printd(f"Sending request to {model_endpoint}")
try:
data["messages"] = [i.to_openai_dict() for i in data["messages"]]
response = requests.post(url, headers=headers, json=data)
printd(f"response = {response}")
response = requests.post(model_endpoint, headers=headers, json=data)
response.raise_for_status() # Raises HTTPError for 4XX/5XX status
response = response.json() # convert to dict from string
printd(f"response.json = {response}")
Expand Down Expand Up @@ -128,7 +125,6 @@ def azure_openai_embeddings_request(
"""https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#embeddings"""
from letta.utils import printd

resource_name = clean_azure_endpoint(resource_name)
url = f"https://{resource_name}.openai.azure.com/openai/deployments/{deployment_id}/embeddings?api-version={api_version}"
headers = {"Content-Type": "application/json", "api-key": f"{api_key}"}

Expand Down
153 changes: 153 additions & 0 deletions letta/llm_api/helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
import copy
import json
import warnings
from typing import List, Union

import requests

from letta.constants import OPENAI_CONTEXT_WINDOW_ERROR_SUBSTRING
from letta.schemas.enums import OptionState
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse, Choice
from letta.utils import json_dumps


# TODO update to use better types
def add_inner_thoughts_to_functions(
functions: List[dict],
inner_thoughts_key: str,
inner_thoughts_description: str,
inner_thoughts_required: bool = True,
# inner_thoughts_to_front: bool = True, TODO support sorting somewhere, probably in the to_dict?
) -> List[dict]:
"""Add an inner_thoughts kwarg to every function in the provided list"""
# return copies
new_functions = []

# functions is a list of dicts in the OpenAI schema (https://platform.openai.com/docs/api-reference/chat/create)
for function_object in functions:
function_params = function_object["parameters"]["properties"]
required_params = list(function_object["parameters"]["required"])

# if the inner thoughts arg doesn't exist, add it
if inner_thoughts_key not in function_params:
function_params[inner_thoughts_key] = {
"type": "string",
"description": inner_thoughts_description,
}

# make sure it's tagged as required
new_function_object = copy.deepcopy(function_object)
if inner_thoughts_required and inner_thoughts_key not in required_params:
required_params.append(inner_thoughts_key)
new_function_object["parameters"]["required"] = required_params

new_functions.append(new_function_object)

# return a list of copies
return new_functions


def unpack_all_inner_thoughts_from_kwargs(
response: ChatCompletionResponse,
inner_thoughts_key: str,
) -> ChatCompletionResponse:
"""Strip the inner thoughts out of the tool call and put it in the message content"""
if len(response.choices) == 0:
raise ValueError(f"Unpacking inner thoughts from empty response not supported")

new_choices = []
for choice in response.choices:
new_choices.append(unpack_inner_thoughts_from_kwargs(choice, inner_thoughts_key))

# return an updated copy
new_response = response.model_copy(deep=True)
new_response.choices = new_choices
return new_response


def unpack_inner_thoughts_from_kwargs(choice: Choice, inner_thoughts_key: str) -> Choice:
message = choice.message
if message.role == "assistant" and message.tool_calls and len(message.tool_calls) >= 1:
if len(message.tool_calls) > 1:
warnings.warn(f"Unpacking inner thoughts from more than one tool call ({len(message.tool_calls)}) is not supported")
# TODO support multiple tool calls
tool_call = message.tool_calls[0]

try:
# Sadly we need to parse the JSON since args are in string format
func_args = dict(json.loads(tool_call.function.arguments))
if inner_thoughts_key in func_args:
# extract the inner thoughts
inner_thoughts = func_args.pop(inner_thoughts_key)

# replace the kwargs
new_choice = choice.model_copy(deep=True)
new_choice.message.tool_calls[0].function.arguments = json_dumps(func_args)
# also replace the message content
if new_choice.message.content is not None:
warnings.warn(f"Overwriting existing inner monologue ({new_choice.message.content}) with kwarg ({inner_thoughts})")
new_choice.message.content = inner_thoughts

return new_choice
else:
warnings.warn(f"Did not find inner thoughts in tool call: {str(tool_call)}")

except json.JSONDecodeError as e:
warnings.warn(f"Failed to strip inner thoughts from kwargs: {e}")
raise e


def is_context_overflow_error(exception: Union[requests.exceptions.RequestException, Exception]) -> bool:
"""Checks if an exception is due to context overflow (based on common OpenAI response messages)"""
from letta.utils import printd

match_string = OPENAI_CONTEXT_WINDOW_ERROR_SUBSTRING

# Backwards compatibility with openai python package/client v0.28 (pre-v1 client migration)
if match_string in str(exception):
printd(f"Found '{match_string}' in str(exception)={(str(exception))}")
return True

# Based on python requests + OpenAI REST API (/v1)
elif isinstance(exception, requests.exceptions.HTTPError):
if exception.response is not None and "application/json" in exception.response.headers.get("Content-Type", ""):
try:
error_details = exception.response.json()
if "error" not in error_details:
printd(f"HTTPError occurred, but couldn't find error field: {error_details}")
return False
else:
error_details = error_details["error"]

# Check for the specific error code
if error_details.get("code") == "context_length_exceeded":
printd(f"HTTPError occurred, caught error code {error_details.get('code')}")
return True
# Soft-check for "maximum context length" inside of the message
elif error_details.get("message") and "maximum context length" in error_details.get("message"):
printd(f"HTTPError occurred, found '{match_string}' in error message contents ({error_details})")
return True
else:
printd(f"HTTPError occurred, but unknown error message: {error_details}")
return False
except ValueError:
# JSON decoding failed
printd(f"HTTPError occurred ({exception}), but no JSON error message.")

# Generic fail
else:
return False


def derive_inner_thoughts_in_kwargs(inner_thoughts_in_kwargs_option: OptionState, model: str):
if inner_thoughts_in_kwargs_option == OptionState.DEFAULT:
# model that are known to not use `content` fields on tool calls
inner_thoughts_in_kwargs = "gpt-4o" in model or "gpt-4-turbo" in model or "gpt-3.5-turbo" in model
else:
inner_thoughts_in_kwargs = True if inner_thoughts_in_kwargs_option == OptionState.YES else False

if not isinstance(inner_thoughts_in_kwargs, bool):
warnings.warn(f"Bad type detected: {type(inner_thoughts_in_kwargs)}")
inner_thoughts_in_kwargs = bool(inner_thoughts_in_kwargs)

return inner_thoughts_in_kwargs
Loading
Loading