-
Notifications
You must be signed in to change notification settings - Fork 210
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
e50b852
commit cce7cdd
Showing
22 changed files
with
2,342 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,292 @@ | ||
import time | ||
|
||
import gradio as gr | ||
import torch | ||
from einops import rearrange | ||
from PIL import Image | ||
|
||
from flux.cli import SamplingOptions | ||
from flux.sampling import denoise, get_noise, get_schedule, prepare, unpack | ||
from flux.util import load_ae, load_clip, load_flow_model, load_t5 | ||
from pulid.pipeline_flux import PuLIDPipeline | ||
from pulid.utils import resize_numpy_image_long | ||
|
||
|
||
def get_models(name: str, device: torch.device, offload: bool): | ||
t5 = load_t5(device, max_length=128) | ||
clip = load_clip(device) | ||
model = load_flow_model(name, device="cpu" if offload else device) | ||
model.eval() | ||
ae = load_ae(name, device="cpu" if offload else device) | ||
return model, ae, t5, clip | ||
|
||
|
||
class FluxGenerator: | ||
def __init__(self, model_name: str, device: str, offload: bool, args): | ||
self.device = torch.device(device) | ||
self.offload = offload | ||
self.model_name = model_name | ||
self.model, self.ae, self.t5, self.clip = get_models( | ||
model_name, | ||
device=self.device, | ||
offload=self.offload, | ||
) | ||
self.pulid_model = PuLIDPipeline(self.model, device, weight_dtype=torch.bfloat16) | ||
self.pulid_model.load_pretrain(args.pretrained_model) | ||
|
||
@torch.inference_mode() | ||
def generate_image( | ||
self, | ||
width, | ||
height, | ||
num_steps, | ||
start_step, | ||
guidance, | ||
seed, | ||
prompt, | ||
id_image=None, | ||
id_weight=1.0, | ||
neg_prompt="", | ||
true_cfg=1.0, | ||
timestep_to_start_cfg=1, | ||
max_sequence_length=128, | ||
): | ||
self.t5.max_length = max_sequence_length | ||
|
||
seed = int(seed) | ||
if seed == -1: | ||
seed = None | ||
|
||
opts = SamplingOptions( | ||
prompt=prompt, | ||
width=width, | ||
height=height, | ||
num_steps=num_steps, | ||
guidance=guidance, | ||
seed=seed, | ||
) | ||
|
||
if opts.seed is None: | ||
opts.seed = torch.Generator(device="cpu").seed() | ||
print(f"Generating '{opts.prompt}' with seed {opts.seed}") | ||
t0 = time.perf_counter() | ||
|
||
use_true_cfg = abs(true_cfg - 1.0) > 1e-2 | ||
|
||
if id_image is not None: | ||
id_image = resize_numpy_image_long(id_image, 1024) | ||
id_embeddings, uncond_id_embeddings = self.pulid_model.get_id_embedding(id_image, cal_uncond=use_true_cfg) | ||
else: | ||
id_embeddings = None | ||
uncond_id_embeddings = None | ||
|
||
# prepare input | ||
x = get_noise( | ||
1, | ||
opts.height, | ||
opts.width, | ||
device=self.device, | ||
dtype=torch.bfloat16, | ||
seed=opts.seed, | ||
) | ||
timesteps = get_schedule( | ||
opts.num_steps, | ||
x.shape[-1] * x.shape[-2] // 4, | ||
shift=True, | ||
) | ||
|
||
if self.offload: | ||
self.t5, self.clip = self.t5.to(self.device), self.clip.to(self.device) | ||
inp = prepare(t5=self.t5, clip=self.clip, img=x, prompt=opts.prompt) | ||
inp_neg = prepare(t5=self.t5, clip=self.clip, img=x, prompt=neg_prompt) if use_true_cfg else None | ||
|
||
# offload TEs to CPU, load model to gpu | ||
if self.offload: | ||
self.t5, self.clip = self.t5.cpu(), self.clip.cpu() | ||
torch.cuda.empty_cache() | ||
self.model = self.model.to(self.device) | ||
|
||
# denoise initial noise | ||
x = denoise( | ||
self.model, **inp, timesteps=timesteps, guidance=opts.guidance, id=id_embeddings, id_weight=id_weight, | ||
start_step=start_step, uncond_id=uncond_id_embeddings, true_cfg=true_cfg, | ||
timestep_to_start_cfg=timestep_to_start_cfg, | ||
neg_txt=inp_neg["txt"] if use_true_cfg else None, | ||
neg_txt_ids=inp_neg["txt_ids"] if use_true_cfg else None, | ||
neg_vec=inp_neg["vec"] if use_true_cfg else None, | ||
) | ||
|
||
# offload model, load autoencoder to gpu | ||
if self.offload: | ||
self.model.cpu() | ||
torch.cuda.empty_cache() | ||
self.ae.decoder.to(x.device) | ||
|
||
# decode latents to pixel space | ||
x = unpack(x.float(), opts.height, opts.width) | ||
with torch.autocast(device_type=self.device.type, dtype=torch.bfloat16): | ||
x = self.ae.decode(x) | ||
|
||
if self.offload: | ||
self.ae.decoder.cpu() | ||
torch.cuda.empty_cache() | ||
|
||
t1 = time.perf_counter() | ||
|
||
print(f"Done in {t1 - t0:.1f}s.") | ||
# bring into PIL format | ||
x = x.clamp(-1, 1) | ||
# x = embed_watermark(x.float()) | ||
x = rearrange(x[0], "c h w -> h w c") | ||
|
||
img = Image.fromarray((127.5 * (x + 1.0)).cpu().byte().numpy()) | ||
return img, str(opts.seed), self.pulid_model.debug_img_list | ||
|
||
_HEADER_ = ''' | ||
<div style="text-align: center; max-width: 650px; margin: 0 auto;"> | ||
<h1 style="font-size: 2.5rem; font-weight: 700; margin-bottom: 1rem; display: contents;">PuLID for FLUX</h1> | ||
<p style="font-size: 1rem; margin-bottom: 1.5rem;">Paper: <a href='https://arxiv.org/abs/2404.16022' target='_blank'>PuLID: Pure and Lightning ID Customization via Contrastive Alignment</a> | Codes: <a href='https://github.com/ToTheBeginning/PuLID' target='_blank'>GitHub</a></p> | ||
</div> | ||
❗️❗️❗️**Tips:** | ||
- `timestep to start inserting ID:` The smaller the value, the higher the fidelity, but the lower the editability; the higher the value, the lower the fidelity, but the higher the editability. **The recommended range for this value is between 0 and 4**. For photorealistic scenes, we recommend using 4; for stylized scenes, we recommend using 0-1. If you are not satisfied with the similarity, you can lower this value; conversely, if you are not satisfied with the editability, you can increase this value. | ||
- `true CFG scale:` In most scenarios, it is recommended to use a fake CFG, i.e., setting the true CFG scale to 1, and just adjusting the guidance scale. This is also more efficiency. However, in a few cases, utilizing a true CFG can yield better results. For more detaileds, please refer to XX. | ||
- please refer to the <a href='URL_ADDRESS' target='_blank'>github doc</a> for more details and info about the model, we provide the detail explanation about the above two parameters in the doc. | ||
- we provide some examples in the bottom, you can try these example prompts first | ||
''' # noqa E501 | ||
|
||
_CITE_ = r""" | ||
If PuLID is helpful, please help to ⭐ the <a href='https://github.com/ToTheBeginning/PuLID' target='_blank'> Github Repo</a>. Thanks! | ||
--- | ||
📧 **Contact** | ||
If you have any questions or feedbacks, feel free to open a discussion or contact <b>[email protected]</b>. | ||
""" # noqa E501 | ||
|
||
|
||
def create_demo(args, model_name: str, device: str = "cuda" if torch.cuda.is_available() else "cpu", | ||
offload: bool = False): | ||
generator = FluxGenerator(model_name, device, offload, args) | ||
|
||
with gr.Blocks() as demo: | ||
gr.Markdown(_HEADER_) | ||
|
||
with gr.Row(): | ||
with gr.Column(): | ||
prompt = gr.Textbox(label="Prompt", value="portrait, color, cinematic") | ||
id_image = gr.Image(label="ID Image") | ||
id_weight = gr.Slider(0.0, 3.0, 1, step=0.05, label="id weight") | ||
|
||
width = gr.Slider(256, 1536, 896, step=16, label="Width") | ||
height = gr.Slider(256, 1536, 1152, step=16, label="Height") | ||
num_steps = gr.Slider(1, 20, 20, step=1, label="Number of steps") | ||
start_step = gr.Slider(0, 10, 0, step=1, label="timestep to start inserting ID") | ||
guidance = gr.Slider(1.0, 10.0, 4, step=0.1, label="Guidance") | ||
seed = gr.Textbox(-1, label="Seed (-1 for random)") | ||
max_sequence_length = gr.Slider(128, 512, 128, step=128, | ||
label="max_sequence_length for prompt (T5), small will be faster") | ||
|
||
with gr.Accordion("Advanced Options (True CFG, true_cfg_scale=1 means use fake CFG, >1 means use true CFG, if using true CFG, we recommend set the guidance scale to 1)", open=False): # noqa E501 | ||
neg_prompt = gr.Textbox( | ||
label="Negative Prompt", | ||
value="bad quality, worst quality, text, signature, watermark, extra limbs") | ||
true_cfg = gr.Slider(1.0, 10.0, 1, step=0.1, label="true CFG scale") | ||
timestep_to_start_cfg = gr.Slider(0, 20, 1, step=1, label="timestep to start cfg", visible=args.dev) | ||
|
||
generate_btn = gr.Button("Generate") | ||
|
||
with gr.Column(): | ||
output_image = gr.Image(label="Generated Image") | ||
seed_output = gr.Textbox(label="Used Seed") | ||
intermediate_output = gr.Gallery(label='Output', elem_id="gallery", visible=args.dev) | ||
gr.Markdown(_CITE_) | ||
|
||
with gr.Row(), gr.Column(): | ||
gr.Markdown("## Examples") | ||
example_inps = [ | ||
[ | ||
'a woman holding sign with glowing green text \"PuLID for FLUX\"', | ||
'example_inputs/liuyifei.png', | ||
4, 4, 2680261499100305976, 1 | ||
], | ||
[ | ||
'portrait, side view', | ||
'example_inputs/liuyifei.png', | ||
4, 4, 1205240166692517553, 1 | ||
], | ||
[ | ||
'white-haired woman with vr technology atmosphere, revolutionary exceptional magnum with remarkable details', # noqa E501 | ||
'example_inputs/liuyifei.png', | ||
4, 4, 6349424134217931066, 1 | ||
], | ||
[ | ||
'a young child is eating Icecream', | ||
'example_inputs/liuyifei.png', | ||
4, 4, 10606046113565776207, 1 | ||
], | ||
[ | ||
'a man is holding a sign with text \"PuLID for FLUX\", winter, snowing, top of the mountain', | ||
'example_inputs/pengwei.jpg', | ||
4, 4, 2410129802683836089, 1 | ||
], | ||
[ | ||
'portrait, candle light', | ||
'example_inputs/pengwei.jpg', | ||
4, 4, 17522759474323955700, 1 | ||
], | ||
[ | ||
'profile shot dark photo of a 25-year-old male with smoke escaping from his mouth, the backlit smoke gives the image an ephemeral quality, natural face, natural eyebrows, natural skin texture, award winning photo, highly detailed face, atmospheric lighting, film grain, monochrome', # noqa E501 | ||
'example_inputs/pengwei.jpg', | ||
4, 4, 17733156847328193625, 1 | ||
], | ||
[ | ||
'American Comics, 1boy', | ||
'example_inputs/pengwei.jpg', | ||
1, 4, 13223174453874179686, 1 | ||
], | ||
[ | ||
'portrait, pixar', | ||
'example_inputs/pengwei.jpg', | ||
1, 4, 9445036702517583939, 1 | ||
], | ||
] | ||
gr.Examples(examples=example_inps, inputs=[prompt, id_image, start_step, guidance, seed, true_cfg], | ||
label='fake CFG') | ||
|
||
example_inps = [ | ||
[ | ||
'portrait, made of ice sculpture', | ||
'example_inputs/lecun.jpg', | ||
1, 1, 3811899118709451814, 5 | ||
], | ||
] | ||
gr.Examples(examples=example_inps, inputs=[prompt, id_image, start_step, guidance, seed, true_cfg], | ||
label='true CFG') | ||
|
||
generate_btn.click( | ||
fn=generator.generate_image, | ||
inputs=[width, height, num_steps, start_step, guidance, seed, prompt, id_image, id_weight, neg_prompt, | ||
true_cfg, timestep_to_start_cfg, max_sequence_length], | ||
outputs=[output_image, seed_output, intermediate_output], | ||
) | ||
|
||
return demo | ||
|
||
|
||
if __name__ == "__main__": | ||
import argparse | ||
|
||
parser = argparse.ArgumentParser(description="PuLID for FLUX.1-dev") | ||
parser.add_argument("--name", type=str, default="flux-dev", choices=list('flux-dev'), | ||
help="currently only support flux-dev") | ||
parser.add_argument("--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu", | ||
help="Device to use") | ||
parser.add_argument("--offload", action="store_true", help="Offload model to CPU when not in use") | ||
parser.add_argument("--port", type=int, default=8080, help="Port to use") | ||
parser.add_argument("--dev", action='store_true', help="Development mode") | ||
parser.add_argument("--pretrained_model", type=str, help='for development') | ||
args = parser.parse_args() | ||
|
||
demo = create_demo(args, args.name, args.device, args.offload) | ||
demo.launch(server_name='0.0.0.0', server_port=args.port) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
# PuLID for FLUX | ||
We are happy to release the **PuLID-FLUX-v0.9.0** model, which provides a tuning-free ID customization solution for FLUX.1-dev. | ||
|
||
If PuLID-FLUX is helpful, please help to ⭐ this repo or recommend it to your friends 😊 | ||
|
||
## Inference | ||
### Local Gradio Demo | ||
1. Please first follow the [dependencies-and-installation](../README.md#wrench-dependencies-and-installation) to set up the environment. | ||
2. Download flux1-dev.safetensors and ae.safetensors from [black-forest-labs/FLUX.1-dev](https://huggingface.co/black-forest-labs/FLUX.1-dev/tree/main) to the models folder. | ||
3. Run the gradio demo with `python app_flux.py` | ||
|
||
### Online Demo | ||
- huggingface demo: we are currently preparing it, will be available soon. | ||
|
||
### ComfyUI | ||
Please stay tuned for the community implementation | ||
|
||
## Visual Results | ||
![pulid_flux_results](https://github.com/user-attachments/assets/7eafb90a-fdd1-4ae7-bc41-8c428d568848) | ||
|
||
|
||
## Useful Tips | ||
There are two parameters that are crucial and need to be set carefully: | ||
|
||
1. `timestep to start inserting ID`: This parameter controls the timing of ID insertion. If set to 0, the ID starts being inserted to the DIT from the first timestep. The earlier it is inserted, the higher the ID fidelity will be, but the editability may decrease. The later it is inserted, the lower the fidelity to the ID, but the editability will increase, and the disruption to the original model behavior will also be smaller. For generating realistic images, we suggest setting this to 4. If you found the ID similarity is not high enough, you could try lowering this parameter accordingly. For generating stylized images, we suggest setting it to 0-1. | ||
![start_id](https://github.com/user-attachments/assets/3866ffab-542d-4e2f-9a0c-6877c9158d49) | ||
|
||
2. `true CFG scale`: FLUX.1-dev is a guidance distill model. The original CFG process, which required twice the number of inference steps, is distilled into a guidance scale, thereby modulating the DIT through the guidance scale to simulate the true CFG process with half the inference steps. We will refer to this as fake CFG in the following doc. Our PuLID-FLUX model can be tested under the fake CFG settings, and the guidance scale can be set to a commonly used value, such as 4. However, the model also supports using the real CFG for inference. We compare the results of using true CFG with the fake CFG in photorealistic scenarios below. | ||
![fake_cfg_vs_true_cfg_fidelity](https://github.com/user-attachments/assets/73b44dc8-37c7-48c8-8f55-73882731126d) | ||
As shown in the above image, in terms of ID fidelity, using fake CFG is similar to true CFG in most cases, except that in a few cases, true CFG achieves higher ID similarity. In terms of image aesthetics and facial naturalness, fake CFG performs better. However, by carefully adjusting hyperparameters, the performance of true CFG may be further improved, we leave this to the community to explore. Therefore, we recommend using fake CFG for photorealistic scenes. If you are not satisfy about the ID fidelity, you can try switching to true CFG. Additionally, as shown below, we have found that using fake CFG in stylized scenes sometimes results in lower ID similarity and poorer style response, so if you encounter these two issues in stylized scenes, please consider switching to true CFG. | ||
![fake_cfg_vs_true_cfg_style](https://github.com/user-attachments/assets/fb042639-64e6-4bb3-a3a4-5c138793318e) | ||
|
||
|
||
|
||
## Some Technical Details | ||
- We switch the ID encoder from an MLP structure to a Transformer structure. Interested users can refer to [source code](../pulid/encoders_flux.py) | ||
- Inspired by [Flamingo](https://arxiv.org/abs/2204.14198), we insert additional cross-attention blocks every few DIT blocks to interact ID features with DIT image features | ||
- We would like to clarify that the acceleration method (lile SDXL-Lightning) serves as an | ||
optional acceleration trick, but it is not indispensable for training PuLID. We will update the arxiv paper with the relevant details in the near future. Please stay tuned. | ||
|
||
|
||
## limitation | ||
The model is currently in beta version, and we have observed that the ID fidelity may not be high for some male inputs, maybe the model requires more training. If the improved model is ready, we will release it here, so please stay tuned. | ||
|
||
## contact | ||
If you have any questions or suggestions about the model, please contact [Yanze Wu](https://tothebeginning.github.io/) or open an issue/discussion here. |
File renamed without changes.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,11 @@ | ||
try: | ||
from ._version import version as __version__ # type: ignore | ||
from ._version import version_tuple | ||
except ImportError: | ||
__version__ = "unknown (no version information available)" | ||
version_tuple = (0, 0, "unknown", "noinfo") | ||
|
||
from pathlib import Path | ||
|
||
PACKAGE = __package__.replace("_", "-") | ||
PACKAGE_ROOT = Path(__file__).parent |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,4 @@ | ||
from .cli import app | ||
|
||
if __name__ == "__main__": | ||
app() |
Oops, something went wrong.