Skip to content

Commit

Permalink
Add code explaining comments to the ai ext
Browse files Browse the repository at this point in the history
  • Loading branch information
diksipav committed Nov 6, 2024
1 parent 5302b9a commit 5738d4f
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 104 deletions.
6 changes: 3 additions & 3 deletions edb/lib/ext/ai.edgeql
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ CREATE EXTENSION PACKAGE ai VERSION '1.0' {
};

alter property api_url {
set default := 'https://api.openai.com/v1';
set default := 'https://api.openai.com/v1'
};

alter property api_style {
Expand All @@ -107,7 +107,7 @@ CREATE EXTENSION PACKAGE ai VERSION '1.0' {
};

alter property api_url {
set default := 'https://api.mistral.ai/v1';
set default := 'https://api.mistral.ai/v1'
};

alter property api_style {
Expand All @@ -128,7 +128,7 @@ CREATE EXTENSION PACKAGE ai VERSION '1.0' {
};

alter property api_url {
set default := 'https://api.anthropic.com/v1';
set default := 'https://api.anthropic.com/v1'
};

alter property api_style {
Expand Down
168 changes: 67 additions & 101 deletions edb/server/protocol/ai_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -1228,6 +1228,8 @@ async def _start_openai_like_chat(
"stream": True,
}
) as event_source:
# we need to keep track of tool_index and finish_reason,
# need them in order to correctly send "content_block_stop" chunk for all tool calls
global tool_index
tool_index = 0
global finish_reason
Expand All @@ -1245,6 +1247,7 @@ async def _start_openai_like_chat(
continue

if sse.data == "[DONE]":
# mistral doesn't send finish_reason for tool calls
if finish_reason=="unknown":
event = (
b'event: content_block_stop\n'
Expand All @@ -1265,6 +1268,7 @@ async def _start_openai_like_chat(
delta = data.get("delta")
role = delta.get("role")
tool_calls = delta.get("tool_calls")

if role:
event_data = json.dumps({
"type": "message_start",
Expand All @@ -1280,8 +1284,8 @@ async def _start_openai_like_chat(
+ b'data: ' + event_data + b'\n\n'
)
protocol.write_raw(event)
if tool_calls:
# when there's only one openai tool call it shows up here
if tool_calls:
for tool_call in tool_calls:
tool_index = tool_call["index"]
event_data = json.dumps({
Expand All @@ -1300,13 +1304,16 @@ async def _start_openai_like_chat(
+ b'data:' + event_data + b'\n\n'
)
protocol.write_raw(event)
# when there are more than one openai tool calls, they show up here
# mistral tool calls always show up here
elif tool_calls:
# OpenAI provides index, Mistral doesn't
for index, tool_call in enumerate(tool_calls):
currentIndex = tool_call.get("index") or index
if tool_call.get("type")=="function" or "id" in tool_call:
if tool_call.get("type") == "function" or "id" in tool_call:
if currentIndex > 0:
tool_index = currentIndex
# send the stop chunk for the previous tool
event = (
b'event: content_block_stop\n'
+ b'data: {"type": "content_block_stop",'
Expand Down Expand Up @@ -1478,7 +1485,9 @@ async def _start_openai_chat(
tools
)


# Anthropic differs from OpenAI and Mistral as there's no tool chunk:
# tool_call(tool_use) is part of the assistant chunk, and
# tool_result is part of the user chunk.
async def _start_anthropic_chat(
protocol: protocol.HttpProtocol,
request: protocol.HttpRequest,
Expand Down Expand Up @@ -1517,88 +1526,77 @@ async def _start_anthropic_chat(
for message in messages:
if message["role"] == "system":
system_prompt_parts.append(message["content"])
else:
if message["role"] == "assistant" and "tool_calls" in message:
# in case Anthropic fix their bag we can return to this
# it doesn't work when u have list of tool-calls in an assistant msg
# msg = {
# "role": "assistant",
# "content": [
# {
# "id": tool_call["id"],
# "type": "tool_use",
# "name": tool_call["function"]["name"],
# "input": json.loads(tool_call["function"]["arguments"]),
# }
# for tool_call in message["tool_calls"]
# ],
# }
# anthropic_messages.append(msg)
for tool_call in message["tool_calls"]:
msg = {
"role": "assistant",
"content": [
{
"id": tool_call["id"],
"type": "tool_use",
"name": tool_call["function"]["name"],
"input": json.loads(tool_call["function"]["arguments"]),
}
],
}
anthropic_messages.append(msg)

# Check if message is a tool result
elif message["role"] == "tool":
tool_result = {
"role": "user",

elif message["role"] == "assistant" and "tool_calls" in message:
# Anthropic doesn't work when there is a list of tool calls in the
# same assistant chunk followed by a few tool_result chunks (or one user chunk with multiple tool_results).
# Assistant chunk should only have 1 tool_use, and should be followed by one tool_result chunk.
for tool_call in message["tool_calls"]:
msg = {
"role": "assistant",
"content": [
{
"type": "tool_result",
"tool_use_id": message["tool_call_id"],
"content": message["content"]
"id": tool_call["id"],
"type": "tool_use",
"name": tool_call["function"]["name"],
"input": json.loads(tool_call["function"]["arguments"]),
}
],
}
anthropic_messages.append(tool_result)

else:
anthropic_messages.append(message)
anthropic_messages.append(msg)
# Check if message is a tool result
elif message["role"] == "tool":
tool_result = {
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": message["tool_call_id"],
"content": message["content"]
}
],
}
anthropic_messages.append(tool_result)
else:
anthropic_messages.append(message)

system_prompt = "\n".join(system_prompt_parts)

# Separate tool_result messages by tool_use_id for faster access

# Make sure that every tool_use chunk is followed by an appropriate tool_result chunk
reordered_messages = []

# Separate tool_result messages by tool_use_id for faster access
tool_result_map = {
item["content"][0]["tool_use_id"]: item
for item in anthropic_messages
if item["role"] == "user" and isinstance(item["content"][0], dict) and item["content"][0]["type"] == "tool_result"
}

# Transform assistant messages and interleave with corresponding user messages
transformed_messages = []

for message in anthropic_messages:
if message["role"] == "assistant":
transformed_messages.append(message)
for item in message["content"]:
if item["type"]=="tool_use":
# Find the matching user tool_result message based on tool_use_id
tool_use_id = item["id"]
if tool_use_id in tool_result_map:
transformed_messages.append(tool_result_map[tool_use_id])
elif not (message["role"] == "user" and isinstance(message["content"][0], dict) and message["content"][0]["type"] == "tool_result"): transformed_messages.append(message)
reordered_messages.append(message)
if isinstance(message["content"], list):
for item in message["content"]:
if item["type"]=="tool_use":
# find the matching user tool_result message based on tool_use_id
tool_use_id = item["id"]
if tool_use_id in tool_result_map:
reordered_messages.append(tool_result_map[tool_use_id])
# append user message that is not tool_result
elif not (message["role"] == "user" and isinstance(message["content"][0], dict) and message["content"][0]["type"] == "tool_result"): reordered_messages.append(message)

params = {
"model": model_name,
"messages": transformed_messages,
"messages": reordered_messages,
"system": system_prompt,
**({"temperature": temperature} if temperature is not None else {}),
**({"top_p": top_p} if top_p is not None else {}),
**{"max_tokens": max_tokens if max_tokens is not None else 4096},
**({"top_k": top_k} if top_k is not None else {}),
**({"tools": tools} if tools is not None else {}),
}

if stream:
async with aconnect_sse(
client,
Expand Down Expand Up @@ -1691,13 +1689,12 @@ async def _start_anthropic_chat(

protocol.write_raw(event)
elif sse.event == "message_stop":

event = (
b'event: message_stop\n'
+ b'data: {"type": "message_stop"}\n\n'
)
protocol.write_raw(event)

protocol.close() # needed because stream doesn't close itself
protocol.close()
else:
result = await client.post(
Expand Down Expand Up @@ -1856,7 +1853,7 @@ async def _handle_rag_request(
ctx_max_obj_count = 5
elif not isinstance(ctx_max_obj_count, int) or ctx_max_obj_count <= 0:
raise TypeError(
'"context.max_object_count" must be an postitive integer')
'"context.max_object_count" must be a positive integer')

prompt_id = None
prompt_name = None
Expand Down Expand Up @@ -1889,14 +1886,13 @@ async def _handle_rag_request(
if (
not isinstance(entry, dict)
or not entry.get("role")
or not entry.get("content")
# or len(entry) > 3
or "content" not in entry # content can be empty string too
or len(entry) > 3
):
print("DIDI ERR2", custom_prompt)
# raise TypeError(
# "prompt.custom must be a list of {role, content} "
# "objects"
# )
raise TypeError(
"prompt.custom must be a list of {role, content} "
"objects"
)
custom_prompt_messages.append(entry)

except Exception as ex:
Expand Down Expand Up @@ -1999,38 +1995,8 @@ async def _handle_rag_request(

prompt_messages.append(dict(role=role, content=content))

tool_messages = []
non_tool_messages = []
found_assistant_followed_by_tool = False

# when using mistral tool messages should be the last and user msg can't appear before the tool one
for i, message in enumerate(custom_prompt_messages):
if not found_assistant_followed_by_tool:
# Check if the next message is either a "tool" role or a "user" role with "tool_result" in content
next_message = custom_prompt_messages[i + 1] if i + 1 < len(custom_prompt_messages) else None
if (
message.get("role") == "assistant"
and next_message
and (
next_message.get("role") == "tool"
or (
next_message.get("role") == "user"
and next_message.get("content")
and next_message["content"][0].get("type") == "tool_result"
)
)
):
found_assistant_followed_by_tool = True
tool_messages.append(message)
else:
non_tool_messages.append(message)
else:
tool_messages.append(message)
messages = prompt_messages + custom_prompt_messages

messages = prompt_messages + non_tool_messages
messages.append(dict(role="user", content=query))
messages = messages + tool_messages

await _start_chat(
protocol,
request,
Expand Down

0 comments on commit 5738d4f

Please sign in to comment.