Skip to content

feat: introduce class method to create ChatMessage from the OpenAI dictionary format #8670

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 8 commits into from
Jan 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions haystack/components/generators/chat/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,10 +308,6 @@ def _prepare_api_call( # noqa: PLR0913
}

def _handle_stream_response(self, chat_completion: Stream, callback: StreamingCallbackT) -> List[ChatMessage]:
print("callback")
print(callback)
print("-" * 100)

chunks: List[StreamingChunk] = []
chunk = None

Expand Down
78 changes: 78 additions & 0 deletions haystack/dataclasses/chat_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,3 +426,81 @@ def to_openai_dict_format(self) -> Dict[str, Any]:
)
openai_msg["tool_calls"] = openai_tool_calls
return openai_msg

@staticmethod
def _validate_openai_message(message: Dict[str, Any]) -> None:
"""
Validate that a message dictionary follows OpenAI's Chat API format.

:param message: The message dictionary to validate
:raises ValueError: If the message format is invalid
"""
if "role" not in message:
raise ValueError("The `role` field is required in the message dictionary.")

role = message["role"]
content = message.get("content")
tool_calls = message.get("tool_calls")

if role not in ["assistant", "user", "system", "developer", "tool"]:
raise ValueError(f"Unsupported role: {role}")

if role == "assistant":
if not content and not tool_calls:
raise ValueError("For assistant messages, either `content` or `tool_calls` must be present.")
if tool_calls:
for tc in tool_calls:
if "function" not in tc:
raise ValueError("Tool calls must contain the `function` field")
elif not content:
raise ValueError(f"The `content` field is required for {role} messages.")

@classmethod
def from_openai_dict_format(cls, message: Dict[str, Any]) -> "ChatMessage":
"""
Create a ChatMessage from a dictionary in the format expected by OpenAI's Chat API.

NOTE: While OpenAI's API requires `tool_call_id` in both tool calls and tool messages, this method
accepts messages without it to support shallow OpenAI-compatible APIs.
If you plan to use the resulting ChatMessage with OpenAI, you must include `tool_call_id` or you'll
encounter validation errors.

:param message:
The OpenAI dictionary to build the ChatMessage object.
:returns:
The created ChatMessage object.

:raises ValueError:
If the message dictionary is missing required fields.
"""
cls._validate_openai_message(message)

role = message["role"]
content = message.get("content")
name = message.get("name")
tool_calls = message.get("tool_calls")
tool_call_id = message.get("tool_call_id")

if role == "assistant":
haystack_tool_calls = None
if tool_calls:
haystack_tool_calls = []
for tc in tool_calls:
haystack_tc = ToolCall(
id=tc.get("id"),
tool_name=tc["function"]["name"],
arguments=json.loads(tc["function"]["arguments"]),
)
haystack_tool_calls.append(haystack_tc)
return cls.from_assistant(text=content, name=name, tool_calls=haystack_tool_calls)

assert content is not None # ensured by _validate_openai_message, but we need to make mypy happy

if role == "user":
return cls.from_user(text=content, name=name)
if role in ["system", "developer"]:
return cls.from_system(text=content, name=name)

return cls.from_tool(
tool_result=content, origin=ToolCall(id=tool_call_id, tool_name="", arguments={}), error=False
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
enhancements:
- |
Add the `from_openai_dict_format` class method to the `ChatMessage` class. It allows you to create a `ChatMessage`
from a dictionary in the format expected by OpenAI's Chat API.
80 changes: 80 additions & 0 deletions test/dataclasses/test_chat_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,86 @@ def test_to_openai_dict_format_invalid():
message.to_openai_dict_format()


def test_from_openai_dict_format_user_message():
openai_msg = {"role": "user", "content": "Hello, how are you?", "name": "John"}
message = ChatMessage.from_openai_dict_format(openai_msg)
assert message.role.value == "user"
assert message.text == "Hello, how are you?"
assert message.name == "John"


def test_from_openai_dict_format_system_message():
openai_msg = {"role": "system", "content": "You are a helpful assistant"}
message = ChatMessage.from_openai_dict_format(openai_msg)
assert message.role.value == "system"
assert message.text == "You are a helpful assistant"


def test_from_openai_dict_format_assistant_message_with_content():
openai_msg = {"role": "assistant", "content": "I can help with that"}
message = ChatMessage.from_openai_dict_format(openai_msg)
assert message.role.value == "assistant"
assert message.text == "I can help with that"


def test_from_openai_dict_format_assistant_message_with_tool_calls():
openai_msg = {
"role": "assistant",
"content": None,
"tool_calls": [{"id": "call_123", "function": {"name": "get_weather", "arguments": '{"location": "Berlin"}'}}],
}
message = ChatMessage.from_openai_dict_format(openai_msg)
assert message.role.value == "assistant"
assert message.text is None
assert len(message.tool_calls) == 1
tool_call = message.tool_calls[0]
assert tool_call.id == "call_123"
assert tool_call.tool_name == "get_weather"
assert tool_call.arguments == {"location": "Berlin"}


def test_from_openai_dict_format_tool_message():
openai_msg = {"role": "tool", "content": "The weather is sunny", "tool_call_id": "call_123"}
message = ChatMessage.from_openai_dict_format(openai_msg)
assert message.role.value == "tool"
assert message.tool_call_result.result == "The weather is sunny"
assert message.tool_call_result.origin.id == "call_123"


def test_from_openai_dict_format_tool_without_id():
openai_msg = {"role": "tool", "content": "The weather is sunny"}
message = ChatMessage.from_openai_dict_format(openai_msg)
assert message.role.value == "tool"
assert message.tool_call_result.result == "The weather is sunny"
assert message.tool_call_result.origin.id is None


def test_from_openai_dict_format_missing_role():
with pytest.raises(ValueError):
ChatMessage.from_openai_dict_format({"content": "test"})


def test_from_openai_dict_format_missing_content():
with pytest.raises(ValueError):
ChatMessage.from_openai_dict_format({"role": "user"})


def test_from_openai_dict_format_invalid_tool_calls():
openai_msg = {"role": "assistant", "tool_calls": [{"invalid": "format"}]}
with pytest.raises(ValueError):
ChatMessage.from_openai_dict_format(openai_msg)


def test_from_openai_dict_format_unsupported_role():
with pytest.raises(ValueError):
ChatMessage.from_openai_dict_format({"role": "invalid", "content": "test"})


def test_from_openai_dict_format_assistant_missing_content_and_tool_calls():
with pytest.raises(ValueError):
ChatMessage.from_openai_dict_format({"role": "assistant", "irrelevant": "irrelevant"})


@pytest.mark.integration
def test_apply_chat_templating_on_chat_message():
messages = [ChatMessage.from_system("You are good assistant"), ChatMessage.from_user("I have a question")]
Expand Down
Loading