Skip to content

Commit

Permalink
additional tests and solving small issue
Browse files Browse the repository at this point in the history
  • Loading branch information
vinid committed Jul 7, 2024
1 parent 19fe2df commit 138443c
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 5 deletions.
71 changes: 71 additions & 0 deletions tests/test_basics.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import pytest
from typing import Union, List
import logging


Expand All @@ -18,6 +19,26 @@ def generate(self, prompt, system_prompt=None, **kwargs):
def __call__(self, prompt, system_prompt=None):
return self.generate(prompt)

class DummyMultimodalEngine(EngineLM):

def __init__(self, is_multimodal=False):
self.is_multimodal = is_multimodal
self.model_string = "gpt-4o" # fake

def generate(self, content: Union[str, List[Union[str, bytes]]], system_prompt: str = None, **kwargs):
if isinstance(content, str):
return "Hello Text"

elif isinstance(content, list):
has_multimodal_input = any(isinstance(item, bytes) for item in content)
if (has_multimodal_input) and (not self.is_multimodal):
raise NotImplementedError("Multimodal generation is only supported for Claude-3 and beyond.")

return "Hello Text from Image"

def __call__(self, prompt, system_prompt=None):
return self.generate(prompt)

# Idempotent engine that returns the prompt as is
class IdempotentEngine(EngineLM):
def generate(self, prompt, system_prompt=None, **kwargs):
Expand Down Expand Up @@ -124,3 +145,53 @@ def test_formattedllmcall():
assert inputs["question"] in output.predecessors
assert inputs["prediction"] in output.predecessors
assert output.get_role_description() == "test response"


def test_multimodal():
from textgrad.autograd import MultimodalLLMCall, LLMCall
from textgrad import Variable
import httpx

image_url = "https://upload.wikimedia.org/wikipedia/commons/a/a7/Camponotus_flavomarginatus_ant.jpg"
image_data = httpx.get(image_url).content

os.environ['OPENAI_API_KEY'] = "fake_key"
engine = DummyMultimodalEngine(is_multimodal=True)

image_variable = Variable(image_data,
role_description="image to answer a question about", requires_grad=False)

text = Variable("Hello", role_description="A variable")
question_variable = Variable("What do you see in this image?", role_description="question", requires_grad=False)
response = MultimodalLLMCall(engine=engine)([image_variable, question_variable])

assert response.value == "Hello Text from Image"

response = LLMCall(engine=engine)(text)

assert response.value == "Hello Text"

## llm call cannot handle images
with pytest.raises(AttributeError):
response = LLMCall(engine=engine)([text, image_variable])

# this is just to check the content, we can't really have int variables but
# it's just for testing purposes
with pytest.raises(AssertionError):
response = MultimodalLLMCall(engine=engine)([Variable(4, role_description="tst"),
Variable(5, role_description="tst")])

def test_multimodal_from_url():
from textgrad import Variable
import httpx

image_url = "https://upload.wikimedia.org/wikipedia/commons/a/a7/Camponotus_flavomarginatus_ant.jpg"
image_data = httpx.get(image_url).content

image_variable = Variable(image_path=image_url,
role_description="image to answer a question about", requires_grad=False)

image_variable_2 = Variable(image_data,
role_description="image to answer a question about", requires_grad=False)

assert image_variable_2.value == image_variable.value
3 changes: 0 additions & 3 deletions textgrad/utils/image_utils.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
import os
import requests
import hashlib
from urllib.parse import urlparse
from typing import Union
import platformdirs

from urllib.parse import urlparse

def is_valid_url(url):
Expand Down
5 changes: 3 additions & 2 deletions textgrad/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,9 @@ def __init__(
if image_path != "":
if is_valid_url(image_path):
self.value = httpx.get(image_path).content
with open(image_path, 'rb') as file:
self.value = file.read()
else:
with open(image_path, 'rb') as file:
self.value = file.read()
else:
self.value = value

Expand Down

0 comments on commit 138443c

Please sign in to comment.