From d09c7132d0449b2cb471a0000de1c2583af07f02 Mon Sep 17 00:00:00 2001 From: mertyg Date: Sun, 7 Jul 2024 11:21:17 -0700 Subject: [PATCH] better function names, a bit more typing --- textgrad/engine/__init__.py | 1 + textgrad/engine/anthropic.py | 23 ++++++++++++----------- textgrad/engine/openai.py | 35 +++++++++++++++++++---------------- 3 files changed, 32 insertions(+), 27 deletions(-) diff --git a/textgrad/engine/__init__.py b/textgrad/engine/__init__.py index b07aff5..5f3b1f8 100644 --- a/textgrad/engine/__init__.py +++ b/textgrad/engine/__init__.py @@ -4,6 +4,7 @@ "opus": "claude-3-opus-20240229", "haiku": "claude-3-haiku-20240307", "sonnet": "claude-3-sonnet-20240229", + "sonnet-3.5": "claude-3-5-sonnet-20240620", "together-llama-3-70b": "together-meta-llama/Llama-3-70b-chat-hf", } diff --git a/textgrad/engine/anthropic.py b/textgrad/engine/anthropic.py index 6a88d00..112db70 100644 --- a/textgrad/engine/anthropic.py +++ b/textgrad/engine/anthropic.py @@ -21,9 +21,9 @@ class ChatAnthropic(EngineLM, CachedEngine): def __init__( self, - model_string="claude-3-opus-20240229", - system_prompt=SYSTEM_PROMPT, - is_multimodal=False, + model_string: str="claude-3-opus-20240229", + system_prompt: str=SYSTEM_PROMPT, + is_multimodal: bool=False, ): root = platformdirs.user_cache_dir("textgrad") cache_path = os.path.join(root, f"cache_anthropic_{model_string}.db") @@ -43,18 +43,19 @@ def __call__(self, prompt, **kwargs): return self.generate(prompt, **kwargs) @retry(wait=wait_random_exponential(min=1, max=5), stop=stop_after_attempt(5)) - def generate(self, content: Union[str, List[Union[str, bytes]]], system_prompt=None, **kwargs): - if any(isinstance(item, bytes) for item in content): - return self._generate_text(content, system_prompt=system_prompt, **kwargs) + def generate(self, content: Union[str, List[Union[str, bytes]]], system_prompt: str=None, **kwargs): + if isinstance(content, str): + return self._generate_from_single_prompt(content, system_prompt=system_prompt, **kwargs) elif isinstance(content, list): - if not self.is_multimodal: + has_multimodal_input = any(isinstance(item, bytes) for item in content) + if (has_multimodal_input) and (not self.is_multimodal): raise NotImplementedError("Multimodal generation is only supported for Claude-3 and beyond.") - return self._generate_multimodal(content, system_prompt=system_prompt, **kwargs) + return self._generate_from_multiple_input(content, system_prompt=system_prompt, **kwargs) - def _generate_text( - self, prompt, system_prompt=None, temperature=0, max_tokens=2000, top_p=0.99 + def _generate_from_single_prompt( + self, prompt: str, system_prompt: str=None, temperature=0, max_tokens=2000, top_p=0.99 ): sys_prompt_arg = system_prompt if system_prompt else self.system_prompt @@ -105,7 +106,7 @@ def _format_content(self, content: List[Union[str, bytes]]) -> List[dict]: raise ValueError(f"Unsupported input type: {type(item)}") return formatted_content - def _generate_multimodal( + def _generate_from_multiple_input( self, content: List[Union[str, bytes]], system_prompt=None, temperature=0, max_tokens=2000, top_p=0.99 ): sys_prompt_arg = system_prompt if system_prompt else self.system_prompt diff --git a/textgrad/engine/openai.py b/textgrad/engine/openai.py index 3723317..49f2219 100644 --- a/textgrad/engine/openai.py +++ b/textgrad/engine/openai.py @@ -22,8 +22,8 @@ class ChatOpenAI(EngineLM, CachedEngine): def __init__( self, - model_string="gpt-3.5-turbo-0613", - system_prompt=DEFAULT_SYSTEM_PROMPT, + model_string: str="gpt-3.5-turbo-0613", + system_prompt: str=DEFAULT_SYSTEM_PROMPT, is_multimodal: bool=False, **kwargs): """ @@ -32,8 +32,6 @@ def __init__( """ root = platformdirs.user_cache_dir("textgrad") cache_path = os.path.join(root, f"cache_openai_{model_string}.db") - self.image_cache_dir = os.path.join(root, "image_cache") - os.makedirs(self.image_cache_dir, exist_ok=True) super().__init__(cache_path=cache_path) @@ -48,17 +46,19 @@ def __init__( self.is_multimodal = is_multimodal @retry(wait=wait_random_exponential(min=1, max=5), stop=stop_after_attempt(5)) - def generate(self, content: Union[str, List[Union[str, bytes]]], system_prompt=None, **kwargs): - if any(isinstance(item, bytes) for item in content): - if not self.is_multimodal: - raise NotImplementedError("Multimodal generation is only supported for GPT-4 models.") - - return self._generate_multimodal(content, system_prompt=system_prompt, **kwargs) - - return self._generate_text(content, system_prompt=system_prompt, **kwargs) - - def _generate_text( - self, prompt, system_prompt=None, temperature=0, max_tokens=2000, top_p=0.99 + def generate(self, content: Union[str, List[Union[str, bytes]]], system_prompt: str=None, **kwargs): + if isinstance(content, str): + return self._generate_from_single_prompt(content, system_prompt=system_prompt, **kwargs) + + elif isinstance(content, list): + has_multimodal_input = any(isinstance(item, bytes) for item in content) + if (has_multimodal_input) and (not self.is_multimodal): + raise NotImplementedError("Multimodal generation is only supported for Claude-3 and beyond.") + + return self._generate_from_multiple_input(content, system_prompt=system_prompt, **kwargs) + + def _generate_from_single_prompt( + self, prompt: str, system_prompt: str=None, temperature=0, max_tokens=2000, top_p=0.99 ): sys_prompt_arg = system_prompt if system_prompt else self.system_prompt @@ -89,9 +89,12 @@ def __call__(self, prompt, **kwargs): return self.generate(prompt, **kwargs) def _format_content(self, content: List[Union[str, bytes]]) -> List[dict]: + """Helper function to format a list of strings and bytes into a list of dictionaries to pass as messages to the API. + """ formatted_content = [] for item in content: if isinstance(item, bytes): + # For now, bytes are assumed to be images image_type = get_image_type_from_bytes(item) base64_image = base64.b64encode(item).decode('utf-8') formatted_content.append({ @@ -109,7 +112,7 @@ def _format_content(self, content: List[Union[str, bytes]]) -> List[dict]: raise ValueError(f"Unsupported input type: {type(item)}") return formatted_content - def _generate_multimodal( + def _generate_from_multiple_input( self, content: List[Union[str, bytes]], system_prompt=None, temperature=0, max_tokens=2000, top_p=0.99 ): sys_prompt_arg = system_prompt if system_prompt else self.system_prompt