Skip to content

Commit

Permalink
add ability to specify temporal interpolation using temporal_downsamp…
Browse files Browse the repository at this point in the history
…le_factor keyword for both Imagen and ElucidatedImagen
  • Loading branch information
lucidrains committed Jan 25, 2023
1 parent 3540965 commit 44da9be
Show file tree
Hide file tree
Showing 5 changed files with 92 additions and 33 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -695,6 +695,7 @@ Anything! It is MIT licensed. In other words, you can freely copy / paste for yo
- [x] add v-parameterization (https://arxiv.org/abs/2202.00512) from <a href="https://imagen.research.google/video/paper.pdf">imagen video</a> paper, the only thing new
- [x] incorporate all learnings from make-a-video (https://makeavideo.studio/)
- [x] build out CLI tool for training, resuming training off config file
- [x] allow for temporal interpolation at specific stages

- [ ] reread <a href="https://arxiv.org/abs/2205.15868">cogvideo</a> and figure out how frame rate conditioning could be used
- [ ] bring in attention expertise for self attention layers in unet3d
Expand All @@ -716,7 +717,7 @@ Anything! It is MIT licensed. In other words, you can freely copy / paste for yo
- [ ] add textual inversion
- [ ] cleanup self conditioning to be extracted at imagen instantiation
- [ ] make sure eventual dreambooth works with imagen-video
- [ ] allow for temporal interpolation at specific stages
- [ ] make sure temporal interpolation works with inpainting

## Citations

Expand Down
44 changes: 34 additions & 10 deletions imagen_pytorch/elucidated_imagen.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,11 @@
default,
cast_tuple,
cast_uint8_images_to_float,
is_float_dtype,
eval_decorator,
check_shape,
pad_tuple_to_length,
resize_image_to,
calc_all_frame_dims,
right_pad_dims_to,
module_device,
normalize_neg_one_to_one,
Expand Down Expand Up @@ -83,6 +83,7 @@ def __init__(
channels = 3,
cond_drop_prob = 0.1,
random_crop_sizes = None,
temporal_downsample_factor = 1,
lowres_sample_noise_level = 0.2, # in the paper, they present a new trick where they noise the lowres conditioning image, and at sample time, fix it to a certain level (0.1 or 0.3) - the unets are also made to be conditioned on this noise level
per_sample_random_aug_noise_level = False, # unclear when conditioning on augmentation noise level, whether each batch element receives a random aug noise value - turning off due to @marunine's find
condition_on_text = True,
Expand Down Expand Up @@ -196,6 +197,14 @@ def __init__(
self.dynamic_thresholding = cast_tuple(dynamic_thresholding, num_unets)
self.dynamic_thresholding_percentile = dynamic_thresholding_percentile

# temporal interpolations

temporal_downsample_factor = cast_tuple(temporal_downsample_factor, num_unets)
self.temporal_downsample_factor = temporal_downsample_factor

assert temporal_downsample_factor[-1] == 1, 'downsample factor of last stage must be 1'
assert all([left >= right for left, right in zip((1, *temporal_downsample_factor[:-1]), temporal_downsample_factor[1:])]), 'temporal downssample factor must be in order of descending'

# elucidating parameters

hparams = [
Expand Down Expand Up @@ -589,7 +598,12 @@ def sample(

assert not (self.is_video and not exists(video_frames)), 'video_frames must be passed in on sample time if training on video'

frame_dims = (video_frames,) if self.is_video else tuple()
# determine the frame dimensions, if needed

if self.is_video:
all_frame_dims = calc_all_frame_dims(self.temporal_downsample_factor, video_frames)
else:
all_frame_dims = (tuple(),) * num_unets

# initializing with an image or video

Expand All @@ -613,7 +627,7 @@ def sample(

# go through each unet in cascade

for unet_number, unet, channel, image_size, unet_hparam, dynamic_threshold, unet_cond_scale, unet_init_images, unet_skip_steps, unet_sigma_min, unet_sigma_max in tqdm(zip(range(1, num_unets + 1), self.unets, self.sample_channels, self.image_sizes, self.hparams, self.dynamic_thresholding, cond_scale, init_images, skip_steps, sigma_min, sigma_max), disable = not use_tqdm):
for unet_number, unet, channel, image_size, frame_dims, unet_hparam, dynamic_threshold, unet_cond_scale, unet_init_images, unet_skip_steps, unet_sigma_min, unet_sigma_max in tqdm(zip(range(1, num_unets + 1), self.unets, self.sample_channels, self.image_sizes, all_frame_dims, self.hparams, self.dynamic_thresholding, cond_scale, init_images, skip_steps, sigma_min, sigma_max), disable = not use_tqdm):
if unet_number < start_at_unet_number:
continue

Expand All @@ -626,16 +640,20 @@ def sample(

shape = (batch_size, channel, *frame_dims, image_size, image_size)

resize_kwargs = dict()
if self.is_video:
resize_kwargs = dict(target_frames = frame_dims[0])

if unet.lowres_cond:
lowres_noise_times = self.lowres_noise_schedule.get_times(batch_size, lowres_sample_noise_level, device = device)

lowres_cond_img = self.resize_to(img, image_size)
lowres_cond_img = self.resize_to(img, image_size, **resize_kwargs)
lowres_cond_img = self.normalize_img(lowres_cond_img)

lowres_cond_img, *_ = self.lowres_noise_schedule.q_sample(x_start = lowres_cond_img, t = lowres_noise_times, noise = torch.randn_like(lowres_cond_img))

if exists(unet_init_images):
unet_init_images = self.resize_to(unet_init_images, image_size)
unet_init_images = self.resize_to(unet_init_images, image_size, **resize_kwargs)

shape = (batch_size, self.channels, *frame_dims, image_size, image_size)

Expand Down Expand Up @@ -710,7 +728,7 @@ def forward(
images = cast_uint8_images_to_float(images)
cond_images = maybe(cast_uint8_images_to_float)(cond_images)

assert is_float_dtype(images.dtype), f'images tensor needs to be floats but {images.dtype} dtype found instead'
assert images.dtype == torch.float, f'images tensor needs to be floats but {images.dtype} dtype found instead'

unet_index = unet_number - 1

Expand All @@ -725,7 +743,13 @@ def forward(

batch_size, c, *_, h, w, device, is_video = *images.shape, images.device, (images.ndim == 5)

frames = images.shape[2] if is_video else None
frames = images.shape[2] if is_video else None
all_frame_dims = tuple(el[0] for el in calc_all_frame_dims(self.temporal_downsample_factor, frames))
ignore_time = kwargs.get('ignore_time', False)

target_frame_size = all_frame_dims[unet_index] if is_video and not ignore_time else None
prev_frame_size = all_frame_dims[unet_index - 1] if is_video and not ignore_time and unet_index > 0 else None
frames_to_resize_kwargs = lambda frames: dict(target_frames = frames) if exists(frames) else dict()

check_shape(images, 'b c ...', c = self.channels)

Expand All @@ -750,16 +774,16 @@ def forward(

lowres_cond_img = lowres_aug_times = None
if exists(prev_image_size):
lowres_cond_img = self.resize_to(images, prev_image_size, clamp_range = self.input_image_range)
lowres_cond_img = self.resize_to(lowres_cond_img, target_image_size, clamp_range = self.input_image_range)
lowres_cond_img = self.resize_to(images, prev_image_size, **frames_to_resize_kwargs(prev_frame_size), clamp_range = self.input_image_range)
lowres_cond_img = self.resize_to(lowres_cond_img, target_image_size, **frames_to_resize_kwargs(target_frame_size), clamp_range = self.input_image_range)

if self.per_sample_random_aug_noise_level:
lowres_aug_times = self.lowres_noise_schedule.sample_random_times(batch_size, device = device)
else:
lowres_aug_time = self.lowres_noise_schedule.sample_random_times(1, device = device)
lowres_aug_times = repeat(lowres_aug_time, '1 -> b', b = batch_size)

images = self.resize_to(images, target_image_size)
images = self.resize_to(images, target_image_size, **frames_to_resize_kwargs(target_frame_size))

# normalize to [-1, 1]

Expand Down
64 changes: 49 additions & 15 deletions imagen_pytorch/imagen_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ def exists(val):
def identity(t, *args, **kwargs):
return t

def divisible_by(numer, denom):
return (numer % denom) == 0

def first(arr, d = None):
if len(arr) == 0:
return d
Expand Down Expand Up @@ -77,9 +80,6 @@ def cast_tuple(val, length = None):

return output

def is_float_dtype(dtype):
return any([dtype == float_dtype for float_dtype in (torch.float64, torch.float32, torch.float16, torch.bfloat16)])

def cast_uint8_images_to_float(images):
if not images.dtype == torch.uint8:
return images
Expand Down Expand Up @@ -158,6 +158,18 @@ def resize_image_to(

return out

def calc_all_frame_dims(
downsample_factors: List[int],
frames
):
all_frame_dims = []

for divisor in downsample_factors:
assert divisible_by(frames, divisor)
all_frame_dims.append((frames // divisor,))

return all_frame_dims

# image normalization functions
# ddpms expect images to be in the range of -1 to 1

Expand Down Expand Up @@ -1124,7 +1136,7 @@ def __init__(
cosine_sim_attn = False,
self_cond = False,
combine_upsample_fmaps = False, # combine feature maps from all upsample blocks, used in unet squared successfully
pixel_shuffle_upsample = True # may address checkboard artifacts
pixel_shuffle_upsample = True, # may address checkboard artifacts
):
super().__init__()

Expand Down Expand Up @@ -1772,7 +1784,8 @@ def __init__(
p2_loss_weight_k = 1,
dynamic_thresholding = True,
dynamic_thresholding_percentile = 0.95, # unsure what this was based on perusal of paper
only_train_unet_number = None
only_train_unet_number = None,
temporal_downsample_factor = 1
):
super().__init__()

Expand Down Expand Up @@ -1882,6 +1895,13 @@ def __init__(
self.right_pad_dims_to_datatype = partial(rearrange, pattern = ('b -> b 1 1 1' if not is_video else 'b -> b 1 1 1 1'))
self.resize_to = resize_video_to if is_video else resize_image_to

# temporal interpolation

temporal_downsample_factor = cast_tuple(temporal_downsample_factor, num_unets)
self.temporal_downsample_factor = temporal_downsample_factor
assert temporal_downsample_factor[-1] == 1, 'downsample factor of last stage must be 1'
assert all([left >= right for left, right in zip((1, *temporal_downsample_factor[:-1]), temporal_downsample_factor[1:])]), 'temporal downssample factor must be in order of descending'

# cascading ddpm related stuff

lowres_conditions = tuple(map(lambda t: t.lowres_cond, self.unets))
Expand Down Expand Up @@ -2240,7 +2260,12 @@ def sample(

assert not (self.is_video and not exists(video_frames)), 'video_frames must be passed in on sample time if training on video'

frame_dims = (video_frames,) if self.is_video else tuple()
if self.is_video:
all_frame_dims = calc_all_frame_dims(self.temporal_downsample_factor, video_frames)
else:
all_frame_dims = (tuple(),) * num_unets

frames_to_resize_kwargs = lambda frames: dict(target_frames = frames) if exists(frames) else dict()

# for initial image and skipping steps

Expand All @@ -2257,11 +2282,12 @@ def sample(
assert exists(start_image_or_video), 'starting image or video must be supplied if only doing upscaling'

prev_image_size = self.image_sizes[start_at_unet_number - 2]
img = self.resize_to(start_image_or_video, prev_image_size)
prev_frame_size = all_frame_dims[start_at_unet_number - 2][0] if self.is_video else None
img = self.resize_to(start_image_or_video, prev_image_size, **frames_to_resize_kwargs(prev_frame_size))

# go through each unet in cascade

for unet_number, unet, channel, image_size, noise_scheduler, pred_objective, dynamic_threshold, unet_cond_scale, unet_init_images, unet_skip_steps in tqdm(zip(range(1, num_unets + 1), self.unets, self.sample_channels, self.image_sizes, self.noise_schedulers, self.pred_objectives, self.dynamic_thresholding, cond_scale, init_images, skip_steps), disable = not use_tqdm):
for unet_number, unet, channel, image_size, frame_dims, noise_scheduler, pred_objective, dynamic_threshold, unet_cond_scale, unet_init_images, unet_skip_steps in tqdm(zip(range(1, num_unets + 1), self.unets, self.sample_channels, self.image_sizes, all_frame_dims, self.noise_schedulers, self.pred_objectives, self.dynamic_thresholding, cond_scale, init_images, skip_steps), disable = not use_tqdm):

if unet_number < start_at_unet_number:
continue
Expand All @@ -2274,16 +2300,18 @@ def sample(
lowres_cond_img = lowres_noise_times = None
shape = (batch_size, channel, *frame_dims, image_size, image_size)

resize_kwargs = dict(target_frames = frame_dims[0]) if self.is_video else dict()

if unet.lowres_cond:
lowres_noise_times = self.lowres_noise_schedule.get_times(batch_size, lowres_sample_noise_level, device = device)

lowres_cond_img = self.resize_to(img, image_size)
lowres_cond_img = self.resize_to(img, image_size, **resize_kwargs)

lowres_cond_img = self.normalize_img(lowres_cond_img)
lowres_cond_img, *_ = self.lowres_noise_schedule.q_sample(x_start = lowres_cond_img, t = lowres_noise_times, noise = torch.randn_like(lowres_cond_img))

if exists(unet_init_images):
unet_init_images = self.resize_to(unet_init_images, image_size)
unet_init_images = self.resize_to(unet_init_images, image_size, **resize_kwargs)

shape = (batch_size, self.channels, *frame_dims, image_size, image_size)

Expand Down Expand Up @@ -2480,7 +2508,7 @@ def forward(
images = cast_uint8_images_to_float(images)
cond_images = maybe(cast_uint8_images_to_float)(cond_images)

assert is_float_dtype(images.dtype), f'images tensor needs to be floats but {images.dtype} dtype found instead'
assert images.dtype == torch.float, f'images tensor needs to be floats but {images.dtype} dtype found instead'

unet_index = unet_number - 1

Expand All @@ -2500,7 +2528,13 @@ def forward(
check_shape(images, 'b c ...', c = self.channels)
assert h >= target_image_size and w >= target_image_size

frames = images.shape[2] if is_video else None
frames = images.shape[2] if is_video else None
all_frame_dims = tuple(el[0] for el in calc_all_frame_dims(self.temporal_downsample_factor, frames))
ignore_time = kwargs.get('ignore_time', False)

target_frame_size = all_frame_dims[unet_index] if is_video and not ignore_time else None
prev_frame_size = all_frame_dims[unet_index - 1] if is_video and not ignore_time and unet_index > 0 else None
frames_to_resize_kwargs = lambda frames: dict(target_frames = frames) if exists(frames) else dict()

times = noise_scheduler.sample_random_times(b, device = device)

Expand All @@ -2523,15 +2557,15 @@ def forward(

lowres_cond_img = lowres_aug_times = None
if exists(prev_image_size):
lowres_cond_img = self.resize_to(images, prev_image_size, clamp_range = self.input_image_range)
lowres_cond_img = self.resize_to(lowres_cond_img, target_image_size, clamp_range = self.input_image_range)
lowres_cond_img = self.resize_to(images, prev_image_size, **frames_to_resize_kwargs(prev_frame_size), clamp_range = self.input_image_range)
lowres_cond_img = self.resize_to(lowres_cond_img, target_image_size, **frames_to_resize_kwargs(target_frame_size), clamp_range = self.input_image_range)

if self.per_sample_random_aug_noise_level:
lowres_aug_times = self.lowres_noise_schedule.sample_random_times(b, device = device)
else:
lowres_aug_time = self.lowres_noise_schedule.sample_random_times(1, device = device)
lowres_aug_times = repeat(lowres_aug_time, '1 -> b', b = b)

images = self.resize_to(images, target_image_size)
images = self.resize_to(images, target_image_size, **frames_to_resize_kwargs(target_frame_size))

return self.p_losses(unet, images, times, text_embeds = text_embeds, text_mask = text_masks, cond_images = cond_images, noise_scheduler = noise_scheduler, lowres_cond_img = lowres_cond_img, lowres_aug_times = lowres_aug_times, pred_objective = pred_objective, p2_loss_weight_gamma = p2_loss_weight_gamma, random_crop_size = random_crop_size, **kwargs)
12 changes: 6 additions & 6 deletions imagen_pytorch/imagen_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,23 +137,23 @@ def masked_mean(t, *, dim, mask = None):
def resize_video_to(
video,
target_image_size,
target_frames = None,
clamp_range = None
):
orig_video_size = video.shape[-1]

if orig_video_size == target_image_size:
return video


frames = video.shape[2]
video = rearrange(video, 'b c f h w -> (b f) c h w')
target_frames = default(target_frames, frames)

target_shape = (target_frames, target_image_size, target_image_size)

out = F.interpolate(video, target_image_size, mode = 'nearest')
out = F.interpolate(video, target_shape, mode = 'nearest')

if exists(clamp_range):
out = out.clamp(*clamp_range)

out = rearrange(out, '(b f) c h w -> b c f h w', f = frames)

return out

Expand Down Expand Up @@ -1600,7 +1600,7 @@ def forward(

batch_size, frames, device, dtype = x.shape[0], x.shape[2], x.device, x.dtype

assert ignore_time or divisible_by(frames, self.total_temporal_divisor), f'number of input frames must be divisible by {self.total_temporal_divisor}'
assert ignore_time or divisible_by(frames, self.total_temporal_divisor), f'number of input frames {frames} must be divisible by {self.total_temporal_divisor}'

# add self conditioning if needed

Expand Down
2 changes: 1 addition & 1 deletion imagen_pytorch/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '1.19.0'
__version__ = '1.20.0'

0 comments on commit 44da9be

Please sign in to comment.