Skip to content

Commit

Permalink
add vanilla res blocks
Browse files Browse the repository at this point in the history
  • Loading branch information
patil-suraj committed Jul 20, 2023
1 parent 5397da8 commit 53b15aa
Show file tree
Hide file tree
Showing 2 changed files with 282 additions and 79 deletions.
231 changes: 209 additions & 22 deletions muse/modeling_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def forward(self, x):
class ResBlock(nn.Module):
def __init__(
self,
channels,
in_channels,
skip_channels=None,
kernel_size=3,
dropout=0.0,
Expand All @@ -168,30 +168,31 @@ def __init__(
add_cond_embeds=False,
cond_embed_dim=None,
use_bias=False,
**kwargs,
):
super().__init__()
self.depthwise = nn.Conv2d(
channels + skip_channels,
channels,
in_channels + skip_channels,
in_channels,
kernel_size=kernel_size,
padding=kernel_size // 2,
groups=channels,
groups=in_channels,
bias=use_bias,
)
self.norm = Norm2D(
channels, eps=1e-6, norm_type=norm_type, use_bias=use_bias, elementwise_affine=ln_elementwise_affine
in_channels, eps=1e-6, norm_type=norm_type, use_bias=use_bias, elementwise_affine=ln_elementwise_affine
)
self.channelwise = nn.Sequential(
nn.Linear(channels, channels * 4, bias=use_bias),
nn.Linear(in_channels, in_channels * 4, bias=use_bias),
nn.GELU(),
GlobalResponseNorm(channels * 4),
GlobalResponseNorm(in_channels * 4),
nn.Dropout(dropout),
nn.Linear(channels * 4, channels, bias=use_bias),
nn.Linear(in_channels * 4, in_channels, bias=use_bias),
)

if add_cond_embeds:
self.adaLN_modulation = AdaLNModulation(
cond_embed_dim=cond_embed_dim, hidden_size=channels, use_bias=use_bias
cond_embed_dim=cond_embed_dim, hidden_size=in_channels, use_bias=use_bias
)

def forward(self, x, x_skip=None, cond_embeds=None):
Expand All @@ -206,6 +207,51 @@ def forward(self, x, x_skip=None, cond_embeds=None):
return x


class ResnetBlockVanilla(nn.Module):
def __init__(self, in_channels, out_channels=None, conv_shortcut=False, dropout=0.0, use_bias=False, **kwargs):
super().__init__()
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut

self.norm1 = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=1e-6, affine=True)
self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=use_bias)

self.norm2 = torch.nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=1e-6, affine=True)
self.dropout = torch.nn.Dropout(dropout)
self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=use_bias)

if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
self.conv_shortcut = torch.nn.Conv2d(
in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=use_bias
)
else:
self.nin_shortcut = torch.nn.Conv2d(
in_channels, out_channels, kernel_size=1, stride=1, padding=0, bias=use_bias
)

def forward(self, hidden_states, **kwargs):
residual = hidden_states
hidden_states = self.norm1(hidden_states)
hidden_states = F.silu(hidden_states)
hidden_states = self.conv1(hidden_states)

hidden_states = self.norm2(hidden_states)
hidden_states = F.silu(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.conv2(hidden_states)

if self.in_channels != self.out_channels:
if self.use_conv_shortcut:
residual = self.conv_shortcut(residual)
else:
residual = self.nin_shortcut(residual)

return residual + hidden_states


class DownsampleBlock(nn.Module):
def __init__(
self,
Expand All @@ -221,6 +267,7 @@ def __init__(
add_cond_embeds=False,
cond_embed_dim=None,
use_bias=False,
**kwargs,
):
super().__init__()
self.add_downsample = add_downsample
Expand Down Expand Up @@ -258,7 +305,7 @@ def __init__(

self.gradient_checkpointing = False

def forward(self, x, x_skip=None, cond_embeds=None):
def forward(self, x, x_skip=None, cond_embeds=None, **kwargs):
if self.add_downsample:
x = self.downsample(x)

Expand Down Expand Up @@ -295,6 +342,7 @@ def __init__(
add_cond_embeds=False,
cond_embed_dim=None,
use_bias=False,
**kwargs,
):
super().__init__()
self.add_upsample = add_upsample
Expand Down Expand Up @@ -332,7 +380,7 @@ def __init__(

self.gradient_checkpointing = False

def forward(self, x, x_skip=None, cond_embeds=None):
def forward(self, x, x_skip=None, cond_embeds=None, **kwargs):
for i, res_block in enumerate(self.res_blocks):
x_res = x_skip[0] if i == 0 and x_skip is not None else None

Expand All @@ -353,6 +401,121 @@ def custom_forward(*inputs):
return x


class DownsampleBlockVanilla(nn.Module):
def __init__(
self,
input_channels,
output_channels=None,
num_res_blocks=4,
dropout=0.0,
add_downsample=True,
use_bias=False,
**kwargs,
):
super().__init__()
self.add_downsample = add_downsample

res_blocks = []
for i in range(num_res_blocks):
in_channels = input_channels if i == 0 else output_channels
res_blocks.append(
ResnetBlockVanilla(
in_channels=in_channels, out_channels=output_channels, dropout=dropout, use_bias=use_bias
)
)
self.res_blocks = nn.ModuleList(res_blocks)

if add_downsample:
self.downsample_conv = nn.Conv2d(output_channels, output_channels, 3, stride=2, bias=use_bias)

self.gradient_checkpointing = False

def forward(self, x, **kwargs):
output_states = ()
for res_block in self.res_blocks:
if self.training and self.gradient_checkpointing:

def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)

return custom_forward

x = torch.utils.checkpoint.checkpoint(create_custom_forward(res_block), x)
else:
x = res_block(x)

output_states = output_states + (x,)

if self.add_downsample:
pad = (0, 1, 0, 1)
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
x = self.downsample_conv(x)
output_states = output_states + (x,)

return x, output_states


class UpsampleBlockVanilla(nn.Module):
def __init__(
self,
input_channels,
output_channels,
skip_channels=None,
num_res_blocks=4,
dropout=0.0,
add_upsample=True,
use_bias=False,
**kwargs,
):
super().__init__()
self.add_upsample = add_upsample
res_blocks = []
for i in range(num_res_blocks):
res_skip_channels = input_channels if (i == num_res_blocks - 1) else output_channels
resnet_in_channels = skip_channels if i == 0 else output_channels

res_blocks.append(
ResnetBlockVanilla(
in_channels=resnet_in_channels + res_skip_channels,
out_channels=output_channels,
dropout=dropout,
)
)
self.res_blocks = nn.ModuleList(res_blocks)

if add_upsample:
self.upsample_conv = nn.Conv2d(output_channels, output_channels, 3, padding=1)

self.gradient_checkpointing = False

def forward(self, x, x_skip, **kwargs):
for res_block in self.res_blocks:
# pop res hidden states
res_hidden_states = x_skip[-1]
x_skip = x_skip[:-1]
x = torch.cat([x, res_hidden_states], dim=1)
if self.training and self.gradient_checkpointing:

def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)

return custom_forward

x = torch.utils.checkpoint.checkpoint(create_custom_forward(res_block), x)
else:
x = res_block(x)

if self.add_upsample:
if x.shape[0] >= 64:
x = x.contiguous()
x = F.interpolate(x, scale_factor=2.0, mode="nearest")
x = self.upsample_conv(x)

return x


# End U-ViT blocks


Expand Down Expand Up @@ -1186,6 +1349,7 @@ def __init__(
xavier_init_embed=True,
use_empty_embeds_for_uncond=False,
learn_uncond_embeds=False,
use_vannilla_resblock=False,
**kwargs,
):
super().__init__()
Expand Down Expand Up @@ -1239,14 +1403,22 @@ def __init__(
cond_embed_dim = hidden_size

# Downsample
DownBlock = DownsampleBlockVanilla if use_vannilla_resblock else DownsampleBlock
output_channels = block_out_channels[0]
self.down_blocks = nn.ModuleList([])
for i in range(len(block_out_channels)):
is_first_block = i == 0
is_final_block = i == len(block_out_channels) - 1
input_channels = output_channels
output_channels = block_out_channels[i]

if use_vannilla_resblock:
add_downsample = not is_final_block
else:
add_downsample = not is_first_block

self.down_blocks.append(
DownsampleBlock(
DownBlock(
input_channels=input_channels,
output_channels=output_channels,
skip_channels=0,
Expand All @@ -1255,7 +1427,7 @@ def __init__(
dropout=hidden_dropout if i == 0 else 0.0,
norm_type=norm_type,
ln_elementwise_affine=ln_elementwise_affine,
add_downsample=not is_first_block,
add_downsample=add_downsample,
add_cond_embeds=add_cond_embeds,
cond_embed_dim=cond_embed_dim,
use_bias=use_bias,
Expand Down Expand Up @@ -1302,20 +1474,28 @@ def __init__(
self.project_from_hidden = nn.Linear(hidden_size, block_out_channels[-1], bias=use_bias)

# Up sample
UpBlock = UpsampleBlockVanilla if use_vannilla_resblock else UpsampleBlock
reversed_block_out_channels = list(reversed(block_out_channels))
output_channels = reversed_block_out_channels[0]
self.up_blocks = nn.ModuleList([])
for i in range(len(reversed_block_out_channels)):
is_final_block = i == len(block_out_channels) - 1
input_channel = reversed_block_out_channels[i]
output_channels = reversed_block_out_channels[i + 1] if not is_final_block else output_channels
prev_output_channels = input_channel if i != 0 else 0

if use_vannilla_resblock:
prev_output_channels = output_channels
output_channels = reversed_block_out_channels[i]
input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
else:
input_channel = reversed_block_out_channels[i]
output_channels = reversed_block_out_channels[i + 1] if not is_final_block else output_channels
prev_output_channels = input_channel if i != 0 else 0

self.up_blocks.append(
UpsampleBlock(
UpBlock(
input_channels=input_channel,
skip_channels=prev_output_channels,
output_channels=output_channels,
num_res_blocks=num_res_blocks,
num_res_blocks=num_res_blocks + 1 if use_vannilla_resblock else num_res_blocks,
kernel_size=3,
dropout=hidden_dropout if i == 0 else 0.0,
norm_type=norm_type,
Expand Down Expand Up @@ -1431,7 +1611,7 @@ def forward(

hidden_states = self.embed(input_ids)

down_block_res_samples = ()
down_block_res_samples = (hidden_states,)
for down_block in self.down_blocks:
hidden_states, res_samples = down_block(hidden_states, cond_embeds=cond_embeds)
down_block_res_samples += res_samples
Expand Down Expand Up @@ -1477,9 +1657,16 @@ def custom_forward(*inputs):
hidden_states = hidden_states.reshape(batch_size, height, width, channels).permute(0, 3, 1, 2)

for i, up_block in enumerate(self.up_blocks):
res_samples = down_block_res_samples[-self.config.num_res_blocks :]
down_block_res_samples = down_block_res_samples[: -self.config.num_res_blocks]
hidden_states = up_block(hidden_states, x_skip=res_samples if i > 0 else None, cond_embeds=cond_embeds)
num_up_blocks = len(up_block.res_blocks)
res_samples = down_block_res_samples[-num_up_blocks:]
down_block_res_samples = down_block_res_samples[:-num_up_blocks]

if self.config.use_vannilla_resblock:
x_skip = res_samples
else:
x_skip = res_samples if i > 0 else None

hidden_states = up_block(hidden_states, x_skip=x_skip, cond_embeds=cond_embeds)

if self.config.layer_norm_before_mlm:
hidden_states = self.layer_norm_before_mlm(hidden_states)
Expand Down
Loading

0 comments on commit 53b15aa

Please sign in to comment.