diff --git a/src/models/esm_models.py b/src/models/esm_models.py index 3c9504a..3f62b05 100644 --- a/src/models/esm_models.py +++ b/src/models/esm_models.py @@ -60,9 +60,13 @@ def __init__(self, esm_head:str='facebook/esm2_t6_8M_UR50D', self.out = nn.Linear(512, 1) # 1 output (binding affinity) def forward_pro(self, data): + pro_seqs = data.pro_seq + if type(pro_seqs) is str: + pro_seqs = [pro_seqs] + #### ESM emb #### # cls and sep tokens are added to the sequence by the tokenizer - seq_tok = self.esm_tok(data.pro_seq, + seq_tok = self.esm_tok(pro_seqs, return_tensors='pt', padding=True) # [B, L_max+2] seq_tok['input_ids'] = seq_tok['input_ids'].to(data.x.device) @@ -75,7 +79,7 @@ def forward_pro(self, data): # removing token by applying mask L_max = esm_emb.shape[1] # L_max+1 - mask = torch.arange(L_max)[None, :] < torch.tensor([len(seq) for seq in data.pro_seq])[:, None] + mask = torch.arange(L_max)[None, :] < torch.tensor([len(seq) for seq in pro_seqs])[:, None] mask = mask.flatten(0,1) # [B*L_max+1] # flatten from [B, L_max+1, emb_dim] @@ -217,9 +221,13 @@ def __init__(self, esm_head: str = 'westlake-repl/SaProt_35M_AF2', # overwrite the forward_pro pass to account for new saprot model def forward_pro(self, data): + pro_seqs = data.pro_seq + if type(pro_seqs) is str: + pro_seqs = [pro_seqs] + #### ESM emb #### # cls and sep tokens are added to the sequence by the tokenizer - seq_tok = self.esm_tok(data.pro_seq, + seq_tok = self.esm_tok(pro_seqs, return_tensors='pt', padding=True) # [B, L_max+2] seq_tok['input_ids'] = seq_tok['input_ids'].to(data.x.device) @@ -237,7 +245,7 @@ def forward_pro(self, data): # for saProt token 2 == L_max = esm_emb.shape[1] # L_max+1 mask = torch.arange(L_max)[None, :] < torch.tensor([len(seq)/2 #NOTE: this is the main difference from normal ESM since the input sequence includes SA tokens - for seq in data.pro_seq])[:, None] + for seq in pro_seqs])[:, None] mask = mask.flatten(0,1) # [B*L_max+1] # flatten from [B, L_max+1, emb_dim]