Skip to content

Commit

Permalink
cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Jun 8, 2022
1 parent 275998c commit f55f1b0
Showing 1 changed file with 3 additions and 9 deletions.
12 changes: 3 additions & 9 deletions video_diffusion_pytorch/video_diffusion_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import torch
from torch import nn, einsum
import torch.nn.functional as F
from inspect import isfunction
from functools import partial

from torch.utils import data
Expand Down Expand Up @@ -35,7 +34,7 @@ def is_odd(n):
def default(val, d):
if exists(val):
return val
return d() if isfunction(d) else d
return d() if callable(d) else d

def cycle(dl):
while True:
Expand Down Expand Up @@ -530,11 +529,6 @@ def extract(a, t, x_shape):
out = a.gather(-1, t)
return out.reshape(b, *((1,) * (len(x_shape) - 1)))

def noise_like(shape, device, repeat=False):
repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
noise = lambda: torch.randn(shape, device=device)
return repeat_noise() if repeat else noise()

def cosine_beta_schedule(timesteps, s = 0.008):
"""
cosine schedule
Expand Down Expand Up @@ -654,10 +648,10 @@ def p_mean_variance(self, x, t, clip_denoised: bool, cond = None, cond_scale = 1
return model_mean, posterior_variance, posterior_log_variance

@torch.inference_mode()
def p_sample(self, x, t, cond = None, cond_scale = 1., clip_denoised = True, repeat_noise = False):
def p_sample(self, x, t, cond = None, cond_scale = 1., clip_denoised = True):
b, *_, device = *x.shape, x.device
model_mean, _, model_log_variance = self.p_mean_variance(x = x, t = t, clip_denoised = clip_denoised, cond = cond, cond_scale = cond_scale)
noise = noise_like(x.shape, device, repeat_noise)
noise = torch.randn_like(x)
# no noise when t == 0
nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1,) * (len(x.shape) - 1)))
return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
Expand Down

0 comments on commit f55f1b0

Please sign in to comment.