Skip to content

Commit

Permalink
Refactor code for better readability and maintainability
Browse files Browse the repository at this point in the history
  • Loading branch information
pengzhangzhi committed Dec 6, 2024
1 parent 9f40f92 commit d921c2b
Show file tree
Hide file tree
Showing 7 changed files with 71 additions and 111 deletions.
4 changes: 1 addition & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ from faesm.progen2 import ProGenForCausalLM
from transformers import AutoTokenizer
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# Avilable model from HF: ["jinyuan22/ProGen2-small", "jinyuan22/ProGen2-base", "jinyuan22/ProGen2-xlarge"]
model = ProGenForCausalLM.from_pretrained("jinyuan22/ProGen2-small").to(torch.float16).to(device).eval()
model = ProGenForCausalLM.from_pretrained("jinyuan22/ProGen2-small").to(torch.float16).to(device).eval()
tokenizer = AutoTokenizer.from_pretrained("jinyuan22/ProGen2-small")

sequence = "2GFLPFRGADM1"
Expand All @@ -116,7 +116,6 @@ It's recommended to use the flash attention for training. Because in the forward

# Benchmarking


### FAESM vs. Official ESM2

Below is the comparison of peak memory usage and inference time of FAESM with the official ESM2. We show that FAESM can save memory usage by up to 60% and inference time by up to 70% (length 1000). The benchmarking is done on ESM-650M with batch size 8, and a single A100 with 80GB of memory.
Expand Down Expand Up @@ -161,7 +160,6 @@ pytest tests/test_compare_esmc.py
Save up to 60% of memory and run time by using FAProgen2.
![benchmark_progen](assets/figs/FAProGen2_benchmark.png)


# TODOs

- Training script
Expand Down
2 changes: 1 addition & 1 deletion faesm/esm.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,7 +589,7 @@ def forward(
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
output_hidden_states=output_hidden_states, # For the hidden states
output_hidden_states=output_hidden_states, # For the hidden states
)
sequence_output = outputs[0]
logits = self.lm_head(sequence_output)
Expand Down
131 changes: 42 additions & 89 deletions faesm/progen2.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# coding=utf-8
# Copyright 2021 The EleutherAI and HuggingFace Teams. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand All @@ -19,18 +18,18 @@

import torch
import torch.utils.checkpoint
from einops import rearrange
from torch import nn
from torch.nn import CrossEntropyLoss
from transformers.configuration_utils import PretrainedConfig
from transformers.activations import ACT2FN
from transformers.configuration_utils import PretrainedConfig
from transformers.modeling_outputs import (
BaseModelOutputWithPast,
CausalLMOutputWithPast,
)
from transformers.modeling_utils import PreTrainedModel
from transformers.utils import logging
from transformers.utils.model_parallel_utils import assert_device_map, get_device_map
from einops import rearrange

try:
from flash_attn import flash_attn_func # , flash_attn_qkvpacked_func
Expand Down Expand Up @@ -115,9 +114,7 @@ def fixed_pos_embedding(x, seq_dim=1, seq_len=None):
seq_len = x.shape[seq_dim]
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2) / dim))
sinusoid_inp = (
torch.einsum("i , j -> i j", torch.arange(seq_len), inv_freq)
.to(x.device)
.float()
torch.einsum("i , j -> i j", torch.arange(seq_len), inv_freq).to(x.device).float()
)
return torch.sin(sinusoid_inp).half(), torch.cos(sinusoid_inp).half()

Expand All @@ -131,9 +128,7 @@ def rotate_every_two(x):

def apply_rotary_pos_emb(x, sincos, offset=0):
sin, cos = map(
lambda t: t[None, offset : x.shape[1] + offset, None, :].repeat_interleave(
2, 3
),
lambda t: t[None, offset : x.shape[1] + offset, None, :].repeat_interleave(2, 3),
sincos,
)
# einsum notation for lambda t: repeat(t[offset:x.shape[1]+offset,:], "n d -> () n () (d j)", j=2)
Expand All @@ -147,9 +142,9 @@ def __init__(self, config):
max_positions = config.max_position_embeddings
self.register_buffer(
"bias",
torch.tril(
torch.ones((max_positions, max_positions), dtype=torch.bool)
).view(1, 1, max_positions, max_positions),
torch.tril(torch.ones((max_positions, max_positions), dtype=torch.bool)).view(
1, 1, max_positions, max_positions
),
)
self.register_buffer("masked_bias", torch.tensor(-1e9))

Expand All @@ -163,9 +158,9 @@ def __init__(self, config):
raise ValueError(
f"embed_dim must be divisible by num_attention_heads (got `embed_dim`: {self.embed_dim} and `num_attention_heads`: {self.num_attention_heads})."
)
self.scale_attn = torch.sqrt(
torch.tensor(self.head_dim, dtype=torch.float16)
).to(torch.get_default_dtype())
self.scale_attn = torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float16)).to(
torch.get_default_dtype()
)
self.qkv_proj = nn.Linear(self.embed_dim, self.embed_dim * 3, bias=False)

self.out_proj = nn.Linear(self.embed_dim, self.embed_dim, bias=False)
Expand All @@ -180,9 +175,7 @@ def _split_heads(self, x, n_head, dim_head, mp_num=4):
return reshaped

def _merge_heads(self, tensor, num_attention_heads, attn_head_size):
"""
Merges attn_head_size dim and num_attn_heads dim into n_ctx
"""
"""Merges attn_head_size dim and num_attn_heads dim into n_ctx."""
if len(tensor.shape) == 5:
# tensor = tensor.permute(0, 1, 3, 2, 4).contiguous()
tensor = rearrange(tensor, "a b c d e -> a b d c e")
Expand All @@ -204,12 +197,9 @@ def _attn(
attention_mask=None,
head_mask=None,
):

# compute causal mask from causal mask buffer
query_length, key_length = query.size(-2), key.size(-2)
causal_mask = self.bias[
:, :, key_length - query_length : key_length, :key_length
]
causal_mask = self.bias[:, :, key_length - query_length : key_length, :key_length]

# Keep the attention weights computation in fp32 to avoid overflow issues
query = query.to(torch.float16)
Expand Down Expand Up @@ -247,7 +237,6 @@ def forward(
use_cache=False,
output_attentions=False,
):

qkv = self.qkv_proj(hidden_states)
# TODO(enijkamp): factor out number of logical TPU-v3/v4 cores or make forward pass agnostic
# mp_num = 4
Expand All @@ -256,16 +245,10 @@ def forward(

local_dim = self.head_dim * self.num_attention_heads // mp_num
query, value, key = torch.split(qkv_split, local_dim, dim=-1)
query = self._split_heads(
query, self.num_attention_heads, self.head_dim, mp_num=mp_num
)
key = self._split_heads(
key, self.num_attention_heads, self.head_dim, mp_num=mp_num
)
query = self._split_heads(query, self.num_attention_heads, self.head_dim, mp_num=mp_num)
key = self._split_heads(key, self.num_attention_heads, self.head_dim, mp_num=mp_num)

value = self._split_heads(
value, self.num_attention_heads, self.head_dim, mp_num=mp_num
)
value = self._split_heads(value, self.num_attention_heads, self.head_dim, mp_num=mp_num)
value = value.permute(0, 2, 1, 3)

seq_len = key.shape[1]
Expand Down Expand Up @@ -308,13 +291,9 @@ def forward(
present = None

# compute self-attention: V x Softmax(QK^T)
attn_output, attn_weights = self._attn(
query, key, value, attention_mask, head_mask
)
attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)

attn_output = self._merge_heads(
attn_output, self.num_attention_heads, self.head_dim
)
attn_output = self._merge_heads(attn_output, self.num_attention_heads, self.head_dim)

attn_output = self.out_proj(attn_output)
attn_output = self.resid_dropout(attn_output)
Expand All @@ -331,9 +310,7 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def _split_heads(self, x, n_head, dim_head, mp_num=4):
return rearrange(
x, "b s m (l d) -> b s (m l) d", m=mp_num, l=n_head // mp_num, d=dim_head
)
return rearrange(x, "b s m (l d) -> b s (m l) d", m=mp_num, l=n_head // mp_num, d=dim_head)

def forward(
self,
Expand All @@ -358,15 +335,9 @@ def forward(
local_dim = self.head_dim * self.num_attention_heads // mp_num
query, value, key = torch.split(qkv_split, local_dim, dim=-1)

query = self._split_heads(
query, self.num_attention_heads, self.head_dim, mp_num=mp_num
)
key = self._split_heads(
key, self.num_attention_heads, self.head_dim, mp_num=mp_num
)
value = self._split_heads(
value, self.num_attention_heads, self.head_dim, mp_num=mp_num
)
query = self._split_heads(query, self.num_attention_heads, self.head_dim, mp_num=mp_num)
key = self._split_heads(key, self.num_attention_heads, self.head_dim, mp_num=mp_num)
value = self._split_heads(value, self.num_attention_heads, self.head_dim, mp_num=mp_num)

seq_len = key.shape[1]
offset = 0
Expand Down Expand Up @@ -413,9 +384,7 @@ def forward(
)

attn_output = rearrange(attn_output, "a b c d -> a c b d")
attn_output = self._merge_heads(
attn_output, self.num_attention_heads, self.head_dim
)
attn_output = self._merge_heads(attn_output, self.num_attention_heads, self.head_dim)
attn_output = self.resid_dropout(self.out_proj(attn_output))

outputs = (attn_output, present)
Expand All @@ -426,9 +395,7 @@ def forward(


class ProGenMLP(nn.Module):
def __init__(
self, intermediate_size, config
): # in MLP: intermediate_size= 4 * embed_dim
def __init__(self, intermediate_size, config): # in MLP: intermediate_size= 4 * embed_dim
super().__init__()
embed_dim = config.n_embd

Expand Down Expand Up @@ -490,10 +457,8 @@ def forward(


class ProGenPreTrainedModel(PreTrainedModel):
"""
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
models.
"""
"""An abstract class to handle weights initialization and a simple interface for downloading
and loading pretrained models."""

config_class = ProGenConfig
base_model_prefix = "transformer"
Expand Down Expand Up @@ -529,9 +494,7 @@ def __init__(self, config):
self.drop = nn.Dropout(config.embd_pdrop)
self.h = nn.ModuleList([ProGenBlock(config) for _ in range(config.n_layer)])
self.ln_f = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_epsilon)
self.rotary_dim = min(
config.rotary_dim, config.n_ctx // config.num_attention_heads
)
self.rotary_dim = min(config.rotary_dim, config.n_ctx // config.num_attention_heads)
self.init_weights()

# Model parallel
Expand Down Expand Up @@ -594,19 +557,15 @@ def forward(
return_dict=None,
):
output_attentions = (
output_attentions
if output_attentions is not None
else self.config.output_attentions
output_attentions if output_attentions is not None else self.config.output_attentions
)
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
use_cache = use_cache if use_cache is not None else self.config.use_cache
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

if input_ids is not None and inputs_embeds is not None:
raise ValueError(
Expand Down Expand Up @@ -687,7 +646,6 @@ def forward(
all_self_attentions = () if output_attentions else None
all_hidden_states = () if output_hidden_states else None
for i, (block, layer_past) in enumerate(zip(self.h, past_key_values)):

# Model parallel
if self.model_parallel:
torch.cuda.set_device(hidden_states.device)
Expand All @@ -705,7 +663,6 @@ def forward(
all_hidden_states = all_hidden_states + (hidden_states,)

if getattr(self.config, "gradient_checkpointing", False) and self.training:

if use_cache:
logger.warning(
"`use_cache=True` is incompatible with `config.gradient_checkpointing=True`. Setting "
Expand Down Expand Up @@ -742,9 +699,7 @@ def custom_forward(*inputs):
presents = presents + (outputs[1],)

if output_attentions:
all_self_attentions = all_self_attentions + (
outputs[2 if use_cache else 1],
)
all_self_attentions = all_self_attentions + (outputs[2 if use_cache else 1],)

# Model Parallel: If it's the last layer for that device, put things on the next device
if self.model_parallel:
Expand Down Expand Up @@ -863,15 +818,14 @@ def forward(
output_hidden_states=None,
return_dict=None,
):
r"""
labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
``labels = input_ids`` Indices are selected in ``[-100, 0, ..., config.vocab_size]`` All labels set to
``-100`` are ignored (masked), the loss is only computed for labels in ``[0, ..., config.vocab_size]``
r"""Labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`,
`optional`):
Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
``labels = input_ids`` Indices are selected in ``[-100, 0, ..., config.vocab_size]`` All labels set to
``-100`` are ignored (masked), the loss is only computed for labels in ``[0, ..., config.vocab_size]``
"""
return_dict = (
return_dict if return_dict is not None else self.config.use_return_dict
)
return_dict = return_dict if return_dict is not None else self.config.use_return_dict

transformer_outputs = self.transformer(
input_ids,
Expand Down Expand Up @@ -905,9 +859,7 @@ def forward(
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
loss = loss_fct(
shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1)
)
loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))

loss = loss.to(hidden_states.dtype)

Expand All @@ -927,10 +879,11 @@ def forward(
def _reorder_cache(
past: Tuple[Tuple[torch.Tensor]], beam_idx: torch.Tensor
) -> Tuple[Tuple[torch.Tensor]]:
"""
This function is used to re-order the :obj:`past_key_values` cache if
:meth:`~transformers.PretrainedModel.beam_search` or :meth:`~transformers.PretrainedModel.beam_sample` is
called. This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step.
"""This function is used to re-order the :obj:`past_key_values` cache if
:meth:`~transformers.PretrainedModel.beam_search` or
:meth:`~transformers.PretrainedModel.beam_sample` is called.
This is required to match :obj:`past_key_values` with the correct beam_idx at every generation step.
"""
return tuple(
tuple(
Expand Down
3 changes: 3 additions & 0 deletions tests/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from transformers import EsmForMaskedLM, EsmTokenizer

from faesm.esm import FAEsmForMaskedLM

# from tests.utils import generate_random_esm2_inputs

# Set Seaborn theme and professional settings
Expand All @@ -27,6 +28,7 @@
}
)


def generate_random_esm2_inputs(
tokenizer, batch_size=3, min_seq_length=5, max_seq_length=10, device="cuda"
):
Expand All @@ -49,6 +51,7 @@ def generate_random_esm2_inputs(
esm_input = {k: v.to(device) for k, v in esm_input.items()}
return esm_input


def benchmark_torch_memory(f, *args, **kwargs):
torch.cuda.reset_peak_memory_stats()
torch.cuda.synchronize()
Expand Down
Loading

0 comments on commit d921c2b

Please sign in to comment.