Skip to content

Commit

Permalink
feat(drivers): add GriptapeCloudPromptDriver
Browse files Browse the repository at this point in the history
  • Loading branch information
collindutter committed Feb 19, 2025
1 parent 89c32d4 commit 542878e
Show file tree
Hide file tree
Showing 6 changed files with 586 additions and 0 deletions.
8 changes: 8 additions & 0 deletions docs/griptape-framework/drivers/prompt-drivers.md
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,14 @@ This driver uses [Azure OpenAi function calling](https://learn.microsoft.com/en-
--8<-- "docs/griptape-framework/drivers/src/prompt_drivers_5.py"
```

### Griptape Cloud

The [GriptapeCloudPromptDriver](../../reference/griptape/drivers/prompt/griptape_cloud_prompt_driver.md) connects to the [Griptape Cloud](https://www.griptape.ai/cloud) chat messages API.

```python
--8<-- "docs/griptape-framework/drivers/src/prompt_drivers_griptape_cloud.py"
```

### Cohere

The [CoherePromptDriver](../../reference/griptape/drivers/prompt/cohere_prompt_driver.md) connects to the Cohere [Chat](https://docs.cohere.com/docs/chat-api) API.
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
import os

from griptape.drivers.prompt.openai import OpenAiChatPromptDriver
from griptape.rules import Rule
from griptape.structures import Agent

agent = Agent(
prompt_driver=OpenAiChatPromptDriver(
api_key=os.environ["GT_CLOUD_API_KEY"],
model="gpt-4o",
),
rules=[
Rule(
"You will be provided with a product description and seed words, and your task is to generate product names.",
),
],
)

agent.run("Product description: A home milkshake maker. Seed words: fast, healthy, compact.")
2 changes: 2 additions & 0 deletions griptape/drivers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from .prompt.google import GooglePromptDriver
from .prompt.dummy import DummyPromptDriver
from .prompt.ollama import OllamaPromptDriver
from .prompt.griptape_cloud import GriptapeCloudPromptDriver

from .memory.conversation import BaseConversationMemoryDriver
from .memory.conversation.local import LocalConversationMemoryDriver
Expand Down Expand Up @@ -139,6 +140,7 @@
"GooglePromptDriver",
"DummyPromptDriver",
"OllamaPromptDriver",
"GriptapeCloudPromptDriver",
"BaseConversationMemoryDriver",
"LocalConversationMemoryDriver",
"AmazonDynamoDbConversationMemoryDriver",
Expand Down
3 changes: 3 additions & 0 deletions griptape/drivers/prompt/griptape_cloud/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from griptape.drivers.prompt.griptape_cloud_prompt_driver import GriptapeCloudPromptDriver

__all__ = ["GriptapeCloudPromptDriver"]
109 changes: 109 additions & 0 deletions griptape/drivers/prompt/griptape_cloud_prompt_driver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
from __future__ import annotations

import logging
import os
from typing import TYPE_CHECKING
from urllib.parse import urljoin

import requests
from attrs import Factory, define, field

from griptape.common import DeltaMessage, Message, PromptStack, observable
from griptape.configs.defaults_config import Defaults
from griptape.drivers.prompt import BasePromptDriver
from griptape.tokenizers import BaseTokenizer, SimpleTokenizer

if TYPE_CHECKING:
from collections.abc import Iterator

from griptape.tools.base_tool import BaseTool


logger = logging.getLogger(Defaults.logging_config.logger_name)


@define
class GriptapeCloudPromptDriver(BasePromptDriver):
base_url: str = field(
default=Factory(lambda: os.getenv("GT_CLOUD_BASE_URL", "https://cloud.griptape.ai")),
)
## temp for testing, no auth
api_key: str = field(default=Factory(lambda: os.environ["GT_CLOUD_API_KEY"]))
headers: dict = field(
default=Factory(
lambda self: {"Authorization": f"Bearer {self.api_key}"} if self.api_key is not None else {},
takes_self=True,
),
kw_only=True,
)
## end temp no auth
tokenizer: BaseTokenizer = field(
default=Factory(
lambda self: SimpleTokenizer(
characters_per_token=4,
max_input_tokens=2000,
max_output_tokens=self.max_tokens,
),
takes_self=True,
),
kw_only=True,
)

@observable
def try_run(self, prompt_stack: PromptStack) -> Message:
url = urljoin(self.base_url.strip("/"), "/api/chat/messages")

params = self._base_params(prompt_stack)
logger.debug(params)
response = requests.post(url, headers=self.headers, json=params)
response.raise_for_status()
response_json = response.json()
logger.debug(response_json)

return Message.from_dict(response_json)

@observable
def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]:
url = urljoin(self.base_url.strip("/"), "/api/chat/messages/stream")
params = self._base_params(prompt_stack)
logger.debug(params)
with requests.post(url, headers=self.headers, json=params, stream=True) as response:
response.raise_for_status()
for line in response.iter_lines():
if line:
decoded_line = line.decode("utf-8")
if decoded_line.startswith("data:"):
delta_message_payload = decoded_line.removeprefix("data:").strip()
logger.debug(delta_message_payload)
yield DeltaMessage.from_json(delta_message_payload)

def _base_params(self, prompt_stack: PromptStack) -> dict:
return {
"prompt_stack": prompt_stack.to_dict(),
"model": self.model,
"tools": self.__to_griptape_tools__(prompt_stack.tools),
**({"output_schema": prompt_stack.to_output_json_schema()} if prompt_stack.output_schema else {}),
"params": {
"max_tokens": self.max_tokens,
"use_native_tools": self.use_native_tools,
"temperature": self.temperature,
"structured_output_strategy": self.structured_output_strategy,
"extra_params": self.extra_params,
},
}

def __to_griptape_tools__(self, tools: list[BaseTool]) -> list[dict]:
return [
{
"name": tool.name,
"activities": [
{
"name": activity.__name__,
"description": tool.activity_description(activity),
"json_schema": tool.to_activity_json_schema(activity, "Schema"),
}
for activity in tool.activities()
],
}
for tool in tools
]
Loading

0 comments on commit 542878e

Please sign in to comment.