@@ -43,6 +43,13 @@ def log(t, eps = 1e-20):
43
43
def sigmoid (t ):
44
44
return torch .where (t >= 0 , 1 / (1 + torch .exp (- t )), t .exp () / (1 + t .exp ()))
45
45
46
+ def gumbel_noise (t ):
47
+ noise = torch .zeros_like (t ).uniform_ (0 , 1 )
48
+ return - log (- log (noise ))
49
+
50
+ def gumbel_sample (t , temperature = 1. , dim = - 1 ):
51
+ return ((t / temperature ) + gumbel_noise (t )).argmax (dim = dim )
52
+
46
53
# gan losses
47
54
48
55
def hinge_discr_loss (fake , real ):
@@ -604,10 +611,8 @@ def generate(
604
611
logits = logits [:, - 1 , :]
605
612
606
613
filtered_logits = top_k (logits , thres = filter_thres )
607
- filtered_logits /= temperature
608
- filtered_logits -= torch .amax (filtered_logits , dim = 1 , keepdim = True )
609
- probs = F .softmax (filtered_logits , dim = - 1 )
610
- sample = torch .multinomial (probs , 1 )
614
+ sample = gumbel_sample (filtered_logits , temperature = temperature , dim = - 1 )
615
+ sample = rearrange (sample , 'b -> b 1' )
611
616
video_indices = torch .cat ((video_indices , sample ), dim = 1 )
612
617
613
618
codes = self .vae .codebook [video_indices ]
0 commit comments