From 44da9be862eb026ca3edb5221a96025f552af694 Mon Sep 17 00:00:00 2001 From: Phil Wang Date: Wed, 25 Jan 2023 12:12:45 -0800 Subject: [PATCH] add ability to specify temporal interpolation using temporal_downsample_factor keyword for both Imagen and ElucidatedImagen --- README.md | 3 +- imagen_pytorch/elucidated_imagen.py | 44 +++++++++++++++----- imagen_pytorch/imagen_pytorch.py | 64 ++++++++++++++++++++++------- imagen_pytorch/imagen_video.py | 12 +++--- imagen_pytorch/version.py | 2 +- 5 files changed, 92 insertions(+), 33 deletions(-) diff --git a/README.md b/README.md index cfff128..5ffe541 100644 --- a/README.md +++ b/README.md @@ -695,6 +695,7 @@ Anything! It is MIT licensed. In other words, you can freely copy / paste for yo - [x] add v-parameterization (https://arxiv.org/abs/2202.00512) from imagen video paper, the only thing new - [x] incorporate all learnings from make-a-video (https://makeavideo.studio/) - [x] build out CLI tool for training, resuming training off config file +- [x] allow for temporal interpolation at specific stages - [ ] reread cogvideo and figure out how frame rate conditioning could be used - [ ] bring in attention expertise for self attention layers in unet3d @@ -716,7 +717,7 @@ Anything! It is MIT licensed. In other words, you can freely copy / paste for yo - [ ] add textual inversion - [ ] cleanup self conditioning to be extracted at imagen instantiation - [ ] make sure eventual dreambooth works with imagen-video -- [ ] allow for temporal interpolation at specific stages +- [ ] make sure temporal interpolation works with inpainting ## Citations diff --git a/imagen_pytorch/elucidated_imagen.py b/imagen_pytorch/elucidated_imagen.py index 62dd583..ba63c35 100644 --- a/imagen_pytorch/elucidated_imagen.py +++ b/imagen_pytorch/elucidated_imagen.py @@ -29,11 +29,11 @@ default, cast_tuple, cast_uint8_images_to_float, - is_float_dtype, eval_decorator, check_shape, pad_tuple_to_length, resize_image_to, + calc_all_frame_dims, right_pad_dims_to, module_device, normalize_neg_one_to_one, @@ -83,6 +83,7 @@ def __init__( channels = 3, cond_drop_prob = 0.1, random_crop_sizes = None, + temporal_downsample_factor = 1, lowres_sample_noise_level = 0.2, # in the paper, they present a new trick where they noise the lowres conditioning image, and at sample time, fix it to a certain level (0.1 or 0.3) - the unets are also made to be conditioned on this noise level per_sample_random_aug_noise_level = False, # unclear when conditioning on augmentation noise level, whether each batch element receives a random aug noise value - turning off due to @marunine's find condition_on_text = True, @@ -196,6 +197,14 @@ def __init__( self.dynamic_thresholding = cast_tuple(dynamic_thresholding, num_unets) self.dynamic_thresholding_percentile = dynamic_thresholding_percentile + # temporal interpolations + + temporal_downsample_factor = cast_tuple(temporal_downsample_factor, num_unets) + self.temporal_downsample_factor = temporal_downsample_factor + + assert temporal_downsample_factor[-1] == 1, 'downsample factor of last stage must be 1' + assert all([left >= right for left, right in zip((1, *temporal_downsample_factor[:-1]), temporal_downsample_factor[1:])]), 'temporal downssample factor must be in order of descending' + # elucidating parameters hparams = [ @@ -589,7 +598,12 @@ def sample( assert not (self.is_video and not exists(video_frames)), 'video_frames must be passed in on sample time if training on video' - frame_dims = (video_frames,) if self.is_video else tuple() + # determine the frame dimensions, if needed + + if self.is_video: + all_frame_dims = calc_all_frame_dims(self.temporal_downsample_factor, video_frames) + else: + all_frame_dims = (tuple(),) * num_unets # initializing with an image or video @@ -613,7 +627,7 @@ def sample( # go through each unet in cascade - for unet_number, unet, channel, image_size, unet_hparam, dynamic_threshold, unet_cond_scale, unet_init_images, unet_skip_steps, unet_sigma_min, unet_sigma_max in tqdm(zip(range(1, num_unets + 1), self.unets, self.sample_channels, self.image_sizes, self.hparams, self.dynamic_thresholding, cond_scale, init_images, skip_steps, sigma_min, sigma_max), disable = not use_tqdm): + for unet_number, unet, channel, image_size, frame_dims, unet_hparam, dynamic_threshold, unet_cond_scale, unet_init_images, unet_skip_steps, unet_sigma_min, unet_sigma_max in tqdm(zip(range(1, num_unets + 1), self.unets, self.sample_channels, self.image_sizes, all_frame_dims, self.hparams, self.dynamic_thresholding, cond_scale, init_images, skip_steps, sigma_min, sigma_max), disable = not use_tqdm): if unet_number < start_at_unet_number: continue @@ -626,16 +640,20 @@ def sample( shape = (batch_size, channel, *frame_dims, image_size, image_size) + resize_kwargs = dict() + if self.is_video: + resize_kwargs = dict(target_frames = frame_dims[0]) + if unet.lowres_cond: lowres_noise_times = self.lowres_noise_schedule.get_times(batch_size, lowres_sample_noise_level, device = device) - lowres_cond_img = self.resize_to(img, image_size) + lowres_cond_img = self.resize_to(img, image_size, **resize_kwargs) lowres_cond_img = self.normalize_img(lowres_cond_img) lowres_cond_img, *_ = self.lowres_noise_schedule.q_sample(x_start = lowres_cond_img, t = lowres_noise_times, noise = torch.randn_like(lowres_cond_img)) if exists(unet_init_images): - unet_init_images = self.resize_to(unet_init_images, image_size) + unet_init_images = self.resize_to(unet_init_images, image_size, **resize_kwargs) shape = (batch_size, self.channels, *frame_dims, image_size, image_size) @@ -710,7 +728,7 @@ def forward( images = cast_uint8_images_to_float(images) cond_images = maybe(cast_uint8_images_to_float)(cond_images) - assert is_float_dtype(images.dtype), f'images tensor needs to be floats but {images.dtype} dtype found instead' + assert images.dtype == torch.float, f'images tensor needs to be floats but {images.dtype} dtype found instead' unet_index = unet_number - 1 @@ -725,7 +743,13 @@ def forward( batch_size, c, *_, h, w, device, is_video = *images.shape, images.device, (images.ndim == 5) - frames = images.shape[2] if is_video else None + frames = images.shape[2] if is_video else None + all_frame_dims = tuple(el[0] for el in calc_all_frame_dims(self.temporal_downsample_factor, frames)) + ignore_time = kwargs.get('ignore_time', False) + + target_frame_size = all_frame_dims[unet_index] if is_video and not ignore_time else None + prev_frame_size = all_frame_dims[unet_index - 1] if is_video and not ignore_time and unet_index > 0 else None + frames_to_resize_kwargs = lambda frames: dict(target_frames = frames) if exists(frames) else dict() check_shape(images, 'b c ...', c = self.channels) @@ -750,8 +774,8 @@ def forward( lowres_cond_img = lowres_aug_times = None if exists(prev_image_size): - lowres_cond_img = self.resize_to(images, prev_image_size, clamp_range = self.input_image_range) - lowres_cond_img = self.resize_to(lowres_cond_img, target_image_size, clamp_range = self.input_image_range) + lowres_cond_img = self.resize_to(images, prev_image_size, **frames_to_resize_kwargs(prev_frame_size), clamp_range = self.input_image_range) + lowres_cond_img = self.resize_to(lowres_cond_img, target_image_size, **frames_to_resize_kwargs(target_frame_size), clamp_range = self.input_image_range) if self.per_sample_random_aug_noise_level: lowres_aug_times = self.lowres_noise_schedule.sample_random_times(batch_size, device = device) @@ -759,7 +783,7 @@ def forward( lowres_aug_time = self.lowres_noise_schedule.sample_random_times(1, device = device) lowres_aug_times = repeat(lowres_aug_time, '1 -> b', b = batch_size) - images = self.resize_to(images, target_image_size) + images = self.resize_to(images, target_image_size, **frames_to_resize_kwargs(target_frame_size)) # normalize to [-1, 1] diff --git a/imagen_pytorch/imagen_pytorch.py b/imagen_pytorch/imagen_pytorch.py index dd82aa9..6fb2eb4 100644 --- a/imagen_pytorch/imagen_pytorch.py +++ b/imagen_pytorch/imagen_pytorch.py @@ -35,6 +35,9 @@ def exists(val): def identity(t, *args, **kwargs): return t +def divisible_by(numer, denom): + return (numer % denom) == 0 + def first(arr, d = None): if len(arr) == 0: return d @@ -77,9 +80,6 @@ def cast_tuple(val, length = None): return output -def is_float_dtype(dtype): - return any([dtype == float_dtype for float_dtype in (torch.float64, torch.float32, torch.float16, torch.bfloat16)]) - def cast_uint8_images_to_float(images): if not images.dtype == torch.uint8: return images @@ -158,6 +158,18 @@ def resize_image_to( return out +def calc_all_frame_dims( + downsample_factors: List[int], + frames +): + all_frame_dims = [] + + for divisor in downsample_factors: + assert divisible_by(frames, divisor) + all_frame_dims.append((frames // divisor,)) + + return all_frame_dims + # image normalization functions # ddpms expect images to be in the range of -1 to 1 @@ -1124,7 +1136,7 @@ def __init__( cosine_sim_attn = False, self_cond = False, combine_upsample_fmaps = False, # combine feature maps from all upsample blocks, used in unet squared successfully - pixel_shuffle_upsample = True # may address checkboard artifacts + pixel_shuffle_upsample = True, # may address checkboard artifacts ): super().__init__() @@ -1772,7 +1784,8 @@ def __init__( p2_loss_weight_k = 1, dynamic_thresholding = True, dynamic_thresholding_percentile = 0.95, # unsure what this was based on perusal of paper - only_train_unet_number = None + only_train_unet_number = None, + temporal_downsample_factor = 1 ): super().__init__() @@ -1882,6 +1895,13 @@ def __init__( self.right_pad_dims_to_datatype = partial(rearrange, pattern = ('b -> b 1 1 1' if not is_video else 'b -> b 1 1 1 1')) self.resize_to = resize_video_to if is_video else resize_image_to + # temporal interpolation + + temporal_downsample_factor = cast_tuple(temporal_downsample_factor, num_unets) + self.temporal_downsample_factor = temporal_downsample_factor + assert temporal_downsample_factor[-1] == 1, 'downsample factor of last stage must be 1' + assert all([left >= right for left, right in zip((1, *temporal_downsample_factor[:-1]), temporal_downsample_factor[1:])]), 'temporal downssample factor must be in order of descending' + # cascading ddpm related stuff lowres_conditions = tuple(map(lambda t: t.lowres_cond, self.unets)) @@ -2240,7 +2260,12 @@ def sample( assert not (self.is_video and not exists(video_frames)), 'video_frames must be passed in on sample time if training on video' - frame_dims = (video_frames,) if self.is_video else tuple() + if self.is_video: + all_frame_dims = calc_all_frame_dims(self.temporal_downsample_factor, video_frames) + else: + all_frame_dims = (tuple(),) * num_unets + + frames_to_resize_kwargs = lambda frames: dict(target_frames = frames) if exists(frames) else dict() # for initial image and skipping steps @@ -2257,11 +2282,12 @@ def sample( assert exists(start_image_or_video), 'starting image or video must be supplied if only doing upscaling' prev_image_size = self.image_sizes[start_at_unet_number - 2] - img = self.resize_to(start_image_or_video, prev_image_size) + prev_frame_size = all_frame_dims[start_at_unet_number - 2][0] if self.is_video else None + img = self.resize_to(start_image_or_video, prev_image_size, **frames_to_resize_kwargs(prev_frame_size)) # go through each unet in cascade - for unet_number, unet, channel, image_size, noise_scheduler, pred_objective, dynamic_threshold, unet_cond_scale, unet_init_images, unet_skip_steps in tqdm(zip(range(1, num_unets + 1), self.unets, self.sample_channels, self.image_sizes, self.noise_schedulers, self.pred_objectives, self.dynamic_thresholding, cond_scale, init_images, skip_steps), disable = not use_tqdm): + for unet_number, unet, channel, image_size, frame_dims, noise_scheduler, pred_objective, dynamic_threshold, unet_cond_scale, unet_init_images, unet_skip_steps in tqdm(zip(range(1, num_unets + 1), self.unets, self.sample_channels, self.image_sizes, all_frame_dims, self.noise_schedulers, self.pred_objectives, self.dynamic_thresholding, cond_scale, init_images, skip_steps), disable = not use_tqdm): if unet_number < start_at_unet_number: continue @@ -2274,16 +2300,18 @@ def sample( lowres_cond_img = lowres_noise_times = None shape = (batch_size, channel, *frame_dims, image_size, image_size) + resize_kwargs = dict(target_frames = frame_dims[0]) if self.is_video else dict() + if unet.lowres_cond: lowres_noise_times = self.lowres_noise_schedule.get_times(batch_size, lowres_sample_noise_level, device = device) - lowres_cond_img = self.resize_to(img, image_size) + lowres_cond_img = self.resize_to(img, image_size, **resize_kwargs) lowres_cond_img = self.normalize_img(lowres_cond_img) lowres_cond_img, *_ = self.lowres_noise_schedule.q_sample(x_start = lowres_cond_img, t = lowres_noise_times, noise = torch.randn_like(lowres_cond_img)) if exists(unet_init_images): - unet_init_images = self.resize_to(unet_init_images, image_size) + unet_init_images = self.resize_to(unet_init_images, image_size, **resize_kwargs) shape = (batch_size, self.channels, *frame_dims, image_size, image_size) @@ -2480,7 +2508,7 @@ def forward( images = cast_uint8_images_to_float(images) cond_images = maybe(cast_uint8_images_to_float)(cond_images) - assert is_float_dtype(images.dtype), f'images tensor needs to be floats but {images.dtype} dtype found instead' + assert images.dtype == torch.float, f'images tensor needs to be floats but {images.dtype} dtype found instead' unet_index = unet_number - 1 @@ -2500,7 +2528,13 @@ def forward( check_shape(images, 'b c ...', c = self.channels) assert h >= target_image_size and w >= target_image_size - frames = images.shape[2] if is_video else None + frames = images.shape[2] if is_video else None + all_frame_dims = tuple(el[0] for el in calc_all_frame_dims(self.temporal_downsample_factor, frames)) + ignore_time = kwargs.get('ignore_time', False) + + target_frame_size = all_frame_dims[unet_index] if is_video and not ignore_time else None + prev_frame_size = all_frame_dims[unet_index - 1] if is_video and not ignore_time and unet_index > 0 else None + frames_to_resize_kwargs = lambda frames: dict(target_frames = frames) if exists(frames) else dict() times = noise_scheduler.sample_random_times(b, device = device) @@ -2523,8 +2557,8 @@ def forward( lowres_cond_img = lowres_aug_times = None if exists(prev_image_size): - lowres_cond_img = self.resize_to(images, prev_image_size, clamp_range = self.input_image_range) - lowres_cond_img = self.resize_to(lowres_cond_img, target_image_size, clamp_range = self.input_image_range) + lowres_cond_img = self.resize_to(images, prev_image_size, **frames_to_resize_kwargs(prev_frame_size), clamp_range = self.input_image_range) + lowres_cond_img = self.resize_to(lowres_cond_img, target_image_size, **frames_to_resize_kwargs(target_frame_size), clamp_range = self.input_image_range) if self.per_sample_random_aug_noise_level: lowres_aug_times = self.lowres_noise_schedule.sample_random_times(b, device = device) @@ -2532,6 +2566,6 @@ def forward( lowres_aug_time = self.lowres_noise_schedule.sample_random_times(1, device = device) lowres_aug_times = repeat(lowres_aug_time, '1 -> b', b = b) - images = self.resize_to(images, target_image_size) + images = self.resize_to(images, target_image_size, **frames_to_resize_kwargs(target_frame_size)) return self.p_losses(unet, images, times, text_embeds = text_embeds, text_mask = text_masks, cond_images = cond_images, noise_scheduler = noise_scheduler, lowres_cond_img = lowres_cond_img, lowres_aug_times = lowres_aug_times, pred_objective = pred_objective, p2_loss_weight_gamma = p2_loss_weight_gamma, random_crop_size = random_crop_size, **kwargs) diff --git a/imagen_pytorch/imagen_video.py b/imagen_pytorch/imagen_video.py index f05f04a..005e2f8 100644 --- a/imagen_pytorch/imagen_video.py +++ b/imagen_pytorch/imagen_video.py @@ -137,6 +137,7 @@ def masked_mean(t, *, dim, mask = None): def resize_video_to( video, target_image_size, + target_frames = None, clamp_range = None ): orig_video_size = video.shape[-1] @@ -144,16 +145,15 @@ def resize_video_to( if orig_video_size == target_image_size: return video - frames = video.shape[2] - video = rearrange(video, 'b c f h w -> (b f) c h w') + target_frames = default(target_frames, frames) + + target_shape = (target_frames, target_image_size, target_image_size) - out = F.interpolate(video, target_image_size, mode = 'nearest') + out = F.interpolate(video, target_shape, mode = 'nearest') if exists(clamp_range): out = out.clamp(*clamp_range) - - out = rearrange(out, '(b f) c h w -> b c f h w', f = frames) return out @@ -1600,7 +1600,7 @@ def forward( batch_size, frames, device, dtype = x.shape[0], x.shape[2], x.device, x.dtype - assert ignore_time or divisible_by(frames, self.total_temporal_divisor), f'number of input frames must be divisible by {self.total_temporal_divisor}' + assert ignore_time or divisible_by(frames, self.total_temporal_divisor), f'number of input frames {frames} must be divisible by {self.total_temporal_divisor}' # add self conditioning if needed diff --git a/imagen_pytorch/version.py b/imagen_pytorch/version.py index 8ac48f0..bec0738 100644 --- a/imagen_pytorch/version.py +++ b/imagen_pytorch/version.py @@ -1 +1 @@ -__version__ = '1.19.0' +__version__ = '1.20.0'