From acfd4bbee62d0db85a88af3000627582ef6e3188 Mon Sep 17 00:00:00 2001 From: anakin87 Date: Fri, 20 Dec 2024 18:55:11 +0100 Subject: [PATCH 1/4] add ChatMessage.from_openai_dict_format --- haystack/dataclasses/chat_message.py | 63 +++++++++++++++++++++ test/dataclasses/test_chat_message.py | 80 +++++++++++++++++++++++++++ 2 files changed, 143 insertions(+) diff --git a/haystack/dataclasses/chat_message.py b/haystack/dataclasses/chat_message.py index 0a028f101e..1d1e8f4194 100644 --- a/haystack/dataclasses/chat_message.py +++ b/haystack/dataclasses/chat_message.py @@ -8,6 +8,10 @@ from enum import Enum from typing import Any, Dict, List, Optional, Sequence, Union +from haystack import logging + +logger = logging.getLogger(__name__) + LEGACY_INIT_PARAMETERS = {"role", "content", "meta", "name"} @@ -426,3 +430,62 @@ def to_openai_dict_format(self) -> Dict[str, Any]: ) openai_msg["tool_calls"] = openai_tool_calls return openai_msg + + @classmethod + def from_openai_dict_format(cls, message: Dict[str, Any]) -> "ChatMessage": + """ + Create a ChatMessage from a dictionary in the format expected by OpenAI's Chat API. + + NOTE: While OpenAI's API requires `tool_call_id` in both tool calls and tool messages, this method + accepts messages without it to support shallow OpenAI-compatible APIs. However, if you plan to use the + resulting ChatMessage with OpenAI, you must include `tool_call_id` or you'll encounter validation errors. + + :param message: + The OpenAI dictionary to build the ChatMessage object. + :returns: + The created ChatMessage object. + + :raises ValueError: + If the message dictionary is missing required fields. + """ + + if "role" not in message: + raise ValueError("The `role` field is required in the message dictionary.") + role = message["role"] + + content = message.get("content") + name = message.get("name") + tool_calls = message.get("tool_calls") + tool_call_id = message.get("tool_call_id") + + if role == "assistant": + if not content and not tool_calls: + raise ValueError("For assistant messages, either `content` or `tool_calls` must be present.") + + haystack_tool_calls = None + if tool_calls: + haystack_tool_calls = [] + for tc in tool_calls: + if "function" not in tc: + raise ValueError("Tool calls must contain the `function` field") + haystack_tc = ToolCall( + id=tc.get("id"), + tool_name=tc["function"]["name"], + arguments=json.loads(tc["function"]["arguments"]), + ) + haystack_tool_calls.append(haystack_tc) + return cls.from_assistant(text=content, name=name, tool_calls=haystack_tool_calls) + + if not content: + raise ValueError(f"The `content` field is required for {role} messages.") + + if role == "user": + return cls.from_user(text=content, name=name) + if role in ["system", "developer"]: + return cls.from_system(text=content, name=name) + if role == "tool": + return cls.from_tool( + tool_result=content, origin=ToolCall(id=tool_call_id, tool_name="", arguments={}), error=False + ) + + raise ValueError(f"Unsupported role: {role}") diff --git a/test/dataclasses/test_chat_message.py b/test/dataclasses/test_chat_message.py index 2209af998f..23a214ca29 100644 --- a/test/dataclasses/test_chat_message.py +++ b/test/dataclasses/test_chat_message.py @@ -288,6 +288,86 @@ def test_to_openai_dict_format_invalid(): message.to_openai_dict_format() +def test_from_openai_dict_format_user_message(): + openai_msg = {"role": "user", "content": "Hello, how are you?", "name": "John"} + message = ChatMessage.from_openai_dict_format(openai_msg) + assert message.role.value == "user" + assert message.text == "Hello, how are you?" + assert message.name == "John" + + +def test_from_openai_dict_format_system_message(): + openai_msg = {"role": "system", "content": "You are a helpful assistant"} + message = ChatMessage.from_openai_dict_format(openai_msg) + assert message.role.value == "system" + assert message.text == "You are a helpful assistant" + + +def test_from_openai_dict_format_assistant_message_with_content(): + openai_msg = {"role": "assistant", "content": "I can help with that"} + message = ChatMessage.from_openai_dict_format(openai_msg) + assert message.role.value == "assistant" + assert message.text == "I can help with that" + + +def test_from_openai_dict_format_assistant_message_with_tool_calls(): + openai_msg = { + "role": "assistant", + "content": None, + "tool_calls": [{"id": "call_123", "function": {"name": "get_weather", "arguments": '{"location": "Berlin"}'}}], + } + message = ChatMessage.from_openai_dict_format(openai_msg) + assert message.role.value == "assistant" + assert message.text is None + assert len(message.tool_calls) == 1 + tool_call = message.tool_calls[0] + assert tool_call.id == "call_123" + assert tool_call.tool_name == "get_weather" + assert tool_call.arguments == {"location": "Berlin"} + + +def test_from_openai_dict_format_tool_message(): + openai_msg = {"role": "tool", "content": "The weather is sunny", "tool_call_id": "call_123"} + message = ChatMessage.from_openai_dict_format(openai_msg) + assert message.role.value == "tool" + assert message.tool_call_result.result == "The weather is sunny" + assert message.tool_call_result.origin.id == "call_123" + + +def test_from_openai_dict_format_tool_without_id(): + openai_msg = {"role": "tool", "content": "The weather is sunny"} + message = ChatMessage.from_openai_dict_format(openai_msg) + assert message.role.value == "tool" + assert message.tool_call_result.result == "The weather is sunny" + assert message.tool_call_result.origin.id is None + + +def test_from_openai_dict_format_missing_role(): + with pytest.raises(ValueError): + ChatMessage.from_openai_dict_format({"content": "test"}) + + +def test_from_openai_dict_format_missing_content(): + with pytest.raises(ValueError): + ChatMessage.from_openai_dict_format({"role": "user"}) + + +def test_from_openai_dict_format_invalid_tool_calls(): + openai_msg = {"role": "assistant", "tool_calls": [{"invalid": "format"}]} + with pytest.raises(ValueError): + ChatMessage.from_openai_dict_format(openai_msg) + + +def test_from_openai_dict_format_unsupported_role(): + with pytest.raises(ValueError): + ChatMessage.from_openai_dict_format({"role": "invalid", "content": "test"}) + + +def test_from_openai_dict_format_assistant_missing_content_and_tool_calls(): + with pytest.raises(ValueError): + ChatMessage.from_openai_dict_format({"role": "assistant", "irrelevant": "irrelevant"}) + + @pytest.mark.integration def test_apply_chat_templating_on_chat_message(): messages = [ChatMessage.from_system("You are good assistant"), ChatMessage.from_user("I have a question")] From 5cda7a01881b2e39b29e01df9936add135cee51e Mon Sep 17 00:00:00 2001 From: anakin87 Date: Fri, 20 Dec 2024 19:01:03 +0100 Subject: [PATCH 2/4] remove print --- haystack/components/generators/chat/openai.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/haystack/components/generators/chat/openai.py b/haystack/components/generators/chat/openai.py index 932fc3345b..09e7d9a1fa 100644 --- a/haystack/components/generators/chat/openai.py +++ b/haystack/components/generators/chat/openai.py @@ -308,10 +308,6 @@ def _prepare_api_call( # noqa: PLR0913 } def _handle_stream_response(self, chat_completion: Stream, callback: StreamingCallbackT) -> List[ChatMessage]: - print("callback") - print(callback) - print("-" * 100) - chunks: List[StreamingChunk] = [] chunk = None From 5acdd106019b75d76d775db86dab1c7aa9829382 Mon Sep 17 00:00:00 2001 From: anakin87 Date: Fri, 20 Dec 2024 19:06:08 +0100 Subject: [PATCH 3/4] release note --- .../notes/chatmsg-from-openai-dict-f15b50d38bdf9abb.yaml | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 releasenotes/notes/chatmsg-from-openai-dict-f15b50d38bdf9abb.yaml diff --git a/releasenotes/notes/chatmsg-from-openai-dict-f15b50d38bdf9abb.yaml b/releasenotes/notes/chatmsg-from-openai-dict-f15b50d38bdf9abb.yaml new file mode 100644 index 0000000000..c51ec7f0fd --- /dev/null +++ b/releasenotes/notes/chatmsg-from-openai-dict-f15b50d38bdf9abb.yaml @@ -0,0 +1,5 @@ +--- +enhancements: + - | + Add the `from_openai_dict_format` class method to the `ChatMessage` class. It allows you to create a `ChatMessage` + from a dictionary in the format expected by OpenAI's Chat API. From edc68d129c8b2163eb7e4272599e3aebae0ed261 Mon Sep 17 00:00:00 2001 From: anakin87 Date: Fri, 20 Dec 2024 19:12:17 +0100 Subject: [PATCH 4/4] improve docstring --- haystack/dataclasses/chat_message.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/haystack/dataclasses/chat_message.py b/haystack/dataclasses/chat_message.py index 1d1e8f4194..76d1be8a42 100644 --- a/haystack/dataclasses/chat_message.py +++ b/haystack/dataclasses/chat_message.py @@ -437,8 +437,9 @@ def from_openai_dict_format(cls, message: Dict[str, Any]) -> "ChatMessage": Create a ChatMessage from a dictionary in the format expected by OpenAI's Chat API. NOTE: While OpenAI's API requires `tool_call_id` in both tool calls and tool messages, this method - accepts messages without it to support shallow OpenAI-compatible APIs. However, if you plan to use the - resulting ChatMessage with OpenAI, you must include `tool_call_id` or you'll encounter validation errors. + accepts messages without it to support shallow OpenAI-compatible APIs. + If you plan to use the resulting ChatMessage with OpenAI, you must include `tool_call_id` or you'll + encounter validation errors. :param message: The OpenAI dictionary to build the ChatMessage object.