Skip to content

Commit

Permalink
Fix formatting in the ai ext
Browse files Browse the repository at this point in the history
  • Loading branch information
diksipav committed Nov 6, 2024
1 parent bcfed1c commit b71927d
Showing 1 changed file with 16 additions and 16 deletions.
32 changes: 16 additions & 16 deletions edb/server/protocol/ai_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -1279,8 +1279,8 @@ async def _start_openai_like_chat(
"stream": True,
},
) as event_source:
# we need to keep track of tool_index and finish_reason,
# need them in order to correctly send "content_block_stop" chunk for all tool calls
# we need tool_index and finish_reason to correctly
# send 'content_block_stop' chunk for tool call messages
global tool_index
tool_index = 0
global finish_reason
Expand Down Expand Up @@ -1341,7 +1341,7 @@ async def _start_openai_like_chat(
+ b'\n\n'
)
protocol.write_raw(event)
# when there's only one openai tool call it shows up here
# when there's only 1 openai tool call it shows up here
if tool_calls:
for tool_call in tool_calls:
tool_index = tool_call["index"]
Expand Down Expand Up @@ -1369,7 +1369,7 @@ async def _start_openai_like_chat(
+ b'\n\n'
)
protocol.write_raw(event)
# when there are more than one openai tool calls, they show up here
# if there are few openai tool calls, they show up here
# mistral tool calls always show up here
elif tool_calls:
# OpenAI provides index, Mistral doesn't
Expand All @@ -1384,7 +1384,8 @@ async def _start_openai_like_chat(
# send the stop chunk for the previous tool
event = (
b'event: content_block_stop\n'
+ b'data: {"type": "content_block_stop",'
+ b'data: '
+ b'{"type": "content_block_stop",'
+ b'"index": '
+ str(currentIndex - 1).encode()
+ b'}\n\n'
Expand Down Expand Up @@ -1620,9 +1621,10 @@ async def _start_anthropic_chat(
system_prompt_parts.append(message["content"])

elif message["role"] == "assistant" and "tool_calls" in message:
# Anthropic doesn't work when there is a list of tool calls in the
# same assistant chunk followed by a few tool_result chunks (or one user chunk with multiple tool_results).
# Assistant chunk should only have 1 tool_use, and should be followed by one tool_result chunk.
# Anthropic fails when an assistant chunk has multiple tool calls
# and is followed by several tool_result chunks (or a user chunk
# with multiple tool_results). Each assistant chunk should have
# only 1 tool_use, followed by 1 tool_result chunk.
for tool_call in message["tool_calls"]:
msg = {
"role": "assistant",
Expand Down Expand Up @@ -1656,7 +1658,8 @@ async def _start_anthropic_chat(

system_prompt = "\n".join(system_prompt_parts)

# Make sure that every tool_use chunk is followed by an appropriate tool_result chunk
# Make sure that every tool_use chunk is
# followed by an appropriate tool_result chunk
reordered_messages = []

# Separate tool_result messages by tool_use_id for faster access
Expand All @@ -1674,7 +1677,7 @@ async def _start_anthropic_chat(
if isinstance(message["content"], list):
for item in message["content"]:
if item["type"] == "tool_use":
# find the matching user tool_result message based on tool_use_id
# find matching tool_result msg based on tool_use_id
tool_use_id = item["id"]
if tool_use_id in tool_result_map:
reordered_messages.append(
Expand Down Expand Up @@ -1746,8 +1749,6 @@ async def _start_anthropic_chat(
+ json.dumps(sse_data).encode("utf-8")
+ b'\n\n'
)
# we don't send content_block_stop when msg text content ends,
# it should be okay since we don't consume that event in the provider
data = sse.json()
if (
data.get("content_block")
Expand Down Expand Up @@ -1800,9 +1801,7 @@ async def _start_anthropic_chat(
event_data = json.dumps(
{
"type": "message_delta",
"delta": message[
"delta"
], # should send stop reason
"delta": message["delta"],
"usage": {
"completion_tokens": message["usage"][
"output_tokens"
Expand All @@ -1824,7 +1823,8 @@ async def _start_anthropic_chat(
+ b'data: {"type": "message_stop"}\n\n'
)
protocol.write_raw(event)
protocol.close() # needed because stream doesn't close itself
# needed because stream doesn't close itself
protocol.close()
protocol.close()
else:
result = await client.post("/messages", json={**params})
Expand Down

0 comments on commit b71927d

Please sign in to comment.