-
Notifications
You must be signed in to change notification settings - Fork 167
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
6 changed files
with
293 additions
and
16 deletions.
There are no files selected for viewing
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from textgrad.engine_experimental.openai import OpenAIEngine | ||
from textgrad.engine_experimental.litellm import LiteLLMEngine |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,82 @@ | ||
from functools import wraps | ||
from abc import ABC, abstractmethod | ||
import hashlib | ||
from typing import List, Union | ||
import diskcache as dc | ||
import platformdirs | ||
import os | ||
|
||
|
||
|
||
def cached(func): | ||
@wraps(func) | ||
def wrapper(self, *args, **kwargs): | ||
|
||
if self.cache is False: | ||
return func(self, *args, **kwargs) | ||
|
||
# get string representation from args and kwargs | ||
key = hash(str(args) + str(kwargs)) | ||
key = hashlib.sha256(f"{key}".encode()).hexdigest() | ||
|
||
if key in self.cache: | ||
return self.cache[key] | ||
|
||
result = func(self, *args, **kwargs) | ||
self.cache[key] = result | ||
return result | ||
|
||
return wrapper | ||
|
||
|
||
class EngineLM(ABC): | ||
system_prompt: str = "You are a helpful, creative, and smart assistant." | ||
model_string: str | ||
is_multimodal: bool | ||
cache: Union[dc.Cache, bool] | ||
|
||
def __init__(self, model_string: str, | ||
system_prompt: str = "You are a helpful, creative, and smart assistant.", | ||
is_multimodal: bool = False, | ||
cache=Union[dc.Cache, bool]): | ||
|
||
root = platformdirs.user_cache_dir("textgrad") | ||
default_cache_path = os.path.join(root, f"cache_model_{model_string}.db") | ||
|
||
self.model_string = model_string | ||
self.system_prompt = system_prompt | ||
self.is_multimodal = is_multimodal | ||
|
||
if isinstance(cache, dc.Cache): | ||
self.cache = cache | ||
elif cache is True: | ||
self.cache = dc.Cache(default_cache_path) | ||
elif cache is False: | ||
self.cache = False | ||
else: | ||
raise ValueError("Cache argument must be a diskcache.Cache object or a boolean.") | ||
|
||
@abstractmethod | ||
def _generate_from_multiple_input(self, prompt, system_prompt=None, **kwargs) -> str: | ||
pass | ||
|
||
@abstractmethod | ||
def _generate_from_single_prompt(self, prompt, system_prompt=None, **kwargs) -> str: | ||
pass | ||
|
||
def generate(self, content, system_prompt=Union[str | List[Union[str, bytes]]], **kwargs): | ||
sys_prompt_arg = system_prompt if system_prompt else self.system_prompt | ||
|
||
if isinstance(content, str): | ||
return self._generate_from_single_prompt(content=content, system_prompt=sys_prompt_arg, **kwargs) | ||
|
||
elif isinstance(content, list): | ||
has_multimodal_input = any(isinstance(item, bytes) for item in content) | ||
if has_multimodal_input and not self.is_multimodal: | ||
raise NotImplementedError("Multimodal generation flag is not set, but multimodal input is provided. " | ||
"Is this model multimodal?") | ||
|
||
return self._generate_from_multiple_input(content=content, system_prompt=sys_prompt_arg, **kwargs) | ||
|
||
def __call__(self, *args, **kwargs): | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
from typing import List, Union | ||
import base64 | ||
|
||
def is_jpeg(data): | ||
jpeg_signature = b'\xFF\xD8\xFF' | ||
return data.startswith(jpeg_signature) | ||
|
||
def is_png(data): | ||
png_signature = b'\x89\x50\x4E\x47\x0D\x0A\x1A\x0A' | ||
return data.startswith(png_signature) | ||
|
||
def get_image_type_from_bytes(data): | ||
if is_jpeg(data): | ||
return "jpeg" | ||
elif is_png(data): | ||
return "png" | ||
else: | ||
raise ValueError("Image type not supported, only jpeg and png supported.") | ||
|
||
def open_ai_like_formatting(content: List[Union[str, bytes]]) -> List[dict]: | ||
"""Helper function to format a list of strings and bytes into a list of dictionaries to pass as messages to the API. | ||
""" | ||
formatted_content = [] | ||
for item in content: | ||
if isinstance(item, bytes): | ||
# For now, bytes are assumed to be images | ||
image_type = get_image_type_from_bytes(item) | ||
base64_image = base64.b64encode(item).decode('utf-8') | ||
formatted_content.append({ | ||
"type": "image_url", | ||
"image_url": { | ||
"url": f"data:image/{image_type};base64,{base64_image}" | ||
} | ||
}) | ||
elif isinstance(item, str): | ||
formatted_content.append({ | ||
"type": "text", | ||
"text": item | ||
}) | ||
else: | ||
raise ValueError(f"Unsupported input type: {type(item)}") | ||
return formatted_content |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
from litellm import completion | ||
from textgrad.engine_experimental.base import EngineLM, cached | ||
import diskcache as dc | ||
from typing import Union, List | ||
from .engine_utils import open_ai_like_formatting | ||
from tenacity import ( | ||
retry, | ||
stop_after_attempt, | ||
wait_random_exponential, | ||
) | ||
|
||
class LiteLLMEngine(EngineLM): | ||
def lite_llm_generate(self, content, system_prompt=None, **kwargs) -> str: | ||
messages = [ | ||
{"role": "system", "content": system_prompt}, | ||
{"role": "user", "content": content}, | ||
] | ||
|
||
return completion(model=self.model_string, | ||
messages=messages)['choices'][0]['message']['content'] | ||
|
||
DEFAULT_SYSTEM_PROMPT = "You are a helpful, creative, and smart assistant." | ||
|
||
def __init__(self, | ||
model_string: str, | ||
system_prompt: str = DEFAULT_SYSTEM_PROMPT, | ||
is_multimodal: bool = True, | ||
cache=Union[dc.Cache, bool]): | ||
|
||
super().__init__( | ||
model_string=model_string, | ||
system_prompt=system_prompt, | ||
is_multimodal=is_multimodal, | ||
cache=cache | ||
) | ||
|
||
@cached | ||
@retry(wait=wait_random_exponential(min=1, max=5), stop=stop_after_attempt(3)) | ||
def _generate_from_single_prompt( | ||
self, content: str, system_prompt: str = None, temperature=0, max_tokens=2000, top_p=0.99 | ||
): | ||
|
||
return self.lite_llm_generate(content, system_prompt) | ||
|
||
@cached | ||
@retry(wait=wait_random_exponential(min=1, max=5), stop=stop_after_attempt(3)) | ||
def _generate_from_multiple_input( | ||
self, content: List[Union[str, bytes]], system_prompt=None, temperature=0, max_tokens=2000, top_p=0.99 | ||
): | ||
formatted_content = open_ai_like_formatting(content) | ||
|
||
return self.lite_llm_generate(formatted_content, system_prompt) | ||
|
||
def __call__(self, content, **kwargs): | ||
return self.generate(content, **kwargs) | ||
|
||
|
||
|
||
|
||
|
||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,106 @@ | ||
try: | ||
from openai import OpenAI | ||
except ImportError: | ||
raise ImportError("If you'd like to use OpenAI models, please install the openai package by running `pip install openai`, and add 'OPENAI_API_KEY' to your environment variables.") | ||
|
||
import os | ||
from typing import List, Union | ||
from textgrad.engine_experimental.engine_utils import open_ai_like_formatting | ||
from textgrad.engine_experimental.base import EngineLM, cached | ||
import diskcache as dc | ||
from tenacity import ( | ||
retry, | ||
stop_after_attempt, | ||
wait_random_exponential, | ||
) | ||
|
||
class OpenAIEngine(EngineLM): | ||
DEFAULT_SYSTEM_PROMPT = "You are a helpful, creative, and smart assistant." | ||
|
||
def __init__(self, model_string: str, | ||
system_prompt: str = DEFAULT_SYSTEM_PROMPT, | ||
is_multimodal: bool = False, | ||
cache=Union[dc.Cache, bool]): | ||
|
||
self.validate() | ||
|
||
super().__init__( | ||
model_string=model_string, | ||
system_prompt=system_prompt, | ||
is_multimodal=is_multimodal, | ||
cache=cache | ||
) | ||
|
||
self.client = OpenAI( | ||
api_key=os.getenv("OPENAI_API_KEY") | ||
) | ||
|
||
def validate(self) -> None: | ||
if os.getenv("OPENAI_API_KEY") is None: | ||
raise ValueError( | ||
"Please set the OPENAI_API_KEY environment variable if you'd like to use OpenAI models.") | ||
|
||
def openai_call(self, user_content, system_prompt, temperature, max_tokens, top_p): | ||
response = self.client.chat.completions.create( | ||
model=self.model_string, | ||
messages=[ | ||
{"role": "system", "content": system_prompt}, | ||
{"role": "user", "content": user_content}, | ||
], | ||
frequency_penalty=0, | ||
presence_penalty=0, | ||
stop=None, | ||
temperature=temperature, | ||
max_tokens=max_tokens, | ||
top_p=top_p, | ||
) | ||
|
||
return response.choices[0].message.content | ||
|
||
@cached | ||
@retry(wait=wait_random_exponential(min=1, max=5), stop=stop_after_attempt(3)) | ||
def _generate_from_single_prompt( | ||
self, content: str, system_prompt: str = None, temperature=0, max_tokens=2000, top_p=0.99 | ||
): | ||
|
||
return self.openai_call(content, system_prompt, temperature, max_tokens, top_p) | ||
|
||
@cached | ||
@retry(wait=wait_random_exponential(min=1, max=5), stop=stop_after_attempt(3)) | ||
def _generate_from_multiple_input( | ||
self, content: List[Union[str, bytes]], system_prompt=None, temperature=0, max_tokens=2000, top_p=0.99 | ||
): | ||
formatted_content = open_ai_like_formatting(content) | ||
|
||
return self.openai_call(formatted_content, system_prompt, temperature, max_tokens, top_p) | ||
|
||
def __call__(self, content, **kwargs): | ||
return self.generate(content, **kwargs) | ||
|
||
|
||
|
||
class OpenAICompatibleEngine(OpenAIEngine): | ||
""" | ||
This is the same as engine.openai.ChatOpenAI, but we pass in an external OpenAI client. | ||
""" | ||
|
||
DEFAULT_SYSTEM_PROMPT = "You are a helpful, creative, and smart assistant." | ||
client = None | ||
|
||
def __init__(self, | ||
client, | ||
model_string: str, | ||
system_prompt: str = DEFAULT_SYSTEM_PROMPT, | ||
is_multimodal: bool = False, | ||
cache=Union[dc.Cache, bool]): | ||
|
||
self.client = client | ||
|
||
super().__init__( | ||
model_string=model_string, | ||
system_prompt=system_prompt, | ||
is_multimodal=is_multimodal, | ||
cache=cache | ||
) | ||
|
||
|