-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[training] add ds support to lora sd3. #10378
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -29,7 +29,7 @@ | |
import torch | ||
import torch.utils.checkpoint | ||
import transformers | ||
from accelerate import Accelerator | ||
from accelerate import Accelerator, DistributedType | ||
from accelerate.logging import get_logger | ||
from accelerate.utils import DistributedDataParallelKwargs, ProjectConfiguration, set_seed | ||
from huggingface_hub import create_repo, upload_folder | ||
|
@@ -1292,11 +1292,17 @@ def save_model_hook(models, weights, output_dir): | |
text_encoder_two_lora_layers_to_save = None | ||
|
||
for model in models: | ||
if isinstance(model, type(unwrap_model(transformer))): | ||
if isinstance(unwrap_model(model), type(unwrap_model(transformer))): | ||
model = unwrap_model(model) | ||
if args.upcast_before_saving: | ||
model = model.to(torch.float32) | ||
transformer_lora_layers_to_save = get_peft_model_state_dict(model) | ||
elif isinstance(model, type(unwrap_model(text_encoder_one))): # or text_encoder_two | ||
elif args.train_text_encoder and isinstance( | ||
unwrap_model(model), type(unwrap_model(text_encoder_one)) | ||
): # or text_encoder_two | ||
# both text encoders are of the same class, so we check hidden size to distinguish between the two | ||
hidden_size = unwrap_model(model).config.hidden_size | ||
model = unwrap_model(model) | ||
hidden_size = model.config.hidden_size | ||
if hidden_size == 768: | ||
text_encoder_one_lora_layers_to_save = get_peft_model_state_dict(model) | ||
elif hidden_size == 1280: | ||
|
@@ -1305,7 +1311,8 @@ def save_model_hook(models, weights, output_dir): | |
raise ValueError(f"unexpected save model: {model.__class__}") | ||
|
||
# make sure to pop weight so that corresponding model is not saved again | ||
weights.pop() | ||
if weights: | ||
weights.pop() | ||
|
||
StableDiffusion3Pipeline.save_lora_weights( | ||
output_dir, | ||
|
@@ -1319,17 +1326,31 @@ def load_model_hook(models, input_dir): | |
text_encoder_one_ = None | ||
text_encoder_two_ = None | ||
|
||
while len(models) > 0: | ||
model = models.pop() | ||
if not accelerator.distributed_type == DistributedType.DEEPSPEED: | ||
while len(models) > 0: | ||
model = models.pop() | ||
|
||
if isinstance(model, type(unwrap_model(transformer))): | ||
transformer_ = model | ||
elif isinstance(model, type(unwrap_model(text_encoder_one))): | ||
text_encoder_one_ = model | ||
elif isinstance(model, type(unwrap_model(text_encoder_two))): | ||
text_encoder_two_ = model | ||
else: | ||
raise ValueError(f"unexpected save model: {model.__class__}") | ||
if isinstance(unwrap_model(model), type(unwrap_model(transformer))): | ||
transformer_ = unwrap_model(model) | ||
elif isinstance(unwrap_model(model), type(unwrap_model(text_encoder_one))): | ||
text_encoder_one_ = unwrap_model(model) | ||
elif isinstance(unwrap_model(model), type(unwrap_model(text_encoder_two))): | ||
text_encoder_two_ = unwrap_model(model) | ||
else: | ||
raise ValueError(f"unexpected save model: {model.__class__}") | ||
|
||
else: | ||
transformer_ = SD3Transformer2DModel.from_pretrained( | ||
args.pretrained_model_name_or_path, subfolder="transformer" | ||
) | ||
transformer_.add_adapter(transformer_lora_config) | ||
if args.train_text_encoder: | ||
text_encoder_one_ = text_encoder_cls_one.from_pretrained( | ||
args.pretrained_model_name_or_path, subfolder="text_encoder" | ||
) | ||
text_encoder_two_ = text_encoder_cls_two.from_pretrained( | ||
args.pretrained_model_name_or_path, subfolder="text_encoder_2" | ||
) | ||
|
||
lora_state_dict = StableDiffusion3Pipeline.lora_state_dict(input_dir) | ||
|
||
|
@@ -1829,7 +1850,7 @@ def get_sigmas(timesteps, n_dim=4, dtype=torch.float32): | |
progress_bar.update(1) | ||
global_step += 1 | ||
|
||
if accelerator.is_main_process: | ||
if accelerator.is_main_process or accelerator.distributed_type == DistributedType.DEEPSPEED: | ||
if global_step % args.checkpointing_steps == 0: | ||
# _before_ saving state, check if this save would set us over the `checkpoints_total_limit` | ||
if args.checkpoints_total_limit is not None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it should be if args.checkpoints_total_limit is not None and accelerator.is_main_process: There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a must! It has to be changed!!! And I think the correct format should be if accelerator.is_main_process and args.checkpoints_total_limit is not None: There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Calm down sir :) The line is already under:
So, this should already take care of what you're suggesting. |
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can modify this to:
As
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
else
should not be needed as the model would already be type-casted.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Addressed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
why is the change of
isinstance(model,...
toisinstance(unwrap_model(model),...
is needed in the if statement?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That is a deepspeed-specific change as with deepspeed, the model gets wrapped into a
Module
.