Skip to content

Commit

Permalink
feat(drivers-prompt-openai):add audio input/output support to OpenAiC…
Browse files Browse the repository at this point in the history
…hatPromptDriver
  • Loading branch information
collindutter committed Feb 5, 2025
1 parent ce01b8b commit f023372
Show file tree
Hide file tree
Showing 20 changed files with 535 additions and 98 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- Tool streaming support to `OllamaPromptDriver`.
- `DateTimeTool.add_timedelta` and `DateTimeTool.get_datetime_diff` for basic datetime arithmetic.
- Support for `pydantic.BaseModel`s anywhere `schema.Schema` is supported.
- Support for `AudioArtifact` inputs/outputs in `OpenAiChatPromptDriver`.

### Changed

Expand Down
4 changes: 4 additions & 0 deletions griptape/common/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
from .prompt_stack.contents.base_message_content import BaseMessageContent
from .prompt_stack.contents.base_delta_message_content import BaseDeltaMessageContent
from .prompt_stack.contents.text_delta_message_content import TextDeltaMessageContent
from .prompt_stack.contents.audio_delta_message_content import AudioDeltaMessageContent
from .prompt_stack.contents.text_message_content import TextMessageContent
from .prompt_stack.contents.image_message_content import ImageMessageContent
from .prompt_stack.contents.audio_message_content import AudioMessageContent
from .prompt_stack.contents.action_call_delta_message_content import ActionCallDeltaMessageContent
from .prompt_stack.contents.action_call_message_content import ActionCallMessageContent
from .prompt_stack.contents.action_result_message_content import ActionResultMessageContent
Expand All @@ -30,8 +32,10 @@
"DeltaMessage",
"Message",
"TextDeltaMessageContent",
"AudioDeltaMessageContent",
"TextMessageContent",
"ImageMessageContent",
"AudioMessageContent",
"GenericMessageContent",
"ActionCallDeltaMessageContent",
"ActionCallMessageContent",
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from __future__ import annotations

from typing import Optional

from attrs import define, field

from griptape.common import BaseDeltaMessageContent


@define(kw_only=True)
class AudioDeltaMessageContent(BaseDeltaMessageContent):
"""A delta message content for audio data.
Attributes:
id: The ID of the audio data.
data: Base64 encoded audio data.
transcript: The transcript of the audio data.
expires_at: The Unix timestamp (in seconds) for when this audio data will no longer be accessible.
"""

id: Optional[str] = field(default=None, metadata={"serializable": True})
data: Optional[str] = field(default=None, metadata={"serializable": True})
transcript: Optional[str] = field(default=None, metadata={"serializable": True})
expires_at: Optional[int] = field(default=None, metadata={"serializable": True})
43 changes: 43 additions & 0 deletions griptape/common/prompt_stack/contents/audio_message_content.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
from __future__ import annotations

import base64
from typing import TYPE_CHECKING

from attrs import define, field

from griptape.artifacts import AudioArtifact
from griptape.common import (
AudioDeltaMessageContent,
BaseDeltaMessageContent,
BaseMessageContent,
)

if TYPE_CHECKING:
from collections.abc import Sequence


@define
class AudioMessageContent(BaseMessageContent):
artifact: AudioArtifact = field(metadata={"serializable": True})

@classmethod
def from_deltas(cls, deltas: Sequence[BaseDeltaMessageContent]) -> AudioMessageContent:
audio_deltas = [delta for delta in deltas if isinstance(delta, AudioDeltaMessageContent)]
audio_data = [delta.data for delta in audio_deltas if delta.data is not None]
transcript_data = [delta.transcript for delta in audio_deltas if delta.transcript is not None]
expires_at = next(delta.expires_at for delta in audio_deltas if delta.expires_at is not None)
audio_id = next(delta.id for delta in audio_deltas if delta.id is not None)

audio_transcript = "".join(data for data in transcript_data)

artifact = AudioArtifact(
value=b"".join(base64.b64decode(data) for data in audio_data),
format="wav",
meta={
"audio_id": audio_id,
"expires_at": expires_at,
"transcript": audio_transcript,
},
)

return cls(artifact=artifact)
4 changes: 4 additions & 0 deletions griptape/common/prompt_stack/prompt_stack.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from griptape.artifacts import (
ActionArtifact,
AudioArtifact,
BaseArtifact,
GenericArtifact,
ImageArtifact,
Expand All @@ -17,6 +18,7 @@
from griptape.common import (
ActionCallMessageContent,
ActionResultMessageContent,
AudioMessageContent,
BaseMessageContent,
GenericMessageContent,
ImageMessageContent,
Expand Down Expand Up @@ -91,6 +93,8 @@ def __to_message_content(self, artifact: str | BaseArtifact) -> list[BaseMessage
return [TextMessageContent(artifact)]
elif isinstance(artifact, ImageArtifact):
return [ImageMessageContent(artifact)]
elif isinstance(artifact, AudioArtifact):
return [AudioMessageContent(artifact)]
elif isinstance(artifact, GenericArtifact):
return [GenericMessageContent(artifact)]
elif isinstance(artifact, ActionArtifact):
Expand Down
8 changes: 8 additions & 0 deletions griptape/drivers/prompt/base_prompt_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
from griptape.common import (
ActionCallDeltaMessageContent,
ActionCallMessageContent,
AudioDeltaMessageContent,
AudioMessageContent,
BaseDeltaMessageContent,
DeltaMessage,
Message,
Expand All @@ -19,6 +21,7 @@
)
from griptape.events import (
ActionChunkEvent,
AudioChunkEvent,
EventBus,
FinishPromptEvent,
StartPromptEvent,
Expand Down Expand Up @@ -177,6 +180,8 @@ def __process_stream(self, prompt_stack: PromptStack) -> Message:
delta_contents[content.index] = [content]
if isinstance(content, TextDeltaMessageContent):
EventBus.publish_event(TextChunkEvent(token=content.text, index=content.index))
elif isinstance(content, AudioDeltaMessageContent) and content.data is not None:
EventBus.publish_event(AudioChunkEvent(data=content.data))
elif isinstance(content, ActionCallDeltaMessageContent):
EventBus.publish_event(
ActionChunkEvent(
Expand All @@ -197,10 +202,13 @@ def __build_message(
content = []
for delta_content in delta_contents:
text_deltas = [delta for delta in delta_content if isinstance(delta, TextDeltaMessageContent)]
audio_deltas = [delta for delta in delta_content if isinstance(delta, AudioDeltaMessageContent)]
action_deltas = [delta for delta in delta_content if isinstance(delta, ActionCallDeltaMessageContent)]

if text_deltas:
content.append(TextMessageContent.from_deltas(text_deltas))
if audio_deltas:
content.append(AudioMessageContent.from_deltas(audio_deltas))
if action_deltas:
content.append(ActionCallMessageContent.from_deltas(action_deltas))

Expand Down
107 changes: 81 additions & 26 deletions griptape/drivers/prompt/openai_chat_prompt_driver.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,21 @@
from __future__ import annotations

import base64
import json
import logging
import time
from typing import TYPE_CHECKING, Optional

import openai
from attrs import Factory, define, field

from griptape.artifacts import ActionArtifact, TextArtifact
from griptape.artifacts import ActionArtifact, AudioArtifact, TextArtifact
from griptape.common import (
ActionCallDeltaMessageContent,
ActionCallMessageContent,
ActionResultMessageContent,
AudioDeltaMessageContent,
AudioMessageContent,
BaseDeltaMessageContent,
BaseMessageContent,
DeltaMessage,
Expand Down Expand Up @@ -93,6 +97,10 @@ class OpenAiChatPromptDriver(BasePromptDriver):
),
kw_only=True,
)
modalities: list[str] = field(default=Factory(lambda: ["text"]), kw_only=True, metadata={"serializable": True})
audio: dict = field(
default=Factory(lambda: {"voice": "alloy", "format": "pcm16"}), kw_only=True, metadata={"serializable": True}
)
_client: openai.OpenAI = field(default=None, kw_only=True, alias="client", metadata={"serializable": False})

@lazy_property()
Expand Down Expand Up @@ -143,14 +151,19 @@ def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]:
choice = chunk.choices[0]
delta = choice.delta

yield DeltaMessage(content=self.__to_prompt_stack_delta_message_content(delta))
content = self.__to_prompt_stack_delta_message_content(delta)

if content is not None:
yield DeltaMessage(content=content)

def _base_params(self, prompt_stack: PromptStack) -> dict:
params = {
"model": self.model,
"temperature": self.temperature,
"user": self.user,
"seed": self.seed,
"modalities": self.modalities,
"audio": self.audio,
**({"stop": self.tokenizer.stop_sequences} if self.tokenizer.stop_sequences else {}),
**({"max_tokens": self.max_tokens} if self.max_tokens is not None else {}),
**({"stream_options": {"include_usage": True}} if self.stream else {}),
Expand Down Expand Up @@ -196,44 +209,47 @@ def __to_openai_messages(self, messages: list[Message]) -> list[dict]:

for message in messages:
# If the message only contains textual content we can send it as a single content.
if message.is_text():
if message.has_all_content_type(TextMessageContent):
openai_messages.append({"role": self.__to_openai_role(message), "content": message.to_text()})
# Action results must be sent as separate messages.
elif message.has_any_content_type(ActionResultMessageContent):
elif action_result_contents := message.get_content_type(ActionResultMessageContent):
openai_messages.extend(
{
"role": self.__to_openai_role(message, action_result),
"content": self.__to_openai_message_content(action_result),
"tool_call_id": action_result.action.tag,
"role": self.__to_openai_role(message, action_result_content),
"content": self.__to_openai_message_content(action_result_content),
"tool_call_id": action_result_content.action.tag,
}
for action_result in message.get_content_type(ActionResultMessageContent)
for action_result_content in action_result_contents
)

if message.has_any_content_type(TextMessageContent):
openai_messages.append({"role": self.__to_openai_role(message), "content": message.to_text()})
else:
openai_message = {
"role": self.__to_openai_role(message),
"content": [
self.__to_openai_message_content(content)
for content in [
content for content in message.content if not isinstance(content, ActionCallMessageContent)
]
],
"content": [],
}

for content in message.content:
if isinstance(content, ActionCallMessageContent):
if "tool_calls" not in openai_message:
openai_message["tool_calls"] = []
openai_message["tool_calls"].append(self.__to_openai_message_content(content))
elif (
isinstance(content, AudioMessageContent)
and message.is_assistant()
and time.time() < content.artifact.meta.get("expires_at", float("inf"))
):
openai_message["audio"] = {
"id": content.artifact.meta["audio_id"],
}
else:
openai_message["content"].append(self.__to_openai_message_content(content))

# Some OpenAi-compatible services don't accept an empty array for content
if not openai_message["content"]:
openai_message["content"] = ""

# Action calls must be attached to the message, not sent as content.
action_call_content = [
content for content in message.content if isinstance(content, ActionCallMessageContent)
]
if action_call_content:
openai_message["tool_calls"] = [
self.__to_openai_message_content(action_call) for action_call in action_call_content
]

openai_messages.append(openai_message)

return openai_messages
Expand Down Expand Up @@ -271,6 +287,23 @@ def __to_openai_message_content(self, content: BaseMessageContent) -> str | dict
"type": "image_url",
"image_url": {"url": f"data:{content.artifact.mime_type};base64,{content.artifact.base64}"},
}
elif isinstance(content, AudioMessageContent):
artifact = content.artifact

# We can't send the audio if it's expired.
if int(time.time()) > artifact.meta.get("expires_at", float("inf")):
return {
"type": "text",
"text": artifact.meta.get("transcript"),
}
else:
return {
"type": "input_audio",
"input_audio": {
"data": base64.b64encode(artifact.value).decode("utf-8"),
"format": artifact.format,
},
}
elif isinstance(content, ActionCallMessageContent):
action = content.artifact.value

Expand All @@ -289,6 +322,20 @@ def __to_prompt_stack_message_content(self, response: ChatCompletionMessage) ->

if response.content is not None:
content.append(TextMessageContent(TextArtifact(response.content)))
if response.audio is not None:
content.append(
AudioMessageContent(
AudioArtifact(
value=base64.b64decode(response.audio.data),
format="wav",
meta={
"audio_id": response.audio.id,
"transcript": response.audio.transcript,
"expires_at": response.audio.expires_at,
},
)
)
)
if response.tool_calls is not None:
content.extend(
[
Expand All @@ -308,7 +355,7 @@ def __to_prompt_stack_message_content(self, response: ChatCompletionMessage) ->

return content

def __to_prompt_stack_delta_message_content(self, content_delta: ChoiceDelta) -> BaseDeltaMessageContent:
def __to_prompt_stack_delta_message_content(self, content_delta: ChoiceDelta) -> Optional[BaseDeltaMessageContent]:
if content_delta.content is not None:
return TextDeltaMessageContent(content_delta.content)
elif content_delta.tool_calls is not None:
Expand All @@ -333,5 +380,13 @@ def __to_prompt_stack_delta_message_content(self, content_delta: ChoiceDelta) ->
raise ValueError(f"Unsupported tool call delta: {tool_call}")
else:
raise ValueError(f"Unsupported tool call delta length: {len(tool_calls)}")
else:
return TextDeltaMessageContent("")
# OpenAi doesn't have types for audio deltas so we need to use hasattr and getattr.
elif hasattr(content_delta, "audio") and getattr(content_delta, "audio") is not None:
audio_chunk: dict = getattr(content_delta, "audio")
return AudioDeltaMessageContent(
id=audio_chunk.get("id"),
data=audio_chunk.get("data"),
expires_at=audio_chunk.get("expires_at"),
transcript=audio_chunk.get("transcript"),
)
return None
2 changes: 2 additions & 0 deletions griptape/events/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .finish_structure_run_event import FinishStructureRunEvent
from .base_chunk_event import BaseChunkEvent
from .text_chunk_event import TextChunkEvent
from .audio_chunk_event import AudioChunkEvent
from .action_chunk_event import ActionChunkEvent
from .event_listener import EventListener
from .start_image_generation_event import StartImageGenerationEvent
Expand Down Expand Up @@ -41,6 +42,7 @@
"FinishStructureRunEvent",
"BaseChunkEvent",
"TextChunkEvent",
"AudioChunkEvent",
"ActionChunkEvent",
"EventListener",
"StartImageGenerationEvent",
Expand Down
17 changes: 17 additions & 0 deletions griptape/events/audio_chunk_event.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
from attrs import define, field

from griptape.events.base_chunk_event import BaseChunkEvent


@define
class AudioChunkEvent(BaseChunkEvent):
"""Stores a chunk of audio data.
Attributes:
data: Base64 encoded audio data.
"""

data: str = field(kw_only=True, metadata={"serializable": True})

def __str__(self) -> str:
return self.data
Loading

0 comments on commit f023372

Please sign in to comment.