|
116 | 116 | "mochi-1-preview": ["model.diffusion_model.blocks.0.attn.qkv_x.weight", "blocks.0.attn.qkv_x.weight"],
|
117 | 117 | "hunyuan-video": "txt_in.individual_token_refiner.blocks.0.adaLN_modulation.1.bias",
|
118 | 118 | "instruct-pix2pix": "model.diffusion_model.input_blocks.0.0.weight",
|
| 119 | + "lumina2": ["model.diffusion_model.cap_embedder.0.weight", "cap_embedder.0.weight"], |
119 | 120 | }
|
120 | 121 |
|
121 | 122 | DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
|
|
174 | 175 | "mochi-1-preview": {"pretrained_model_name_or_path": "genmo/mochi-1-preview"},
|
175 | 176 | "hunyuan-video": {"pretrained_model_name_or_path": "hunyuanvideo-community/HunyuanVideo"},
|
176 | 177 | "instruct-pix2pix": {"pretrained_model_name_or_path": "timbrooks/instruct-pix2pix"},
|
| 178 | + "lumina2": {"pretrained_model_name_or_path": "Alpha-VLLM/Lumina-Image-2.0"}, |
177 | 179 | }
|
178 | 180 |
|
179 | 181 | # Use to configure model sample size when original config is provided
|
@@ -657,6 +659,9 @@ def infer_diffusers_model_type(checkpoint):
|
657 | 659 | ):
|
658 | 660 | model_type = "instruct-pix2pix"
|
659 | 661 |
|
| 662 | + elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["lumina2"]): |
| 663 | + model_type = "lumina2" |
| 664 | + |
660 | 665 | else:
|
661 | 666 | model_type = "v1"
|
662 | 667 |
|
@@ -2798,3 +2803,75 @@ def calculate_layers(keys, key_prefix):
|
2798 | 2803 | converted_state_dict["pos_embed.proj.bias"] = checkpoint.pop("init_x_linear.bias")
|
2799 | 2804 |
|
2800 | 2805 | return converted_state_dict
|
| 2806 | + |
| 2807 | + |
| 2808 | +def convert_lumina2_to_diffusers(checkpoint, **kwargs): |
| 2809 | + converted_state_dict = {} |
| 2810 | + |
| 2811 | + # Original Lumina-Image-2 has an extra norm paramter that is unused |
| 2812 | + # We just remove it here |
| 2813 | + checkpoint.pop("norm_final.weight", None) |
| 2814 | + |
| 2815 | + # Comfy checkpoints add this prefix |
| 2816 | + keys = list(checkpoint.keys()) |
| 2817 | + for k in keys: |
| 2818 | + if "model.diffusion_model." in k: |
| 2819 | + checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k) |
| 2820 | + |
| 2821 | + LUMINA_KEY_MAP = { |
| 2822 | + "cap_embedder": "time_caption_embed.caption_embedder", |
| 2823 | + "t_embedder.mlp.0": "time_caption_embed.timestep_embedder.linear_1", |
| 2824 | + "t_embedder.mlp.2": "time_caption_embed.timestep_embedder.linear_2", |
| 2825 | + "attention": "attn", |
| 2826 | + ".out.": ".to_out.0.", |
| 2827 | + "k_norm": "norm_k", |
| 2828 | + "q_norm": "norm_q", |
| 2829 | + "w1": "linear_1", |
| 2830 | + "w2": "linear_2", |
| 2831 | + "w3": "linear_3", |
| 2832 | + "adaLN_modulation.1": "norm1.linear", |
| 2833 | + } |
| 2834 | + ATTENTION_NORM_MAP = { |
| 2835 | + "attention_norm1": "norm1.norm", |
| 2836 | + "attention_norm2": "norm2", |
| 2837 | + } |
| 2838 | + CONTEXT_REFINER_MAP = { |
| 2839 | + "context_refiner.0.attention_norm1": "context_refiner.0.norm1", |
| 2840 | + "context_refiner.0.attention_norm2": "context_refiner.0.norm2", |
| 2841 | + "context_refiner.1.attention_norm1": "context_refiner.1.norm1", |
| 2842 | + "context_refiner.1.attention_norm2": "context_refiner.1.norm2", |
| 2843 | + } |
| 2844 | + FINAL_LAYER_MAP = { |
| 2845 | + "final_layer.adaLN_modulation.1": "norm_out.linear_1", |
| 2846 | + "final_layer.linear": "norm_out.linear_2", |
| 2847 | + } |
| 2848 | + |
| 2849 | + def convert_lumina_attn_to_diffusers(tensor, diffusers_key): |
| 2850 | + q_dim = 2304 |
| 2851 | + k_dim = v_dim = 768 |
| 2852 | + |
| 2853 | + to_q, to_k, to_v = torch.split(tensor, [q_dim, k_dim, v_dim], dim=0) |
| 2854 | + |
| 2855 | + return { |
| 2856 | + diffusers_key.replace("qkv", "to_q"): to_q, |
| 2857 | + diffusers_key.replace("qkv", "to_k"): to_k, |
| 2858 | + diffusers_key.replace("qkv", "to_v"): to_v, |
| 2859 | + } |
| 2860 | + |
| 2861 | + for key in keys: |
| 2862 | + diffusers_key = key |
| 2863 | + for k, v in CONTEXT_REFINER_MAP.items(): |
| 2864 | + diffusers_key = diffusers_key.replace(k, v) |
| 2865 | + for k, v in FINAL_LAYER_MAP.items(): |
| 2866 | + diffusers_key = diffusers_key.replace(k, v) |
| 2867 | + for k, v in ATTENTION_NORM_MAP.items(): |
| 2868 | + diffusers_key = diffusers_key.replace(k, v) |
| 2869 | + for k, v in LUMINA_KEY_MAP.items(): |
| 2870 | + diffusers_key = diffusers_key.replace(k, v) |
| 2871 | + |
| 2872 | + if "qkv" in diffusers_key: |
| 2873 | + converted_state_dict.update(convert_lumina_attn_to_diffusers(checkpoint.pop(key), diffusers_key)) |
| 2874 | + else: |
| 2875 | + converted_state_dict[diffusers_key] = checkpoint.pop(key) |
| 2876 | + |
| 2877 | + return converted_state_dict |
0 commit comments