Skip to content

Commit

Permalink
fix(esm_models): check if pro_seq is just string before running esm a…
Browse files Browse the repository at this point in the history
…nd building mask.
  • Loading branch information
jyaacoub committed Aug 12, 2024
1 parent 40942ab commit b53a0ff
Showing 1 changed file with 12 additions and 4 deletions.
16 changes: 12 additions & 4 deletions src/models/esm_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,13 @@ def __init__(self, esm_head:str='facebook/esm2_t6_8M_UR50D',
self.out = nn.Linear(512, 1) # 1 output (binding affinity)

def forward_pro(self, data):
pro_seqs = data.pro_seq
if type(pro_seqs) is str:
pro_seqs = [pro_seqs]

#### ESM emb ####
# cls and sep tokens are added to the sequence by the tokenizer
seq_tok = self.esm_tok(data.pro_seq,
seq_tok = self.esm_tok(pro_seqs,
return_tensors='pt',
padding=True) # [B, L_max+2]
seq_tok['input_ids'] = seq_tok['input_ids'].to(data.x.device)
Expand All @@ -75,7 +79,7 @@ def forward_pro(self, data):

# removing <sep> token by applying mask
L_max = esm_emb.shape[1] # L_max+1
mask = torch.arange(L_max)[None, :] < torch.tensor([len(seq) for seq in data.pro_seq])[:, None]
mask = torch.arange(L_max)[None, :] < torch.tensor([len(seq) for seq in pro_seqs])[:, None]
mask = mask.flatten(0,1) # [B*L_max+1]

# flatten from [B, L_max+1, emb_dim]
Expand Down Expand Up @@ -217,9 +221,13 @@ def __init__(self, esm_head: str = 'westlake-repl/SaProt_35M_AF2',

# overwrite the forward_pro pass to account for new saprot model
def forward_pro(self, data):
pro_seqs = data.pro_seq
if type(pro_seqs) is str:
pro_seqs = [pro_seqs]

#### ESM emb ####
# cls and sep tokens are added to the sequence by the tokenizer
seq_tok = self.esm_tok(data.pro_seq,
seq_tok = self.esm_tok(pro_seqs,
return_tensors='pt',
padding=True) # [B, L_max+2]
seq_tok['input_ids'] = seq_tok['input_ids'].to(data.x.device)
Expand All @@ -237,7 +245,7 @@ def forward_pro(self, data):
# for saProt token 2 == <eos>
L_max = esm_emb.shape[1] # L_max+1
mask = torch.arange(L_max)[None, :] < torch.tensor([len(seq)/2 #NOTE: this is the main difference from normal ESM since the input sequence includes SA tokens
for seq in data.pro_seq])[:, None]
for seq in pro_seqs])[:, None]
mask = mask.flatten(0,1) # [B*L_max+1]

# flatten from [B, L_max+1, emb_dim]
Expand Down

0 comments on commit b53a0ff

Please sign in to comment.