diff --git a/README.md b/README.md
index cfff128..5ffe541 100644
--- a/README.md
+++ b/README.md
@@ -695,6 +695,7 @@ Anything! It is MIT licensed. In other words, you can freely copy / paste for yo
- [x] add v-parameterization (https://arxiv.org/abs/2202.00512) from imagen video paper, the only thing new
- [x] incorporate all learnings from make-a-video (https://makeavideo.studio/)
- [x] build out CLI tool for training, resuming training off config file
+- [x] allow for temporal interpolation at specific stages
- [ ] reread cogvideo and figure out how frame rate conditioning could be used
- [ ] bring in attention expertise for self attention layers in unet3d
@@ -716,7 +717,7 @@ Anything! It is MIT licensed. In other words, you can freely copy / paste for yo
- [ ] add textual inversion
- [ ] cleanup self conditioning to be extracted at imagen instantiation
- [ ] make sure eventual dreambooth works with imagen-video
-- [ ] allow for temporal interpolation at specific stages
+- [ ] make sure temporal interpolation works with inpainting
## Citations
diff --git a/imagen_pytorch/elucidated_imagen.py b/imagen_pytorch/elucidated_imagen.py
index 62dd583..ba63c35 100644
--- a/imagen_pytorch/elucidated_imagen.py
+++ b/imagen_pytorch/elucidated_imagen.py
@@ -29,11 +29,11 @@
default,
cast_tuple,
cast_uint8_images_to_float,
- is_float_dtype,
eval_decorator,
check_shape,
pad_tuple_to_length,
resize_image_to,
+ calc_all_frame_dims,
right_pad_dims_to,
module_device,
normalize_neg_one_to_one,
@@ -83,6 +83,7 @@ def __init__(
channels = 3,
cond_drop_prob = 0.1,
random_crop_sizes = None,
+ temporal_downsample_factor = 1,
lowres_sample_noise_level = 0.2, # in the paper, they present a new trick where they noise the lowres conditioning image, and at sample time, fix it to a certain level (0.1 or 0.3) - the unets are also made to be conditioned on this noise level
per_sample_random_aug_noise_level = False, # unclear when conditioning on augmentation noise level, whether each batch element receives a random aug noise value - turning off due to @marunine's find
condition_on_text = True,
@@ -196,6 +197,14 @@ def __init__(
self.dynamic_thresholding = cast_tuple(dynamic_thresholding, num_unets)
self.dynamic_thresholding_percentile = dynamic_thresholding_percentile
+ # temporal interpolations
+
+ temporal_downsample_factor = cast_tuple(temporal_downsample_factor, num_unets)
+ self.temporal_downsample_factor = temporal_downsample_factor
+
+ assert temporal_downsample_factor[-1] == 1, 'downsample factor of last stage must be 1'
+ assert all([left >= right for left, right in zip((1, *temporal_downsample_factor[:-1]), temporal_downsample_factor[1:])]), 'temporal downssample factor must be in order of descending'
+
# elucidating parameters
hparams = [
@@ -589,7 +598,12 @@ def sample(
assert not (self.is_video and not exists(video_frames)), 'video_frames must be passed in on sample time if training on video'
- frame_dims = (video_frames,) if self.is_video else tuple()
+ # determine the frame dimensions, if needed
+
+ if self.is_video:
+ all_frame_dims = calc_all_frame_dims(self.temporal_downsample_factor, video_frames)
+ else:
+ all_frame_dims = (tuple(),) * num_unets
# initializing with an image or video
@@ -613,7 +627,7 @@ def sample(
# go through each unet in cascade
- for unet_number, unet, channel, image_size, unet_hparam, dynamic_threshold, unet_cond_scale, unet_init_images, unet_skip_steps, unet_sigma_min, unet_sigma_max in tqdm(zip(range(1, num_unets + 1), self.unets, self.sample_channels, self.image_sizes, self.hparams, self.dynamic_thresholding, cond_scale, init_images, skip_steps, sigma_min, sigma_max), disable = not use_tqdm):
+ for unet_number, unet, channel, image_size, frame_dims, unet_hparam, dynamic_threshold, unet_cond_scale, unet_init_images, unet_skip_steps, unet_sigma_min, unet_sigma_max in tqdm(zip(range(1, num_unets + 1), self.unets, self.sample_channels, self.image_sizes, all_frame_dims, self.hparams, self.dynamic_thresholding, cond_scale, init_images, skip_steps, sigma_min, sigma_max), disable = not use_tqdm):
if unet_number < start_at_unet_number:
continue
@@ -626,16 +640,20 @@ def sample(
shape = (batch_size, channel, *frame_dims, image_size, image_size)
+ resize_kwargs = dict()
+ if self.is_video:
+ resize_kwargs = dict(target_frames = frame_dims[0])
+
if unet.lowres_cond:
lowres_noise_times = self.lowres_noise_schedule.get_times(batch_size, lowres_sample_noise_level, device = device)
- lowres_cond_img = self.resize_to(img, image_size)
+ lowres_cond_img = self.resize_to(img, image_size, **resize_kwargs)
lowres_cond_img = self.normalize_img(lowres_cond_img)
lowres_cond_img, *_ = self.lowres_noise_schedule.q_sample(x_start = lowres_cond_img, t = lowres_noise_times, noise = torch.randn_like(lowres_cond_img))
if exists(unet_init_images):
- unet_init_images = self.resize_to(unet_init_images, image_size)
+ unet_init_images = self.resize_to(unet_init_images, image_size, **resize_kwargs)
shape = (batch_size, self.channels, *frame_dims, image_size, image_size)
@@ -710,7 +728,7 @@ def forward(
images = cast_uint8_images_to_float(images)
cond_images = maybe(cast_uint8_images_to_float)(cond_images)
- assert is_float_dtype(images.dtype), f'images tensor needs to be floats but {images.dtype} dtype found instead'
+ assert images.dtype == torch.float, f'images tensor needs to be floats but {images.dtype} dtype found instead'
unet_index = unet_number - 1
@@ -725,7 +743,13 @@ def forward(
batch_size, c, *_, h, w, device, is_video = *images.shape, images.device, (images.ndim == 5)
- frames = images.shape[2] if is_video else None
+ frames = images.shape[2] if is_video else None
+ all_frame_dims = tuple(el[0] for el in calc_all_frame_dims(self.temporal_downsample_factor, frames))
+ ignore_time = kwargs.get('ignore_time', False)
+
+ target_frame_size = all_frame_dims[unet_index] if is_video and not ignore_time else None
+ prev_frame_size = all_frame_dims[unet_index - 1] if is_video and not ignore_time and unet_index > 0 else None
+ frames_to_resize_kwargs = lambda frames: dict(target_frames = frames) if exists(frames) else dict()
check_shape(images, 'b c ...', c = self.channels)
@@ -750,8 +774,8 @@ def forward(
lowres_cond_img = lowres_aug_times = None
if exists(prev_image_size):
- lowres_cond_img = self.resize_to(images, prev_image_size, clamp_range = self.input_image_range)
- lowres_cond_img = self.resize_to(lowres_cond_img, target_image_size, clamp_range = self.input_image_range)
+ lowres_cond_img = self.resize_to(images, prev_image_size, **frames_to_resize_kwargs(prev_frame_size), clamp_range = self.input_image_range)
+ lowres_cond_img = self.resize_to(lowres_cond_img, target_image_size, **frames_to_resize_kwargs(target_frame_size), clamp_range = self.input_image_range)
if self.per_sample_random_aug_noise_level:
lowres_aug_times = self.lowres_noise_schedule.sample_random_times(batch_size, device = device)
@@ -759,7 +783,7 @@ def forward(
lowres_aug_time = self.lowres_noise_schedule.sample_random_times(1, device = device)
lowres_aug_times = repeat(lowres_aug_time, '1 -> b', b = batch_size)
- images = self.resize_to(images, target_image_size)
+ images = self.resize_to(images, target_image_size, **frames_to_resize_kwargs(target_frame_size))
# normalize to [-1, 1]
diff --git a/imagen_pytorch/imagen_pytorch.py b/imagen_pytorch/imagen_pytorch.py
index dd82aa9..6fb2eb4 100644
--- a/imagen_pytorch/imagen_pytorch.py
+++ b/imagen_pytorch/imagen_pytorch.py
@@ -35,6 +35,9 @@ def exists(val):
def identity(t, *args, **kwargs):
return t
+def divisible_by(numer, denom):
+ return (numer % denom) == 0
+
def first(arr, d = None):
if len(arr) == 0:
return d
@@ -77,9 +80,6 @@ def cast_tuple(val, length = None):
return output
-def is_float_dtype(dtype):
- return any([dtype == float_dtype for float_dtype in (torch.float64, torch.float32, torch.float16, torch.bfloat16)])
-
def cast_uint8_images_to_float(images):
if not images.dtype == torch.uint8:
return images
@@ -158,6 +158,18 @@ def resize_image_to(
return out
+def calc_all_frame_dims(
+ downsample_factors: List[int],
+ frames
+):
+ all_frame_dims = []
+
+ for divisor in downsample_factors:
+ assert divisible_by(frames, divisor)
+ all_frame_dims.append((frames // divisor,))
+
+ return all_frame_dims
+
# image normalization functions
# ddpms expect images to be in the range of -1 to 1
@@ -1124,7 +1136,7 @@ def __init__(
cosine_sim_attn = False,
self_cond = False,
combine_upsample_fmaps = False, # combine feature maps from all upsample blocks, used in unet squared successfully
- pixel_shuffle_upsample = True # may address checkboard artifacts
+ pixel_shuffle_upsample = True, # may address checkboard artifacts
):
super().__init__()
@@ -1772,7 +1784,8 @@ def __init__(
p2_loss_weight_k = 1,
dynamic_thresholding = True,
dynamic_thresholding_percentile = 0.95, # unsure what this was based on perusal of paper
- only_train_unet_number = None
+ only_train_unet_number = None,
+ temporal_downsample_factor = 1
):
super().__init__()
@@ -1882,6 +1895,13 @@ def __init__(
self.right_pad_dims_to_datatype = partial(rearrange, pattern = ('b -> b 1 1 1' if not is_video else 'b -> b 1 1 1 1'))
self.resize_to = resize_video_to if is_video else resize_image_to
+ # temporal interpolation
+
+ temporal_downsample_factor = cast_tuple(temporal_downsample_factor, num_unets)
+ self.temporal_downsample_factor = temporal_downsample_factor
+ assert temporal_downsample_factor[-1] == 1, 'downsample factor of last stage must be 1'
+ assert all([left >= right for left, right in zip((1, *temporal_downsample_factor[:-1]), temporal_downsample_factor[1:])]), 'temporal downssample factor must be in order of descending'
+
# cascading ddpm related stuff
lowres_conditions = tuple(map(lambda t: t.lowres_cond, self.unets))
@@ -2240,7 +2260,12 @@ def sample(
assert not (self.is_video and not exists(video_frames)), 'video_frames must be passed in on sample time if training on video'
- frame_dims = (video_frames,) if self.is_video else tuple()
+ if self.is_video:
+ all_frame_dims = calc_all_frame_dims(self.temporal_downsample_factor, video_frames)
+ else:
+ all_frame_dims = (tuple(),) * num_unets
+
+ frames_to_resize_kwargs = lambda frames: dict(target_frames = frames) if exists(frames) else dict()
# for initial image and skipping steps
@@ -2257,11 +2282,12 @@ def sample(
assert exists(start_image_or_video), 'starting image or video must be supplied if only doing upscaling'
prev_image_size = self.image_sizes[start_at_unet_number - 2]
- img = self.resize_to(start_image_or_video, prev_image_size)
+ prev_frame_size = all_frame_dims[start_at_unet_number - 2][0] if self.is_video else None
+ img = self.resize_to(start_image_or_video, prev_image_size, **frames_to_resize_kwargs(prev_frame_size))
# go through each unet in cascade
- for unet_number, unet, channel, image_size, noise_scheduler, pred_objective, dynamic_threshold, unet_cond_scale, unet_init_images, unet_skip_steps in tqdm(zip(range(1, num_unets + 1), self.unets, self.sample_channels, self.image_sizes, self.noise_schedulers, self.pred_objectives, self.dynamic_thresholding, cond_scale, init_images, skip_steps), disable = not use_tqdm):
+ for unet_number, unet, channel, image_size, frame_dims, noise_scheduler, pred_objective, dynamic_threshold, unet_cond_scale, unet_init_images, unet_skip_steps in tqdm(zip(range(1, num_unets + 1), self.unets, self.sample_channels, self.image_sizes, all_frame_dims, self.noise_schedulers, self.pred_objectives, self.dynamic_thresholding, cond_scale, init_images, skip_steps), disable = not use_tqdm):
if unet_number < start_at_unet_number:
continue
@@ -2274,16 +2300,18 @@ def sample(
lowres_cond_img = lowres_noise_times = None
shape = (batch_size, channel, *frame_dims, image_size, image_size)
+ resize_kwargs = dict(target_frames = frame_dims[0]) if self.is_video else dict()
+
if unet.lowres_cond:
lowres_noise_times = self.lowres_noise_schedule.get_times(batch_size, lowres_sample_noise_level, device = device)
- lowres_cond_img = self.resize_to(img, image_size)
+ lowres_cond_img = self.resize_to(img, image_size, **resize_kwargs)
lowres_cond_img = self.normalize_img(lowres_cond_img)
lowres_cond_img, *_ = self.lowres_noise_schedule.q_sample(x_start = lowres_cond_img, t = lowres_noise_times, noise = torch.randn_like(lowres_cond_img))
if exists(unet_init_images):
- unet_init_images = self.resize_to(unet_init_images, image_size)
+ unet_init_images = self.resize_to(unet_init_images, image_size, **resize_kwargs)
shape = (batch_size, self.channels, *frame_dims, image_size, image_size)
@@ -2480,7 +2508,7 @@ def forward(
images = cast_uint8_images_to_float(images)
cond_images = maybe(cast_uint8_images_to_float)(cond_images)
- assert is_float_dtype(images.dtype), f'images tensor needs to be floats but {images.dtype} dtype found instead'
+ assert images.dtype == torch.float, f'images tensor needs to be floats but {images.dtype} dtype found instead'
unet_index = unet_number - 1
@@ -2500,7 +2528,13 @@ def forward(
check_shape(images, 'b c ...', c = self.channels)
assert h >= target_image_size and w >= target_image_size
- frames = images.shape[2] if is_video else None
+ frames = images.shape[2] if is_video else None
+ all_frame_dims = tuple(el[0] for el in calc_all_frame_dims(self.temporal_downsample_factor, frames))
+ ignore_time = kwargs.get('ignore_time', False)
+
+ target_frame_size = all_frame_dims[unet_index] if is_video and not ignore_time else None
+ prev_frame_size = all_frame_dims[unet_index - 1] if is_video and not ignore_time and unet_index > 0 else None
+ frames_to_resize_kwargs = lambda frames: dict(target_frames = frames) if exists(frames) else dict()
times = noise_scheduler.sample_random_times(b, device = device)
@@ -2523,8 +2557,8 @@ def forward(
lowres_cond_img = lowres_aug_times = None
if exists(prev_image_size):
- lowres_cond_img = self.resize_to(images, prev_image_size, clamp_range = self.input_image_range)
- lowres_cond_img = self.resize_to(lowres_cond_img, target_image_size, clamp_range = self.input_image_range)
+ lowres_cond_img = self.resize_to(images, prev_image_size, **frames_to_resize_kwargs(prev_frame_size), clamp_range = self.input_image_range)
+ lowres_cond_img = self.resize_to(lowres_cond_img, target_image_size, **frames_to_resize_kwargs(target_frame_size), clamp_range = self.input_image_range)
if self.per_sample_random_aug_noise_level:
lowres_aug_times = self.lowres_noise_schedule.sample_random_times(b, device = device)
@@ -2532,6 +2566,6 @@ def forward(
lowres_aug_time = self.lowres_noise_schedule.sample_random_times(1, device = device)
lowres_aug_times = repeat(lowres_aug_time, '1 -> b', b = b)
- images = self.resize_to(images, target_image_size)
+ images = self.resize_to(images, target_image_size, **frames_to_resize_kwargs(target_frame_size))
return self.p_losses(unet, images, times, text_embeds = text_embeds, text_mask = text_masks, cond_images = cond_images, noise_scheduler = noise_scheduler, lowres_cond_img = lowres_cond_img, lowres_aug_times = lowres_aug_times, pred_objective = pred_objective, p2_loss_weight_gamma = p2_loss_weight_gamma, random_crop_size = random_crop_size, **kwargs)
diff --git a/imagen_pytorch/imagen_video.py b/imagen_pytorch/imagen_video.py
index f05f04a..005e2f8 100644
--- a/imagen_pytorch/imagen_video.py
+++ b/imagen_pytorch/imagen_video.py
@@ -137,6 +137,7 @@ def masked_mean(t, *, dim, mask = None):
def resize_video_to(
video,
target_image_size,
+ target_frames = None,
clamp_range = None
):
orig_video_size = video.shape[-1]
@@ -144,16 +145,15 @@ def resize_video_to(
if orig_video_size == target_image_size:
return video
-
frames = video.shape[2]
- video = rearrange(video, 'b c f h w -> (b f) c h w')
+ target_frames = default(target_frames, frames)
+
+ target_shape = (target_frames, target_image_size, target_image_size)
- out = F.interpolate(video, target_image_size, mode = 'nearest')
+ out = F.interpolate(video, target_shape, mode = 'nearest')
if exists(clamp_range):
out = out.clamp(*clamp_range)
-
- out = rearrange(out, '(b f) c h w -> b c f h w', f = frames)
return out
@@ -1600,7 +1600,7 @@ def forward(
batch_size, frames, device, dtype = x.shape[0], x.shape[2], x.device, x.dtype
- assert ignore_time or divisible_by(frames, self.total_temporal_divisor), f'number of input frames must be divisible by {self.total_temporal_divisor}'
+ assert ignore_time or divisible_by(frames, self.total_temporal_divisor), f'number of input frames {frames} must be divisible by {self.total_temporal_divisor}'
# add self conditioning if needed
diff --git a/imagen_pytorch/version.py b/imagen_pytorch/version.py
index 8ac48f0..bec0738 100644
--- a/imagen_pytorch/version.py
+++ b/imagen_pytorch/version.py
@@ -1 +1 @@
-__version__ = '1.19.0'
+__version__ = '1.20.0'