diff --git a/textgrad/engine/engine_utils.py b/textgrad/engine/engine_utils.py deleted file mode 100644 index 51caff9..0000000 --- a/textgrad/engine/engine_utils.py +++ /dev/null @@ -1,16 +0,0 @@ -def is_jpeg(data): - jpeg_signature = b'\xFF\xD8\xFF' - return data.startswith(jpeg_signature) - -def is_png(data): - png_signature = b'\x89\x50\x4E\x47\x0D\x0A\x1A\x0A' - return data.startswith(png_signature) - - -def get_image_type_from_bytes(data): - if is_jpeg(data): - return "jpeg" - elif is_png(data): - return "png" - else: - raise ValueError("Image type not supported, only jpeg and png supported.") \ No newline at end of file diff --git a/textgrad/engine_experimental/__init__.py b/textgrad/engine_experimental/__init__.py new file mode 100644 index 0000000..abe0de9 --- /dev/null +++ b/textgrad/engine_experimental/__init__.py @@ -0,0 +1,2 @@ +from textgrad.engine_experimental.openai import OpenAIEngine +from textgrad.engine_experimental.litellm import LiteLLMEngine \ No newline at end of file diff --git a/textgrad/engine_experimental/base.py b/textgrad/engine_experimental/base.py new file mode 100644 index 0000000..3203a85 --- /dev/null +++ b/textgrad/engine_experimental/base.py @@ -0,0 +1,82 @@ +from functools import wraps +from abc import ABC, abstractmethod +import hashlib +from typing import List, Union +import diskcache as dc +import platformdirs +import os + + + +def cached(func): + @wraps(func) + def wrapper(self, *args, **kwargs): + + if self.cache is False: + return func(self, *args, **kwargs) + + # get string representation from args and kwargs + key = hash(str(args) + str(kwargs)) + key = hashlib.sha256(f"{key}".encode()).hexdigest() + + if key in self.cache: + return self.cache[key] + + result = func(self, *args, **kwargs) + self.cache[key] = result + return result + + return wrapper + + +class EngineLM(ABC): + system_prompt: str = "You are a helpful, creative, and smart assistant." + model_string: str + is_multimodal: bool + cache: Union[dc.Cache, bool] + + def __init__(self, model_string: str, + system_prompt: str = "You are a helpful, creative, and smart assistant.", + is_multimodal: bool = False, + cache=Union[dc.Cache, bool]): + + root = platformdirs.user_cache_dir("textgrad") + default_cache_path = os.path.join(root, f"cache_model_{model_string}.db") + + self.model_string = model_string + self.system_prompt = system_prompt + self.is_multimodal = is_multimodal + + if isinstance(cache, dc.Cache): + self.cache = cache + elif cache is True: + self.cache = dc.Cache(default_cache_path) + elif cache is False: + self.cache = False + else: + raise ValueError("Cache argument must be a diskcache.Cache object or a boolean.") + + @abstractmethod + def _generate_from_multiple_input(self, prompt, system_prompt=None, **kwargs) -> str: + pass + + @abstractmethod + def _generate_from_single_prompt(self, prompt, system_prompt=None, **kwargs) -> str: + pass + + def generate(self, content, system_prompt=Union[str | List[Union[str, bytes]]], **kwargs): + sys_prompt_arg = system_prompt if system_prompt else self.system_prompt + + if isinstance(content, str): + return self._generate_from_single_prompt(content=content, system_prompt=sys_prompt_arg, **kwargs) + + elif isinstance(content, list): + has_multimodal_input = any(isinstance(item, bytes) for item in content) + if has_multimodal_input and not self.is_multimodal: + raise NotImplementedError("Multimodal generation flag is not set, but multimodal input is provided. " + "Is this model multimodal?") + + return self._generate_from_multiple_input(content=content, system_prompt=sys_prompt_arg, **kwargs) + + def __call__(self, *args, **kwargs): + pass diff --git a/textgrad/engine_experimental/engine_utils.py b/textgrad/engine_experimental/engine_utils.py new file mode 100644 index 0000000..9e6b4de --- /dev/null +++ b/textgrad/engine_experimental/engine_utils.py @@ -0,0 +1,42 @@ +from typing import List, Union +import base64 + +def is_jpeg(data): + jpeg_signature = b'\xFF\xD8\xFF' + return data.startswith(jpeg_signature) + +def is_png(data): + png_signature = b'\x89\x50\x4E\x47\x0D\x0A\x1A\x0A' + return data.startswith(png_signature) + +def get_image_type_from_bytes(data): + if is_jpeg(data): + return "jpeg" + elif is_png(data): + return "png" + else: + raise ValueError("Image type not supported, only jpeg and png supported.") + +def open_ai_like_formatting(content: List[Union[str, bytes]]) -> List[dict]: + """Helper function to format a list of strings and bytes into a list of dictionaries to pass as messages to the API. + """ + formatted_content = [] + for item in content: + if isinstance(item, bytes): + # For now, bytes are assumed to be images + image_type = get_image_type_from_bytes(item) + base64_image = base64.b64encode(item).decode('utf-8') + formatted_content.append({ + "type": "image_url", + "image_url": { + "url": f"data:image/{image_type};base64,{base64_image}" + } + }) + elif isinstance(item, str): + formatted_content.append({ + "type": "text", + "text": item + }) + else: + raise ValueError(f"Unsupported input type: {type(item)}") + return formatted_content \ No newline at end of file diff --git a/textgrad/engine_experimental/litellm.py b/textgrad/engine_experimental/litellm.py new file mode 100644 index 0000000..48b4c10 --- /dev/null +++ b/textgrad/engine_experimental/litellm.py @@ -0,0 +1,61 @@ +from litellm import completion +from textgrad.engine_experimental.base import EngineLM, cached +import diskcache as dc +from typing import Union, List +from .engine_utils import open_ai_like_formatting +from tenacity import ( + retry, + stop_after_attempt, + wait_random_exponential, +) + +class LiteLLMEngine(EngineLM): + def lite_llm_generate(self, content, system_prompt=None, **kwargs) -> str: + messages = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": content}, + ] + + return completion(model=self.model_string, + messages=messages)['choices'][0]['message']['content'] + + DEFAULT_SYSTEM_PROMPT = "You are a helpful, creative, and smart assistant." + + def __init__(self, + model_string: str, + system_prompt: str = DEFAULT_SYSTEM_PROMPT, + is_multimodal: bool = True, + cache=Union[dc.Cache, bool]): + + super().__init__( + model_string=model_string, + system_prompt=system_prompt, + is_multimodal=is_multimodal, + cache=cache + ) + + @cached + @retry(wait=wait_random_exponential(min=1, max=5), stop=stop_after_attempt(3)) + def _generate_from_single_prompt( + self, content: str, system_prompt: str = None, temperature=0, max_tokens=2000, top_p=0.99 + ): + + return self.lite_llm_generate(content, system_prompt) + + @cached + @retry(wait=wait_random_exponential(min=1, max=5), stop=stop_after_attempt(3)) + def _generate_from_multiple_input( + self, content: List[Union[str, bytes]], system_prompt=None, temperature=0, max_tokens=2000, top_p=0.99 + ): + formatted_content = open_ai_like_formatting(content) + + return self.lite_llm_generate(formatted_content, system_prompt) + + def __call__(self, content, **kwargs): + return self.generate(content, **kwargs) + + + + + + diff --git a/textgrad/engine_experimental/openai.py b/textgrad/engine_experimental/openai.py new file mode 100644 index 0000000..2fea80f --- /dev/null +++ b/textgrad/engine_experimental/openai.py @@ -0,0 +1,106 @@ +try: + from openai import OpenAI +except ImportError: + raise ImportError("If you'd like to use OpenAI models, please install the openai package by running `pip install openai`, and add 'OPENAI_API_KEY' to your environment variables.") + +import os +from typing import List, Union +from textgrad.engine_experimental.engine_utils import open_ai_like_formatting +from textgrad.engine_experimental.base import EngineLM, cached +import diskcache as dc +from tenacity import ( + retry, + stop_after_attempt, + wait_random_exponential, +) + +class OpenAIEngine(EngineLM): + DEFAULT_SYSTEM_PROMPT = "You are a helpful, creative, and smart assistant." + + def __init__(self, model_string: str, + system_prompt: str = DEFAULT_SYSTEM_PROMPT, + is_multimodal: bool = False, + cache=Union[dc.Cache, bool]): + + self.validate() + + super().__init__( + model_string=model_string, + system_prompt=system_prompt, + is_multimodal=is_multimodal, + cache=cache + ) + + self.client = OpenAI( + api_key=os.getenv("OPENAI_API_KEY") + ) + + def validate(self) -> None: + if os.getenv("OPENAI_API_KEY") is None: + raise ValueError( + "Please set the OPENAI_API_KEY environment variable if you'd like to use OpenAI models.") + + def openai_call(self, user_content, system_prompt, temperature, max_tokens, top_p): + response = self.client.chat.completions.create( + model=self.model_string, + messages=[ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_content}, + ], + frequency_penalty=0, + presence_penalty=0, + stop=None, + temperature=temperature, + max_tokens=max_tokens, + top_p=top_p, + ) + + return response.choices[0].message.content + + @cached + @retry(wait=wait_random_exponential(min=1, max=5), stop=stop_after_attempt(3)) + def _generate_from_single_prompt( + self, content: str, system_prompt: str = None, temperature=0, max_tokens=2000, top_p=0.99 + ): + + return self.openai_call(content, system_prompt, temperature, max_tokens, top_p) + + @cached + @retry(wait=wait_random_exponential(min=1, max=5), stop=stop_after_attempt(3)) + def _generate_from_multiple_input( + self, content: List[Union[str, bytes]], system_prompt=None, temperature=0, max_tokens=2000, top_p=0.99 + ): + formatted_content = open_ai_like_formatting(content) + + return self.openai_call(formatted_content, system_prompt, temperature, max_tokens, top_p) + + def __call__(self, content, **kwargs): + return self.generate(content, **kwargs) + + + +class OpenAICompatibleEngine(OpenAIEngine): + """ + This is the same as engine.openai.ChatOpenAI, but we pass in an external OpenAI client. + """ + + DEFAULT_SYSTEM_PROMPT = "You are a helpful, creative, and smart assistant." + client = None + + def __init__(self, + client, + model_string: str, + system_prompt: str = DEFAULT_SYSTEM_PROMPT, + is_multimodal: bool = False, + cache=Union[dc.Cache, bool]): + + self.client = client + + super().__init__( + model_string=model_string, + system_prompt=system_prompt, + is_multimodal=is_multimodal, + cache=cache + ) + +