Skip to content

Commit

Permalink
take a tiny step towards continuous times
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed May 26, 2022
1 parent 2438326 commit a347be7
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 4 deletions.
17 changes: 14 additions & 3 deletions imagen_pytorch/imagen_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,10 @@ def __init__(self, *, beta_schedule, timesteps):
register_buffer('posterior_mean_coef1', betas * torch.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
register_buffer('posterior_mean_coef2', (1. - alphas_cumprod_prev) * torch.sqrt(alphas) / (1. - alphas_cumprod))

def sample_random_times(self, batch_size):
device = self.betas.device
return torch.randint(0, self.num_timesteps, (batch_size,), device = device, dtype = torch.long)

def get_learned_posterior_log_variance(self, var_interp_frac_unnormalized, x_t, t):
# if learned variance, posterior variance and posterior log variance are predicted by the network
# by an interpolation of the max and min log beta values
Expand Down Expand Up @@ -254,6 +258,11 @@ def predict_start_from_noise(self, x_t, t, noise):
extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
)

class GaussianDiffusionContinuousTimes(GaussianDiffusion):
def __init__(self, *, beta_schedule, timesteps):
super().__init__()
raise NotImplementedError

# norms and residuals

class LayerNorm(nn.Module):
Expand Down Expand Up @@ -963,6 +972,7 @@ def __init__(
learned_variance = True,
vb_loss_weight = 0.001,
auto_normalize_img = True, # whether to take care of normalizing the image from [0, 1] to [-1, 1] and back automatically - you can turn this off if you want to pass in the [-1, 1] ranged image yourself from the dataloader
continuous_times = False,
dynamic_thresholding_percentile = 0.9 # unsure what this was based on perusal of paper
):
super().__init__()
Expand Down Expand Up @@ -1001,10 +1011,11 @@ def __init__(
timesteps = cast_tuple(timesteps, num_unets)
beta_schedules = cast_tuple(beta_schedules, num_unets)

noise_scheduler_klass = GaussianDiffusion if not continuous_times else GaussianDiffusionContinuousTimes
self.noise_schedulers = nn.ModuleList([])

for timestep, beta_schedule in zip(timesteps, beta_schedules):
noise_scheduler = GaussianDiffusion(beta_schedule = beta_schedule, timesteps = timestep)
noise_scheduler = noise_scheduler_klass(beta_schedule = beta_schedule, timesteps = timestep)
self.noise_schedulers.append(noise_scheduler)

# whether to use learned variance, defaults to True for the first unet in the cascade, as in paper
Expand Down Expand Up @@ -1327,7 +1338,7 @@ def forward(
check_shape(image, 'b c h w', c = self.channels)
assert h >= target_image_size and w >= target_image_size

times = torch.randint(0, noise_scheduler.num_timesteps, (b,), device = device, dtype = torch.long)
times = noise_scheduler.sample_random_times(b)

if exists(texts) and not exists(text_embeds) and not self.unconditional:
text_embeds, text_masks = t5_encode_text(texts, name = self.text_encoder_name)
Expand All @@ -1342,7 +1353,7 @@ def forward(
if exists(prev_image_size):
lowres_cond_img = resize_image_to(image, prev_image_size)
lowres_cond_img = resize_image_to(lowres_cond_img, target_image_size)
lowres_aug_times = torch.randint(0, noise_scheduler.num_timesteps, (b,), device = device, dtype = torch.long)
lowres_aug_times = noise_scheduler.sample_random_times(b)

image = resize_image_to(image, target_image_size)

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'imagen-pytorch',
packages = find_packages(exclude=[]),
version = '0.0.30',
version = '0.0.31',
license='MIT',
description = 'Imagen - unprecedented photorealism × deep level of language understanding',
author = 'Phil Wang',
Expand Down

0 comments on commit a347be7

Please sign in to comment.