-
Notifications
You must be signed in to change notification settings - Fork 189
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(drivers): add GriptapeCloudPromptDriver
- Loading branch information
1 parent
89c32d4
commit 542878e
Showing
6 changed files
with
586 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
19 changes: 19 additions & 0 deletions
19
docs/griptape-framework/drivers/src/prompt_drivers_griptape_cloud.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
import os | ||
|
||
from griptape.drivers.prompt.openai import OpenAiChatPromptDriver | ||
from griptape.rules import Rule | ||
from griptape.structures import Agent | ||
|
||
agent = Agent( | ||
prompt_driver=OpenAiChatPromptDriver( | ||
api_key=os.environ["GT_CLOUD_API_KEY"], | ||
model="gpt-4o", | ||
), | ||
rules=[ | ||
Rule( | ||
"You will be provided with a product description and seed words, and your task is to generate product names.", | ||
), | ||
], | ||
) | ||
|
||
agent.run("Product description: A home milkshake maker. Seed words: fast, healthy, compact.") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
from griptape.drivers.prompt.griptape_cloud_prompt_driver import GriptapeCloudPromptDriver | ||
|
||
__all__ = ["GriptapeCloudPromptDriver"] |
109 changes: 109 additions & 0 deletions
109
griptape/drivers/prompt/griptape_cloud_prompt_driver.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
from __future__ import annotations | ||
|
||
import logging | ||
import os | ||
from typing import TYPE_CHECKING | ||
from urllib.parse import urljoin | ||
|
||
import requests | ||
from attrs import Factory, define, field | ||
|
||
from griptape.common import DeltaMessage, Message, PromptStack, observable | ||
from griptape.configs.defaults_config import Defaults | ||
from griptape.drivers.prompt import BasePromptDriver | ||
from griptape.tokenizers import BaseTokenizer, SimpleTokenizer | ||
|
||
if TYPE_CHECKING: | ||
from collections.abc import Iterator | ||
|
||
from griptape.tools.base_tool import BaseTool | ||
|
||
|
||
logger = logging.getLogger(Defaults.logging_config.logger_name) | ||
|
||
|
||
@define | ||
class GriptapeCloudPromptDriver(BasePromptDriver): | ||
base_url: str = field( | ||
default=Factory(lambda: os.getenv("GT_CLOUD_BASE_URL", "https://cloud.griptape.ai")), | ||
) | ||
## temp for testing, no auth | ||
api_key: str = field(default=Factory(lambda: os.environ["GT_CLOUD_API_KEY"])) | ||
headers: dict = field( | ||
default=Factory( | ||
lambda self: {"Authorization": f"Bearer {self.api_key}"} if self.api_key is not None else {}, | ||
takes_self=True, | ||
), | ||
kw_only=True, | ||
) | ||
## end temp no auth | ||
tokenizer: BaseTokenizer = field( | ||
default=Factory( | ||
lambda self: SimpleTokenizer( | ||
characters_per_token=4, | ||
max_input_tokens=2000, | ||
max_output_tokens=self.max_tokens, | ||
), | ||
takes_self=True, | ||
), | ||
kw_only=True, | ||
) | ||
|
||
@observable | ||
def try_run(self, prompt_stack: PromptStack) -> Message: | ||
url = urljoin(self.base_url.strip("/"), "/api/chat/messages") | ||
|
||
params = self._base_params(prompt_stack) | ||
logger.debug(params) | ||
response = requests.post(url, headers=self.headers, json=params) | ||
response.raise_for_status() | ||
response_json = response.json() | ||
logger.debug(response_json) | ||
|
||
return Message.from_dict(response_json) | ||
|
||
@observable | ||
def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]: | ||
url = urljoin(self.base_url.strip("/"), "/api/chat/messages/stream") | ||
params = self._base_params(prompt_stack) | ||
logger.debug(params) | ||
with requests.post(url, headers=self.headers, json=params, stream=True) as response: | ||
response.raise_for_status() | ||
for line in response.iter_lines(): | ||
if line: | ||
decoded_line = line.decode("utf-8") | ||
if decoded_line.startswith("data:"): | ||
delta_message_payload = decoded_line.removeprefix("data:").strip() | ||
logger.debug(delta_message_payload) | ||
yield DeltaMessage.from_json(delta_message_payload) | ||
|
||
def _base_params(self, prompt_stack: PromptStack) -> dict: | ||
return { | ||
"prompt_stack": prompt_stack.to_dict(), | ||
"model": self.model, | ||
"tools": self.__to_griptape_tools__(prompt_stack.tools), | ||
**({"output_schema": prompt_stack.to_output_json_schema()} if prompt_stack.output_schema else {}), | ||
"params": { | ||
"max_tokens": self.max_tokens, | ||
"use_native_tools": self.use_native_tools, | ||
"temperature": self.temperature, | ||
"structured_output_strategy": self.structured_output_strategy, | ||
"extra_params": self.extra_params, | ||
}, | ||
} | ||
|
||
def __to_griptape_tools__(self, tools: list[BaseTool]) -> list[dict]: | ||
return [ | ||
{ | ||
"name": tool.name, | ||
"activities": [ | ||
{ | ||
"name": activity.__name__, | ||
"description": tool.activity_description(activity), | ||
"json_schema": tool.to_activity_json_schema(activity, "Schema"), | ||
} | ||
for activity in tool.activities() | ||
], | ||
} | ||
for tool in tools | ||
] |
Oops, something went wrong.