Skip to content

Commit

Permalink
Merge branch 'pocket-training-v103' into v103
Browse files Browse the repository at this point in the history
  • Loading branch information
jyaacoub authored Aug 7, 2024
2 parents 9f97d76 + c163778 commit 1aa39ef
Show file tree
Hide file tree
Showing 4 changed files with 146 additions and 59 deletions.
116 changes: 83 additions & 33 deletions playground.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,36 @@
# # %%
# import numpy as np
# import torch

# d = torch.load("/cluster/home/t122995uhn/projects/data/v131/DavisKibaDataset/davis/nomsa_aflow_original_binary/full/data_pro.pt")
# np.array(list(d['ABL1(F317I)p'].pro_seq))[d['ABL1(F317I)p'].pocket_mask].shape



# %%
# building pocket datasets:
from src.utils.pocket_alignment import pocket_dataset_full
import shutil
import os

data_dir = '/cluster/home/t122995uhn/projects/data/'
db_type = ['kiba', 'davis']
db_feat = ['nomsa_binary_original_binary', 'nomsa_aflow_original_binary',
'nomsa_binary_gvp_binary', 'nomsa_aflow_gvp_binary']

for t in db_type:
for f in db_feat:
print(f'\n---{t}-{f}---\n')
dataset_dir= f"{data_dir}/DavisKibaDataset/{t}/{f}/full"
save_dir = f"{data_dir}/v131/DavisKibaDataset/{t}/{f}/full"

pocket_dataset_full(
dataset_dir= dataset_dir,
pocket_dir = f"{data_dir}/{t}/",
save_dir = save_dir,
skip_download=True
)

#%%
import pandas as pd

Expand Down Expand Up @@ -32,50 +65,67 @@ def get_test_oncokbs(train_df=pd.read_csv('/cluster/home/t122995uhn/projects/dat

get_test_oncokbs(train_df=train_df)





#%%
########################################################################
########################## BUILD DATASETS ##############################
########################################################################
##############################################################################
########################## BUILD/SPLIT DATASETS ##############################
##############################################################################
import os
from src.data_prep.init_dataset import create_datasets
from src import cfg
import logging
cfg.logger.setLevel(logging.DEBUG)

splits = '/cluster/home/t122995uhn/projects/MutDTA/splits/davis/'
create_datasets([cfg.DATA_OPT.PDBbind, cfg.DATA_OPT.davis, cfg.DATA_OPT.kiba],
dbs = [cfg.DATA_OPT.davis, cfg.DATA_OPT.kiba]
splits = ['davis', 'kiba']
splits = ['/cluster/home/t122995uhn/projects/MutDTA/splits/' + s for s in splits]
print(splits)

#%%
for split, db in zip(splits, dbs):
print('\n',split, db)
create_datasets(db,
feat_opt=cfg.PRO_FEAT_OPT.nomsa,
edge_opt=[cfg.PRO_EDGE_OPT.binary, cfg.PRO_EDGE_OPT.aflow],
ligand_features=[cfg.LIG_FEAT_OPT.original, cfg.LIG_FEAT_OPT.gvp],
ligand_edges=cfg.LIG_EDGE_OPT.binary, overwrite=False,
k_folds=5,
test_prots_csv=f'{splits}/test.csv',
val_prots_csv=[f'{splits}/val{i}.csv' for i in range(5)],)
# data_root=os.path.abspath('../data/test/'))

# %% Copy splits to commit them:
#from to:
import shutil
from_dir_p = '/cluster/home/t122995uhn/projects/data/v131/'
to_dir_p = '/cluster/home/t122995uhn/projects/MutDTA/splits/'
from_db = ['PDBbindDataset', 'DavisKibaDataset/kiba', 'DavisKibaDataset/davis']
to_db = ['pdbbind', 'kiba', 'davis']

from_db = [f'{from_dir_p}/{f}/nomsa_binary_original_binary/' for f in from_db]
to_db = [f'{to_dir_p}/{f}' for f in to_db]

for src, dst in zip(from_db, to_db):
for x in ['train', 'val']:
for i in range(5):
print(f"{src}/{x}{i}/XY.csv", f"{dst}/{x}{i}.csv")
shutil.copy(f"{src}/{x}{i}/XY.csv", f"{dst}/{x}{i}.csv")

print(f"{src}/test/XY.csv", f"{dst}/test.csv")
shutil.copy(f"{src}/test/XY.csv", f"{dst}/test.csv")


test_prots_csv=f'{split}/test.csv',
val_prots_csv=[f'{split}/val{i}.csv' for i in range(5)])

#%% TEST INFERENCE
from src import cfg
from src.utils.loader import Loader

# db2 = Loader.load_dataset(cfg.DATA_OPT.davis,
# cfg.PRO_FEAT_OPT.nomsa, cfg.PRO_EDGE_OPT.aflow,
# path='/cluster/home/t122995uhn/projects/data/',
# subset="full")

db2 = Loader.load_DataLoaders(cfg.DATA_OPT.davis,
cfg.PRO_FEAT_OPT.nomsa, cfg.PRO_EDGE_OPT.aflow,
path='/cluster/home/t122995uhn/projects/data/v131',
training_fold=0,
batch_train=2)
for b2 in db2['test']: break


# %%
m = Loader.init_model(cfg.MODEL_OPT.DG, cfg.PRO_FEAT_OPT.nomsa, cfg.PRO_EDGE_OPT.aflow,
dropout=0.3480, output_dim=256,
)

#%%
# m(b['protein'], b['ligand'])
m(b2['protein'], b2['ligand'])
#%%
model = m
loaders = db2
device = 'cpu'
NUM_EPOCHS = 1
LEARNING_RATE = 0.001
from src.train_test.training import train

logs = train(model, loaders['train'], loaders['val'], device,
epochs=NUM_EPOCHS, lr_0=LEARNING_RATE)
# %%
5 changes: 4 additions & 1 deletion src/train_test/splitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,19 +343,22 @@ def resplit(dataset:str|BaseDataset, split_files:dict|str=None, use_train_set=Fa
split_files = split_files.copy()
test_prots = set(pd.read_csv(split_files['test'])['prot_id'])
test_idxs = [i for i in range(len(dataset.df)) if dataset.df.iloc[i]['prot_id'] in test_prots]
assert len(test_idxs) > 100, f"Error in splitting, not enough entries in test split - {split_files['test']}"
dataset.save_subset(test_idxs, 'test')
del split_files['test']

# Building the folds
for k, v in tqdm(split_files.items(), desc="Building folds from split files"):
prots = set(pd.read_csv(v)['prot_id'])
val_idxs = [i for i in range(len(dataset.df)) if dataset.df.iloc[i]['prot_id'] in prots]
val_idxs = [i for i in range(len(dataset.df)) if dataset.df.iloc[i]['prot_id'] in prots]
assert len(val_idxs) > 100, f"Error in splitting, not enough entries in {k} split - {v}"
dataset.save_subset(val_idxs, k)

if not use_train_set:
# Build training set from all proteins not in the val/test set
idxs = set(val_idxs + test_idxs)
train_idxs = [i for i in range(len(dataset.df)) if i not in idxs]
assert len(train_idxs) > 100, f"Error in splitting, not enough entries in train split"
dataset.save_subset(train_idxs, k.replace('val', 'train'))

return dataset
Expand Down
67 changes: 43 additions & 24 deletions src/utils/pocket_alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from Bio import Align
from Bio.Align import substitution_matrices
import numpy as np
import pandas as pd
import torch

Expand Down Expand Up @@ -78,26 +79,34 @@ def mask_graph(data, mask: list[bool]):
additional attributes:
-pocket_mask : list[bool]
The mask specified by the mask parameter of dimension [full_seuqence_length]
-pocket_mask_x : torch.Tensor
-x : torch.Tensor
The nodes of only the pocket of the protein sequence of dimension
[pocket_sequence_length, num_features]
-pocket_mask_edge_index : torch.Tensor
-edge_index : torch.Tensor
The edge connections in COO format only relating to
the pocket nodes of the protein sequence of dimension [2, num_pocket_edges]
"""
# node map for updating edge indicies after mask
node_map = np.cumsum(mask) - 1

nodes = data.x[mask]
edges = data.edge_index
edges = []
edge_mask = []
for i in range(edges.shape[1]):
# Throw out edges that are connected to at least one node not in the
# binding pocket
node_1, node_2 = edges[:,i][0], edges[:,i][1]
edge_mask.append(True) if mask[node_1] and mask[node_2] else edge_mask.append(False)
edges = torch.transpose(torch.transpose(edges, 0, 1)[edge_mask], 0, 1)
for i in range(data.edge_index.shape[1]):
# Throw out edges that are not part of connecting two nodes in the pocket...
node_1, node_2 = data.edge_index[:,i][0], data.edge_index[:,i][1]
if mask[node_1] and mask[node_2]:
# append mapped index:
edges.append([node_map[node_1], node_map[node_2]])
edge_mask.append(True)
else:
edge_mask.append(False)

data.x = nodes
data.pocket_mask = mask
data.pocket_mask_x = nodes
data.pocket_mask_edge_index = edges
data.edge_index = torch.tensor(edges).T # reshape to (2, E)
if 'edge_weight' in data:
data.edge_weight = data.edge_weight[edge_mask]
return data


Expand All @@ -122,7 +131,8 @@ def _parse_json(json_path: str) -> str:

def get_dataset_binding_pockets(
dataset_path: str = 'data/DavisKibaDataset/kiba/nomsa_binary_original_binary/full',
pockets_path: str = 'data/DavisKibaDataset/kiba_pocket'
pockets_path: str = 'data/DavisKibaDataset/kiba_pocket',
skip_download: bool = False,
) -> tuple[dict[str, str], set[str]]:
"""
Get all binding pocket sequences for a dataset
Expand All @@ -149,21 +159,27 @@ def get_dataset_binding_pockets(
# Strip out mutations and '-(alpha, beta, gamma)' tags if they are present,
# the binding pocket sequence will be the same for mutated and non-mutated genes
prot_ids = [id.split('(')[0].split('-')[0] for id in prot_ids]
dl = Downloader()
seq_save_dir = os.path.join(pockets_path, 'pockets')
os.makedirs(seq_save_dir, exist_ok=True)
download_check = dl.download_pocket_seq(prot_ids, seq_save_dir)

if not skip_download: # to use cached downloads only! (useful when on compute node)
dl = Downloader()
os.makedirs(seq_save_dir, exist_ok=True)
dl.download_pocket_seq(prot_ids, seq_save_dir)

download_errors = set()
for key, val in download_check.items():
if val == 400:
download_errors.add(key)
sequences = {}
for file in os.listdir(seq_save_dir):
pocket_seq = _parse_json(os.path.join(seq_save_dir, file))
if pocket_seq == 0 or len(pocket_seq) == 0:
download_errors.add(file.split('.')[0])
else:
sequences[file.split('.')[0]] = pocket_seq

# adding any remainder prots not downloaded.
for p in prot_ids:
if p not in sequences:
download_errors.add(p)

return (sequences, download_errors)


Expand Down Expand Up @@ -197,7 +213,7 @@ def create_binding_pocket_dataset(
new_data = mask_graph(data, mask)
new_dataset[id] = new_data
os.makedirs(os.path.dirname(new_dataset_path), exist_ok=True)
torch.save(dataset, new_dataset_path)
torch.save(new_dataset, new_dataset_path)


def binding_pocket_filter(dataset_csv_path: str, download_errors: set[str], csv_save_path: str):
Expand All @@ -215,16 +231,17 @@ def binding_pocket_filter(dataset_csv_path: str, download_errors: set[str], csv_
csv_save_path : str
The path to save the new CSV file to.
"""
df = pd.read_csv(dataset_csv_path)
df = df[~df['prot_id'].isin(download_errors)]
df = pd.read_csv(dataset_csv_path, index_col=0)
df = df[~df.prot_id.str.split('(').str[0].str.split('-').str[0].isin(download_errors)]
os.makedirs(os.path.dirname(csv_save_path), exist_ok=True)
df.to_csv(csv_save_path)


def pocket_dataset_full(
dataset_dir: str,
pocket_dir: str,
save_dir: str
save_dir: str,
skip_download: bool = False
) -> None:
"""
Create all elements of a dataset that includes binding pockets. This
Expand All @@ -240,7 +257,7 @@ def pocket_dataset_full(
save_dir : str
The path to where the new dataset is to be saved
"""
pocket_map, download_errors = get_dataset_binding_pockets(dataset_dir, pocket_dir)
pocket_map, download_errors = get_dataset_binding_pockets(dataset_dir, pocket_dir, skip_download)
print(f'Binding pocket sequences were not found for the following {len(download_errors)} protein IDs:')
print(','.join(list(download_errors)))
create_binding_pocket_dataset(
Expand All @@ -254,7 +271,9 @@ def pocket_dataset_full(
download_errors,
os.path.join(save_dir, 'cleaned_XY.csv')
)
shutil.copy2(os.path.join(dataset_dir, 'data_mol.pt'), os.path.join(save_dir, 'data_mol.pt'))
if dataset_dir != save_dir:
shutil.copy2(os.path.join(dataset_dir, 'data_mol.pt'), os.path.join(save_dir, 'data_mol.pt'))
shutil.copy2(os.path.join(dataset_dir, 'XY.csv'), os.path.join(save_dir, 'XY.csv'))


if __name__ == '__main__':
Expand Down
17 changes: 16 additions & 1 deletion train_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,22 @@
from src.utils.arg_parse import parse_train_test_args

args, unknown_args = parse_train_test_args(verbose=True,
jyp_args='-m DG -d PDBbind -f nomsa -e binary -bs 64')
jyp_args='--model_opt DG \
--data_opt davis \
\
--feature_opt nomsa \
--edge_opt binary \
--ligand_feature_opt original \
--ligand_edge_opt binary \
\
--learning_rate 0.00012 \
--batch_size 128 \
--dropout 0.24 \
--output_dim 128 \
\
--train \
--fold_selection 0 \
--num_epochs 2000')
FORCE_TRAINING = args.train
DEBUG = args.debug

Expand Down

0 comments on commit 1aa39ef

Please sign in to comment.