Skip to content

Commit

Permalink
integrate the pseudo 3d convs into all resnet blocks within the unet3…
Browse files Browse the repository at this point in the history
…d - mashup of imagen video + make-a-video
  • Loading branch information
lucidrains committed Dec 11, 2022
1 parent a455782 commit 2c4004f
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 25 deletions.
74 changes: 50 additions & 24 deletions imagen_pytorch/imagen_video/imagen_video.py
Original file line number Diff line number Diff line change
Expand Up @@ -341,15 +341,16 @@ def forward(self, x, mask = None):

return latents

# pseudo 3d conv from make-a-video
# main contribution from make-a-video - pseudo conv3d
# axial space-time convolutions, but made causal to keep in line with the design decisions of imagen-video paper

class PseudoConv3D(nn.Module):
class Conv3d(nn.Module):
def __init__(
self,
dim,
*,
kernel_size,
dim_out = None,
kernel_size = 3,
*,
temporal_kernel_size = None,
**kwargs
):
Expand All @@ -358,20 +359,22 @@ def __init__(
temporal_kernel_size = default(temporal_kernel_size, kernel_size)

self.spatial_conv = nn.Conv2d(dim, dim_out, kernel_size = kernel_size, padding = kernel_size // 2)
self.temporal_conv = nn.Conv1d(dim_out, dim_out, kernel_size = temporal_kernel_size, padding = temporal_kernel_size // 2)
self.temporal_conv = nn.Conv1d(dim_out, dim_out, kernel_size = temporal_kernel_size) if kernel_size > 1 else None
self.kernel_size = kernel_size

nn.init.dirac_(self.temporal_conv.weight.data) # initialized to be identity
nn.init.zeros_(self.temporal_conv.bias.data)
if exists(self.temporal_conv):
nn.init.dirac_(self.temporal_conv.weight.data) # initialized to be identity
nn.init.zeros_(self.temporal_conv.bias.data)

def forward(
self,
x,
convolve_across_time = True
ignore_time = False
):
b, c, *_, h, w = x.shape

is_video = x.ndim == 5
convolve_across_time &= is_video
ignore_time &= is_video

if is_video:
x = rearrange(x, 'b c f h w -> (b f) c h w')
Expand All @@ -381,11 +384,16 @@ def forward(
if is_video:
x = rearrange(x, '(b f) c h w -> b c f h w', b = b)

if not convolve_across_time:
if not ignore_time or not exists(self.temporal_conv):
return x

x = rearrange(x, 'b c f h w -> (b h w) c f')

# causal temporal convolution - time is causal in imagen-video

if self.kernel_size > 1:
x = F.pad(x, (self.kernel_size - 1, 0))

x = self.temporal_conv(x)

x = rearrange(x, '(b h w) c f -> b c f h w', h = h, w = w)
Expand Down Expand Up @@ -612,17 +620,22 @@ def __init__(
super().__init__()
self.groupnorm = nn.GroupNorm(groups, dim) if norm else Identity()
self.activation = nn.SiLU()
self.project = Conv2d(dim, dim_out, 3, padding = 1)
self.project = Conv3d(dim, dim_out, 3, padding = 1)

def forward(self, x, scale_shift = None):
def forward(
self,
x,
scale_shift = None,
ignore_time = False
):
x = self.groupnorm(x)

if exists(scale_shift):
scale, shift = scale_shift
x = x * (scale + 1) + shift

x = self.activation(x)
return self.project(x)
return self.project(x, ignore_time = ignore_time)

class ResnetBlock(nn.Module):
def __init__(
Expand Down Expand Up @@ -671,21 +684,27 @@ def __init__(
self.res_conv = Conv2d(dim, dim_out, 1) if dim != dim_out else Identity()


def forward(self, x, time_emb = None, cond = None):
def forward(
self,
x,
time_emb = None,
cond = None,
ignore_time = False
):

scale_shift = None
if exists(self.time_mlp) and exists(time_emb):
time_emb = self.time_mlp(time_emb)
time_emb = rearrange(time_emb, 'b c -> b c 1 1 1')
scale_shift = time_emb.chunk(2, dim = 1)

h = self.block1(x)
h = self.block1(x, ignore_time = ignore_time)

if exists(self.cross_attn):
assert exists(cond)
h = self.cross_attn(h, context = cond) + h

h = self.block2(h, scale_shift = scale_shift)
h = self.block2(h, scale_shift = scale_shift, ignore_time = ignore_time)

h = h * self.gca(h)

Expand Down Expand Up @@ -1545,12 +1564,19 @@ def forward(

time_attn_bias = self.time_rel_pos_bias(frames, device = device, dtype = dtype)

# ignoring time in pseudo 3d resnet blocks

conv_kwargs = dict(
ignore_time = ignore_time
)

# initial convolution

x = self.init_conv(x)

x = self.init_temporal_peg(x)
x = self.init_temporal_attn(x, attn_bias = time_attn_bias)
if not ignore_time:
x = self.init_temporal_peg(x)
x = self.init_temporal_attn(x, attn_bias = time_attn_bias)

# init conv residual

Expand Down Expand Up @@ -1651,7 +1677,7 @@ def forward(
# initial resnet block (for memory efficient unet)

if exists(self.init_resnet_block):
x = self.init_resnet_block(x, t)
x = self.init_resnet_block(x, t, **conv_kwargs)

# go through the layers of the unet, down and up

Expand All @@ -1661,10 +1687,10 @@ def forward(
if exists(pre_downsample):
x = pre_downsample(x)

x = init_block(x, t, c)
x = init_block(x, t, c, **conv_kwargs)

for resnet_block in resnet_blocks:
x = resnet_block(x, t)
x = resnet_block(x, t, **conv_kwargs)
hiddens.append(x)

x = attn_block(x, c)
Expand Down Expand Up @@ -1695,11 +1721,11 @@ def forward(

for init_block, resnet_blocks, attn_block, temporal_peg, temporal_attn, upsample in self.ups:
x = add_skip_connection(x)
x = init_block(x, t, c)
x = init_block(x, t, c, **conv_kwargs)

for resnet_block in resnet_blocks:
x = add_skip_connection(x)
x = resnet_block(x, t)
x = resnet_block(x, t, **conv_kwargs)

x = attn_block(x, c)

Expand All @@ -1720,7 +1746,7 @@ def forward(
x = torch.cat((x, init_conv_residual), dim = 1)

if exists(self.final_res_block):
x = self.final_res_block(x, t)
x = self.final_res_block(x, t, **conv_kwargs)

if exists(lowres_cond_img):
x = torch.cat((x, lowres_cond_img), dim = 1)
Expand Down
2 changes: 1 addition & 1 deletion imagen_pytorch/version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '1.17.2'
__version__ = '1.18.0'

0 comments on commit 2c4004f

Please sign in to comment.