Skip to content

Commit

Permalink
Merge branch 'development' of https://github.com/jyaacoub/MutDTA into…
Browse files Browse the repository at this point in the history
… development
  • Loading branch information
jyaacoub committed Nov 13, 2023
2 parents e9f4be5 + 6fc396a commit 3f09148
Show file tree
Hide file tree
Showing 16 changed files with 323 additions and 271 deletions.
31 changes: 0 additions & 31 deletions docs/reqbasic.txt

This file was deleted.

43 changes: 21 additions & 22 deletions docs/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,32 +1,31 @@
numpy==1.23.5
pandas==1.5.3
tqdm==4.65.0
rdkit==2023.3.1
scipy==1.10.1
numpy
pandas
tqdm
rdkit
scipy

# for generating figures:
matplotlib==3.7.1
seaborn==0.11.2
statannotations==0.6.0
matplotlib
seaborn
statannotations

lifelines==0.27.7 # used for concordance index calc
#biopython # used for cmap
lifelines

# model building
torch==2.0.1
torch-geometric==2.3.1
transformers==4.31.0 # huggingface needed for esm
torch
torch-geometric
transformers

# optional:
torchsummary==1.5.1
tabulate==0.9.0 # for torch_geometric.nn.summary
ipykernel==6.23.1
plotly==5.14.1
requests==2.31.0
#ray[tune]
torchsummary
tabulate
ipykernel
plotly
requests
ray[tune]

submitit==1.4.5
ProDy==2.4.1
submitit
ProDy

# for chemgpt
selfies==1.0.4
selfies
32 changes: 32 additions & 0 deletions docs/requirements_versions.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
numpy==1.23.5
pandas==1.5.3
tqdm==4.65.0
rdkit==2023.3.1
scipy==1.10.1

# for generating figures:
matplotlib==3.7.1
seaborn==0.11.2
statannotations==0.6.0

lifelines==0.27.7 # used for concordance index calc
#biopython # used for cmap

# model building
torch==2.0.1
torch-geometric==2.3.1
transformers==4.31.0 # huggingface needed for esm

# optional:
torchsummary==1.5.1
tabulate==0.9.0 # for torch_geometric.nn.summary
ipykernel==6.23.1
plotly==5.14.1
requests==2.31.0
#ray[tune]

submitit==1.4.5
ProDy==2.4.1

# for chemgpt
selfies==1.0.4
23 changes: 8 additions & 15 deletions playground.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,9 @@
# %%
import pandas as pd

df = pd.read_csv('/cluster/home/t122995uhn/projects/data/DavisKibaDataset/davis/nomsa_binary_original_binary/train/XY.csv', index_col=0)
dft = pd.read_csv('/cluster/home/t122995uhn/projects/data/DavisKibaDataset/davis/nomsa_binary_original_binary/test/XY.csv', index_col=0)
dfv = pd.read_csv('/cluster/home/t122995uhn/projects/data/DavisKibaDataset/davis/nomsa_binary_original_binary/val/XY.csv', index_col=0)

trainp = df['prot_id'].drop_duplicates()
testp = dft['prot_id'].drop_duplicates()
valp = dfv['prot_id'].drop_duplicates()

overlap_train_test = trainp[trainp.isin(testp)]
overlap_train_val = trainp[trainp.isin(valp)]
overlap_test_val = testp[testp.isin(valp)]

# %%
from src.data_processing.init_dataset import create_datasets

create_datasets(
data_opt=['davis'],
feat_opt=['foldseek'],
edge_opt=['binary']
)
# %%
2 changes: 2 additions & 0 deletions results/model_media/model_stats.csv
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,5 @@ DGM_davis4D_nomsaF_af2E_64B_0.0001LR_0.4D_2000E,0.8289131436553395,0.72401732620
EDIM_davis2D_nomsaF_binaryE_48B_0.0001LR_0.4D_2000E,0.8554472810106364,0.7147960047167372,0.6480432041534627,0.3808700198664558,0.3378217575814264,0.6171466761366019
EDIM_davis3D_nomsaF_binaryE_48B_0.0001LR_0.4D_2000E,0.8432469960431263,0.7538742839566113,0.6228109253699023,0.3432475435749076,0.3027376194169844,0.5858733170019843
EDIM_davis4D_nomsaF_binaryE_48B_0.0001LR_0.4D_2000E,0.8554137020365195,0.7430289570015053,0.6449800902544639,0.3552904408465783,0.2941160608088106,0.5960624471031356
EDIM_davis0D_nomsaF_binaryE_48B_0.0001LR_0.4D_2000E,0.8421343521980615,0.7442023086590929,0.6221393745119952,0.3652477400883954,0.3289030617467355,0.6043572950568193
EDIM_davis1D_nomsaF_binaryE_48B_0.0001LR_0.4D_2000E,0.8421023614186662,0.7223580257923967,0.6225670625924141,0.3758262699500786,0.3181417466188423,0.6130467110670105
2 changes: 2 additions & 0 deletions results/model_media/model_stats_val.csv
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,5 @@ DGM_davis4D_nomsaF_af2E_64B_0.0001LR_0.4D_2000E,0.8257259216950036,0.73150924245
EDIM_davis3D_nomsaF_binaryE_48B_0.0001LR_0.4D_2000E,0.8366344407519283,0.7604833093809331,0.5981203556939766,0.3109887031416362,0.2938859237605942,0.5576636110969015
EDIM_davis2D_nomsaF_binaryE_48B_0.0001LR_0.4D_2000E,0.8371103489279001,0.7312114712420225,0.5906166464692194,0.3223543817280097,0.3226479165694292,0.5677626103645869
EDIM_davis4D_nomsaF_binaryE_48B_0.0001LR_0.4D_2000E,0.8333774894628959,0.759476307205861,0.6095430076677542,0.3768294961878906,0.31885009590497027,0.6138643956020666
EDIM_davis0D_nomsaF_binaryE_48B_0.0001LR_0.4D_2000E,0.8289380911435942,0.7355891133257101,0.6157672298130831,0.4139988652707585,0.3623009506861369,0.6434274359014842
EDIM_davis1D_nomsaF_binaryE_48B_0.0001LR_0.4D_2000E,0.8526864235317578,0.7780004172307035,0.6560782026672488,0.37383884772108467,0.3317426473954145,0.6114236237839398
95 changes: 67 additions & 28 deletions src/data_analysis/figures.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,34 +156,6 @@ def fig2_pro_feat(df, verbose=False, sel_col='cindex', exclude=[], show=True, ad
# reset stylesheet back to defaults
mpl.rcParams.update(mpl.rcParamsDefault)

# similar to above but only for one dataset to show significance between ESM and other features
# this will be plotted as a jitter plot with error bars
def fig4_pro_feat_violin(df, sel_dataset='davis', verbose=False, sel_col='cindex', exclude=[],
show=True, add_labels=True, add_stats=False):
# Extract relevant data
filtered_df = df[(df['edge'] == 'binary') & (~df['overlap']) & (df['lig_feat'].isna())]

# show all with fold info
filtered_df = filtered_df[(filtered_df['data'] == sel_dataset) & (filtered_df['fold'] != '')]
nomsa = filtered_df[(filtered_df['feat'] == 'nomsa')][sel_col]
msa = filtered_df[(filtered_df['feat'] == 'msa')][sel_col]
shannon = filtered_df[(filtered_df['feat'] == 'shannon')][sel_col]
esm = filtered_df[(filtered_df['feat'] == 'ESM')][sel_col]

# Might have different length for each feature
ax = sns.violinplot(data=[nomsa, msa, shannon, esm])
ax.set_xticklabels(['nomsa', 'msa', 'shannon', 'esm'])
ax.set_ylabel(sel_col)
ax.set_xlabel('Features')
ax.set_title(f'Feature {sel_col} for {sel_dataset}')

# %% Annotation for stats
if add_stats:
pairs=[('nomsa', 'msa'), ('nomsa', 'shannon'), ('msa', 'shannon')]
annotator = Annotator(ax, pairs, data=filtered_df, x='feat', y=sel_col, verbose=verbose)
annotator.configure(test='Mann-Whitney', text_format='star', loc='outside')
annotator.apply_and_annotate()

# Figure 3 - Edge type cindex difference
# Edges -> binary, simple, anm, af2
def fig3_edge_feat(df, verbose=False, sel_col='cindex', exclude=[], show=True, add_labels=True):
Expand Down Expand Up @@ -275,6 +247,70 @@ def fig3_edge_feat(df, verbose=False, sel_col='cindex', exclude=[], show=True, a
# reset stylesheet back to defaults
mpl.rcParams.update(mpl.rcParamsDefault)

# Figure 4: violin plot with error bars for Cross-validation results to show significance among pro feats
def fig4_pro_feat_violin(df, sel_dataset='davis', verbose=False, sel_col='cindex', exclude=[],
show=True, add_labels=True, add_stats=False):
# Extract relevant data
filtered_df = df[(df['edge'] == 'binary') & (~df['overlap']) & (df['lig_feat'].isna())]

# show all with fold info
filtered_df = filtered_df[(filtered_df['data'] == sel_dataset) & (filtered_df['fold'] != '')]
nomsa = filtered_df[(filtered_df['feat'] == 'nomsa')][sel_col]
msa = filtered_df[(filtered_df['feat'] == 'msa')][sel_col]
shannon = filtered_df[(filtered_df['feat'] == 'shannon')][sel_col]
esm = filtered_df[(filtered_df['feat'] == 'ESM')][sel_col]

# Get values for each node feature
ax = sns.violinplot(data=[nomsa, msa, shannon, esm])
ax.set_xticklabels(['nomsa', 'msa', 'shannon', 'esm'])
ax.set_ylabel(sel_col)
ax.set_xlabel('Features')
ax.set_title(f'Feature {sel_col} for {sel_dataset}')

# Annotation for stats
if add_stats:
pairs=[('nomsa', 'msa'), ('nomsa', 'shannon'), ('msa', 'shannon')]
annotator = Annotator(ax, pairs, data=filtered_df, x='feat', y=sel_col, verbose=verbose)
annotator.configure(test='Mann-Whitney', text_format='star', loc='inside',
hide_non_significant=True)
annotator.apply_and_annotate()

if show:
plt.show()

# Figure 5: violin plot with error bars for Cross-validation results to show significance among edge feats
def fig5_edge_feat_violin(df, sel_dataset='davis', verbose=False, sel_col='cindex', exclude=[],
show=True, add_labels=True, add_stats=False):
filtered_df = df[(df['feat'] == 'nomsa') & (~df['overlap']) & (df['lig_feat'].isna())]
filtered_df = filtered_df[(filtered_df['data'] == sel_dataset) & (filtered_df['fold'] != '')]

filtered_df.sort_values(by=['edge'], inplace=True)

# Get values for each edge type
binary = filtered_df[filtered_df['edge'] == 'binary'][sel_col]
simple = filtered_df[filtered_df['edge'] == 'simple'][sel_col]
anm = filtered_df[filtered_df['edge'] == 'anm'][sel_col]
af2 = filtered_df[filtered_df['edge'] == 'af2'][sel_col]

# plot violin plot with annotations
ax = sns.violinplot(data=[binary, simple, anm, af2])
ax.set_xticklabels(['binary', 'simple', 'anm', 'af2'])
ax.set_ylabel(sel_col)
ax.set_xlabel('Edge type')
ax.set_title(f'Edge type {sel_col} for {sel_dataset}')

if add_stats:
pairs = [('binary', 'simple'), ('binary', 'anm'), ('binary', 'af2'),
('simple', 'anm'), ('simple', 'af2'), ('anm', 'af2')]
annotator = Annotator(ax, pairs, data=filtered_df, x='edge', y=sel_col, verbose=verbose)
annotator.configure(test='Mann-Whitney', text_format='star', loc='inside', hide_non_significant=True)
annotator.apply_and_annotate()

if show:
plt.show()



def prepare_df(csv_p:str, old_csv_p='results/model_media/old_model_stats.csv') -> pd.DataFrame:
df = pd.read_csv(csv_p)

Expand Down Expand Up @@ -304,6 +340,9 @@ def prepare_df(csv_p:str, old_csv_p='results/model_media/old_model_stats.csv') -

return df




if __name__ == '__main__':
df = prepare_df('results/model_media/model_stats.csv')

Expand Down
59 changes: 38 additions & 21 deletions src/data_processing/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,15 +100,10 @@ def __init__(self, save_root:str, data_root:str, aln_dir:str,
assert feature_opt in self.FEATURE_OPTIONS, \
f"Invalid feature_opt '{feature_opt}', choose from {self.FEATURE_OPTIONS}"

self.shannon = False
self.feature_opt = feature_opt
if feature_opt == 'nomsa':
self.aln_dir = None # none treats it as np.zeros
elif feature_opt == 'msa':
self.aln_dir = None # none treats it as np.zeros
if feature_opt in ['msa', 'shannon']:
self.aln_dir = aln_dir # path to sequence alignments
elif feature_opt == 'shannon':
self.aln_dir = aln_dir
self.shannon = True

assert edge_opt in self.EDGE_OPTIONS, \
f"Invalid edge_opt '{edge_opt}', choose from {self.EDGE_OPTIONS}"
Expand Down Expand Up @@ -150,6 +145,10 @@ def pdb_p(self, code) -> str:
"""path to pdbfile for a particular protein"""
raise NotImplementedError

def pddlt_p(self, code) -> str:
"""path to plddt file for a particular protein"""
return None

@abc.abstractmethod
def cmap_p(self, code) -> str:
raise NotImplementedError
Expand Down Expand Up @@ -332,10 +331,15 @@ def process(self):
# extra_feat is Lx54 or Lx34 (if shannon=True)
try:
pro_cmap = np.load(self.cmap_p(code))
extra_feat, edge_idx = target_to_graph(pro_seq, pro_cmap,
threshold=self.cmap_threshold,
aln_file=self.aln_p(code),
shannon=self.shannon)
# updated_seq is for updated foldseek 3di combined seq
updated_seq, extra_feat, edge_idx = target_to_graph(target_sequence=pro_seq,
contact_map=pro_cmap,
threshold=self.cmap_threshold,
pro_feat=self.feature_opt,
aln_file=self.aln_p(code),
# for foldseek feats
pdb_fp=self.pdb_p(code),
pddlt_fp=self.pddlt_p(code))
except Exception as e:
raise Exception(f"error on protein graph creation for code {code}") from e

Expand Down Expand Up @@ -364,7 +368,7 @@ def process(self):

pro = torchg.data.Data(x=torch.Tensor(pro_feat),
edge_index=torch.LongTensor(edge_idx),
pro_seq=pro_seq, # protein sequence for downstream esm model
pro_seq=updated_seq, # Protein sequence for downstream esm model
prot_id=prot_id,
edge_weight=pro_edge_weight)
processed_prots[prot_id] = pro
Expand Down Expand Up @@ -591,13 +595,22 @@ def pdb_p(self, code, safe=True):
code = re.sub(r'[()]', '_', code)
# davis and kiba dont have their own structures so this must be made using
# af or some other method beforehand.
if self.edge_opt not in cfg.STRUCT_EDGE_OPT: return None
if (self.edge_opt not in cfg.STRUCT_EDGE_OPT) and \
(self.feature_opt not in cfg.STRUCT_PRO_FEAT_OPT): return None

file = glob(os.path.join(self.af_conf_dir, f'highQ/{code}_unrelaxed_rank_001*.pdb'))
# should only be one file
assert not safe or len(file) == 1, f'Incorrect pdb pathing, {len(file)}# of structures for {code}.'
return file[0] if len(file) >= 1 else None

def pddlt_p(self, code, safe=True):
# this contains confidence scores for each predicted residue position in the protein
pdb_p = self.pdb_p(code, safe=safe)
if pdb_p is None: return None
# from O00141_unrelaxed_rank_001_alphafold2_ptm_model_1_seed_000.pdb
# to O00141_scores_rank_001_alphafold2_ptm_model_1_seed_000.json
return pdb_p.replace('unrelaxed', 'scores').replace('.pdb', '.json')

def cmap_p(self, code):
return os.path.join(self.data_root, 'pconsc4', f'{code}.npy')

Expand Down Expand Up @@ -718,17 +731,21 @@ def pre_process(self):
no_cmap = [c for c in codes if not os.path.isfile(self.cmap_p(c))]
print(f'Number of codes without cmap files: {len(no_cmap)} out of {len(codes)}')

# Checking that structure and af_confs files are present if edgeW is anm or af2
# Checking that structure and af_confs files are present if required:
no_confs = []
if self.edge_opt in cfg.STRUCT_EDGE_OPT:
no_confs = [c for c in codes if (
(self.pdb_p(c, safe=False) is None) or # no highQ structure
(len(self.af_conf_files(c)) < 2))] # not enough af confirmations.
if self.edge_opt in cfg.STRUCT_EDGE_OPT or self.feature_opt in cfg.STRUCT_PRO_FEAT_OPT:
if self.feature_opt == 'foldseek':
# we only need HighQ structures for foldseek
no_confs = [c for c in codes if (self.pdb_p(c, safe=False) is None)]
else:
no_confs = [c for c in codes if (
(self.pdb_p(c, safe=False) is None) or # no highQ structure
(len(self.af_conf_files(c)) < 2))] # only if not for foldseek

# WARNING: TEMPORARY FIX FOR DAVIS (TESK1 highQ structure is mismatched...)
no_confs.append('TESK1')
# WARNING: TEMPORARY FIX FOR DAVIS (TESK1 highQ structure is mismatched...)
no_confs.append('TESK1')

print(f'Number of codes missing af2 configurations: {len(no_confs)} / {len(codes)}')
print(f'Number of codes missing af2 configurations: {len(no_confs)} / {len(codes)}')

invalid_codes = set(no_aln + no_cmap + no_confs)
# filtering out invalid codes:
Expand Down
Loading

0 comments on commit 3f09148

Please sign in to comment.