Skip to content

Commit

Permalink
Add offload option in flux-control training (#10225)
Browse files Browse the repository at this point in the history
* Add offload option in flux-control training

* Update examples/flux-control/train_control_flux.py

Co-authored-by: Sayak Paul <[email protected]>

* modify help message

* fix format

---------

Co-authored-by: Sayak Paul <[email protected]>
  • Loading branch information
Adenialzz and sayakpaul authored Dec 15, 2024
1 parent a5f35ee commit 96a9097
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 6 deletions.
2 changes: 2 additions & 0 deletions examples/flux-control/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ accelerate launch train_control_lora_flux.py \
--max_train_steps=5000 \
--validation_image="openpose.png" \
--validation_prompt="A couple, 4k photo, highly detailed" \
--offload \
--seed="0" \
--push_to_hub
```
Expand Down Expand Up @@ -154,6 +155,7 @@ accelerate launch --config_file=accelerate_ds2.yaml train_control_flux.py \
--validation_steps=200 \
--validation_image "2_pose_1024.jpg" "3_pose_1024.jpg" \
--validation_prompt "two friends sitting by each other enjoying a day at the park, full hd, cinematic" "person enjoying a day at the park, full hd, cinematic" \
--offload \
--seed="0" \
--push_to_hub
```
Expand Down
13 changes: 10 additions & 3 deletions examples/flux-control/train_control_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -541,6 +541,11 @@ def parse_args(input_args=None):
default=1.29,
help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.",
)
parser.add_argument(
"--offload",
action="store_true",
help="Whether to offload the VAE and the text encoders to CPU when they are not used.",
)

if input_args is not None:
args = parser.parse_args(input_args)
Expand Down Expand Up @@ -999,8 +1004,9 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
control_latents = encode_images(
batch["conditioning_pixel_values"], vae.to(accelerator.device), weight_dtype
)
# offload vae to CPU.
vae.cpu()
if args.offload:
# offload vae to CPU.
vae.cpu()

# Sample a random timestep for each image
# for weighting schemes where we sample timesteps non-uniformly
Expand Down Expand Up @@ -1064,7 +1070,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
if args.proportion_empty_prompts and random.random() < args.proportion_empty_prompts:
prompt_embeds.zero_()
pooled_prompt_embeds.zero_()
text_encoding_pipeline = text_encoding_pipeline.to("cpu")
if args.offload:
text_encoding_pipeline = text_encoding_pipeline.to("cpu")

# Predict.
model_pred = flux_transformer(
Expand Down
14 changes: 11 additions & 3 deletions examples/flux-control/train_control_lora_flux.py
Original file line number Diff line number Diff line change
Expand Up @@ -573,6 +573,11 @@ def parse_args(input_args=None):
default=1.29,
help="Scale of mode weighting scheme. Only effective when using the `'mode'` as the `weighting_scheme`.",
)
parser.add_argument(
"--offload",
action="store_true",
help="Whether to offload the VAE and the text encoders to CPU when they are not used.",
)

if input_args is not None:
args = parser.parse_args(input_args)
Expand Down Expand Up @@ -1140,8 +1145,10 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
control_latents = encode_images(
batch["conditioning_pixel_values"], vae.to(accelerator.device), weight_dtype
)
# offload vae to CPU.
vae.cpu()

if args.offload:
# offload vae to CPU.
vae.cpu()

# Sample a random timestep for each image
# for weighting schemes where we sample timesteps non-uniformly
Expand Down Expand Up @@ -1205,7 +1212,8 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
if args.proportion_empty_prompts and random.random() < args.proportion_empty_prompts:
prompt_embeds.zero_()
pooled_prompt_embeds.zero_()
text_encoding_pipeline = text_encoding_pipeline.to("cpu")
if args.offload:
text_encoding_pipeline = text_encoding_pipeline.to("cpu")

# Predict.
model_pred = flux_transformer(
Expand Down

0 comments on commit 96a9097

Please sign in to comment.