Skip to content

Commit

Permalink
mlp refact
Browse files Browse the repository at this point in the history
  • Loading branch information
vince62s committed Jun 5, 2024
1 parent fb2a37b commit 1d6f317
Show file tree
Hide file tree
Showing 10 changed files with 128 additions and 142 deletions.
2 changes: 1 addition & 1 deletion docs/source/Core API/0_modules.rst
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ Embeddings
.. autoclass:: eole.modules.PositionalEncoding
:members:

.. autoclass:: eole.modules.position_ffn.PositionwiseFeedForward
.. autoclass:: eole.modules.transformer_mlp.MLP
:members:

Encoders
Expand Down
2 changes: 1 addition & 1 deletion eole/config/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from pydantic import Field, field_validator, model_validator # , TypeAdapter

from eole import constants
from eole.modules.position_ffn import (
from eole.modules.transformer_mlp import (
ActivationFunction,
) # might be better defined elsewhere
from eole.config.config import Config
Expand Down
81 changes: 32 additions & 49 deletions eole/decoders/transformer_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import torch.nn as nn
from eole.decoders.decoder import DecoderBase
from eole.modules import MultiHeadedAttention, AverageAttention
from eole.modules.position_ffn import PositionwiseFeedForward
from eole.modules.transformer_mlp import MLP
from eole.modules.moe import MoE
from eole.modules.rmsnorm import RMSNorm

Expand All @@ -23,56 +23,52 @@ def __init__(
model_config (eole.config.TransformerDecoderConfig): full decoder config
"""
super(TransformerDecoderLayerBase, self).__init__()
if model_config.layer_norm == "standard":
layernorm = nn.LayerNorm
elif model_config.layer_norm == "rms":
layernorm = RMSNorm
else:
raise ValueError(
f"{model_config.layer_norm} layer norm type is not supported"
)
self.parallel_residual = model_config.parallel_residual
self.shared_layer_norm = model_config.shared_layer_norm
self.dropout_p = getattr(running_config, "dropout", [0.0])[0]
self.full_context_alignment = model_config.full_context_alignment
self.alignment_heads = model_config.alignment_heads
self.sliding_window = model_config.sliding_window
self.self_attn_type = model_config.self_attn_type

if model_config.self_attn_type in ["scaled-dot", "scaled-dot-flash"]:
self.input_layernorm = layernorm(
model_config.hidden_size, eps=model_config.norm_eps
)
if self.self_attn_type in ["scaled-dot", "scaled-dot-flash"]:
self.self_attn = MultiHeadedAttention(
model_config,
running_config=running_config,
attn_type="self",
)
elif model_config.self_attn_type == "average":
elif self.self_attn_type == "average":
self.self_attn = AverageAttention(
model_config.hidden_size,
dropout=getattr(running_config, "attention_dropout", [0.0])[0],
aan_useffn=model_config.aan_useffn,
)

self.dropout = nn.Dropout(self.dropout_p)
self.post_attention_layernorm = layernorm(
model_config.hidden_size, eps=model_config.norm_eps
)
if model_config.parallel_residual and not model_config.shared_layer_norm:
self.residual_layernorm = layernorm(
model_config.hidden_size, eps=model_config.norm_eps
)
if model_config.num_experts > 0:
self.feed_forward = MoE(model_config, running_config)
self.mlp = MoE(model_config, running_config)
else:
self.feed_forward = PositionwiseFeedForward(
self.mlp = MLP(
model_config,
running_config=running_config,
)
self.parallel_residual = model_config.parallel_residual
self.shared_layer_norm = model_config.shared_layer_norm
if model_config.layer_norm == "standard":
self.layer_norm_1 = nn.LayerNorm(
model_config.hidden_size, eps=model_config.norm_eps
)
if model_config.parallel_residual and not model_config.shared_layer_norm:
self.layer_norm_res = nn.LayerNorm(
model_config.hidden_size, eps=model_config.norm_eps
)
elif model_config.layer_norm == "rms":
self.layer_norm_1 = RMSNorm(
model_config.hidden_size, eps=model_config.norm_eps
)
if model_config.parallel_residual and not model_config.shared_layer_norm:
self.layer_norm_res = RMSNorm(
model_config.hidden_size, eps=model_config.norm_eps
)
else:
raise ValueError(
f"{model_config.layer_norm} layer norm type is not supported"
)

self.dropout_p = getattr(running_config, "dropout", [0.0])[0]
self.dropout = nn.Dropout(self.dropout_p)
self.full_context_alignment = model_config.full_context_alignment
self.alignment_heads = model_config.alignment_heads
self.sliding_window = model_config.sliding_window
self.self_attn_type = model_config.self_attn_type

def forward(self, *args, **kwargs):
"""Extend `_forward` for (possibly) multiple decoder pass:
Expand Down Expand Up @@ -112,7 +108,7 @@ def forward(self, *args, **kwargs):

def update_dropout(self, dropout, attention_dropout):
self.self_attn.update_dropout(attention_dropout)
self.feed_forward.update_dropout(dropout)
self.mlp.update_dropout(dropout)
self.dropout.p = dropout

def _forward(self, *args, **kwargs):
Expand Down Expand Up @@ -168,19 +164,6 @@ def __init__(
):
super(TransformerDecoderBase, self).__init__()

if model_config.layer_norm == "standard":
self.layer_norm = nn.LayerNorm(
model_config.hidden_size, eps=model_config.norm_eps
)
elif model_config.layer_norm == "rms":
self.layer_norm = RMSNorm(
model_config.hidden_size, eps=model_config.norm_eps
)
else:
raise ValueError(
f"{model_config.layer_norm} layer norm type is not supported"
)

self.alignment_layer = model_config.alignment_layer

@classmethod
Expand Down
49 changes: 27 additions & 22 deletions eole/decoders/transformer_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,23 +35,22 @@ def __init__(
model_config,
running_config=running_config,
)
self.context_attn = MultiHeadedAttention(
model_config,
running_config=running_config,
attn_type="context",
)
if model_config.layer_norm == "standard":
self.layer_norm_2 = nn.LayerNorm(
model_config.hidden_size, eps=model_config.norm_eps
)
layernorm = nn.LayerNorm
elif model_config.layer_norm == "rms":
self.layer_norm_2 = RMSNorm(
model_config.hidden_size, eps=model_config.norm_eps
)
layernorm = RMSNorm
else:
raise ValueError(
f"{model_config.layer_norm} layer norm type is not supported"
)
self.precontext_layernorm = layernorm(
model_config.hidden_size, eps=model_config.norm_eps
)
self.context_attn = MultiHeadedAttention(
model_config,
running_config=running_config,
attn_type="context",
)

def update_dropout(self, dropout, attention_dropout):
super(TransformerDecoderLayer, self).update_dropout(dropout, attention_dropout)
Expand Down Expand Up @@ -99,13 +98,14 @@ def _forward(
# mask now are (batch x 1 x tlen x s or t len)
# 1 = heads to be expanded in MHA

norm_layer_in = self.layer_norm_1(layer_in)
norm_layer_in = self.input_layernorm(layer_in)

self_attn, _ = self._forward_self_attn(
norm_layer_in, dec_mask, step, return_attn=return_attn
)
if self.dropout_p > 0:
self_attn = self.dropout(self_attn)

if self.parallel_residual:
ctx_attn, attns = self.context_attn(
enc_out,
Expand All @@ -114,23 +114,18 @@ def _forward(
mask=src_pad_mask,
return_attn=return_attn,
)
# feed_forward applies residual, so we remove and apply residual with un-normed
layer_out = (
self.feed_forward(norm_layer_in)
- norm_layer_in
+ layer_in
+ self_attn
+ ctx_attn
)
# we apply residual with un-normed
layer_out = self.mlp(norm_layer_in) + layer_in + self_attn + ctx_attn
else:
query = self_attn + layer_in
norm_query = self.layer_norm_2(query)
norm_query = self.precontext_layernorm(query)
ctx_attn, attns = self.context_attn(
enc_out, enc_out, norm_query, mask=src_pad_mask, return_attn=return_attn
)
if self.dropout_p > 0:
ctx_attn = self.dropout(ctx_attn)
layer_out = self.feed_forward(ctx_attn + query)
layer_out = self.post_attention_layernorm(ctx_attn + query)
layer_out = self.mlp(layer_out) + layer_out

return layer_out, attns

Expand All @@ -154,6 +149,14 @@ def __init__(
super(TransformerDecoder, self).__init__(
model_config, running_config=running_config
)
if model_config.layer_norm == "standard":
layernorm = nn.LayerNorm
elif model_config.layer_norm == "rms":
layernorm = RMSNorm
else:
raise ValueError(
f"{model_config.layer_norm} layer norm type is not supported"
)

self.transformer_layers = nn.ModuleList(
[
Expand All @@ -164,6 +167,8 @@ def __init__(
for i in range(model_config.layers)
]
)
# This is the Decoder out layer norm
self.layer_norm = layernorm(model_config.hidden_size, eps=model_config.norm_eps)

def forward(self, emb, **kwargs):
"""Decode, possibly stepwise."""
Expand Down
22 changes: 16 additions & 6 deletions eole/decoders/transformer_lm_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,24 +50,24 @@ def _forward(self, layer_in, pad_mask, step=None, future=False, return_attn=Fals
# mask now are (batch x 1 x tlen x tlen)
# 1 = heads to be expanded in MHA

norm_layer_in = self.layer_norm_1(layer_in)
norm_layer_in = self.input_layernorm(layer_in)

attn_output, attns = self._forward_self_attn(
norm_layer_in, dec_mask, step, return_attn=return_attn
)
if self.dropout_p > 0:
attn_output = self.dropout(attn_output)
if self.parallel_residual:
# feed_forward applies residual, so we remove and apply residual with un-normed
# we apply residual with un-normed
if not self.shared_layer_norm:
norm_res_layer_in = self.layer_norm_res(layer_in)
norm_res_layer_in = self.residual_layernorm(layer_in)
ff_in = norm_res_layer_in
else:
ff_in = norm_layer_in
layer_out = self.feed_forward(ff_in) - ff_in + layer_in + attn_output
layer_out = self.mlp(ff_in) + layer_in + attn_output
else:
layer_out = attn_output + layer_in
layer_out = self.feed_forward(layer_out)
layer_out = self.post_attention_layernorm(attn_output + layer_in)
layer_out = self.mlp(layer_out) + layer_out

return layer_out, attns

Expand All @@ -85,6 +85,14 @@ def __init__(
running_config=None,
):
super(TransformerLMDecoder, self).__init__(model_config)
if model_config.layer_norm == "standard":
layernorm = nn.LayerNorm
elif model_config.layer_norm == "rms":
layernorm = RMSNorm
else:
raise ValueError(
f"{model_config.layer_norm} layer norm type is not supported"
)
self.transformer_layers = nn.ModuleList(
[
TransformerLMDecoderLayer(
Expand All @@ -94,6 +102,8 @@ def __init__(
for i in range(model_config.layers)
]
)
# This is the Decoder out layer norm
self.layer_norm = layernorm(model_config.hidden_size, eps=model_config.norm_eps)

def forward(self, emb, **kwargs):
"""Decode, possibly stepwise."""
Expand Down
Loading

0 comments on commit 1d6f317

Please sign in to comment.