-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #1847 from h2oai/image_generation_tool
More tools for Agents
- Loading branch information
Showing
22 changed files
with
561 additions
and
39 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
import os | ||
import argparse | ||
import uuid | ||
|
||
from openai import OpenAI | ||
|
||
|
||
def main(): | ||
parser = argparse.ArgumentParser(description="Get transcription of an audio file") | ||
# Model | ||
parser.add_argument("--model", type=str, required=False, help="Model name") | ||
# File name | ||
parser.add_argument("--output", type=str, default='', required=False, help="Path (ensure unique) to the audio file") | ||
args = parser.parse_args() | ||
## | ||
stt_url = os.getenv("STT_OPENAI_BASE_URL", None) | ||
assert stt_url is not None, "STT_OPENAI_BASE_URL environment variable is not set" | ||
stt_api_key = os.getenv('STT_OPENAI_API_KEY', 'EMPTY') | ||
|
||
if not args.model: | ||
stt_model = os.getenv('STT_OPENAI_MODEL') | ||
assert stt_model is not None, "STT_OPENAI_MODEL environment variable is not set" | ||
args.model = stt_model | ||
|
||
# Read the audio file | ||
audio_file = open(args.file_path, "rb") | ||
client = OpenAI(base_url=stt_url, api_key=stt_api_key) | ||
transcription = client.audio.transcriptions.create( | ||
model=args.model, | ||
file=audio_file | ||
) | ||
# Save the image to a file | ||
if not args.output: | ||
args.output = f"transcription_{str(uuid.uuid4())[:6]}.txt" | ||
# Write the transcription to a file | ||
with open(args.output, "wt") as txt_file: | ||
txt_file.write(transcription.text) | ||
|
||
full_path = os.path.abspath(args.output) | ||
print(f"Transcription successfully saved to the file: {full_path}") | ||
# generally too much, have agent read if too long for context of LLM | ||
if len(transcription.text) < 1024: | ||
print(f"Audio file successfully transcribed as follows:\n\n{transcription.text}") | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,164 @@ | ||
import ast | ||
import base64 | ||
import os | ||
import argparse | ||
import tempfile | ||
import uuid | ||
|
||
|
||
def main(): | ||
parser = argparse.ArgumentParser(description="Generate images from text prompts") | ||
parser.add_argument("--prompt", type=str, required=True, help="User prompt") | ||
parser.add_argument("--model", type=str, required=False, help="Model name") | ||
parser.add_argument("--output", type=str, required=False, default="", help="Name (unique) of the output file") | ||
parser.add_argument("--quality", type=str, required=False, choices=['standard', 'hd', 'quick', 'manual'], default='standard', | ||
help="Image quality") | ||
parser.add_argument("--size", type=str, required=False, default="1024x1024", help="Image size (height x width)") | ||
|
||
imagegen_url = os.getenv("IMAGEGEN_OPENAI_BASE_URL", '') | ||
assert imagegen_url is not None, "IMAGEGEN_OPENAI_BASE_URL environment variable is not set" | ||
server_api_key = os.getenv('IMAGEGEN_OPENAI_API_KEY', 'EMPTY') | ||
|
||
generation_params = {} | ||
|
||
is_openai = False | ||
if imagegen_url == "https://api.gpt.h2o.ai/v1": | ||
parser.add_argument("--guidance_scale", type=float, help="Guidance scale for image generation") | ||
parser.add_argument("--num_inference_steps", type=int, help="Number of inference steps") | ||
args = parser.parse_args() | ||
from openai import OpenAI | ||
client = OpenAI(base_url=imagegen_url, api_key=server_api_key) | ||
available_models = ['flux.1-schnell', 'playv2'] | ||
if os.getenv('IMAGEGEN_OPENAI_MODELS'): | ||
# allow override | ||
available_models = ast.literal_eval(os.getenv('IMAGEGEN_OPENAI_MODELS')) | ||
if not args.model: | ||
args.model = available_models[0] | ||
if args.model not in available_models: | ||
args.model = available_models[0] | ||
elif imagegen_url == "https://api.openai.com/v1" or 'openai.azure.com' in imagegen_url: | ||
is_openai = True | ||
parser.add_argument("--style", type=str, choices=['vivid', 'natural', 'artistic'], default='vivid', | ||
help="Image style") | ||
args = parser.parse_args() | ||
# https://platform.openai.com/docs/api-reference/images/create | ||
available_models = ['dall-e-3', 'dall-e-2'] | ||
# assumes deployment name matches model name, unless override | ||
if os.getenv('IMAGEGEN_OPENAI_MODELS'): | ||
# allow override | ||
available_models = ast.literal_eval(os.getenv('IMAGEGEN_OPENAI_MODELS')) | ||
if not args.model: | ||
args.model = available_models[0] | ||
if args.model not in available_models: | ||
args.model = available_models[0] | ||
|
||
if 'openai.azure.com' in imagegen_url: | ||
# https://learn.microsoft.com/en-us/azure/ai-services/openai/dall-e-quickstart?tabs=dalle3%2Ccommand-line%2Ctypescript&pivots=programming-language-python | ||
from openai import AzureOpenAI | ||
client = AzureOpenAI( | ||
api_version="2024-02-01" if args.model == 'dall-e-3' else '2023-06-01-preview', | ||
api_key=os.environ["IMAGEGEN_OPENAI_API_KEY"], | ||
# like base_url, but Azure endpoint like https://PROJECT.openai.azure.com/ | ||
azure_endpoint=os.environ['IMAGEGEN_OPENAI_BASE_URL'] | ||
) | ||
else: | ||
from openai import OpenAI | ||
client = OpenAI(base_url=imagegen_url, api_key=server_api_key) | ||
|
||
dalle2aliases = ['dall-e-2', 'dalle2', 'dalle-2'] | ||
max_chars = 1000 if args.model in dalle2aliases else 4000 | ||
args.prompt = args.prompt[:max_chars] | ||
|
||
if args.model in dalle2aliases: | ||
valid_sizes = ['256x256', '512x512', '1024x1024'] | ||
else: | ||
valid_sizes = ['1024x1024', '1792x1024', '1024x1792'] | ||
|
||
if args.size not in valid_sizes: | ||
args.size = valid_sizes[0] | ||
|
||
args.quality = 'standard' if args.quality not in ['standard', 'hd'] else args.quality | ||
args.style = 'vivid' if args.style not in ['vivid', 'natural'] else args.style | ||
generation_params.update({ | ||
"style": args.style, | ||
}) | ||
else: | ||
parser.add_argument("--guidance_scale", type=float, help="Guidance scale for image generation") | ||
parser.add_argument("--num_inference_steps", type=int, help="Number of inference steps") | ||
args = parser.parse_args() | ||
|
||
from openai import OpenAI | ||
client = OpenAI(base_url=imagegen_url, api_key=server_api_key) | ||
assert os.getenv('IMAGEGEN_OPENAI_MODELS'), "IMAGEGEN_OPENAI_MODELS environment variable is not set" | ||
available_models = ast.literal_eval(os.getenv('IMAGEGEN_OPENAI_MODELS')) # must be string of list of strings | ||
assert available_models, "IMAGEGEN_OPENAI_MODELS environment variable is not set, must be for this server" | ||
if args.model is None: | ||
args.model = available_models[0] | ||
if args.model not in available_models: | ||
args.model = available_models[0] | ||
|
||
# for azure, args.model use assume deployment name matches model name (i.e. dall-e-3 not dalle3) unless IMAGEGEN_OPENAI_MODELS set | ||
generation_params.update({ | ||
"prompt": args.prompt, | ||
"model": args.model, | ||
"quality": args.quality, | ||
"size": args.size, | ||
"response_format": "b64_json", | ||
}) | ||
|
||
if not is_openai: | ||
extra_body = {} | ||
if args.guidance_scale: | ||
extra_body["guidance_scale"] = args.guidance_scale | ||
if args.num_inference_steps: | ||
extra_body["num_inference_steps"] = args.num_inference_steps | ||
if extra_body: | ||
generation_params["extra_body"] = extra_body | ||
|
||
response = client.images.generate(**generation_params) | ||
|
||
if hasattr(response.data[0], 'revised_prompt') and response.data[0].revised_prompt: | ||
print("Image Generator revised the prompt (this is expected): %s" % response.data[0].revised_prompt) | ||
|
||
assert response.data[0].b64_json is not None or response.data[0].url is not None, "No image data returned" | ||
|
||
if response.data[0].b64_json: | ||
image_data_base64 = response.data[0].b64_json | ||
image_data = base64.b64decode(image_data_base64) | ||
else: | ||
from src.utils import download_simple | ||
dest = download_simple(response.data[0].url, overwrite=True) | ||
with open(dest, "rb") as f: | ||
image_data = f.read() | ||
os.remove(dest) | ||
|
||
# Determine file type and name | ||
image_format = get_image_format(image_data) | ||
if not args.output: | ||
args.output = f"image_{str(uuid.uuid4())[:6]}.{image_format}" | ||
else: | ||
# If an output path is provided, ensure it has the correct extension | ||
base, ext = os.path.splitext(args.output) | ||
if ext.lower() != f".{image_format}": | ||
args.output = f"{base}.{image_format}" | ||
|
||
# Write the image data to a file | ||
with open(args.output, "wb") as img_file: | ||
img_file.write(image_data) | ||
|
||
full_path = os.path.abspath(args.output) | ||
print(f"Image successfully saved to the file: {full_path}") | ||
|
||
# NOTE: Could provide stats like image size, etc. | ||
|
||
|
||
def get_image_format(image_data): | ||
from PIL import Image | ||
import io | ||
# Use PIL to determine the image format | ||
with Image.open(io.BytesIO(image_data)) as img: | ||
return img.format.lower() | ||
|
||
|
||
if __name__ == "__main__": | ||
main() |
Oops, something went wrong.