Skip to content

Commit

Permalink
[fix bug for esmc] generate sequence_id regardless of flash attention…
Browse files Browse the repository at this point in the history
… enabled
  • Loading branch information
pengzhangzhi committed Dec 6, 2024
1 parent 4460468 commit 371d014
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 30 deletions.
2 changes: 1 addition & 1 deletion faesm/esmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,6 +473,7 @@ def forward(
Returns:
ESMCOutput: The output of the ESMC model.
"""
sequence_id = sequence_tokens == self.tokenizer.pad_token_id
if self.use_flash_attn:
sequence_tokens, cu_seqlens, max_seqlen, _, pad_fn = unpad(
sequence_tokens.unsqueeze(-1), ~sequence_id
Expand All @@ -484,7 +485,6 @@ def forward(
pad_fn = lambda x: x
cu_seqlens = None
max_seqlen = None
sequence_id = sequence_tokens == self.tokenizer.pad_token_id

x = self.embed(sequence_tokens)
x, _ = self.transformer(
Expand Down
71 changes: 42 additions & 29 deletions tests/test_compare_esmc.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,53 @@
import torch
from esm.models.esmc import ESMC
from esm.sdk.api import ESMProtein, LogitsConfig

from faesm.esmc import ESMC as FAESMC
from huggingface_hub import login

# Function for benchmarking two models
def benchmark_flash_vs_official(sequence, use_flash_attn):
# Flash Attention Implementation
model_flash = FAESMC.from_pretrained("esmc_300m", use_flash_attn=use_flash_attn).to("cuda")
input_ids_flash = model_flash.tokenizer(sequence, return_tensors="pt")["input_ids"].to("cuda")
output_flash = model_flash(input_ids_flash)
logits_flash = output_flash.sequence_logits
embeddings_flash = output_flash.embeddings

# Official Implementation
protein = ESMProtein(sequence=sequence[0]) # Single sequence for now
model_official = ESMC.from_pretrained("esmc_300m").to("cuda")
protein_tensor = model_official.encode(protein)
logits_output_official = model_official.logits(
protein_tensor, LogitsConfig(sequence=True, return_embeddings=True)
)
logits_official = logits_output_official.logits.sequence
embeddings_official = logits_output_official.embeddings

# Compute differences
logits_diff = torch.abs(logits_flash - logits_official).max()
embeddings_diff = torch.abs(embeddings_flash - embeddings_official).max()

return logits_diff.item(), embeddings_diff.item()

# Define the sequence
seq = "MPGWFKKAWYGLASLLSFSSFILIIVALVVPHWLSGKILCQTGVDLVNATDRELVKFIGDIYYGLFRGCKVRQCGLGGRQSQFTIFPHLVKELNAGLHVMILLLLFLALALALVSMGFAILNMIQVPYRAVSGPGGICLWNVLAGGVVALAIASFVAAVKFHDLTERIANFQEKLFQFVVVEEQYEESFWICVASASAHAANLVVVAISQIPLPEIKTKIEEATVTAEDILY"
sequence = [seq]

# Flash Attention Implementation
model_flash = FAESMC.from_pretrained("esmc_300m", use_flash_attn=False).to("cuda")
input_ids_flash = model_flash.tokenizer(sequence, return_tensors="pt")["input_ids"].to("cuda")
output_flash = model_flash(input_ids_flash)
logits_flash = output_flash.sequence_logits
embeddings_flash = output_flash.embeddings
# Login to Hugging Face Hub (use your API key with "Read" permission)
login("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")

# Official Implementation
from huggingface_hub import login
# Benchmark with `use_flash_attn=True`
logits_diff_flash, embeddings_diff_flash = benchmark_flash_vs_official(sequence, use_flash_attn=True)
print("[Flash Attention Enabled]")
print("Max absolute error in logits:", logits_diff_flash)
print("Max absolute error in embeddings:", embeddings_diff_flash)
assert logits_diff_flash < 1, f"Logits diff: {logits_diff_flash}"
assert embeddings_diff_flash < 0.1, f"Embeddings diff: {embeddings_diff_flash}"

# Will instruct you how to get an API key from huggingface hub, make one with "Read" permission.
login("Your API key")
protein = ESMProtein(sequence=seq)
model_official = ESMC.from_pretrained("esmc_300m").to("cuda")
protein_tensor = model_official.encode(protein)
logits_output_official = model_official.logits(
protein_tensor, LogitsConfig(sequence=True, return_embeddings=True)
)
logits_official = logits_output_official.logits.sequence
embeddings_official = logits_output_official.embeddings

# Compute differences
logits_diff = torch.abs(logits_flash - logits_official).max()
embeddings_diff = torch.abs(embeddings_flash - embeddings_official).max()

# Print results
print("Max absolute error in logits:", logits_diff.item())
print("Max absolute error in embeddings:", embeddings_diff.item())
assert logits_diff < 1, f"Logits diff: {logits_diff}"
assert embeddings_diff < 0.1, f"Embeddings diff: {embeddings_diff}"
# Benchmark with `use_flash_attn=False`
logits_diff_no_flash, embeddings_diff_no_flash = benchmark_flash_vs_official(sequence, use_flash_attn=False)
print("\n[Flash Attention Disabled]")
print("Max absolute error in logits:", logits_diff_no_flash)
print("Max absolute error in embeddings:", embeddings_diff_no_flash)
assert logits_diff_no_flash < 1, f"Logits diff: {logits_diff_no_flash}"
assert embeddings_diff_no_flash < 0.1, f"Embeddings diff: {embeddings_diff_no_flash}"

0 comments on commit 371d014

Please sign in to comment.