Skip to content

Commit 28f48f4

Browse files
authored
[Single File] Add Single File support for Lumina Image 2.0 Transformer (#10781)
* update * update
1 parent 067eab1 commit 28f48f4

File tree

5 files changed

+208
-1
lines changed

5 files changed

+208
-1
lines changed

docs/source/en/api/pipelines/lumina2.md

+50
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,56 @@ Make sure to check out the Schedulers [guide](../../using-diffusers/schedulers)
2626

2727
</Tip>
2828

29+
## Using Single File loading with Lumina Image 2.0
30+
31+
Single file loading for Lumina Image 2.0 is available for the `Lumina2Transformer2DModel`
32+
33+
```python
34+
import torch
35+
from diffusers import Lumina2Transformer2DModel, Lumina2Text2ImgPipeline
36+
37+
ckpt_path = "https://huggingface.co/Alpha-VLLM/Lumina-Image-2.0/blob/main/consolidated.00-of-01.pth"
38+
transformer = Lumina2Transformer2DModel.from_single_file(
39+
ckpt_path, torch_dtype=torch.bfloat16
40+
)
41+
42+
pipe = Lumina2Text2ImgPipeline.from_pretrained(
43+
"Alpha-VLLM/Lumina-Image-2.0", transformer=transformer, torch_dtype=torch.bfloat16
44+
)
45+
pipe.enable_model_cpu_offload()
46+
image = pipe(
47+
"a cat holding a sign that says hello",
48+
generator=torch.Generator("cpu").manual_seed(0),
49+
).images[0]
50+
image.save("lumina-single-file.png")
51+
52+
```
53+
54+
## Using GGUF Quantized Checkpoints with Lumina Image 2.0
55+
56+
GGUF Quantized checkpoints for the `Lumina2Transformer2DModel` can be loaded via `from_single_file` with the `GGUFQuantizationConfig`
57+
58+
```python
59+
from diffusers import Lumina2Transformer2DModel, Lumina2Text2ImgPipeline, GGUFQuantizationConfig
60+
61+
ckpt_path = "https://huggingface.co/calcuis/lumina-gguf/blob/main/lumina2-q4_0.gguf"
62+
transformer = Lumina2Transformer2DModel.from_single_file(
63+
ckpt_path,
64+
quantization_config=GGUFQuantizationConfig(compute_dtype=torch.bfloat16),
65+
torch_dtype=torch.bfloat16,
66+
)
67+
68+
pipe = Lumina2Text2ImgPipeline.from_pretrained(
69+
"Alpha-VLLM/Lumina-Image-2.0", transformer=transformer, torch_dtype=torch.bfloat16
70+
)
71+
pipe.enable_model_cpu_offload()
72+
image = pipe(
73+
"a cat holding a sign that says hello",
74+
generator=torch.Generator("cpu").manual_seed(0),
75+
).images[0]
76+
image.save("lumina-gguf.png")
77+
```
78+
2979
## Lumina2Text2ImgPipeline
3080

3181
[[autodoc]] Lumina2Text2ImgPipeline

src/diffusers/loaders/single_file_model.py

+5
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
convert_ldm_vae_checkpoint,
3535
convert_ltx_transformer_checkpoint_to_diffusers,
3636
convert_ltx_vae_checkpoint_to_diffusers,
37+
convert_lumina2_to_diffusers,
3738
convert_mochi_transformer_checkpoint_to_diffusers,
3839
convert_sd3_transformer_checkpoint_to_diffusers,
3940
convert_stable_cascade_unet_single_file_to_diffusers,
@@ -111,6 +112,10 @@
111112
"checkpoint_mapping_fn": convert_auraflow_transformer_checkpoint_to_diffusers,
112113
"default_subfolder": "transformer",
113114
},
115+
"Lumina2Transformer2DModel": {
116+
"checkpoint_mapping_fn": convert_lumina2_to_diffusers,
117+
"default_subfolder": "transformer",
118+
},
114119
}
115120

116121

src/diffusers/loaders/single_file_utils.py

+77
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@
116116
"mochi-1-preview": ["model.diffusion_model.blocks.0.attn.qkv_x.weight", "blocks.0.attn.qkv_x.weight"],
117117
"hunyuan-video": "txt_in.individual_token_refiner.blocks.0.adaLN_modulation.1.bias",
118118
"instruct-pix2pix": "model.diffusion_model.input_blocks.0.0.weight",
119+
"lumina2": ["model.diffusion_model.cap_embedder.0.weight", "cap_embedder.0.weight"],
119120
}
120121

121122
DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
@@ -174,6 +175,7 @@
174175
"mochi-1-preview": {"pretrained_model_name_or_path": "genmo/mochi-1-preview"},
175176
"hunyuan-video": {"pretrained_model_name_or_path": "hunyuanvideo-community/HunyuanVideo"},
176177
"instruct-pix2pix": {"pretrained_model_name_or_path": "timbrooks/instruct-pix2pix"},
178+
"lumina2": {"pretrained_model_name_or_path": "Alpha-VLLM/Lumina-Image-2.0"},
177179
}
178180

179181
# Use to configure model sample size when original config is provided
@@ -657,6 +659,9 @@ def infer_diffusers_model_type(checkpoint):
657659
):
658660
model_type = "instruct-pix2pix"
659661

662+
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["lumina2"]):
663+
model_type = "lumina2"
664+
660665
else:
661666
model_type = "v1"
662667

@@ -2798,3 +2803,75 @@ def calculate_layers(keys, key_prefix):
27982803
converted_state_dict["pos_embed.proj.bias"] = checkpoint.pop("init_x_linear.bias")
27992804

28002805
return converted_state_dict
2806+
2807+
2808+
def convert_lumina2_to_diffusers(checkpoint, **kwargs):
2809+
converted_state_dict = {}
2810+
2811+
# Original Lumina-Image-2 has an extra norm paramter that is unused
2812+
# We just remove it here
2813+
checkpoint.pop("norm_final.weight", None)
2814+
2815+
# Comfy checkpoints add this prefix
2816+
keys = list(checkpoint.keys())
2817+
for k in keys:
2818+
if "model.diffusion_model." in k:
2819+
checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)
2820+
2821+
LUMINA_KEY_MAP = {
2822+
"cap_embedder": "time_caption_embed.caption_embedder",
2823+
"t_embedder.mlp.0": "time_caption_embed.timestep_embedder.linear_1",
2824+
"t_embedder.mlp.2": "time_caption_embed.timestep_embedder.linear_2",
2825+
"attention": "attn",
2826+
".out.": ".to_out.0.",
2827+
"k_norm": "norm_k",
2828+
"q_norm": "norm_q",
2829+
"w1": "linear_1",
2830+
"w2": "linear_2",
2831+
"w3": "linear_3",
2832+
"adaLN_modulation.1": "norm1.linear",
2833+
}
2834+
ATTENTION_NORM_MAP = {
2835+
"attention_norm1": "norm1.norm",
2836+
"attention_norm2": "norm2",
2837+
}
2838+
CONTEXT_REFINER_MAP = {
2839+
"context_refiner.0.attention_norm1": "context_refiner.0.norm1",
2840+
"context_refiner.0.attention_norm2": "context_refiner.0.norm2",
2841+
"context_refiner.1.attention_norm1": "context_refiner.1.norm1",
2842+
"context_refiner.1.attention_norm2": "context_refiner.1.norm2",
2843+
}
2844+
FINAL_LAYER_MAP = {
2845+
"final_layer.adaLN_modulation.1": "norm_out.linear_1",
2846+
"final_layer.linear": "norm_out.linear_2",
2847+
}
2848+
2849+
def convert_lumina_attn_to_diffusers(tensor, diffusers_key):
2850+
q_dim = 2304
2851+
k_dim = v_dim = 768
2852+
2853+
to_q, to_k, to_v = torch.split(tensor, [q_dim, k_dim, v_dim], dim=0)
2854+
2855+
return {
2856+
diffusers_key.replace("qkv", "to_q"): to_q,
2857+
diffusers_key.replace("qkv", "to_k"): to_k,
2858+
diffusers_key.replace("qkv", "to_v"): to_v,
2859+
}
2860+
2861+
for key in keys:
2862+
diffusers_key = key
2863+
for k, v in CONTEXT_REFINER_MAP.items():
2864+
diffusers_key = diffusers_key.replace(k, v)
2865+
for k, v in FINAL_LAYER_MAP.items():
2866+
diffusers_key = diffusers_key.replace(k, v)
2867+
for k, v in ATTENTION_NORM_MAP.items():
2868+
diffusers_key = diffusers_key.replace(k, v)
2869+
for k, v in LUMINA_KEY_MAP.items():
2870+
diffusers_key = diffusers_key.replace(k, v)
2871+
2872+
if "qkv" in diffusers_key:
2873+
converted_state_dict.update(convert_lumina_attn_to_diffusers(checkpoint.pop(key), diffusers_key))
2874+
else:
2875+
converted_state_dict[diffusers_key] = checkpoint.pop(key)
2876+
2877+
return converted_state_dict

src/diffusers/models/transformers/transformer_lumina2.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121

2222
from ...configuration_utils import ConfigMixin, register_to_config
2323
from ...loaders import PeftAdapterMixin
24+
from ...loaders.single_file_model import FromOriginalModelMixin
2425
from ...utils import logging
2526
from ..attention import LuminaFeedForward
2627
from ..attention_processor import Attention
@@ -333,7 +334,7 @@ def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor):
333334
)
334335

335336

336-
class Lumina2Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
337+
class Lumina2Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
337338
r"""
338339
Lumina2NextDiT: Diffusion model with a Transformer backbone.
339340
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,74 @@
1+
# coding=utf-8
2+
# Copyright 2025 HuggingFace Inc.
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
import gc
17+
import unittest
18+
19+
import torch
20+
21+
from diffusers import (
22+
Lumina2Transformer2DModel,
23+
)
24+
from diffusers.utils.testing_utils import (
25+
backend_empty_cache,
26+
enable_full_determinism,
27+
require_torch_accelerator,
28+
torch_device,
29+
)
30+
31+
32+
enable_full_determinism()
33+
34+
35+
@require_torch_accelerator
36+
class Lumina2Transformer2DModelSingleFileTests(unittest.TestCase):
37+
model_class = Lumina2Transformer2DModel
38+
ckpt_path = "https://huggingface.co/Comfy-Org/Lumina_Image_2.0_Repackaged/blob/main/split_files/diffusion_models/lumina_2_model_bf16.safetensors"
39+
alternate_keys_ckpt_paths = [
40+
"https://huggingface.co/Comfy-Org/Lumina_Image_2.0_Repackaged/blob/main/split_files/diffusion_models/lumina_2_model_bf16.safetensors"
41+
]
42+
43+
repo_id = "Alpha-VLLM/Lumina-Image-2.0"
44+
45+
def setUp(self):
46+
super().setUp()
47+
gc.collect()
48+
backend_empty_cache(torch_device)
49+
50+
def tearDown(self):
51+
super().tearDown()
52+
gc.collect()
53+
backend_empty_cache(torch_device)
54+
55+
def test_single_file_components(self):
56+
model = self.model_class.from_pretrained(self.repo_id, subfolder="transformer")
57+
model_single_file = self.model_class.from_single_file(self.ckpt_path)
58+
59+
PARAMS_TO_IGNORE = ["torch_dtype", "_name_or_path", "_use_default_values", "_diffusers_version"]
60+
for param_name, param_value in model_single_file.config.items():
61+
if param_name in PARAMS_TO_IGNORE:
62+
continue
63+
assert (
64+
model.config[param_name] == param_value
65+
), f"{param_name} differs between single file loading and pretrained loading"
66+
67+
def test_checkpoint_loading(self):
68+
for ckpt_path in self.alternate_keys_ckpt_paths:
69+
torch.cuda.empty_cache()
70+
model = self.model_class.from_single_file(ckpt_path)
71+
72+
del model
73+
gc.collect()
74+
torch.cuda.empty_cache()

0 commit comments

Comments
 (0)