Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Put flash attention 2 into ProGen2 #6

Merged
merged 14 commits into from
Dec 6, 2024
Prev Previous commit
Next Next commit
update progen2
  • Loading branch information
JinyuanSun committed Dec 5, 2024
commit 0ddc6a50181d74b7f6ef52a62331e0398af070b2
438 changes: 0 additions & 438 deletions 1pga_faesmfold.pdb

This file was deleted.

11 changes: 6 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,18 +73,19 @@ print("Repr shape:", outputs['last_hidden_state'].shape) # (batch_size, sequenc
# Step 5: start the repo if the code works for u!
```

For generative protein language like ProGen2:

### ProGen2

```python
import torch
from faesm.progen2 import ProGenForCausalLM
from transformers import AutoTokenizer
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = ProGenForCausalLM.from_pretrained("jinyuan22/ProGen2-small").to(torch.float16).to(device).eval()
# Avilable model from HF: ["jinyuan22/ProGen2-small", "jinyuan22/ProGen2-base", "jinyuan22/ProGen2-xlarge"]
model = ProGenForCausalLM.from_pretrained("jinyuan22/ProGen2-small").to(torch.float16).to(device).eval()
tokenizer = AutoTokenizer.from_pretrained("jinyuan22/ProGen2-small")

# sequence = "1" + "ACDEFGHIKLMNPQRSTVWY" * 50 + "2" # 1002 token

sequence = "2GFLPFRGADEGLAAREAATLAARGTAARAYREDSWAVPVPRGLLGDLTARVAALGAASPPPADPLAVTLDLHHVTAEVALTTVLDAATLVHGQTRVLSAEDAAEAATAAAAATEAYLERLQDFVLFMSASVRVWRRGNAAGATGPEWDQWYTVADRDALGSAPTHLAVLGRQADALCHFVLDRVAWGTCGTPLWSGDEDLGNVVATFAGYADRLATAPRDLIM1"

inputs = tokenizer(sequence, return_tensors="pt").to(device)
Expand All @@ -102,7 +103,7 @@ target = target - first_token

ce_eval = torch.nn.functional.cross_entropy(input=logits.view(-1, logits.size(-1)), target=target.view(-1), reduction="mean").item()
print(ce_eval)
assert abs(ce_eval - 2.4) < 0.1
assert abs(ce_eval - 2.4) < 0.1 # 2.4 is the reference ce for the official progen2-small
```

### Training \[WIP\]
Expand All @@ -112,7 +113,7 @@ It's recommended to use the flash attention for training. Because in the forward

# Benchmarking

Below is the comparison of peak memory usage and inference time of FAESM with the official ESM2 and shows that FAESM can save memory usage by up to 60% and inference time by up to 70% (length 1000). The benchmarking is done on ESM-650M with batch size 8, and a single A100 with 80GB of memory.
Below is the comparison of peak memory usage and inference time of FAESM with the official ESM2 and shows that FAESM can save memory usage by up to 60% and inference time by up to 70% (length 1000). The benchmarking is done on ESM-650M with batch size 8, and a single A100 with 80GB of memory.

![benchmark](assets/figs/benchmark.png)

Expand Down
File renamed without changes
31 changes: 0 additions & 31 deletions data.ipynb

This file was deleted.

Binary file removed esmfold_benchmark.png
Binary file not shown.
Loading