Skip to content

Commit 69f919d

Browse files
authored
follow-up refactor on lumina2 (#10776)
* up
1 parent a6b843a commit 69f919d

File tree

3 files changed

+86
-123
lines changed

3 files changed

+86
-123
lines changed

Diff for: src/diffusers/models/transformers/transformer_lumina2.py

+83-107
Original file line numberDiff line numberDiff line change
@@ -242,97 +242,85 @@ def __init__(self, theta: int, axes_dim: List[int], axes_lens: List[int] = (300,
242242

243243
def _precompute_freqs_cis(self, axes_dim: List[int], axes_lens: List[int], theta: int) -> List[torch.Tensor]:
244244
freqs_cis = []
245-
# Use float32 for MPS compatibility
246-
dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
245+
freqs_dtype = torch.float32 if torch.backends.mps.is_available() else torch.float64
247246
for i, (d, e) in enumerate(zip(axes_dim, axes_lens)):
248-
emb = get_1d_rotary_pos_embed(d, e, theta=self.theta, freqs_dtype=dtype)
247+
emb = get_1d_rotary_pos_embed(d, e, theta=self.theta, freqs_dtype=freqs_dtype)
249248
freqs_cis.append(emb)
250249
return freqs_cis
251250

252251
def _get_freqs_cis(self, ids: torch.Tensor) -> torch.Tensor:
252+
device = ids.device
253+
if ids.device.type == "mps":
254+
ids = ids.to("cpu")
255+
253256
result = []
254257
for i in range(len(self.axes_dim)):
255258
freqs = self.freqs_cis[i].to(ids.device)
256259
index = ids[:, :, i : i + 1].repeat(1, 1, freqs.shape[-1]).to(torch.int64)
257260
result.append(torch.gather(freqs.unsqueeze(0).repeat(index.shape[0], 1, 1), dim=1, index=index))
258-
return torch.cat(result, dim=-1)
261+
return torch.cat(result, dim=-1).to(device)
259262

260263
def forward(self, hidden_states: torch.Tensor, attention_mask: torch.Tensor):
261-
batch_size = len(hidden_states)
262-
p_h = p_w = self.patch_size
263-
device = hidden_states[0].device
264+
batch_size, channels, height, width = hidden_states.shape
265+
p = self.patch_size
266+
post_patch_height, post_patch_width = height // p, width // p
267+
image_seq_len = post_patch_height * post_patch_width
268+
device = hidden_states.device
264269

270+
encoder_seq_len = attention_mask.shape[1]
265271
l_effective_cap_len = attention_mask.sum(dim=1).tolist()
266-
# TODO: this should probably be refactored because all subtensors of hidden_states will be of same shape
267-
img_sizes = [(img.size(1), img.size(2)) for img in hidden_states]
268-
l_effective_img_len = [(H // p_h) * (W // p_w) for (H, W) in img_sizes]
269-
270-
max_seq_len = max((cap_len + img_len for cap_len, img_len in zip(l_effective_cap_len, l_effective_img_len)))
271-
max_img_len = max(l_effective_img_len)
272+
seq_lengths = [cap_seq_len + image_seq_len for cap_seq_len in l_effective_cap_len]
273+
max_seq_len = max(seq_lengths)
272274

275+
# Create position IDs
273276
position_ids = torch.zeros(batch_size, max_seq_len, 3, dtype=torch.int32, device=device)
274277

275-
for i in range(batch_size):
276-
cap_len = l_effective_cap_len[i]
277-
img_len = l_effective_img_len[i]
278-
H, W = img_sizes[i]
279-
H_tokens, W_tokens = H // p_h, W // p_w
280-
assert H_tokens * W_tokens == img_len
278+
for i, (cap_seq_len, seq_len) in enumerate(zip(l_effective_cap_len, seq_lengths)):
279+
# add caption position ids
280+
position_ids[i, :cap_seq_len, 0] = torch.arange(cap_seq_len, dtype=torch.int32, device=device)
281+
position_ids[i, cap_seq_len:seq_len, 0] = cap_seq_len
281282

282-
position_ids[i, :cap_len, 0] = torch.arange(cap_len, dtype=torch.int32, device=device)
283-
position_ids[i, cap_len : cap_len + img_len, 0] = cap_len
283+
# add image position ids
284284
row_ids = (
285-
torch.arange(H_tokens, dtype=torch.int32, device=device).view(-1, 1).repeat(1, W_tokens).flatten()
285+
torch.arange(post_patch_height, dtype=torch.int32, device=device)
286+
.view(-1, 1)
287+
.repeat(1, post_patch_width)
288+
.flatten()
286289
)
287290
col_ids = (
288-
torch.arange(W_tokens, dtype=torch.int32, device=device).view(1, -1).repeat(H_tokens, 1).flatten()
291+
torch.arange(post_patch_width, dtype=torch.int32, device=device)
292+
.view(1, -1)
293+
.repeat(post_patch_height, 1)
294+
.flatten()
289295
)
290-
position_ids[i, cap_len : cap_len + img_len, 1] = row_ids
291-
position_ids[i, cap_len : cap_len + img_len, 2] = col_ids
296+
position_ids[i, cap_seq_len:seq_len, 1] = row_ids
297+
position_ids[i, cap_seq_len:seq_len, 2] = col_ids
292298

299+
# Get combined rotary embeddings
293300
freqs_cis = self._get_freqs_cis(position_ids)
294301

295-
cap_freqs_cis_shape = list(freqs_cis.shape)
296-
cap_freqs_cis_shape[1] = attention_mask.shape[1]
297-
cap_freqs_cis = torch.zeros(*cap_freqs_cis_shape, device=device, dtype=freqs_cis.dtype)
298-
299-
img_freqs_cis_shape = list(freqs_cis.shape)
300-
img_freqs_cis_shape[1] = max_img_len
301-
img_freqs_cis = torch.zeros(*img_freqs_cis_shape, device=device, dtype=freqs_cis.dtype)
302-
303-
for i in range(batch_size):
304-
cap_len = l_effective_cap_len[i]
305-
img_len = l_effective_img_len[i]
306-
cap_freqs_cis[i, :cap_len] = freqs_cis[i, :cap_len]
307-
img_freqs_cis[i, :img_len] = freqs_cis[i, cap_len : cap_len + img_len]
308-
309-
flat_hidden_states = []
310-
for i in range(batch_size):
311-
img = hidden_states[i]
312-
C, H, W = img.size()
313-
img = img.view(C, H // p_h, p_h, W // p_w, p_w).permute(1, 3, 2, 4, 0).flatten(2).flatten(0, 1)
314-
flat_hidden_states.append(img)
315-
hidden_states = flat_hidden_states
316-
padded_img_embed = torch.zeros(
317-
batch_size, max_img_len, hidden_states[0].shape[-1], device=device, dtype=hidden_states[0].dtype
302+
# create separate rotary embeddings for captions and images
303+
cap_freqs_cis = torch.zeros(
304+
batch_size, encoder_seq_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype
318305
)
319-
padded_img_mask = torch.zeros(batch_size, max_img_len, dtype=torch.bool, device=device)
320-
for i in range(batch_size):
321-
padded_img_embed[i, : l_effective_img_len[i]] = hidden_states[i]
322-
padded_img_mask[i, : l_effective_img_len[i]] = True
323-
324-
return (
325-
padded_img_embed,
326-
padded_img_mask,
327-
img_sizes,
328-
l_effective_cap_len,
329-
l_effective_img_len,
330-
freqs_cis,
331-
cap_freqs_cis,
332-
img_freqs_cis,
333-
max_seq_len,
306+
img_freqs_cis = torch.zeros(
307+
batch_size, image_seq_len, freqs_cis.shape[-1], device=device, dtype=freqs_cis.dtype
308+
)
309+
310+
for i, (cap_seq_len, seq_len) in enumerate(zip(l_effective_cap_len, seq_lengths)):
311+
cap_freqs_cis[i, :cap_seq_len] = freqs_cis[i, :cap_seq_len]
312+
img_freqs_cis[i, :image_seq_len] = freqs_cis[i, cap_seq_len:seq_len]
313+
314+
# image patch embeddings
315+
hidden_states = (
316+
hidden_states.view(batch_size, channels, post_patch_height, p, post_patch_width, p)
317+
.permute(0, 2, 4, 3, 5, 1)
318+
.flatten(3)
319+
.flatten(1, 2)
334320
)
335321

322+
return hidden_states, cap_freqs_cis, img_freqs_cis, freqs_cis, l_effective_cap_len, seq_lengths
323+
336324

337325
class Lumina2Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
338326
r"""
@@ -472,75 +460,63 @@ def forward(
472460
hidden_states: torch.Tensor,
473461
timestep: torch.Tensor,
474462
encoder_hidden_states: torch.Tensor,
475-
attention_mask: torch.Tensor,
476-
use_mask_in_transformer: bool = True,
463+
encoder_attention_mask: torch.Tensor,
477464
return_dict: bool = True,
478465
) -> Union[torch.Tensor, Transformer2DModelOutput]:
479-
batch_size = hidden_states.size(0)
480-
481466
# 1. Condition, positional & patch embedding
467+
batch_size, _, height, width = hidden_states.shape
468+
482469
temb, encoder_hidden_states = self.time_caption_embed(hidden_states, timestep, encoder_hidden_states)
483470

484471
(
485472
hidden_states,
486-
hidden_mask,
487-
hidden_sizes,
488-
encoder_hidden_len,
489-
hidden_len,
490-
joint_rotary_emb,
491-
encoder_rotary_emb,
492-
hidden_rotary_emb,
493-
max_seq_len,
494-
) = self.rope_embedder(hidden_states, attention_mask)
473+
context_rotary_emb,
474+
noise_rotary_emb,
475+
rotary_emb,
476+
encoder_seq_lengths,
477+
seq_lengths,
478+
) = self.rope_embedder(hidden_states, encoder_attention_mask)
495479

496480
hidden_states = self.x_embedder(hidden_states)
497481

498482
# 2. Context & noise refinement
499483
for layer in self.context_refiner:
500-
# NOTE: mask not used for performance
501-
encoder_hidden_states = layer(
502-
encoder_hidden_states, attention_mask if use_mask_in_transformer else None, encoder_rotary_emb
503-
)
484+
encoder_hidden_states = layer(encoder_hidden_states, encoder_attention_mask, context_rotary_emb)
504485

505486
for layer in self.noise_refiner:
506-
# NOTE: mask not used for performance
507-
hidden_states = layer(
508-
hidden_states, hidden_mask if use_mask_in_transformer else None, hidden_rotary_emb, temb
509-
)
487+
hidden_states = layer(hidden_states, None, noise_rotary_emb, temb)
488+
489+
# 3. Joint Transformer blocks
490+
max_seq_len = max(seq_lengths)
491+
use_mask = len(set(seq_lengths)) > 1
492+
493+
attention_mask = hidden_states.new_zeros(batch_size, max_seq_len, dtype=torch.bool)
494+
joint_hidden_states = hidden_states.new_zeros(batch_size, max_seq_len, self.config.hidden_size)
495+
for i, (encoder_seq_len, seq_len) in enumerate(zip(encoder_seq_lengths, seq_lengths)):
496+
attention_mask[i, :seq_len] = True
497+
joint_hidden_states[i, :encoder_seq_len] = encoder_hidden_states[i, :encoder_seq_len]
498+
joint_hidden_states[i, encoder_seq_len:seq_len] = hidden_states[i]
499+
500+
hidden_states = joint_hidden_states
510501

511-
# 3. Attention mask preparation
512-
mask = hidden_states.new_zeros(batch_size, max_seq_len, dtype=torch.bool)
513-
padded_hidden_states = hidden_states.new_zeros(batch_size, max_seq_len, self.config.hidden_size)
514-
for i in range(batch_size):
515-
cap_len = encoder_hidden_len[i]
516-
img_len = hidden_len[i]
517-
mask[i, : cap_len + img_len] = True
518-
padded_hidden_states[i, :cap_len] = encoder_hidden_states[i, :cap_len]
519-
padded_hidden_states[i, cap_len : cap_len + img_len] = hidden_states[i, :img_len]
520-
hidden_states = padded_hidden_states
521-
522-
# 4. Transformer blocks
523502
for layer in self.layers:
524-
# NOTE: mask not used for performance
525503
if torch.is_grad_enabled() and self.gradient_checkpointing:
526504
hidden_states = self._gradient_checkpointing_func(
527-
layer, hidden_states, mask if use_mask_in_transformer else None, joint_rotary_emb, temb
505+
layer, hidden_states, attention_mask if use_mask else None, rotary_emb, temb
528506
)
529507
else:
530-
hidden_states = layer(hidden_states, mask if use_mask_in_transformer else None, joint_rotary_emb, temb)
508+
hidden_states = layer(hidden_states, attention_mask if use_mask else None, rotary_emb, temb)
531509

532-
# 5. Output norm & projection & unpatchify
510+
# 4. Output norm & projection
533511
hidden_states = self.norm_out(hidden_states, temb)
534512

535-
height_tokens = width_tokens = self.config.patch_size
513+
# 5. Unpatchify
514+
p = self.config.patch_size
536515
output = []
537-
for i in range(len(hidden_sizes)):
538-
height, width = hidden_sizes[i]
539-
begin = encoder_hidden_len[i]
540-
end = begin + (height // height_tokens) * (width // width_tokens)
516+
for i, (encoder_seq_len, seq_len) in enumerate(zip(encoder_seq_lengths, seq_lengths)):
541517
output.append(
542-
hidden_states[i][begin:end]
543-
.view(height // height_tokens, width // width_tokens, height_tokens, width_tokens, self.out_channels)
518+
hidden_states[i][encoder_seq_len:seq_len]
519+
.view(height // p, width // p, p, p, self.out_channels)
544520
.permute(4, 0, 2, 1, 3)
545521
.flatten(3, 4)
546522
.flatten(1, 2)

Diff for: src/diffusers/pipelines/lumina2/pipeline_lumina2.py

+2-15
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,6 @@
2424
from ...models.transformers.transformer_lumina2 import Lumina2Transformer2DModel
2525
from ...schedulers import FlowMatchEulerDiscreteScheduler
2626
from ...utils import (
27-
is_bs4_available,
28-
is_ftfy_available,
2927
is_torch_xla_available,
3028
logging,
3129
replace_example_docstring,
@@ -44,12 +42,6 @@
4442
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
4543

4644

47-
if is_bs4_available():
48-
pass
49-
50-
if is_ftfy_available():
51-
pass
52-
5345
EXAMPLE_DOC_STRING = """
5446
Examples:
5547
```py
@@ -527,7 +519,6 @@ def __call__(
527519
system_prompt: Optional[str] = None,
528520
cfg_trunc_ratio: float = 1.0,
529521
cfg_normalization: bool = True,
530-
use_mask_in_transformer: bool = True,
531522
max_sequence_length: int = 256,
532523
) -> Union[ImagePipelineOutput, Tuple]:
533524
"""
@@ -599,8 +590,6 @@ def __call__(
599590
The ratio of the timestep interval to apply normalization-based guidance scale.
600591
cfg_normalization (`bool`, *optional*, defaults to `True`):
601592
Whether to apply normalization-based guidance scale.
602-
use_mask_in_transformer (`bool`, *optional*, defaults to `True`):
603-
Whether to use attention mask in `Lumina2Transformer2DModel`. Set `False` for performance gain.
604593
max_sequence_length (`int`, defaults to `256`):
605594
Maximum sequence length to use with the `prompt`.
606595
@@ -706,8 +695,7 @@ def __call__(
706695
hidden_states=latents,
707696
timestep=current_timestep,
708697
encoder_hidden_states=prompt_embeds,
709-
attention_mask=prompt_attention_mask,
710-
use_mask_in_transformer=use_mask_in_transformer,
698+
encoder_attention_mask=prompt_attention_mask,
711699
return_dict=False,
712700
)[0]
713701

@@ -717,8 +705,7 @@ def __call__(
717705
hidden_states=latents,
718706
timestep=current_timestep,
719707
encoder_hidden_states=negative_prompt_embeds,
720-
attention_mask=negative_prompt_attention_mask,
721-
use_mask_in_transformer=use_mask_in_transformer,
708+
encoder_attention_mask=negative_prompt_attention_mask,
722709
return_dict=False,
723710
)[0]
724711
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_cond - noise_pred_uncond)

Diff for: tests/models/transformers/test_models_transformer_lumina2.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ def dummy_input(self):
5151
"hidden_states": hidden_states,
5252
"encoder_hidden_states": encoder_hidden_states,
5353
"timestep": timestep,
54-
"attention_mask": attention_mask,
54+
"encoder_attention_mask": attention_mask,
5555
}
5656

5757
@property

0 commit comments

Comments
 (0)