Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Put flash attention 2 into ProGen2 #6

Merged
merged 14 commits into from
Dec 6, 2024
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -162,3 +162,5 @@ cython_debug/
#.idea/
# vim
*.sw?

tests/progen2/
46 changes: 45 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,10 @@ pip install faesm

## ESM2

FAESM is a drop-in replacement for the official ESM implementation. You can use the same code as you would use the official ESM implementation. For example:import torch
FAESM is a drop-in replacement for the official ESM implementation. You can use the same code as you would use the official ESM implementation. For example:

```python
import torch
from faesm.esm import FAEsmForMaskedLM

# Step 1: Load the FAESM model
Expand All @@ -73,6 +74,47 @@ print("Repr shape:", outputs['last_hidden_state'].shape) # (batch_size, sequenc
# Step 5: start the repo if the code works for u!
```




## ProGen2

For generative protein language like ProGen2.

```python
import torch
from faesm.progen2 import ProGenForCausalLM
from transformers import AutoTokenizer
device = 'cuda' if torch.cuda.is_available() else 'cpu'
# Avilable model from HF: ["jinyuan22/ProGen2-small", "jinyuan22/ProGen2-base", "jinyuan22/ProGen2-xlarge"]
model = ProGenForCausalLM.from_pretrained("jinyuan22/ProGen2-small").to(torch.float16).to(device).eval()
tokenizer = AutoTokenizer.from_pretrained("jinyuan22/ProGen2-small")

sequence = "2GFLPFRGADEGLAAREAATLAARGTAARAYREDSWAVPVPRGLLGDLTARVAALGAASPPPADPLAVTLDLHHVTAEVALTTVLDAATLVHGQTRVLSAEDAAEAATAAAAATEAYLERLQDFVLFMSASVRVWRRGNAAGATGPEWDQWYTVADRDALGSAPTHLAVLGRQADALCHFVLDRVAWGTCGTPLWSGDEDLGNVVATFAGYADRLATAPRDLIM1"

inputs = tokenizer(sequence, return_tensors="pt").to(device)
target = inputs.input_ids[0,...]
with torch.no_grad():
logits = model(inputs.input_ids, labels=inputs.input_ids).logits[0,...]

logits = logits[:-1, ...]
target = target[1:]

bos_token, eos_token = 3, 4
if target[-1] in [bos_token, eos_token]:
logits = logits[:-1, ...]
target = target[:-1]

# remove unused logits
first_token, last_token = 5, 29
logits = logits[:, first_token:(last_token+1)]
target = target - first_token

ce_eval = torch.nn.functional.cross_entropy(input=logits.view(-1, logits.size(-1)), target=target.view(-1), reduction="mean").item()
print(ce_eval)
assert abs(ce_eval - 2.4) < 0.1 # 2.4 is the reference ce for the official progen2-small
```

## ESM-C

Right after EvolutionaryScale release [ESM-C](https://www.evolutionaryscale.ai/blog/esm-cambrian), we follow up with the flash attention version of ESM-C in FAESM. You can run ESM-C easily with the following code:
Expand All @@ -85,6 +127,7 @@ input_ids = model.tokenizer(sequence, return_tensors="pt")["input_ids"].to("cuda
output = model(input_ids)
print(output.sequence_logits.shape)
print(output.embeddings.shape)

```

### Training \[WIP\]
Expand All @@ -94,6 +137,7 @@ It's recommended to use the flash attention for training. Because in the forward

# Benchmarking


### FAESM vs. Official ESM2

Below is the comparison of peak memory usage and inference time of FAESM with the official ESM2. We show that FAESM can save memory usage by up to 60% and inference time by up to 70% (length 1000). The benchmarking is done on ESM-650M with batch size 8, and a single A100 with 80GB of memory.
Expand Down
Binary file added assets/figs/FAProGen2_benchmark.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
13 changes: 9 additions & 4 deletions faesm/esm.py
Original file line number Diff line number Diff line change
Expand Up @@ -589,14 +589,19 @@ def forward(
attention_mask=attention_mask,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
output_hidden_states=output_hidden_states, # For the hidden states
)
sequence_output = outputs[0]
logits = self.lm_head(sequence_output)

result = {
"logits": logits,
"last_hidden_state": sequence_output,
}
if outputs.hidden_states is not None:
result = {
"logits": logits,
"last_hidden_state": sequence_output,
"hidden_states": [x.unsqueeze(0) for x in outputs.hidden_states],
}
else:
result = {"logits": logits, "last_hidden_state": sequence_output}
return result

@classmethod
Expand Down
Loading