Skip to content

Commit

Permalink
add ability to ignore time in unet3d, akin to video ddpm
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Dec 11, 2022
1 parent 8243df6 commit a455782
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 13 deletions.
4 changes: 3 additions & 1 deletion imagen_pytorch/elucidated_imagen.py
Original file line number Diff line number Diff line change
Expand Up @@ -693,7 +693,8 @@ def forward(
text_embeds = None,
text_masks = None,
unet_number = None,
cond_images = None
cond_images = None,
**kwargs
):
assert images.shape[-1] == images.shape[-2], f'the images you pass in must be a square, but received dimensions of {images.shape[2]}, {images.shape[-1]}'
assert not (len(self.unets) > 1 and not exists(unet_number)), f'you must specify which unet you want trained, from a range of 1 to {len(self.unets)}, if you are training cascading DDPM (multiple unets)'
Expand Down Expand Up @@ -803,6 +804,7 @@ def forward(
lowres_noise_times = self.lowres_noise_schedule.get_condition(lowres_aug_times),
lowres_cond_img = lowres_cond_img_noisy,
cond_drop_prob = self.cond_drop_prob,
**kwargs
)

# self conditioning - https://arxiv.org/abs/2208.04202 - training will be 25% slower
Expand Down
9 changes: 6 additions & 3 deletions imagen_pytorch/imagen_pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -2343,7 +2343,8 @@ def p_losses(
times_next = None,
pred_objective = 'noise',
p2_loss_weight_gamma = 0.,
random_crop_size = None
random_crop_size = None,
**kwargs
):
is_video = x_start.ndim == 5

Expand Down Expand Up @@ -2398,6 +2399,7 @@ def p_losses(
lowres_noise_times = self.lowres_noise_schedule.get_condition(lowres_aug_times),
lowres_cond_img = lowres_cond_img_noisy,
cond_drop_prob = self.cond_drop_prob,
**kwargs
)

# self condition if needed
Expand Down Expand Up @@ -2463,7 +2465,8 @@ def forward(
text_embeds = None,
text_masks = None,
unet_number = None,
cond_images = None
cond_images = None,
**kwargs
):
assert images.shape[-1] == images.shape[-2], f'the images you pass in must be a square, but received dimensions of {images.shape[2]}, {images.shape[-1]}'
assert not (len(self.unets) > 1 and not exists(unet_number)), f'you must specify which unet you want trained, from a range of 1 to {len(self.unets)}, if you are training cascading DDPM (multiple unets)'
Expand Down Expand Up @@ -2527,4 +2530,4 @@ def forward(

images = self.resize_to(images, target_image_size)

return self.p_losses(unet, images, times, text_embeds = text_embeds, text_mask = text_masks, cond_images = cond_images, noise_scheduler = noise_scheduler, lowres_cond_img = lowres_cond_img, lowres_aug_times = lowres_aug_times, pred_objective = pred_objective, p2_loss_weight_gamma = p2_loss_weight_gamma, random_crop_size = random_crop_size)
return self.p_losses(unet, images, times, text_embeds = text_embeds, text_mask = text_masks, cond_images = cond_images, noise_scheduler = noise_scheduler, lowres_cond_img = lowres_cond_img, lowres_aug_times = lowres_aug_times, pred_objective = pred_objective, p2_loss_weight_gamma = p2_loss_weight_gamma, random_crop_size = random_crop_size, **kwargs)
28 changes: 20 additions & 8 deletions imagen_pytorch/imagen_video/imagen_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -430,7 +430,13 @@ def __init__(
LayerNorm(dim)
)

def forward(self, x, context = None, mask = None, attn_bias = None):
def forward(
self,
x,
context = None,
mask = None,
attn_bias = None
):
b, n, device = *x.shape[:2], x.device

x = self.norm(x)
Expand Down Expand Up @@ -1505,7 +1511,8 @@ def forward(
text_mask = None,
cond_images = None,
self_cond = None,
cond_drop_prob = 0.
cond_drop_prob = 0.,
ignore_time = False
):
assert x.ndim == 5, 'input to 3d unet must have 5 dimensions (batch, channels, time, height, width)'

Expand Down Expand Up @@ -1661,8 +1668,10 @@ def forward(
hiddens.append(x)

x = attn_block(x, c)
x = temporal_peg(x)
x = temporal_attn(x, attn_bias = time_attn_bias)

if not ignore_time:
x = temporal_peg(x)
x = temporal_attn(x, attn_bias = time_attn_bias)

hiddens.append(x)

Expand All @@ -1674,8 +1683,9 @@ def forward(
if exists(self.mid_attn):
x = self.mid_attn(x)

x = self.mid_temporal_peg(x)
x = self.mid_temporal_attn(x, attn_bias = time_attn_bias)
if not ignore_time:
x = self.mid_temporal_peg(x)
x = self.mid_temporal_attn(x, attn_bias = time_attn_bias)

x = self.mid_block2(x, t, c)

Expand All @@ -1692,8 +1702,10 @@ def forward(
x = resnet_block(x, t)

x = attn_block(x, c)
x = temporal_peg(x)
x = temporal_attn(x, attn_bias = time_attn_bias)

if not ignore_time:
x = temporal_peg(x)
x = temporal_attn(x, attn_bias = time_attn_bias)

up_hiddens.append(x.contiguous())
x = upsample(x)
Expand Down
2 changes: 1 addition & 1 deletion imagen_pytorch/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '1.17.1'
__version__ = '1.17.2'

0 comments on commit a455782

Please sign in to comment.