From 73cd7b19c04b78c076ba571de99a537ad31f6a1e Mon Sep 17 00:00:00 2001 From: jyaacoub Date: Fri, 2 Aug 2024 17:34:15 -0400 Subject: [PATCH 1/7] refactor(config): update issue number for pockets #103 --- src/utils/config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/utils/config.py b/src/utils/config.py index c4e8650..5d59804 100644 --- a/src/utils/config.py +++ b/src/utils/config.py @@ -87,7 +87,7 @@ class LIG_FEAT_OPT(StringEnum): from pathlib import Path # Model save paths -issue_number = 131 +issue_number = 103 DATA_BASENAME = f'data/{f"v{issue_number}" if issue_number else ""}' RESULTS_PATH = os.path.abspath(f'results/v{issue_number}/') MEDIA_SAVE_DIR = f'{RESULTS_PATH}/model_media/' From a8dce152c2f3857c622229a6542a4ed038d2d2c7 Mon Sep 17 00:00:00 2001 From: jyaacoub Date: Mon, 5 Aug 2024 21:33:18 -0400 Subject: [PATCH 2/7] fix(splitting): checks that split worked --- src/train_test/splitting.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/train_test/splitting.py b/src/train_test/splitting.py index ed7460c..6e3c779 100644 --- a/src/train_test/splitting.py +++ b/src/train_test/splitting.py @@ -343,19 +343,22 @@ def resplit(dataset:str|BaseDataset, split_files:dict|str=None, use_train_set=Fa split_files = split_files.copy() test_prots = set(pd.read_csv(split_files['test'])['prot_id']) test_idxs = [i for i in range(len(dataset.df)) if dataset.df.iloc[i]['prot_id'] in test_prots] + assert len(test_idxs) > 100, f"Error in splitting, not enough entries in test split - {split_files['test']}" dataset.save_subset(test_idxs, 'test') del split_files['test'] # Building the folds - for k, v in tqdm(split_files.items(), desc="Building folds from split files"): + for k, v in split_files.items(): prots = set(pd.read_csv(v)['prot_id']) - val_idxs = [i for i in range(len(dataset.df)) if dataset.df.iloc[i]['prot_id'] in prots] + val_idxs = [i for i in range(len(dataset.df)) if dataset.df.iloc[i]['prot_id'] in prots] + assert len(val_idxs) > 100, f"Error in splitting, not enough entries in {k} split - {v}" dataset.save_subset(val_idxs, k) if not use_train_set: # Build training set from all proteins not in the val/test set idxs = set(val_idxs + test_idxs) train_idxs = [i for i in range(len(dataset.df)) if i not in idxs] + assert len(train_idxs) > 100, f"Error in splitting, not enough entries in train split" dataset.save_subset(train_idxs, k.replace('val', 'train')) return dataset From c163778e917af781c0c1a99058ab8e4e94c5a87e Mon Sep 17 00:00:00 2001 From: jyaacoub Date: Wed, 7 Aug 2024 10:04:20 -0400 Subject: [PATCH 3/7] fix(pocket_alignment): in place modification, offline setup, and edge index renumbering #103 - Had to make some modifications since edge index needs to be updated after applying the mask so that it still points to the right nodes and we dont get something like an "IndexError" for being out of bounds - Also error due to not removing all proteins without pocket sequences (line 216 saved the old dataset instead of the new one). - Successfully built pocket datasets for davis and kiba #131 #103 --- playground.py | 112 +++++++++++++++++++++++++--------- src/utils/pocket_alignment.py | 67 ++++++++++++-------- train_test.py | 17 +++++- 3 files changed, 142 insertions(+), 54 deletions(-) diff --git a/playground.py b/playground.py index 1182518..ae013c9 100644 --- a/playground.py +++ b/playground.py @@ -1,3 +1,37 @@ +# # %% +# import numpy as np +# import torch + +# d = torch.load("/cluster/home/t122995uhn/projects/data/v131/DavisKibaDataset/davis/nomsa_aflow_original_binary/full/data_pro.pt") +# np.array(list(d['ABL1(F317I)p'].pro_seq))[d['ABL1(F317I)p'].pocket_mask].shape + + + +# %% +# building pocket datasets: +from src.utils.pocket_alignment import pocket_dataset_full +import shutil +import os + +data_dir = '/cluster/home/t122995uhn/projects/data/' +db_type = ['kiba', 'davis'] +db_feat = ['nomsa_binary_original_binary', 'nomsa_aflow_original_binary', + 'nomsa_binary_gvp_binary', 'nomsa_aflow_gvp_binary'] + +for t in db_type: + for f in db_feat: + print(f'\n---{t}-{f}---\n') + dataset_dir= f"{data_dir}/DavisKibaDataset/{t}/{f}/full" + save_dir = f"{data_dir}/v131/DavisKibaDataset/{t}/{f}/full" + + pocket_dataset_full( + dataset_dir= dataset_dir, + pocket_dir = f"{data_dir}/{t}/", + save_dir = save_dir, + skip_download=True + ) + + #%% import pandas as pd @@ -37,45 +71,65 @@ def get_test_oncokbs(train_df=pd.read_csv('/cluster/home/t122995uhn/projects/dat #%% -######################################################################## -########################## BUILD DATASETS ############################## -######################################################################## +############################################################################## +########################## BUILD/SPLIT DATASETS ############################## +############################################################################## import os from src.data_prep.init_dataset import create_datasets from src import cfg import logging cfg.logger.setLevel(logging.DEBUG) -splits = '/cluster/home/t122995uhn/projects/MutDTA/splits/davis/' -create_datasets([cfg.DATA_OPT.PDBbind, cfg.DATA_OPT.davis, cfg.DATA_OPT.kiba], +dbs = [cfg.DATA_OPT.davis, cfg.DATA_OPT.kiba] +splits = ['davis', 'kiba'] +splits = ['/cluster/home/t122995uhn/projects/MutDTA/splits/' + s for s in splits] +print(splits) + +#%% +for split, db in zip(splits, dbs): + print('\n',split, db) + create_datasets(db, feat_opt=cfg.PRO_FEAT_OPT.nomsa, edge_opt=[cfg.PRO_EDGE_OPT.binary, cfg.PRO_EDGE_OPT.aflow], ligand_features=[cfg.LIG_FEAT_OPT.original, cfg.LIG_FEAT_OPT.gvp], ligand_edges=cfg.LIG_EDGE_OPT.binary, overwrite=False, k_folds=5, - test_prots_csv=f'{splits}/test.csv', - val_prots_csv=[f'{splits}/val{i}.csv' for i in range(5)],) - # data_root=os.path.abspath('../data/test/')) + test_prots_csv=f'{split}/test.csv', + val_prots_csv=[f'{split}/val{i}.csv' for i in range(5)]) -# %% Copy splits to commit them: -#from to: -import shutil -from_dir_p = '/cluster/home/t122995uhn/projects/data/v131/' -to_dir_p = '/cluster/home/t122995uhn/projects/MutDTA/splits/' -from_db = ['PDBbindDataset', 'DavisKibaDataset/kiba', 'DavisKibaDataset/davis'] -to_db = ['pdbbind', 'kiba', 'davis'] - -from_db = [f'{from_dir_p}/{f}/nomsa_binary_original_binary/' for f in from_db] -to_db = [f'{to_dir_p}/{f}' for f in to_db] - -for src, dst in zip(from_db, to_db): - for x in ['train', 'val']: - for i in range(5): - print(f"{src}/{x}{i}/XY.csv", f"{dst}/{x}{i}.csv") - shutil.copy(f"{src}/{x}{i}/XY.csv", f"{dst}/{x}{i}.csv") - - print(f"{src}/test/XY.csv", f"{dst}/test.csv") - shutil.copy(f"{src}/test/XY.csv", f"{dst}/test.csv") - - +#%% TEST INFERENCE +from src import cfg +from src.utils.loader import Loader + +# db2 = Loader.load_dataset(cfg.DATA_OPT.davis, +# cfg.PRO_FEAT_OPT.nomsa, cfg.PRO_EDGE_OPT.aflow, +# path='/cluster/home/t122995uhn/projects/data/', +# subset="full") + +db2 = Loader.load_DataLoaders(cfg.DATA_OPT.davis, + cfg.PRO_FEAT_OPT.nomsa, cfg.PRO_EDGE_OPT.aflow, + path='/cluster/home/t122995uhn/projects/data/v131', + training_fold=0, + batch_train=2) +for b2 in db2['test']: break + + +# %% +m = Loader.init_model(cfg.MODEL_OPT.DG, cfg.PRO_FEAT_OPT.nomsa, cfg.PRO_EDGE_OPT.aflow, + dropout=0.3480, output_dim=256, + ) + +#%% +# m(b['protein'], b['ligand']) +m(b2['protein'], b2['ligand']) +#%% +model = m +loaders = db2 +device = 'cpu' +NUM_EPOCHS = 1 +LEARNING_RATE = 0.001 +from src.train_test.training import train + +logs = train(model, loaders['train'], loaders['val'], device, + epochs=NUM_EPOCHS, lr_0=LEARNING_RATE) # %% diff --git a/src/utils/pocket_alignment.py b/src/utils/pocket_alignment.py index 8ff30eb..c0bbec9 100644 --- a/src/utils/pocket_alignment.py +++ b/src/utils/pocket_alignment.py @@ -9,6 +9,7 @@ from Bio import Align from Bio.Align import substitution_matrices +import numpy as np import pandas as pd import torch @@ -78,26 +79,34 @@ def mask_graph(data, mask: list[bool]): additional attributes: -pocket_mask : list[bool] The mask specified by the mask parameter of dimension [full_seuqence_length] - -pocket_mask_x : torch.Tensor + -x : torch.Tensor The nodes of only the pocket of the protein sequence of dimension [pocket_sequence_length, num_features] - -pocket_mask_edge_index : torch.Tensor + -edge_index : torch.Tensor The edge connections in COO format only relating to the pocket nodes of the protein sequence of dimension [2, num_pocket_edges] """ + # node map for updating edge indicies after mask + node_map = np.cumsum(mask) - 1 + nodes = data.x[mask] - edges = data.edge_index + edges = [] edge_mask = [] - for i in range(edges.shape[1]): - # Throw out edges that are connected to at least one node not in the - # binding pocket - node_1, node_2 = edges[:,i][0], edges[:,i][1] - edge_mask.append(True) if mask[node_1] and mask[node_2] else edge_mask.append(False) - edges = torch.transpose(torch.transpose(edges, 0, 1)[edge_mask], 0, 1) + for i in range(data.edge_index.shape[1]): + # Throw out edges that are not part of connecting two nodes in the pocket... + node_1, node_2 = data.edge_index[:,i][0], data.edge_index[:,i][1] + if mask[node_1] and mask[node_2]: + # append mapped index: + edges.append([node_map[node_1], node_map[node_2]]) + edge_mask.append(True) + else: + edge_mask.append(False) + data.x = nodes data.pocket_mask = mask - data.pocket_mask_x = nodes - data.pocket_mask_edge_index = edges + data.edge_index = torch.tensor(edges).T # reshape to (2, E) + if 'edge_weight' in data: + data.edge_weight = data.edge_weight[edge_mask] return data @@ -122,7 +131,8 @@ def _parse_json(json_path: str) -> str: def get_dataset_binding_pockets( dataset_path: str = 'data/DavisKibaDataset/kiba/nomsa_binary_original_binary/full', - pockets_path: str = 'data/DavisKibaDataset/kiba_pocket' + pockets_path: str = 'data/DavisKibaDataset/kiba_pocket', + skip_download: bool = False, ) -> tuple[dict[str, str], set[str]]: """ Get all binding pocket sequences for a dataset @@ -149,14 +159,14 @@ def get_dataset_binding_pockets( # Strip out mutations and '-(alpha, beta, gamma)' tags if they are present, # the binding pocket sequence will be the same for mutated and non-mutated genes prot_ids = [id.split('(')[0].split('-')[0] for id in prot_ids] - dl = Downloader() seq_save_dir = os.path.join(pockets_path, 'pockets') - os.makedirs(seq_save_dir, exist_ok=True) - download_check = dl.download_pocket_seq(prot_ids, seq_save_dir) + + if not skip_download: # to use cached downloads only! (useful when on compute node) + dl = Downloader() + os.makedirs(seq_save_dir, exist_ok=True) + dl.download_pocket_seq(prot_ids, seq_save_dir) + download_errors = set() - for key, val in download_check.items(): - if val == 400: - download_errors.add(key) sequences = {} for file in os.listdir(seq_save_dir): pocket_seq = _parse_json(os.path.join(seq_save_dir, file)) @@ -164,6 +174,12 @@ def get_dataset_binding_pockets( download_errors.add(file.split('.')[0]) else: sequences[file.split('.')[0]] = pocket_seq + + # adding any remainder prots not downloaded. + for p in prot_ids: + if p not in sequences: + download_errors.add(p) + return (sequences, download_errors) @@ -197,7 +213,7 @@ def create_binding_pocket_dataset( new_data = mask_graph(data, mask) new_dataset[id] = new_data os.makedirs(os.path.dirname(new_dataset_path), exist_ok=True) - torch.save(dataset, new_dataset_path) + torch.save(new_dataset, new_dataset_path) def binding_pocket_filter(dataset_csv_path: str, download_errors: set[str], csv_save_path: str): @@ -215,8 +231,8 @@ def binding_pocket_filter(dataset_csv_path: str, download_errors: set[str], csv_ csv_save_path : str The path to save the new CSV file to. """ - df = pd.read_csv(dataset_csv_path) - df = df[~df['prot_id'].isin(download_errors)] + df = pd.read_csv(dataset_csv_path, index_col=0) + df = df[~df.prot_id.str.split('(').str[0].str.split('-').str[0].isin(download_errors)] os.makedirs(os.path.dirname(csv_save_path), exist_ok=True) df.to_csv(csv_save_path) @@ -224,7 +240,8 @@ def binding_pocket_filter(dataset_csv_path: str, download_errors: set[str], csv_ def pocket_dataset_full( dataset_dir: str, pocket_dir: str, - save_dir: str + save_dir: str, + skip_download: bool = False ) -> None: """ Create all elements of a dataset that includes binding pockets. This @@ -240,7 +257,7 @@ def pocket_dataset_full( save_dir : str The path to where the new dataset is to be saved """ - pocket_map, download_errors = get_dataset_binding_pockets(dataset_dir, pocket_dir) + pocket_map, download_errors = get_dataset_binding_pockets(dataset_dir, pocket_dir, skip_download) print(f'Binding pocket sequences were not found for the following {len(download_errors)} protein IDs:') print(','.join(list(download_errors))) create_binding_pocket_dataset( @@ -254,7 +271,9 @@ def pocket_dataset_full( download_errors, os.path.join(save_dir, 'cleaned_XY.csv') ) - shutil.copy2(os.path.join(dataset_dir, 'data_mol.pt'), os.path.join(save_dir, 'data_mol.pt')) + if dataset_dir != save_dir: + shutil.copy2(os.path.join(dataset_dir, 'data_mol.pt'), os.path.join(save_dir, 'data_mol.pt')) + shutil.copy2(os.path.join(dataset_dir, 'XY.csv'), os.path.join(save_dir, 'XY.csv')) if __name__ == '__main__': diff --git a/train_test.py b/train_test.py index 8bc9bba..d98b730 100644 --- a/train_test.py +++ b/train_test.py @@ -2,7 +2,22 @@ from src.utils.arg_parse import parse_train_test_args args, unknown_args = parse_train_test_args(verbose=True, - jyp_args='-m DG -d PDBbind -f nomsa -e binary -bs 64') + jyp_args='--model_opt DG \ + --data_opt davis \ + \ + --feature_opt nomsa \ + --edge_opt binary \ + --ligand_feature_opt original \ + --ligand_edge_opt binary \ + \ + --learning_rate 0.00012 \ + --batch_size 128 \ + --dropout 0.24 \ + --output_dim 128 \ + \ + --train \ + --fold_selection 0 \ + --num_epochs 2000') FORCE_TRAINING = args.train DEBUG = args.debug From a70a44aa241698f32bbd097e6eac2f79ba785d37 Mon Sep 17 00:00:00 2001 From: jyaacoub Date: Mon, 12 Aug 2024 10:42:16 -0400 Subject: [PATCH 4/7] results(davis): pocket version on DG and aflow #103 Aflow still underperforms here... --- playground.py | 134 ++++++++++++++++++- results/v103/model_media/model_stats.csv | 11 ++ results/v103/model_media/model_stats_val.csv | 11 ++ train_test.py | 4 +- 4 files changed, 153 insertions(+), 7 deletions(-) create mode 100644 results/v103/model_media/model_stats.csv create mode 100644 results/v103/model_media/model_stats_val.csv diff --git a/playground.py b/playground.py index c2d110a..cd08998 100644 --- a/playground.py +++ b/playground.py @@ -1,9 +1,133 @@ -# # %% -# import numpy as np -# import torch +#%% +import pandas as pd +import matplotlib.pyplot as plt +import seaborn as sns + +new = '/cluster/home/t122995uhn/projects/splits/new/pdbbind/' + +train_df = pd.concat([pd.read_csv(f'{new}train0.csv'), + pd.read_csv(f'{new}val0.csv')], axis=0) +test_df = pd.read_csv(f'{new}test.csv') + +all_df = pd.concat([train_df, test_df], axis=0) +print(len(all_df)) + + +#%% +old = '/cluster/home/t122995uhn/projects/splits/old/pdbbind/' +old_test_df = pd.read_csv(f'{old}test.csv') +old_train_df = all_df[~all_df['code'].isin(old_test_df['code'])] + +# %% +# this will give us an estimate to how well targeted the training proteins are vs the test proteins +def proteins_targeted(train_df, test_df, split='new', min_freq=0, normalized=False): + # protein count comparison (number of diverse proteins) + plt.figure(figsize=(18,8)) + # x-axis is the normalized frequency, y-axis is the number of proteins that have that frequency (also normalized) + vc = train_df.prot_id.value_counts() + vc = vc[vc > min_freq] + train_counts = list(vc/len(test_df)) if normalized else vc.values + vc = test_df.prot_id.value_counts() + vc = vc[vc > min_freq] + test_counts = list(vc/len(test_df)) if normalized else vc.values + + sns.histplot(train_counts, + bins=50, stat='density', color='green', alpha=0.4) + sns.histplot(test_counts, + bins=50,stat='density', color='blue', alpha=0.4) + + sns.kdeplot(train_counts, color='green', alpha=0.8) + sns.kdeplot(test_counts, color='blue', alpha=0.8) + + plt.xlabel(f"{'normalized ' if normalized else ''} frequency") + plt.ylabel("normalized number of proteins with that frequency") + plt.title(f"Targeted differences for {split} split{f' (> {min_freq})' if min_freq else ''}") + if not normalized: + plt.xlim(-8,100) + +# proteins_targeted(old_train_df, old_test_df, split='oncoKB') +# plt.show() +# proteins_targeted(train_df, test_df, split='random') +# plt.show() + + +proteins_targeted(old_test_df, test_df, split='oncoKB(green) vs random(blue) test') +plt.show() +proteins_targeted(old_test_df, test_df, split='oncoKB(green) vs random(blue) test', min_freq=5) +plt.show() +proteins_targeted(old_test_df, test_df, split='oncoKB(green) vs random(blue) test', min_freq=10) +plt.show() +proteins_targeted(old_test_df, test_df, split='oncoKB(green) vs random(blue) test', min_freq=15) +plt.show() +proteins_targeted(old_test_df, test_df, split='oncoKB(green) vs random(blue) test', min_freq=20) +plt.show() +# proteins_targeted(old_train_df, train_df, split='oncoKB(green) vs random train') +# plt.show() +#%% sequence length comparison +def seq_kde(all_df, train_df, test_df, split='new'): + plt.figure(figsize=(12, 8)) + + sns.kdeplot(all_df.prot_seq.str.len().reset_index()['prot_seq'], label='All', color='blue') + sns.kdeplot(train_df.prot_seq.str.len().reset_index()['prot_seq'], label='Train', color='green') + sns.kdeplot(test_df.prot_seq.str.len().reset_index()['prot_seq'], label='Test', color='red') + + plt.xlabel('Sequence Length') + plt.ylabel('Density') + plt.title(f'Sequence Length Distribution ({split} split)') + plt.legend() + +seq_kde(all_df,train_df,test_df, split='new') +plt.show() +seq_kde(all_df,old_train_df,old_test_df, split='old') + +# %% +from Bio import pairwise2 +from Bio.Seq import Seq +from Bio.SeqRecord import SeqRecord +from Bio.Align import substitution_matrices + +from tqdm import tqdm +import random + +def get_group_similarity(group1, group2): + # Choose a substitution matrix (e.g., BLOSUM62) + matrix = substitution_matrices.load("BLOSUM62") + + # Define gap penalties + gap_open = -10 + gap_extend = -0.5 + + # Function to calculate pairwise similarity score + def calculate_similarity(seq1, seq2): + alignments = pairwise2.align.globalds(seq1, seq2, matrix, gap_open, gap_extend) + return alignments[0][2] # Return the score of the best alignment + + # Compute pairwise similarity between all sequences in group1 and group2 + similarity_scores = [] + for seq1 in group1: + for seq2 in group2: + score = calculate_similarity(seq1, seq2) + similarity_scores.append(score) + + # Calculate the average similarity score + average_similarity = sum(similarity_scores) / len(similarity_scores) + return similarity_scores, average_similarity + + +# sample 10 sequences randomly 100x +train_seq = old_train_df.prot_seq.drop_duplicates().to_list() +test_seq = old_test_df.prot_seq.drop_duplicates().to_list() +sample_size = 5 +trials = 100 + +est_similarity = 0 +for _ in tqdm(range(trials)): + _, avg = get_group_similarity(random.sample(train_seq, sample_size), + random.sample(test_seq, sample_size)) + est_similarity += avg + +print(est_similarity/1000) -# d = torch.load("/cluster/home/t122995uhn/projects/data/v131/DavisKibaDataset/davis/nomsa_aflow_original_binary/full/data_pro.pt") -# np.array(list(d['ABL1(F317I)p'].pro_seq))[d['ABL1(F317I)p'].pocket_mask].shape diff --git a/results/v103/model_media/model_stats.csv b/results/v103/model_media/model_stats.csv new file mode 100644 index 0000000..822830d --- /dev/null +++ b/results/v103/model_media/model_stats.csv @@ -0,0 +1,11 @@ +run,cindex,pearson,spearman,mse,mae,rmse +DGM_davis0D_nomsaF_binaryE_128B_0.00012LR_0.24D_2000E_originalLF_binaryLE,0.8174366987963239,0.6808973439070014,0.5780986864623106,0.374029119754687,0.3416232488841833,0.6115792015386781 +DGM_davis1D_nomsaF_binaryE_128B_0.00012LR_0.24D_2000E_originalLF_binaryLE,0.8359138385559401,0.7212884148849212,0.6093121108415754,0.3444294398275105,0.3380570360012467,0.5868811121747832 +DGM_davis2D_nomsaF_binaryE_128B_0.00012LR_0.24D_2000E_originalLF_binaryLE,0.811306156371881,0.6771836874485692,0.5650256869521153,0.3933000326926663,0.333361968167426,0.6271363748760442 +DGM_davis3D_nomsaF_binaryE_128B_0.00012LR_0.24D_2000E_originalLF_binaryLE,0.8148631243802541,0.717113315384429,0.571925536761479,0.3422128815756367,0.3177703711270548,0.5849896422806448 +DGM_davis4D_nomsaF_binaryE_128B_0.00012LR_0.24D_2000E_originalLF_binaryLE,0.8196459665927316,0.694403802004145,0.5825760745508323,0.3702764201890446,0.33563001218595,0.6085034266041931 +DGM_davis0D_nomsaF_aflowE_128B_0.0008279387625584954LR_0.3480347297724069D_2000E_originalLF_binaryLE,0.7936510071795485,0.628767072325098,0.5217398281378556,0.3566859747000747,0.3591853688744937,0.597231927060229 +DGM_davis1D_nomsaF_aflowE_128B_0.0008279387625584954LR_0.3480347297724069D_2000E_originalLF_binaryLE,0.8009815205097928,0.6035635252189794,0.5304746622864567,0.4253406250688673,0.364227359902625,0.6521814356978182 +DGM_davis3D_nomsaF_aflowE_128B_0.0008279387625584954LR_0.3480347297724069D_2000E_originalLF_binaryLE,0.7783955876418098,0.5816462981556966,0.4961723044095886,0.4376154312774337,0.3656365177210639,0.6615250798552038 +DGM_davis4D_nomsaF_aflowE_128B_0.0008279387625584954LR_0.3480347297724069D_2000E_originalLF_binaryLE,0.8181456735871336,0.6918684941945846,0.56229516172368,0.3071043302279289,0.2969707269294589,0.554169947063109 +DGM_davis2D_nomsaF_aflowE_128B_0.0008279387625584954LR_0.3480347297724069D_2000E_originalLF_binaryLE,0.8014907154138579,0.6425965261636467,0.5354462017864902,0.3606209315377456,0.3375259168007795,0.6005172200176657 diff --git a/results/v103/model_media/model_stats_val.csv b/results/v103/model_media/model_stats_val.csv new file mode 100644 index 0000000..e1d1ab8 --- /dev/null +++ b/results/v103/model_media/model_stats_val.csv @@ -0,0 +1,11 @@ +run,cindex,pearson,spearman,mse,mae,rmse +DGM_davis0D_nomsaF_binaryE_128B_0.00012LR_0.24D_2000E_originalLF_binaryLE,0.8507053734550806,0.7688628504779598,0.6689225345680122,0.3760747658599554,0.3388000398874283,0.6132493504765867 +DGM_davis1D_nomsaF_binaryE_128B_0.00012LR_0.24D_2000E_originalLF_binaryLE,0.8718442303414345,0.8308115505911805,0.7173863620955029,0.323120446450846,0.3234809194096876,0.5684368447337365 +DGM_davis2D_nomsaF_binaryE_128B_0.00012LR_0.24D_2000E_originalLF_binaryLE,0.8818149976678145,0.834014760655388,0.7187113282294693,0.3071136635927556,0.2922643621762593,0.5541783680303262 +DGM_davis3D_nomsaF_binaryE_128B_0.00012LR_0.24D_2000E_originalLF_binaryLE,0.8686054788183053,0.828507036778059,0.7018974086625753,0.3046836030153428,0.3018857493804209,0.5519815241612194 +DGM_davis4D_nomsaF_binaryE_128B_0.00012LR_0.24D_2000E_originalLF_binaryLE,0.8510324353139875,0.7912085758636695,0.6660120299194481,0.3556152282825756,0.3246624081237138,0.5963348290034514 +DGM_davis0D_nomsaF_aflowE_128B_0.0008279387625584954LR_0.3480347297724069D_2000E_originalLF_binaryLE,0.8224153103333314,0.7079363892542606,0.598653291929885,0.390209011583234,0.3662272181520597,0.6246671206196417 +DGM_davis1D_nomsaF_aflowE_128B_0.0008279387625584954LR_0.3480347297724069D_2000E_originalLF_binaryLE,0.8493586014591634,0.7889115161443931,0.673101173187512,0.385586498014305,0.3309167098727579,0.6209561160132856 +DGM_davis3D_nomsaF_aflowE_128B_0.0008279387625584954LR_0.3480347297724069D_2000E_originalLF_binaryLE,0.8337160086816762,0.736812322417127,0.6264321347273434,0.3912511533576103,0.3474602306165372,0.6255007221079847 +DGM_davis4D_nomsaF_aflowE_128B_0.0008279387625584954LR_0.3480347297724069D_2000E_originalLF_binaryLE,0.8463969226855363,0.7444417955240732,0.6410258059445946,0.3682857548365584,0.3208404985420844,0.6068655162690977 +DGM_davis2D_nomsaF_aflowE_128B_0.0008279387625584954LR_0.3480347297724069D_2000E_originalLF_binaryLE,0.8485013700858233,0.7861129348915608,0.6464621130340457,0.36031183643520487,0.33436829519000405,0.6002598074460799 diff --git a/train_test.py b/train_test.py index d98b730..486b905 100644 --- a/train_test.py +++ b/train_test.py @@ -61,6 +61,7 @@ cp_saver = CheckpointSaver(model=None, save_path=None, train_all=False, # forces full training + min_delta=0.2, patience=100) # %% Training loop @@ -103,14 +104,13 @@ ligand_feature=ligand_feature, ligand_edge=ligand_edge, **unknown_args).to(device) cp_saver.new_model(model, save_path=model_save_p) - cp_saver.min_delta = 0.2 if DATA == cfg.DATA_OPT.PDBbind else 0.05 + cp_saver.min_delta = 0.2 if DATA == cfg.DATA_OPT.PDBbind else 0.03 if DEBUG: # run single batch through model debug(model, loaders['train'], device) continue # skip training - # ==== TRAINING ==== # check if model has already been trained: logs = None From 862dffc1467371da643403ded61978deeca91054 Mon Sep 17 00:00:00 2001 From: jyaacoub Date: Fri, 16 Aug 2024 13:42:24 -0400 Subject: [PATCH 5/7] fix(esm): apply pocket mask to ESM embeddings #103 --- src/models/esm_models.py | 11 +++++++++-- src/utils/config.py | 2 +- src/utils/loader.py | 4 ++-- 3 files changed, 12 insertions(+), 5 deletions(-) diff --git a/src/models/esm_models.py b/src/models/esm_models.py index 3c9504a..7a3c94c 100644 --- a/src/models/esm_models.py +++ b/src/models/esm_models.py @@ -18,7 +18,7 @@ class EsmDTA(BaseModel): def __init__(self, esm_head:str='facebook/esm2_t6_8M_UR50D', num_features_pro=320, pro_emb_dim=54, num_features_mol=78, output_dim=128, dropout=0.2, pro_feat='esm_only', edge_weight_opt='binary', - dropout_prot=0.0, pro_extra_fc_lyr=False): + dropout_prot=0.0, pro_extra_fc_lyr=False, **kwargs): super(EsmDTA, self).__init__(pro_feat, edge_weight_opt) @@ -75,13 +75,20 @@ def forward_pro(self, data): # removing token by applying mask L_max = esm_emb.shape[1] # L_max+1 - mask = torch.arange(L_max)[None, :] < torch.tensor([len(seq) for seq in data.pro_seq])[:, None] + mask = torch.arange(L_max)[None, :] < torch.tensor([len(seq) for seq in data.pro_seq])[:, None] mask = mask.flatten(0,1) # [B*L_max+1] # flatten from [B, L_max+1, emb_dim] esm_emb = esm_emb.flatten(0,1) # to [B*L_max+1, emb_dim] esm_emb = esm_emb[mask] # [B*L, emb_dim] + # applying pocket mask if relevant + if "pocket_mask" in data: + # pocket mask must be a flattened mask from [B,L] to [B*L] + mask = [torch.tensor(d, dtype=torch.bool) for d in data.pocket_mask]# [B, L] + mask = torch.cat(mask) + esm_emb = esm_emb[mask] + if self.esm_only: target_x = esm_emb # [B*L, emb_dim] else: diff --git a/src/utils/config.py b/src/utils/config.py index 5d59804..fd90ab9 100644 --- a/src/utils/config.py +++ b/src/utils/config.py @@ -108,7 +108,7 @@ class LIG_FEAT_OPT(StringEnum): SLURM_ACCOUNT = None SLURM_GPU_NAME = 'v100' -if 'uhnh4h' in DOMAIN_NAME: +if ('uhnh4h' in DOMAIN_NAME) or ('h4h' in DOMAIN_NAME): CLUSTER = 'h4h' SLURM_PARTITION = 'gpu' SLURM_CONSTRAINT = 'gpu32g' diff --git a/src/utils/loader.py b/src/utils/loader.py index cbc646a..a02b082 100644 --- a/src/utils/loader.py +++ b/src/utils/loader.py @@ -291,7 +291,7 @@ def load_DataLoaders(data:str=None, pro_feature:str=None, edge_opt:str=None, pat bs = 1 if d == 'test' else batch_train loader = DataLoader(dataset=loaded_datasets[d], batch_size=bs, - shuffle=False) + shuffle=True) loaders[d] = loader return loaders @@ -327,7 +327,7 @@ def load_distributed_DataLoaders(num_replicas:int, rank:int, seed:int, data:str, sampler=sampler, batch_size=bs, # should be per gpu batch size (local batch size) num_workers=num_workers, - shuffle=False, + shuffle=True, pin_memory=True, drop_last=True) # drop last batch if not divisible by batch size loaders[d] = loader From 78333345d44f6c985957ee9ebf50502266793046 Mon Sep 17 00:00:00 2001 From: jyaacoub Date: Mon, 19 Aug 2024 13:34:33 -0400 Subject: [PATCH 6/7] fix(loader): no support for shuffle with DDP --- src/utils/loader.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/utils/loader.py b/src/utils/loader.py index a02b082..df33ae2 100644 --- a/src/utils/loader.py +++ b/src/utils/loader.py @@ -318,7 +318,7 @@ def load_distributed_DataLoaders(num_replicas:int, rank:int, seed:int, data:str, loaders = {} for d in loaded_datasets: dataset = loaded_datasets[d] - sampler = DistributedSampler(dataset, shuffle=True, + sampler = DistributedSampler(dataset, shuffle=False, num_replicas=num_replicas, rank=rank, seed=seed) @@ -327,7 +327,7 @@ def load_distributed_DataLoaders(num_replicas:int, rank:int, seed:int, data:str, sampler=sampler, batch_size=bs, # should be per gpu batch size (local batch size) num_workers=num_workers, - shuffle=True, + shuffle=False, # mut exclusive with DDP pin_memory=True, drop_last=True) # drop last batch if not divisible by batch size loaders[d] = loader @@ -418,4 +418,4 @@ def wrapper(*args, **kwargs): # Return the function call output return func(*args, **kwargs) return wrapper - return decorator \ No newline at end of file + return decorator From 9b92565244dd9360a5e0b42f9e8240bc7deef8f8 Mon Sep 17 00:00:00 2001 From: jyaacoub Date: Tue, 20 Aug 2024 09:39:03 -0400 Subject: [PATCH 7/7] results(kiba): DG and aflow pocket #103 --- results/v103/model_media/model_stats.csv | 11 +++++++ results/v103/model_media/model_stats_val.csv | 13 +++++++- train_test.py | 34 +++++++++++--------- 3 files changed, 41 insertions(+), 17 deletions(-) diff --git a/results/v103/model_media/model_stats.csv b/results/v103/model_media/model_stats.csv index 822830d..d4227f5 100644 --- a/results/v103/model_media/model_stats.csv +++ b/results/v103/model_media/model_stats.csv @@ -9,3 +9,14 @@ DGM_davis1D_nomsaF_aflowE_128B_0.0008279387625584954LR_0.3480347297724069D_2000E DGM_davis3D_nomsaF_aflowE_128B_0.0008279387625584954LR_0.3480347297724069D_2000E_originalLF_binaryLE,0.7783955876418098,0.5816462981556966,0.4961723044095886,0.4376154312774337,0.3656365177210639,0.6615250798552038 DGM_davis4D_nomsaF_aflowE_128B_0.0008279387625584954LR_0.3480347297724069D_2000E_originalLF_binaryLE,0.8181456735871336,0.6918684941945846,0.56229516172368,0.3071043302279289,0.2969707269294589,0.554169947063109 DGM_davis2D_nomsaF_aflowE_128B_0.0008279387625584954LR_0.3480347297724069D_2000E_originalLF_binaryLE,0.8014907154138579,0.6425965261636467,0.5354462017864902,0.3606209315377456,0.3375259168007795,0.6005172200176657 +DGM_kiba0D_nomsaF_binaryE_128B_0.0001LR_0.4D_2000E_originalLF_binaryLE,0.7476189416085207,0.7148917008987766,0.6299877614860792,0.3746319657859179,0.3958356694230301,0.6120718632529336 +DGM_kiba1D_nomsaF_binaryE_128B_0.0001LR_0.4D_2000E_originalLF_binaryLE,0.7073391610819149,0.624956249151526,0.5401876728173656,0.4451318825041403,0.458846963456725,0.667182045999546 +DGM_kiba2D_nomsaF_binaryE_128B_0.0001LR_0.4D_2000E_originalLF_binaryLE,0.7401141841678894,0.6795148074510864,0.6127459332278625,0.4004100160666026,0.4095781581139723,0.6327795951724444 +DGM_kiba3D_nomsaF_binaryE_128B_0.0001LR_0.4D_2000E_originalLF_binaryLE,0.7396234368040389,0.6913457932090825,0.6201197126448974,0.3934219012917641,0.4068530834848238,0.6272335301080165 +DGM_kiba4D_nomsaF_binaryE_128B_0.0001LR_0.4D_2000E_originalLF_binaryLE,0.752441708545282,0.7025492844189518,0.6449954833411846,0.3728163774990898,0.4045171920082104,0.6105869123221442 +DGM_kiba0D_nomsaF_aflowE_64B_0.0001139464546302261LR_0.4321620419748407D_2000E_originalLF_binaryLE,0.7599929872587803,0.7067412429690916,0.6593355592769512,0.3962219319168832,0.4099100126533609,0.6294616206861887 +DGM_kiba2D_nomsaF_aflowE_64B_0.0001139464546302261LR_0.4321620419748407D_2000E_originalLF_binaryLE,0.7149604681753393,0.6152047008431843,0.5597795125500629,0.4741719822054008,0.4603646989542154,0.6886014683439187 +DGM_kiba1D_nomsaF_aflowE_64B_0.0001139464546302261LR_0.4321620419748407D_2000E_originalLF_binaryLE,0.7140873472783476,0.6102548954720128,0.5558196740209606,0.4781851688315759,0.4659358458753446,0.6915093411021834 +DGM_kiba3D_nomsaF_aflowE_64B_0.0001139464546302261LR_0.4321620419748407D_2000E_originalLF_binaryLE,0.7164547158247304,0.6084847523640808,0.5607065445063388,0.4802083760845744,0.4646035882672965,0.6929706891958523 +DGM_kiba4D_nomsaF_aflowE_64B_0.0001139464546302261LR_0.4321620419748407D_2000E_originalLF_binaryLE,0.7687577053257117,0.7532822502738942,0.6745267167129126,0.3466135049736077,0.3887611475832294,0.5887389107011765 +EDIM_davis0D_nomsaF_binaryE_48B_0.0001LR_0.4D_2000E_originalLF_binaryLE,0.8114566611493259,0.7317647125777735,0.6044949818493646,0.3736163373086704,0.3493916191183746,0.611241635778086 diff --git a/results/v103/model_media/model_stats_val.csv b/results/v103/model_media/model_stats_val.csv index e1d1ab8..1aba5af 100644 --- a/results/v103/model_media/model_stats_val.csv +++ b/results/v103/model_media/model_stats_val.csv @@ -8,4 +8,15 @@ DGM_davis0D_nomsaF_aflowE_128B_0.0008279387625584954LR_0.3480347297724069D_2000E DGM_davis1D_nomsaF_aflowE_128B_0.0008279387625584954LR_0.3480347297724069D_2000E_originalLF_binaryLE,0.8493586014591634,0.7889115161443931,0.673101173187512,0.385586498014305,0.3309167098727579,0.6209561160132856 DGM_davis3D_nomsaF_aflowE_128B_0.0008279387625584954LR_0.3480347297724069D_2000E_originalLF_binaryLE,0.8337160086816762,0.736812322417127,0.6264321347273434,0.3912511533576103,0.3474602306165372,0.6255007221079847 DGM_davis4D_nomsaF_aflowE_128B_0.0008279387625584954LR_0.3480347297724069D_2000E_originalLF_binaryLE,0.8463969226855363,0.7444417955240732,0.6410258059445946,0.3682857548365584,0.3208404985420844,0.6068655162690977 -DGM_davis2D_nomsaF_aflowE_128B_0.0008279387625584954LR_0.3480347297724069D_2000E_originalLF_binaryLE,0.8485013700858233,0.7861129348915608,0.6464621130340457,0.36031183643520487,0.33436829519000405,0.6002598074460799 +DGM_davis2D_nomsaF_aflowE_128B_0.0008279387625584954LR_0.3480347297724069D_2000E_originalLF_binaryLE,0.8485013700858233,0.7861129348915608,0.6464621130340457,0.3603118364352048,0.334368295190004,0.6002598074460799 +DGM_kiba0D_nomsaF_binaryE_128B_0.0001LR_0.4D_2000E_originalLF_binaryLE,0.76838489785244,0.7349294201300529,0.6720189966892385,0.2870000493840723,0.36449329645463846,0.5357238555301344 +DGM_kiba1D_nomsaF_binaryE_128B_0.0001LR_0.4D_2000E_originalLF_binaryLE,0.6830895888292509,0.6256216928279031,0.4862834063605662,0.4635991200460103,0.4637030257199048,0.6808811350346038 +DGM_kiba2D_nomsaF_binaryE_128B_0.0001LR_0.4D_2000E_originalLF_binaryLE,0.7424018620084226,0.7064897658791767,0.6113374073010096,0.3869704646200874,0.4214332353778002,0.6220695014386153 +DGM_kiba3D_nomsaF_binaryE_128B_0.0001LR_0.4D_2000E_originalLF_binaryLE,0.7887989201757725,0.7493602745702617,0.7146442736012206,0.2624049771770561,0.3196535898269914,0.5122547971244936 +DGM_kiba4D_nomsaF_binaryE_128B_0.0001LR_0.4D_2000E_originalLF_binaryLE,0.805961202163743,0.797223082421482,0.7422315375449509,0.2469041691088263,0.3146771252445765,0.4968945251346872 +DGM_kiba0D_nomsaF_aflowE_64B_0.0001139464546302261LR_0.4321620419748407D_2000E_originalLF_binaryLE,0.7414150463245769,0.7180513866946735,0.6154427990455673,0.2477156183512984,0.352769788125809,0.4977103759731139 +DGM_kiba2D_nomsaF_aflowE_64B_0.0001139464546302261LR_0.4321620419748407D_2000E_originalLF_binaryLE,0.7017734270681899,0.6190248117265895,0.5234448165080476,0.4732107505692587,0.4709162900019082,0.6879031549348054 +DGM_kiba1D_nomsaF_aflowE_64B_0.0001139464546302261LR_0.4321620419748407D_2000E_originalLF_binaryLE,0.715296095350676,0.6760788124615275,0.5590996884326331,0.4113607007891719,0.446394580254481,0.6413740724329071 +DGM_kiba3D_nomsaF_aflowE_64B_0.0001139464546302261LR_0.4321620419748407D_2000E_originalLF_binaryLE,0.6897289677863744,0.552288313673736,0.4932783355544295,0.4163241510325705,0.4502416662950819,0.6452318583521511 +DGM_kiba4D_nomsaF_aflowE_64B_0.0001139464546302261LR_0.4321620419748407D_2000E_originalLF_binaryLE,0.7785830290333009,0.772583639636063,0.6834004931220337,0.2542728701347933,0.3402214839171678,0.504254767091788 +EDIM_davis0D_nomsaF_binaryE_48B_0.0001LR_0.4D_2000E_originalLF_binaryLE,0.8460899419942509,0.7821818481200006,0.6773536752793916,0.3864269594875424,0.3440388107583636,0.6216324955209006 diff --git a/train_test.py b/train_test.py index 486b905..0b53a5b 100644 --- a/train_test.py +++ b/train_test.py @@ -2,22 +2,24 @@ from src.utils.arg_parse import parse_train_test_args args, unknown_args = parse_train_test_args(verbose=True, - jyp_args='--model_opt DG \ - --data_opt davis \ - \ - --feature_opt nomsa \ - --edge_opt binary \ - --ligand_feature_opt original \ - --ligand_edge_opt binary \ - \ - --learning_rate 0.00012 \ - --batch_size 128 \ - --dropout 0.24 \ - --output_dim 128 \ - \ - --train \ - --fold_selection 0 \ - --num_epochs 2000') + jyp_args='--model_opt EDI \ + --data_opt davis \ + --fold_selection 0 \ + \ + --feature_opt nomsa \ + --edge_opt binary \ + --ligand_feature_opt original \ + --ligand_edge_opt binary \ + \ + --learning_rate 0.0001 \ + --batch_size 12 \ + \ + --dropout 0.4 \ + --dropout_prot 0.0 \ + --output_dim 128 \ + --pro_emb_dim 512 \ + --pro_extra_fc_lyr False\ + --debug') FORCE_TRAINING = args.train DEBUG = args.debug