diff --git a/edb/server/protocol/ai_ext.py b/edb/server/protocol/ai_ext.py index cf131db05a41..b163c293aa76 100644 --- a/edb/server/protocol/ai_ext.py +++ b/edb/server/protocol/ai_ext.py @@ -1279,8 +1279,8 @@ async def _start_openai_like_chat( "stream": True, }, ) as event_source: - # we need to keep track of tool_index and finish_reason, - # need them in order to correctly send "content_block_stop" chunk for all tool calls + # we need tool_index and finish_reason to correctly + # send 'content_block_stop' chunk for tool call messages global tool_index tool_index = 0 global finish_reason @@ -1341,7 +1341,7 @@ async def _start_openai_like_chat( + b'\n\n' ) protocol.write_raw(event) - # when there's only one openai tool call it shows up here + # when there's only 1 openai tool call it shows up here if tool_calls: for tool_call in tool_calls: tool_index = tool_call["index"] @@ -1369,7 +1369,7 @@ async def _start_openai_like_chat( + b'\n\n' ) protocol.write_raw(event) - # when there are more than one openai tool calls, they show up here + # if there are few openai tool calls, they show up here # mistral tool calls always show up here elif tool_calls: # OpenAI provides index, Mistral doesn't @@ -1384,7 +1384,8 @@ async def _start_openai_like_chat( # send the stop chunk for the previous tool event = ( b'event: content_block_stop\n' - + b'data: {"type": "content_block_stop",' + + b'data: ' + + b'{"type": "content_block_stop",' + b'"index": ' + str(currentIndex - 1).encode() + b'}\n\n' @@ -1620,9 +1621,10 @@ async def _start_anthropic_chat( system_prompt_parts.append(message["content"]) elif message["role"] == "assistant" and "tool_calls" in message: - # Anthropic doesn't work when there is a list of tool calls in the - # same assistant chunk followed by a few tool_result chunks (or one user chunk with multiple tool_results). - # Assistant chunk should only have 1 tool_use, and should be followed by one tool_result chunk. + # Anthropic fails when an assistant chunk has multiple tool calls + # and is followed by several tool_result chunks (or a user chunk + # with multiple tool_results). Each assistant chunk should have + # only 1 tool_use, followed by 1 tool_result chunk. for tool_call in message["tool_calls"]: msg = { "role": "assistant", @@ -1656,7 +1658,8 @@ async def _start_anthropic_chat( system_prompt = "\n".join(system_prompt_parts) - # Make sure that every tool_use chunk is followed by an appropriate tool_result chunk + # Make sure that every tool_use chunk is + # followed by an appropriate tool_result chunk reordered_messages = [] # Separate tool_result messages by tool_use_id for faster access @@ -1674,7 +1677,7 @@ async def _start_anthropic_chat( if isinstance(message["content"], list): for item in message["content"]: if item["type"] == "tool_use": - # find the matching user tool_result message based on tool_use_id + # find matching tool_result msg based on tool_use_id tool_use_id = item["id"] if tool_use_id in tool_result_map: reordered_messages.append( @@ -1746,8 +1749,6 @@ async def _start_anthropic_chat( + json.dumps(sse_data).encode("utf-8") + b'\n\n' ) - # we don't send content_block_stop when msg text content ends, - # it should be okay since we don't consume that event in the provider data = sse.json() if ( data.get("content_block") @@ -1800,9 +1801,7 @@ async def _start_anthropic_chat( event_data = json.dumps( { "type": "message_delta", - "delta": message[ - "delta" - ], # should send stop reason + "delta": message["delta"], "usage": { "completion_tokens": message["usage"][ "output_tokens" @@ -1824,7 +1823,8 @@ async def _start_anthropic_chat( + b'data: {"type": "message_stop"}\n\n' ) protocol.write_raw(event) - protocol.close() # needed because stream doesn't close itself + # needed because stream doesn't close itself + protocol.close() protocol.close() else: result = await client.post("/messages", json={**params})