diff --git a/playground.py b/playground.py index 8669224..f0a76ff 100644 --- a/playground.py +++ b/playground.py @@ -1,34 +1,254 @@ +#%% +import pandas as pd +import matplotlib.pyplot as plt +import seaborn as sns + +new = '/cluster/home/t122995uhn/projects/splits/new/pdbbind/' + +train_df = pd.concat([pd.read_csv(f'{new}train0.csv'), + pd.read_csv(f'{new}val0.csv')], axis=0) +test_df = pd.read_csv(f'{new}test.csv') + +all_df = pd.concat([train_df, test_df], axis=0) +print(len(all_df)) + + +#%% +old = '/cluster/home/t122995uhn/projects/splits/old/pdbbind/' +old_test_df = pd.read_csv(f'{old}test.csv') +old_train_df = all_df[~all_df['code'].isin(old_test_df['code'])] + +# %% +# this will give us an estimate to how well targeted the training proteins are vs the test proteins +def proteins_targeted(train_df, test_df, split='new', min_freq=0, normalized=False): + # protein count comparison (number of diverse proteins) + plt.figure(figsize=(18,8)) + # x-axis is the normalized frequency, y-axis is the number of proteins that have that frequency (also normalized) + vc = train_df.prot_id.value_counts() + vc = vc[vc > min_freq] + train_counts = list(vc/len(test_df)) if normalized else vc.values + vc = test_df.prot_id.value_counts() + vc = vc[vc > min_freq] + test_counts = list(vc/len(test_df)) if normalized else vc.values + + sns.histplot(train_counts, + bins=50, stat='density', color='green', alpha=0.4) + sns.histplot(test_counts, + bins=50,stat='density', color='blue', alpha=0.4) + + sns.kdeplot(train_counts, color='green', alpha=0.8) + sns.kdeplot(test_counts, color='blue', alpha=0.8) + + plt.xlabel(f"{'normalized ' if normalized else ''} frequency") + plt.ylabel("normalized number of proteins with that frequency") + plt.title(f"Targeted differences for {split} split{f' (> {min_freq})' if min_freq else ''}") + if not normalized: + plt.xlim(-8,100) + +# proteins_targeted(old_train_df, old_test_df, split='oncoKB') +# plt.show() +# proteins_targeted(train_df, test_df, split='random') +# plt.show() + + +proteins_targeted(old_test_df, test_df, split='oncoKB(green) vs random(blue) test') +plt.show() +proteins_targeted(old_test_df, test_df, split='oncoKB(green) vs random(blue) test', min_freq=5) +plt.show() +proteins_targeted(old_test_df, test_df, split='oncoKB(green) vs random(blue) test', min_freq=10) +plt.show() +proteins_targeted(old_test_df, test_df, split='oncoKB(green) vs random(blue) test', min_freq=15) +plt.show() +proteins_targeted(old_test_df, test_df, split='oncoKB(green) vs random(blue) test', min_freq=20) +plt.show() +# proteins_targeted(old_train_df, train_df, split='oncoKB(green) vs random train') +# plt.show() +#%% sequence length comparison +def seq_kde(all_df, train_df, test_df, split='new'): + plt.figure(figsize=(12, 8)) + + sns.kdeplot(all_df.prot_seq.str.len().reset_index()['prot_seq'], label='All', color='blue') + sns.kdeplot(train_df.prot_seq.str.len().reset_index()['prot_seq'], label='Train', color='green') + sns.kdeplot(test_df.prot_seq.str.len().reset_index()['prot_seq'], label='Test', color='red') + + plt.xlabel('Sequence Length') + plt.ylabel('Density') + plt.title(f'Sequence Length Distribution ({split} split)') + plt.legend() + +seq_kde(all_df,train_df,test_df, split='new') +plt.show() +seq_kde(all_df,old_train_df,old_test_df, split='old') + +# %% +from Bio import pairwise2 +from Bio.Seq import Seq +from Bio.SeqRecord import SeqRecord +from Bio.Align import substitution_matrices + +from tqdm import tqdm +import random + +def get_group_similarity(group1, group2): + # Choose a substitution matrix (e.g., BLOSUM62) + matrix = substitution_matrices.load("BLOSUM62") + + # Define gap penalties + gap_open = -10 + gap_extend = -0.5 + + # Function to calculate pairwise similarity score + def calculate_similarity(seq1, seq2): + alignments = pairwise2.align.globalds(seq1, seq2, matrix, gap_open, gap_extend) + return alignments[0][2] # Return the score of the best alignment + + # Compute pairwise similarity between all sequences in group1 and group2 + similarity_scores = [] + for seq1 in group1: + for seq2 in group2: + score = calculate_similarity(seq1, seq2) + similarity_scores.append(score) + + # Calculate the average similarity score + average_similarity = sum(similarity_scores) / len(similarity_scores) + return similarity_scores, average_similarity + + +# sample 10 sequences randomly 100x +train_seq = old_train_df.prot_seq.drop_duplicates().to_list() +test_seq = old_test_df.prot_seq.drop_duplicates().to_list() +sample_size = 5 +trials = 100 + +est_similarity = 0 +for _ in tqdm(range(trials)): + _, avg = get_group_similarity(random.sample(train_seq, sample_size), + random.sample(test_seq, sample_size)) + est_similarity += avg + +print(est_similarity/1000) + + + + # %% -######################################################################## -########################## VIOLIN PLOTTING ############################# -######################################################################## +# building pocket datasets: +from src.utils.pocket_alignment import pocket_dataset_full +import shutil +import os + +data_dir = '/cluster/home/t122995uhn/projects/data/' +db_type = ['kiba', 'davis'] +db_feat = ['nomsa_binary_original_binary', 'nomsa_aflow_original_binary', + 'nomsa_binary_gvp_binary', 'nomsa_aflow_gvp_binary'] + +for t in db_type: + for f in db_feat: + print(f'\n---{t}-{f}---\n') + dataset_dir= f"{data_dir}/DavisKibaDataset/{t}/{f}/full" + save_dir = f"{data_dir}/v131/DavisKibaDataset/{t}/{f}/full" + + pocket_dataset_full( + dataset_dir= dataset_dir, + pocket_dir = f"{data_dir}/{t}/", + save_dir = save_dir, + skip_download=True + ) + +#%% +import pandas as pd + +def get_test_oncokbs(train_df=pd.read_csv('/cluster/home/t122995uhn/projects/data/test/PDBbindDataset/nomsa_binary_original_binary/full/cleaned_XY.csv'), + oncokb_fp='/cluster/home/t122995uhn/projects/data/tcga/mart_export.tsv', + biomart='/cluster/home/t122995uhn/projects/downloads/oncoKB_DrugGenePairList.csv'): + #Get gene names for PDBbind + dfbm = pd.read_csv(oncokb_fp, sep='\t') + dfbm['PDB ID'] = dfbm['PDB ID'].str.lower() + train_df.reset_index(names='idx',inplace=True) + + df_uni = train_df.merge(dfbm, how='inner', left_on='prot_id', right_on='UniProtKB/Swiss-Prot ID') + df_pdb = train_df.merge(dfbm, how='inner', left_on='code', right_on='PDB ID') + + # identifying ovelap with oncokb + # df_all will have duplicate entries for entries with multiple gene names... + df_all = pd.concat([df_uni, df_pdb]).drop_duplicates(['idx', 'Gene name'])[['idx', 'code', 'Gene name']] + + dfkb = pd.read_csv(biomart) + df_all_kb = df_all.merge(dfkb.drop_duplicates('gene'), left_on='Gene name', right_on='gene', how='inner') + + trained_genes = set(df_all_kb.gene) + + #Identify non-trained genes + return dfkb[~dfkb['gene'].isin(trained_genes)], dfkb[dfkb['gene'].isin(trained_genes)], dfkb + + +train_df = pd.read_csv('/cluster/home/t122995uhn/projects/data/test/PDBbindDataset/nomsa_binary_original_binary/train0/cleaned_XY.csv') +val_df = pd.read_csv('/cluster/home/t122995uhn/projects/data/test/PDBbindDataset/nomsa_binary_original_binary/val0/cleaned_XY.csv') + +train_df = pd.concat([train_df, val_df]) + +get_test_oncokbs(train_df=train_df) + +#%% +############################################################################## +########################## BUILD/SPLIT DATASETS ############################## +############################################################################## +import os +from src.data_prep.init_dataset import create_datasets +from src import cfg import logging -from matplotlib import pyplot as plt - -from src.analysis.figures import prepare_df, fig_combined, custom_fig - -dft = prepare_df('./results/v115/model_media/model_stats.csv') -dfv = prepare_df('./results/v115/model_media/model_stats_val.csv') - -models = { - 'DG': ('nomsa', 'binary', 'original', 'binary'), - 'esm': ('ESM', 'binary', 'original', 'binary'), # esm model - 'aflow': ('nomsa', 'aflow', 'original', 'binary'), - # 'gvpP': ('gvp', 'binary', 'original', 'binary'), - 'gvpL': ('nomsa', 'binary', 'gvp', 'binary'), - # 'aflow_ring3': ('nomsa', 'aflow_ring3', 'original', 'binary'), - 'gvpL_aflow': ('nomsa', 'aflow', 'gvp', 'binary'), - # 'gvpL_aflow_rng3': ('nomsa', 'aflow_ring3', 'gvp', 'binary'), - #GVPL_ESMM_davis3D_nomsaF_aflowE_48B_0.00010636872718329864LR_0.23282479481785903D_2000E_gvpLF_binaryLE - # 'gvpl_esm_aflow': ('ESM', 'aflow', 'gvp', 'binary'), -} - -fig, axes = fig_combined(dft, datasets=['davis'], fig_callable=custom_fig, - models=models, metrics=['cindex', 'mse'], - fig_scale=(10,5), add_stats=True, title_postfix=" test set performance", box=True, fold_labels=True) -plt.xticks(rotation=45) - -fig, axes = fig_combined(dfv, datasets=['davis'], fig_callable=custom_fig, - models=models, metrics=['cindex', 'mse'], - fig_scale=(10,5), add_stats=True, title_postfix=" validation set performance", box=True, fold_labels=True) -plt.xticks(rotation=45) \ No newline at end of file +cfg.logger.setLevel(logging.DEBUG) + +dbs = [cfg.DATA_OPT.davis, cfg.DATA_OPT.kiba] +splits = ['davis', 'kiba'] +splits = ['/cluster/home/t122995uhn/projects/MutDTA/splits/' + s for s in splits] +print(splits) + +#%% +for split, db in zip(splits, dbs): + print('\n',split, db) + create_datasets(db, + feat_opt=cfg.PRO_FEAT_OPT.nomsa, + edge_opt=[cfg.PRO_EDGE_OPT.binary, cfg.PRO_EDGE_OPT.aflow], + ligand_features=[cfg.LIG_FEAT_OPT.original, cfg.LIG_FEAT_OPT.gvp], + ligand_edges=cfg.LIG_EDGE_OPT.binary, overwrite=False, + k_folds=5, + + test_prots_csv=f'{split}/test.csv', + val_prots_csv=[f'{split}/val{i}.csv' for i in range(5)]) + +#%% TEST INFERENCE +from src import cfg +from src.utils.loader import Loader + +# db2 = Loader.load_dataset(cfg.DATA_OPT.davis, +# cfg.PRO_FEAT_OPT.nomsa, cfg.PRO_EDGE_OPT.aflow, +# path='/cluster/home/t122995uhn/projects/data/', +# subset="full") + +db2 = Loader.load_DataLoaders(cfg.DATA_OPT.davis, + cfg.PRO_FEAT_OPT.nomsa, cfg.PRO_EDGE_OPT.aflow, + path='/cluster/home/t122995uhn/projects/data/v131', + training_fold=0, + batch_train=2) +for b2 in db2['test']: break + + +# %% +m = Loader.init_model(cfg.MODEL_OPT.DG, cfg.PRO_FEAT_OPT.nomsa, cfg.PRO_EDGE_OPT.aflow, + dropout=0.3480, output_dim=256, + ) + +#%% +# m(b['protein'], b['ligand']) +m(b2['protein'], b2['ligand']) +#%% +model = m +loaders = db2 +device = 'cpu' +NUM_EPOCHS = 1 +LEARNING_RATE = 0.001 +from src.train_test.training import train + +logs = train(model, loaders['train'], loaders['val'], device, + epochs=NUM_EPOCHS, lr_0=LEARNING_RATE) diff --git a/results/v103/model_media/model_stats.csv b/results/v103/model_media/model_stats.csv new file mode 100644 index 0000000..d4227f5 --- /dev/null +++ b/results/v103/model_media/model_stats.csv @@ -0,0 +1,22 @@ +run,cindex,pearson,spearman,mse,mae,rmse +DGM_davis0D_nomsaF_binaryE_128B_0.00012LR_0.24D_2000E_originalLF_binaryLE,0.8174366987963239,0.6808973439070014,0.5780986864623106,0.374029119754687,0.3416232488841833,0.6115792015386781 +DGM_davis1D_nomsaF_binaryE_128B_0.00012LR_0.24D_2000E_originalLF_binaryLE,0.8359138385559401,0.7212884148849212,0.6093121108415754,0.3444294398275105,0.3380570360012467,0.5868811121747832 +DGM_davis2D_nomsaF_binaryE_128B_0.00012LR_0.24D_2000E_originalLF_binaryLE,0.811306156371881,0.6771836874485692,0.5650256869521153,0.3933000326926663,0.333361968167426,0.6271363748760442 +DGM_davis3D_nomsaF_binaryE_128B_0.00012LR_0.24D_2000E_originalLF_binaryLE,0.8148631243802541,0.717113315384429,0.571925536761479,0.3422128815756367,0.3177703711270548,0.5849896422806448 +DGM_davis4D_nomsaF_binaryE_128B_0.00012LR_0.24D_2000E_originalLF_binaryLE,0.8196459665927316,0.694403802004145,0.5825760745508323,0.3702764201890446,0.33563001218595,0.6085034266041931 +DGM_davis0D_nomsaF_aflowE_128B_0.0008279387625584954LR_0.3480347297724069D_2000E_originalLF_binaryLE,0.7936510071795485,0.628767072325098,0.5217398281378556,0.3566859747000747,0.3591853688744937,0.597231927060229 +DGM_davis1D_nomsaF_aflowE_128B_0.0008279387625584954LR_0.3480347297724069D_2000E_originalLF_binaryLE,0.8009815205097928,0.6035635252189794,0.5304746622864567,0.4253406250688673,0.364227359902625,0.6521814356978182 +DGM_davis3D_nomsaF_aflowE_128B_0.0008279387625584954LR_0.3480347297724069D_2000E_originalLF_binaryLE,0.7783955876418098,0.5816462981556966,0.4961723044095886,0.4376154312774337,0.3656365177210639,0.6615250798552038 +DGM_davis4D_nomsaF_aflowE_128B_0.0008279387625584954LR_0.3480347297724069D_2000E_originalLF_binaryLE,0.8181456735871336,0.6918684941945846,0.56229516172368,0.3071043302279289,0.2969707269294589,0.554169947063109 +DGM_davis2D_nomsaF_aflowE_128B_0.0008279387625584954LR_0.3480347297724069D_2000E_originalLF_binaryLE,0.8014907154138579,0.6425965261636467,0.5354462017864902,0.3606209315377456,0.3375259168007795,0.6005172200176657 +DGM_kiba0D_nomsaF_binaryE_128B_0.0001LR_0.4D_2000E_originalLF_binaryLE,0.7476189416085207,0.7148917008987766,0.6299877614860792,0.3746319657859179,0.3958356694230301,0.6120718632529336 +DGM_kiba1D_nomsaF_binaryE_128B_0.0001LR_0.4D_2000E_originalLF_binaryLE,0.7073391610819149,0.624956249151526,0.5401876728173656,0.4451318825041403,0.458846963456725,0.667182045999546 +DGM_kiba2D_nomsaF_binaryE_128B_0.0001LR_0.4D_2000E_originalLF_binaryLE,0.7401141841678894,0.6795148074510864,0.6127459332278625,0.4004100160666026,0.4095781581139723,0.6327795951724444 +DGM_kiba3D_nomsaF_binaryE_128B_0.0001LR_0.4D_2000E_originalLF_binaryLE,0.7396234368040389,0.6913457932090825,0.6201197126448974,0.3934219012917641,0.4068530834848238,0.6272335301080165 +DGM_kiba4D_nomsaF_binaryE_128B_0.0001LR_0.4D_2000E_originalLF_binaryLE,0.752441708545282,0.7025492844189518,0.6449954833411846,0.3728163774990898,0.4045171920082104,0.6105869123221442 +DGM_kiba0D_nomsaF_aflowE_64B_0.0001139464546302261LR_0.4321620419748407D_2000E_originalLF_binaryLE,0.7599929872587803,0.7067412429690916,0.6593355592769512,0.3962219319168832,0.4099100126533609,0.6294616206861887 +DGM_kiba2D_nomsaF_aflowE_64B_0.0001139464546302261LR_0.4321620419748407D_2000E_originalLF_binaryLE,0.7149604681753393,0.6152047008431843,0.5597795125500629,0.4741719822054008,0.4603646989542154,0.6886014683439187 +DGM_kiba1D_nomsaF_aflowE_64B_0.0001139464546302261LR_0.4321620419748407D_2000E_originalLF_binaryLE,0.7140873472783476,0.6102548954720128,0.5558196740209606,0.4781851688315759,0.4659358458753446,0.6915093411021834 +DGM_kiba3D_nomsaF_aflowE_64B_0.0001139464546302261LR_0.4321620419748407D_2000E_originalLF_binaryLE,0.7164547158247304,0.6084847523640808,0.5607065445063388,0.4802083760845744,0.4646035882672965,0.6929706891958523 +DGM_kiba4D_nomsaF_aflowE_64B_0.0001139464546302261LR_0.4321620419748407D_2000E_originalLF_binaryLE,0.7687577053257117,0.7532822502738942,0.6745267167129126,0.3466135049736077,0.3887611475832294,0.5887389107011765 +EDIM_davis0D_nomsaF_binaryE_48B_0.0001LR_0.4D_2000E_originalLF_binaryLE,0.8114566611493259,0.7317647125777735,0.6044949818493646,0.3736163373086704,0.3493916191183746,0.611241635778086 diff --git a/results/v103/model_media/model_stats_val.csv b/results/v103/model_media/model_stats_val.csv new file mode 100644 index 0000000..1aba5af --- /dev/null +++ b/results/v103/model_media/model_stats_val.csv @@ -0,0 +1,22 @@ +run,cindex,pearson,spearman,mse,mae,rmse +DGM_davis0D_nomsaF_binaryE_128B_0.00012LR_0.24D_2000E_originalLF_binaryLE,0.8507053734550806,0.7688628504779598,0.6689225345680122,0.3760747658599554,0.3388000398874283,0.6132493504765867 +DGM_davis1D_nomsaF_binaryE_128B_0.00012LR_0.24D_2000E_originalLF_binaryLE,0.8718442303414345,0.8308115505911805,0.7173863620955029,0.323120446450846,0.3234809194096876,0.5684368447337365 +DGM_davis2D_nomsaF_binaryE_128B_0.00012LR_0.24D_2000E_originalLF_binaryLE,0.8818149976678145,0.834014760655388,0.7187113282294693,0.3071136635927556,0.2922643621762593,0.5541783680303262 +DGM_davis3D_nomsaF_binaryE_128B_0.00012LR_0.24D_2000E_originalLF_binaryLE,0.8686054788183053,0.828507036778059,0.7018974086625753,0.3046836030153428,0.3018857493804209,0.5519815241612194 +DGM_davis4D_nomsaF_binaryE_128B_0.00012LR_0.24D_2000E_originalLF_binaryLE,0.8510324353139875,0.7912085758636695,0.6660120299194481,0.3556152282825756,0.3246624081237138,0.5963348290034514 +DGM_davis0D_nomsaF_aflowE_128B_0.0008279387625584954LR_0.3480347297724069D_2000E_originalLF_binaryLE,0.8224153103333314,0.7079363892542606,0.598653291929885,0.390209011583234,0.3662272181520597,0.6246671206196417 +DGM_davis1D_nomsaF_aflowE_128B_0.0008279387625584954LR_0.3480347297724069D_2000E_originalLF_binaryLE,0.8493586014591634,0.7889115161443931,0.673101173187512,0.385586498014305,0.3309167098727579,0.6209561160132856 +DGM_davis3D_nomsaF_aflowE_128B_0.0008279387625584954LR_0.3480347297724069D_2000E_originalLF_binaryLE,0.8337160086816762,0.736812322417127,0.6264321347273434,0.3912511533576103,0.3474602306165372,0.6255007221079847 +DGM_davis4D_nomsaF_aflowE_128B_0.0008279387625584954LR_0.3480347297724069D_2000E_originalLF_binaryLE,0.8463969226855363,0.7444417955240732,0.6410258059445946,0.3682857548365584,0.3208404985420844,0.6068655162690977 +DGM_davis2D_nomsaF_aflowE_128B_0.0008279387625584954LR_0.3480347297724069D_2000E_originalLF_binaryLE,0.8485013700858233,0.7861129348915608,0.6464621130340457,0.3603118364352048,0.334368295190004,0.6002598074460799 +DGM_kiba0D_nomsaF_binaryE_128B_0.0001LR_0.4D_2000E_originalLF_binaryLE,0.76838489785244,0.7349294201300529,0.6720189966892385,0.2870000493840723,0.36449329645463846,0.5357238555301344 +DGM_kiba1D_nomsaF_binaryE_128B_0.0001LR_0.4D_2000E_originalLF_binaryLE,0.6830895888292509,0.6256216928279031,0.4862834063605662,0.4635991200460103,0.4637030257199048,0.6808811350346038 +DGM_kiba2D_nomsaF_binaryE_128B_0.0001LR_0.4D_2000E_originalLF_binaryLE,0.7424018620084226,0.7064897658791767,0.6113374073010096,0.3869704646200874,0.4214332353778002,0.6220695014386153 +DGM_kiba3D_nomsaF_binaryE_128B_0.0001LR_0.4D_2000E_originalLF_binaryLE,0.7887989201757725,0.7493602745702617,0.7146442736012206,0.2624049771770561,0.3196535898269914,0.5122547971244936 +DGM_kiba4D_nomsaF_binaryE_128B_0.0001LR_0.4D_2000E_originalLF_binaryLE,0.805961202163743,0.797223082421482,0.7422315375449509,0.2469041691088263,0.3146771252445765,0.4968945251346872 +DGM_kiba0D_nomsaF_aflowE_64B_0.0001139464546302261LR_0.4321620419748407D_2000E_originalLF_binaryLE,0.7414150463245769,0.7180513866946735,0.6154427990455673,0.2477156183512984,0.352769788125809,0.4977103759731139 +DGM_kiba2D_nomsaF_aflowE_64B_0.0001139464546302261LR_0.4321620419748407D_2000E_originalLF_binaryLE,0.7017734270681899,0.6190248117265895,0.5234448165080476,0.4732107505692587,0.4709162900019082,0.6879031549348054 +DGM_kiba1D_nomsaF_aflowE_64B_0.0001139464546302261LR_0.4321620419748407D_2000E_originalLF_binaryLE,0.715296095350676,0.6760788124615275,0.5590996884326331,0.4113607007891719,0.446394580254481,0.6413740724329071 +DGM_kiba3D_nomsaF_aflowE_64B_0.0001139464546302261LR_0.4321620419748407D_2000E_originalLF_binaryLE,0.6897289677863744,0.552288313673736,0.4932783355544295,0.4163241510325705,0.4502416662950819,0.6452318583521511 +DGM_kiba4D_nomsaF_aflowE_64B_0.0001139464546302261LR_0.4321620419748407D_2000E_originalLF_binaryLE,0.7785830290333009,0.772583639636063,0.6834004931220337,0.2542728701347933,0.3402214839171678,0.504254767091788 +EDIM_davis0D_nomsaF_binaryE_48B_0.0001LR_0.4D_2000E_originalLF_binaryLE,0.8460899419942509,0.7821818481200006,0.6773536752793916,0.3864269594875424,0.3440388107583636,0.6216324955209006 diff --git a/src/models/esm_models.py b/src/models/esm_models.py index 3f62b05..5b4504a 100644 --- a/src/models/esm_models.py +++ b/src/models/esm_models.py @@ -18,7 +18,7 @@ class EsmDTA(BaseModel): def __init__(self, esm_head:str='facebook/esm2_t6_8M_UR50D', num_features_pro=320, pro_emb_dim=54, num_features_mol=78, output_dim=128, dropout=0.2, pro_feat='esm_only', edge_weight_opt='binary', - dropout_prot=0.0, pro_extra_fc_lyr=False): + dropout_prot=0.0, pro_extra_fc_lyr=False, **kwargs): super(EsmDTA, self).__init__(pro_feat, edge_weight_opt) @@ -86,6 +86,13 @@ def forward_pro(self, data): esm_emb = esm_emb.flatten(0,1) # to [B*L_max+1, emb_dim] esm_emb = esm_emb[mask] # [B*L, emb_dim] + # applying pocket mask if relevant + if "pocket_mask" in data: + # pocket mask must be a flattened mask from [B,L] to [B*L] + mask = [torch.tensor(d, dtype=torch.bool) for d in data.pocket_mask]# [B, L] + mask = torch.cat(mask) + esm_emb = esm_emb[mask] + if self.esm_only: target_x = esm_emb # [B*L, emb_dim] else: diff --git a/src/train_test/splitting.py b/src/train_test/splitting.py index ed7460c..7430dda 100644 --- a/src/train_test/splitting.py +++ b/src/train_test/splitting.py @@ -343,19 +343,22 @@ def resplit(dataset:str|BaseDataset, split_files:dict|str=None, use_train_set=Fa split_files = split_files.copy() test_prots = set(pd.read_csv(split_files['test'])['prot_id']) test_idxs = [i for i in range(len(dataset.df)) if dataset.df.iloc[i]['prot_id'] in test_prots] + assert len(test_idxs) > 100, f"Error in splitting, not enough entries in test split - {split_files['test']}" dataset.save_subset(test_idxs, 'test') del split_files['test'] # Building the folds for k, v in tqdm(split_files.items(), desc="Building folds from split files"): prots = set(pd.read_csv(v)['prot_id']) - val_idxs = [i for i in range(len(dataset.df)) if dataset.df.iloc[i]['prot_id'] in prots] + val_idxs = [i for i in range(len(dataset.df)) if dataset.df.iloc[i]['prot_id'] in prots] + assert len(val_idxs) > 100, f"Error in splitting, not enough entries in {k} split - {v}" dataset.save_subset(val_idxs, k) if not use_train_set: # Build training set from all proteins not in the val/test set idxs = set(val_idxs + test_idxs) train_idxs = [i for i in range(len(dataset.df)) if i not in idxs] + assert len(train_idxs) > 100, f"Error in splitting, not enough entries in train split" dataset.save_subset(train_idxs, k.replace('val', 'train')) return dataset diff --git a/src/utils/config.py b/src/utils/config.py index 8fa711f..2cfb9d6 100644 --- a/src/utils/config.py +++ b/src/utils/config.py @@ -89,8 +89,7 @@ class LIG_FEAT_OPT(StringEnum): from pathlib import Path # Model save paths -issue_number = None - +issue_number = 103 DATA_BASENAME = f'data/{f"v{issue_number}" if issue_number else ""}' RESULTS_PATH = os.path.abspath(f'results/{f"v{issue_number}/" if issue_number else ""}') MEDIA_SAVE_DIR = f'{RESULTS_PATH}/model_media/' @@ -111,7 +110,7 @@ class LIG_FEAT_OPT(StringEnum): SLURM_ACCOUNT = None SLURM_GPU_NAME = 'v100' -if ('uhnh4h' in DOMAIN_NAME or 'h4h' in DOMAIN_NAME): +if ('uhnh4h' in DOMAIN_NAME) or ('h4h' in DOMAIN_NAME): CLUSTER = 'h4h' SLURM_PARTITION = 'gpu' SLURM_CONSTRAINT = 'gpu32g' diff --git a/src/utils/loader.py b/src/utils/loader.py index 48ead68..1743cdd 100644 --- a/src/utils/loader.py +++ b/src/utils/loader.py @@ -338,7 +338,7 @@ def load_DataLoaders(data:str=None, pro_feature:str=None, edge_opt:str=None, pat bs = 1 if d == 'test' else batch_train loader = DataLoader(dataset=loaded_datasets[d], batch_size=bs, - shuffle=False) + shuffle=True) loaders[d] = loader return loaders @@ -365,7 +365,7 @@ def load_distributed_DataLoaders(num_replicas:int, rank:int, seed:int, data:str, loaders = {} for d in loaded_datasets: dataset = loaded_datasets[d] - sampler = DistributedSampler(dataset, shuffle=True, + sampler = DistributedSampler(dataset, shuffle=False, num_replicas=num_replicas, rank=rank, seed=seed) @@ -374,7 +374,7 @@ def load_distributed_DataLoaders(num_replicas:int, rank:int, seed:int, data:str, sampler=sampler, batch_size=bs, # should be per gpu batch size (local batch size) num_workers=num_workers, - shuffle=False, + shuffle=False, # mut exclusive with DDP pin_memory=True, drop_last=True) # drop last batch if not divisible by batch size loaders[d] = loader @@ -465,4 +465,4 @@ def wrapper(*args, **kwargs): # Return the function call output return func(*args, **kwargs) return wrapper - return decorator \ No newline at end of file + return decorator diff --git a/src/utils/pocket_alignment.py b/src/utils/pocket_alignment.py index 8ff30eb..c0bbec9 100644 --- a/src/utils/pocket_alignment.py +++ b/src/utils/pocket_alignment.py @@ -9,6 +9,7 @@ from Bio import Align from Bio.Align import substitution_matrices +import numpy as np import pandas as pd import torch @@ -78,26 +79,34 @@ def mask_graph(data, mask: list[bool]): additional attributes: -pocket_mask : list[bool] The mask specified by the mask parameter of dimension [full_seuqence_length] - -pocket_mask_x : torch.Tensor + -x : torch.Tensor The nodes of only the pocket of the protein sequence of dimension [pocket_sequence_length, num_features] - -pocket_mask_edge_index : torch.Tensor + -edge_index : torch.Tensor The edge connections in COO format only relating to the pocket nodes of the protein sequence of dimension [2, num_pocket_edges] """ + # node map for updating edge indicies after mask + node_map = np.cumsum(mask) - 1 + nodes = data.x[mask] - edges = data.edge_index + edges = [] edge_mask = [] - for i in range(edges.shape[1]): - # Throw out edges that are connected to at least one node not in the - # binding pocket - node_1, node_2 = edges[:,i][0], edges[:,i][1] - edge_mask.append(True) if mask[node_1] and mask[node_2] else edge_mask.append(False) - edges = torch.transpose(torch.transpose(edges, 0, 1)[edge_mask], 0, 1) + for i in range(data.edge_index.shape[1]): + # Throw out edges that are not part of connecting two nodes in the pocket... + node_1, node_2 = data.edge_index[:,i][0], data.edge_index[:,i][1] + if mask[node_1] and mask[node_2]: + # append mapped index: + edges.append([node_map[node_1], node_map[node_2]]) + edge_mask.append(True) + else: + edge_mask.append(False) + data.x = nodes data.pocket_mask = mask - data.pocket_mask_x = nodes - data.pocket_mask_edge_index = edges + data.edge_index = torch.tensor(edges).T # reshape to (2, E) + if 'edge_weight' in data: + data.edge_weight = data.edge_weight[edge_mask] return data @@ -122,7 +131,8 @@ def _parse_json(json_path: str) -> str: def get_dataset_binding_pockets( dataset_path: str = 'data/DavisKibaDataset/kiba/nomsa_binary_original_binary/full', - pockets_path: str = 'data/DavisKibaDataset/kiba_pocket' + pockets_path: str = 'data/DavisKibaDataset/kiba_pocket', + skip_download: bool = False, ) -> tuple[dict[str, str], set[str]]: """ Get all binding pocket sequences for a dataset @@ -149,14 +159,14 @@ def get_dataset_binding_pockets( # Strip out mutations and '-(alpha, beta, gamma)' tags if they are present, # the binding pocket sequence will be the same for mutated and non-mutated genes prot_ids = [id.split('(')[0].split('-')[0] for id in prot_ids] - dl = Downloader() seq_save_dir = os.path.join(pockets_path, 'pockets') - os.makedirs(seq_save_dir, exist_ok=True) - download_check = dl.download_pocket_seq(prot_ids, seq_save_dir) + + if not skip_download: # to use cached downloads only! (useful when on compute node) + dl = Downloader() + os.makedirs(seq_save_dir, exist_ok=True) + dl.download_pocket_seq(prot_ids, seq_save_dir) + download_errors = set() - for key, val in download_check.items(): - if val == 400: - download_errors.add(key) sequences = {} for file in os.listdir(seq_save_dir): pocket_seq = _parse_json(os.path.join(seq_save_dir, file)) @@ -164,6 +174,12 @@ def get_dataset_binding_pockets( download_errors.add(file.split('.')[0]) else: sequences[file.split('.')[0]] = pocket_seq + + # adding any remainder prots not downloaded. + for p in prot_ids: + if p not in sequences: + download_errors.add(p) + return (sequences, download_errors) @@ -197,7 +213,7 @@ def create_binding_pocket_dataset( new_data = mask_graph(data, mask) new_dataset[id] = new_data os.makedirs(os.path.dirname(new_dataset_path), exist_ok=True) - torch.save(dataset, new_dataset_path) + torch.save(new_dataset, new_dataset_path) def binding_pocket_filter(dataset_csv_path: str, download_errors: set[str], csv_save_path: str): @@ -215,8 +231,8 @@ def binding_pocket_filter(dataset_csv_path: str, download_errors: set[str], csv_ csv_save_path : str The path to save the new CSV file to. """ - df = pd.read_csv(dataset_csv_path) - df = df[~df['prot_id'].isin(download_errors)] + df = pd.read_csv(dataset_csv_path, index_col=0) + df = df[~df.prot_id.str.split('(').str[0].str.split('-').str[0].isin(download_errors)] os.makedirs(os.path.dirname(csv_save_path), exist_ok=True) df.to_csv(csv_save_path) @@ -224,7 +240,8 @@ def binding_pocket_filter(dataset_csv_path: str, download_errors: set[str], csv_ def pocket_dataset_full( dataset_dir: str, pocket_dir: str, - save_dir: str + save_dir: str, + skip_download: bool = False ) -> None: """ Create all elements of a dataset that includes binding pockets. This @@ -240,7 +257,7 @@ def pocket_dataset_full( save_dir : str The path to where the new dataset is to be saved """ - pocket_map, download_errors = get_dataset_binding_pockets(dataset_dir, pocket_dir) + pocket_map, download_errors = get_dataset_binding_pockets(dataset_dir, pocket_dir, skip_download) print(f'Binding pocket sequences were not found for the following {len(download_errors)} protein IDs:') print(','.join(list(download_errors))) create_binding_pocket_dataset( @@ -254,7 +271,9 @@ def pocket_dataset_full( download_errors, os.path.join(save_dir, 'cleaned_XY.csv') ) - shutil.copy2(os.path.join(dataset_dir, 'data_mol.pt'), os.path.join(save_dir, 'data_mol.pt')) + if dataset_dir != save_dir: + shutil.copy2(os.path.join(dataset_dir, 'data_mol.pt'), os.path.join(save_dir, 'data_mol.pt')) + shutil.copy2(os.path.join(dataset_dir, 'XY.csv'), os.path.join(save_dir, 'XY.csv')) if __name__ == '__main__': diff --git a/train_test.py b/train_test.py index 8bc9bba..0b53a5b 100644 --- a/train_test.py +++ b/train_test.py @@ -2,7 +2,24 @@ from src.utils.arg_parse import parse_train_test_args args, unknown_args = parse_train_test_args(verbose=True, - jyp_args='-m DG -d PDBbind -f nomsa -e binary -bs 64') + jyp_args='--model_opt EDI \ + --data_opt davis \ + --fold_selection 0 \ + \ + --feature_opt nomsa \ + --edge_opt binary \ + --ligand_feature_opt original \ + --ligand_edge_opt binary \ + \ + --learning_rate 0.0001 \ + --batch_size 12 \ + \ + --dropout 0.4 \ + --dropout_prot 0.0 \ + --output_dim 128 \ + --pro_emb_dim 512 \ + --pro_extra_fc_lyr False\ + --debug') FORCE_TRAINING = args.train DEBUG = args.debug @@ -46,6 +63,7 @@ cp_saver = CheckpointSaver(model=None, save_path=None, train_all=False, # forces full training + min_delta=0.2, patience=100) # %% Training loop @@ -88,14 +106,13 @@ ligand_feature=ligand_feature, ligand_edge=ligand_edge, **unknown_args).to(device) cp_saver.new_model(model, save_path=model_save_p) - cp_saver.min_delta = 0.2 if DATA == cfg.DATA_OPT.PDBbind else 0.05 + cp_saver.min_delta = 0.2 if DATA == cfg.DATA_OPT.PDBbind else 0.03 if DEBUG: # run single batch through model debug(model, loaders['train'], device) continue # skip training - # ==== TRAINING ==== # check if model has already been trained: logs = None