diff --git a/flagai/model/aquila2/modeling_aquila.py b/flagai/model/aquila2/modeling_aquila.py index ce6333e5..7963fa2d 100755 --- a/flagai/model/aquila2/modeling_aquila.py +++ b/flagai/model/aquila2/modeling_aquila.py @@ -82,6 +82,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min) +# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Aquila class AquilaRMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): """ @@ -98,6 +99,8 @@ def forward(self, hidden_states): return (self.weight * hidden_states).to(input_dtype) + +# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Aquila class AquilaRotaryEmbedding(torch.nn.Module): def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None): super().__init__() @@ -133,6 +136,7 @@ def forward(self, x, seq_len=None): self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype), ) +# Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->Aquila class AquilaLinearScalingRotaryEmbedding(AquilaRotaryEmbedding): """AquilaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev""" @@ -151,6 +155,7 @@ def _set_cos_sin_cache(self, seq_len, device, dtype): self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False) self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False) +# Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->Aquila class AquilaDynamicNTKScalingRotaryEmbedding(AquilaRotaryEmbedding): """AquilaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla""" @@ -195,6 +200,7 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids): return q_embed, k_embed +# Copied from transformers.models.llama.modeling_llama.LlamaMLP with Llama->Aquila class AquilaMLP(nn.Module): def __init__(self, config): super().__init__() @@ -241,6 +247,7 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim) +# Copied from transformers.models.llama.modeling_llama.LlamaAttention with Llama->Aquila class AquilaAttention(nn.Module): """Multi-headed attention from 'Attention Is All You Need' paper""" def __init__(self, config: AquilaConfig): @@ -390,6 +397,7 @@ def forward( return attn_output, attn_weights, past_key_value +# Copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer with Llama->Aquila class AquilaDecoderLayer(nn.Module): def __init__(self, config: AquilaConfig): super().__init__() @@ -474,6 +482,7 @@ def forward( "The bare Aquila Model outputting raw hidden-states without any specific head on top.", AQUILA_START_DOCSTRING, ) +# Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel with Llama->Aquila class AquilaPreTrainedModel(PreTrainedModel): config_class = AquilaConfig base_model_prefix = "model" @@ -565,6 +574,7 @@ def _set_gradient_checkpointing(self, module, value=False): "The bare Aquila Model outputting raw hidden-states without any specific head on top.", AQUILA_START_DOCSTRING, ) +# Copied from transformers.models.llama.modeling_llama.LlamaModel with LLAMA->AQUILA,Llama->Aquila class AquilaModel(AquilaPreTrainedModel): """ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`AquilaDecoderLayer`] @@ -742,6 +752,7 @@ def custom_forward(*inputs): attentions=all_self_attns, ) +# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM with LLAMA->AQUILA,Llama->Aquila class AquilaForCausalLM(AquilaPreTrainedModel): _tied_weights_keys = ["lm_head.weight"] @@ -1027,6 +1038,7 @@ def predict(self, text, tokenizer=None, """, AQUILA_START_DOCSTRING, ) +# Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with LLAMA->AQUILA,Llama->Aquila class AquilaForSequenceClassification(AquilaPreTrainedModel): _keys_to_ignore_on_load_missing = [r"lm_head.weight"]