Skip to content

Commit

Permalink
bring in the pseudo 3d conv for the 3dunet from make-a-video, and add…
Browse files Browse the repository at this point in the history
… the new prediction objective from the progressive distillation paper, which allows for distillation, and noted in imagen video to improve upresoluting unets (fixes the color shifting issue)
  • Loading branch information
lucidrains committed Oct 24, 2022
1 parent fa29d24 commit 906f296
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 5 deletions.
27 changes: 27 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -774,6 +774,13 @@ Anything! It is MIT licensed. In other words, you can freely copy / paste for yo
}
```

```bibtex
@misc{Singer2022,
author = {Uriel Singer},
url = {https://makeavideo.studio/Make-A-Video.pdf}
}
```

```bibtex
@article{Sunkara2022NoMS,
title = {No More Strided Convolutions or Pooling: A New CNN Building Block for Low-Resolution Images and Small Objects},
Expand All @@ -783,3 +790,23 @@ Anything! It is MIT licensed. In other words, you can freely copy / paste for yo
volume = {abs/2208.03641}
}
```

```bibtex
@article{Salimans2022ProgressiveDF,
title = {Progressive Distillation for Fast Sampling of Diffusion Models},
author = {Tim Salimans and Jonathan Ho},
journal = {ArXiv},
year = {2022},
volume = {abs/2202.00512}
}
```

```bibtx
@article{Ho2022ImagenVH,
title = {Imagen Video: High Definition Video Generation with Diffusion Models},
author = {Jonathan Ho and William Chan and Chitwan Saharia and Jay Whang and Ruiqi Gao and Alexey A. Gritsenko and Diederik P. Kingma and Ben Poole and Mohammad Norouzi and David J. Fleet and Tim Salimans},
journal = {ArXiv},
year = {2022},
volume = {abs/2210.02303}
}
```
21 changes: 17 additions & 4 deletions imagen_pytorch/imagen_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ def q_sample(self, x_start, t, noise = None):
log_snr_padded_dim = right_pad_dims_to(x_start, log_snr)
alpha, sigma = log_snr_to_alpha_sigma(log_snr_padded_dim)

return alpha * x_start + sigma * noise, log_snr
return alpha * x_start + sigma * noise, log_snr, alpha, sigma

def q_sample_from_to(self, x_from, from_t, to_t, noise = None):
shape, device, dtype = x_from.shape, x_from.device, x_from.dtype
Expand All @@ -276,6 +276,12 @@ def q_sample_from_to(self, x_from, from_t, to_t, noise = None):

return x_from * (alpha_to / alpha) + noise * (sigma_to * alpha - sigma * alpha_to) / alpha

def predict_start_from_v(self, x_t, t, v):
log_snr = self.log_snr(t)
log_snr = right_pad_dims_to(x_t, log_snr)
alpha, sigma = log_snr_to_alpha_sigma(log_snr)
return alpha * x_t - sigma * v

def predict_start_from_noise(self, x_t, t, noise):
log_snr = self.log_snr(t)
log_snr = right_pad_dims_to(x_t, log_snr)
Expand Down Expand Up @@ -1997,6 +2003,8 @@ def p_mean_variance(
x_start = noise_scheduler.predict_start_from_noise(x, t = t, noise = pred)
elif pred_objective == 'x_start':
x_start = pred
elif pred_objective == 'v':
x_start = noise_scheduler.predict_start_from_v(x, t = t, v = pred)
else:
raise ValueError(f'unknown objective {pred_objective}')

Expand Down Expand Up @@ -2108,7 +2116,7 @@ def p_sample_loop(
is_last_resample_step = r == 0

if has_inpainting:
noised_inpaint_images, _ = noise_scheduler.q_sample(inpaint_images, t = times)
noised_inpaint_images, *_ = noise_scheduler.q_sample(inpaint_images, t = times)
img = img * ~inpaint_masks + noised_inpaint_images * inpaint_masks

self_cond = x_start if unet.self_cond else None
Expand Down Expand Up @@ -2264,7 +2272,7 @@ def sample(
lowres_cond_img = self.resize_to(img, image_size)

lowres_cond_img = self.normalize_img(lowres_cond_img)
lowres_cond_img, _ = self.lowres_noise_schedule.q_sample(x_start = lowres_cond_img, t = lowres_noise_times, noise = torch.randn_like(lowres_cond_img))
lowres_cond_img, *_ = self.lowres_noise_schedule.q_sample(x_start = lowres_cond_img, t = lowres_noise_times, noise = torch.randn_like(lowres_cond_img))

if exists(unet_init_images):
unet_init_images = self.resize_to(unet_init_images, image_size)
Expand Down Expand Up @@ -2358,7 +2366,7 @@ def p_losses(

# get x_t

x_noisy, log_snr = noise_scheduler.q_sample(x_start = x_start, t = times, noise = noise)
x_noisy, log_snr, alpha, sigma = noise_scheduler.q_sample(x_start = x_start, t = times, noise = noise)

# also noise the lowres conditioning image
# at sample time, they then fix the noise level of 0.1 - 0.3
Expand Down Expand Up @@ -2416,6 +2424,11 @@ def p_losses(
target = noise
elif pred_objective == 'x_start':
target = x_start
elif pred_objective == 'v':
# derivation detailed in Appendix D of Progressive Distillation paper
# https://arxiv.org/abs/2202.00512
# this makes distillation viable as well as solve an issue with color shifting in upresoluting unets, noted in imagen-video
target = alpha * noise - sigma * x_start
else:
raise ValueError(f'unknown objective {pred_objective}')

Expand Down
51 changes: 51 additions & 0 deletions imagen_pytorch/imagen_video/imagen_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,57 @@ def forward(self, x, mask = None):

return latents

# pseudo 3d conv from make-a-video

class PseudoConv3D(nn.Module):
def __init__(
self,
dim,
*,
kernel_size,
dim_out = None,
temporal_kernel_size = None,
**kwargs
):
super().__init__()
dim_out = default(dim_out, dim)
temporal_kernel_size = default(temporal_kernel_size, kernel_size)

self.spatial_conv = nn.Conv2d(dim, dim_out, kernel_size = kernel_size, padding = kernel_size // 2)
self.temporal_conv = nn.Conv1d(dim_out, dim_out, kernel_size = temporal_kernel_size, padding = temporal_kernel_size // 2)

nn.init.dirac_(self.temporal_conv.weight.data) # initialized to be identity
nn.init.zeros_(self.temporal_conv.bias.data)

def forward(
self,
x,
convolve_across_time = True
):
b, c, *_, h, w = x.shape

is_video = x.ndim == 5
convolve_across_time &= is_video

if is_video:
x = rearrange(x, 'b c f h w -> (b f) c h w')

x = self.spatial_conv(x)

if is_video:
x = rearrange(x, '(b f) c h w -> b c f h w', b = b)

if not convolve_across_time:
return x

x = rearrange(x, 'b c f h w -> (b h w) c f')

x = self.temporal_conv(x)

x = rearrange(x, '(b h w) c f -> b c f h w', h = h, w = w)

return x

# attention

class Attention(nn.Module):
Expand Down
2 changes: 1 addition & 1 deletion imagen_pytorch/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '1.11.15'
__version__ = '1.12.0'

0 comments on commit 906f296

Please sign in to comment.