diff --git a/README.md b/README.md
index 73f2fcc..b877bd4 100644
--- a/README.md
+++ b/README.md
@@ -7,15 +7,16 @@
> ByteDance Inc
### :triangular_flag_on_post: Updates
+* **2024.10.31**: 💥 We are happy to release our latest [models](https://huggingface.co/guozinan/PuLID), **PuLID-v1.1** and **PuLID-FLUX-v0.9.1**!
* **2024.09.26**: 🎉 PuLID accepted by NeurIPS 2024
-* **2024.09.12**: 💥 We're thrilled to announce the release of the **PuLID-FLUX-v0.9.0 model**. Enjoy exploring its capabilities! 😊 [Learn more about this model](docs/pulid_for_flux.md)
+* **2024.09.12**: We're thrilled to announce the release of the **PuLID-FLUX-v0.9.0 model**. Enjoy exploring its capabilities! 😊 [Learn more about this model](docs/pulid_for_flux.md)
* **2024.05.23**: share the [preview of our upcoming v1.1 model](docs/v1.1_preview.md), please stay tuned
* **2024.05.01**: release v1 codes&models, also the [🤗HuggingFace Demo](https://huggingface.co/spaces/yanze/PuLID)
* **2024.04.25**: release arXiv paper.
### :soon: update plan
-- [ ] release PuLID-FLUX-v0.9.1 model in 2024.10
-- [ ] release PuLID v1.1 (for SDXL) model in 2024.10
+- [x] release PuLID-FLUX-v0.9.1 model in 2024.10
+- [x] release PuLID v1.1 (for SDXL) model in 2024.10
## PuLID for FLUX
Please check the doc and demo of PuLID-FLUX [here](docs/pulid_for_flux.md).
@@ -28,6 +29,7 @@ We will actively update and maintain this repository in the near future, so plea
- [x] We have optimized the codes to support consumer-grade GPUS, and now **PuLID-FLUX can run on a 16GB graphic card**. Check the details [here](https://github.com/ToTheBeginning/PuLID/blob/main/docs/pulid_for_flux.md#local-gradio-demo)
- [x] (Community Implementation) Online Replicate demo is ready now [![Replicate](https://replicate.com/zsxkib/flux-pulid/badge)](https://replicate.com/zsxkib/flux-pulid)
- [x] Local gradio demo supports 12GB graphic card now
+- [x] v0.9.1 is ready now
Below results are generated with PuLID-FLUX.
@@ -62,7 +64,13 @@ pip install -r requirements_fp8.txt
## :zap: Quick Inference
### Local Gradio Demo
```bash
+# for v1 version
python app.py
+
+# for v1.1 version
+python app_v1.1.py --base BASE_MODEL
+Usage:
+ -base: can be RunDiffusion/Juggernaut-XL-v9 or Lykon/dreamshaper-xl-lightning
```
### Online HuggingFace Demo
diff --git a/app_v1_1.py b/app_v1_1.py
new file mode 100644
index 0000000..05815b7
--- /dev/null
+++ b/app_v1_1.py
@@ -0,0 +1,252 @@
+import argparse
+
+import gradio as gr
+import numpy as np
+import torch
+
+from pulid import attention_processor as attention
+from pulid.pipeline_v1_1 import PuLIDPipeline
+from pulid.utils import resize_numpy_image_long
+
+torch.set_grad_enabled(False)
+
+parser = argparse.ArgumentParser()
+parser.add_argument(
+ '--base',
+ type=str,
+ default='RunDiffusion/Juggernaut-XL-v9',
+ choices=[
+ 'Lykon/dreamshaper-xl-lightning',
+ # 'SG161222/RealVisXL_V4.0', will add it later
+ 'RunDiffusion/Juggernaut-XL-v9',
+ ],
+)
+# parser.add_argument('--sampler', type=str, default='dpmpp_2m', choices=['dpmpp_sde', 'dpmpp_2m'])
+parser.add_argument('--port', type=int, default=7860)
+args = parser.parse_args()
+
+use_lightning_model = 'lightning' in args.base.lower()
+# currently we only support two commonly used sampler
+args.sampler = 'dpmpp_sde' if use_lightning_model else 'dpmpp_2m'
+if use_lightning_model:
+ default_cfg = 2.0
+ default_steps = 5
+else:
+ default_cfg = 7.0
+ default_steps = 25
+
+pipeline = PuLIDPipeline(sdxl_repo=args.base, sampler=args.sampler)
+
+# other params
+DEFAULT_NEGATIVE_PROMPT = (
+ 'flaws in the eyes, flaws in the face, flaws, lowres, non-HDRi, low quality, worst quality,'
+ 'artifacts noise, text, watermark, glitch, deformed, mutated, ugly, disfigured, hands, '
+ 'low resolution, partially rendered objects, deformed or partially rendered eyes, '
+ 'deformed, deformed eyeballs, cross-eyed,blurry'
+)
+
+dreamshaper_example_inps = [
+ ['portrait, blacklight', 'example_inputs/liuyifei.png', 42, 0.8, 10],
+ ['pixel art, 1boy', 'example_inputs/lecun.jpg', 42, 0.8, 10],
+ [
+ 'cinematic film still, close up, photo of redheaded girl near grasses, fictional landscapes, (intense sunlight:1.4), realist detail, brooding mood, ue5, detailed character expressions, light amber and red, amazing quality, wallpaper, analog film grain',
+ 'example_inputs/liuyifei.png',
+ 42,
+ 0.8,
+ 10,
+ ],
+ [
+ 'A minimalist line art depiction of an Artificial Intelligence being\'s thought process, lines and nodes forming intricate patterns.',
+ 'example_inputs/hinton.jpeg',
+ 42,
+ 0.8,
+ 10,
+ ],
+ [
+ 'instagram photo, photo of 23 y.o man in black sweater, pale skin, (smile:0.4), hard shadows',
+ 'example_inputs/pengwei.jpg',
+ 42,
+ 0.8,
+ 10,
+ ],
+ [
+ 'by Tsutomu Nihei,(strange but extremely beautiful:1.4),(masterpiece, best quality:1.4),in the style of nicola samori,The Joker,',
+ 'example_inputs/lecun.jpg',
+ 1675432759740519133,
+ 0.8,
+ 10,
+ ],
+]
+
+jugger_example_inps = [
+ [
+ 'robot,simple robot,robot with glass face,ellipse head robot,(made partially out of glass),hexagonal shapes,ferns growing inside head,butterflies on head,butterflies flying around',
+ 'example_inputs/hinton.jpeg',
+ 15022214902832471291,
+ 0.8,
+ 20,
+ ],
+ ['sticker art, 1girl', 'example_inputs/liuyifei.png', 42, 0.8, 20],
+ [
+ '1girl, cute model, Long thick Maxi Skirt, Knit sweater, swept back hair, alluring smile, working at a clothing store, perfect eyes, highly detailed beautiful expressive eyes, detailed eyes, 35mm photograph, film, bokeh, professional, 4k, highly detailed dynamic lighting, photorealistic, 8k, raw, rich, intricate details,',
+ 'example_inputs/liuyifei.png',
+ 42,
+ 0.8,
+ 20,
+ ],
+ ['Chinese paper-cut, 1girl', 'example_inputs/liuyifei.png', 42, 0.8, 20],
+ ['Studio Ghibli, 1boy', 'example_inputs/hinton.jpeg', 42, 0.8, 20],
+ ['1man made of ice sculpture', 'example_inputs/lecun.jpg', 42, 0.8, 20],
+ ['portrait of green-skinned shrek, wearing lacoste purple sweater', 'example_inputs/lecun.jpg', 42, 0.8, 20],
+ ['1990s Japanese anime, 1girl', 'example_inputs/liuyifei.png', 42, 0.8, 20],
+ ['made of little stones, portrait', 'example_inputs/hinton.jpeg', 42, 0.8, 20],
+]
+
+
+@torch.inference_mode()
+def run(*args):
+ id_image = args[0]
+ supp_images = args[1:4]
+ prompt, neg_prompt, scale, seed, steps, H, W, id_scale, num_zero, ortho = args[4:]
+ seed = int(seed)
+ if seed == -1:
+ seed = torch.Generator(device="cpu").seed()
+
+ pipeline.debug_img_list = []
+
+ attention.NUM_ZERO = num_zero
+ if ortho == 'v2':
+ attention.ORTHO = False
+ attention.ORTHO_v2 = True
+ elif ortho == 'v1':
+ attention.ORTHO = True
+ attention.ORTHO_v2 = False
+ else:
+ attention.ORTHO = False
+ attention.ORTHO_v2 = False
+
+ if id_image is not None:
+ id_image = resize_numpy_image_long(id_image, 1024)
+ supp_id_image_list = [
+ resize_numpy_image_long(supp_id_image, 1024) for supp_id_image in supp_images if supp_id_image is not None
+ ]
+ id_image_list = [id_image] + supp_id_image_list
+ uncond_id_embedding, id_embedding = pipeline.get_id_embedding(id_image_list)
+ else:
+ uncond_id_embedding = None
+ id_embedding = None
+
+ img = pipeline.inference(
+ prompt, (1, H, W), neg_prompt, id_embedding, uncond_id_embedding, id_scale, scale, steps, seed
+ )[0]
+
+ return np.array(img), str(seed), pipeline.debug_img_list
+
+
+_HEADER_ = '''
+
Official Gradio Demo
+
+**PuLID** is a tuning-free ID customization approach. PuLID maintains high ID fidelity while effectively reducing interference with the original model’s behavior.
+
+Code: GitHub. Paper: ArXiv.
+
+❗️❗️❗️**Tips:**
+- we provide some examples in the bottom, you can try these example prompts first
+- a single ID image is usually sufficient, you can also supplement with additional auxiliary images
+- You can adjust the trade-off between ID fidelity and editability in the advanced options, but generally, the default settings are good enough.
+
+''' # noqa E501
+
+_CITE_ = r"""
+If PuLID is helpful, please help to ⭐ the Github Repo. Thanks! [![GitHub Stars](https://img.shields.io/github/stars/ToTheBeginning/PuLID?style=social)](https://github.com/ToTheBeginning/PuLID)
+---
+📧 **Contact**
+If you have any questions, feel free to open a discussion or contact us at wuyanze123@gmail.com or guozinan.1@bytedance.com.
+""" # noqa E501
+
+
+with gr.Blocks(title="PuLID", css=".gr-box {border-color: #8136e2}") as demo:
+ gr.Markdown(_HEADER_)
+ with gr.Row():
+ with gr.Column():
+ with gr.Row():
+ face_image = gr.Image(label="ID image (main)", height=256)
+ supp_image1 = gr.Image(label="Additional ID image (auxiliary)", height=256)
+ supp_image2 = gr.Image(label="Additional ID image (auxiliary)", height=256)
+ supp_image3 = gr.Image(label="Additional ID image (auxiliary)", height=256)
+ prompt = gr.Textbox(label="Prompt", value='portrait,color,cinematic,in garden,soft light,detailed face')
+ submit = gr.Button("Generate")
+ neg_prompt = gr.Textbox(label="Negative Prompt", value=DEFAULT_NEGATIVE_PROMPT)
+ scale = gr.Slider(
+ label="CFG (recommend 2 for lightning model and 7 for non-accelerated model)",
+ value=default_cfg,
+ minimum=1,
+ maximum=10,
+ step=0.1,
+ )
+ seed = gr.Textbox(-1, label="Seed (-1 for random)")
+ steps = gr.Slider(label="Steps", value=default_steps, minimum=1, maximum=30, step=1)
+ with gr.Row():
+ H = gr.Slider(label="Height", value=1152, minimum=512, maximum=2024, step=64)
+ W = gr.Slider(label="Width", value=896, minimum=512, maximum=2024, step=64)
+ with gr.Row(), gr.Accordion(
+ "Advanced Options (adjust the trade-off between ID fidelity and editability)", open=False
+ ):
+ id_scale = gr.Slider(
+ label="ID scale (Increasing it enhances ID similarity but reduces editability)",
+ minimum=0,
+ maximum=5,
+ step=0.05,
+ value=0.8,
+ interactive=True,
+ )
+ num_zero = gr.Slider(
+ label="num zero (Increasing it enhances ID editability but reduces similarity)",
+ minimum=0,
+ maximum=80,
+ step=1,
+ value=20,
+ interactive=True,
+ )
+ ortho = gr.Dropdown(label="ortho", choices=['off', 'v1', 'v2'], value='v2', visible=False)
+
+ with gr.Column():
+ output = gr.Image(label="Generated Image")
+ seed_output = gr.Textbox(label="Used Seed")
+ intermediate_output = gr.Gallery(label='DebugImage', elem_id="gallery", visible=False)
+ gr.Markdown(_CITE_)
+
+ with gr.Row(), gr.Column():
+ gr.Markdown("## Examples")
+ if args.base == 'Lykon/dreamshaper-xl-lightning':
+ gr.Examples(
+ examples=dreamshaper_example_inps,
+ inputs=[prompt, face_image, seed, id_scale, num_zero],
+ label='dreamshaper-xl-lightning examples',
+ )
+ elif args.base == 'RunDiffusion/Juggernaut-XL-v9':
+ gr.Examples(
+ examples=jugger_example_inps,
+ inputs=[prompt, face_image, seed, id_scale, num_zero],
+ label='Juggernaut-XL-v9 examples',
+ )
+
+ inps = [
+ face_image,
+ supp_image1,
+ supp_image2,
+ supp_image3,
+ prompt,
+ neg_prompt,
+ scale,
+ seed,
+ steps,
+ H,
+ W,
+ id_scale,
+ num_zero,
+ ortho,
+ ]
+ submit.click(fn=run, inputs=inps, outputs=[output, seed_output, intermediate_output])
+
+demo.launch(server_name='0.0.0.0', server_port=args.port)
diff --git a/docs/pulid_v1.1.md b/docs/pulid_v1.1.md
new file mode 100644
index 0000000..5b07963
--- /dev/null
+++ b/docs/pulid_v1.1.md
@@ -0,0 +1,15 @@
+# PuLID v1.1
+Following are some examples generated by PuLID-v1.1, you can reproduce these results from our Gradio demo.
+![release_v1 1](https://github.com/user-attachments/assets/d5bf3865-5147-428d-bb98-80812c680900)
+
+## How to use
+### Online demo
+We plan to upgrade the PuLID demo to v1.1 soon, please stay tuned.
+
+### Local demo
+```bash
+python app_v1.1.py --base BASE_MODEL
+Usage:
+ -base: can be RunDiffusion/Juggernaut-XL-v9 or Lykon/dreamshaper-xl-lightning
+```
+
diff --git a/pulid/encoders_flux.py b/pulid/encoders_transformer.py
similarity index 98%
rename from pulid/encoders_flux.py
rename to pulid/encoders_transformer.py
index 7891fb3..afbdaf3 100644
--- a/pulid/encoders_flux.py
+++ b/pulid/encoders_transformer.py
@@ -190,8 +190,10 @@ def forward(self, x, y):
latents = self.latents.repeat(x.size(0), 1, 1)
+ num_duotu = x.shape[1] if x.ndim == 3 else 1
+
x = self.id_embedding_mapping(x)
- x = x.reshape(-1, self.num_id_token, self.dim)
+ x = x.reshape(-1, self.num_id_token * num_duotu, self.dim)
latents = torch.cat((latents, x), dim=1)
diff --git a/pulid/pipeline_flux.py b/pulid/pipeline_flux.py
index 019bcef..706514c 100644
--- a/pulid/pipeline_flux.py
+++ b/pulid/pipeline_flux.py
@@ -14,7 +14,7 @@
from eva_clip import create_model_and_transforms
from eva_clip.constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
-from pulid.encoders_flux import IDFormer, PerceiverAttentionCA
+from pulid.encoders_transformer import IDFormer, PerceiverAttentionCA
from pulid.utils import img2tensor, tensor2img
diff --git a/pulid/pipeline_v1_1.py b/pulid/pipeline_v1_1.py
new file mode 100644
index 0000000..4ca68cf
--- /dev/null
+++ b/pulid/pipeline_v1_1.py
@@ -0,0 +1,324 @@
+import gc
+
+import cv2
+import insightface
+import numpy as np
+import torch
+import torch.nn as nn
+from basicsr.utils import img2tensor, tensor2img
+from diffusers import DPMSolverMultistepScheduler, StableDiffusionXLPipeline
+from facexlib.parsing import init_parsing_model
+from facexlib.utils.face_restoration_helper import FaceRestoreHelper
+
+from huggingface_hub import hf_hub_download, snapshot_download
+from insightface.app import FaceAnalysis
+from safetensors.torch import load_file
+from torchvision.transforms import InterpolationMode
+from torchvision.transforms.functional import normalize, resize
+
+from eva_clip import create_model_and_transforms
+from eva_clip.constants import OPENAI_DATASET_MEAN, OPENAI_DATASET_STD
+from pulid.encoders_transformer import IDFormer
+from pulid.utils import is_torch2_available, sample_dpmpp_2m, sample_dpmpp_sde
+
+if is_torch2_available():
+ from pulid.attention_processor import AttnProcessor2_0 as AttnProcessor
+ from pulid.attention_processor import IDAttnProcessor2_0 as IDAttnProcessor
+else:
+ from pulid.attention_processor import AttnProcessor, IDAttnProcessor
+
+
+class PuLIDPipeline:
+ def __init__(self, sdxl_repo='Lykon/dreamshaper-xl-lightning', sampler='dpmpp_sde', *args, **kwargs):
+ super().__init__()
+ self.device = 'cuda'
+
+ # load base model
+ self.pipe = StableDiffusionXLPipeline.from_pretrained(sdxl_repo, torch_dtype=torch.float16, variant="fp16").to(
+ self.device
+ )
+ self.pipe.watermark = None
+ self.hack_unet_attn_layers(self.pipe.unet)
+
+ # scheduler
+ self.pipe.scheduler = DPMSolverMultistepScheduler.from_config(self.pipe.scheduler.config)
+
+ # ID adapters
+ self.id_adapter = IDFormer().to(self.device)
+
+ # preprocessors
+ # face align and parsing
+ self.face_helper = FaceRestoreHelper(
+ upscale_factor=1,
+ face_size=512,
+ crop_ratio=(1, 1),
+ det_model='retinaface_resnet50',
+ save_ext='png',
+ device=self.device,
+ )
+ self.face_helper.face_parse = None
+ self.face_helper.face_parse = init_parsing_model(model_name='bisenet', device=self.device)
+ # clip-vit backbone
+ model, _, _ = create_model_and_transforms('EVA02-CLIP-L-14-336', 'eva_clip', force_custom_clip=True)
+ model = model.visual
+ self.clip_vision_model = model.to(self.device)
+ eva_transform_mean = getattr(self.clip_vision_model, 'image_mean', OPENAI_DATASET_MEAN)
+ eva_transform_std = getattr(self.clip_vision_model, 'image_std', OPENAI_DATASET_STD)
+ if not isinstance(eva_transform_mean, (list, tuple)):
+ eva_transform_mean = (eva_transform_mean,) * 3
+ if not isinstance(eva_transform_std, (list, tuple)):
+ eva_transform_std = (eva_transform_std,) * 3
+ self.eva_transform_mean = eva_transform_mean
+ self.eva_transform_std = eva_transform_std
+ # antelopev2
+ snapshot_download('DIAMONIK7777/antelopev2', local_dir='models/antelopev2')
+ self.app = FaceAnalysis(
+ name='antelopev2', root='.', providers=['CUDAExecutionProvider', 'CPUExecutionProvider']
+ )
+ self.app.prepare(ctx_id=0, det_size=(640, 640))
+ self.handler_ante = insightface.model_zoo.get_model('models/antelopev2/glintr100.onnx')
+ self.handler_ante.prepare(ctx_id=0)
+
+ gc.collect()
+ torch.cuda.empty_cache()
+
+ self.load_pretrain()
+
+ # other configs
+ self.debug_img_list = []
+
+ # karras schedule related code, borrow from lllyasviel/Omost
+ linear_start = 0.00085
+ linear_end = 0.012
+ timesteps = 1000
+ betas = torch.linspace(linear_start**0.5, linear_end**0.5, timesteps, dtype=torch.float64) ** 2
+ alphas = 1.0 - betas
+ alphas_cumprod = torch.tensor(np.cumprod(alphas, axis=0), dtype=torch.float32)
+
+ self.sigmas = ((1 - alphas_cumprod) / alphas_cumprod) ** 0.5
+ self.log_sigmas = self.sigmas.log()
+ self.sigma_data = 1.0
+
+ if sampler == 'dpmpp_sde':
+ self.sampler = sample_dpmpp_sde
+ elif sampler == 'dpmpp_2m':
+ self.sampler = sample_dpmpp_2m
+ else:
+ raise NotImplementedError(f'sampler {sampler} not implemented')
+
+ @property
+ def sigma_min(self):
+ return self.sigmas[0]
+
+ @property
+ def sigma_max(self):
+ return self.sigmas[-1]
+
+ def timestep(self, sigma):
+ log_sigma = sigma.log()
+ dists = log_sigma.to(self.log_sigmas.device) - self.log_sigmas[:, None]
+ return dists.abs().argmin(dim=0).view(sigma.shape).to(sigma.device)
+
+ def get_sigmas_karras(self, n, rho=7.0):
+ ramp = torch.linspace(0, 1, n)
+ min_inv_rho = self.sigma_min ** (1 / rho)
+ max_inv_rho = self.sigma_max ** (1 / rho)
+ sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
+ return torch.cat([sigmas, sigmas.new_zeros([1])])
+
+ def hack_unet_attn_layers(self, unet):
+ id_adapter_attn_procs = {}
+ for name, _ in unet.attn_processors.items():
+ cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
+ if name.startswith("mid_block"):
+ hidden_size = unet.config.block_out_channels[-1]
+ elif name.startswith("up_blocks"):
+ block_id = int(name[len("up_blocks.")])
+ hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
+ elif name.startswith("down_blocks"):
+ block_id = int(name[len("down_blocks.")])
+ hidden_size = unet.config.block_out_channels[block_id]
+ if cross_attention_dim is not None:
+ id_adapter_attn_procs[name] = IDAttnProcessor(
+ hidden_size=hidden_size,
+ cross_attention_dim=cross_attention_dim,
+ ).to(unet.device)
+ else:
+ id_adapter_attn_procs[name] = AttnProcessor()
+ unet.set_attn_processor(id_adapter_attn_procs)
+ self.id_adapter_attn_layers = nn.ModuleList(unet.attn_processors.values())
+
+ def load_pretrain(self):
+ hf_hub_download('guozinan/PuLID', 'pulid_v1.1.safetensors', local_dir='models')
+ ckpt_path = 'models/pulid_v1.1.safetensors'
+ state_dict = load_file(ckpt_path)
+ state_dict_dict = {}
+ for k, v in state_dict.items():
+ module = k.split('.')[0]
+ state_dict_dict.setdefault(module, {})
+ new_k = k[len(module) + 1 :]
+ state_dict_dict[module][new_k] = v
+
+ for module in state_dict_dict:
+ print(f'loading from {module}')
+ getattr(self, module).load_state_dict(state_dict_dict[module], strict=True)
+
+ def to_gray(self, img):
+ x = 0.299 * img[:, 0:1] + 0.587 * img[:, 1:2] + 0.114 * img[:, 2:3]
+ x = x.repeat(1, 3, 1, 1)
+ return x
+
+ def get_id_embedding(self, image_list):
+ """
+ Args:
+ image in image_list: numpy rgb image, range [0, 255]
+ """
+ id_cond_list = []
+ id_vit_hidden_list = []
+ for ii, image in enumerate(image_list):
+ self.face_helper.clean_all()
+ image_bgr = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
+ # get antelopev2 embedding
+ face_info = self.app.get(image_bgr)
+ if len(face_info) > 0:
+ face_info = sorted(
+ face_info, key=lambda x: (x['bbox'][2] - x['bbox'][0]) * (x['bbox'][3] - x['bbox'][1])
+ )[
+ -1
+ ] # only use the maximum face
+ id_ante_embedding = face_info['embedding']
+ self.debug_img_list.append(
+ image[
+ int(face_info['bbox'][1]) : int(face_info['bbox'][3]),
+ int(face_info['bbox'][0]) : int(face_info['bbox'][2]),
+ ]
+ )
+ else:
+ id_ante_embedding = None
+
+ # using facexlib to detect and align face
+ self.face_helper.read_image(image_bgr)
+ self.face_helper.get_face_landmarks_5(only_center_face=True)
+ self.face_helper.align_warp_face()
+ if len(self.face_helper.cropped_faces) == 0:
+ raise RuntimeError('facexlib align face fail')
+ align_face = self.face_helper.cropped_faces[0]
+ # incase insightface didn't detect face
+ if id_ante_embedding is None:
+ print('fail to detect face using insightface, extract embedding on align face')
+ id_ante_embedding = self.handler_ante.get_feat(align_face)
+
+ id_ante_embedding = torch.from_numpy(id_ante_embedding).to(self.device)
+ if id_ante_embedding.ndim == 1:
+ id_ante_embedding = id_ante_embedding.unsqueeze(0)
+
+ # parsing
+ input = img2tensor(align_face, bgr2rgb=True).unsqueeze(0) / 255.0
+ input = input.to(self.device)
+ parsing_out = self.face_helper.face_parse(normalize(input, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225]))[
+ 0
+ ]
+ parsing_out = parsing_out.argmax(dim=1, keepdim=True)
+ bg_label = [0, 16, 18, 7, 8, 9, 14, 15]
+ bg = sum(parsing_out == i for i in bg_label).bool()
+ white_image = torch.ones_like(input)
+ # only keep the face features
+ face_features_image = torch.where(bg, white_image, self.to_gray(input))
+ self.debug_img_list.append(tensor2img(face_features_image, rgb2bgr=False))
+
+ # transform img before sending to eva-clip-vit
+ face_features_image = resize(
+ face_features_image, self.clip_vision_model.image_size, InterpolationMode.BICUBIC
+ )
+ face_features_image = normalize(face_features_image, self.eva_transform_mean, self.eva_transform_std)
+ id_cond_vit, id_vit_hidden = self.clip_vision_model(
+ face_features_image, return_all_features=False, return_hidden=True, shuffle=False
+ )
+ id_cond_vit_norm = torch.norm(id_cond_vit, 2, 1, True)
+ id_cond_vit = torch.div(id_cond_vit, id_cond_vit_norm)
+
+ id_cond = torch.cat([id_ante_embedding, id_cond_vit], dim=-1)
+
+ id_cond_list.append(id_cond)
+ id_vit_hidden_list.append(id_vit_hidden)
+
+ id_uncond = torch.zeros_like(id_cond_list[0])
+ id_vit_hidden_uncond = []
+ for layer_idx in range(0, len(id_vit_hidden_list[0])):
+ id_vit_hidden_uncond.append(torch.zeros_like(id_vit_hidden_list[0][layer_idx]))
+
+ id_cond = torch.stack(id_cond_list, dim=1)
+ id_vit_hidden = id_vit_hidden_list[0]
+ for i in range(1, len(image_list)):
+ for j, x in enumerate(id_vit_hidden_list[i]):
+ id_vit_hidden[j] = torch.cat([id_vit_hidden[j], x], dim=1)
+ id_embedding = self.id_adapter(id_cond, id_vit_hidden)
+ uncond_id_embedding = self.id_adapter(id_uncond, id_vit_hidden_uncond)
+
+ # return id_embedding
+ return uncond_id_embedding, id_embedding
+
+ def __call__(self, x, sigma, **extra_args):
+ x_ddim_space = x / (sigma[:, None, None, None] ** 2 + self.sigma_data**2) ** 0.5
+ t = self.timestep(sigma)
+ cfg_scale = extra_args['cfg_scale']
+ eps_positive = self.pipe.unet(x_ddim_space, t, return_dict=False, **extra_args['positive'])[0]
+ eps_negative = self.pipe.unet(x_ddim_space, t, return_dict=False, **extra_args['negative'])[0]
+ noise_pred = eps_negative + cfg_scale * (eps_positive - eps_negative)
+ return x - noise_pred * sigma[:, None, None, None]
+
+ def inference(
+ self,
+ prompt,
+ size,
+ prompt_n='',
+ id_embedding=None,
+ uncond_id_embedding=None,
+ id_scale=1.0,
+ guidance_scale=1.2,
+ steps=4,
+ seed=-1,
+ ):
+
+ # sigmas
+ sigmas = self.get_sigmas_karras(steps).to(self.device)
+
+ # latents
+ noise = torch.randn((size[0], 4, size[1] // 8, size[2] // 8), device="cpu", generator=torch.manual_seed(seed))
+ noise = noise.to(dtype=self.pipe.unet.dtype, device=self.device)
+ latents = noise * sigmas[0].to(noise)
+
+ (
+ prompt_embeds,
+ negative_prompt_embeds,
+ pooled_prompt_embeds,
+ negative_pooled_prompt_embeds,
+ ) = self.pipe.encode_prompt(
+ prompt=prompt,
+ negative_prompt=prompt_n,
+ )
+
+ add_time_ids = list((size[1], size[2]) + (0, 0) + (size[1], size[2]))
+ add_time_ids = torch.tensor([add_time_ids], dtype=self.pipe.unet.dtype, device=self.device)
+ add_neg_time_ids = add_time_ids.clone()
+
+ sampler_kwargs = dict(
+ cfg_scale=guidance_scale,
+ positive=dict(
+ encoder_hidden_states=prompt_embeds,
+ added_cond_kwargs={"text_embeds": pooled_prompt_embeds, "time_ids": add_time_ids},
+ cross_attention_kwargs={'id_embedding': id_embedding, 'id_scale': id_scale},
+ ),
+ negative=dict(
+ encoder_hidden_states=negative_prompt_embeds,
+ added_cond_kwargs={"text_embeds": negative_pooled_prompt_embeds, "time_ids": add_neg_time_ids},
+ cross_attention_kwargs={'id_embedding': uncond_id_embedding, 'id_scale': id_scale},
+ ),
+ )
+
+ latents = self.sampler(self, latents, sigmas, extra_args=sampler_kwargs, disable=False)
+ latents = latents.to(dtype=self.pipe.vae.dtype, device=self.device) / self.pipe.vae.config.scaling_factor
+ images = self.pipe.vae.decode(latents).sample
+ images = self.pipe.image_processor.postprocess(images, output_type='pil')
+
+ return images
diff --git a/pulid/utils.py b/pulid/utils.py
index 8097672..14b5099 100644
--- a/pulid/utils.py
+++ b/pulid/utils.py
@@ -7,7 +7,9 @@
import numpy as np
import torch
import torch.nn.functional as F
+import torchsde
from torchvision.utils import make_grid
+from tqdm.auto import trange
from transformers import PretrainedConfig
@@ -164,3 +166,172 @@ def tensor2img(tensor, rgb2bgr=True, out_type=np.uint8, min_max=(0, 1)):
if len(result) == 1:
result = result[0]
return result
+
+
+# We didn't find a correct configuration to make the diffusers scheduler align with dpm++2m (karras) in ComfyUI,
+# so we copied the ComfyUI code directly.
+
+
+def append_dims(x, target_dims):
+ """Appends dimensions to the end of a tensor until it has target_dims dimensions."""
+ dims_to_append = target_dims - x.ndim
+ if dims_to_append < 0:
+ raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less')
+ expanded = x[(...,) + (None,) * dims_to_append]
+ # MPS will get inf values if it tries to index into the new axes, but detaching fixes this.
+ # https://github.com/pytorch/pytorch/issues/84364
+ return expanded.detach().clone() if expanded.device.type == 'mps' else expanded
+
+
+def to_d(x, sigma, denoised):
+ """Converts a denoiser output to a Karras ODE derivative."""
+ return (x - denoised) / append_dims(sigma, x.ndim)
+
+
+def get_ancestral_step(sigma_from, sigma_to, eta=1.0):
+ """Calculates the noise level (sigma_down) to step down to and the amount
+ of noise to add (sigma_up) when doing an ancestral sampling step."""
+ if not eta:
+ return sigma_to, 0.0
+ sigma_up = min(sigma_to, eta * (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5)
+ sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
+ return sigma_down, sigma_up
+
+
+class BatchedBrownianTree:
+ """A wrapper around torchsde.BrownianTree that enables batches of entropy."""
+
+ def __init__(self, x, t0, t1, seed=None, **kwargs):
+ self.cpu_tree = True
+ if "cpu" in kwargs:
+ self.cpu_tree = kwargs.pop("cpu")
+ t0, t1, self.sign = self.sort(t0, t1)
+ w0 = kwargs.get('w0', torch.zeros_like(x))
+ if seed is None:
+ seed = torch.randint(0, 2**63 - 1, []).item()
+ self.batched = True
+ try:
+ assert len(seed) == x.shape[0]
+ w0 = w0[0]
+ except TypeError:
+ seed = [seed]
+ self.batched = False
+ if self.cpu_tree:
+ self.trees = [torchsde.BrownianTree(t0.cpu(), w0.cpu(), t1.cpu(), entropy=s, **kwargs) for s in seed]
+ else:
+ self.trees = [torchsde.BrownianTree(t0, w0, t1, entropy=s, **kwargs) for s in seed]
+
+ @staticmethod
+ def sort(a, b):
+ return (a, b, 1) if a < b else (b, a, -1)
+
+ def __call__(self, t0, t1):
+ t0, t1, sign = self.sort(t0, t1)
+ if self.cpu_tree:
+ w = torch.stack(
+ [tree(t0.cpu().float(), t1.cpu().float()).to(t0.dtype).to(t0.device) for tree in self.trees]
+ ) * (self.sign * sign)
+ else:
+ w = torch.stack([tree(t0, t1) for tree in self.trees]) * (self.sign * sign)
+
+ return w if self.batched else w[0]
+
+
+class BrownianTreeNoiseSampler:
+ """A noise sampler backed by a torchsde.BrownianTree.
+
+ Args:
+ x (Tensor): The tensor whose shape, device and dtype to use to generate
+ random samples.
+ sigma_min (float): The low end of the valid interval.
+ sigma_max (float): The high end of the valid interval.
+ seed (int or List[int]): The random seed. If a list of seeds is
+ supplied instead of a single integer, then the noise sampler will
+ use one BrownianTree per batch item, each with its own seed.
+ transform (callable): A function that maps sigma to the sampler's
+ internal timestep.
+ """
+
+ def __init__(self, x, sigma_min, sigma_max, seed=None, transform=lambda x: x, cpu=False):
+ self.transform = transform
+ t0, t1 = self.transform(torch.as_tensor(sigma_min)), self.transform(torch.as_tensor(sigma_max))
+ self.tree = BatchedBrownianTree(x, t0, t1, seed, cpu=cpu)
+
+ def __call__(self, sigma, sigma_next):
+ t0, t1 = self.transform(torch.as_tensor(sigma)), self.transform(torch.as_tensor(sigma_next))
+ return self.tree(t0, t1) / (t1 - t0).abs().sqrt()
+
+
+@torch.no_grad()
+def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=None):
+ """DPM-Solver++(2M)."""
+ extra_args = {} if extra_args is None else extra_args
+ s_in = x.new_ones([x.shape[0]])
+ sigma_fn = lambda t: t.neg().exp()
+ t_fn = lambda sigma: sigma.log().neg()
+ old_denoised = None
+
+ for i in trange(len(sigmas) - 1, disable=disable):
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
+ if callback is not None:
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
+ t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])
+ h = t_next - t
+ if old_denoised is None or sigmas[i + 1] == 0:
+ x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised
+ else:
+ h_last = t - t_fn(sigmas[i - 1])
+ r = h_last / h
+ denoised_d = (1 + 1 / (2 * r)) * denoised - (1 / (2 * r)) * old_denoised
+ x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_d
+ old_denoised = denoised
+ return x
+
+
+@torch.no_grad()
+def sample_dpmpp_sde(
+ model, x, sigmas, extra_args=None, callback=None, disable=None, eta=1.0, s_noise=1.0, noise_sampler=None, r=1 / 2
+):
+ """DPM-Solver++ (stochastic)."""
+ sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
+ seed = extra_args.get("seed", None)
+ noise_sampler = (
+ BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seed, cpu=False)
+ if noise_sampler is None
+ else noise_sampler
+ )
+ extra_args = {} if extra_args is None else extra_args
+ s_in = x.new_ones([x.shape[0]])
+ sigma_fn = lambda t: t.neg().exp()
+ t_fn = lambda sigma: sigma.log().neg()
+
+ for i in trange(len(sigmas) - 1, disable=disable):
+ denoised = model(x, sigmas[i] * s_in, **extra_args)
+ if callback is not None:
+ callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
+ if sigmas[i + 1] == 0:
+ # Euler method
+ d = to_d(x, sigmas[i], denoised)
+ dt = sigmas[i + 1] - sigmas[i]
+ x = x + d * dt
+ else:
+ # DPM-Solver++
+ t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])
+ h = t_next - t
+ s = t + h * r
+ fac = 1 / (2 * r)
+
+ # Step 1
+ sd, su = get_ancestral_step(sigma_fn(t), sigma_fn(s), eta)
+ s_ = t_fn(sd)
+ x_2 = (sigma_fn(s_) / sigma_fn(t)) * x - (t - s_).expm1() * denoised
+ x_2 = x_2 + noise_sampler(sigma_fn(t), sigma_fn(s)) * s_noise * su
+ denoised_2 = model(x_2, sigma_fn(s) * s_in, **extra_args)
+
+ # Step 2
+ sd, su = get_ancestral_step(sigma_fn(t), sigma_fn(t_next), eta)
+ t_next_ = t_fn(sd)
+ denoised_d = (1 - fac) * denoised + fac * denoised_2
+ x = (sigma_fn(t_next_) / sigma_fn(t)) * x - (t - t_next_).expm1() * denoised_d
+ x = x + noise_sampler(sigma_fn(t), sigma_fn(t_next)) * s_noise * su
+ return x
diff --git a/requirements.txt b/requirements.txt
index 3b5c67e..081da6e 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -14,4 +14,5 @@ onnxruntime
onnxruntime-gpu
accelerate
SentencePiece
-safetensors
\ No newline at end of file
+safetensors
+torchsde
\ No newline at end of file