-
import pandas as pd
from torch.utils.data import Dataset, DataLoader
import torch
import esm
import torch.nn.functional as F
import time
DEVICE = "cuda:0"
model, alphabet = esm.pretrained.esm2_t33_650M_UR50D()
model = model.to(DEVICE)
model.eval()
class ProteinSequenceDataset(Dataset):
def __init__(self, dataframe, seq_col, id_col):
self.dataframe = dataframe
self.seq_col = seq_col
self.id_col = id_col
def __len__(self):
return len(self.dataframe)
def __getitem__(self, idx):
seq = self.dataframe.iloc[idx][self.seq_col]
id = self.dataframe.iloc[idx][self.id_col]
return seq, id
def protein_collate_batch(batch):
seqs, ids = zip(*batch)
batch_labels, batch_strs, batch_tokens = alphabet.get_batch_converter()(list(zip(ids, seqs)))
lengths = (batch_tokens != alphabet.padding_idx).sum(1)
return batch_tokens.to(DEVICE), lengths, ids
full_data = pd.read_csv('test.csv')
dataset = ProteinSequenceDataset(full_data, 'protein_seq', 'Raw_ID2')
data_loader = DataLoader(dataset, batch_size=1, shuffle=False, collate_fn=protein_collate_batch)
def process_protein_data(data_loader, fixed_dim=1024):
sequence_representations = []
ids_list = []
for batch_tokens, batch_lens, ids in data_loader:
with torch.no_grad():
results = model(batch_tokens, repr_layers=[33], return_contacts=False)
token_representations = results["representations"][33]
for i, tokens_len in enumerate(batch_lens):
seq_rep = token_representations[i, 1 : tokens_len - 1]
pad_size = fixed_dim - seq_rep.size(0)
if pad_size > 0:
seq_rep = F.pad(seq_rep, (0, 0, 0, pad_size), "constant", 0)
else:
seq_rep = seq_rep[:fixed_dim]
sequence_representations.append(seq_rep)
ids_list.append(ids[i])
return ids_list, sequence_representations
time1 = time.time()
ids_list, sequence_representations = process_protein_data(data_loader)
time2 = time.time()
print('embedding time:', time2 - time1) I want to use the above code to get the embeddings of proteins for my downstream tasks, I tried different batch_size in the code and the output is as follows:
|
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 3 replies
-
Am I understanding this correctly that when you have a batch size of 8, it's embedding 8 sequences in 257 seconds? Meaning one sequence is embedding in 32 seconds, which should be much faster than batch_size=1? |
Beta Was this translation helpful? Give feedback.
Could it be that your file contains sequences of very ragged sequence lengths, so the number of padding tokens is high? Try sorting the sequences first.