From d346300162e15aab8f11de801bb353f7e185cbcb Mon Sep 17 00:00:00 2001 From: mi804 <1576993271@qq.com> Date: Fri, 24 Jan 2025 11:26:48 +0800 Subject: [PATCH] update eligen ui and readme --- apps/gradio/eligen_ui.py | 316 ---------------------- apps/gradio/entity_level_control.py | 390 ++++++++++++++++++++++++++++ examples/EntityControl/README.md | 15 +- 3 files changed, 398 insertions(+), 323 deletions(-) delete mode 100644 apps/gradio/eligen_ui.py create mode 100644 apps/gradio/entity_level_control.py diff --git a/apps/gradio/eligen_ui.py b/apps/gradio/eligen_ui.py deleted file mode 100644 index ada3e8e..0000000 --- a/apps/gradio/eligen_ui.py +++ /dev/null @@ -1,316 +0,0 @@ -import gradio as gr -from diffsynth import ModelManager, FluxImagePipeline, download_customized_models -import os, torch -from PIL import Image -import numpy as np -from PIL import ImageDraw, ImageFont -import random -import json - -def save_mask_prompts(masks, mask_prompts, global_prompt, seed=0, random_dir='0000000'): - save_dir = os.path.join('workdirs/tmp_mask', random_dir) - print(f'save to {save_dir}') - os.makedirs(save_dir, exist_ok=True) - for i, mask in enumerate(masks): - save_path = os.path.join(save_dir, f'{i}.png') - mask.save(save_path) - sample = { - "global_prompt": global_prompt, - "mask_prompts": mask_prompts, - "seed": seed, - } - with open(os.path.join(save_dir, f"prompts.json"), 'w') as f: - json.dump(sample, f, indent=4) - -def visualize_masks(image, masks, mask_prompts, font_size=35, use_random_colors=False): - # Create a blank image for overlays - overlay = Image.new('RGBA', image.size, (0, 0, 0, 0)) - colors = [ - (165, 238, 173, 80), - (76, 102, 221, 80), - (221, 160, 77, 80), - (204, 93, 71, 80), - (145, 187, 149, 80), - (134, 141, 172, 80), - (157, 137, 109, 80), - (153, 104, 95, 80), - (165, 238, 173, 80), - (76, 102, 221, 80), - (221, 160, 77, 80), - (204, 93, 71, 80), - (145, 187, 149, 80), - (134, 141, 172, 80), - (157, 137, 109, 80), - (153, 104, 95, 80), - ] - # Generate random colors for each mask - if use_random_colors: - colors = [(random.randint(0, 255), random.randint(0, 255), random.randint(0, 255), 80) for _ in range(len(masks))] - # Font settings - try: - font = ImageFont.truetype("arial", font_size) # Adjust as needed - except IOError: - font = ImageFont.load_default(font_size) - # Overlay each mask onto the overlay image - for mask, mask_prompt, color in zip(masks, mask_prompts, colors): - # Convert mask to RGBA mode - mask_rgba = mask.convert('RGBA') - mask_data = mask_rgba.getdata() - new_data = [(color if item[:3] == (255, 255, 255) else (0, 0, 0, 0)) for item in mask_data] - mask_rgba.putdata(new_data) - # Draw the mask prompt text on the mask - draw = ImageDraw.Draw(mask_rgba) - mask_bbox = mask.getbbox() # Get the bounding box of the mask - text_position = (mask_bbox[0] + 10, mask_bbox[1] + 10) # Adjust text position based on mask position - draw.text(text_position, mask_prompt, fill=(255, 255, 255, 255), font=font) - # Alpha composite the overlay with this mask - overlay = Image.alpha_composite(overlay, mask_rgba) - # Composite the overlay onto the original image - result = Image.alpha_composite(image.convert('RGBA'), overlay) - return result - -config = { - "model_config": { - "FLUX": { - "model_folder": "models/FLUX", - "pipeline_class": FluxImagePipeline, - "default_parameters": { - "cfg_scale": 3.0, - "embedded_guidance": 3.5, - "num_inference_steps": 50, - } - }, - }, - "max_num_painter_layers": 8, - "max_num_model_cache": 1, -} - - -def load_model_list(model_type): - if model_type is None: - return [] - folder = config["model_config"][model_type]["model_folder"] - file_list = [i for i in os.listdir(folder) if i.endswith(".safetensors")] - if model_type in ["HunyuanDiT", "Kolors", "FLUX"]: - file_list += [i for i in os.listdir(folder) if os.path.isdir(os.path.join(folder, i))] - file_list = sorted(file_list) - return file_list - - -model_dict = {} - -def load_model(model_type, model_path): - global model_dict - model_key = f"{model_type}:{model_path}" - if model_key in model_dict: - return model_dict[model_key] - model_path = os.path.join(config["model_config"][model_type]["model_folder"], model_path) - model_manager = ModelManager() - if model_type == "FLUX": - model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda", model_id_list=["FLUX.1-dev"]) - model_manager.load_lora( - download_customized_models( - model_id="DiffSynth-Studio/Eligen", - origin_file_path="model_bf16.safetensors", - local_dir="models/lora/entity_control", - ), - lora_alpha=1, - ) - else: - model_manager.load_model(model_path) - pipe = config["model_config"][model_type]["pipeline_class"].from_model_manager(model_manager) - while len(model_dict) + 1 > config["max_num_model_cache"]: - key = next(iter(model_dict.keys())) - model_manager_to_release, _ = model_dict[key] - model_manager_to_release.to("cpu") - del model_dict[key] - torch.cuda.empty_cache() - model_dict[model_key] = model_manager, pipe - return model_manager, pipe - - -with gr.Blocks() as app: - gr.Markdown(""" - # 实体级控制文生图模型EliGen - **UI说明** - 1. **点击Load model读取模型**,然后左侧界面为文生图输入参数;右侧Painter为局部控制区域绘制区域,每个局部控制条件由其Local prompt和绘制的mask组成,支持精准控制文生图和Inpainting两种模式。 - 2. **精准控制生图模式:** 输入Globalprompt;激活并绘制一个或多个局部控制条件,点击Generate生成图像; Global Prompt推荐包含每个Local Prompt。 - 3. **Inpainting模式:** 你可以上传图像,或者将上一步生成的图像设置为Inpaint Input Image,采用类似的方式输入局部控制条件,进行局部重绘。 - 4. 尽情创造! - """) - gr.Markdown(""" - # Entity-Level Controlled Text-to-Image Model: EliGen - **UI Instructions** - 1. **Click "Load model" to load the model.** The left interface is for text-to-image input parameters; the right "Painter" is the area for drawing local control regions. Each local control condition consists of its Local Prompt and the drawn mask, supporting both precise control text-to-image and Inpainting modes. - 2. **Precise Control Image Generation Mode:** Enter the Global Prompt; activate and draw one or more local control conditions, then click "Generate" to create the image. It is recommended that the Global Prompt includes all Local Prompts. - 3. **Inpainting Mode:** You can upload an image or set the image generated in the previous step as the "Inpaint Input Image." Use a similar method to input local control conditions for local redrawing. - 4. Enjoy! - """) - with gr.Row(): - random_mask_dir = gr.State('') - with gr.Column(scale=382, min_width=100): - model_type = gr.State('FLUX') - model_path = gr.State('FLUX.1-dev') - with gr.Accordion(label="Model"): - load_model_button = gr.Button(value="Load model") - - with gr.Accordion(label="Global prompt"): - prompt = gr.Textbox(label="Prompt", lines=3) - negative_prompt = gr.Textbox(label="Negative prompt", value="worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw,", lines=1) - cfg_scale = gr.Slider(minimum=1.0, maximum=10.0, value=7.0, step=0.1, interactive=True, label="Classifier-free guidance scale") - embedded_guidance = gr.Slider(minimum=0.0, maximum=10.0, value=0.0, step=0.1, interactive=True, label="Embedded guidance scale") - - with gr.Accordion(label="Inference Options"): - num_inference_steps = gr.Slider(minimum=1, maximum=100, value=20, step=1, interactive=True, label="Inference steps") - height = gr.Slider(minimum=64, maximum=2048, value=1024, step=64, interactive=True, label="Height") - width = gr.Slider(minimum=64, maximum=2048, value=1024, step=64, interactive=True, label="Width") - return_with_mask = gr.Checkbox(value=True, interactive=True, label="show result with mask painting") - with gr.Column(): - use_fixed_seed = gr.Checkbox(value=True, interactive=False, label="Use fixed seed") - seed = gr.Number(minimum=0, maximum=10**9, value=0, interactive=True, label="Random seed", show_label=False) - with gr.Accordion(label="Inpaint Input Image (Testing)"): - input_image = gr.Image(sources=None, show_label=False, interactive=True, type="pil") - background_weight = gr.Slider(minimum=0.0, maximum=1000., value=0., step=1, interactive=False, label="background_weight") - - with gr.Column(): - reset_input_button = gr.Button(value="Reset Inpaint Input") - send_input_to_painter = gr.Button(value="Set as painter's background") - @gr.on(inputs=[input_image], outputs=[input_image], triggers=reset_input_button.click) - def reset_input_image(input_image): - return None - @gr.on( - inputs=[model_type, model_path, prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width, return_with_mask], - outputs=[prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width, return_with_mask, load_model_button, random_mask_dir], - triggers=load_model_button.click - ) - def model_path_to_default_params(model_type, model_path, prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width, return_with_mask): - load_model(model_type, model_path) - cfg_scale = config["model_config"][model_type]["default_parameters"].get("cfg_scale", cfg_scale) - embedded_guidance = config["model_config"][model_type]["default_parameters"].get("embedded_guidance", embedded_guidance) - num_inference_steps = config["model_config"][model_type]["default_parameters"].get("num_inference_steps", num_inference_steps) - height = config["model_config"][model_type]["default_parameters"].get("height", height) - width = config["model_config"][model_type]["default_parameters"].get("width", width) - return_with_mask = config["model_config"][model_type]["default_parameters"].get("return_with_mask", return_with_mask) - return prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width, return_with_mask, gr.update(value="Loaded FLUX"), gr.State(f'{random.randint(0, 1000000):08d}') - - - with gr.Column(scale=618, min_width=100): - with gr.Accordion(label="Painter"): - enable_local_prompt_list = [] - local_prompt_list = [] - mask_scale_list = [] - canvas_list = [] - for painter_layer_id in range(config["max_num_painter_layers"]): - with gr.Tab(label=f"Layer {painter_layer_id}"): - enable_local_prompt = gr.Checkbox(label="Enable", value=False, key=f"enable_local_prompt_{painter_layer_id}") - local_prompt = gr.Textbox(label="Local prompt", key=f"local_prompt_{painter_layer_id}") - mask_scale = gr.Slider(minimum=0.0, maximum=5.0, value=1.0, step=0.1, interactive=True, label="Mask scale", key=f"mask_scale_{painter_layer_id}") - canvas = gr.ImageEditor(canvas_size=(512, 1), sources=None, layers=False, interactive=True, image_mode="RGBA", - brush=gr.Brush(default_size=50, default_color="#000000", colors=["#000000"]), - label="Painter", key=f"canvas_{painter_layer_id}", width=width, height=height) - @gr.on(inputs=[height, width, canvas], outputs=canvas, triggers=[height.change, width.change, canvas.clear, enable_local_prompt.change], show_progress="hidden") - def resize_canvas(height, width, canvas): - h, w = canvas["background"].shape[:2] - if h != height or width != w: - return np.ones((height, width, 3), dtype=np.uint8) * 255 - else: - return canvas - - enable_local_prompt_list.append(enable_local_prompt) - local_prompt_list.append(local_prompt) - mask_scale_list.append(mask_scale) - canvas_list.append(canvas) - with gr.Accordion(label="Results"): - run_button = gr.Button(value="Generate", variant="primary") - output_image = gr.Image(sources=None, show_label=False, interactive=False, type="pil") - with gr.Row(): - with gr.Column(): - output_to_painter_button = gr.Button(value="Set as painter's background") - with gr.Column(): - output_to_input_button = gr.Button(value="Set as input image") - real_output = gr.State(None) - mask_out = gr.State(None) - - @gr.on( - inputs=[model_type, model_path, prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width, return_with_mask, seed, input_image, background_weight, random_mask_dir] + enable_local_prompt_list + local_prompt_list + mask_scale_list + canvas_list, - outputs=[output_image, real_output, mask_out], - triggers=run_button.click - ) - def generate_image(model_type, model_path, prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width, return_with_mask, seed, input_image, background_weight, random_mask_dir, *args, progress=gr.Progress()): - _, pipe = load_model(model_type, model_path) - input_params = { - "prompt": prompt, - "negative_prompt": negative_prompt, - "cfg_scale": cfg_scale, - "num_inference_steps": num_inference_steps, - "height": height, - "width": width, - "progress_bar_cmd": progress.tqdm, - } - if isinstance(pipe, FluxImagePipeline): - input_params["embedded_guidance"] = embedded_guidance - if input_image is not None: - input_params["input_image"] = input_image.resize((width, height)).convert("RGB") - input_params["enable_eligen_inpaint"] = True - - enable_local_prompt_list, local_prompt_list, mask_scale_list, canvas_list = ( - args[0 * config["max_num_painter_layers"]: 1 * config["max_num_painter_layers"]], - args[1 * config["max_num_painter_layers"]: 2 * config["max_num_painter_layers"]], - args[2 * config["max_num_painter_layers"]: 3 * config["max_num_painter_layers"]], - args[3 * config["max_num_painter_layers"]: 4 * config["max_num_painter_layers"]] - ) - local_prompts, masks, mask_scales = [], [], [] - for enable_local_prompt, local_prompt, mask_scale, canvas in zip( - enable_local_prompt_list, local_prompt_list, mask_scale_list, canvas_list - ): - if enable_local_prompt: - local_prompts.append(local_prompt) - masks.append(Image.fromarray(canvas["layers"][0][:, :, -1]).convert("RGB")) - mask_scales.append(mask_scale) - entity_masks = None if len(masks) == 0 else masks - entity_prompts = None if len(local_prompts) == 0 else local_prompts - input_params.update({ - "eligen_entity_prompts": entity_prompts, - "eligen_entity_masks": entity_masks, - }) - torch.manual_seed(seed) - image = pipe(**input_params) - # visualize masks - masks = [mask.resize(image.size) for mask in masks] - image_with_mask = visualize_masks(image, masks, local_prompts) - # save_mask_prompts(masks, local_prompts, prompt, seed, random_mask_dir.value) - - real_output = gr.State(image) - mask_out = gr.State(image_with_mask) - - if return_with_mask: - return image_with_mask, real_output, mask_out - return image, real_output, mask_out - - @gr.on(inputs=[input_image] + canvas_list, outputs=canvas_list, triggers=send_input_to_painter.click) - def send_input_to_painter_background(input_image, *canvas_list): - if input_image is None: - return tuple(canvas_list) - for canvas in canvas_list: - h, w = canvas["background"].shape[:2] - canvas["background"] = input_image.resize((w, h)) - return tuple(canvas_list) - @gr.on(inputs=[real_output] + canvas_list, outputs=canvas_list, triggers=output_to_painter_button.click) - def send_output_to_painter_background(real_output, *canvas_list): - if real_output is None: - return tuple(canvas_list) - for canvas in canvas_list: - h, w = canvas["background"].shape[:2] - canvas["background"] = real_output.value.resize((w, h)) - return tuple(canvas_list) - @gr.on(inputs=[return_with_mask, real_output, mask_out], outputs=[output_image], triggers=[return_with_mask.change], show_progress="hidden") - def show_output(return_with_mask, real_output, mask_out): - if return_with_mask: - return mask_out.value - else: - return real_output.value - @gr.on(inputs=[real_output], outputs=[input_image], triggers=output_to_input_button.click) - def send_output_to_pipe_input(real_output): - return real_output.value - -app.launch() diff --git a/apps/gradio/entity_level_control.py b/apps/gradio/entity_level_control.py new file mode 100644 index 0000000..58f4722 --- /dev/null +++ b/apps/gradio/entity_level_control.py @@ -0,0 +1,390 @@ +import os +import torch +import numpy as np +from PIL import Image, ImageDraw, ImageFont +import random +import json +import gradio as gr +from diffsynth import ModelManager, FluxImagePipeline, download_customized_models +from modelscope import dataset_snapshot_download + + +dataset_snapshot_download(dataset_id="DiffSynth-Studio/examples_in_diffsynth", local_dir="./", allow_file_pattern=f"data/examples/eligen/entity_control/*") +example_json = 'data/examples/eligen/entity_control/ui_examples.json' +with open(example_json, 'r') as f: + examples = json.load(f)['examples'] + +for idx in range(len(examples)): + example_id = examples[idx]['example_id'] + entity_prompts = examples[idx]['local_prompt_list'] + examples[idx]['mask_lists'] = [Image.open(f"data/examples/eligen/entity_control/example_{example_id}/{i}.png").convert('RGB') for i in range(len(entity_prompts))] + +def create_canvas_data(background, masks): + if background.shape[-1] == 3: + background = np.dstack([background, np.full(background.shape[:2], 255, dtype=np.uint8)]) + layers = [] + for mask in masks: + if mask is not None: + mask_single_channel = mask if mask.ndim == 2 else mask[..., 0] + layer = np.zeros((mask_single_channel.shape[0], mask_single_channel.shape[1], 4), dtype=np.uint8) + layer[..., -1] = mask_single_channel + layers.append(layer) + else: + layers.append(np.zeros_like(background)) + + composite = background.copy() + for layer in layers: + if layer.size > 0: + composite = np.where(layer[..., -1:] > 0, layer, composite) + return { + "background": background, + "layers": layers, + "composite": composite, + } + +def load_example(load_example_button): + example_idx = int(load_example_button.split()[-1]) - 1 + example = examples[example_idx] + result = [ + 50, + example["global_prompt"], + example["negative_prompt"], + example["seed"], + *example["local_prompt_list"], + ] + num_entities = len(example["local_prompt_list"]) + result += [""] * (config["max_num_painter_layers"] - num_entities) + masks = [] + for mask in example["mask_lists"]: + mask_single_channel = np.array(mask.convert("L")) + masks.append(mask_single_channel) + for _ in range(config["max_num_painter_layers"] - len(masks)): + blank_mask = np.zeros_like(masks[0]) if masks else np.zeros((512, 512), dtype=np.uint8) + masks.append(blank_mask) + background = np.ones((masks[0].shape[0], masks[0].shape[1], 4), dtype=np.uint8) * 255 + canvas_data_list = [] + for mask in masks: + canvas_data = create_canvas_data(background, [mask]) + canvas_data_list.append(canvas_data) + result.extend(canvas_data_list) + return result + +def save_mask_prompts(masks, mask_prompts, global_prompt, seed=0, random_dir='0000000'): + save_dir = os.path.join('workdirs/tmp_mask', random_dir) + print(f'save to {save_dir}') + os.makedirs(save_dir, exist_ok=True) + for i, mask in enumerate(masks): + save_path = os.path.join(save_dir, f'{i}.png') + mask.save(save_path) + sample = { + "global_prompt": global_prompt, + "mask_prompts": mask_prompts, + "seed": seed, + } + with open(os.path.join(save_dir, f"prompts.json"), 'w') as f: + json.dump(sample, f, indent=4) + +def visualize_masks(image, masks, mask_prompts, font_size=35, use_random_colors=False): + # Create a blank image for overlays + overlay = Image.new('RGBA', image.size, (0, 0, 0, 0)) + colors = [ + (165, 238, 173, 80), + (76, 102, 221, 80), + (221, 160, 77, 80), + (204, 93, 71, 80), + (145, 187, 149, 80), + (134, 141, 172, 80), + (157, 137, 109, 80), + (153, 104, 95, 80), + (165, 238, 173, 80), + (76, 102, 221, 80), + (221, 160, 77, 80), + (204, 93, 71, 80), + (145, 187, 149, 80), + (134, 141, 172, 80), + (157, 137, 109, 80), + (153, 104, 95, 80), + ] + # Generate random colors for each mask + if use_random_colors: + colors = [(random.randint(0, 255), random.randint(0, 255), random.randint(0, 255), 80) for _ in range(len(masks))] + # Font settings + try: + font = ImageFont.truetype("arial", font_size) # Adjust as needed + except IOError: + font = ImageFont.load_default(font_size) + # Overlay each mask onto the overlay image + for mask, mask_prompt, color in zip(masks, mask_prompts, colors): + if mask is None: + continue + # Convert mask to RGBA mode + mask_rgba = mask.convert('RGBA') + mask_data = mask_rgba.getdata() + new_data = [(color if item[:3] == (255, 255, 255) else (0, 0, 0, 0)) for item in mask_data] + mask_rgba.putdata(new_data) + # Draw the mask prompt text on the mask + draw = ImageDraw.Draw(mask_rgba) + mask_bbox = mask.getbbox() # Get the bounding box of the mask + if mask_bbox is None: + continue + text_position = (mask_bbox[0] + 10, mask_bbox[1] + 10) # Adjust text position based on mask position + draw.text(text_position, mask_prompt, fill=(255, 255, 255, 255), font=font) + # Alpha composite the overlay with this mask + overlay = Image.alpha_composite(overlay, mask_rgba) + # Composite the overlay onto the original image + result = Image.alpha_composite(image.convert('RGBA'), overlay) + return result + +config = { + "model_config": { + "FLUX": { + "model_folder": "models/FLUX", + "pipeline_class": FluxImagePipeline, + "default_parameters": { + "cfg_scale": 3.0, + "embedded_guidance": 3.5, + "num_inference_steps": 30, + } + }, + }, + "max_num_painter_layers": 8, + "max_num_model_cache": 1, +} + +model_dict = {} + +def load_model(model_type='FLUX', model_path='FLUX.1-dev'): + global model_dict + model_key = f"{model_type}:{model_path}" + if model_key in model_dict: + return model_dict[model_key] + model_path = os.path.join(config["model_config"][model_type]["model_folder"], model_path) + model_manager = ModelManager(torch_dtype=torch.bfloat16, device="cuda", model_id_list=["FLUX.1-dev"]) + model_manager.load_lora( + download_customized_models( + model_id="DiffSynth-Studio/Eligen", + origin_file_path="model_bf16.safetensors", + local_dir="models/lora/entity_control", + ), + lora_alpha=1, + ) + pipe = config["model_config"][model_type]["pipeline_class"].from_model_manager(model_manager) + model_dict[model_key] = model_manager, pipe + return model_manager, pipe + + +with gr.Blocks() as app: + gr.Markdown( + """## EliGen: Entity-Level Controllable Text-to-Image Model + 1. On the left, input the **global prompt** for the overall image, such as "a person stands by the river." + 2. On the right, input the **local prompt** for each entity, such as "person," and draw the corresponding mask in the **Entity Mask Painter**. Generally, solid rectangular masks yield better results. + 3. Click the **Generate** button to create the image. By selecting different **random seeds**, you can generate diverse images. + 4. **You can directly click the "Load Example" button on any sample at the bottom to load example inputs.** + """ + ) + + loading_status = gr.Textbox(label="Loading Model...", value="Loading model... Please wait...", visible=True) + main_interface = gr.Column(visible=False) + + def initialize_model(): + try: + load_model() + return { + loading_status: gr.update(value="Model loaded successfully!", visible=False), + main_interface: gr.update(visible=True), + } + except Exception as e: + print(f'Failed to load model with error: {e}') + return { + loading_status: gr.update(value=f"Failed to load model: {str(e)}", visible=True), + main_interface: gr.update(visible=True), + } + + app.load(initialize_model, inputs=None, outputs=[loading_status, main_interface]) + + with main_interface: + with gr.Row(): + local_prompt_list = [] + canvas_list = [] + random_mask_dir = gr.State(f'{random.randint(0, 1000000):08d}') + with gr.Column(scale=382, min_width=100): + model_type = gr.State('FLUX') + model_path = gr.State('FLUX.1-dev') + with gr.Accordion(label="Global prompt"): + prompt = gr.Textbox(label="Global Prompt", lines=3) + negative_prompt = gr.Textbox(label="Negative prompt", value="worst quality, low quality, monochrome, zombie, interlocked fingers, Aissist, cleavage, nsfw, blur,", lines=3) + with gr.Accordion(label="Inference Options", open=True): + seed = gr.Number(minimum=0, maximum=10**9, value=42, interactive=True, label="Random seed", show_label=True) + num_inference_steps = gr.Slider(minimum=1, maximum=100, value=30, step=1, interactive=True, label="Inference steps") + cfg_scale = gr.Slider(minimum=2.0, maximum=10.0, value=3.0, step=0.1, interactive=True, label="Classifier-free guidance scale") + embedded_guidance = gr.Slider(minimum=0.0, maximum=10.0, value=3.5, step=0.1, interactive=True, label="Embedded guidance scale") + height = gr.Slider(minimum=64, maximum=2048, value=1024, step=64, interactive=True, label="Height") + width = gr.Slider(minimum=64, maximum=2048, value=1024, step=64, interactive=True, label="Width") + with gr.Accordion(label="Inpaint Input Image", open=False): + input_image = gr.Image(sources=None, show_label=False, interactive=True, type="pil") + background_weight = gr.Slider(minimum=0.0, maximum=1000., value=0., step=1, interactive=False, label="background_weight", visible=False) + + with gr.Column(): + reset_input_button = gr.Button(value="Reset Inpaint Input") + send_input_to_painter = gr.Button(value="Set as painter's background") + @gr.on(inputs=[input_image], outputs=[input_image], triggers=reset_input_button.click) + def reset_input_image(input_image): + return None + + with gr.Column(scale=618, min_width=100): + with gr.Accordion(label="Entity Painter"): + for painter_layer_id in range(config["max_num_painter_layers"]): + with gr.Tab(label=f"Entity {painter_layer_id}"): + local_prompt = gr.Textbox(label="Local prompt", key=f"local_prompt_{painter_layer_id}") + canvas = gr.ImageEditor( + canvas_size=(512, 512), + sources=None, + layers=False, + interactive=True, + image_mode="RGBA", + brush=gr.Brush( + default_size=50, + default_color="#000000", + colors=["#000000"], + ), + label="Entity Mask Painter", + key=f"canvas_{painter_layer_id}", + width=width, + height=height, + ) + @gr.on(inputs=[height, width, canvas], outputs=canvas, triggers=[height.change, width.change, canvas.clear], show_progress="hidden") + def resize_canvas(height, width, canvas): + h, w = canvas["background"].shape[:2] + if h != height or width != w: + return np.ones((height, width, 3), dtype=np.uint8) * 255 + else: + return canvas + local_prompt_list.append(local_prompt) + canvas_list.append(canvas) + with gr.Accordion(label="Results"): + run_button = gr.Button(value="Generate", variant="primary") + output_image = gr.Image(sources=None, show_label=False, interactive=False, type="pil") + with gr.Row(): + with gr.Column(): + output_to_painter_button = gr.Button(value="Set as painter's background") + with gr.Column(): + return_with_mask = gr.Checkbox(value=False, interactive=True, label="show result with mask painting") + output_to_input_button = gr.Button(value="Set as input image", visible=False, interactive=False) + real_output = gr.State(None) + mask_out = gr.State(None) + + @gr.on( + inputs=[model_type, model_path, prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width, return_with_mask, seed, input_image, background_weight, random_mask_dir] + local_prompt_list + canvas_list, + outputs=[output_image, real_output, mask_out], + triggers=run_button.click + ) + def generate_image(model_type, model_path, prompt, negative_prompt, cfg_scale, embedded_guidance, num_inference_steps, height, width, return_with_mask, seed, input_image, background_weight, random_mask_dir, *args, progress=gr.Progress()): + _, pipe = load_model(model_type, model_path) + input_params = { + "prompt": prompt, + "negative_prompt": negative_prompt, + "cfg_scale": cfg_scale, + "num_inference_steps": num_inference_steps, + "height": height, + "width": width, + "progress_bar_cmd": progress.tqdm, + } + if isinstance(pipe, FluxImagePipeline): + input_params["embedded_guidance"] = embedded_guidance + if input_image is not None: + input_params["input_image"] = input_image.resize((width, height)).convert("RGB") + input_params["enable_eligen_inpaint"] = True + + local_prompt_list, canvas_list = ( + args[0 * config["max_num_painter_layers"]: 1 * config["max_num_painter_layers"]], + args[1 * config["max_num_painter_layers"]: 2 * config["max_num_painter_layers"]], + ) + local_prompts, masks = [], [] + for local_prompt, canvas in zip(local_prompt_list, canvas_list): + if isinstance(local_prompt, str) and len(local_prompt) > 0: + local_prompts.append(local_prompt) + masks.append(Image.fromarray(canvas["layers"][0][:, :, -1]).convert("RGB")) + entity_masks = None if len(masks) == 0 else masks + entity_prompts = None if len(local_prompts) == 0 else local_prompts + input_params.update({ + "eligen_entity_prompts": entity_prompts, + "eligen_entity_masks": entity_masks, + }) + torch.manual_seed(seed) + # save_mask_prompts(masks, local_prompts, prompt, seed, random_mask_dir) + image = pipe(**input_params) + masks = [mask.resize(image.size) for mask in masks] + image_with_mask = visualize_masks(image, masks, local_prompts) + + real_output = gr.State(image) + mask_out = gr.State(image_with_mask) + + if return_with_mask: + return image_with_mask, real_output, mask_out + return image, real_output, mask_out + + @gr.on(inputs=[input_image] + canvas_list, outputs=canvas_list, triggers=send_input_to_painter.click) + def send_input_to_painter_background(input_image, *canvas_list): + if input_image is None: + return tuple(canvas_list) + for canvas in canvas_list: + h, w = canvas["background"].shape[:2] + canvas["background"] = input_image.resize((w, h)) + return tuple(canvas_list) + @gr.on(inputs=[real_output] + canvas_list, outputs=canvas_list, triggers=output_to_painter_button.click) + def send_output_to_painter_background(real_output, *canvas_list): + if real_output is None: + return tuple(canvas_list) + for canvas in canvas_list: + h, w = canvas["background"].shape[:2] + canvas["background"] = real_output.value.resize((w, h)) + return tuple(canvas_list) + @gr.on(inputs=[return_with_mask, real_output, mask_out], outputs=[output_image], triggers=[return_with_mask.change], show_progress="hidden") + def show_output(return_with_mask, real_output, mask_out): + if return_with_mask: + return mask_out.value + else: + return real_output.value + @gr.on(inputs=[real_output], outputs=[input_image], triggers=output_to_input_button.click) + def send_output_to_pipe_input(real_output): + return real_output.value + + with gr.Column(): + gr.Markdown("## Examples") + for i in range(0, len(examples), 2): + with gr.Row(): + if i < len(examples): + example = examples[i] + with gr.Column(): + example_image = gr.Image( + value=f"data/examples/eligen/entity_control/example_{example['example_id']}/example_image.png", + label=example["description"], + interactive=False, + width=1024, + height=512 + ) + load_example_button = gr.Button(value=f"Load Example {example['example_id']}") + load_example_button.click( + load_example, + inputs=[load_example_button], + outputs=[num_inference_steps, prompt, negative_prompt, seed] + local_prompt_list + canvas_list + ) + + if i + 1 < len(examples): + example = examples[i + 1] + with gr.Column(): + example_image = gr.Image( + value=f"data/examples/eligen/entity_control/example_{example['example_id']}/example_image.png", + label=example["description"], + interactive=False, + width=1024, + height=512 + ) + load_example_button = gr.Button(value=f"Load Example {example['example_id']}") + load_example_button.click( + load_example, + inputs=[load_example_button], + outputs=[num_inference_steps, prompt, negative_prompt, seed] + local_prompt_list + canvas_list + ) +app.config["show_progress"] = "hidden" +app.launch() diff --git a/examples/EntityControl/README.md b/examples/EntityControl/README.md index 92c220d..85f9cfb 100644 --- a/examples/EntityControl/README.md +++ b/examples/EntityControl/README.md @@ -7,11 +7,12 @@ We propose EliGen, a novel approach that leverages fine-grained entity-level inf * Paper: [EliGen: Entity-Level Controlled Image Generation with Regional Attention](https://arxiv.org/abs/2501.01097) * Github: [DiffSynth-Studio](https://github.com/modelscope/DiffSynth-Studio) * Model: [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen) +* Online Demo: [ModelScope EliGen Studio](https://www.modelscope.cn/studios/DiffSynth-Studio/EliGen) * Training dataset: Coming soon ## Methodology -![regional-attention](https://github.com/user-attachments/assets/9a147201-15ab-421f-a6c5-701075754478) +![regional-attention](https://github.com/user-attachments/assets/bef5ae2b-cc03-404e-b9c8-0c037ac66190) We introduce a regional attention mechanism within the DiT framework to effectively process the conditions of each entity. This mechanism enables the local prompt associated with each entity to semantically influence specific regions through regional attention. To further enhance the layout control capabilities of EliGen, we meticulously contribute an entity-annotated dataset and fine-tune the model using the LoRA framework. @@ -32,7 +33,7 @@ We introduce a regional attention mechanism within the DiT framework to effectiv 4. **Entity Transfer** We have provided an example of how to integrate EliGen with In-Cotext LoRA, which achieves interesting entity transfer results. See [./entity_transfer.py](./entity_transfer.py) for usage. 5. **Play with EliGen using UI** - Download the checkpoint of EliGen from [ModelScope](https://www.modelscope.cn/models/DiffSynth-Studio/Eligen) to `models/lora/entity_control` and run the following command to try interactive UI: + Run the following command to try interactive UI: ```bash python apps/gradio/entity_level_control.py ``` @@ -41,7 +42,7 @@ We introduce a regional attention mechanism within the DiT framework to effectiv 1. The effect of generating images with continuously changing entity positions. -https://github.com/user-attachments/assets/4fc76df1-b26a-46e8-a950-865cdf02a38d +https://github.com/user-attachments/assets/54a048c8-b663-4262-8c40-43c87c266d4b 2. The image generation effect of complex Entity combinations, demonstrating the strong generalization of EliGen. See [./entity_control.py](./entity_control.py) `example_1-6` for generation prompts. @@ -67,17 +68,17 @@ https://github.com/user-attachments/assets/4fc76df1-b26a-46e8-a950-865cdf02a38d Demonstration of the inpainting mode of EliGen, see [./entity_inpaint.py](./entity_inpaint.py) for generation prompts. |Inpainting Input|Inpainting Output| |-|-| -|![image_2_base](https://github.com/user-attachments/assets/5ef499f3-3d8a-49cc-8ceb-86af7f5cb9f8)|![image_2_enhance](https://github.com/user-attachments/assets/88fc3bde-0984-4b3c-8ca9-d63de660855b)| -|![image_1_base](https://github.com/user-attachments/assets/5f74c710-bf30-4db1-ae40-a1e1995ccef6)|![image_1_enhance](https://github.com/user-attachments/assets/1cd71177-e956-46d3-86ce-06f774c96efd)| +|![inpaint_i1](https://github.com/user-attachments/assets/5ef499f3-3d8a-49cc-8ceb-86af7f5cb9f8)|![inpaint_o1](https://github.com/user-attachments/assets/88fc3bde-0984-4b3c-8ca9-d63de660855b)| +|![inpaint_i2](https://github.com/user-attachments/assets/5f74c710-bf30-4db1-ae40-a1e1995ccef6)|![inpaint_o2](https://github.com/user-attachments/assets/7c3b4857-b774-47ea-b163-34d49e7c976d)| ### Styled Entity Control Demonstration of the styled entity control results with EliGen and IP-Adapter, see [./entity_control_ipadapter.py](./entity_control_ipadapter.py) for generation prompts. |Style Reference|Entity Control Variance 1|Entity Control Variance 2|Entity Control Variance 3| |-|-|-|-| -|![image_1_base](https://github.com/user-attachments/assets/5e2dd3ab-37d3-4f58-8e02-ee2f9b238604)|![image_1_enhance](https://github.com/user-attachments/assets/0f6711a2-572a-41b3-938a-95deff6d732d)|![image_1_enhance](https://github.com/user-attachments/assets/ce2e66e5-1fdf-44e8-bca7-555d805a50b1)|![image_1_enhance](https://github.com/user-attachments/assets/ad2da233-2f7c-4065-ab57-b2d84dc2c0e2)| +|![image_1_base](https://github.com/user-attachments/assets/5e2dd3ab-37d3-4f58-8e02-ee2f9b238604)|![result1](https://github.com/user-attachments/assets/0f6711a2-572a-41b3-938a-95deff6d732d)|![result2](https://github.com/user-attachments/assets/ce2e66e5-1fdf-44e8-bca7-555d805a50b1)|![result3](https://github.com/user-attachments/assets/ad2da233-2f7c-4065-ab57-b2d84dc2c0e2)| ### Entity Transfer Demonstration of the entity transfer results with EliGen and In-Context LoRA, see [./entity_transfer.py](./entity_transfer.py) for generation prompts. |Entity to Transfer|Transfer Target Image|Transfer Example 1|Transfer Example 2| |-|-|-|-| -|![image_1_base](https://github.com/user-attachments/assets/0d40ef22-0a09-420d-bd5a-bfb93120b60d)|![image_1_enhance](https://github.com/user-attachments/assets/f6c58ef2-54c1-4d86-8429-dad2eb0e0685)|![image_1_enhance](https://github.com/user-attachments/assets/05eed2e3-097d-40af-8aae-1e0c75051f32)|![image_1_enhance](https://github.com/user-attachments/assets/54314d16-244b-411e-8a91-96c500efa5f5)| \ No newline at end of file +|![source](https://github.com/user-attachments/assets/0d40ef22-0a09-420d-bd5a-bfb93120b60d)|![targe](https://github.com/user-attachments/assets/f6c58ef2-54c1-4d86-8429-dad2eb0e0685)|![result1](https://github.com/user-attachments/assets/05eed2e3-097d-40af-8aae-1e0c75051f32)|![result2](https://github.com/user-attachments/assets/54314d16-244b-411e-8a91-96c500efa5f5)| \ No newline at end of file