Skip to content

Commit

Permalink
support saving images locally in image generation tool
Browse files Browse the repository at this point in the history
  • Loading branch information
whimo committed Apr 19, 2024
1 parent 0e0064c commit c396619
Show file tree
Hide file tree
Showing 4 changed files with 626 additions and 599 deletions.
20 changes: 11 additions & 9 deletions examples/image_generation_crewai.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,14 @@
configure_logging(verbose=True)

image_generator_tool = DallEImageGeneratorTool()
# For saving images locally use the line below
# image_generator_tool = DallEImageGeneratorTool(images_directory="images")

writer = CrewAIMotleyAgent(
role="Short stories writer",
goal="Write short stories for children",
backstory="""You are an accomplished children's writer, known for your funny and interesting short stories.
Many parents around the world read your books to their children.""",
backstory="You are an accomplished children's writer, known for your funny and interesting short stories.\n"
"Many parents around the world read your books to their children.",
verbose=True,
delegation=True,
)
Expand All @@ -23,8 +25,8 @@
role="Illustrator",
goal="Create beautiful and insightful illustrations",
backstory="You are an expert in creating illustrations for all sorts of concepts and articles. "
"You do it by skillfully prompting a text-to-image model.\n"
"Your final answer MUST be the exact URL of the illustration.",
"You do it by skillfully prompting a text-to-image model.\n"
"Your final answer MUST be the exact URL or filename of the illustration.",
verbose=True,
delegation=False,
tools=[image_generator_tool],
Expand All @@ -36,11 +38,11 @@
crew=crew,
name="write a short story about a cat",
description="Creatively write a short story of about 4 paragraphs "
"about a house cat that was smarter than its owners. \n"
"Write it in a cool and simple language, "
"making it intriguing yet suitable for children's comprehension.\n"
"You must include a fun illustration.\n"
"Your final answer MUST be the full story with the illustration URL attached.",
"about a house cat that was smarter than its owners. \n"
"Write it in a cool and simple language, "
"making it intriguing yet suitable for children's comprehension.\n"
"You must include a fun illustration.\n"
"Your final answer MUST be the full story with the illustration URL or filename attached.",
agent=writer,
)

Expand Down
28 changes: 24 additions & 4 deletions motleycrew/common/utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,15 @@
from typing import Sequence
from typing import Optional, Sequence
import logging
import hashlib
from urllib.parse import urlparse
from langchain_core.messages import BaseMessage


def configure_logging(verbose: bool = False):
level = logging.INFO if verbose else logging.WARNING
logging.basicConfig(format="%(asctime)s - %(levelname)s - %(message)s", level=level)


def to_str(value: str | BaseMessage | Sequence[str] | Sequence[BaseMessage]) -> str:
if isinstance(value, str):
return value
Expand All @@ -17,6 +24,19 @@ def to_str(value: str | BaseMessage | Sequence[str] | Sequence[BaseMessage]) ->
)


def configure_logging(verbose: bool = False):
level = logging.INFO if verbose else logging.WARNING
logging.basicConfig(format='%(asctime)s - %(levelname)s - %(message)s', level=level)
def is_http_url(url):
try:
parsed_url = urlparse(url)
return parsed_url.scheme in ["http", "https"]
except ValueError:
return False


def generate_hex_hash(data: str, length: Optional[int] = None):
hash_obj = hashlib.sha256()
hash_obj.update(data.encode("utf-8"))
hex_hash = hash_obj.hexdigest()

if length is not None:
hex_hash = hex_hash[:length]
return hex_hash
74 changes: 68 additions & 6 deletions motleycrew/tool/image_generation.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,89 @@
import logging
from typing import Optional

import os
import requests
import mimetypes

from langchain.agents import Tool
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_community.utilities.dalle_image_generator import DallEAPIWrapper

from .tool import MotleyTool
import motleycrew.common.utils as motley_utils


def download_image(url: str, file_path: str) -> Optional[str]:
response = requests.get(url, stream=True)
if response.status_code == requests.codes.ok:
content_type = response.headers.get("content-type")
extension = mimetypes.guess_extension(content_type)
if not extension:
extension = ".png" # default to .png if content-type is not recognized

file_path_with_extension = file_path + extension
logging.info("Downloading image %s to %s", url, file_path_with_extension)

with open(file_path_with_extension, "wb") as f:
for chunk in response:
f.write(chunk)

return file_path_with_extension
else:
logging.error("Failed to download image. Status code: %s", response.status_code)


class DallEImageGeneratorTool(MotleyTool):
def __init__(self):
langchain_tool = create_dalle_image_generator_langchain_tool()
def __init__(self, images_directory: Optional[str] = None):
langchain_tool = create_dalle_image_generator_langchain_tool(
images_directory=images_directory
)
super().__init__(langchain_tool)


class DallEToolInput(BaseModel):
"""Input for the Dall-E tool."""

query: str = Field(description="image generation query")
description: str = Field(description="image description")


def run_dalle_and_save_images(
description: str, images_directory: Optional[str] = None, file_name_length: int = 8
) -> Optional[list[str]]:
dalle_api = DallEAPIWrapper()
dalle_result = dalle_api.run(query=description)
logging.info("Dall-E API output: %s", dalle_result)

urls = dalle_result.split(dalle_api.separator)
if not len(urls) or not motley_utils.is_http_url(urls[0]):
logging.error("Dall-E API did not return a valid url: %s", dalle_result)
return

if images_directory:
os.makedirs(images_directory, exist_ok=True)
file_paths = []
for url in urls:
file_name = motley_utils.generate_hex_hash(url, length=file_name_length)
file_path = os.path.join(images_directory, file_name)

file_path_with_extension = download_image(url=url, file_path=file_path)
file_paths.append(file_path_with_extension)
return file_paths
else:
logging.info("Images directory is not provided, returning URLs")
return urls


def create_dalle_image_generator_langchain_tool(images_directory: Optional[str] = None):
def run_dalle_and_save_images_partial(description: str):
return run_dalle_and_save_images(
description=description, images_directory=images_directory
)

def create_dalle_image_generator_langchain_tool():
return Tool(
name="Dall-E-Image-Generator",
func=DallEAPIWrapper().run,
func=run_dalle_and_save_images_partial,
description="A wrapper around OpenAI DALL-E API. Useful for when you need to generate images from a text description. "
"Input should be an image description.",
"Input should be an image description.",
args_schema=DallEToolInput,
)
Loading

0 comments on commit c396619

Please sign in to comment.