diff --git a/docs/griptape-framework/drivers/prompt-drivers.md b/docs/griptape-framework/drivers/prompt-drivers.md index 4b2daf128..b6d3d1988 100644 --- a/docs/griptape-framework/drivers/prompt-drivers.md +++ b/docs/griptape-framework/drivers/prompt-drivers.md @@ -98,6 +98,14 @@ This driver uses [Azure OpenAi function calling](https://learn.microsoft.com/en- --8<-- "docs/griptape-framework/drivers/src/prompt_drivers_5.py" ``` +### Griptape Cloud + +The [GriptapeCloudPromptDriver](../../reference/griptape/drivers/prompt/griptape_cloud_prompt_driver.md) connects to the [Griptape Cloud](https://www.griptape.ai/cloud) chat messages API. + +```python +--8<-- "docs/griptape-framework/drivers/src/prompt_drivers_griptape_cloud.py" +``` + ### Cohere The [CoherePromptDriver](../../reference/griptape/drivers/prompt/cohere_prompt_driver.md) connects to the Cohere [Chat](https://docs.cohere.com/docs/chat-api) API. diff --git a/docs/griptape-framework/drivers/src/prompt_drivers_griptape_cloud.py b/docs/griptape-framework/drivers/src/prompt_drivers_griptape_cloud.py new file mode 100644 index 000000000..6f04edb17 --- /dev/null +++ b/docs/griptape-framework/drivers/src/prompt_drivers_griptape_cloud.py @@ -0,0 +1,19 @@ +import os + +from griptape.drivers.prompt.openai import OpenAiChatPromptDriver +from griptape.rules import Rule +from griptape.structures import Agent + +agent = Agent( + prompt_driver=OpenAiChatPromptDriver( + api_key=os.environ["GT_CLOUD_API_KEY"], + model="gpt-4o", + ), + rules=[ + Rule( + "You will be provided with a product description and seed words, and your task is to generate product names.", + ), + ], +) + +agent.run("Product description: A home milkshake maker. Seed words: fast, healthy, compact.") diff --git a/griptape/drivers/__init__.py b/griptape/drivers/__init__.py index fb15f8ff3..d646e623f 100644 --- a/griptape/drivers/__init__.py +++ b/griptape/drivers/__init__.py @@ -13,6 +13,7 @@ from .prompt.google import GooglePromptDriver from .prompt.dummy import DummyPromptDriver from .prompt.ollama import OllamaPromptDriver +from .prompt.griptape_cloud import GriptapeCloudPromptDriver from .memory.conversation import BaseConversationMemoryDriver from .memory.conversation.local import LocalConversationMemoryDriver @@ -139,6 +140,7 @@ "GooglePromptDriver", "DummyPromptDriver", "OllamaPromptDriver", + "GriptapeCloudPromptDriver", "BaseConversationMemoryDriver", "LocalConversationMemoryDriver", "AmazonDynamoDbConversationMemoryDriver", diff --git a/griptape/drivers/prompt/griptape_cloud/__init__.py b/griptape/drivers/prompt/griptape_cloud/__init__.py new file mode 100644 index 000000000..022104964 --- /dev/null +++ b/griptape/drivers/prompt/griptape_cloud/__init__.py @@ -0,0 +1,3 @@ +from griptape.drivers.prompt.griptape_cloud_prompt_driver import GriptapeCloudPromptDriver + +__all__ = ["GriptapeCloudPromptDriver"] diff --git a/griptape/drivers/prompt/griptape_cloud_prompt_driver.py b/griptape/drivers/prompt/griptape_cloud_prompt_driver.py new file mode 100644 index 000000000..97cae8494 --- /dev/null +++ b/griptape/drivers/prompt/griptape_cloud_prompt_driver.py @@ -0,0 +1,109 @@ +from __future__ import annotations + +import logging +import os +from typing import TYPE_CHECKING +from urllib.parse import urljoin + +import requests +from attrs import Factory, define, field + +from griptape.common import DeltaMessage, Message, PromptStack, observable +from griptape.configs.defaults_config import Defaults +from griptape.drivers.prompt import BasePromptDriver +from griptape.tokenizers import BaseTokenizer, SimpleTokenizer + +if TYPE_CHECKING: + from collections.abc import Iterator + + from griptape.tools.base_tool import BaseTool + + +logger = logging.getLogger(Defaults.logging_config.logger_name) + + +@define +class GriptapeCloudPromptDriver(BasePromptDriver): + base_url: str = field( + default=Factory(lambda: os.getenv("GT_CLOUD_BASE_URL", "https://cloud.griptape.ai")), + ) + ## temp for testing, no auth + api_key: str = field(default=Factory(lambda: os.environ["GT_CLOUD_API_KEY"])) + headers: dict = field( + default=Factory( + lambda self: {"Authorization": f"Bearer {self.api_key}"} if self.api_key is not None else {}, + takes_self=True, + ), + kw_only=True, + ) + ## end temp no auth + tokenizer: BaseTokenizer = field( + default=Factory( + lambda self: SimpleTokenizer( + characters_per_token=4, + max_input_tokens=2000, + max_output_tokens=self.max_tokens, + ), + takes_self=True, + ), + kw_only=True, + ) + + @observable + def try_run(self, prompt_stack: PromptStack) -> Message: + url = urljoin(self.base_url.strip("/"), "/api/chat/messages") + + params = self._base_params(prompt_stack) + logger.debug(params) + response = requests.post(url, headers=self.headers, json=params) + response.raise_for_status() + response_json = response.json() + logger.debug(response_json) + + return Message.from_dict(response_json) + + @observable + def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]: + url = urljoin(self.base_url.strip("/"), "/api/chat/messages/stream") + params = self._base_params(prompt_stack) + logger.debug(params) + with requests.post(url, headers=self.headers, json=params, stream=True) as response: + response.raise_for_status() + for line in response.iter_lines(): + if line: + decoded_line = line.decode("utf-8") + if decoded_line.startswith("data:"): + delta_message_payload = decoded_line.removeprefix("data:").strip() + logger.debug(delta_message_payload) + yield DeltaMessage.from_json(delta_message_payload) + + def _base_params(self, prompt_stack: PromptStack) -> dict: + return { + "prompt_stack": prompt_stack.to_dict(), + "model": self.model, + "tools": self.__to_griptape_tools__(prompt_stack.tools), + **({"output_schema": prompt_stack.to_output_json_schema()} if prompt_stack.output_schema else {}), + "params": { + "max_tokens": self.max_tokens, + "use_native_tools": self.use_native_tools, + "temperature": self.temperature, + "structured_output_strategy": self.structured_output_strategy, + "extra_params": self.extra_params, + }, + } + + def __to_griptape_tools__(self, tools: list[BaseTool]) -> list[dict]: + return [ + { + "name": tool.name, + "activities": [ + { + "name": activity.__name__, + "description": tool.activity_description(activity), + "json_schema": tool.to_activity_json_schema(activity, "Schema"), + } + for activity in tool.activities() + ], + } + for tool in tools + ] diff --git a/tests/unit/drivers/prompt/test_griptape_cloud_prompt_driver.py b/tests/unit/drivers/prompt/test_griptape_cloud_prompt_driver.py new file mode 100644 index 000000000..b61dec696 --- /dev/null +++ b/tests/unit/drivers/prompt/test_griptape_cloud_prompt_driver.py @@ -0,0 +1,445 @@ +import json +import time +from unittest.mock import ANY, MagicMock + +import pytest +from schema import Schema + +from griptape.artifacts import ActionArtifact, AudioArtifact, ImageArtifact, ListArtifact, TextArtifact +from griptape.common import ActionCallDeltaMessageContent, PromptStack, TextDeltaMessageContent, ToolAction +from griptape.common.prompt_stack.contents.audio_delta_message_content import AudioDeltaMessageContent +from griptape.drivers.prompt.griptape_cloud import GriptapeCloudPromptDriver +from tests.mocks.mock_tool.tool import MockTool + + +class TestGriptapeCloudPromptDriver: + GRIPTAPE_CLOUD_STRUCTURED_OUTPUT_SCHEMA = { + "$id": "Output Format", + "$schema": "http://json-schema.org/draft-07/schema#", + "additionalProperties": False, + "properties": {"foo": {"type": "string"}}, + "required": ["foo"], + "type": "object", + } + GRIPTAPE_CLOUD_TOOLS = [ + { + "activities": [ + { + "description": "test description: foo", + "json_schema": { + "$id": "Schema", + "$schema": "http://json-schema.org/draft-07/schema#", + "additionalProperties": False, + "properties": { + "values": { + "additionalProperties": False, + "description": "Test input", + "properties": { + "test": {"type": "string"}, + }, + "required": ["test"], + "type": "object", + }, + }, + "required": ["values"], + "type": "object", + }, + "name": "test", + }, + { + "description": "test description", + "json_schema": { + "$id": "Schema", + "$schema": "http://json-schema.org/draft-07/schema#", + "additionalProperties": False, + "properties": { + "values": { + "additionalProperties": False, + "description": "Test input", + "properties": { + "test": {"type": "string"}, + }, + "required": ["test"], + "type": "object", + }, + }, + "required": ["values"], + "type": "object", + }, + "name": "test_callable_schema", + }, + { + "description": "test description: foo", + "json_schema": { + "$id": "Schema", + "$schema": "http://json-schema.org/draft-07/schema#", + "additionalProperties": False, + "properties": { + "values": { + "additionalProperties": False, + "description": "Test input", + "properties": { + "test": {"type": "string"}, + }, + "required": ["test"], + "type": "object", + }, + }, + "required": ["values"], + "type": "object", + }, + "name": "test_error", + }, + { + "description": "test description: foo", + "json_schema": { + "$id": "Schema", + "$schema": "http://json-schema.org/draft-07/schema#", + "additionalProperties": False, + "properties": { + "values": { + "additionalProperties": False, + "description": "Test input", + "properties": { + "test": {"type": "string"}, + }, + "required": ["test"], + "type": "object", + }, + }, + "required": ["values"], + "type": "object", + }, + "name": "test_exception", + }, + { + "description": "test description", + "json_schema": { + "$id": "Schema", + "$schema": "http://json-schema.org/draft-07/schema#", + "additionalProperties": False, + "properties": {}, + "required": [], + "type": "object", + }, + "name": "test_list_output", + }, + { + "description": "test description", + "json_schema": { + "$id": "Schema", + "$schema": "http://json-schema.org/draft-07/schema#", + "additionalProperties": False, + "properties": {}, + "required": [], + "type": "object", + }, + "name": "test_no_schema", + }, + { + "description": "test description: foo", + "json_schema": { + "$id": "Schema", + "$schema": "http://json-schema.org/draft-07/schema#", + "additionalProperties": False, + "properties": { + "values": { + "additionalProperties": False, + "description": "Test input", + "properties": { + "test": {"type": "string"}, + }, + "required": ["test"], + "type": "object", + }, + }, + "required": ["values"], + "type": "object", + }, + "name": "test_str_output", + }, + { + "description": "test description", + "json_schema": { + "$id": "Schema", + "$schema": "http://json-schema.org/draft-07/schema#", + "additionalProperties": False, + "properties": { + "values": { + "additionalProperties": False, + "description": "Test input", + "properties": { + "test": {"type": "string"}, + }, + "required": ["test"], + "type": "object", + }, + }, + "required": ["values"], + "type": "object", + }, + "name": "test_without_default_memory", + }, + ], + "name": "MockTool", + }, + ] + + @pytest.fixture(autouse=True) + def mock_post(self, mocker): + def request(*args, **kwargs): + mock_response = mocker.Mock() + if "chat/messages/stream" in args[0]: + mock_response.iter_lines.return_value = [ + f"data: {json.dumps(event)}".encode() + for event in [ + { + "type": "DeltaMessage", + "content": {"type": "TextDeltaMessageContent", "text": "model-output"}, + "role": "assistant", + }, + { + "type": "DeltaMessage", + "content": { + "type": "ActionCallDeltaMessageContent", + "tag": "MockTool_test", + "name": "MockTool", + "path": "test", + "partial_input": json.dumps({"foo": "bar"}), + "index": 0, + }, + "role": "assistant", + }, + { + "type": "DeltaMessage", + "content": { + "type": "ActionCallDeltaMessageContent", + "tag": "MockTool_test", + "name": "MockTool", + "path": "test", + "partial_input": json.dumps({"foo": "bar"}), + "index": 1, + }, + "role": "assistant", + }, + { + "type": "DeltaMessage", + "content": {"type": "AudioDeltaMessageContent", "data": "YXNzaXN0YW50LWF1ZGlvLWRhdGE="}, + "role": "assistant", + }, + { + "type": "DeltaMessage", + "content": { + "type": "AudioDeltaMessageContent", + "expires_at": time.time() + 1000, + "transcript": "assistant-audio-transcription", + }, + "role": "assistant", + }, + ] + ] + + mock_response.__enter__ = MagicMock(return_value=mock_response) + mock_response.__exit__ = MagicMock() + return mock_response + elif "chat/messages" in args[0]: + mock_response.json.return_value = { + "role": "assistant", + "content": [ + { + "type": "TextMessageContent", + "artifact": {"type": "TextArtifact", "value": "text-model-output"}, + }, + { + "type": "ActionCallMessageContent", + "artifact": { + "type": "ActionArtifact", + "value": { + "type": "ToolAction", + "tag": "MockTool_test", + "name": "MockTool", + "path": "test", + "input": {"foo": "bar"}, + }, + }, + }, + { + "type": "AudioMessageContent", + "artifact": { + "type": "AudioArtifact", + "value": b"YXVkaW8tbW9kZWwtb3V0cHV0", + "format": "wav", + }, + }, + ], + } + return mock_response + else: + return mocker.Mock( + raise_for_status=lambda: None, + ) + + return mocker.patch("requests.post", side_effect=request) + + @pytest.fixture() + def prompt_stack(self): + prompt_stack = PromptStack() + prompt_stack.output_schema = Schema({"foo": str}) + prompt_stack.tools = [MockTool()] + prompt_stack.add_system_message("system-input") + prompt_stack.add_user_message("user-input") + prompt_stack.add_user_message( + ListArtifact( + [TextArtifact("user-input"), ImageArtifact(value=b"image-data", format="png", width=100, height=100)] + ) + ) + prompt_stack.add_assistant_message("assistant-input") + prompt_stack.add_assistant_message( + ListArtifact( + [ + TextArtifact(""), + ActionArtifact(ToolAction(tag="MockTool_test", name="MockTool", path="test", input={"foo": "bar"})), + ] + ) + ) + prompt_stack.add_user_message( + ListArtifact( + [ + TextArtifact("keep-going"), + ActionArtifact( + ToolAction( + tag="MockTool_test", + name="MockTool", + path="test", + input={"foo": "bar"}, + output=TextArtifact("tool-output"), + ) + ), + ] + ) + ) + return prompt_stack + + def test_init(self): + assert GriptapeCloudPromptDriver(api_key="foo", model="gpt-4o") + + @pytest.mark.parametrize("use_native_tools", [True, False]) + @pytest.mark.parametrize("structured_output_strategy", ["native", "tool", "rule", "foo"]) + def test_try_run( + self, + mock_post, + prompt_stack, + use_native_tools, + structured_output_strategy, + ): + # Given + driver = GriptapeCloudPromptDriver( + api_key="foo", + model="gpt-4o", + use_native_tools=use_native_tools, + structured_output_strategy=structured_output_strategy, + extra_params={"foo": "bar"}, + ) + + # When + message = driver.try_run(prompt_stack) + + # Then + mock_post.assert_called_once_with( + "https://cloud.griptape.ai/api/chat/messages", + headers={"Authorization": f"Bearer {driver.api_key}"}, + json={ + "prompt_stack": prompt_stack.to_dict(), + "model": driver.model, + "tools": self.GRIPTAPE_CLOUD_TOOLS, + "output_schema": self.GRIPTAPE_CLOUD_STRUCTURED_OUTPUT_SCHEMA, + "params": { + "max_tokens": driver.max_tokens, + "use_native_tools": use_native_tools, + "temperature": driver.temperature, + "structured_output_strategy": structured_output_strategy, + "extra_params": {"foo": "bar"}, + }, + }, + ) + assert isinstance(message.value[0], TextArtifact) + assert message.value[0].value == "text-model-output" + assert isinstance(message.value[1], ActionArtifact) + assert message.value[1].value.tag == "MockTool_test" + assert message.value[1].value.name == "MockTool" + assert message.value[1].value.path == "test" + assert message.value[1].value.input == {"foo": "bar"} + + assert isinstance(message.value[2], AudioArtifact) + assert message.value[2].value == b"audio-model-output" + assert message.value[2].format == "wav" + + @pytest.mark.parametrize("use_native_tools", [True, False]) + @pytest.mark.parametrize("structured_output_strategy", ["native", "tool", "rule", "foo"]) + def test_try_stream_run( + self, + mock_post, + prompt_stack, + use_native_tools, + structured_output_strategy, + ): + # Given + driver = GriptapeCloudPromptDriver( + api_key="foo", + model="gpt-4o", + stream=True, + use_native_tools=use_native_tools, + structured_output_strategy=structured_output_strategy, + extra_params={"foo": "bar"}, + ) + + # When + stream = driver.try_stream(prompt_stack) + event = next(stream) + + # Then + mock_post.assert_called_once_with( + "https://cloud.griptape.ai/api/chat/messages/stream", + headers={"Authorization": f"Bearer {driver.api_key}"}, + json={ + "prompt_stack": prompt_stack.to_dict(), + "model": driver.model, + "tools": self.GRIPTAPE_CLOUD_TOOLS, + "output_schema": self.GRIPTAPE_CLOUD_STRUCTURED_OUTPUT_SCHEMA, + "params": { + "max_tokens": driver.max_tokens, + "use_native_tools": use_native_tools, + "temperature": driver.temperature, + "structured_output_strategy": structured_output_strategy, + "extra_params": {"foo": "bar"}, + }, + }, + stream=True, + ) + assert isinstance(event.content, TextDeltaMessageContent) + assert event.content.text == "model-output" + + event = next(stream) + assert isinstance(event.content, ActionCallDeltaMessageContent) + assert event.content.index == 0 + assert event.content.tag == "MockTool_test" + assert event.content.name == "MockTool" + assert event.content.path == "test" + assert event.content.partial_input == json.dumps({"foo": "bar"}) + + event = next(stream) + assert isinstance(event.content, ActionCallDeltaMessageContent) + assert event.content.index == 1 + assert event.content.tag == "MockTool_test" + assert event.content.name == "MockTool" + assert event.content.path == "test" + assert event.content.partial_input == json.dumps({"foo": "bar"}) + + event = next(stream) + assert isinstance(event.content, AudioDeltaMessageContent) + assert event.content.data == "YXNzaXN0YW50LWF1ZGlvLWRhdGE=" + + event = next(stream) + assert isinstance(event.content, AudioDeltaMessageContent) + assert event.content.expires_at == ANY + assert event.content.transcript == "assistant-audio-transcription"