From 5f746e075da1fdd1f46d57a2b6adb9f2fe8bf638 Mon Sep 17 00:00:00 2001 From: yanzewu Date: Fri, 1 Nov 2024 04:08:30 +0800 Subject: [PATCH] add PuLID-v1.1 --- README.md | 14 +- app_v1_1.py | 252 ++++++++++++++ docs/pulid_v1.1.md | 15 + ...coders_flux.py => encoders_transformer.py} | 4 +- pulid/pipeline_flux.py | 2 +- pulid/pipeline_v1_1.py | 324 ++++++++++++++++++ pulid/utils.py | 171 +++++++++ requirements.txt | 3 +- 8 files changed, 779 insertions(+), 6 deletions(-) create mode 100644 app_v1_1.py create mode 100644 docs/pulid_v1.1.md rename pulid/{encoders_flux.py => encoders_transformer.py} (98%) create mode 100644 pulid/pipeline_v1_1.py diff --git a/README.md b/README.md index 73f2fcc..b877bd4 100644 --- a/README.md +++ b/README.md @@ -7,15 +7,16 @@ > ByteDance Inc
### :triangular_flag_on_post: Updates +* **2024.10.31**: 💥 We are happy to release our latest [models](https://huggingface.co/guozinan/PuLID), **PuLID-v1.1** and **PuLID-FLUX-v0.9.1**! * **2024.09.26**: 🎉 PuLID accepted by NeurIPS 2024 -* **2024.09.12**: 💥 We're thrilled to announce the release of the **PuLID-FLUX-v0.9.0 model**. Enjoy exploring its capabilities! 😊 [Learn more about this model](docs/pulid_for_flux.md) +* **2024.09.12**: We're thrilled to announce the release of the **PuLID-FLUX-v0.9.0 model**. Enjoy exploring its capabilities! 😊 [Learn more about this model](docs/pulid_for_flux.md) * **2024.05.23**: share the [preview of our upcoming v1.1 model](docs/v1.1_preview.md), please stay tuned * **2024.05.01**: release v1 codes&models, also the [🤗HuggingFace Demo](https://huggingface.co/spaces/yanze/PuLID) * **2024.04.25**: release arXiv paper. ### :soon: update plan -- [ ] release PuLID-FLUX-v0.9.1 model in 2024.10 -- [ ] release PuLID v1.1 (for SDXL) model in 2024.10 +- [x] release PuLID-FLUX-v0.9.1 model in 2024.10 +- [x] release PuLID v1.1 (for SDXL) model in 2024.10 ## PuLID for FLUX Please check the doc and demo of PuLID-FLUX [here](docs/pulid_for_flux.md). @@ -28,6 +29,7 @@ We will actively update and maintain this repository in the near future, so plea - [x] We have optimized the codes to support consumer-grade GPUS, and now **PuLID-FLUX can run on a 16GB graphic card**. Check the details [here](https://github.com/ToTheBeginning/PuLID/blob/main/docs/pulid_for_flux.md#local-gradio-demo) - [x] (Community Implementation) Online Replicate demo is ready now [![Replicate](https://replicate.com/zsxkib/flux-pulid/badge)](https://replicate.com/zsxkib/flux-pulid) - [x] Local gradio demo supports 12GB graphic card now +- [x] v0.9.1 is ready now Below results are generated with PuLID-FLUX. @@ -62,7 +64,13 @@ pip install -r requirements_fp8.txt ## :zap: Quick Inference ### Local Gradio Demo ```bash +# for v1 version python app.py + +# for v1.1 version +python app_v1.1.py --base BASE_MODEL +Usage: + -base: can be RunDiffusion/Juggernaut-XL-v9 or Lykon/dreamshaper-xl-lightning ``` ### Online HuggingFace Demo diff --git a/app_v1_1.py b/app_v1_1.py new file mode 100644 index 0000000..05815b7 --- /dev/null +++ b/app_v1_1.py @@ -0,0 +1,252 @@ +import argparse + +import gradio as gr +import numpy as np +import torch + +from pulid import attention_processor as attention +from pulid.pipeline_v1_1 import PuLIDPipeline +from pulid.utils import resize_numpy_image_long + +torch.set_grad_enabled(False) + +parser = argparse.ArgumentParser() +parser.add_argument( + '--base', + type=str, + default='RunDiffusion/Juggernaut-XL-v9', + choices=[ + 'Lykon/dreamshaper-xl-lightning', + # 'SG161222/RealVisXL_V4.0', will add it later + 'RunDiffusion/Juggernaut-XL-v9', + ], +) +# parser.add_argument('--sampler', type=str, default='dpmpp_2m', choices=['dpmpp_sde', 'dpmpp_2m']) +parser.add_argument('--port', type=int, default=7860) +args = parser.parse_args() + +use_lightning_model = 'lightning' in args.base.lower() +# currently we only support two commonly used sampler +args.sampler = 'dpmpp_sde' if use_lightning_model else 'dpmpp_2m' +if use_lightning_model: + default_cfg = 2.0 + default_steps = 5 +else: + default_cfg = 7.0 + default_steps = 25 + +pipeline = PuLIDPipeline(sdxl_repo=args.base, sampler=args.sampler) + +# other params +DEFAULT_NEGATIVE_PROMPT = ( + 'flaws in the eyes, flaws in the face, flaws, lowres, non-HDRi, low quality, worst quality,' + 'artifacts noise, text, watermark, glitch, deformed, mutated, ugly, disfigured, hands, ' + 'low resolution, partially rendered objects, deformed or partially rendered eyes, ' + 'deformed, deformed eyeballs, cross-eyed,blurry' +) + +dreamshaper_example_inps = [ + ['portrait, blacklight', 'example_inputs/liuyifei.png', 42, 0.8, 10], + ['pixel art, 1boy', 'example_inputs/lecun.jpg', 42, 0.8, 10], + [ + 'cinematic film still, close up, photo of redheaded girl near grasses, fictional landscapes, (intense sunlight:1.4), realist detail, brooding mood, ue5, detailed character expressions, light amber and red, amazing quality, wallpaper, analog film grain', + 'example_inputs/liuyifei.png', + 42, + 0.8, + 10, + ], + [ + 'A minimalist line art depiction of an Artificial Intelligence being\'s thought process, lines and nodes forming intricate patterns.', + 'example_inputs/hinton.jpeg', + 42, + 0.8, + 10, + ], + [ + 'instagram photo, photo of 23 y.o man in black sweater, pale skin, (smile:0.4), hard shadows', + 'example_inputs/pengwei.jpg', + 42, + 0.8, + 10, + ], + [ + 'by Tsutomu Nihei,(strange but extremely beautiful:1.4),(masterpiece, best quality:1.4),in the style of nicola samori,The Joker,', + 'example_inputs/lecun.jpg', + 1675432759740519133, + 0.8, + 10, + ], +] + +jugger_example_inps = [ + [ + 'robot,simple robot,robot with glass face,ellipse head robot,(made partially out of glass),hexagonal shapes,ferns growing inside head,butterflies on head,butterflies flying around', + 'example_inputs/hinton.jpeg', + 15022214902832471291, + 0.8, + 20, + ], + ['sticker art, 1girl', 'example_inputs/liuyifei.png', 42, 0.8, 20], + [ + '1girl, cute model, Long thick Maxi Skirt, Knit sweater, swept back hair, alluring smile, working at a clothing store, perfect eyes, highly detailed beautiful expressive eyes, detailed eyes, 35mm photograph, film, bokeh, professional, 4k, highly detailed dynamic lighting, photorealistic, 8k, raw, rich, intricate details,', + 'example_inputs/liuyifei.png', + 42, + 0.8, + 20, + ], + ['Chinese paper-cut, 1girl', 'example_inputs/liuyifei.png', 42, 0.8, 20], + ['Studio Ghibli, 1boy', 'example_inputs/hinton.jpeg', 42, 0.8, 20], + ['1man made of ice sculpture', 'example_inputs/lecun.jpg', 42, 0.8, 20], + ['portrait of green-skinned shrek, wearing lacoste purple sweater', 'example_inputs/lecun.jpg', 42, 0.8, 20], + ['1990s Japanese anime, 1girl', 'example_inputs/liuyifei.png', 42, 0.8, 20], + ['made of little stones, portrait', 'example_inputs/hinton.jpeg', 42, 0.8, 20], +] + + +@torch.inference_mode() +def run(*args): + id_image = args[0] + supp_images = args[1:4] + prompt, neg_prompt, scale, seed, steps, H, W, id_scale, num_zero, ortho = args[4:] + seed = int(seed) + if seed == -1: + seed = torch.Generator(device="cpu").seed() + + pipeline.debug_img_list = [] + + attention.NUM_ZERO = num_zero + if ortho == 'v2': + attention.ORTHO = False + attention.ORTHO_v2 = True + elif ortho == 'v1': + attention.ORTHO = True + attention.ORTHO_v2 = False + else: + attention.ORTHO = False + attention.ORTHO_v2 = False + + if id_image is not None: + id_image = resize_numpy_image_long(id_image, 1024) + supp_id_image_list = [ + resize_numpy_image_long(supp_id_image, 1024) for supp_id_image in supp_images if supp_id_image is not None + ] + id_image_list = [id_image] + supp_id_image_list + uncond_id_embedding, id_embedding = pipeline.get_id_embedding(id_image_list) + else: + uncond_id_embedding = None + id_embedding = None + + img = pipeline.inference( + prompt, (1, H, W), neg_prompt, id_embedding, uncond_id_embedding, id_scale, scale, steps, seed + )[0] + + return np.array(img), str(seed), pipeline.debug_img_list + + +_HEADER_ = ''' +

Official Gradio Demo

PuLID: Pure and Lightning ID Customization via Contrastive Alignment

+ +**PuLID** is a tuning-free ID customization approach. PuLID maintains high ID fidelity while effectively reducing interference with the original model’s behavior. + +Code: GitHub. Paper: ArXiv. + +❗️❗️❗️**Tips:** +- we provide some examples in the bottom, you can try these example prompts first +- a single ID image is usually sufficient, you can also supplement with additional auxiliary images +- You can adjust the trade-off between ID fidelity and editability in the advanced options, but generally, the default settings are good enough. + +''' # noqa E501 + +_CITE_ = r""" +If PuLID is helpful, please help to ⭐ the Github Repo. Thanks! [![GitHub Stars](https://img.shields.io/github/stars/ToTheBeginning/PuLID?style=social)](https://github.com/ToTheBeginning/PuLID) +--- +📧 **Contact** +If you have any questions, feel free to open a discussion or contact us at wuyanze123@gmail.com or guozinan.1@bytedance.com. +""" # noqa E501 + + +with gr.Blocks(title="PuLID", css=".gr-box {border-color: #8136e2}") as demo: + gr.Markdown(_HEADER_) + with gr.Row(): + with gr.Column(): + with gr.Row(): + face_image = gr.Image(label="ID image (main)", height=256) + supp_image1 = gr.Image(label="Additional ID image (auxiliary)", height=256) + supp_image2 = gr.Image(label="Additional ID image (auxiliary)", height=256) + supp_image3 = gr.Image(label="Additional ID image (auxiliary)", height=256) + prompt = gr.Textbox(label="Prompt", value='portrait,color,cinematic,in garden,soft light,detailed face') + submit = gr.Button("Generate") + neg_prompt = gr.Textbox(label="Negative Prompt", value=DEFAULT_NEGATIVE_PROMPT) + scale = gr.Slider( + label="CFG (recommend 2 for lightning model and 7 for non-accelerated model)", + value=default_cfg, + minimum=1, + maximum=10, + step=0.1, + ) + seed = gr.Textbox(-1, label="Seed (-1 for random)") + steps = gr.Slider(label="Steps", value=default_steps, minimum=1, maximum=30, step=1) + with gr.Row(): + H = gr.Slider(label="Height", value=1152, minimum=512, maximum=2024, step=64) + W = gr.Slider(label="Width", value=896, minimum=512, maximum=2024, step=64) + with gr.Row(), gr.Accordion( + "Advanced Options (adjust the trade-off between ID fidelity and editability)", open=False + ): + id_scale = gr.Slider( + label="ID scale (Increasing it enhances ID similarity but reduces editability)", + minimum=0, + maximum=5, + step=0.05, + value=0.8, + interactive=True, + ) + num_zero = gr.Slider( + label="num zero (Increasing it enhances ID editability but reduces similarity)", + minimum=0, + maximum=80, + step=1, + value=20, + interactive=True, + ) + ortho = gr.Dropdown(label="ortho", choices=['off', 'v1', 'v2'], value='v2', visible=False) + + with gr.Column(): + output = gr.Image(label="Generated Image") + seed_output = gr.Textbox(label="Used Seed") + intermediate_output = gr.Gallery(label='DebugImage', elem_id="gallery", visible=False) + gr.Markdown(_CITE_) + + with gr.Row(), gr.Column(): + gr.Markdown("## Examples") + if args.base == 'Lykon/dreamshaper-xl-lightning': + gr.Examples( + examples=dreamshaper_example_inps, + inputs=[prompt, face_image, seed, id_scale, num_zero], + label='dreamshaper-xl-lightning examples', + ) + elif args.base == 'RunDiffusion/Juggernaut-XL-v9': + gr.Examples( + examples=jugger_example_inps, + inputs=[prompt, face_image, seed, id_scale, num_zero], + label='Juggernaut-XL-v9 examples', + ) + + inps = [ + face_image, + supp_image1, + supp_image2, + supp_image3, + prompt, + neg_prompt, + scale, + seed, + steps, + H, + W, + id_scale, + num_zero, + ortho, + ] + submit.click(fn=run, inputs=inps, outputs=[output, seed_output, intermediate_output]) + +demo.launch(server_name='0.0.0.0', server_port=args.port) diff --git a/docs/pulid_v1.1.md b/docs/pulid_v1.1.md new file mode 100644 index 0000000..5b07963 --- /dev/null +++ b/docs/pulid_v1.1.md @@ -0,0 +1,15 @@ +# PuLID v1.1 +Following are some examples generated by PuLID-v1.1, you can reproduce these results from our Gradio demo. +![release_v1 1](https://github.com/user-attachments/assets/d5bf3865-5147-428d-bb98-80812c680900) + +## How to use +### Online demo +We plan to upgrade the PuLID demo to v1.1 soon, please stay tuned. + +### Local demo +```bash +python app_v1.1.py --base BASE_MODEL +Usage: + -base: can be RunDiffusion/Juggernaut-XL-v9 or Lykon/dreamshaper-xl-lightning +``` + diff --git a/pulid/encoders_flux.py b/pulid/encoders_transformer.py similarity index 98% rename from pulid/encoders_flux.py rename to pulid/encoders_transformer.py index 7891fb3..afbdaf3 100644 --- a/pulid/encoders_flux.py +++ b/pulid/encoders_transformer.py @@ -190,8 +190,10 @@ def forward(self, x, y): latents = self.latents.repeat(x.size(0), 1, 1) + num_duotu = x.shape[1] if x.ndim == 3 else 1 + x = self.id_embedding_mapping(x) - x = x.reshape(-1, self.num_id_token, self.dim) + x = x.reshape(-1, self.num_id_token * num_duotu, self.dim) latents = torch.cat((latents, x), dim=1) diff --git a/pulid/pipeline_flux.py b/pulid/pipeline_flux.py index 019bcef..706514c 100644 --- a/pulid/pipeline_flux.py +++ b/pulid/pipeline_flux.py @@ -14,7 +14,7 @@ from eva_clip import create_model_and_transforms from eva_clip.constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD -from pulid.encoders_flux import IDFormer, PerceiverAttentionCA +from pulid.encoders_transformer import IDFormer, PerceiverAttentionCA from pulid.utils import img2tensor, tensor2img diff --git a/pulid/pipeline_v1_1.py b/pulid/pipeline_v1_1.py new file mode 100644 index 0000000..4ca68cf --- /dev/null +++ b/pulid/pipeline_v1_1.py @@ -0,0 +1,324 @@ +import gc + +import cv2 +import insightface +import numpy as np +import torch +import torch.nn as nn +from basicsr.utils import img2tensor, tensor2img +from diffusers import DPMSolverMultistepScheduler, StableDiffusionXLPipeline +from facexlib.parsing import init_parsing_model +from facexlib.utils.face_restoration_helper import FaceRestoreHelper + +from huggingface_hub import hf_hub_download, snapshot_download +from insightface.app import FaceAnalysis +from safetensors.torch import load_file +from torchvision.transforms import InterpolationMode +from torchvision.transforms.functional import normalize, resize + +from eva_clip import create_model_and_transforms +from eva_clip.constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD +from pulid.encoders_transformer import IDFormer +from pulid.utils import is_torch2_available, sample_dpmpp_2m, sample_dpmpp_sde + +if is_torch2_available(): + from pulid.attention_processor import AttnProcessor2_0 as AttnProcessor + from pulid.attention_processor import IDAttnProcessor2_0 as IDAttnProcessor +else: + from pulid.attention_processor import AttnProcessor, IDAttnProcessor + + +class PuLIDPipeline: + def __init__(self, sdxl_repo='Lykon/dreamshaper-xl-lightning', sampler='dpmpp_sde', *args, **kwargs): + super().__init__() + self.device = 'cuda' + + # load base model + self.pipe = StableDiffusionXLPipeline.from_pretrained(sdxl_repo, torch_dtype=torch.float16, variant="fp16").to( + self.device + ) + self.pipe.watermark = None + self.hack_unet_attn_layers(self.pipe.unet) + + # scheduler + self.pipe.scheduler = DPMSolverMultistepScheduler.from_config(self.pipe.scheduler.config) + + # ID adapters + self.id_adapter = IDFormer().to(self.device) + + # preprocessors + # face align and parsing + self.face_helper = FaceRestoreHelper( + upscale_factor=1, + face_size=512, + crop_ratio=(1, 1), + det_model='retinaface_resnet50', + save_ext='png', + device=self.device, + ) + self.face_helper.face_parse = None + self.face_helper.face_parse = init_parsing_model(model_name='bisenet', device=self.device) + # clip-vit backbone + model, _, _ = create_model_and_transforms('EVA02-CLIP-L-14-336', 'eva_clip', force_custom_clip=True) + model = model.visual + self.clip_vision_model = model.to(self.device) + eva_transform_mean = getattr(self.clip_vision_model, 'image_mean', OPENAI_DATASET_MEAN) + eva_transform_std = getattr(self.clip_vision_model, 'image_std', OPENAI_DATASET_STD) + if not isinstance(eva_transform_mean, (list, tuple)): + eva_transform_mean = (eva_transform_mean,) * 3 + if not isinstance(eva_transform_std, (list, tuple)): + eva_transform_std = (eva_transform_std,) * 3 + self.eva_transform_mean = eva_transform_mean + self.eva_transform_std = eva_transform_std + # antelopev2 + snapshot_download('DIAMONIK7777/antelopev2', local_dir='models/antelopev2') + self.app = FaceAnalysis( + name='antelopev2', root='.', providers=['CUDAExecutionProvider', 'CPUExecutionProvider'] + ) + self.app.prepare(ctx_id=0, det_size=(640, 640)) + self.handler_ante = insightface.model_zoo.get_model('models/antelopev2/glintr100.onnx') + self.handler_ante.prepare(ctx_id=0) + + gc.collect() + torch.cuda.empty_cache() + + self.load_pretrain() + + # other configs + self.debug_img_list = [] + + # karras schedule related code, borrow from lllyasviel/Omost + linear_start = 0.00085 + linear_end = 0.012 + timesteps = 1000 + betas = torch.linspace(linear_start**0.5, linear_end**0.5, timesteps, dtype=torch.float64) ** 2 + alphas = 1.0 - betas + alphas_cumprod = torch.tensor(np.cumprod(alphas, axis=0), dtype=torch.float32) + + self.sigmas = ((1 - alphas_cumprod) / alphas_cumprod) ** 0.5 + self.log_sigmas = self.sigmas.log() + self.sigma_data = 1.0 + + if sampler == 'dpmpp_sde': + self.sampler = sample_dpmpp_sde + elif sampler == 'dpmpp_2m': + self.sampler = sample_dpmpp_2m + else: + raise NotImplementedError(f'sampler {sampler} not implemented') + + @property + def sigma_min(self): + return self.sigmas[0] + + @property + def sigma_max(self): + return self.sigmas[-1] + + def timestep(self, sigma): + log_sigma = sigma.log() + dists = log_sigma.to(self.log_sigmas.device) - self.log_sigmas[:, None] + return dists.abs().argmin(dim=0).view(sigma.shape).to(sigma.device) + + def get_sigmas_karras(self, n, rho=7.0): + ramp = torch.linspace(0, 1, n) + min_inv_rho = self.sigma_min ** (1 / rho) + max_inv_rho = self.sigma_max ** (1 / rho) + sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho + return torch.cat([sigmas, sigmas.new_zeros([1])]) + + def hack_unet_attn_layers(self, unet): + id_adapter_attn_procs = {} + for name, _ in unet.attn_processors.items(): + cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim + if name.startswith("mid_block"): + hidden_size = unet.config.block_out_channels[-1] + elif name.startswith("up_blocks"): + block_id = int(name[len("up_blocks.")]) + hidden_size = list(reversed(unet.config.block_out_channels))[block_id] + elif name.startswith("down_blocks"): + block_id = int(name[len("down_blocks.")]) + hidden_size = unet.config.block_out_channels[block_id] + if cross_attention_dim is not None: + id_adapter_attn_procs[name] = IDAttnProcessor( + hidden_size=hidden_size, + cross_attention_dim=cross_attention_dim, + ).to(unet.device) + else: + id_adapter_attn_procs[name] = AttnProcessor() + unet.set_attn_processor(id_adapter_attn_procs) + self.id_adapter_attn_layers = nn.ModuleList(unet.attn_processors.values()) + + def load_pretrain(self): + hf_hub_download('guozinan/PuLID', 'pulid_v1.1.safetensors', local_dir='models') + ckpt_path = 'models/pulid_v1.1.safetensors' + state_dict = load_file(ckpt_path) + state_dict_dict = {} + for k, v in state_dict.items(): + module = k.split('.')[0] + state_dict_dict.setdefault(module, {}) + new_k = k[len(module) + 1 :] + state_dict_dict[module][new_k] = v + + for module in state_dict_dict: + print(f'loading from {module}') + getattr(self, module).load_state_dict(state_dict_dict[module], strict=True) + + def to_gray(self, img): + x = 0.299 * img[:, 0:1] + 0.587 * img[:, 1:2] + 0.114 * img[:, 2:3] + x = x.repeat(1, 3, 1, 1) + return x + + def get_id_embedding(self, image_list): + """ + Args: + image in image_list: numpy rgb image, range [0, 255] + """ + id_cond_list = [] + id_vit_hidden_list = [] + for ii, image in enumerate(image_list): + self.face_helper.clean_all() + image_bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) + # get antelopev2 embedding + face_info = self.app.get(image_bgr) + if len(face_info) > 0: + face_info = sorted( + face_info, key=lambda x: (x['bbox'][2] - x['bbox'][0]) * (x['bbox'][3] - x['bbox'][1]) + )[ + -1 + ] # only use the maximum face + id_ante_embedding = face_info['embedding'] + self.debug_img_list.append( + image[ + int(face_info['bbox'][1]) : int(face_info['bbox'][3]), + int(face_info['bbox'][0]) : int(face_info['bbox'][2]), + ] + ) + else: + id_ante_embedding = None + + # using facexlib to detect and align face + self.face_helper.read_image(image_bgr) + self.face_helper.get_face_landmarks_5(only_center_face=True) + self.face_helper.align_warp_face() + if len(self.face_helper.cropped_faces) == 0: + raise RuntimeError('facexlib align face fail') + align_face = self.face_helper.cropped_faces[0] + # incase insightface didn't detect face + if id_ante_embedding is None: + print('fail to detect face using insightface, extract embedding on align face') + id_ante_embedding = self.handler_ante.get_feat(align_face) + + id_ante_embedding = torch.from_numpy(id_ante_embedding).to(self.device) + if id_ante_embedding.ndim == 1: + id_ante_embedding = id_ante_embedding.unsqueeze(0) + + # parsing + input = img2tensor(align_face, bgr2rgb=True).unsqueeze(0) / 255.0 + input = input.to(self.device) + parsing_out = self.face_helper.face_parse(normalize(input, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]))[ + 0 + ] + parsing_out = parsing_out.argmax(dim=1, keepdim=True) + bg_label = [0, 16, 18, 7, 8, 9, 14, 15] + bg = sum(parsing_out == i for i in bg_label).bool() + white_image = torch.ones_like(input) + # only keep the face features + face_features_image = torch.where(bg, white_image, self.to_gray(input)) + self.debug_img_list.append(tensor2img(face_features_image, rgb2bgr=False)) + + # transform img before sending to eva-clip-vit + face_features_image = resize( + face_features_image, self.clip_vision_model.image_size, InterpolationMode.BICUBIC + ) + face_features_image = normalize(face_features_image, self.eva_transform_mean, self.eva_transform_std) + id_cond_vit, id_vit_hidden = self.clip_vision_model( + face_features_image, return_all_features=False, return_hidden=True, shuffle=False + ) + id_cond_vit_norm = torch.norm(id_cond_vit, 2, 1, True) + id_cond_vit = torch.div(id_cond_vit, id_cond_vit_norm) + + id_cond = torch.cat([id_ante_embedding, id_cond_vit], dim=-1) + + id_cond_list.append(id_cond) + id_vit_hidden_list.append(id_vit_hidden) + + id_uncond = torch.zeros_like(id_cond_list[0]) + id_vit_hidden_uncond = [] + for layer_idx in range(0, len(id_vit_hidden_list[0])): + id_vit_hidden_uncond.append(torch.zeros_like(id_vit_hidden_list[0][layer_idx])) + + id_cond = torch.stack(id_cond_list, dim=1) + id_vit_hidden = id_vit_hidden_list[0] + for i in range(1, len(image_list)): + for j, x in enumerate(id_vit_hidden_list[i]): + id_vit_hidden[j] = torch.cat([id_vit_hidden[j], x], dim=1) + id_embedding = self.id_adapter(id_cond, id_vit_hidden) + uncond_id_embedding = self.id_adapter(id_uncond, id_vit_hidden_uncond) + + # return id_embedding + return uncond_id_embedding, id_embedding + + def __call__(self, x, sigma, **extra_args): + x_ddim_space = x / (sigma[:, None, None, None] ** 2 + self.sigma_data**2) ** 0.5 + t = self.timestep(sigma) + cfg_scale = extra_args['cfg_scale'] + eps_positive = self.pipe.unet(x_ddim_space, t, return_dict=False, **extra_args['positive'])[0] + eps_negative = self.pipe.unet(x_ddim_space, t, return_dict=False, **extra_args['negative'])[0] + noise_pred = eps_negative + cfg_scale * (eps_positive - eps_negative) + return x - noise_pred * sigma[:, None, None, None] + + def inference( + self, + prompt, + size, + prompt_n='', + id_embedding=None, + uncond_id_embedding=None, + id_scale=1.0, + guidance_scale=1.2, + steps=4, + seed=-1, + ): + + # sigmas + sigmas = self.get_sigmas_karras(steps).to(self.device) + + # latents + noise = torch.randn((size[0], 4, size[1] // 8, size[2] // 8), device="cpu", generator=torch.manual_seed(seed)) + noise = noise.to(dtype=self.pipe.unet.dtype, device=self.device) + latents = noise * sigmas[0].to(noise) + + ( + prompt_embeds, + negative_prompt_embeds, + pooled_prompt_embeds, + negative_pooled_prompt_embeds, + ) = self.pipe.encode_prompt( + prompt=prompt, + negative_prompt=prompt_n, + ) + + add_time_ids = list((size[1], size[2]) + (0, 0) + (size[1], size[2])) + add_time_ids = torch.tensor([add_time_ids], dtype=self.pipe.unet.dtype, device=self.device) + add_neg_time_ids = add_time_ids.clone() + + sampler_kwargs = dict( + cfg_scale=guidance_scale, + positive=dict( + encoder_hidden_states=prompt_embeds, + added_cond_kwargs={"text_embeds": pooled_prompt_embeds, "time_ids": add_time_ids}, + cross_attention_kwargs={'id_embedding': id_embedding, 'id_scale': id_scale}, + ), + negative=dict( + encoder_hidden_states=negative_prompt_embeds, + added_cond_kwargs={"text_embeds": negative_pooled_prompt_embeds, "time_ids": add_neg_time_ids}, + cross_attention_kwargs={'id_embedding': uncond_id_embedding, 'id_scale': id_scale}, + ), + ) + + latents = self.sampler(self, latents, sigmas, extra_args=sampler_kwargs, disable=False) + latents = latents.to(dtype=self.pipe.vae.dtype, device=self.device) / self.pipe.vae.config.scaling_factor + images = self.pipe.vae.decode(latents).sample + images = self.pipe.image_processor.postprocess(images, output_type='pil') + + return images diff --git a/pulid/utils.py b/pulid/utils.py index 8097672..14b5099 100644 --- a/pulid/utils.py +++ b/pulid/utils.py @@ -7,7 +7,9 @@ import numpy as np import torch import torch.nn.functional as F +import torchsde from torchvision.utils import make_grid +from tqdm.auto import trange from transformers import PretrainedConfig @@ -164,3 +166,172 @@ def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)): if len(result) == 1: result = result[0] return result + + +# We didn't find a correct configuration to make the diffusers scheduler align with dpm++2m (karras) in ComfyUI, +# so we copied the ComfyUI code directly. + + +def append_dims(x, target_dims): + """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" + dims_to_append = target_dims - x.ndim + if dims_to_append < 0: + raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less') + expanded = x[(...,) + (None,) * dims_to_append] + # MPS will get inf values if it tries to index into the new axes, but detaching fixes this. + # https://github.com/pytorch/pytorch/issues/84364 + return expanded.detach().clone() if expanded.device.type == 'mps' else expanded + + +def to_d(x, sigma, denoised): + """Converts a denoiser output to a Karras ODE derivative.""" + return (x - denoised) / append_dims(sigma, x.ndim) + + +def get_ancestral_step(sigma_from, sigma_to, eta=1.0): + """Calculates the noise level (sigma_down) to step down to and the amount + of noise to add (sigma_up) when doing an ancestral sampling step.""" + if not eta: + return sigma_to, 0.0 + sigma_up = min(sigma_to, eta * (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5) + sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5 + return sigma_down, sigma_up + + +class BatchedBrownianTree: + """A wrapper around torchsde.BrownianTree that enables batches of entropy.""" + + def __init__(self, x, t0, t1, seed=None, **kwargs): + self.cpu_tree = True + if "cpu" in kwargs: + self.cpu_tree = kwargs.pop("cpu") + t0, t1, self.sign = self.sort(t0, t1) + w0 = kwargs.get('w0', torch.zeros_like(x)) + if seed is None: + seed = torch.randint(0, 2**63 - 1, []).item() + self.batched = True + try: + assert len(seed) == x.shape[0] + w0 = w0[0] + except TypeError: + seed = [seed] + self.batched = False + if self.cpu_tree: + self.trees = [torchsde.BrownianTree(t0.cpu(), w0.cpu(), t1.cpu(), entropy=s, **kwargs) for s in seed] + else: + self.trees = [torchsde.BrownianTree(t0, w0, t1, entropy=s, **kwargs) for s in seed] + + @staticmethod + def sort(a, b): + return (a, b, 1) if a < b else (b, a, -1) + + def __call__(self, t0, t1): + t0, t1, sign = self.sort(t0, t1) + if self.cpu_tree: + w = torch.stack( + [tree(t0.cpu().float(), t1.cpu().float()).to(t0.dtype).to(t0.device) for tree in self.trees] + ) * (self.sign * sign) + else: + w = torch.stack([tree(t0, t1) for tree in self.trees]) * (self.sign * sign) + + return w if self.batched else w[0] + + +class BrownianTreeNoiseSampler: + """A noise sampler backed by a torchsde.BrownianTree. + + Args: + x (Tensor): The tensor whose shape, device and dtype to use to generate + random samples. + sigma_min (float): The low end of the valid interval. + sigma_max (float): The high end of the valid interval. + seed (int or List[int]): The random seed. If a list of seeds is + supplied instead of a single integer, then the noise sampler will + use one BrownianTree per batch item, each with its own seed. + transform (callable): A function that maps sigma to the sampler's + internal timestep. + """ + + def __init__(self, x, sigma_min, sigma_max, seed=None, transform=lambda x: x, cpu=False): + self.transform = transform + t0, t1 = self.transform(torch.as_tensor(sigma_min)), self.transform(torch.as_tensor(sigma_max)) + self.tree = BatchedBrownianTree(x, t0, t1, seed, cpu=cpu) + + def __call__(self, sigma, sigma_next): + t0, t1 = self.transform(torch.as_tensor(sigma)), self.transform(torch.as_tensor(sigma_next)) + return self.tree(t0, t1) / (t1 - t0).abs().sqrt() + + +@torch.no_grad() +def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=None): + """DPM-Solver++(2M).""" + extra_args = {} if extra_args is None else extra_args + s_in = x.new_ones([x.shape[0]]) + sigma_fn = lambda t: t.neg().exp() + t_fn = lambda sigma: sigma.log().neg() + old_denoised = None + + for i in trange(len(sigmas) - 1, disable=disable): + denoised = model(x, sigmas[i] * s_in, **extra_args) + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1]) + h = t_next - t + if old_denoised is None or sigmas[i + 1] == 0: + x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised + else: + h_last = t - t_fn(sigmas[i - 1]) + r = h_last / h + denoised_d = (1 + 1 / (2 * r)) * denoised - (1 / (2 * r)) * old_denoised + x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_d + old_denoised = denoised + return x + + +@torch.no_grad() +def sample_dpmpp_sde( + model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1.0, s_noise=1.0, noise_sampler=None, r=1 / 2 +): + """DPM-Solver++ (stochastic).""" + sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max() + seed = extra_args.get("seed", None) + noise_sampler = ( + BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=False) + if noise_sampler is None + else noise_sampler + ) + extra_args = {} if extra_args is None else extra_args + s_in = x.new_ones([x.shape[0]]) + sigma_fn = lambda t: t.neg().exp() + t_fn = lambda sigma: sigma.log().neg() + + for i in trange(len(sigmas) - 1, disable=disable): + denoised = model(x, sigmas[i] * s_in, **extra_args) + if callback is not None: + callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) + if sigmas[i + 1] == 0: + # Euler method + d = to_d(x, sigmas[i], denoised) + dt = sigmas[i + 1] - sigmas[i] + x = x + d * dt + else: + # DPM-Solver++ + t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1]) + h = t_next - t + s = t + h * r + fac = 1 / (2 * r) + + # Step 1 + sd, su = get_ancestral_step(sigma_fn(t), sigma_fn(s), eta) + s_ = t_fn(sd) + x_2 = (sigma_fn(s_) / sigma_fn(t)) * x - (t - s_).expm1() * denoised + x_2 = x_2 + noise_sampler(sigma_fn(t), sigma_fn(s)) * s_noise * su + denoised_2 = model(x_2, sigma_fn(s) * s_in, **extra_args) + + # Step 2 + sd, su = get_ancestral_step(sigma_fn(t), sigma_fn(t_next), eta) + t_next_ = t_fn(sd) + denoised_d = (1 - fac) * denoised + fac * denoised_2 + x = (sigma_fn(t_next_) / sigma_fn(t)) * x - (t - t_next_).expm1() * denoised_d + x = x + noise_sampler(sigma_fn(t), sigma_fn(t_next)) * s_noise * su + return x diff --git a/requirements.txt b/requirements.txt index 3b5c67e..081da6e 100644 --- a/requirements.txt +++ b/requirements.txt @@ -14,4 +14,5 @@ onnxruntime onnxruntime-gpu accelerate SentencePiece -safetensors \ No newline at end of file +safetensors +torchsde \ No newline at end of file