From 138443c570c2b97d7815fc0ae8b6de2ca4725555 Mon Sep 17 00:00:00 2001 From: vinid Date: Sun, 7 Jul 2024 17:09:24 -0400 Subject: [PATCH] additional tests and solving small issue --- tests/test_basics.py | 71 +++++++++++++++++++++++++++++++++++ textgrad/utils/image_utils.py | 3 -- textgrad/variable.py | 5 ++- 3 files changed, 74 insertions(+), 5 deletions(-) diff --git a/tests/test_basics.py b/tests/test_basics.py index f9039e2..ac5225c 100644 --- a/tests/test_basics.py +++ b/tests/test_basics.py @@ -1,5 +1,6 @@ import os import pytest +from typing import Union, List import logging @@ -18,6 +19,26 @@ def generate(self, prompt, system_prompt=None, **kwargs): def __call__(self, prompt, system_prompt=None): return self.generate(prompt) +class DummyMultimodalEngine(EngineLM): + + def __init__(self, is_multimodal=False): + self.is_multimodal = is_multimodal + self.model_string = "gpt-4o" # fake + + def generate(self, content: Union[str, List[Union[str, bytes]]], system_prompt: str = None, **kwargs): + if isinstance(content, str): + return "Hello Text" + + elif isinstance(content, list): + has_multimodal_input = any(isinstance(item, bytes) for item in content) + if (has_multimodal_input) and (not self.is_multimodal): + raise NotImplementedError("Multimodal generation is only supported for Claude-3 and beyond.") + + return "Hello Text from Image" + + def __call__(self, prompt, system_prompt=None): + return self.generate(prompt) + # Idempotent engine that returns the prompt as is class IdempotentEngine(EngineLM): def generate(self, prompt, system_prompt=None, **kwargs): @@ -124,3 +145,53 @@ def test_formattedllmcall(): assert inputs["question"] in output.predecessors assert inputs["prediction"] in output.predecessors assert output.get_role_description() == "test response" + + +def test_multimodal(): + from textgrad.autograd import MultimodalLLMCall, LLMCall + from textgrad import Variable + import httpx + + image_url = "https://upload.wikimedia.org/wikipedia/commons/a/a7/Camponotus_flavomarginatus_ant.jpg" + image_data = httpx.get(image_url).content + + os.environ['OPENAI_API_KEY'] = "fake_key" + engine = DummyMultimodalEngine(is_multimodal=True) + + image_variable = Variable(image_data, + role_description="image to answer a question about", requires_grad=False) + + text = Variable("Hello", role_description="A variable") + question_variable = Variable("What do you see in this image?", role_description="question", requires_grad=False) + response = MultimodalLLMCall(engine=engine)([image_variable, question_variable]) + + assert response.value == "Hello Text from Image" + + response = LLMCall(engine=engine)(text) + + assert response.value == "Hello Text" + + ## llm call cannot handle images + with pytest.raises(AttributeError): + response = LLMCall(engine=engine)([text, image_variable]) + + # this is just to check the content, we can't really have int variables but + # it's just for testing purposes + with pytest.raises(AssertionError): + response = MultimodalLLMCall(engine=engine)([Variable(4, role_description="tst"), + Variable(5, role_description="tst")]) + +def test_multimodal_from_url(): + from textgrad import Variable + import httpx + + image_url = "https://upload.wikimedia.org/wikipedia/commons/a/a7/Camponotus_flavomarginatus_ant.jpg" + image_data = httpx.get(image_url).content + + image_variable = Variable(image_path=image_url, + role_description="image to answer a question about", requires_grad=False) + + image_variable_2 = Variable(image_data, + role_description="image to answer a question about", requires_grad=False) + + assert image_variable_2.value == image_variable.value \ No newline at end of file diff --git a/textgrad/utils/image_utils.py b/textgrad/utils/image_utils.py index fb75c2d..717284b 100644 --- a/textgrad/utils/image_utils.py +++ b/textgrad/utils/image_utils.py @@ -1,10 +1,7 @@ import os import requests import hashlib -from urllib.parse import urlparse -from typing import Union import platformdirs - from urllib.parse import urlparse def is_valid_url(url): diff --git a/textgrad/variable.py b/textgrad/variable.py index 7f3af88..1dfd249 100644 --- a/textgrad/variable.py +++ b/textgrad/variable.py @@ -49,8 +49,9 @@ def __init__( if image_path != "": if is_valid_url(image_path): self.value = httpx.get(image_path).content - with open(image_path, 'rb') as file: - self.value = file.read() + else: + with open(image_path, 'rb') as file: + self.value = file.read() else: self.value = value