Skip to content

Commit

Permalink
Merge pull request #135 from jyaacoub/pocket-training-v103
Browse files Browse the repository at this point in the history
Training new pocket representation
  • Loading branch information
jyaacoub committed Sep 15, 2024
2 parents af725c6 + 018a9ee commit ef9ac0f
Show file tree
Hide file tree
Showing 9 changed files with 377 additions and 68 deletions.
284 changes: 252 additions & 32 deletions playground.py
Original file line number Diff line number Diff line change
@@ -1,34 +1,254 @@
#%%
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

new = '/cluster/home/t122995uhn/projects/splits/new/pdbbind/'

train_df = pd.concat([pd.read_csv(f'{new}train0.csv'),
pd.read_csv(f'{new}val0.csv')], axis=0)
test_df = pd.read_csv(f'{new}test.csv')

all_df = pd.concat([train_df, test_df], axis=0)
print(len(all_df))


#%%
old = '/cluster/home/t122995uhn/projects/splits/old/pdbbind/'
old_test_df = pd.read_csv(f'{old}test.csv')
old_train_df = all_df[~all_df['code'].isin(old_test_df['code'])]

# %%
# this will give us an estimate to how well targeted the training proteins are vs the test proteins
def proteins_targeted(train_df, test_df, split='new', min_freq=0, normalized=False):
# protein count comparison (number of diverse proteins)
plt.figure(figsize=(18,8))
# x-axis is the normalized frequency, y-axis is the number of proteins that have that frequency (also normalized)
vc = train_df.prot_id.value_counts()
vc = vc[vc > min_freq]
train_counts = list(vc/len(test_df)) if normalized else vc.values
vc = test_df.prot_id.value_counts()
vc = vc[vc > min_freq]
test_counts = list(vc/len(test_df)) if normalized else vc.values

sns.histplot(train_counts,
bins=50, stat='density', color='green', alpha=0.4)
sns.histplot(test_counts,
bins=50,stat='density', color='blue', alpha=0.4)

sns.kdeplot(train_counts, color='green', alpha=0.8)
sns.kdeplot(test_counts, color='blue', alpha=0.8)

plt.xlabel(f"{'normalized ' if normalized else ''} frequency")
plt.ylabel("normalized number of proteins with that frequency")
plt.title(f"Targeted differences for {split} split{f' (> {min_freq})' if min_freq else ''}")
if not normalized:
plt.xlim(-8,100)

# proteins_targeted(old_train_df, old_test_df, split='oncoKB')
# plt.show()
# proteins_targeted(train_df, test_df, split='random')
# plt.show()


proteins_targeted(old_test_df, test_df, split='oncoKB(green) vs random(blue) test')
plt.show()
proteins_targeted(old_test_df, test_df, split='oncoKB(green) vs random(blue) test', min_freq=5)
plt.show()
proteins_targeted(old_test_df, test_df, split='oncoKB(green) vs random(blue) test', min_freq=10)
plt.show()
proteins_targeted(old_test_df, test_df, split='oncoKB(green) vs random(blue) test', min_freq=15)
plt.show()
proteins_targeted(old_test_df, test_df, split='oncoKB(green) vs random(blue) test', min_freq=20)
plt.show()
# proteins_targeted(old_train_df, train_df, split='oncoKB(green) vs random train')
# plt.show()
#%% sequence length comparison
def seq_kde(all_df, train_df, test_df, split='new'):
plt.figure(figsize=(12, 8))

sns.kdeplot(all_df.prot_seq.str.len().reset_index()['prot_seq'], label='All', color='blue')
sns.kdeplot(train_df.prot_seq.str.len().reset_index()['prot_seq'], label='Train', color='green')
sns.kdeplot(test_df.prot_seq.str.len().reset_index()['prot_seq'], label='Test', color='red')

plt.xlabel('Sequence Length')
plt.ylabel('Density')
plt.title(f'Sequence Length Distribution ({split} split)')
plt.legend()

seq_kde(all_df,train_df,test_df, split='new')
plt.show()
seq_kde(all_df,old_train_df,old_test_df, split='old')

# %%
from Bio import pairwise2
from Bio.Seq import Seq
from Bio.SeqRecord import SeqRecord
from Bio.Align import substitution_matrices

from tqdm import tqdm
import random

def get_group_similarity(group1, group2):
# Choose a substitution matrix (e.g., BLOSUM62)
matrix = substitution_matrices.load("BLOSUM62")

# Define gap penalties
gap_open = -10
gap_extend = -0.5

# Function to calculate pairwise similarity score
def calculate_similarity(seq1, seq2):
alignments = pairwise2.align.globalds(seq1, seq2, matrix, gap_open, gap_extend)
return alignments[0][2] # Return the score of the best alignment

# Compute pairwise similarity between all sequences in group1 and group2
similarity_scores = []
for seq1 in group1:
for seq2 in group2:
score = calculate_similarity(seq1, seq2)
similarity_scores.append(score)

# Calculate the average similarity score
average_similarity = sum(similarity_scores) / len(similarity_scores)
return similarity_scores, average_similarity


# sample 10 sequences randomly 100x
train_seq = old_train_df.prot_seq.drop_duplicates().to_list()
test_seq = old_test_df.prot_seq.drop_duplicates().to_list()
sample_size = 5
trials = 100

est_similarity = 0
for _ in tqdm(range(trials)):
_, avg = get_group_similarity(random.sample(train_seq, sample_size),
random.sample(test_seq, sample_size))
est_similarity += avg

print(est_similarity/1000)




# %%
########################################################################
########################## VIOLIN PLOTTING #############################
########################################################################
# building pocket datasets:
from src.utils.pocket_alignment import pocket_dataset_full
import shutil
import os

data_dir = '/cluster/home/t122995uhn/projects/data/'
db_type = ['kiba', 'davis']
db_feat = ['nomsa_binary_original_binary', 'nomsa_aflow_original_binary',
'nomsa_binary_gvp_binary', 'nomsa_aflow_gvp_binary']

for t in db_type:
for f in db_feat:
print(f'\n---{t}-{f}---\n')
dataset_dir= f"{data_dir}/DavisKibaDataset/{t}/{f}/full"
save_dir = f"{data_dir}/v131/DavisKibaDataset/{t}/{f}/full"

pocket_dataset_full(
dataset_dir= dataset_dir,
pocket_dir = f"{data_dir}/{t}/",
save_dir = save_dir,
skip_download=True
)

#%%
import pandas as pd

def get_test_oncokbs(train_df=pd.read_csv('/cluster/home/t122995uhn/projects/data/test/PDBbindDataset/nomsa_binary_original_binary/full/cleaned_XY.csv'),
oncokb_fp='/cluster/home/t122995uhn/projects/data/tcga/mart_export.tsv',
biomart='/cluster/home/t122995uhn/projects/downloads/oncoKB_DrugGenePairList.csv'):
#Get gene names for PDBbind
dfbm = pd.read_csv(oncokb_fp, sep='\t')
dfbm['PDB ID'] = dfbm['PDB ID'].str.lower()
train_df.reset_index(names='idx',inplace=True)

df_uni = train_df.merge(dfbm, how='inner', left_on='prot_id', right_on='UniProtKB/Swiss-Prot ID')
df_pdb = train_df.merge(dfbm, how='inner', left_on='code', right_on='PDB ID')

# identifying ovelap with oncokb
# df_all will have duplicate entries for entries with multiple gene names...
df_all = pd.concat([df_uni, df_pdb]).drop_duplicates(['idx', 'Gene name'])[['idx', 'code', 'Gene name']]

dfkb = pd.read_csv(biomart)
df_all_kb = df_all.merge(dfkb.drop_duplicates('gene'), left_on='Gene name', right_on='gene', how='inner')

trained_genes = set(df_all_kb.gene)

#Identify non-trained genes
return dfkb[~dfkb['gene'].isin(trained_genes)], dfkb[dfkb['gene'].isin(trained_genes)], dfkb


train_df = pd.read_csv('/cluster/home/t122995uhn/projects/data/test/PDBbindDataset/nomsa_binary_original_binary/train0/cleaned_XY.csv')
val_df = pd.read_csv('/cluster/home/t122995uhn/projects/data/test/PDBbindDataset/nomsa_binary_original_binary/val0/cleaned_XY.csv')

train_df = pd.concat([train_df, val_df])

get_test_oncokbs(train_df=train_df)

#%%
##############################################################################
########################## BUILD/SPLIT DATASETS ##############################
##############################################################################
import os
from src.data_prep.init_dataset import create_datasets
from src import cfg
import logging
from matplotlib import pyplot as plt

from src.analysis.figures import prepare_df, fig_combined, custom_fig

dft = prepare_df('./results/v115/model_media/model_stats.csv')
dfv = prepare_df('./results/v115/model_media/model_stats_val.csv')

models = {
'DG': ('nomsa', 'binary', 'original', 'binary'),
'esm': ('ESM', 'binary', 'original', 'binary'), # esm model
'aflow': ('nomsa', 'aflow', 'original', 'binary'),
# 'gvpP': ('gvp', 'binary', 'original', 'binary'),
'gvpL': ('nomsa', 'binary', 'gvp', 'binary'),
# 'aflow_ring3': ('nomsa', 'aflow_ring3', 'original', 'binary'),
'gvpL_aflow': ('nomsa', 'aflow', 'gvp', 'binary'),
# 'gvpL_aflow_rng3': ('nomsa', 'aflow_ring3', 'gvp', 'binary'),
#GVPL_ESMM_davis3D_nomsaF_aflowE_48B_0.00010636872718329864LR_0.23282479481785903D_2000E_gvpLF_binaryLE
# 'gvpl_esm_aflow': ('ESM', 'aflow', 'gvp', 'binary'),
}

fig, axes = fig_combined(dft, datasets=['davis'], fig_callable=custom_fig,
models=models, metrics=['cindex', 'mse'],
fig_scale=(10,5), add_stats=True, title_postfix=" test set performance", box=True, fold_labels=True)
plt.xticks(rotation=45)

fig, axes = fig_combined(dfv, datasets=['davis'], fig_callable=custom_fig,
models=models, metrics=['cindex', 'mse'],
fig_scale=(10,5), add_stats=True, title_postfix=" validation set performance", box=True, fold_labels=True)
plt.xticks(rotation=45)
cfg.logger.setLevel(logging.DEBUG)

dbs = [cfg.DATA_OPT.davis, cfg.DATA_OPT.kiba]
splits = ['davis', 'kiba']
splits = ['/cluster/home/t122995uhn/projects/MutDTA/splits/' + s for s in splits]
print(splits)

#%%
for split, db in zip(splits, dbs):
print('\n',split, db)
create_datasets(db,
feat_opt=cfg.PRO_FEAT_OPT.nomsa,
edge_opt=[cfg.PRO_EDGE_OPT.binary, cfg.PRO_EDGE_OPT.aflow],
ligand_features=[cfg.LIG_FEAT_OPT.original, cfg.LIG_FEAT_OPT.gvp],
ligand_edges=cfg.LIG_EDGE_OPT.binary, overwrite=False,
k_folds=5,

test_prots_csv=f'{split}/test.csv',
val_prots_csv=[f'{split}/val{i}.csv' for i in range(5)])

#%% TEST INFERENCE
from src import cfg
from src.utils.loader import Loader

# db2 = Loader.load_dataset(cfg.DATA_OPT.davis,
# cfg.PRO_FEAT_OPT.nomsa, cfg.PRO_EDGE_OPT.aflow,
# path='/cluster/home/t122995uhn/projects/data/',
# subset="full")

db2 = Loader.load_DataLoaders(cfg.DATA_OPT.davis,
cfg.PRO_FEAT_OPT.nomsa, cfg.PRO_EDGE_OPT.aflow,
path='/cluster/home/t122995uhn/projects/data/v131',
training_fold=0,
batch_train=2)
for b2 in db2['test']: break


# %%
m = Loader.init_model(cfg.MODEL_OPT.DG, cfg.PRO_FEAT_OPT.nomsa, cfg.PRO_EDGE_OPT.aflow,
dropout=0.3480, output_dim=256,
)

#%%
# m(b['protein'], b['ligand'])
m(b2['protein'], b2['ligand'])
#%%
model = m
loaders = db2
device = 'cpu'
NUM_EPOCHS = 1
LEARNING_RATE = 0.001
from src.train_test.training import train

logs = train(model, loaders['train'], loaders['val'], device,
epochs=NUM_EPOCHS, lr_0=LEARNING_RATE)
22 changes: 22 additions & 0 deletions results/v103/model_media/model_stats.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
run,cindex,pearson,spearman,mse,mae,rmse
DGM_davis0D_nomsaF_binaryE_128B_0.00012LR_0.24D_2000E_originalLF_binaryLE,0.8174366987963239,0.6808973439070014,0.5780986864623106,0.374029119754687,0.3416232488841833,0.6115792015386781
DGM_davis1D_nomsaF_binaryE_128B_0.00012LR_0.24D_2000E_originalLF_binaryLE,0.8359138385559401,0.7212884148849212,0.6093121108415754,0.3444294398275105,0.3380570360012467,0.5868811121747832
DGM_davis2D_nomsaF_binaryE_128B_0.00012LR_0.24D_2000E_originalLF_binaryLE,0.811306156371881,0.6771836874485692,0.5650256869521153,0.3933000326926663,0.333361968167426,0.6271363748760442
DGM_davis3D_nomsaF_binaryE_128B_0.00012LR_0.24D_2000E_originalLF_binaryLE,0.8148631243802541,0.717113315384429,0.571925536761479,0.3422128815756367,0.3177703711270548,0.5849896422806448
DGM_davis4D_nomsaF_binaryE_128B_0.00012LR_0.24D_2000E_originalLF_binaryLE,0.8196459665927316,0.694403802004145,0.5825760745508323,0.3702764201890446,0.33563001218595,0.6085034266041931
DGM_davis0D_nomsaF_aflowE_128B_0.0008279387625584954LR_0.3480347297724069D_2000E_originalLF_binaryLE,0.7936510071795485,0.628767072325098,0.5217398281378556,0.3566859747000747,0.3591853688744937,0.597231927060229
DGM_davis1D_nomsaF_aflowE_128B_0.0008279387625584954LR_0.3480347297724069D_2000E_originalLF_binaryLE,0.8009815205097928,0.6035635252189794,0.5304746622864567,0.4253406250688673,0.364227359902625,0.6521814356978182
DGM_davis3D_nomsaF_aflowE_128B_0.0008279387625584954LR_0.3480347297724069D_2000E_originalLF_binaryLE,0.7783955876418098,0.5816462981556966,0.4961723044095886,0.4376154312774337,0.3656365177210639,0.6615250798552038
DGM_davis4D_nomsaF_aflowE_128B_0.0008279387625584954LR_0.3480347297724069D_2000E_originalLF_binaryLE,0.8181456735871336,0.6918684941945846,0.56229516172368,0.3071043302279289,0.2969707269294589,0.554169947063109
DGM_davis2D_nomsaF_aflowE_128B_0.0008279387625584954LR_0.3480347297724069D_2000E_originalLF_binaryLE,0.8014907154138579,0.6425965261636467,0.5354462017864902,0.3606209315377456,0.3375259168007795,0.6005172200176657
DGM_kiba0D_nomsaF_binaryE_128B_0.0001LR_0.4D_2000E_originalLF_binaryLE,0.7476189416085207,0.7148917008987766,0.6299877614860792,0.3746319657859179,0.3958356694230301,0.6120718632529336
DGM_kiba1D_nomsaF_binaryE_128B_0.0001LR_0.4D_2000E_originalLF_binaryLE,0.7073391610819149,0.624956249151526,0.5401876728173656,0.4451318825041403,0.458846963456725,0.667182045999546
DGM_kiba2D_nomsaF_binaryE_128B_0.0001LR_0.4D_2000E_originalLF_binaryLE,0.7401141841678894,0.6795148074510864,0.6127459332278625,0.4004100160666026,0.4095781581139723,0.6327795951724444
DGM_kiba3D_nomsaF_binaryE_128B_0.0001LR_0.4D_2000E_originalLF_binaryLE,0.7396234368040389,0.6913457932090825,0.6201197126448974,0.3934219012917641,0.4068530834848238,0.6272335301080165
DGM_kiba4D_nomsaF_binaryE_128B_0.0001LR_0.4D_2000E_originalLF_binaryLE,0.752441708545282,0.7025492844189518,0.6449954833411846,0.3728163774990898,0.4045171920082104,0.6105869123221442
DGM_kiba0D_nomsaF_aflowE_64B_0.0001139464546302261LR_0.4321620419748407D_2000E_originalLF_binaryLE,0.7599929872587803,0.7067412429690916,0.6593355592769512,0.3962219319168832,0.4099100126533609,0.6294616206861887
DGM_kiba2D_nomsaF_aflowE_64B_0.0001139464546302261LR_0.4321620419748407D_2000E_originalLF_binaryLE,0.7149604681753393,0.6152047008431843,0.5597795125500629,0.4741719822054008,0.4603646989542154,0.6886014683439187
DGM_kiba1D_nomsaF_aflowE_64B_0.0001139464546302261LR_0.4321620419748407D_2000E_originalLF_binaryLE,0.7140873472783476,0.6102548954720128,0.5558196740209606,0.4781851688315759,0.4659358458753446,0.6915093411021834
DGM_kiba3D_nomsaF_aflowE_64B_0.0001139464546302261LR_0.4321620419748407D_2000E_originalLF_binaryLE,0.7164547158247304,0.6084847523640808,0.5607065445063388,0.4802083760845744,0.4646035882672965,0.6929706891958523
DGM_kiba4D_nomsaF_aflowE_64B_0.0001139464546302261LR_0.4321620419748407D_2000E_originalLF_binaryLE,0.7687577053257117,0.7532822502738942,0.6745267167129126,0.3466135049736077,0.3887611475832294,0.5887389107011765
EDIM_davis0D_nomsaF_binaryE_48B_0.0001LR_0.4D_2000E_originalLF_binaryLE,0.8114566611493259,0.7317647125777735,0.6044949818493646,0.3736163373086704,0.3493916191183746,0.611241635778086
22 changes: 22 additions & 0 deletions results/v103/model_media/model_stats_val.csv
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
run,cindex,pearson,spearman,mse,mae,rmse
DGM_davis0D_nomsaF_binaryE_128B_0.00012LR_0.24D_2000E_originalLF_binaryLE,0.8507053734550806,0.7688628504779598,0.6689225345680122,0.3760747658599554,0.3388000398874283,0.6132493504765867
DGM_davis1D_nomsaF_binaryE_128B_0.00012LR_0.24D_2000E_originalLF_binaryLE,0.8718442303414345,0.8308115505911805,0.7173863620955029,0.323120446450846,0.3234809194096876,0.5684368447337365
DGM_davis2D_nomsaF_binaryE_128B_0.00012LR_0.24D_2000E_originalLF_binaryLE,0.8818149976678145,0.834014760655388,0.7187113282294693,0.3071136635927556,0.2922643621762593,0.5541783680303262
DGM_davis3D_nomsaF_binaryE_128B_0.00012LR_0.24D_2000E_originalLF_binaryLE,0.8686054788183053,0.828507036778059,0.7018974086625753,0.3046836030153428,0.3018857493804209,0.5519815241612194
DGM_davis4D_nomsaF_binaryE_128B_0.00012LR_0.24D_2000E_originalLF_binaryLE,0.8510324353139875,0.7912085758636695,0.6660120299194481,0.3556152282825756,0.3246624081237138,0.5963348290034514
DGM_davis0D_nomsaF_aflowE_128B_0.0008279387625584954LR_0.3480347297724069D_2000E_originalLF_binaryLE,0.8224153103333314,0.7079363892542606,0.598653291929885,0.390209011583234,0.3662272181520597,0.6246671206196417
DGM_davis1D_nomsaF_aflowE_128B_0.0008279387625584954LR_0.3480347297724069D_2000E_originalLF_binaryLE,0.8493586014591634,0.7889115161443931,0.673101173187512,0.385586498014305,0.3309167098727579,0.6209561160132856
DGM_davis3D_nomsaF_aflowE_128B_0.0008279387625584954LR_0.3480347297724069D_2000E_originalLF_binaryLE,0.8337160086816762,0.736812322417127,0.6264321347273434,0.3912511533576103,0.3474602306165372,0.6255007221079847
DGM_davis4D_nomsaF_aflowE_128B_0.0008279387625584954LR_0.3480347297724069D_2000E_originalLF_binaryLE,0.8463969226855363,0.7444417955240732,0.6410258059445946,0.3682857548365584,0.3208404985420844,0.6068655162690977
DGM_davis2D_nomsaF_aflowE_128B_0.0008279387625584954LR_0.3480347297724069D_2000E_originalLF_binaryLE,0.8485013700858233,0.7861129348915608,0.6464621130340457,0.3603118364352048,0.334368295190004,0.6002598074460799
DGM_kiba0D_nomsaF_binaryE_128B_0.0001LR_0.4D_2000E_originalLF_binaryLE,0.76838489785244,0.7349294201300529,0.6720189966892385,0.2870000493840723,0.36449329645463846,0.5357238555301344
DGM_kiba1D_nomsaF_binaryE_128B_0.0001LR_0.4D_2000E_originalLF_binaryLE,0.6830895888292509,0.6256216928279031,0.4862834063605662,0.4635991200460103,0.4637030257199048,0.6808811350346038
DGM_kiba2D_nomsaF_binaryE_128B_0.0001LR_0.4D_2000E_originalLF_binaryLE,0.7424018620084226,0.7064897658791767,0.6113374073010096,0.3869704646200874,0.4214332353778002,0.6220695014386153
DGM_kiba3D_nomsaF_binaryE_128B_0.0001LR_0.4D_2000E_originalLF_binaryLE,0.7887989201757725,0.7493602745702617,0.7146442736012206,0.2624049771770561,0.3196535898269914,0.5122547971244936
DGM_kiba4D_nomsaF_binaryE_128B_0.0001LR_0.4D_2000E_originalLF_binaryLE,0.805961202163743,0.797223082421482,0.7422315375449509,0.2469041691088263,0.3146771252445765,0.4968945251346872
DGM_kiba0D_nomsaF_aflowE_64B_0.0001139464546302261LR_0.4321620419748407D_2000E_originalLF_binaryLE,0.7414150463245769,0.7180513866946735,0.6154427990455673,0.2477156183512984,0.352769788125809,0.4977103759731139
DGM_kiba2D_nomsaF_aflowE_64B_0.0001139464546302261LR_0.4321620419748407D_2000E_originalLF_binaryLE,0.7017734270681899,0.6190248117265895,0.5234448165080476,0.4732107505692587,0.4709162900019082,0.6879031549348054
DGM_kiba1D_nomsaF_aflowE_64B_0.0001139464546302261LR_0.4321620419748407D_2000E_originalLF_binaryLE,0.715296095350676,0.6760788124615275,0.5590996884326331,0.4113607007891719,0.446394580254481,0.6413740724329071
DGM_kiba3D_nomsaF_aflowE_64B_0.0001139464546302261LR_0.4321620419748407D_2000E_originalLF_binaryLE,0.6897289677863744,0.552288313673736,0.4932783355544295,0.4163241510325705,0.4502416662950819,0.6452318583521511
DGM_kiba4D_nomsaF_aflowE_64B_0.0001139464546302261LR_0.4321620419748407D_2000E_originalLF_binaryLE,0.7785830290333009,0.772583639636063,0.6834004931220337,0.2542728701347933,0.3402214839171678,0.504254767091788
EDIM_davis0D_nomsaF_binaryE_48B_0.0001LR_0.4D_2000E_originalLF_binaryLE,0.8460899419942509,0.7821818481200006,0.6773536752793916,0.3864269594875424,0.3440388107583636,0.6216324955209006
Loading

0 comments on commit ef9ac0f

Please sign in to comment.