Skip to content

Commit

Permalink
new engines
Browse files Browse the repository at this point in the history
  • Loading branch information
vinid committed Sep 8, 2024
1 parent 467aa09 commit 6fd2a36
Show file tree
Hide file tree
Showing 6 changed files with 293 additions and 16 deletions.
16 changes: 0 additions & 16 deletions textgrad/engine/engine_utils.py

This file was deleted.

2 changes: 2 additions & 0 deletions textgrad/engine_experimental/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
from textgrad.engine_experimental.openai import OpenAIEngine
from textgrad.engine_experimental.litellm import LiteLLMEngine
82 changes: 82 additions & 0 deletions textgrad/engine_experimental/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
from functools import wraps
from abc import ABC, abstractmethod
import hashlib
from typing import List, Union
import diskcache as dc
import platformdirs
import os



def cached(func):
@wraps(func)
def wrapper(self, *args, **kwargs):

if self.cache is False:
return func(self, *args, **kwargs)

# get string representation from args and kwargs
key = hash(str(args) + str(kwargs))
key = hashlib.sha256(f"{key}".encode()).hexdigest()

if key in self.cache:
return self.cache[key]

result = func(self, *args, **kwargs)
self.cache[key] = result
return result

return wrapper


class EngineLM(ABC):
system_prompt: str = "You are a helpful, creative, and smart assistant."
model_string: str
is_multimodal: bool
cache: Union[dc.Cache, bool]

def __init__(self, model_string: str,
system_prompt: str = "You are a helpful, creative, and smart assistant.",
is_multimodal: bool = False,
cache=Union[dc.Cache, bool]):

root = platformdirs.user_cache_dir("textgrad")
default_cache_path = os.path.join(root, f"cache_model_{model_string}.db")

self.model_string = model_string
self.system_prompt = system_prompt
self.is_multimodal = is_multimodal

if isinstance(cache, dc.Cache):
self.cache = cache
elif cache is True:
self.cache = dc.Cache(default_cache_path)
elif cache is False:
self.cache = False
else:
raise ValueError("Cache argument must be a diskcache.Cache object or a boolean.")

@abstractmethod
def _generate_from_multiple_input(self, prompt, system_prompt=None, **kwargs) -> str:
pass

@abstractmethod
def _generate_from_single_prompt(self, prompt, system_prompt=None, **kwargs) -> str:
pass

def generate(self, content, system_prompt=Union[str | List[Union[str, bytes]]], **kwargs):
sys_prompt_arg = system_prompt if system_prompt else self.system_prompt

if isinstance(content, str):
return self._generate_from_single_prompt(content=content, system_prompt=sys_prompt_arg, **kwargs)

elif isinstance(content, list):
has_multimodal_input = any(isinstance(item, bytes) for item in content)
if has_multimodal_input and not self.is_multimodal:
raise NotImplementedError("Multimodal generation flag is not set, but multimodal input is provided. "
"Is this model multimodal?")

return self._generate_from_multiple_input(content=content, system_prompt=sys_prompt_arg, **kwargs)

def __call__(self, *args, **kwargs):
pass
42 changes: 42 additions & 0 deletions textgrad/engine_experimental/engine_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
from typing import List, Union
import base64

def is_jpeg(data):
jpeg_signature = b'\xFF\xD8\xFF'
return data.startswith(jpeg_signature)

def is_png(data):
png_signature = b'\x89\x50\x4E\x47\x0D\x0A\x1A\x0A'
return data.startswith(png_signature)

def get_image_type_from_bytes(data):
if is_jpeg(data):
return "jpeg"
elif is_png(data):
return "png"
else:
raise ValueError("Image type not supported, only jpeg and png supported.")

def open_ai_like_formatting(content: List[Union[str, bytes]]) -> List[dict]:
"""Helper function to format a list of strings and bytes into a list of dictionaries to pass as messages to the API.
"""
formatted_content = []
for item in content:
if isinstance(item, bytes):
# For now, bytes are assumed to be images
image_type = get_image_type_from_bytes(item)
base64_image = base64.b64encode(item).decode('utf-8')
formatted_content.append({
"type": "image_url",
"image_url": {
"url": f"data:image/{image_type};base64,{base64_image}"
}
})
elif isinstance(item, str):
formatted_content.append({
"type": "text",
"text": item
})
else:
raise ValueError(f"Unsupported input type: {type(item)}")
return formatted_content
61 changes: 61 additions & 0 deletions textgrad/engine_experimental/litellm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from litellm import completion
from textgrad.engine_experimental.base import EngineLM, cached
import diskcache as dc
from typing import Union, List
from .engine_utils import open_ai_like_formatting
from tenacity import (
retry,
stop_after_attempt,
wait_random_exponential,
)

class LiteLLMEngine(EngineLM):
def lite_llm_generate(self, content, system_prompt=None, **kwargs) -> str:
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": content},
]

return completion(model=self.model_string,
messages=messages)['choices'][0]['message']['content']

DEFAULT_SYSTEM_PROMPT = "You are a helpful, creative, and smart assistant."

def __init__(self,
model_string: str,
system_prompt: str = DEFAULT_SYSTEM_PROMPT,
is_multimodal: bool = True,
cache=Union[dc.Cache, bool]):

super().__init__(
model_string=model_string,
system_prompt=system_prompt,
is_multimodal=is_multimodal,
cache=cache
)

@cached
@retry(wait=wait_random_exponential(min=1, max=5), stop=stop_after_attempt(3))
def _generate_from_single_prompt(
self, content: str, system_prompt: str = None, temperature=0, max_tokens=2000, top_p=0.99
):

return self.lite_llm_generate(content, system_prompt)

@cached
@retry(wait=wait_random_exponential(min=1, max=5), stop=stop_after_attempt(3))
def _generate_from_multiple_input(
self, content: List[Union[str, bytes]], system_prompt=None, temperature=0, max_tokens=2000, top_p=0.99
):
formatted_content = open_ai_like_formatting(content)

return self.lite_llm_generate(formatted_content, system_prompt)

def __call__(self, content, **kwargs):
return self.generate(content, **kwargs)






106 changes: 106 additions & 0 deletions textgrad/engine_experimental/openai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
try:
from openai import OpenAI
except ImportError:
raise ImportError("If you'd like to use OpenAI models, please install the openai package by running `pip install openai`, and add 'OPENAI_API_KEY' to your environment variables.")

import os
from typing import List, Union
from textgrad.engine_experimental.engine_utils import open_ai_like_formatting
from textgrad.engine_experimental.base import EngineLM, cached
import diskcache as dc
from tenacity import (
retry,
stop_after_attempt,
wait_random_exponential,
)

class OpenAIEngine(EngineLM):
DEFAULT_SYSTEM_PROMPT = "You are a helpful, creative, and smart assistant."

def __init__(self, model_string: str,
system_prompt: str = DEFAULT_SYSTEM_PROMPT,
is_multimodal: bool = False,
cache=Union[dc.Cache, bool]):

self.validate()

super().__init__(
model_string=model_string,
system_prompt=system_prompt,
is_multimodal=is_multimodal,
cache=cache
)

self.client = OpenAI(
api_key=os.getenv("OPENAI_API_KEY")
)

def validate(self) -> None:
if os.getenv("OPENAI_API_KEY") is None:
raise ValueError(
"Please set the OPENAI_API_KEY environment variable if you'd like to use OpenAI models.")

def openai_call(self, user_content, system_prompt, temperature, max_tokens, top_p):
response = self.client.chat.completions.create(
model=self.model_string,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_content},
],
frequency_penalty=0,
presence_penalty=0,
stop=None,
temperature=temperature,
max_tokens=max_tokens,
top_p=top_p,
)

return response.choices[0].message.content

@cached
@retry(wait=wait_random_exponential(min=1, max=5), stop=stop_after_attempt(3))
def _generate_from_single_prompt(
self, content: str, system_prompt: str = None, temperature=0, max_tokens=2000, top_p=0.99
):

return self.openai_call(content, system_prompt, temperature, max_tokens, top_p)

@cached
@retry(wait=wait_random_exponential(min=1, max=5), stop=stop_after_attempt(3))
def _generate_from_multiple_input(
self, content: List[Union[str, bytes]], system_prompt=None, temperature=0, max_tokens=2000, top_p=0.99
):
formatted_content = open_ai_like_formatting(content)

return self.openai_call(formatted_content, system_prompt, temperature, max_tokens, top_p)

def __call__(self, content, **kwargs):
return self.generate(content, **kwargs)



class OpenAICompatibleEngine(OpenAIEngine):
"""
This is the same as engine.openai.ChatOpenAI, but we pass in an external OpenAI client.
"""

DEFAULT_SYSTEM_PROMPT = "You are a helpful, creative, and smart assistant."
client = None

def __init__(self,
client,
model_string: str,
system_prompt: str = DEFAULT_SYSTEM_PROMPT,
is_multimodal: bool = False,
cache=Union[dc.Cache, bool]):

self.client = client

super().__init__(
model_string=model_string,
system_prompt=system_prompt,
is_multimodal=is_multimodal,
cache=cache
)


0 comments on commit 6fd2a36

Please sign in to comment.