Skip to content

Commit

Permalink
update examples
Browse files Browse the repository at this point in the history
  • Loading branch information
sarahalamdari committed Sep 12, 2023
1 parent 1188a16 commit 002447c
Show file tree
Hide file tree
Showing 7 changed files with 3,235 additions and 692 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,7 @@ Our LRAR baseline models are also available:
* ` lr_ar_640M `


An example of generating one sequence randomly sampled from the train distribution length can be found in
An example of unconditionally generating a sequence of a specified length can be found in
[this notebook](https://github.com/microsoft/evodiff/tree/main/examples/evodiff.ipynb).

To evaluate the generated sequences, we implement our self-consistency Omegafold ESM-IF pipeline, as shown in
Expand Down
25 changes: 24 additions & 1 deletion evodiff/conditional_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,7 @@ def generate_scaffold(model, PDB_ID, motif_start_idxs, motif_end_idxs, scaffold_
p = torch.nn.functional.softmax(p, dim=1) # softmax over categorical probs
p_sample = torch.multinomial(p, num_samples=1)
sample[:, i] = p_sample.squeeze()
print("new sequence", [tokenizer.untokenize(s) for s in sample])
print("Generated sequence:", [tokenizer.untokenize(s) for s in sample])
untokenized = [tokenizer.untokenize(s) for s in sample]

return untokenized, new_start_idxs, new_end_idxs
Expand Down Expand Up @@ -484,6 +484,29 @@ def scramble_input(sequences, start_idxs, end_idxs):
sequences = [[s] for s in sequences]
return scrambled_seqs, sequences, scrambled_idrs, original_idrs, start_idxs, end_idxs

def inpaint_simple(model, sequence, start_idx, end_idx, tokenizer=Tokenizer(), device='cuda'):
"used in examples for simplicity"
all_aas = tokenizer.all_aas
idr_length = end_idx - start_idx
masked_sequence = sequence[0:start_idx] + '#' * idr_length + sequence[end_idx:]
tokenized_sequence = torch.tensor(tokenizer.tokenizeMSA(masked_sequence))

loc = np.arange(start_idx, end_idx)
sample = tokenized_sequence.to(torch.long)
sample = sample.to(device)
np.random.shuffle(loc)
with torch.no_grad():
for i in tqdm(loc):
timestep = torch.tensor([0]) # placeholder but not called in model
timestep = timestep.to(device)
prediction = model(sample.unsqueeze(0), timestep)
p = prediction[:, i, :len(all_aas)-6]
p = torch.nn.functional.softmax(p, dim=1)
p_sample = torch.multinomial(p, num_samples=1)
sample[i] = p_sample.squeeze()
untokenized_seq = tokenizer.untokenize(sample)
untokenized_idr = tokenizer.untokenize(sample[start_idx:end_idx])
return sample, untokenized_seq, untokenized_idr

def inpaint(model, tokenized_sequences, start_idxs, end_idxs, sequences, tokenizer=Tokenizer(), device='cuda', random_baseline=False, data_top_dir='/'):
if random_baseline:
Expand Down
72 changes: 72 additions & 0 deletions evodiff/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,78 @@
import os
from torch.utils.data import Subset

def subsample_msa(path_to_msa, n_sequences=64, max_seq_len=512, selection_type='random'):
alphabet = PROTEIN_ALPHABET
tokenizer = Tokenizer(alphabet)
alpha = np.array(list(alphabet))
gap_idx = tokenizer.alphabet.index(GAP)

if not os.path.exists(path_to_msa):
print("PATH TO MSA DOES NOT EXIST")
path = path_to_msa
parsed_msa = parse_fasta(path)

aligned_msa = [[char for char in seq if (char.isupper() or char == '-') and not char == '.'] for seq in
parsed_msa]
aligned_msa = [''.join(seq) for seq in aligned_msa]

tokenized_msa = [tokenizer.tokenizeMSA(seq) for seq in aligned_msa]
tokenized_msa = np.array([l.tolist() for l in tokenized_msa])
msa_seq_len = len(tokenized_msa[0])

if msa_seq_len > max_seq_len:
slice_start = np.random.choice(msa_seq_len - max_seq_len + 1)
seq_len = max_seq_len
else:
slice_start = 0
seq_len = msa_seq_len

# Slice to 512
sliced_msa_seq = tokenized_msa[:, slice_start: slice_start + max_seq_len]
anchor_seq = sliced_msa_seq[0] # This is the query sequence in MSA

# slice out all-gap rows
sliced_msa = [seq for seq in sliced_msa_seq if (list(set(seq)) != [gap_idx])]
msa_num_seqs = len(sliced_msa)

if msa_num_seqs < n_sequences:
output = np.full(shape=(n_sequences, seq_len), fill_value=tokenizer.pad_id)
output[:msa_num_seqs] = sliced_msa
raise Exception("msa num_seqs < self.n_sequences, indicates dataset not filtered properly")
elif msa_num_seqs > n_sequences:
if selection_type == 'random':
random_idx = np.random.choice(msa_num_seqs - 1, size=n_sequences - 1, replace=False) + 1
anchor_seq = np.expand_dims(anchor_seq, axis=0)
output = np.concatenate((anchor_seq, np.array(sliced_msa)[random_idx.astype(int)]), axis=0)
elif selection_type == "MaxHamming":
output = [list(anchor_seq)]
msa_subset = sliced_msa[1:]
msa_ind = np.arange(msa_num_seqs)[1:]
random_ind = np.random.choice(msa_ind)
random_seq = sliced_msa[random_ind]
output.append(list(random_seq))
random_seq = np.expand_dims(random_seq, axis=0)
msa_subset = np.delete(msa_subset, (random_ind - 1), axis=0)
m = len(msa_ind) - 1
distance_matrix = np.ones((n_sequences - 2, m))

for i in range(n_sequences - 2):
curr_dist = cdist(random_seq, msa_subset, metric='hamming')
curr_dist = np.expand_dims(np.array(curr_dist), axis=0) # shape is now (1,msa_num_seqs)
distance_matrix[i] = curr_dist
col_min = np.min(distance_matrix, axis=0) # (1,num_choices)
max_ind = np.argmax(col_min)
random_ind = max_ind
random_seq = msa_subset[random_ind]
output.append(list(random_seq))
random_seq = np.expand_dims(random_seq, axis=0)
msa_subset = np.delete(msa_subset, random_ind, axis=0)
distance_matrix = np.delete(distance_matrix, random_ind, axis=1)
else:
output = sliced_msa

output = [''.join(seq) for seq in alpha[output]]
return output, output[0]

def read_openfold_files(data_dir, filename):
"""
Expand Down
6 changes: 3 additions & 3 deletions evodiff/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def generate_oaardm(model, tokenizer, seq_len, penalty=None, batch_size=3, devic
loc = np.arange(seq_len)
np.random.shuffle(loc)
with torch.no_grad():
for i in loc:
for i in tqdm(loc):
timestep = torch.tensor([0] * batch_size) # placeholder but not called in model
timestep = timestep.to(device)
prediction = model(sample, timestep) #, input_mask=input_mask.unsqueeze(-1)) #sample prediction given input
Expand All @@ -217,7 +217,7 @@ def generate_oaardm(model, tokenizer, seq_len, penalty=None, batch_size=3, devic
def generate_autoreg(model, tokenizer, samples=100, batch_size=1, max_seq_len=1024):
# Generates 1 seq at a time, no batching, to make it easier to deal w variable seq lengths
# Generates until max length or until stop token is predicted
model.eval().cuda()
#model.eval().cuda()
device = model.device()

start = tokenizer.start_id
Expand Down Expand Up @@ -260,7 +260,7 @@ def generate_d3pm(model, tokenizer, Q, Q_bar, timesteps, seq_len, batch_size=3,
"""
Generate a random start string from uniform dist and convert to predictions
"""
model.eval().cuda()
#model.eval()
#device = model.device()

sample = torch.randint(0, tokenizer.K, (batch_size, seq_len))
Expand Down
60 changes: 46 additions & 14 deletions evodiff/generate-msa.py → evodiff/generate_msa.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,18 @@
import argparse
import json
import evodiff
import os
import numpy as np
import torch
import pandas as pd
from sequence_models.esm import MSATransformer
from sequence_models.constants import MSA_ALPHABET, MSA_PAD, MASK
from evodiff.utils import Tokenizer
from sequence_models.utils import parse_fasta
from evodiff.model import MSATransformerTime
from evodiff.data import read_idr_files
from tqdm import tqdm
import pathlib
import glob
import string

from evodiff.data import A3MMSADataset, IDRDataset
from torch.utils.data import Subset
from sequence_models.samplers import SortishSampler, ApproxBatchSampler
from torch.utils.data import DataLoader
import torch
from sequence_models.collaters import MSAAbsorbingCollater
from evodiff.collaters import D3PMCollaterMSA
from sequence_models.constants import MSA_ALPHABET
from evodiff.utils import Tokenizer
from scipy.spatial.distance import hamming, cdist

home = str(pathlib.Path.home())

def main():
Expand Down Expand Up @@ -235,6 +221,52 @@ def generate_msa(model, tokenizer, batch_size, n_sequences, seq_length, penalty_
untokenized = [[tokenizer.untokenize(msa.flatten())] for msa in sample]
return sample, untokenized # return output and untokenized output

def generate_query_oadm_msa_simple(path_to_msa, model, tokenizer, n_sequences, seq_length, batch_size=1, penalty_value=2, device='gpu',
start_msa=True, selection_type='MaxHamming'):
mask_id = tokenizer.mask_id
src = torch.full((batch_size, n_sequences, seq_length), fill_value=mask_id)

valid_msas = []
query_sequences = []
for i in range(batch_size):
#print(path_to_msa)
valid_msa, query_sequence = evodiff.data.subsample_msa(path_to_msa, n_sequences=n_sequences,
max_seq_len=seq_length, selection_type=selection_type)
valid_msa = torch.tensor(np.array([tokenizer.tokenizeMSA(seq) for seq in valid_msa]))
valid_msas.append(valid_msa)
query_sequences.append(query_sequence)

for i in range(batch_size):
seq_len = len(query_sequences[i])
src[i, 1:n_sequences, :seq_len] = valid_msas[i][1:n_sequences, :seq_len].squeeze()
padding = torch.full((n_sequences, seq_length-seq_len), fill_value=tokenizer.pad_id)
src[i, :, seq_len:] = padding
x_indices = np.arange(0,1)
y_indices = np.arange(seq_len)
src = src.to(device)
sample = src.clone()
if start_msa:
all_ind = np.transpose([np.tile(x_indices, len(y_indices)), np.repeat(y_indices, len(x_indices))])
np.random.shuffle(all_ind)

# ONLY USING ON BATCH_SIZE=1 for now
with torch.no_grad():
for i in tqdm(all_ind):
random_x, random_y = i
preds = model(sample) # Output shape of preds is (BS=1, N=64, L, n_tokens=31)
p = preds[:, random_x, random_y, :]
if random_x == 0 : # for first row don't let p_softmax predict gaps
p = preds[:, random_x, random_y, :tokenizer.K-1]
p_softmax = torch.nn.functional.softmax(p, dim=1)
# Penalize gaps
penalty = torch.ones(p.shape).to(p.device)
penalty[:, -1] += penalty_value
p_softmax /= penalty
p_sample = torch.multinomial(input=p_softmax, num_samples=1)
p_sample = p_sample.squeeze()
sample[:, random_x, random_y] = p_sample
untokenized = [[tokenizer.untokenize(msa[0])] for msa in sample] # return query sequence only
return sample, untokenized # return query sequences only

def generate_msa_d3pm(model, batch_size, n_sequences, seq_length, Q_bar=None, Q=None, tokenizer=Tokenizer(),
start_query=False, data_top_dir='../data', selection_type='MaxHamming', out_path='../ref/',
Expand Down
Loading

0 comments on commit 002447c

Please sign in to comment.