@@ -24,6 +24,17 @@ def exists(val):
24
24
def default (val , d ):
25
25
return val if exists (val ) else d
26
26
27
+ # decorators
28
+
29
+ def eval_decorator (fn ):
30
+ def inner (model , * args , ** kwargs ):
31
+ was_training = model .training
32
+ model .eval ()
33
+ out = fn (model , * args , ** kwargs )
34
+ model .train (was_training )
35
+ return out
36
+ return inner
37
+
27
38
# tensor helper functions
28
39
29
40
def log (t , eps = 1e-20 ):
@@ -129,6 +140,10 @@ def __init__(
129
140
self .discr_loss = hinge_discr_loss if use_hinge_loss else bce_discr_loss
130
141
self .gen_loss = hinge_gen_loss if use_hinge_loss else bce_gen_loss
131
142
143
+ @property
144
+ def codebook (self ):
145
+ return self .vq .codebook
146
+
132
147
def encode (self , fmap ):
133
148
for enc in self .encoders :
134
149
fmap = enc (fmap )
@@ -161,9 +176,11 @@ def forward(
161
176
162
177
fmap = self .decode (fmap )
163
178
164
- if not return_loss :
179
+ if not return_loss and not return_discr_loss :
165
180
return fmap
166
181
182
+ assert return_loss ^ return_discr_loss , 'you should either return autoencoder loss or discriminator loss, but not both'
183
+
167
184
if return_discr_loss :
168
185
fmap .detach_ ()
169
186
fmap_discr_logits , img_discr_logits = map (self .discr , (fmap , img ))
@@ -332,16 +349,31 @@ def forward(self, x, mask = None):
332
349
# more variables
333
350
334
351
kernel_size = self .kernel_size
335
- num_frames , fmap_size , _ = self .video_shape
352
+ fmap_size = self .video_shape [1 ]
353
+
354
+ bos_only = n == 1
355
+ tokens_per_frame = fmap_size ** 2
356
+
357
+ padding = 0 if bos_only else (tokens_per_frame - (n - 1 ) % tokens_per_frame )
358
+ num_frames = (n + padding ) // tokens_per_frame
336
359
337
360
# pad for last token in video
338
361
339
- x = F .pad (x , (0 , 0 , 0 , 1 ), value = 0. )
362
+ if padding > 0 :
363
+ x = F .pad (x , (0 , 0 , 0 , padding ), value = 0. )
340
364
341
365
# derive queries / keys / values
342
366
343
- qkv = self .to_qkv (x ).chunk (3 , dim = - 1 )
344
- q , k , v = map (lambda t : rearrange (t , 'b n (h d) -> (b h) n d' , h = h ), qkv )
367
+ q , k , v = self .to_qkv (x ).chunk (3 , dim = - 1 )
368
+
369
+ # early return if <bos>
370
+
371
+ if bos_only :
372
+ return self .to_out (v )
373
+
374
+ # split out heads
375
+
376
+ q , k , v = map (lambda t : rearrange (t , 'b n (h d) -> (b h) n d' , h = h ), (q , k , v ))
345
377
346
378
# scale queries
347
379
@@ -352,11 +384,6 @@ def forward(self, x, mask = None):
352
384
q = q [:, 1 :]
353
385
bos_value = v [:, :1 ]
354
386
355
- # prepare precomputed causal mask
356
-
357
- causal_mask = self .causal_mask [:n ]
358
- causal_mask = repeat (causal_mask , 'i j -> b i j' , b = b * h )
359
-
360
387
# compute keys and values
361
388
362
389
(k_bos , k ), (v_bos , v ) = map (lambda t : (t [:, :1 ], t [:, 1 :]), (k , v ))
@@ -376,6 +403,10 @@ def forward(self, x, mask = None):
376
403
377
404
# causal mask
378
405
406
+ i , j = sim .shape [- 2 :]
407
+ causal_mask = self .causal_mask [:i , :j ]
408
+ causal_mask = repeat (causal_mask , 'i j -> b i j' , b = b * h )
409
+
379
410
sim = sim .masked_fill (causal_mask , - torch .finfo (sim .dtype ).max )
380
411
381
412
# attention
@@ -450,6 +481,16 @@ def forward(
450
481
451
482
return self .norm (x )
452
483
484
+ # sampling helpers
485
+
486
+ def top_k (logits , thres = 0.5 ):
487
+ num_logits = logits .shape [- 1 ]
488
+ k = max (int ((1 - thres ) * num_logits ), 1 )
489
+ val , ind = torch .topk (logits , k )
490
+ probs = torch .full_like (logits , float ('-inf' ))
491
+ probs .scatter_ (1 , ind , val )
492
+ return probs
493
+
453
494
# main class
454
495
455
496
class NUWA (nn .Module ):
@@ -495,9 +536,13 @@ def __init__(
495
536
496
537
fmap_size = image_size // (2 ** vae_num_layers )
497
538
539
+ self .video_fmap_size = fmap_size
540
+ self .max_video_frames = max_video_frames
541
+ video_shape = (max_video_frames , fmap_size , fmap_size )
542
+
498
543
self .video_pos_emb = AxialPositionalEmbedding (
499
544
dim = dim ,
500
- axial_shape = ( max_video_frames , fmap_size , fmap_size )
545
+ axial_shape = video_shape
501
546
)
502
547
503
548
self .video_transformer = Transformer (
@@ -511,11 +556,66 @@ def __init__(
511
556
ff_dropout = ff_dropout ,
512
557
sparse_3dna_attn = True ,
513
558
sparse_3dna_kernel_size = sparse_3dna_kernel_size ,
514
- sparse_3dna_video_shape = ( max_video_frames , fmap_size , fmap_size )
559
+ sparse_3dna_video_shape = video_shape
515
560
)
516
561
517
562
self .to_logits = nn .Linear (dim , num_image_tokens )
518
563
564
+ @torch .no_grad ()
565
+ @eval_decorator
566
+ def generate (
567
+ self ,
568
+ * ,
569
+ text ,
570
+ text_mask = None ,
571
+ filter_thres = 0.9 ,
572
+ temperature = 1.
573
+ ):
574
+ batch , seq_len , device = * text .shape , text .device
575
+ assert seq_len <= self .text_max_seq_len , 'your input text has a greater length than what was designated on initialization'
576
+
577
+ tokens = self .text_embedding (text )
578
+ pos_emb = self .text_pos_embedding (torch .arange (seq_len , device = device ))
579
+ tokens = tokens + rearrange (pos_emb , 'n d -> 1 n d' )
580
+
581
+ text_embeds = self .text_transformer (
582
+ tokens ,
583
+ mask = text_mask
584
+ )
585
+
586
+ bos = repeat (self .video_bos , 'd -> b 1 d' , b = batch )
587
+
588
+ video_indices = torch .empty ((batch , 0 ), device = device , dtype = torch .long )
589
+ total_video_tokens = self .video_fmap_size * self .video_fmap_size * self .max_video_frames
590
+
591
+ for _ in range (total_video_tokens ):
592
+ frame_embeddings = self .image_embedding (video_indices )
593
+ frame_embeddings = self .video_pos_emb (frame_embeddings ) + frame_embeddings
594
+ frame_embeddings = torch .cat ((bos , frame_embeddings ), dim = 1 )
595
+
596
+ frame_embeddings = self .video_transformer (
597
+ frame_embeddings ,
598
+ context = text_embeds ,
599
+ context_mask = text_mask
600
+ )
601
+
602
+ logits = self .to_logits (frame_embeddings )
603
+ logits = logits [:, - 1 , :]
604
+
605
+ filtered_logits = top_k (logits , thres = filter_thres )
606
+ filtered_logits /= temperature
607
+ filtered_logits -= torch .amax (filtered_logits , dim = 1 , keepdim = True )
608
+ probs = F .softmax (filtered_logits , dim = - 1 )
609
+ sample = torch .multinomial (probs , 1 )
610
+ video_indices = torch .cat ((video_indices , sample ), dim = 1 )
611
+
612
+ codes = self .vae .codebook [video_indices ]
613
+ codes = rearrange (codes , 'b (f h w) d -> (b f) d h w' , h = self .video_fmap_size , w = self .video_fmap_size )
614
+
615
+ image_reconstructions = self .vae .decode (codes )
616
+ video = rearrange (image_reconstructions , '(b f) d h w -> b f d h w' , b = batch )
617
+ return video
618
+
519
619
def forward (
520
620
self ,
521
621
* ,
0 commit comments