From a347be7111fbe4d3155744701f1140c3937a1414 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Thu, 26 May 2022 09:29:28 -0700 Subject: [PATCH] take a tiny step towards continuous times --- imagen_pytorch/imagen_pytorch.py | 17 ++++++++++++++--- setup.py | 2 +- 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/imagen_pytorch/imagen_pytorch.py b/imagen_pytorch/imagen_pytorch.py index 39c5e7e..00ab433 100644 --- a/imagen_pytorch/imagen_pytorch.py +++ b/imagen_pytorch/imagen_pytorch.py @@ -220,6 +220,10 @@ def __init__(self, *, beta_schedule, timesteps): register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod)) register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod)) + def sample_random_times(self, batch_size): + device = self.betas.device + return torch.randint(0, self.num_timesteps, (batch_size,), device = device, dtype = torch.long) + def get_learned_posterior_log_variance(self, var_interp_frac_unnormalized, x_t, t): # if learned variance, posterior variance and posterior log variance are predicted by the network # by an interpolation of the max and min log beta values @@ -254,6 +258,11 @@ def predict_start_from_noise(self, x_t, t, noise): extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise ) +class GaussianDiffusionContinuousTimes(GaussianDiffusion): + def __init__(self, *, beta_schedule, timesteps): + super().__init__() + raise NotImplementedError + # norms and residuals class LayerNorm(nn.Module): @@ -963,6 +972,7 @@ def __init__( learned_variance = True, vb_loss_weight = 0.001, auto_normalize_img = True, # whether to take care of normalizing the image from [0, 1] to [-1, 1] and back automatically - you can turn this off if you want to pass in the [-1, 1] ranged image yourself from the dataloader + continuous_times = False, dynamic_thresholding_percentile = 0.9 # unsure what this was based on perusal of paper ): super().__init__() @@ -1001,10 +1011,11 @@ def __init__( timesteps = cast_tuple(timesteps, num_unets) beta_schedules = cast_tuple(beta_schedules, num_unets) + noise_scheduler_klass = GaussianDiffusion if not continuous_times else GaussianDiffusionContinuousTimes self.noise_schedulers = nn.ModuleList([]) for timestep, beta_schedule in zip(timesteps, beta_schedules): - noise_scheduler = GaussianDiffusion(beta_schedule = beta_schedule, timesteps = timestep) + noise_scheduler = noise_scheduler_klass(beta_schedule = beta_schedule, timesteps = timestep) self.noise_schedulers.append(noise_scheduler) # whether to use learned variance, defaults to True for the first unet in the cascade, as in paper @@ -1327,7 +1338,7 @@ def forward( check_shape(image, 'b c h w', c = self.channels) assert h >= target_image_size and w >= target_image_size - times = torch.randint(0, noise_scheduler.num_timesteps, (b,), device = device, dtype = torch.long) + times = noise_scheduler.sample_random_times(b) if exists(texts) and not exists(text_embeds) and not self.unconditional: text_embeds, text_masks = t5_encode_text(texts, name = self.text_encoder_name) @@ -1342,7 +1353,7 @@ def forward( if exists(prev_image_size): lowres_cond_img = resize_image_to(image, prev_image_size) lowres_cond_img = resize_image_to(lowres_cond_img, target_image_size) - lowres_aug_times = torch.randint(0, noise_scheduler.num_timesteps, (b,), device = device, dtype = torch.long) + lowres_aug_times = noise_scheduler.sample_random_times(b) image = resize_image_to(image, target_image_size) diff --git a/setup.py b/setup.py index 207be6b..44d8e77 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ setup( name = 'imagen-pytorch', packages = find_packages(exclude=[]), - version = '0.0.30', + version = '0.0.31', license='MIT', description = 'Imagen - unprecedented photorealism × deep level of language understanding', author = 'Phil Wang',