@@ -242,97 +242,85 @@ def __init__(self, theta: int, axes_dim: List[int], axes_lens: List[int] = (300,
242
242
243
243
def _precompute_freqs_cis (self , axes_dim : List [int ], axes_lens : List [int ], theta : int ) -> List [torch .Tensor ]:
244
244
freqs_cis = []
245
- # Use float32 for MPS compatibility
246
- dtype = torch .float32 if torch .backends .mps .is_available () else torch .float64
245
+ freqs_dtype = torch .float32 if torch .backends .mps .is_available () else torch .float64
247
246
for i , (d , e ) in enumerate (zip (axes_dim , axes_lens )):
248
- emb = get_1d_rotary_pos_embed (d , e , theta = self .theta , freqs_dtype = dtype )
247
+ emb = get_1d_rotary_pos_embed (d , e , theta = self .theta , freqs_dtype = freqs_dtype )
249
248
freqs_cis .append (emb )
250
249
return freqs_cis
251
250
252
251
def _get_freqs_cis (self , ids : torch .Tensor ) -> torch .Tensor :
252
+ device = ids .device
253
+ if ids .device .type == "mps" :
254
+ ids = ids .to ("cpu" )
255
+
253
256
result = []
254
257
for i in range (len (self .axes_dim )):
255
258
freqs = self .freqs_cis [i ].to (ids .device )
256
259
index = ids [:, :, i : i + 1 ].repeat (1 , 1 , freqs .shape [- 1 ]).to (torch .int64 )
257
260
result .append (torch .gather (freqs .unsqueeze (0 ).repeat (index .shape [0 ], 1 , 1 ), dim = 1 , index = index ))
258
- return torch .cat (result , dim = - 1 )
261
+ return torch .cat (result , dim = - 1 ). to ( device )
259
262
260
263
def forward (self , hidden_states : torch .Tensor , attention_mask : torch .Tensor ):
261
- batch_size = len (hidden_states )
262
- p_h = p_w = self .patch_size
263
- device = hidden_states [0 ].device
264
+ batch_size , channels , height , width = hidden_states .shape
265
+ p = self .patch_size
266
+ post_patch_height , post_patch_width = height // p , width // p
267
+ image_seq_len = post_patch_height * post_patch_width
268
+ device = hidden_states .device
264
269
270
+ encoder_seq_len = attention_mask .shape [1 ]
265
271
l_effective_cap_len = attention_mask .sum (dim = 1 ).tolist ()
266
- # TODO: this should probably be refactored because all subtensors of hidden_states will be of same shape
267
- img_sizes = [(img .size (1 ), img .size (2 )) for img in hidden_states ]
268
- l_effective_img_len = [(H // p_h ) * (W // p_w ) for (H , W ) in img_sizes ]
269
-
270
- max_seq_len = max ((cap_len + img_len for cap_len , img_len in zip (l_effective_cap_len , l_effective_img_len )))
271
- max_img_len = max (l_effective_img_len )
272
+ seq_lengths = [cap_seq_len + image_seq_len for cap_seq_len in l_effective_cap_len ]
273
+ max_seq_len = max (seq_lengths )
272
274
275
+ # Create position IDs
273
276
position_ids = torch .zeros (batch_size , max_seq_len , 3 , dtype = torch .int32 , device = device )
274
277
275
- for i in range (batch_size ):
276
- cap_len = l_effective_cap_len [i ]
277
- img_len = l_effective_img_len [i ]
278
- H , W = img_sizes [i ]
279
- H_tokens , W_tokens = H // p_h , W // p_w
280
- assert H_tokens * W_tokens == img_len
278
+ for i , (cap_seq_len , seq_len ) in enumerate (zip (l_effective_cap_len , seq_lengths )):
279
+ # add caption position ids
280
+ position_ids [i , :cap_seq_len , 0 ] = torch .arange (cap_seq_len , dtype = torch .int32 , device = device )
281
+ position_ids [i , cap_seq_len :seq_len , 0 ] = cap_seq_len
281
282
282
- position_ids [i , :cap_len , 0 ] = torch .arange (cap_len , dtype = torch .int32 , device = device )
283
- position_ids [i , cap_len : cap_len + img_len , 0 ] = cap_len
283
+ # add image position ids
284
284
row_ids = (
285
- torch .arange (H_tokens , dtype = torch .int32 , device = device ).view (- 1 , 1 ).repeat (1 , W_tokens ).flatten ()
285
+ torch .arange (post_patch_height , dtype = torch .int32 , device = device )
286
+ .view (- 1 , 1 )
287
+ .repeat (1 , post_patch_width )
288
+ .flatten ()
286
289
)
287
290
col_ids = (
288
- torch .arange (W_tokens , dtype = torch .int32 , device = device ).view (1 , - 1 ).repeat (H_tokens , 1 ).flatten ()
291
+ torch .arange (post_patch_width , dtype = torch .int32 , device = device )
292
+ .view (1 , - 1 )
293
+ .repeat (post_patch_height , 1 )
294
+ .flatten ()
289
295
)
290
- position_ids [i , cap_len : cap_len + img_len , 1 ] = row_ids
291
- position_ids [i , cap_len : cap_len + img_len , 2 ] = col_ids
296
+ position_ids [i , cap_seq_len : seq_len , 1 ] = row_ids
297
+ position_ids [i , cap_seq_len : seq_len , 2 ] = col_ids
292
298
299
+ # Get combined rotary embeddings
293
300
freqs_cis = self ._get_freqs_cis (position_ids )
294
301
295
- cap_freqs_cis_shape = list (freqs_cis .shape )
296
- cap_freqs_cis_shape [1 ] = attention_mask .shape [1 ]
297
- cap_freqs_cis = torch .zeros (* cap_freqs_cis_shape , device = device , dtype = freqs_cis .dtype )
298
-
299
- img_freqs_cis_shape = list (freqs_cis .shape )
300
- img_freqs_cis_shape [1 ] = max_img_len
301
- img_freqs_cis = torch .zeros (* img_freqs_cis_shape , device = device , dtype = freqs_cis .dtype )
302
-
303
- for i in range (batch_size ):
304
- cap_len = l_effective_cap_len [i ]
305
- img_len = l_effective_img_len [i ]
306
- cap_freqs_cis [i , :cap_len ] = freqs_cis [i , :cap_len ]
307
- img_freqs_cis [i , :img_len ] = freqs_cis [i , cap_len : cap_len + img_len ]
308
-
309
- flat_hidden_states = []
310
- for i in range (batch_size ):
311
- img = hidden_states [i ]
312
- C , H , W = img .size ()
313
- img = img .view (C , H // p_h , p_h , W // p_w , p_w ).permute (1 , 3 , 2 , 4 , 0 ).flatten (2 ).flatten (0 , 1 )
314
- flat_hidden_states .append (img )
315
- hidden_states = flat_hidden_states
316
- padded_img_embed = torch .zeros (
317
- batch_size , max_img_len , hidden_states [0 ].shape [- 1 ], device = device , dtype = hidden_states [0 ].dtype
302
+ # create separate rotary embeddings for captions and images
303
+ cap_freqs_cis = torch .zeros (
304
+ batch_size , encoder_seq_len , freqs_cis .shape [- 1 ], device = device , dtype = freqs_cis .dtype
318
305
)
319
- padded_img_mask = torch .zeros (batch_size , max_img_len , dtype = torch .bool , device = device )
320
- for i in range (batch_size ):
321
- padded_img_embed [i , : l_effective_img_len [i ]] = hidden_states [i ]
322
- padded_img_mask [i , : l_effective_img_len [i ]] = True
323
-
324
- return (
325
- padded_img_embed ,
326
- padded_img_mask ,
327
- img_sizes ,
328
- l_effective_cap_len ,
329
- l_effective_img_len ,
330
- freqs_cis ,
331
- cap_freqs_cis ,
332
- img_freqs_cis ,
333
- max_seq_len ,
306
+ img_freqs_cis = torch .zeros (
307
+ batch_size , image_seq_len , freqs_cis .shape [- 1 ], device = device , dtype = freqs_cis .dtype
308
+ )
309
+
310
+ for i , (cap_seq_len , seq_len ) in enumerate (zip (l_effective_cap_len , seq_lengths )):
311
+ cap_freqs_cis [i , :cap_seq_len ] = freqs_cis [i , :cap_seq_len ]
312
+ img_freqs_cis [i , :image_seq_len ] = freqs_cis [i , cap_seq_len :seq_len ]
313
+
314
+ # image patch embeddings
315
+ hidden_states = (
316
+ hidden_states .view (batch_size , channels , post_patch_height , p , post_patch_width , p )
317
+ .permute (0 , 2 , 4 , 3 , 5 , 1 )
318
+ .flatten (3 )
319
+ .flatten (1 , 2 )
334
320
)
335
321
322
+ return hidden_states , cap_freqs_cis , img_freqs_cis , freqs_cis , l_effective_cap_len , seq_lengths
323
+
336
324
337
325
class Lumina2Transformer2DModel (ModelMixin , ConfigMixin , PeftAdapterMixin , FromOriginalModelMixin ):
338
326
r"""
@@ -472,75 +460,63 @@ def forward(
472
460
hidden_states : torch .Tensor ,
473
461
timestep : torch .Tensor ,
474
462
encoder_hidden_states : torch .Tensor ,
475
- attention_mask : torch .Tensor ,
476
- use_mask_in_transformer : bool = True ,
463
+ encoder_attention_mask : torch .Tensor ,
477
464
return_dict : bool = True ,
478
465
) -> Union [torch .Tensor , Transformer2DModelOutput ]:
479
- batch_size = hidden_states .size (0 )
480
-
481
466
# 1. Condition, positional & patch embedding
467
+ batch_size , _ , height , width = hidden_states .shape
468
+
482
469
temb , encoder_hidden_states = self .time_caption_embed (hidden_states , timestep , encoder_hidden_states )
483
470
484
471
(
485
472
hidden_states ,
486
- hidden_mask ,
487
- hidden_sizes ,
488
- encoder_hidden_len ,
489
- hidden_len ,
490
- joint_rotary_emb ,
491
- encoder_rotary_emb ,
492
- hidden_rotary_emb ,
493
- max_seq_len ,
494
- ) = self .rope_embedder (hidden_states , attention_mask )
473
+ context_rotary_emb ,
474
+ noise_rotary_emb ,
475
+ rotary_emb ,
476
+ encoder_seq_lengths ,
477
+ seq_lengths ,
478
+ ) = self .rope_embedder (hidden_states , encoder_attention_mask )
495
479
496
480
hidden_states = self .x_embedder (hidden_states )
497
481
498
482
# 2. Context & noise refinement
499
483
for layer in self .context_refiner :
500
- # NOTE: mask not used for performance
501
- encoder_hidden_states = layer (
502
- encoder_hidden_states , attention_mask if use_mask_in_transformer else None , encoder_rotary_emb
503
- )
484
+ encoder_hidden_states = layer (encoder_hidden_states , encoder_attention_mask , context_rotary_emb )
504
485
505
486
for layer in self .noise_refiner :
506
- # NOTE: mask not used for performance
507
- hidden_states = layer (
508
- hidden_states , hidden_mask if use_mask_in_transformer else None , hidden_rotary_emb , temb
509
- )
487
+ hidden_states = layer (hidden_states , None , noise_rotary_emb , temb )
488
+
489
+ # 3. Joint Transformer blocks
490
+ max_seq_len = max (seq_lengths )
491
+ use_mask = len (set (seq_lengths )) > 1
492
+
493
+ attention_mask = hidden_states .new_zeros (batch_size , max_seq_len , dtype = torch .bool )
494
+ joint_hidden_states = hidden_states .new_zeros (batch_size , max_seq_len , self .config .hidden_size )
495
+ for i , (encoder_seq_len , seq_len ) in enumerate (zip (encoder_seq_lengths , seq_lengths )):
496
+ attention_mask [i , :seq_len ] = True
497
+ joint_hidden_states [i , :encoder_seq_len ] = encoder_hidden_states [i , :encoder_seq_len ]
498
+ joint_hidden_states [i , encoder_seq_len :seq_len ] = hidden_states [i ]
499
+
500
+ hidden_states = joint_hidden_states
510
501
511
- # 3. Attention mask preparation
512
- mask = hidden_states .new_zeros (batch_size , max_seq_len , dtype = torch .bool )
513
- padded_hidden_states = hidden_states .new_zeros (batch_size , max_seq_len , self .config .hidden_size )
514
- for i in range (batch_size ):
515
- cap_len = encoder_hidden_len [i ]
516
- img_len = hidden_len [i ]
517
- mask [i , : cap_len + img_len ] = True
518
- padded_hidden_states [i , :cap_len ] = encoder_hidden_states [i , :cap_len ]
519
- padded_hidden_states [i , cap_len : cap_len + img_len ] = hidden_states [i , :img_len ]
520
- hidden_states = padded_hidden_states
521
-
522
- # 4. Transformer blocks
523
502
for layer in self .layers :
524
- # NOTE: mask not used for performance
525
503
if torch .is_grad_enabled () and self .gradient_checkpointing :
526
504
hidden_states = self ._gradient_checkpointing_func (
527
- layer , hidden_states , mask if use_mask_in_transformer else None , joint_rotary_emb , temb
505
+ layer , hidden_states , attention_mask if use_mask else None , rotary_emb , temb
528
506
)
529
507
else :
530
- hidden_states = layer (hidden_states , mask if use_mask_in_transformer else None , joint_rotary_emb , temb )
508
+ hidden_states = layer (hidden_states , attention_mask if use_mask else None , rotary_emb , temb )
531
509
532
- # 5 . Output norm & projection & unpatchify
510
+ # 4 . Output norm & projection
533
511
hidden_states = self .norm_out (hidden_states , temb )
534
512
535
- height_tokens = width_tokens = self .config .patch_size
513
+ # 5. Unpatchify
514
+ p = self .config .patch_size
536
515
output = []
537
- for i in range (len (hidden_sizes )):
538
- height , width = hidden_sizes [i ]
539
- begin = encoder_hidden_len [i ]
540
- end = begin + (height // height_tokens ) * (width // width_tokens )
516
+ for i , (encoder_seq_len , seq_len ) in enumerate (zip (encoder_seq_lengths , seq_lengths )):
541
517
output .append (
542
- hidden_states [i ][begin : end ]
543
- .view (height // height_tokens , width // width_tokens , height_tokens , width_tokens , self .out_channels )
518
+ hidden_states [i ][encoder_seq_len : seq_len ]
519
+ .view (height // p , width // p , p , p , self .out_channels )
544
520
.permute (4 , 0 , 2 , 1 , 3 )
545
521
.flatten (3 , 4 )
546
522
.flatten (1 , 2 )
0 commit comments