Skip to content

Commit eb2c77e

Browse files
committedJan 3, 2022
avoid softmax and torch multinomial
1 parent fe227a4 commit eb2c77e

File tree

2 files changed

+10
-5
lines changed

2 files changed

+10
-5
lines changed
 

‎nuwa_pytorch/nuwa_pytorch.py

+9-4
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,13 @@ def log(t, eps = 1e-20):
4343
def sigmoid(t):
4444
return torch.where(t >= 0, 1 / (1 + torch.exp(-t)), t.exp() / (1 + t.exp()))
4545

46+
def gumbel_noise(t):
47+
noise = torch.zeros_like(t).uniform_(0, 1)
48+
return -log(-log(noise))
49+
50+
def gumbel_sample(t, temperature = 1., dim = -1):
51+
return ((t / temperature) + gumbel_noise(t)).argmax(dim = dim)
52+
4653
# gan losses
4754

4855
def hinge_discr_loss(fake, real):
@@ -604,10 +611,8 @@ def generate(
604611
logits = logits[:, -1, :]
605612

606613
filtered_logits = top_k(logits, thres = filter_thres)
607-
filtered_logits /= temperature
608-
filtered_logits -= torch.amax(filtered_logits, dim = 1, keepdim = True)
609-
probs = F.softmax(filtered_logits, dim = -1)
610-
sample = torch.multinomial(probs, 1)
614+
sample = gumbel_sample(filtered_logits, temperature = temperature, dim = -1)
615+
sample = rearrange(sample, 'b -> b 1')
611616
video_indices = torch.cat((video_indices, sample), dim = 1)
612617

613618
codes = self.vae.codebook[video_indices]

‎setup.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
setup(
44
name = 'nuwa-pytorch',
55
packages = find_packages(exclude=[]),
6-
version = '0.0.2',
6+
version = '0.0.3',
77
license='MIT',
88
description = 'NÜWA - Pytorch',
99
author = 'Phil Wang',

0 commit comments

Comments
 (0)
Please sign in to comment.