From 5738d4f02b4626564240f0b082929b9e93b78c71 Mon Sep 17 00:00:00 2001 From: Dijana Pavlovic Date: Wed, 6 Nov 2024 11:52:22 +0000 Subject: [PATCH] Add code explaining comments to the ai ext --- edb/lib/ext/ai.edgeql | 6 +- edb/server/protocol/ai_ext.py | 168 ++++++++++++++-------------------- 2 files changed, 70 insertions(+), 104 deletions(-) diff --git a/edb/lib/ext/ai.edgeql b/edb/lib/ext/ai.edgeql index c61efc89fca..38a03aa6819 100644 --- a/edb/lib/ext/ai.edgeql +++ b/edb/lib/ext/ai.edgeql @@ -86,7 +86,7 @@ CREATE EXTENSION PACKAGE ai VERSION '1.0' { }; alter property api_url { - set default := 'https://api.openai.com/v1'; + set default := 'https://api.openai.com/v1' }; alter property api_style { @@ -107,7 +107,7 @@ CREATE EXTENSION PACKAGE ai VERSION '1.0' { }; alter property api_url { - set default := 'https://api.mistral.ai/v1'; + set default := 'https://api.mistral.ai/v1' }; alter property api_style { @@ -128,7 +128,7 @@ CREATE EXTENSION PACKAGE ai VERSION '1.0' { }; alter property api_url { - set default := 'https://api.anthropic.com/v1'; + set default := 'https://api.anthropic.com/v1' }; alter property api_style { diff --git a/edb/server/protocol/ai_ext.py b/edb/server/protocol/ai_ext.py index f184dc8bf08..2d5e6396086 100644 --- a/edb/server/protocol/ai_ext.py +++ b/edb/server/protocol/ai_ext.py @@ -1228,6 +1228,8 @@ async def _start_openai_like_chat( "stream": True, } ) as event_source: + # we need to keep track of tool_index and finish_reason, + # need them in order to correctly send "content_block_stop" chunk for all tool calls global tool_index tool_index = 0 global finish_reason @@ -1245,6 +1247,7 @@ async def _start_openai_like_chat( continue if sse.data == "[DONE]": + # mistral doesn't send finish_reason for tool calls if finish_reason=="unknown": event = ( b'event: content_block_stop\n' @@ -1265,6 +1268,7 @@ async def _start_openai_like_chat( delta = data.get("delta") role = delta.get("role") tool_calls = delta.get("tool_calls") + if role: event_data = json.dumps({ "type": "message_start", @@ -1280,8 +1284,8 @@ async def _start_openai_like_chat( + b'data: ' + event_data + b'\n\n' ) protocol.write_raw(event) - - if tool_calls: + # when there's only one openai tool call it shows up here + if tool_calls: for tool_call in tool_calls: tool_index = tool_call["index"] event_data = json.dumps({ @@ -1300,13 +1304,16 @@ async def _start_openai_like_chat( + b'data:' + event_data + b'\n\n' ) protocol.write_raw(event) + # when there are more than one openai tool calls, they show up here + # mistral tool calls always show up here elif tool_calls: # OpenAI provides index, Mistral doesn't for index, tool_call in enumerate(tool_calls): currentIndex = tool_call.get("index") or index - if tool_call.get("type")=="function" or "id" in tool_call: + if tool_call.get("type") == "function" or "id" in tool_call: if currentIndex > 0: tool_index = currentIndex + # send the stop chunk for the previous tool event = ( b'event: content_block_stop\n' + b'data: {"type": "content_block_stop",' @@ -1478,7 +1485,9 @@ async def _start_openai_chat( tools ) - +# Anthropic differs from OpenAI and Mistral as there's no tool chunk: +# tool_call(tool_use) is part of the assistant chunk, and +# tool_result is part of the user chunk. async def _start_anthropic_chat( protocol: protocol.HttpProtocol, request: protocol.HttpRequest, @@ -1517,80 +1526,69 @@ async def _start_anthropic_chat( for message in messages: if message["role"] == "system": system_prompt_parts.append(message["content"]) - else: - if message["role"] == "assistant" and "tool_calls" in message: - # in case Anthropic fix their bag we can return to this - # it doesn't work when u have list of tool-calls in an assistant msg - # msg = { - # "role": "assistant", - # "content": [ - # { - # "id": tool_call["id"], - # "type": "tool_use", - # "name": tool_call["function"]["name"], - # "input": json.loads(tool_call["function"]["arguments"]), - # } - # for tool_call in message["tool_calls"] - # ], - # } - # anthropic_messages.append(msg) - for tool_call in message["tool_calls"]: - msg = { - "role": "assistant", - "content": [ - { - "id": tool_call["id"], - "type": "tool_use", - "name": tool_call["function"]["name"], - "input": json.loads(tool_call["function"]["arguments"]), - } - ], - } - anthropic_messages.append(msg) - - # Check if message is a tool result - elif message["role"] == "tool": - tool_result = { - "role": "user", + + elif message["role"] == "assistant" and "tool_calls" in message: + # Anthropic doesn't work when there is a list of tool calls in the + # same assistant chunk followed by a few tool_result chunks (or one user chunk with multiple tool_results). + # Assistant chunk should only have 1 tool_use, and should be followed by one tool_result chunk. + for tool_call in message["tool_calls"]: + msg = { + "role": "assistant", "content": [ { - "type": "tool_result", - "tool_use_id": message["tool_call_id"], - "content": message["content"] + "id": tool_call["id"], + "type": "tool_use", + "name": tool_call["function"]["name"], + "input": json.loads(tool_call["function"]["arguments"]), } ], } - anthropic_messages.append(tool_result) - - else: - anthropic_messages.append(message) + anthropic_messages.append(msg) + # Check if message is a tool result + elif message["role"] == "tool": + tool_result = { + "role": "user", + "content": [ + { + "type": "tool_result", + "tool_use_id": message["tool_call_id"], + "content": message["content"] + } + ], + } + anthropic_messages.append(tool_result) + else: + anthropic_messages.append(message) system_prompt = "\n".join(system_prompt_parts) - # Separate tool_result messages by tool_use_id for faster access + + # Make sure that every tool_use chunk is followed by an appropriate tool_result chunk + reordered_messages = [] + + # Separate tool_result messages by tool_use_id for faster access tool_result_map = { item["content"][0]["tool_use_id"]: item for item in anthropic_messages if item["role"] == "user" and isinstance(item["content"][0], dict) and item["content"][0]["type"] == "tool_result" } - - # Transform assistant messages and interleave with corresponding user messages - transformed_messages = [] for message in anthropic_messages: if message["role"] == "assistant": - transformed_messages.append(message) - for item in message["content"]: - if item["type"]=="tool_use": - # Find the matching user tool_result message based on tool_use_id - tool_use_id = item["id"] - if tool_use_id in tool_result_map: - transformed_messages.append(tool_result_map[tool_use_id]) - elif not (message["role"] == "user" and isinstance(message["content"][0], dict) and message["content"][0]["type"] == "tool_result"): transformed_messages.append(message) + reordered_messages.append(message) + if isinstance(message["content"], list): + for item in message["content"]: + if item["type"]=="tool_use": + # find the matching user tool_result message based on tool_use_id + tool_use_id = item["id"] + if tool_use_id in tool_result_map: + reordered_messages.append(tool_result_map[tool_use_id]) + # append user message that is not tool_result + elif not (message["role"] == "user" and isinstance(message["content"][0], dict) and message["content"][0]["type"] == "tool_result"): reordered_messages.append(message) params = { "model": model_name, - "messages": transformed_messages, + "messages": reordered_messages, "system": system_prompt, **({"temperature": temperature} if temperature is not None else {}), **({"top_p": top_p} if top_p is not None else {}), @@ -1598,7 +1596,7 @@ async def _start_anthropic_chat( **({"top_k": top_k} if top_k is not None else {}), **({"tools": tools} if tools is not None else {}), } - + if stream: async with aconnect_sse( client, @@ -1691,13 +1689,12 @@ async def _start_anthropic_chat( protocol.write_raw(event) elif sse.event == "message_stop": - event = ( b'event: message_stop\n' + b'data: {"type": "message_stop"}\n\n' ) protocol.write_raw(event) - + protocol.close() # needed because stream doesn't close itself protocol.close() else: result = await client.post( @@ -1856,7 +1853,7 @@ async def _handle_rag_request( ctx_max_obj_count = 5 elif not isinstance(ctx_max_obj_count, int) or ctx_max_obj_count <= 0: raise TypeError( - '"context.max_object_count" must be an postitive integer') + '"context.max_object_count" must be a positive integer') prompt_id = None prompt_name = None @@ -1889,14 +1886,13 @@ async def _handle_rag_request( if ( not isinstance(entry, dict) or not entry.get("role") - or not entry.get("content") - # or len(entry) > 3 + or "content" not in entry # content can be empty string too + or len(entry) > 3 ): - print("DIDI ERR2", custom_prompt) - # raise TypeError( - # "prompt.custom must be a list of {role, content} " - # "objects" - # ) + raise TypeError( + "prompt.custom must be a list of {role, content} " + "objects" + ) custom_prompt_messages.append(entry) except Exception as ex: @@ -1999,38 +1995,8 @@ async def _handle_rag_request( prompt_messages.append(dict(role=role, content=content)) - tool_messages = [] - non_tool_messages = [] - found_assistant_followed_by_tool = False - - # when using mistral tool messages should be the last and user msg can't appear before the tool one - for i, message in enumerate(custom_prompt_messages): - if not found_assistant_followed_by_tool: - # Check if the next message is either a "tool" role or a "user" role with "tool_result" in content - next_message = custom_prompt_messages[i + 1] if i + 1 < len(custom_prompt_messages) else None - if ( - message.get("role") == "assistant" - and next_message - and ( - next_message.get("role") == "tool" - or ( - next_message.get("role") == "user" - and next_message.get("content") - and next_message["content"][0].get("type") == "tool_result" - ) - ) - ): - found_assistant_followed_by_tool = True - tool_messages.append(message) - else: - non_tool_messages.append(message) - else: - tool_messages.append(message) + messages = prompt_messages + custom_prompt_messages - messages = prompt_messages + non_tool_messages - messages.append(dict(role="user", content=query)) - messages = messages + tool_messages - await _start_chat( protocol, request,