From 1d6f317a2052c5bf33407101602e477e99961454 Mon Sep 17 00:00:00 2001 From: vince62s Date: Wed, 5 Jun 2024 14:26:37 +0200 Subject: [PATCH] mlp refact --- docs/source/Core API/0_modules.rst | 2 +- eole/config/models.py | 2 +- eole/decoders/transformer_base.py | 81 +++++++----------- eole/decoders/transformer_decoder.py | 49 ++++++----- eole/decoders/transformer_lm_decoder.py | 22 +++-- eole/encoders/transformer.py | 52 +++++------ eole/modules/average_attn.py | 4 +- eole/modules/moe.py | 2 +- .../{position_ffn.py => transformer_mlp.py} | 56 +++++------- eole/tests/test_model_lm/model.00.safetensors | Bin 10135212 -> 10135220 bytes 10 files changed, 128 insertions(+), 142 deletions(-) rename eole/modules/{position_ffn.py => transformer_mlp.py} (63%) diff --git a/docs/source/Core API/0_modules.rst b/docs/source/Core API/0_modules.rst index abfa213f..3904fb8f 100644 --- a/docs/source/Core API/0_modules.rst +++ b/docs/source/Core API/0_modules.rst @@ -10,7 +10,7 @@ Embeddings .. autoclass:: eole.modules.PositionalEncoding :members: -.. autoclass:: eole.modules.position_ffn.PositionwiseFeedForward +.. autoclass:: eole.modules.transformer_mlp.MLP :members: Encoders diff --git a/eole/config/models.py b/eole/config/models.py index c3b94b1c..ec7f28dd 100644 --- a/eole/config/models.py +++ b/eole/config/models.py @@ -2,7 +2,7 @@ from pydantic import Field, field_validator, model_validator # , TypeAdapter from eole import constants -from eole.modules.position_ffn import ( +from eole.modules.transformer_mlp import ( ActivationFunction, ) # might be better defined elsewhere from eole.config.config import Config diff --git a/eole/decoders/transformer_base.py b/eole/decoders/transformer_base.py index 13d3f94f..a09a4aa5 100644 --- a/eole/decoders/transformer_base.py +++ b/eole/decoders/transformer_base.py @@ -7,7 +7,7 @@ import torch.nn as nn from eole.decoders.decoder import DecoderBase from eole.modules import MultiHeadedAttention, AverageAttention -from eole.modules.position_ffn import PositionwiseFeedForward +from eole.modules.transformer_mlp import MLP from eole.modules.moe import MoE from eole.modules.rmsnorm import RMSNorm @@ -23,56 +23,52 @@ def __init__( model_config (eole.config.TransformerDecoderConfig): full decoder config """ super(TransformerDecoderLayerBase, self).__init__() + if model_config.layer_norm == "standard": + layernorm = nn.LayerNorm + elif model_config.layer_norm == "rms": + layernorm = RMSNorm + else: + raise ValueError( + f"{model_config.layer_norm} layer norm type is not supported" + ) + self.parallel_residual = model_config.parallel_residual + self.shared_layer_norm = model_config.shared_layer_norm + self.dropout_p = getattr(running_config, "dropout", [0.0])[0] + self.full_context_alignment = model_config.full_context_alignment + self.alignment_heads = model_config.alignment_heads + self.sliding_window = model_config.sliding_window + self.self_attn_type = model_config.self_attn_type - if model_config.self_attn_type in ["scaled-dot", "scaled-dot-flash"]: + self.input_layernorm = layernorm( + model_config.hidden_size, eps=model_config.norm_eps + ) + if self.self_attn_type in ["scaled-dot", "scaled-dot-flash"]: self.self_attn = MultiHeadedAttention( model_config, running_config=running_config, attn_type="self", ) - elif model_config.self_attn_type == "average": + elif self.self_attn_type == "average": self.self_attn = AverageAttention( model_config.hidden_size, dropout=getattr(running_config, "attention_dropout", [0.0])[0], aan_useffn=model_config.aan_useffn, ) - + self.dropout = nn.Dropout(self.dropout_p) + self.post_attention_layernorm = layernorm( + model_config.hidden_size, eps=model_config.norm_eps + ) + if model_config.parallel_residual and not model_config.shared_layer_norm: + self.residual_layernorm = layernorm( + model_config.hidden_size, eps=model_config.norm_eps + ) if model_config.num_experts > 0: - self.feed_forward = MoE(model_config, running_config) + self.mlp = MoE(model_config, running_config) else: - self.feed_forward = PositionwiseFeedForward( + self.mlp = MLP( model_config, running_config=running_config, ) - self.parallel_residual = model_config.parallel_residual - self.shared_layer_norm = model_config.shared_layer_norm - if model_config.layer_norm == "standard": - self.layer_norm_1 = nn.LayerNorm( - model_config.hidden_size, eps=model_config.norm_eps - ) - if model_config.parallel_residual and not model_config.shared_layer_norm: - self.layer_norm_res = nn.LayerNorm( - model_config.hidden_size, eps=model_config.norm_eps - ) - elif model_config.layer_norm == "rms": - self.layer_norm_1 = RMSNorm( - model_config.hidden_size, eps=model_config.norm_eps - ) - if model_config.parallel_residual and not model_config.shared_layer_norm: - self.layer_norm_res = RMSNorm( - model_config.hidden_size, eps=model_config.norm_eps - ) - else: - raise ValueError( - f"{model_config.layer_norm} layer norm type is not supported" - ) - - self.dropout_p = getattr(running_config, "dropout", [0.0])[0] - self.dropout = nn.Dropout(self.dropout_p) - self.full_context_alignment = model_config.full_context_alignment - self.alignment_heads = model_config.alignment_heads - self.sliding_window = model_config.sliding_window - self.self_attn_type = model_config.self_attn_type def forward(self, *args, **kwargs): """Extend `_forward` for (possibly) multiple decoder pass: @@ -112,7 +108,7 @@ def forward(self, *args, **kwargs): def update_dropout(self, dropout, attention_dropout): self.self_attn.update_dropout(attention_dropout) - self.feed_forward.update_dropout(dropout) + self.mlp.update_dropout(dropout) self.dropout.p = dropout def _forward(self, *args, **kwargs): @@ -168,19 +164,6 @@ def __init__( ): super(TransformerDecoderBase, self).__init__() - if model_config.layer_norm == "standard": - self.layer_norm = nn.LayerNorm( - model_config.hidden_size, eps=model_config.norm_eps - ) - elif model_config.layer_norm == "rms": - self.layer_norm = RMSNorm( - model_config.hidden_size, eps=model_config.norm_eps - ) - else: - raise ValueError( - f"{model_config.layer_norm} layer norm type is not supported" - ) - self.alignment_layer = model_config.alignment_layer @classmethod diff --git a/eole/decoders/transformer_decoder.py b/eole/decoders/transformer_decoder.py index 5a5c98e0..68bee275 100644 --- a/eole/decoders/transformer_decoder.py +++ b/eole/decoders/transformer_decoder.py @@ -35,23 +35,22 @@ def __init__( model_config, running_config=running_config, ) - self.context_attn = MultiHeadedAttention( - model_config, - running_config=running_config, - attn_type="context", - ) if model_config.layer_norm == "standard": - self.layer_norm_2 = nn.LayerNorm( - model_config.hidden_size, eps=model_config.norm_eps - ) + layernorm = nn.LayerNorm elif model_config.layer_norm == "rms": - self.layer_norm_2 = RMSNorm( - model_config.hidden_size, eps=model_config.norm_eps - ) + layernorm = RMSNorm else: raise ValueError( f"{model_config.layer_norm} layer norm type is not supported" ) + self.precontext_layernorm = layernorm( + model_config.hidden_size, eps=model_config.norm_eps + ) + self.context_attn = MultiHeadedAttention( + model_config, + running_config=running_config, + attn_type="context", + ) def update_dropout(self, dropout, attention_dropout): super(TransformerDecoderLayer, self).update_dropout(dropout, attention_dropout) @@ -99,13 +98,14 @@ def _forward( # mask now are (batch x 1 x tlen x s or t len) # 1 = heads to be expanded in MHA - norm_layer_in = self.layer_norm_1(layer_in) + norm_layer_in = self.input_layernorm(layer_in) self_attn, _ = self._forward_self_attn( norm_layer_in, dec_mask, step, return_attn=return_attn ) if self.dropout_p > 0: self_attn = self.dropout(self_attn) + if self.parallel_residual: ctx_attn, attns = self.context_attn( enc_out, @@ -114,23 +114,18 @@ def _forward( mask=src_pad_mask, return_attn=return_attn, ) - # feed_forward applies residual, so we remove and apply residual with un-normed - layer_out = ( - self.feed_forward(norm_layer_in) - - norm_layer_in - + layer_in - + self_attn - + ctx_attn - ) + # we apply residual with un-normed + layer_out = self.mlp(norm_layer_in) + layer_in + self_attn + ctx_attn else: query = self_attn + layer_in - norm_query = self.layer_norm_2(query) + norm_query = self.precontext_layernorm(query) ctx_attn, attns = self.context_attn( enc_out, enc_out, norm_query, mask=src_pad_mask, return_attn=return_attn ) if self.dropout_p > 0: ctx_attn = self.dropout(ctx_attn) - layer_out = self.feed_forward(ctx_attn + query) + layer_out = self.post_attention_layernorm(ctx_attn + query) + layer_out = self.mlp(layer_out) + layer_out return layer_out, attns @@ -154,6 +149,14 @@ def __init__( super(TransformerDecoder, self).__init__( model_config, running_config=running_config ) + if model_config.layer_norm == "standard": + layernorm = nn.LayerNorm + elif model_config.layer_norm == "rms": + layernorm = RMSNorm + else: + raise ValueError( + f"{model_config.layer_norm} layer norm type is not supported" + ) self.transformer_layers = nn.ModuleList( [ @@ -164,6 +167,8 @@ def __init__( for i in range(model_config.layers) ] ) + # This is the Decoder out layer norm + self.layer_norm = layernorm(model_config.hidden_size, eps=model_config.norm_eps) def forward(self, emb, **kwargs): """Decode, possibly stepwise.""" diff --git a/eole/decoders/transformer_lm_decoder.py b/eole/decoders/transformer_lm_decoder.py index a6f1f576..fe784646 100644 --- a/eole/decoders/transformer_lm_decoder.py +++ b/eole/decoders/transformer_lm_decoder.py @@ -50,7 +50,7 @@ def _forward(self, layer_in, pad_mask, step=None, future=False, return_attn=Fals # mask now are (batch x 1 x tlen x tlen) # 1 = heads to be expanded in MHA - norm_layer_in = self.layer_norm_1(layer_in) + norm_layer_in = self.input_layernorm(layer_in) attn_output, attns = self._forward_self_attn( norm_layer_in, dec_mask, step, return_attn=return_attn @@ -58,16 +58,16 @@ def _forward(self, layer_in, pad_mask, step=None, future=False, return_attn=Fals if self.dropout_p > 0: attn_output = self.dropout(attn_output) if self.parallel_residual: - # feed_forward applies residual, so we remove and apply residual with un-normed + # we apply residual with un-normed if not self.shared_layer_norm: - norm_res_layer_in = self.layer_norm_res(layer_in) + norm_res_layer_in = self.residual_layernorm(layer_in) ff_in = norm_res_layer_in else: ff_in = norm_layer_in - layer_out = self.feed_forward(ff_in) - ff_in + layer_in + attn_output + layer_out = self.mlp(ff_in) + layer_in + attn_output else: - layer_out = attn_output + layer_in - layer_out = self.feed_forward(layer_out) + layer_out = self.post_attention_layernorm(attn_output + layer_in) + layer_out = self.mlp(layer_out) + layer_out return layer_out, attns @@ -85,6 +85,14 @@ def __init__( running_config=None, ): super(TransformerLMDecoder, self).__init__(model_config) + if model_config.layer_norm == "standard": + layernorm = nn.LayerNorm + elif model_config.layer_norm == "rms": + layernorm = RMSNorm + else: + raise ValueError( + f"{model_config.layer_norm} layer norm type is not supported" + ) self.transformer_layers = nn.ModuleList( [ TransformerLMDecoderLayer( @@ -94,6 +102,8 @@ def __init__( for i in range(model_config.layers) ] ) + # This is the Decoder out layer norm + self.layer_norm = layernorm(model_config.hidden_size, eps=model_config.norm_eps) def forward(self, emb, **kwargs): """Decode, possibly stepwise.""" diff --git a/eole/encoders/transformer.py b/eole/encoders/transformer.py index 265ae03e..7599d46a 100644 --- a/eole/encoders/transformer.py +++ b/eole/encoders/transformer.py @@ -6,7 +6,7 @@ from eole.encoders.encoder import EncoderBase from eole.modules import MultiHeadedAttention -from eole.modules.position_ffn import PositionwiseFeedForward +from eole.modules.transformer_mlp import MLP from eole.modules.rmsnorm import RMSNorm @@ -26,32 +26,34 @@ def __init__( running_config=None, ): super(TransformerEncoderLayer, self).__init__() + if model_config.layer_norm == "standard": + layernorm = nn.LayerNorm + elif model_config.layer_norm == "rms": + layernorm = RMSNorm + else: + raise ValueError( + f"{model_config.layer_norm} layer norm type is not supported" + ) + self.parallel_residual = model_config.parallel_residual + self.dropout_p = getattr(running_config, "dropout", [0.0])[0] + self.input_layernorm = layernorm( + model_config.hidden_size, eps=model_config.norm_eps + ) self.self_attn = MultiHeadedAttention( model_config, running_config=running_config, is_decoder=False, attn_type="self", ) - self.feed_forward = PositionwiseFeedForward( + self.dropout = nn.Dropout(self.dropout_p) + self.post_attention_layernorm = layernorm( + model_config.hidden_size, eps=model_config.norm_eps + ) + self.mlp = MLP( model_config, running_config=running_config, ) - self.parallel_residual = model_config.parallel_residual - if model_config.layer_norm == "standard": - self.layer_norm = nn.LayerNorm( - model_config.hidden_size, eps=model_config.norm_eps - ) - elif model_config.layer_norm == "rms": - self.layer_norm = RMSNorm( - model_config.hidden_size, eps=model_config.norm_eps - ) - else: - raise ValueError( - f"{model_config.layer_norm} layer norm type is not supported" - ) - self.dropout_p = getattr(running_config, "dropout", [0.0])[0] - self.dropout = nn.Dropout(self.dropout_p) def forward(self, layer_in, mask): """ @@ -63,26 +65,25 @@ def forward(self, layer_in, mask): (FloatTensor): * layer_out ``(batch_size, src_len, model_dim)`` """ - norm_layer_in = self.layer_norm(layer_in) + norm_layer_in = self.input_layernorm(layer_in) context, _ = self.self_attn( norm_layer_in, norm_layer_in, norm_layer_in, mask=mask ) if self.dropout_p > 0: context = self.dropout(context) if self.parallel_residual: - # feed_forward applies residual, so we remove and apply residual with un-normed - layer_out = ( - self.feed_forward(norm_layer_in) - norm_layer_in + layer_in + context - ) + # apply mlp and add residual with un-normed + layer_out = self.mlp(norm_layer_in) + layer_in + context else: - layer_out = context + layer_in - layer_out = self.feed_forward(layer_out) + # apply post attention norm and add residual after mlp + layer_out = self.post_attention_layernorm(context + layer_in) + layer_out = self.mlp(layer_out) + layer_out return layer_out def update_dropout(self, dropout, attention_dropout): self.self_attn.update_dropout(attention_dropout) - self.feed_forward.update_dropout(dropout) + self.mlp.update_dropout(dropout) self.dropout.p = dropout @@ -119,6 +120,7 @@ def __init__( for i in range(model_config.layers) ] ) + # This is the Encoder out layer norm if model_config.layer_norm == "standard": self.layer_norm = nn.LayerNorm( model_config.hidden_size, eps=model_config.norm_eps diff --git a/eole/modules/average_attn.py b/eole/modules/average_attn.py index 770d1310..e2464f67 100644 --- a/eole/modules/average_attn.py +++ b/eole/modules/average_attn.py @@ -5,8 +5,8 @@ import torch.nn as nn from torch import Tensor from typing import Optional -from eole.modules.position_ffn import PositionwiseFeedForward -from eole.modules.position_ffn import ActivationFunction +from eole.modules.transformer_mlp import MLP +from eole.modules.transformer_mlp import ActivationFunction def cumulative_average_mask( diff --git a/eole/modules/moe.py b/eole/modules/moe.py index d67b0067..6f45227a 100644 --- a/eole/modules/moe.py +++ b/eole/modules/moe.py @@ -1,7 +1,7 @@ """MoE mixture of experts".""" import torch import torch.nn as nn -from eole.modules.position_ffn import PositionwiseFeedForward +from eole.modules.transformer_mlp import MLP from torch.distributed import all_reduce diff --git a/eole/modules/position_ffn.py b/eole/modules/transformer_mlp.py similarity index 63% rename from eole/modules/position_ffn.py rename to eole/modules/transformer_mlp.py index ccc7acb7..51acc42c 100644 --- a/eole/modules/position_ffn.py +++ b/eole/modules/transformer_mlp.py @@ -1,4 +1,4 @@ -"""Position feed-forward network from "Attention is All You Need".""" +"""MLP network from "Attention is All You Need".""" import torch.nn as nn import torch.nn.functional as F @@ -14,6 +14,7 @@ class ActivationFunction(str, Enum): gelu = "gelu" silu = "silu" gated_gelu = "gated-gelu" + gated_silu = "gated-silu" # for silu, see: https://arxiv.org/pdf/2002.05202.pdf @@ -22,11 +23,12 @@ class ActivationFunction(str, Enum): ActivationFunction.gelu: F.gelu, ActivationFunction.silu: F.silu, ActivationFunction.gated_gelu: F.gelu, + ActivationFunction.gated_silu: F.silu, } -class PositionwiseFeedForward(nn.Module): - """A two-layer Feed-Forward-Network with residual layer norm. +class MLP(nn.Module): + """A two/three-layer Feed-Forward-Network. Args: model_config: eole.config.models.ModelConfig object @@ -39,51 +41,39 @@ def __init__( running_config=None, ): self.parallel_gpu = running_config.parallel_gpu - super(PositionwiseFeedForward, self).__init__() + super(MLP, self).__init__() assert ( model_config.transformer_ff % self.parallel_gpu == 0 ), "Model intermediate ffn size must be divisible by the number of partitions" - self.w_1 = skip_init( + self.gate_up_proj = skip_init( nn.Linear, in_features=model_config.hidden_size, out_features=model_config.transformer_ff // self.parallel_gpu, bias=model_config.add_ffnbias, ) - self.w_2 = skip_init( + self.down_proj = skip_init( nn.Linear, in_features=model_config.transformer_ff // self.parallel_gpu, out_features=model_config.hidden_size, bias=model_config.add_ffnbias, ) - if model_config.layer_norm == "standard" and not model_config.parallel_residual: - self.layer_norm = nn.LayerNorm( - model_config.hidden_size, eps=model_config.norm_eps - ) - elif model_config.layer_norm == "rms" and not model_config.parallel_residual: - self.layer_norm = RMSNorm( - model_config.hidden_size, eps=model_config.norm_eps - ) - elif not model_config.parallel_residual: - raise ValueError( - f"{model_config.layer_norm} layer norm type is not supported" - ) - self.parallel_residual = model_config.parallel_residual + self.dropout_p = getattr(running_config, "dropout", [0.0])[0] self.dropout_1 = nn.Dropout(self.dropout_p) self.dropout_2 = nn.Dropout(self.dropout_p) self.activation = ACTIVATION_FUNCTIONS[model_config.pos_ffn_activation_fn] if ( - model_config.pos_ffn_activation_fn == "silu" + model_config.pos_ffn_activation_fn == "gated-silu" or model_config.pos_ffn_activation_fn == "gated-gelu" ): - self.w_3 = skip_init( + self.up_proj = skip_init( nn.Linear, in_features=model_config.hidden_size, out_features=model_config.transformer_ff // self.parallel_gpu, bias=model_config.add_ffnbias, ) else: - self.w_3 = None + self.up_proj = None self.maybe_ckpt = ( checkpoint if "ffn" in getattr(running_config, "use_ckpting", []) @@ -99,24 +89,20 @@ def forward(self, x): Returns: (FloatTensor): Output ``(batch_size, input_len, model_dim)``. """ - if not self.parallel_residual: - norm_x = self.layer_norm(x) - else: - norm_x = x.clone() - inter = self.maybe_ckpt(self.w_1, norm_x) - inter = self.activation(inter) - if self.w_3 is not None: - inter.mul_(self.maybe_ckpt(self.w_3, norm_x)) + mlp_out = self.maybe_ckpt(self.gate_up_proj, x) + mlp_out = self.activation(mlp_out) + if self.up_proj is not None: + mlp_out.mul_(self.maybe_ckpt(self.up_proj, x)) if self.dropout_p > 0: - inter = self.dropout_1(inter) - inter = self.maybe_ckpt(self.w_2, inter) + mlp_out = self.dropout_1(mlp_out) + mlp_out = self.maybe_ckpt(self.down_proj, mlp_out) if self.dropout_p > 0: - inter = self.dropout_2(inter) + mlp_out = self.dropout_2(mlp_out) if self.parallel_gpu > 1: - all_reduce(inter) + all_reduce(mlp_out) - return inter + x + return mlp_out def update_dropout(self, dropout): self.dropout_1.p = dropout diff --git a/eole/tests/test_model_lm/model.00.safetensors b/eole/tests/test_model_lm/model.00.safetensors index 73b35b6ce52d3e40388cd0b14be2cb5771311649..416d01d2cba94c63d1ddf4a3b50372ee44318e37 100644 GIT binary patch delta 1186 zcmajb$xjnu9LDiXr$xFTr4}fnwpGERrA~)3EvP82xbOQ?t5gJ`t)+;lsECLwT3`2* ziHBaimBg!F$Xxs*Og#Ayc=cO|2?*-plg#`6o_X^=@5iq{jN_2>tTbWmOr-iU(eBtl zJS~*suqcy!rPKBj$#jp%n!Vz^>=YwXT}e-OD$t(nPefDcWS1C{Lq&=j@&{F6w|PsI zz|3*sM@4p7f@(zZYg#0zi7&F(UI=x>GVy3%>i02j(PA#QOh)`#I1&9n3SFcMI1ID}ms;GHy6-EZrEXmmtys0%S`~N2*ZrbN z518n!CSE;xGZ_;91W*16p8V3N5vqqzGV^<%dEfV$4`05^?;(B4NT02}i9}DVcQBcX zCwnp#ox&`63TLq{F(kFx{?7SRv4Dt{I>lRq!xT`pkXKVfA|Sh`r*eu*)ER5c1j63- zj<&GKSR3rBzi2y+?Shv4dWcjwq>8apSMhF|$SFB#!ah(2b4Qgw5E3t=NX`*nyqch27YLz1WBS zIDmsVgcuGZjw3jV9wg9pfzIEho}Lq7&Eh|@TOB+lX-hHxG!T);3c;u0=n1Xpks wqZq?AjN>|P;3m?zh1UVhX)ufC1^1Y!8YVE_OC