From 0eae18dac2ab0739871af52b67622a74f0b6ecf0 Mon Sep 17 00:00:00 2001 From: yanzewu Date: Thu, 19 Sep 2024 19:50:59 +0800 Subject: [PATCH] support 12GB cards --- README.md | 1 + docs/pulid_for_flux.md | 5 ++++- flux/model.py | 10 +++++++++- 3 files changed, 14 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index ef1623d..5330ed8 100644 --- a/README.md +++ b/README.md @@ -21,6 +21,7 @@ We will actively update and maintain this repository in the near future, so plea - [x] Local gradio demo is ready now - [x] Online HuggingFace demo is ready now [![flux](https://img.shields.io/badge/🤗-PuLID_FLUX_demo-orange)](https://huggingface.co/spaces/yanze/PuLID-FLUX) - [x] We have optimized the codes to support consumer-grade GPUS, and now **PuLID-FLUX can run on a 16GB graphic card**. Check the details [here](https://github.com/ToTheBeginning/PuLID/blob/main/docs/pulid_for_flux.md#local-gradio-demo) +- [x] Support 12GB graphic card Below results are generated with PuLID-FLUX. diff --git a/docs/pulid_for_flux.md b/docs/pulid_for_flux.md index 24354a0..83f5e7b 100644 --- a/docs/pulid_for_flux.md +++ b/docs/pulid_for_flux.md @@ -25,7 +25,10 @@ Run `python app_flux.py --offload --fp8 --onnx_provider cpu`, the peak memory is For 24GB graphic memory users, you can run `python app_flux.py --offload --fp8`, the peak memory is under 17GB. -However, there is a difference in image quality between fp8 and bf16, with some degradation in the former. +For 12GB graphic memory users, you can run `python app_flux.py --aggressive_offload --fp8 --onnx_provider cpu`, the peak memory is about 11GB. +However, using aggressive offload (like sequential offload), the speed will be very slow due to the frequent need for memory transfers between CPU and GPU at each timestep. + +Please note that, there is a difference in image quality between fp8 and bf16, with some degradation in the former. Specifically, the details of the face may be slightly worse, but the layout is similar. If you want the best results of PuLID-FLUX or you have the resources, please use bf16 rather than fp8. We have included a comparison in the table below. diff --git a/flux/model.py b/flux/model.py index 981c7c7..846b42d 100644 --- a/flux/model.py +++ b/flux/model.py @@ -127,8 +127,16 @@ def forward( img = torch.cat((txt, img), 1) if aggressive_offload: - self.single_blocks = self.single_blocks.to(DEVICE) + # put half of the single blcoks to gpu + for i in range(len(self.single_blocks) // 2): + self.single_blocks[i] = self.single_blocks[i].to(DEVICE) for i, block in enumerate(self.single_blocks): + if aggressive_offload and i == len(self.single_blocks)//2: + # put first half of the single blcoks to cpu and last half to gpu + for j in range(len(self.single_blocks) // 2): + self.single_blocks[j].cpu() + for j in range(len(self.single_blocks) // 2, len(self.single_blocks)): + self.single_blocks[j] = self.single_blocks[j].to(DEVICE) x = block(img, vec=vec, pe=pe) real_img, txt = x[:, txt.shape[1]:, ...], x[:, :txt.shape[1], ...]