Skip to content

Commit

Permalink
Separate AdaLN and modulation modules
Browse files Browse the repository at this point in the history
  • Loading branch information
corystephenson-db committed Jul 25, 2024
1 parent 0df0095 commit ab641eb
Showing 1 changed file with 66 additions and 32 deletions.
98 changes: 66 additions & 32 deletions diffusion/models/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,59 @@ def get_multidimensional_position_embeddings(position_embeddings: torch.Tensor,
return sequenced_embeddings # (B, S, F, D)


class AdaptiveLayerNorm(nn.Module):
"""Adaptive LayerNorm.
Scales and shifts the output of a LayerNorm using an MLP conditioned on a scalar.
Args:
num_features (int): Number of input features.
"""

def __init__(self, num_features: int):
super().__init__()
self.num_features = num_features
# MLP for computing modulations.
# Initialized to zero so modulation acts as identity at initialization.
self.adaLN_mlp_linear = nn.Linear(self.num_features, 2 * self.num_features, bias=True)
nn.init.zeros_(self.adaLN_mlp_linear.weight)
nn.init.zeros_(self.adaLN_mlp_linear.bias)
self.adaLN_mlp = nn.Sequential(nn.SiLU(), self.adaLN_mlp_linear)
# LayerNorm
self.layernorm = nn.LayerNorm(self.num_features, elementwise_affine=True, eps=1e-6)

def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
# Calculate the modulations
mods = self.adaLN_mlp(t).unsqueeze(1).chunk(2, dim=2)
# Apply the modulations
return modulate(self.layernorm(x), mods[0], mods[1])


class ModulationLayer(nn.Module):
"""Modulation layer.
Scales the input by a factor determined by a scalar input.
Args:
num_features (int): Number of input features.
"""

def __init__(self, num_features: int):
super().__init__()
self.num_features = num_features
# MLP for computing modulation.
# Initialized to zero so modulation starts off at zero.
self.adaLN_mlp_linear = nn.Linear(self.num_features, self.num_features, bias=True)
nn.init.zeros_(self.adaLN_mlp_linear.weight)
nn.init.zeros_(self.adaLN_mlp_linear.bias)
self.adaLN_mlp = nn.Sequential(nn.SiLU(), self.adaLN_mlp_linear)

def forward(self, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
# Calculate the modulations
mods = self.adaLN_mlp(t).unsqueeze(1)
return x * mods


class ScalarEmbedding(nn.Module):
"""Embedding block for scalars.
Expand Down Expand Up @@ -121,14 +174,8 @@ class PreAttentionBlock(nn.Module):
def __init__(self, num_features: int):
super().__init__()
self.num_features = num_features

# AdaLN MLP for pre-attention. Initialized to zero so modulation acts as identity at initialization.
self.adaLN_mlp_linear = nn.Linear(self.num_features, 2 * self.num_features, bias=True)
nn.init.zeros_(self.adaLN_mlp_linear.weight)
nn.init.zeros_(self.adaLN_mlp_linear.bias)
self.adaLN_mlp = nn.Sequential(nn.SiLU(), self.adaLN_mlp_linear)
# Input layernorm
self.input_norm = nn.LayerNorm(self.num_features, elementwise_affine=True, eps=1e-6)
# Adaptive layernorm
self.adaptive_layernorm = AdaptiveLayerNorm(self.num_features)
# Linear layer to get q, k, and v
self.qkv = nn.Linear(self.num_features, 3 * self.num_features)
# QK layernorms. Original MMDiT used RMSNorm here.
Expand All @@ -140,10 +187,7 @@ def __init__(self, num_features: int):
nn.init.normal_(self.qkv.weight, std=0.02)

def forward(self, x: torch.Tensor, t: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
# Calculate the modulations
mods = self.adaLN_mlp(t).unsqueeze(1).chunk(2, dim=2)
# Forward, with modulations
x = modulate(self.input_norm(x), mods[0], mods[1])
x = self.adaptive_layernorm(x, t)
# Calculate the query, key, and values all in one go
q, k, v = self.qkv(x).chunk(3, dim=-1)
q = self.q_norm(q)
Expand Down Expand Up @@ -196,15 +240,12 @@ def __init__(self, num_features: int, expansion_factor: int = 4):
super().__init__()
self.num_features = num_features
self.expansion_factor = expansion_factor
# AdaLN MLP for post-attention. Initialized to zero so modulation acts as identity at initialization.
self.adaLN_mlp_linear = nn.Linear(self.num_features, 4 * self.num_features, bias=True)
nn.init.zeros_(self.adaLN_mlp_linear.weight)
nn.init.zeros_(self.adaLN_mlp_linear.bias)
self.adaLN_mlp = nn.Sequential(nn.SiLU(), self.adaLN_mlp_linear)
# Input modulation
self.modulate_v = ModulationLayer(self.num_features)
# Linear layer to process v
self.linear_v = nn.Linear(self.num_features, self.num_features)
# Layernorm for the output
self.output_norm = nn.LayerNorm(self.num_features, elementwise_affine=True, eps=1e-6)
self.output_norm = AdaptiveLayerNorm(self.num_features)
# Transformer style MLP layers
self.linear_1 = nn.Linear(self.num_features, self.expansion_factor * self.num_features)
self.nonlinearity = nn.GELU(approximate='tanh')
Expand All @@ -214,20 +255,20 @@ def __init__(self, num_features: int, expansion_factor: int = 4):
nn.init.zeros_(self.linear_2.bias)
# Output MLP
self.output_mlp = nn.Sequential(self.linear_1, self.nonlinearity, self.linear_2)
# Output modulation
self.modulate_output = ModulationLayer(self.num_features)

def forward(self, v: torch.Tensor, x: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
"""Forward takes v from self attention and the original sequence x with scalar conditioning t."""
# Calculate the modulations
mods = self.adaLN_mlp(t).unsqueeze(1).chunk(4, dim=2)
# Postprocess v with linear + gating modulation
y = mods[0] * self.linear_v(v)
y = self.modulate_v(self.linear_v(v), t)
y = x + y
# Adaptive layernorm
y = modulate(self.output_norm(y), mods[1], mods[2])
y = self.output_norm(y, t)
# Output MLP
y = self.output_mlp(y)
# Gating modulation for the output
y = mods[3] * y
y = self.modulate_output(y, t)
y = x + y
return y

Expand Down Expand Up @@ -353,17 +394,11 @@ def __init__(self,
self.transformer_blocks.append(
MMDiTBlock(self.num_features, self.num_heads, expansion_factor=self.expansion_factor, is_last=True))
# Output projection layer
self.final_norm = nn.LayerNorm(self.num_features, elementwise_affine=True, eps=1e-6)
self.final_norm = AdaptiveLayerNorm(self.num_features)
self.final_linear = nn.Linear(self.num_features, self.input_features)
# Init the output layer to zero
nn.init.zeros_(self.final_linear.weight)
nn.init.zeros_(self.final_linear.bias)
# AdaLN MLP for the output layer
self.adaLN_mlp_linear = nn.Linear(self.num_features, 2 * self.num_features)
# Init the modulations to zero. This will ensure the block acts as identity at initialization
nn.init.zeros_(self.adaLN_mlp_linear.weight)
nn.init.zeros_(self.adaLN_mlp_linear.bias)
self.adaLN_mlp = nn.Sequential(nn.SiLU(), self.adaLN_mlp_linear)

def fsdp_wrap_fn(self, module: nn.Module) -> bool:
if isinstance(module, MMDiTBlock):
Expand Down Expand Up @@ -438,7 +473,6 @@ def forward(self,
for block in self.transformer_blocks:
y, c = block(y, c, t, mask=mask)
# Pass through the output layers to get the right number of elements
mods = self.adaLN_mlp(t).unsqueeze(1).chunk(2, dim=2)
y = modulate(self.final_norm(y), mods[0], mods[1])
y = self.final_norm(y, t)
y = self.final_linear(y)
return y

0 comments on commit ab641eb

Please sign in to comment.