Skip to content

Commit

Permalink
added comments
Browse files Browse the repository at this point in the history
Signed-off-by: 严照东 <[email protected]>
  • Loading branch information
严照东 committed Oct 7, 2023
1 parent 3ca71af commit 41e724a
Showing 1 changed file with 12 additions and 0 deletions.
12 changes: 12 additions & 0 deletions flagai/model/aquila2/modeling_aquila.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int]
return inverted_mask.masked_fill(inverted_mask.to(torch.bool), torch.finfo(dtype).min)


# Copied from transformers.models.llama.modeling_llama.LlamaRMSNorm with Llama->Aquila
class AquilaRMSNorm(nn.Module):
def __init__(self, hidden_size, eps=1e-6):
"""
Expand All @@ -98,6 +99,8 @@ def forward(self, hidden_states):

return (self.weight * hidden_states).to(input_dtype)


# Copied from transformers.models.llama.modeling_llama.LlamaRotaryEmbedding with Llama->Aquila
class AquilaRotaryEmbedding(torch.nn.Module):
def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
super().__init__()
Expand Down Expand Up @@ -133,6 +136,7 @@ def forward(self, x, seq_len=None):
self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
)

# Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->Aquila
class AquilaLinearScalingRotaryEmbedding(AquilaRotaryEmbedding):
"""AquilaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""

Expand All @@ -151,6 +155,7 @@ def _set_cos_sin_cache(self, seq_len, device, dtype):
self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)

# Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->Aquila
class AquilaDynamicNTKScalingRotaryEmbedding(AquilaRotaryEmbedding):
"""AquilaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""

Expand Down Expand Up @@ -195,6 +200,7 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
return q_embed, k_embed


# Copied from transformers.models.llama.modeling_llama.LlamaMLP with Llama->Aquila
class AquilaMLP(nn.Module):
def __init__(self, config):
super().__init__()
Expand Down Expand Up @@ -241,6 +247,7 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)


# Copied from transformers.models.llama.modeling_llama.LlamaAttention with Llama->Aquila
class AquilaAttention(nn.Module):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config: AquilaConfig):
Expand Down Expand Up @@ -390,6 +397,7 @@ def forward(
return attn_output, attn_weights, past_key_value


# Copied from transformers.models.llama.modeling_llama.LlamaDecoderLayer with Llama->Aquila
class AquilaDecoderLayer(nn.Module):
def __init__(self, config: AquilaConfig):
super().__init__()
Expand Down Expand Up @@ -474,6 +482,7 @@ def forward(
"The bare Aquila Model outputting raw hidden-states without any specific head on top.",
AQUILA_START_DOCSTRING,
)
# Copied from transformers.models.llama.modeling_llama.LlamaPreTrainedModel with Llama->Aquila
class AquilaPreTrainedModel(PreTrainedModel):
config_class = AquilaConfig
base_model_prefix = "model"
Expand Down Expand Up @@ -565,6 +574,7 @@ def _set_gradient_checkpointing(self, module, value=False):
"The bare Aquila Model outputting raw hidden-states without any specific head on top.",
AQUILA_START_DOCSTRING,
)
# Copied from transformers.models.llama.modeling_llama.LlamaModel with LLAMA->AQUILA,Llama->Aquila
class AquilaModel(AquilaPreTrainedModel):
"""
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`AquilaDecoderLayer`]
Expand Down Expand Up @@ -742,6 +752,7 @@ def custom_forward(*inputs):
attentions=all_self_attns,
)

# Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM with LLAMA->AQUILA,Llama->Aquila
class AquilaForCausalLM(AquilaPreTrainedModel):
_tied_weights_keys = ["lm_head.weight"]

Expand Down Expand Up @@ -1027,6 +1038,7 @@ def predict(self, text, tokenizer=None,
""",
AQUILA_START_DOCSTRING,
)
# Copied from transformers.models.llama.modeling_llama.LlamaForSequenceClassification with LLAMA->AQUILA,Llama->Aquila
class AquilaForSequenceClassification(AquilaPreTrainedModel):
_keys_to_ignore_on_load_missing = [r"lm_head.weight"]

Expand Down

0 comments on commit 41e724a

Please sign in to comment.