Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[training] add ds support to lora sd3. #10378

Merged
merged 4 commits into from
Dec 30, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 37 additions & 16 deletions examples/dreambooth/train_dreambooth_lora_sd3.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@
import torch
import torch.utils.checkpoint
import transformers
from accelerate import Accelerator
from accelerate import Accelerator, DistributedType
from accelerate.logging import get_logger
from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed
from huggingface_hub import create_repo, upload_folder
Expand Down Expand Up @@ -1292,11 +1292,17 @@ def save_model_hook(models, weights, output_dir):
text_encoder_two_lora_layers_to_save = None

for model in models:
if isinstance(model, type(unwrap_model(transformer))):
if isinstance(unwrap_model(model), type(unwrap_model(transformer))):
model = unwrap_model(model)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can modify this to:

transformer_model = unwrap_model(model)
    if args.upcast_before_saving:
        transformer_model = transformer_model.to(torch.float32)
    else:
        transformer_model = transformer_model.to(weight_dtype)
    transformer_lora_layers_to_save = get_peft_model_state_dict(transformer_model)

As

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

else should not be needed as the model would already be type-casted.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why is the change of isinstance(model,... to isinstance(unwrap_model(model),... is needed in the if statement?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That is a deepspeed-specific change as with deepspeed, the model gets wrapped into a Module.

if args.upcast_before_saving:
model = model.to(torch.float32)
transformer_lora_layers_to_save = get_peft_model_state_dict(model)
elif isinstance(model, type(unwrap_model(text_encoder_one))): # or text_encoder_two
elif args.train_text_encoder and isinstance(
unwrap_model(model), type(unwrap_model(text_encoder_one))
): # or text_encoder_two
# both text encoders are of the same class, so we check hidden size to distinguish between the two
hidden_size = unwrap_model(model).config.hidden_size
model = unwrap_model(model)
hidden_size = model.config.hidden_size
if hidden_size == 768:
text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model)
elif hidden_size == 1280:
Expand All @@ -1305,7 +1311,8 @@ def save_model_hook(models, weights, output_dir):
raise ValueError(f"unexpected save model: {model.__class__}")

# make sure to pop weight so that corresponding model is not saved again
weights.pop()
if weights:
weights.pop()

StableDiffusion3Pipeline.save_lora_weights(
output_dir,
Expand All @@ -1319,17 +1326,31 @@ def load_model_hook(models, input_dir):
text_encoder_one_ = None
text_encoder_two_ = None

while len(models) > 0:
model = models.pop()
if not accelerator.distributed_type == DistributedType.DEEPSPEED:
while len(models) > 0:
model = models.pop()

if isinstance(model, type(unwrap_model(transformer))):
transformer_ = model
elif isinstance(model, type(unwrap_model(text_encoder_one))):
text_encoder_one_ = model
elif isinstance(model, type(unwrap_model(text_encoder_two))):
text_encoder_two_ = model
else:
raise ValueError(f"unexpected save model: {model.__class__}")
if isinstance(unwrap_model(model), type(unwrap_model(transformer))):
transformer_ = unwrap_model(model)
elif isinstance(unwrap_model(model), type(unwrap_model(text_encoder_one))):
text_encoder_one_ = unwrap_model(model)
elif isinstance(unwrap_model(model), type(unwrap_model(text_encoder_two))):
text_encoder_two_ = unwrap_model(model)
else:
raise ValueError(f"unexpected save model: {model.__class__}")

else:
transformer_ = SD3Transformer2DModel.from_pretrained(
args.pretrained_model_name_or_path, subfolder="transformer"
)
transformer_.add_adapter(transformer_lora_config)
if args.train_text_encoder:
text_encoder_one_ = text_encoder_cls_one.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder"
)
text_encoder_two_ = text_encoder_cls_two.from_pretrained(
args.pretrained_model_name_or_path, subfolder="text_encoder_2"
)

lora_state_dict = StableDiffusion3Pipeline.lora_state_dict(input_dir)

Expand Down Expand Up @@ -1829,7 +1850,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32):
progress_bar.update(1)
global_step += 1

if accelerator.is_main_process:
if accelerator.is_main_process or accelerator.distributed_type == DistributedType.DEEPSPEED:
if global_step % args.checkpointing_steps == 0:
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit`
if args.checkpoints_total_limit is not None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it should be if args.checkpoints_total_limit is not None and accelerator.is_main_process:

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a must! It has to be changed!!! And I think the correct format should be if accelerator.is_main_process and args.checkpoints_total_limit is not None:

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Calm down sir :)

The line is already under:

if accelerator.is_main_process or accelerator.distributed_type == DistributedType.DEEPSPEED:

So, this should already take care of what you're suggesting.

Expand Down
Loading