Skip to content

Commit

Permalink
Merge pull request #1847 from h2oai/image_generation_tool
Browse files Browse the repository at this point in the history
More tools for Agents
  • Loading branch information
pseudotensor committed Sep 19, 2024
2 parents 49dd0ed + 133bf80 commit dd8ad17
Show file tree
Hide file tree
Showing 22 changed files with 561 additions and 39 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ https://github.com/h2oai/h2ogpt/assets/2249614/2f805035-2c85-42fb-807f-fd0bca79a
- **Gradio UI** or CLI with streaming of all models
- **Upload** and **View** documents through the UI (control multiple collaborative or personal collections)
- **Vision Models** LLaVa, Claude-3, Gemini-Pro-Vision, GPT-4-Vision
- **Image Generation** Stable Diffusion (sdxl-turbo, sdxl, SD3) and PlaygroundAI (playv2)
- **Image Generation** Stable Diffusion (sdxl-turbo, sdxl, SD3), PlaygroundAI (playv2), and Flux
- **Voice STT** using Whisper with streaming audio conversion
- **Voice TTS** using MIT-Licensed Microsoft Speech T5 with multiple voices and Streaming audio conversion
- **Voice TTS** using MPL2-Licensed TTS including Voice Cloning and Streaming audio conversion
Expand Down
10 changes: 9 additions & 1 deletion gradio_utils/grclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from enum import Enum
from functools import lru_cache
from pathlib import Path
from typing import Callable, Generator, Any, Union, List, Dict, Literal
from typing import Callable, Generator, Any, Union, List, Dict, Literal, Tuple
import ast
import inspect
import numpy as np
Expand Down Expand Up @@ -798,6 +798,10 @@ def query_or_summarize_or_extract(
tts_language: str = "autodetect",
tts_speed: float = 1.0,
visible_image_models: List[str] = [],
image_size: str = "1024x1024",
image_quality: str = 'standard',
image_guidance_scale: float = 3.0,
image_num_inference_steps: int = 30,
visible_models: Union[str, int, list] = None,
# don't use the below (no doc string stuff) block
num_return_sequences: int = None,
Expand Down Expand Up @@ -976,6 +980,10 @@ def query_or_summarize_or_extract(
:param tts_speed: Default speed of TTS, < 1.0 (needs rubberband) for slower than normal, > 1.0 for faster. Tries to keep fixed pitch.
:param visible_image_models: Which image gen models to include
:param image_size
:param image_quality
:param image_guidance_scale
:param image_num_inference_steps
:param visible_models: Which models in model_lock list to show by default
Takes integers of position in model_lock (model_states) list or strings of base_model names
Ignored if model_lock not used
Expand Down
98 changes: 94 additions & 4 deletions openai_server/agent_prompting.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import ast
import os
import tempfile
import time
Expand Down Expand Up @@ -76,10 +77,12 @@ def agent_system_prompt(agent_code_writer_system_message, agent_system_site_pack
* You DO have access to the internet.{serp}{papers_search}{wolframalpha}{news_api}
* Example Public APIs (not limited to these): wttr.in (weather) or research papers (arxiv).
* Only generate code with API code that uses publicly available APIs or uses API keys already given.
* Do not generate code that requires any API keys or credentials that were not already given."""
* Do not generate code that requires any API keys or credentials that were not already given.
* You CAN use API and API keys given to you by user or in any document context and you CAN run code using those API keys."""
else:
apis = """\nAPIs and external services instructions:
* You DO NOT have access to the internet. You cannot use any APIs that require internet access."""
* You DO NOT have access to the internet. You cannot use any APIs that require broad internet access.
* You CAN use API and API keys given to you by user or in any document context and you CAN run code using those API keys."""
agent_code_writer_system_message = f"""You are a helpful AI assistant. Solve tasks using your coding and language skills.
* {date_str}
Query understanding instructions:
Expand Down Expand Up @@ -448,7 +451,6 @@ def get_image_query_helper(base_url, api_key, model):
model_list = client.models.list()
image_models = [x.id for x in model_list if x.model_extra['actually_image']]
we_are_vision_model = len([x for x in model_list if x.id == model]) > 0
image_query_helper = ''
if we_are_vision_model:
vision_model = model
elif not we_are_vision_model and len(image_models) > 0:
Expand Down Expand Up @@ -501,13 +503,100 @@ def get_mermaid_renderer_helper():
return mmdc


def get_image_generation_helper():
imagegen_url = os.getenv("IMAGEGEN_OPENAI_BASE_URL", '')
if imagegen_url:
cwd = os.path.abspath(os.getcwd())

quality_string = "[--quality {quality}]"
if imagegen_url == "https://api.gpt.h2o.ai/v1":
if os.getenv("IMAGEGEN_OPENAI_MODELS"):
models = ast.literal_eval(os.getenv("IMAGEGEN_OPENAI_MODELS"))
else:
models = "['flux.1-schnell', 'playv2']"
quality_options = "['standard', 'hd', 'quick', 'manual']"
style_options = "* Choose playv2 model for more artistic renderings, flux.1-schnell for more accurate renderings."
guidance_steps_string = """
* Only applicable of quality is set to manual. guidance_scale is 3.0 by default, can be 0.0 to 10.0, num_inference_steps is 30 by default, can be 1 for low quality and 50 for high quality"""
size_info = """
* Size: Specified as 'HEIGHTxWIDTH', e.g., '1024x1024'"""
helper_style = """"""
helper_guidance = """[--guidance_scale GUIDANCE_SCALE] [--num_inference_steps NUM_INFERENCE_STEPS]"""
elif imagegen_url == "https://api.openai.com/v1" or 'openai.azure.com' in imagegen_url:
if os.getenv("IMAGEGEN_OPENAI_MODELS"):
models = ast.literal_eval(os.getenv("IMAGEGEN_OPENAI_MODELS"))
else:
models = "['dall-e-2', 'dall-e-3']"
quality_options = "['standard', 'hd']"
style_options = """
* Style options: ['vivid', 'natural']"""
guidance_steps_string = ''
size_info = """
* Size allowed for dall-e-2: ['256x256', '512x512', '1024x1024']
* Size allowed for dall-e-3: ['1024x1024', '1792x1024', '1024x1792']"""
helper_style = """[--style STYLE]"""
helper_guidance = """"""
else:
models = ast.literal_eval(os.getenv("IMAGEGEN_OPENAI_MODELS")) # must be set then
quality_options = "['standard', 'hd', 'quick', 'manual']"
style_options = ""
# probably local host or local pod, so allow
guidance_steps_string = """
* Only applicable of quality is set to manual. guidance_scale is 3.0 by default, can be 0.0 to 10.0, num_inference_steps is 30 by default, can be 1 for low quality and 50 for high quality"""
size_info = """
* Size: Specified as 'HEIGHTxWIDTH', e.g., '1024x1024'"""
helper_style = """"""
helper_guidance = """[--guidance_scale GUIDANCE_SCALE] [--num_inference_steps NUM_INFERENCE_STEPS]"""

image_generation = f"""\n* Image generation using python. Use for generating images from prompt.
* For image generation, you are recommended to use the existing pre-built python code, E.g.:
```sh
# filename: my_image_generation.sh
# execution: true
python {cwd}/openai_server/agent_tools/image_generation.py --prompt "PROMPT"
```
* usage: python {cwd}/openai_server/agent_tools/image_generation.py [-h] --prompt PROMPT [--output OUTPUT_FILE_NAME] [--model MODEL] {quality_string} {helper_style} {helper_guidance}
* Available models: {models}
* Quality options: {quality_options}{size_info}{style_options}{guidance_steps_string}
* As a helpful assistant, you will convert the user's requested image generation prompt into an excellent prompt, unless the user directly requests a specific prompt be used for image generation.
* Image generation takes about 10-20s per image, so do not automatically generate too many images at once.
* However, if the user directly requests many images or anything related to images, then you MUST follow their instructions no matter what.
* Do not do an image_query on the image generated, unless user directly asks for an analysis of the image generated or the user directly asks for automatic improvement of the image generated.
"""
else:
image_generation = ''
return image_generation


def get_audio_transcription_helper():
stt_url = os.getenv("STT_OPENAI_BASE_URL", '')
if stt_url:
if not os.getenv("STT_OPENAI_MODEL"):
os.environ["STT_OPENAI_MODEL"] = "whisper-1"
cwd = os.path.abspath(os.getcwd())
audio_transcription = f"""\n* Audio transcription using python. Use for transcribing audio files to text.
* For an audio transcription, you are recommended to use the existing pre-built python code, E.g.:
```sh
# filename: my_audio_transcription.sh
# execution: true
python {cwd}/openai_server/agent_tools/audio_transcription.py --file_path "./audio.wav"
```
* usage: python {cwd}/openai_server/agent_tools/audio_transcription.py [-h] --file_path FILE_PATH
"""
else:
audio_transcription = ''
return audio_transcription


def get_full_system_prompt(agent_code_writer_system_message, agent_system_site_packages, system_prompt, base_url,
api_key, model, text_context_list, image_file, temp_dir, query):
agent_code_writer_system_message = agent_system_prompt(agent_code_writer_system_message,
agent_system_site_packages)

image_query_helper = get_image_query_helper(base_url, api_key, model)
mermaid_renderer_helper = get_mermaid_renderer_helper()
image_generation_helper = get_image_generation_helper()
audio_transcription_helper = get_audio_transcription_helper()

chat_doc_query, internal_file_names = get_chat_doc_context(text_context_list, image_file,
temp_dir,
Expand All @@ -524,5 +613,6 @@ def get_full_system_prompt(agent_code_writer_system_message, agent_system_site_p

agent_tools_note = f"\nDo not hallucinate agent_tools tools. The only files in the {path_agent_tools} directory are as follows: {list_dir}\n"

system_message = agent_code_writer_system_message + image_query_helper + mermaid_renderer_helper + agent_tools_note + chat_doc_query
system_message = agent_code_writer_system_message + image_query_helper + mermaid_renderer_helper + image_generation_helper + audio_transcription_helper + agent_tools_note + chat_doc_query
# TODO: Also return image_generation_helper and audio_transcription_helper ?
return system_message, internal_file_names, chat_doc_query, image_query_helper, mermaid_renderer_helper
47 changes: 47 additions & 0 deletions openai_server/agent_tools/audio_transcription.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
import os
import argparse
import uuid

from openai import OpenAI


def main():
parser = argparse.ArgumentParser(description="Get transcription of an audio file")
# Model
parser.add_argument("--model", type=str, required=False, help="Model name")
# File name
parser.add_argument("--output", type=str, default='', required=False, help="Path (ensure unique) to the audio file")
args = parser.parse_args()
##
stt_url = os.getenv("STT_OPENAI_BASE_URL", None)
assert stt_url is not None, "STT_OPENAI_BASE_URL environment variable is not set"
stt_api_key = os.getenv('STT_OPENAI_API_KEY', 'EMPTY')

if not args.model:
stt_model = os.getenv('STT_OPENAI_MODEL')
assert stt_model is not None, "STT_OPENAI_MODEL environment variable is not set"
args.model = stt_model

# Read the audio file
audio_file = open(args.file_path, "rb")
client = OpenAI(base_url=stt_url, api_key=stt_api_key)
transcription = client.audio.transcriptions.create(
model=args.model,
file=audio_file
)
# Save the image to a file
if not args.output:
args.output = f"transcription_{str(uuid.uuid4())[:6]}.txt"
# Write the transcription to a file
with open(args.output, "wt") as txt_file:
txt_file.write(transcription.text)

full_path = os.path.abspath(args.output)
print(f"Transcription successfully saved to the file: {full_path}")
# generally too much, have agent read if too long for context of LLM
if len(transcription.text) < 1024:
print(f"Audio file successfully transcribed as follows:\n\n{transcription.text}")


if __name__ == "__main__":
main()
164 changes: 164 additions & 0 deletions openai_server/agent_tools/image_generation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
import ast
import base64
import os
import argparse
import tempfile
import uuid


def main():
parser = argparse.ArgumentParser(description="Generate images from text prompts")
parser.add_argument("--prompt", type=str, required=True, help="User prompt")
parser.add_argument("--model", type=str, required=False, help="Model name")
parser.add_argument("--output", type=str, required=False, default="", help="Name (unique) of the output file")
parser.add_argument("--quality", type=str, required=False, choices=['standard', 'hd', 'quick', 'manual'], default='standard',
help="Image quality")
parser.add_argument("--size", type=str, required=False, default="1024x1024", help="Image size (height x width)")

imagegen_url = os.getenv("IMAGEGEN_OPENAI_BASE_URL", '')
assert imagegen_url is not None, "IMAGEGEN_OPENAI_BASE_URL environment variable is not set"
server_api_key = os.getenv('IMAGEGEN_OPENAI_API_KEY', 'EMPTY')

generation_params = {}

is_openai = False
if imagegen_url == "https://api.gpt.h2o.ai/v1":
parser.add_argument("--guidance_scale", type=float, help="Guidance scale for image generation")
parser.add_argument("--num_inference_steps", type=int, help="Number of inference steps")
args = parser.parse_args()
from openai import OpenAI
client = OpenAI(base_url=imagegen_url, api_key=server_api_key)
available_models = ['flux.1-schnell', 'playv2']
if os.getenv('IMAGEGEN_OPENAI_MODELS'):
# allow override
available_models = ast.literal_eval(os.getenv('IMAGEGEN_OPENAI_MODELS'))
if not args.model:
args.model = available_models[0]
if args.model not in available_models:
args.model = available_models[0]
elif imagegen_url == "https://api.openai.com/v1" or 'openai.azure.com' in imagegen_url:
is_openai = True
parser.add_argument("--style", type=str, choices=['vivid', 'natural', 'artistic'], default='vivid',
help="Image style")
args = parser.parse_args()
# https://platform.openai.com/docs/api-reference/images/create
available_models = ['dall-e-3', 'dall-e-2']
# assumes deployment name matches model name, unless override
if os.getenv('IMAGEGEN_OPENAI_MODELS'):
# allow override
available_models = ast.literal_eval(os.getenv('IMAGEGEN_OPENAI_MODELS'))
if not args.model:
args.model = available_models[0]
if args.model not in available_models:
args.model = available_models[0]

if 'openai.azure.com' in imagegen_url:
# https://learn.microsoft.com/en-us/azure/ai-services/openai/dall-e-quickstart?tabs=dalle3%2Ccommand-line%2Ctypescript&pivots=programming-language-python
from openai import AzureOpenAI
client = AzureOpenAI(
api_version="2024-02-01" if args.model == 'dall-e-3' else '2023-06-01-preview',
api_key=os.environ["IMAGEGEN_OPENAI_API_KEY"],
# like base_url, but Azure endpoint like https://PROJECT.openai.azure.com/
azure_endpoint=os.environ['IMAGEGEN_OPENAI_BASE_URL']
)
else:
from openai import OpenAI
client = OpenAI(base_url=imagegen_url, api_key=server_api_key)

dalle2aliases = ['dall-e-2', 'dalle2', 'dalle-2']
max_chars = 1000 if args.model in dalle2aliases else 4000
args.prompt = args.prompt[:max_chars]

if args.model in dalle2aliases:
valid_sizes = ['256x256', '512x512', '1024x1024']
else:
valid_sizes = ['1024x1024', '1792x1024', '1024x1792']

if args.size not in valid_sizes:
args.size = valid_sizes[0]

args.quality = 'standard' if args.quality not in ['standard', 'hd'] else args.quality
args.style = 'vivid' if args.style not in ['vivid', 'natural'] else args.style
generation_params.update({
"style": args.style,
})
else:
parser.add_argument("--guidance_scale", type=float, help="Guidance scale for image generation")
parser.add_argument("--num_inference_steps", type=int, help="Number of inference steps")
args = parser.parse_args()

from openai import OpenAI
client = OpenAI(base_url=imagegen_url, api_key=server_api_key)
assert os.getenv('IMAGEGEN_OPENAI_MODELS'), "IMAGEGEN_OPENAI_MODELS environment variable is not set"
available_models = ast.literal_eval(os.getenv('IMAGEGEN_OPENAI_MODELS')) # must be string of list of strings
assert available_models, "IMAGEGEN_OPENAI_MODELS environment variable is not set, must be for this server"
if args.model is None:
args.model = available_models[0]
if args.model not in available_models:
args.model = available_models[0]

# for azure, args.model use assume deployment name matches model name (i.e. dall-e-3 not dalle3) unless IMAGEGEN_OPENAI_MODELS set
generation_params.update({
"prompt": args.prompt,
"model": args.model,
"quality": args.quality,
"size": args.size,
"response_format": "b64_json",
})

if not is_openai:
extra_body = {}
if args.guidance_scale:
extra_body["guidance_scale"] = args.guidance_scale
if args.num_inference_steps:
extra_body["num_inference_steps"] = args.num_inference_steps
if extra_body:
generation_params["extra_body"] = extra_body

response = client.images.generate(**generation_params)

if hasattr(response.data[0], 'revised_prompt') and response.data[0].revised_prompt:
print("Image Generator revised the prompt (this is expected): %s" % response.data[0].revised_prompt)

assert response.data[0].b64_json is not None or response.data[0].url is not None, "No image data returned"

if response.data[0].b64_json:
image_data_base64 = response.data[0].b64_json
image_data = base64.b64decode(image_data_base64)
else:
from src.utils import download_simple
dest = download_simple(response.data[0].url, overwrite=True)
with open(dest, "rb") as f:
image_data = f.read()
os.remove(dest)

# Determine file type and name
image_format = get_image_format(image_data)
if not args.output:
args.output = f"image_{str(uuid.uuid4())[:6]}.{image_format}"
else:
# If an output path is provided, ensure it has the correct extension
base, ext = os.path.splitext(args.output)
if ext.lower() != f".{image_format}":
args.output = f"{base}.{image_format}"

# Write the image data to a file
with open(args.output, "wb") as img_file:
img_file.write(image_data)

full_path = os.path.abspath(args.output)
print(f"Image successfully saved to the file: {full_path}")

# NOTE: Could provide stats like image size, etc.


def get_image_format(image_data):
from PIL import Image
import io
# Use PIL to determine the image format
with Image.open(io.BytesIO(image_data)) as img:
return img.format.lower()


if __name__ == "__main__":
main()
Loading

0 comments on commit dd8ad17

Please sign in to comment.