Skip to content

Commit 7e90aea

Browse files
committedJan 3, 2022
complete video generation code
1 parent 6dc8b21 commit 7e90aea

File tree

2 files changed

+137
-18
lines changed

2 files changed

+137
-18
lines changed
 

‎README.md

+25-6
Original file line numberDiff line numberDiff line change
@@ -32,9 +32,16 @@ vae = VQGanVAE(
3232

3333
imgs = torch.randn(10, 3, 256, 256)
3434

35+
# alternate learning for autoencoder ...
36+
3537
loss = vae(imgs, return_loss = True)
3638
loss.backward()
3739

40+
# and the discriminator ...
41+
42+
discr_loss = vae(imgs, return_discr_loss = True)
43+
discr_loss.backward()
44+
3845
# do above for many steps
3946
```
4047

@@ -44,22 +51,29 @@ Then, with your learned VAE
4451
import torch
4552
from nuwa_pytorch import NUWA, VQGanVAE
4653

54+
# autoencoder
55+
4756
vae = VQGanVAE(
4857
dim = 512,
49-
num_layers = 4
58+
num_layers = 4,
59+
image_size = 256
5060
)
5161

62+
# NUWA transformer
63+
5264
nuwa = NUWA(
5365
vae = vae,
5466
dim = 512,
5567
max_video_frames = 5,
5668
text_num_tokens = 20000,
5769
image_size = 256
58-
)
70+
).cuda()
71+
72+
# data
5973

60-
text = torch.randint(0, 20000, (1, 256))
61-
mask = torch.ones(1, 256).bool()
62-
video = torch.randn(1, 5, 3, 256, 256)
74+
text = torch.randint(0, 20000, (1, 256)).cuda()
75+
mask = torch.ones(1, 256).bool().cuda()
76+
video = torch.randn(1, 5, 3, 256, 256).cuda()
6377

6478
loss = nuwa(
6579
text = text,
@@ -71,17 +85,22 @@ loss = nuwa(
7185
loss.backward()
7286

7387
# do above with as much data as possible
88+
89+
# then you can generate a video from text
90+
91+
video = nuwa.generate(text = text, text_mask = mask) # (1, 5, 3, 256, 256)
92+
7493
```
7594

7695
## Todo
7796

7897
- [x] complete 3dna causal attention in decoder
98+
- [x] write up easy generation functions
7999
- [ ] flesh out VAE resnet blocks, offer some choices
80100
- [ ] make sure GAN portion of VQGan is correct, reread paper
81101
- [ ] offer new vqvae improvements (orthogonal reg and smaller codebook dimensions)
82102
- [ ] offer vqvae training script
83103
- [ ] take care of audio transformer and cross modality attention
84-
- [ ] write up easy generation functions
85104
- [ ] segmentation mask encoder, make sure embeddings can undergo 3dna attention with decoder during cross attention
86105
- [ ] investigate custom attention layouts in microsoft deepspeed sparse attention (using triton)
87106

‎nuwa_pytorch/nuwa_pytorch.py

+112-12
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,17 @@ def exists(val):
2424
def default(val, d):
2525
return val if exists(val) else d
2626

27+
# decorators
28+
29+
def eval_decorator(fn):
30+
def inner(model, *args, **kwargs):
31+
was_training = model.training
32+
model.eval()
33+
out = fn(model, *args, **kwargs)
34+
model.train(was_training)
35+
return out
36+
return inner
37+
2738
# tensor helper functions
2839

2940
def log(t, eps = 1e-20):
@@ -129,6 +140,10 @@ def __init__(
129140
self.discr_loss = hinge_discr_loss if use_hinge_loss else bce_discr_loss
130141
self.gen_loss = hinge_gen_loss if use_hinge_loss else bce_gen_loss
131142

143+
@property
144+
def codebook(self):
145+
return self.vq.codebook
146+
132147
def encode(self, fmap):
133148
for enc in self.encoders:
134149
fmap = enc(fmap)
@@ -161,9 +176,11 @@ def forward(
161176

162177
fmap = self.decode(fmap)
163178

164-
if not return_loss:
179+
if not return_loss and not return_discr_loss:
165180
return fmap
166181

182+
assert return_loss ^ return_discr_loss, 'you should either return autoencoder loss or discriminator loss, but not both'
183+
167184
if return_discr_loss:
168185
fmap.detach_()
169186
fmap_discr_logits, img_discr_logits = map(self.discr, (fmap, img))
@@ -332,16 +349,31 @@ def forward(self, x, mask = None):
332349
# more variables
333350

334351
kernel_size = self.kernel_size
335-
num_frames, fmap_size, _ = self.video_shape
352+
fmap_size = self.video_shape[1]
353+
354+
bos_only = n == 1
355+
tokens_per_frame = fmap_size ** 2
356+
357+
padding = 0 if bos_only else (tokens_per_frame - (n - 1) % tokens_per_frame)
358+
num_frames = (n + padding) // tokens_per_frame
336359

337360
# pad for last token in video
338361

339-
x = F.pad(x, (0, 0, 0, 1), value = 0.)
362+
if padding > 0:
363+
x = F.pad(x, (0, 0, 0, padding), value = 0.)
340364

341365
# derive queries / keys / values
342366

343-
qkv = self.to_qkv(x).chunk(3, dim = -1)
344-
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), qkv)
367+
q, k, v = self.to_qkv(x).chunk(3, dim = -1)
368+
369+
# early return if <bos>
370+
371+
if bos_only:
372+
return self.to_out(v)
373+
374+
# split out heads
375+
376+
q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h = h), (q, k, v))
345377

346378
# scale queries
347379

@@ -352,11 +384,6 @@ def forward(self, x, mask = None):
352384
q = q[:, 1:]
353385
bos_value = v[:, :1]
354386

355-
# prepare precomputed causal mask
356-
357-
causal_mask = self.causal_mask[:n]
358-
causal_mask = repeat(causal_mask, 'i j -> b i j', b = b * h)
359-
360387
# compute keys and values
361388

362389
(k_bos, k), (v_bos, v) = map(lambda t: (t[:, :1], t[:, 1:]), (k, v))
@@ -376,6 +403,10 @@ def forward(self, x, mask = None):
376403

377404
# causal mask
378405

406+
i, j = sim.shape[-2:]
407+
causal_mask = self.causal_mask[:i, :j]
408+
causal_mask = repeat(causal_mask, 'i j -> b i j', b = b * h)
409+
379410
sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)
380411

381412
# attention
@@ -450,6 +481,16 @@ def forward(
450481

451482
return self.norm(x)
452483

484+
# sampling helpers
485+
486+
def top_k(logits, thres = 0.5):
487+
num_logits = logits.shape[-1]
488+
k = max(int((1 - thres) * num_logits), 1)
489+
val, ind = torch.topk(logits, k)
490+
probs = torch.full_like(logits, float('-inf'))
491+
probs.scatter_(1, ind, val)
492+
return probs
493+
453494
# main class
454495

455496
class NUWA(nn.Module):
@@ -495,9 +536,13 @@ def __init__(
495536

496537
fmap_size = image_size // (2 ** vae_num_layers)
497538

539+
self.video_fmap_size = fmap_size
540+
self.max_video_frames = max_video_frames
541+
video_shape = (max_video_frames, fmap_size, fmap_size)
542+
498543
self.video_pos_emb = AxialPositionalEmbedding(
499544
dim = dim,
500-
axial_shape = (max_video_frames, fmap_size, fmap_size)
545+
axial_shape = video_shape
501546
)
502547

503548
self.video_transformer = Transformer(
@@ -511,11 +556,66 @@ def __init__(
511556
ff_dropout = ff_dropout,
512557
sparse_3dna_attn = True,
513558
sparse_3dna_kernel_size = sparse_3dna_kernel_size,
514-
sparse_3dna_video_shape = (max_video_frames, fmap_size, fmap_size)
559+
sparse_3dna_video_shape = video_shape
515560
)
516561

517562
self.to_logits = nn.Linear(dim, num_image_tokens)
518563

564+
@torch.no_grad()
565+
@eval_decorator
566+
def generate(
567+
self,
568+
*,
569+
text,
570+
text_mask = None,
571+
filter_thres = 0.9,
572+
temperature = 1.
573+
):
574+
batch, seq_len, device = *text.shape, text.device
575+
assert seq_len <= self.text_max_seq_len, 'your input text has a greater length than what was designated on initialization'
576+
577+
tokens = self.text_embedding(text)
578+
pos_emb = self.text_pos_embedding(torch.arange(seq_len, device = device))
579+
tokens = tokens + rearrange(pos_emb, 'n d -> 1 n d')
580+
581+
text_embeds = self.text_transformer(
582+
tokens,
583+
mask = text_mask
584+
)
585+
586+
bos = repeat(self.video_bos, 'd -> b 1 d', b = batch)
587+
588+
video_indices = torch.empty((batch, 0), device = device, dtype = torch.long)
589+
total_video_tokens = self.video_fmap_size * self.video_fmap_size * self.max_video_frames
590+
591+
for _ in range(total_video_tokens):
592+
frame_embeddings = self.image_embedding(video_indices)
593+
frame_embeddings = self.video_pos_emb(frame_embeddings) + frame_embeddings
594+
frame_embeddings = torch.cat((bos, frame_embeddings), dim = 1)
595+
596+
frame_embeddings = self.video_transformer(
597+
frame_embeddings,
598+
context = text_embeds,
599+
context_mask = text_mask
600+
)
601+
602+
logits = self.to_logits(frame_embeddings)
603+
logits = logits[:, -1, :]
604+
605+
filtered_logits = top_k(logits, thres = filter_thres)
606+
filtered_logits /= temperature
607+
filtered_logits -= torch.amax(filtered_logits, dim = 1, keepdim = True)
608+
probs = F.softmax(filtered_logits, dim = -1)
609+
sample = torch.multinomial(probs, 1)
610+
video_indices = torch.cat((video_indices, sample), dim = 1)
611+
612+
codes = self.vae.codebook[video_indices]
613+
codes = rearrange(codes, 'b (f h w) d -> (b f) d h w', h = self.video_fmap_size, w = self.video_fmap_size)
614+
615+
image_reconstructions = self.vae.decode(codes)
616+
video = rearrange(image_reconstructions, '(b f) d h w -> b f d h w', b = batch)
617+
return video
618+
519619
def forward(
520620
self,
521621
*,

0 commit comments

Comments
 (0)
Please sign in to comment.