Skip to content

Commit

Permalink
better function names, a bit more typing
Browse files Browse the repository at this point in the history
  • Loading branch information
mertyg committed Jul 7, 2024
1 parent 6565dee commit d09c713
Show file tree
Hide file tree
Showing 3 changed files with 32 additions and 27 deletions.
1 change: 1 addition & 0 deletions textgrad/engine/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
"opus": "claude-3-opus-20240229",
"haiku": "claude-3-haiku-20240307",
"sonnet": "claude-3-sonnet-20240229",
"sonnet-3.5": "claude-3-5-sonnet-20240620",
"together-llama-3-70b": "together-meta-llama/Llama-3-70b-chat-hf",
}

Expand Down
23 changes: 12 additions & 11 deletions textgrad/engine/anthropic.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ class ChatAnthropic(EngineLM, CachedEngine):

def __init__(
self,
model_string="claude-3-opus-20240229",
system_prompt=SYSTEM_PROMPT,
is_multimodal=False,
model_string: str="claude-3-opus-20240229",
system_prompt: str=SYSTEM_PROMPT,
is_multimodal: bool=False,
):
root = platformdirs.user_cache_dir("textgrad")
cache_path = os.path.join(root, f"cache_anthropic_{model_string}.db")
Expand All @@ -43,18 +43,19 @@ def __call__(self, prompt, **kwargs):
return self.generate(prompt, **kwargs)

@retry(wait=wait_random_exponential(min=1, max=5), stop=stop_after_attempt(5))
def generate(self, content: Union[str, List[Union[str, bytes]]], system_prompt=None, **kwargs):
if any(isinstance(item, bytes) for item in content):
return self._generate_text(content, system_prompt=system_prompt, **kwargs)
def generate(self, content: Union[str, List[Union[str, bytes]]], system_prompt: str=None, **kwargs):
if isinstance(content, str):
return self._generate_from_single_prompt(content, system_prompt=system_prompt, **kwargs)

elif isinstance(content, list):
if not self.is_multimodal:
has_multimodal_input = any(isinstance(item, bytes) for item in content)
if (has_multimodal_input) and (not self.is_multimodal):
raise NotImplementedError("Multimodal generation is only supported for Claude-3 and beyond.")

return self._generate_multimodal(content, system_prompt=system_prompt, **kwargs)
return self._generate_from_multiple_input(content, system_prompt=system_prompt, **kwargs)

def _generate_text(
self, prompt, system_prompt=None, temperature=0, max_tokens=2000, top_p=0.99
def _generate_from_single_prompt(
self, prompt: str, system_prompt: str=None, temperature=0, max_tokens=2000, top_p=0.99
):

sys_prompt_arg = system_prompt if system_prompt else self.system_prompt
Expand Down Expand Up @@ -105,7 +106,7 @@ def _format_content(self, content: List[Union[str, bytes]]) -> List[dict]:
raise ValueError(f"Unsupported input type: {type(item)}")
return formatted_content

def _generate_multimodal(
def _generate_from_multiple_input(
self, content: List[Union[str, bytes]], system_prompt=None, temperature=0, max_tokens=2000, top_p=0.99
):
sys_prompt_arg = system_prompt if system_prompt else self.system_prompt
Expand Down
35 changes: 19 additions & 16 deletions textgrad/engine/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ class ChatOpenAI(EngineLM, CachedEngine):

def __init__(
self,
model_string="gpt-3.5-turbo-0613",
system_prompt=DEFAULT_SYSTEM_PROMPT,
model_string: str="gpt-3.5-turbo-0613",
system_prompt: str=DEFAULT_SYSTEM_PROMPT,
is_multimodal: bool=False,
**kwargs):
"""
Expand All @@ -32,8 +32,6 @@ def __init__(
"""
root = platformdirs.user_cache_dir("textgrad")
cache_path = os.path.join(root, f"cache_openai_{model_string}.db")
self.image_cache_dir = os.path.join(root, "image_cache")
os.makedirs(self.image_cache_dir, exist_ok=True)

super().__init__(cache_path=cache_path)

Expand All @@ -48,17 +46,19 @@ def __init__(
self.is_multimodal = is_multimodal

@retry(wait=wait_random_exponential(min=1, max=5), stop=stop_after_attempt(5))
def generate(self, content: Union[str, List[Union[str, bytes]]], system_prompt=None, **kwargs):
if any(isinstance(item, bytes) for item in content):
if not self.is_multimodal:
raise NotImplementedError("Multimodal generation is only supported for GPT-4 models.")

return self._generate_multimodal(content, system_prompt=system_prompt, **kwargs)

return self._generate_text(content, system_prompt=system_prompt, **kwargs)

def _generate_text(
self, prompt, system_prompt=None, temperature=0, max_tokens=2000, top_p=0.99
def generate(self, content: Union[str, List[Union[str, bytes]]], system_prompt: str=None, **kwargs):
if isinstance(content, str):
return self._generate_from_single_prompt(content, system_prompt=system_prompt, **kwargs)

elif isinstance(content, list):
has_multimodal_input = any(isinstance(item, bytes) for item in content)
if (has_multimodal_input) and (not self.is_multimodal):
raise NotImplementedError("Multimodal generation is only supported for Claude-3 and beyond.")

return self._generate_from_multiple_input(content, system_prompt=system_prompt, **kwargs)

def _generate_from_single_prompt(
self, prompt: str, system_prompt: str=None, temperature=0, max_tokens=2000, top_p=0.99
):

sys_prompt_arg = system_prompt if system_prompt else self.system_prompt
Expand Down Expand Up @@ -89,9 +89,12 @@ def __call__(self, prompt, **kwargs):
return self.generate(prompt, **kwargs)

def _format_content(self, content: List[Union[str, bytes]]) -> List[dict]:
"""Helper function to format a list of strings and bytes into a list of dictionaries to pass as messages to the API.
"""
formatted_content = []
for item in content:
if isinstance(item, bytes):
# For now, bytes are assumed to be images
image_type = get_image_type_from_bytes(item)
base64_image = base64.b64encode(item).decode('utf-8')
formatted_content.append({
Expand All @@ -109,7 +112,7 @@ def _format_content(self, content: List[Union[str, bytes]]) -> List[dict]:
raise ValueError(f"Unsupported input type: {type(item)}")
return formatted_content

def _generate_multimodal(
def _generate_from_multiple_input(
self, content: List[Union[str, bytes]], system_prompt=None, temperature=0, max_tokens=2000, top_p=0.99
):
sys_prompt_arg = system_prompt if system_prompt else self.system_prompt
Expand Down

0 comments on commit d09c713

Please sign in to comment.