Skip to content

Commit

Permalink
Merge branch 'main' into refactor/omnigen
Browse files Browse the repository at this point in the history
  • Loading branch information
a-r-r-o-w committed Feb 11, 2025
2 parents 970b086 + 798e171 commit d5d7caa
Show file tree
Hide file tree
Showing 8 changed files with 163 additions and 7 deletions.
3 changes: 3 additions & 0 deletions .github/workflows/trufflehog.yml
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,6 @@ jobs:
fetch-depth: 0
- name: Secret Scanning
uses: trufflesecurity/trufflehog@main
with:
extra_args: --results=verified,unknown

84 changes: 83 additions & 1 deletion src/diffusers/loaders/lora_conversion_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -519,7 +519,7 @@ def _convert_sd_scripts_to_ai_toolkit(sds_sd):
remaining_keys = list(sds_sd.keys())
te_state_dict = {}
if remaining_keys:
if not all(k.startswith("lora_te1") for k in remaining_keys):
if not all(k.startswith("lora_te") for k in remaining_keys):
raise ValueError(f"Incompatible keys detected: \n\n {', '.join(remaining_keys)}")
for key in remaining_keys:
if not key.endswith("lora_down.weight"):
Expand Down Expand Up @@ -558,6 +558,88 @@ def _convert_sd_scripts_to_ai_toolkit(sds_sd):
new_state_dict = {**ait_sd, **te_state_dict}
return new_state_dict

def _convert_mixture_state_dict_to_diffusers(state_dict):
new_state_dict = {}

def _convert(original_key, diffusers_key, state_dict, new_state_dict):
down_key = f"{original_key}.lora_down.weight"
down_weight = state_dict.pop(down_key)
lora_rank = down_weight.shape[0]

up_weight_key = f"{original_key}.lora_up.weight"
up_weight = state_dict.pop(up_weight_key)

alpha_key = f"{original_key}.alpha"
alpha = state_dict.pop(alpha_key)

# scale weight by alpha and dim
scale = alpha / lora_rank
# calculate scale_down and scale_up
scale_down = scale
scale_up = 1.0
while scale_down * 2 < scale_up:
scale_down *= 2
scale_up /= 2
down_weight = down_weight * scale_down
up_weight = up_weight * scale_up

diffusers_down_key = f"{diffusers_key}.lora_A.weight"
new_state_dict[diffusers_down_key] = down_weight
new_state_dict[diffusers_down_key.replace(".lora_A.", ".lora_B.")] = up_weight

all_unique_keys = {
k.replace(".lora_down.weight", "").replace(".lora_up.weight", "").replace(".alpha", "") for k in state_dict
}
all_unique_keys = sorted(all_unique_keys)
assert all("lora_transformer_" in k for k in all_unique_keys), f"{all_unique_keys=}"

for k in all_unique_keys:
if k.startswith("lora_transformer_single_transformer_blocks_"):
i = int(k.split("lora_transformer_single_transformer_blocks_")[-1].split("_")[0])
diffusers_key = f"single_transformer_blocks.{i}"
elif k.startswith("lora_transformer_transformer_blocks_"):
i = int(k.split("lora_transformer_transformer_blocks_")[-1].split("_")[0])
diffusers_key = f"transformer_blocks.{i}"
else:
raise NotImplementedError

if "attn_" in k:
if "_to_out_0" in k:
diffusers_key += ".attn.to_out.0"
elif "_to_add_out" in k:
diffusers_key += ".attn.to_add_out"
elif any(qkv in k for qkv in ["to_q", "to_k", "to_v"]):
remaining = k.split("attn_")[-1]
diffusers_key += f".attn.{remaining}"
elif any(add_qkv in k for add_qkv in ["add_q_proj", "add_k_proj", "add_v_proj"]):
remaining = k.split("attn_")[-1]
diffusers_key += f".attn.{remaining}"

if diffusers_key == f"transformer_blocks.{i}":
print(k, diffusers_key)
_convert(k, diffusers_key, state_dict, new_state_dict)

if len(state_dict) > 0:
raise ValueError(
f"Expected an empty state dict at this point but its has these keys which couldn't be parsed: {list(state_dict.keys())}."
)

new_state_dict = {f"transformer.{k}": v for k, v in new_state_dict.items()}
return new_state_dict

# This is weird.
# https://huggingface.co/sayakpaul/different-lora-from-civitai/tree/main?show_file_info=sharp_detailed_foot.safetensors
# has both `peft` and non-peft state dict.
has_peft_state_dict = any(k.startswith("transformer.") for k in state_dict)
if has_peft_state_dict:
state_dict = {k: v for k, v in state_dict.items() if k.startswith("transformer.")}
return state_dict
# Another weird one.
has_mixture = any(
k.startswith("lora_transformer_") and ("lora_down" in k or "lora_up" in k or "alpha" in k) for k in state_dict
)
if has_mixture:
return _convert_mixture_state_dict_to_diffusers(state_dict)
return _convert_sd_scripts_to_ai_toolkit(state_dict)


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,11 @@
def prepare_causal_attention_mask(
num_frames: int, height_width: int, dtype: torch.dtype, device: torch.device, batch_size: int = None
) -> torch.Tensor:
seq_len = num_frames * height_width
mask = torch.full((seq_len, seq_len), float("-inf"), dtype=dtype, device=device)
for i in range(seq_len):
i_frame = i // height_width
mask[i, : (i_frame + 1) * height_width] = 0
indices = torch.arange(1, num_frames + 1, dtype=torch.int32, device=device)
indices_blocks = indices.repeat_interleave(height_width)
x, y = torch.meshgrid(indices_blocks, indices_blocks, indexing="xy")
mask = torch.where(x <= y, 0, -float("inf")).to(dtype=dtype)

if batch_size is not None:
mask = mask.unsqueeze(0).expand(batch_size, -1, -1)
return mask
Expand Down
3 changes: 2 additions & 1 deletion src/diffusers/models/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from huggingface_hub import DDUFEntry, create_repo, split_torch_state_dict_into_shards
from huggingface_hub.utils import validate_hf_hub_args
from torch import Tensor, nn
from typing_extensions import Self

from .. import __version__
from ..hooks import apply_layerwise_casting
Expand Down Expand Up @@ -605,7 +606,7 @@ def dequantize(self):

@classmethod
@validate_hf_hub_args
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs) -> Self:
r"""
Instantiate a pretrained PyTorch model from a pretrained model configuration.
Expand Down
26 changes: 26 additions & 0 deletions tests/models/autoencoders/test_models_autoencoder_hunyuan_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import torch

from diffusers import AutoencoderKLHunyuanVideo
from diffusers.models.autoencoders.autoencoder_kl_hunyuan_video import prepare_causal_attention_mask
from diffusers.utils.testing_utils import (
enable_full_determinism,
floats_tensor,
Expand Down Expand Up @@ -182,3 +183,28 @@ def test_forward_with_norm_groups(self):
@unittest.skip("Unsupported test.")
def test_outputs_equivalence(self):
pass

def test_prepare_causal_attention_mask(self):
def prepare_causal_attention_mask_orig(
num_frames: int, height_width: int, dtype: torch.dtype, device: torch.device, batch_size: int = None
) -> torch.Tensor:
seq_len = num_frames * height_width
mask = torch.full((seq_len, seq_len), float("-inf"), dtype=dtype, device=device)
for i in range(seq_len):
i_frame = i // height_width
mask[i, : (i_frame + 1) * height_width] = 0
if batch_size is not None:
mask = mask.unsqueeze(0).expand(batch_size, -1, -1)
return mask

# test with some odd shapes
original_mask = prepare_causal_attention_mask_orig(
num_frames=31, height_width=111, dtype=torch.float32, device=torch_device
)
new_mask = prepare_causal_attention_mask(
num_frames=31, height_width=111, dtype=torch.float32, device=torch_device
)
self.assertTrue(
torch.allclose(original_mask, new_mask),
"Causal attention mask should be the same",
)
6 changes: 6 additions & 0 deletions tests/models/autoencoders/test_models_autoencoder_oobleck.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,12 @@ def test_forward_with_norm_groups(self):
def test_set_attn_processor_for_determinism(self):
return

@unittest.skip(
"Test not supported because of 'weight_norm_fwd_first_dim_kernel' not implemented for 'Float8_e4m3fn'"
)
def test_layerwise_casting_training(self):
return super().test_layerwise_casting_training()

@unittest.skip(
"The convolution layers of AutoencoderOobleck are wrapped with torch.nn.utils.weight_norm. This causes the hook's pre_forward to not "
"cast the module weights to compute_dtype (as required by forward pass). As a result, forward pass errors out. To fix:\n"
Expand Down
30 changes: 30 additions & 0 deletions tests/models/test_modeling_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1338,6 +1338,36 @@ def test_variant_sharded_ckpt_right_format(self):
# Example: diffusion_pytorch_model.fp16-00001-of-00002.safetensors
assert all(f.split(".")[1].split("-")[0] == variant for f in shard_files)

def test_layerwise_casting_training(self):
def test_fn(storage_dtype, compute_dtype):
if torch.device(torch_device).type == "cpu" and compute_dtype == torch.bfloat16:
return
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()

model = self.model_class(**init_dict)
model = model.to(torch_device, dtype=compute_dtype)
model.enable_layerwise_casting(storage_dtype=storage_dtype, compute_dtype=compute_dtype)
model.train()

inputs_dict = cast_maybe_tensor_dtype(inputs_dict, torch.float32, compute_dtype)
with torch.amp.autocast(device_type=torch.device(torch_device).type):
output = model(**inputs_dict)

if isinstance(output, dict):
output = output.to_tuple()[0]

input_tensor = inputs_dict[self.main_input_name]
noise = torch.randn((input_tensor.shape[0],) + self.output_shape).to(torch_device)
noise = cast_maybe_tensor_dtype(noise, torch.float32, compute_dtype)
loss = torch.nn.functional.mse_loss(output, noise)

loss.backward()

test_fn(torch.float16, torch.float32)
test_fn(torch.float8_e4m3fn, torch.float32)
test_fn(torch.float8_e5m2, torch.float32)
test_fn(torch.float8_e4m3fn, torch.bfloat16)

def test_layerwise_casting_inference(self):
from diffusers.hooks.layerwise_casting import DEFAULT_SKIP_MODULES_PATTERN, SUPPORTED_PYTORCH_LAYERS

Expand Down
8 changes: 8 additions & 0 deletions tests/models/unets/test_models_unet_1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,10 @@ def test_ema_training(self):
def test_training(self):
pass

@unittest.skip("Test not supported.")
def test_layerwise_casting_training(self):
pass

def test_determinism(self):
super().test_determinism()

Expand Down Expand Up @@ -239,6 +243,10 @@ def test_ema_training(self):
def test_training(self):
pass

@unittest.skip("Test not supported.")
def test_layerwise_casting_training(self):
pass

def prepare_init_args_and_inputs_for_common(self):
init_dict = {
"in_channels": 14,
Expand Down

0 comments on commit d5d7caa

Please sign in to comment.