diff --git a/docs/reqbasic.txt b/docs/reqbasic.txt deleted file mode 100644 index 77171a4c..00000000 --- a/docs/reqbasic.txt +++ /dev/null @@ -1,31 +0,0 @@ -numpy -pandas -tqdm -rdkit -scipy - -# for generating figures: -matplotlib -seaborn -statannotations - -lifelines -#biopython # used for cmap - -# model building -torch -torch-geometric -transformers - -# optional: -torchsummary -tabulate -ipykernel -plotly -requests - -submitit -ProDy - -# for chemgpt -selfies diff --git a/docs/requirements.txt b/docs/requirements.txt index c4ac33f4..e7b6cbf8 100644 --- a/docs/requirements.txt +++ b/docs/requirements.txt @@ -1,32 +1,31 @@ -numpy==1.23.5 -pandas==1.5.3 -tqdm==4.65.0 -rdkit==2023.3.1 -scipy==1.10.1 +numpy +pandas +tqdm +rdkit +scipy # for generating figures: -matplotlib==3.7.1 -seaborn==0.11.2 -statannotations==0.6.0 +matplotlib +seaborn +statannotations -lifelines==0.27.7 # used for concordance index calc -#biopython # used for cmap +lifelines # model building -torch==2.0.1 -torch-geometric==2.3.1 -transformers==4.31.0 # huggingface needed for esm +torch +torch-geometric +transformers # optional: -torchsummary==1.5.1 -tabulate==0.9.0 # for torch_geometric.nn.summary -ipykernel==6.23.1 -plotly==5.14.1 -requests==2.31.0 -#ray[tune] +torchsummary +tabulate +ipykernel +plotly +requests +ray[tune] -submitit==1.4.5 -ProDy==2.4.1 +submitit +ProDy # for chemgpt -selfies==1.0.4 +selfies diff --git a/docs/requirements_versions.txt b/docs/requirements_versions.txt new file mode 100644 index 00000000..c4ac33f4 --- /dev/null +++ b/docs/requirements_versions.txt @@ -0,0 +1,32 @@ +numpy==1.23.5 +pandas==1.5.3 +tqdm==4.65.0 +rdkit==2023.3.1 +scipy==1.10.1 + +# for generating figures: +matplotlib==3.7.1 +seaborn==0.11.2 +statannotations==0.6.0 + +lifelines==0.27.7 # used for concordance index calc +#biopython # used for cmap + +# model building +torch==2.0.1 +torch-geometric==2.3.1 +transformers==4.31.0 # huggingface needed for esm + +# optional: +torchsummary==1.5.1 +tabulate==0.9.0 # for torch_geometric.nn.summary +ipykernel==6.23.1 +plotly==5.14.1 +requests==2.31.0 +#ray[tune] + +submitit==1.4.5 +ProDy==2.4.1 + +# for chemgpt +selfies==1.0.4 diff --git a/playground.py b/playground.py index 7613312b..5bc24243 100644 --- a/playground.py +++ b/playground.py @@ -1,16 +1,9 @@ # %% -import pandas as pd - -df = pd.read_csv('/cluster/home/t122995uhn/projects/data/DavisKibaDataset/davis/nomsa_binary_original_binary/train/XY.csv', index_col=0) -dft = pd.read_csv('/cluster/home/t122995uhn/projects/data/DavisKibaDataset/davis/nomsa_binary_original_binary/test/XY.csv', index_col=0) -dfv = pd.read_csv('/cluster/home/t122995uhn/projects/data/DavisKibaDataset/davis/nomsa_binary_original_binary/val/XY.csv', index_col=0) - -trainp = df['prot_id'].drop_duplicates() -testp = dft['prot_id'].drop_duplicates() -valp = dfv['prot_id'].drop_duplicates() - -overlap_train_test = trainp[trainp.isin(testp)] -overlap_train_val = trainp[trainp.isin(valp)] -overlap_test_val = testp[testp.isin(valp)] - -# %% +from src.data_processing.init_dataset import create_datasets + +create_datasets( + data_opt=['davis'], + feat_opt=['foldseek'], + edge_opt=['binary'] +) +# %% \ No newline at end of file diff --git a/results/model_media/model_stats.csv b/results/model_media/model_stats.csv index c6c3f78a..e30c0cdc 100644 --- a/results/model_media/model_stats.csv +++ b/results/model_media/model_stats.csv @@ -55,3 +55,5 @@ DGM_davis4D_nomsaF_af2E_64B_0.0001LR_0.4D_2000E,0.8289131436553395,0.72401732620 EDIM_davis2D_nomsaF_binaryE_48B_0.0001LR_0.4D_2000E,0.8554472810106364,0.7147960047167372,0.6480432041534627,0.3808700198664558,0.3378217575814264,0.6171466761366019 EDIM_davis3D_nomsaF_binaryE_48B_0.0001LR_0.4D_2000E,0.8432469960431263,0.7538742839566113,0.6228109253699023,0.3432475435749076,0.3027376194169844,0.5858733170019843 EDIM_davis4D_nomsaF_binaryE_48B_0.0001LR_0.4D_2000E,0.8554137020365195,0.7430289570015053,0.6449800902544639,0.3552904408465783,0.2941160608088106,0.5960624471031356 +EDIM_davis0D_nomsaF_binaryE_48B_0.0001LR_0.4D_2000E,0.8421343521980615,0.7442023086590929,0.6221393745119952,0.3652477400883954,0.3289030617467355,0.6043572950568193 +EDIM_davis1D_nomsaF_binaryE_48B_0.0001LR_0.4D_2000E,0.8421023614186662,0.7223580257923967,0.6225670625924141,0.3758262699500786,0.3181417466188423,0.6130467110670105 diff --git a/results/model_media/model_stats_val.csv b/results/model_media/model_stats_val.csv index ac26b4a5..1767a925 100644 --- a/results/model_media/model_stats_val.csv +++ b/results/model_media/model_stats_val.csv @@ -33,3 +33,5 @@ DGM_davis4D_nomsaF_af2E_64B_0.0001LR_0.4D_2000E,0.8257259216950036,0.73150924245 EDIM_davis3D_nomsaF_binaryE_48B_0.0001LR_0.4D_2000E,0.8366344407519283,0.7604833093809331,0.5981203556939766,0.3109887031416362,0.2938859237605942,0.5576636110969015 EDIM_davis2D_nomsaF_binaryE_48B_0.0001LR_0.4D_2000E,0.8371103489279001,0.7312114712420225,0.5906166464692194,0.3223543817280097,0.3226479165694292,0.5677626103645869 EDIM_davis4D_nomsaF_binaryE_48B_0.0001LR_0.4D_2000E,0.8333774894628959,0.759476307205861,0.6095430076677542,0.3768294961878906,0.31885009590497027,0.6138643956020666 +EDIM_davis0D_nomsaF_binaryE_48B_0.0001LR_0.4D_2000E,0.8289380911435942,0.7355891133257101,0.6157672298130831,0.4139988652707585,0.3623009506861369,0.6434274359014842 +EDIM_davis1D_nomsaF_binaryE_48B_0.0001LR_0.4D_2000E,0.8526864235317578,0.7780004172307035,0.6560782026672488,0.37383884772108467,0.3317426473954145,0.6114236237839398 diff --git a/src/data_analysis/figures.py b/src/data_analysis/figures.py index 4ccb3c44..859d2ccb 100644 --- a/src/data_analysis/figures.py +++ b/src/data_analysis/figures.py @@ -156,34 +156,6 @@ def fig2_pro_feat(df, verbose=False, sel_col='cindex', exclude=[], show=True, ad # reset stylesheet back to defaults mpl.rcParams.update(mpl.rcParamsDefault) -# similar to above but only for one dataset to show significance between ESM and other features -# this will be plotted as a jitter plot with error bars -def fig4_pro_feat_violin(df, sel_dataset='davis', verbose=False, sel_col='cindex', exclude=[], - show=True, add_labels=True, add_stats=False): - # Extract relevant data - filtered_df = df[(df['edge'] == 'binary') & (~df['overlap']) & (df['lig_feat'].isna())] - - # show all with fold info - filtered_df = filtered_df[(filtered_df['data'] == sel_dataset) & (filtered_df['fold'] != '')] - nomsa = filtered_df[(filtered_df['feat'] == 'nomsa')][sel_col] - msa = filtered_df[(filtered_df['feat'] == 'msa')][sel_col] - shannon = filtered_df[(filtered_df['feat'] == 'shannon')][sel_col] - esm = filtered_df[(filtered_df['feat'] == 'ESM')][sel_col] - - # Might have different length for each feature - ax = sns.violinplot(data=[nomsa, msa, shannon, esm]) - ax.set_xticklabels(['nomsa', 'msa', 'shannon', 'esm']) - ax.set_ylabel(sel_col) - ax.set_xlabel('Features') - ax.set_title(f'Feature {sel_col} for {sel_dataset}') - - # %% Annotation for stats - if add_stats: - pairs=[('nomsa', 'msa'), ('nomsa', 'shannon'), ('msa', 'shannon')] - annotator = Annotator(ax, pairs, data=filtered_df, x='feat', y=sel_col, verbose=verbose) - annotator.configure(test='Mann-Whitney', text_format='star', loc='outside') - annotator.apply_and_annotate() - # Figure 3 - Edge type cindex difference # Edges -> binary, simple, anm, af2 def fig3_edge_feat(df, verbose=False, sel_col='cindex', exclude=[], show=True, add_labels=True): @@ -275,6 +247,70 @@ def fig3_edge_feat(df, verbose=False, sel_col='cindex', exclude=[], show=True, a # reset stylesheet back to defaults mpl.rcParams.update(mpl.rcParamsDefault) +# Figure 4: violin plot with error bars for Cross-validation results to show significance among pro feats +def fig4_pro_feat_violin(df, sel_dataset='davis', verbose=False, sel_col='cindex', exclude=[], + show=True, add_labels=True, add_stats=False): + # Extract relevant data + filtered_df = df[(df['edge'] == 'binary') & (~df['overlap']) & (df['lig_feat'].isna())] + + # show all with fold info + filtered_df = filtered_df[(filtered_df['data'] == sel_dataset) & (filtered_df['fold'] != '')] + nomsa = filtered_df[(filtered_df['feat'] == 'nomsa')][sel_col] + msa = filtered_df[(filtered_df['feat'] == 'msa')][sel_col] + shannon = filtered_df[(filtered_df['feat'] == 'shannon')][sel_col] + esm = filtered_df[(filtered_df['feat'] == 'ESM')][sel_col] + + # Get values for each node feature + ax = sns.violinplot(data=[nomsa, msa, shannon, esm]) + ax.set_xticklabels(['nomsa', 'msa', 'shannon', 'esm']) + ax.set_ylabel(sel_col) + ax.set_xlabel('Features') + ax.set_title(f'Feature {sel_col} for {sel_dataset}') + + # Annotation for stats + if add_stats: + pairs=[('nomsa', 'msa'), ('nomsa', 'shannon'), ('msa', 'shannon')] + annotator = Annotator(ax, pairs, data=filtered_df, x='feat', y=sel_col, verbose=verbose) + annotator.configure(test='Mann-Whitney', text_format='star', loc='inside', + hide_non_significant=True) + annotator.apply_and_annotate() + + if show: + plt.show() + +# Figure 5: violin plot with error bars for Cross-validation results to show significance among edge feats +def fig5_edge_feat_violin(df, sel_dataset='davis', verbose=False, sel_col='cindex', exclude=[], + show=True, add_labels=True, add_stats=False): + filtered_df = df[(df['feat'] == 'nomsa') & (~df['overlap']) & (df['lig_feat'].isna())] + filtered_df = filtered_df[(filtered_df['data'] == sel_dataset) & (filtered_df['fold'] != '')] + + filtered_df.sort_values(by=['edge'], inplace=True) + + # Get values for each edge type + binary = filtered_df[filtered_df['edge'] == 'binary'][sel_col] + simple = filtered_df[filtered_df['edge'] == 'simple'][sel_col] + anm = filtered_df[filtered_df['edge'] == 'anm'][sel_col] + af2 = filtered_df[filtered_df['edge'] == 'af2'][sel_col] + + # plot violin plot with annotations + ax = sns.violinplot(data=[binary, simple, anm, af2]) + ax.set_xticklabels(['binary', 'simple', 'anm', 'af2']) + ax.set_ylabel(sel_col) + ax.set_xlabel('Edge type') + ax.set_title(f'Edge type {sel_col} for {sel_dataset}') + + if add_stats: + pairs = [('binary', 'simple'), ('binary', 'anm'), ('binary', 'af2'), + ('simple', 'anm'), ('simple', 'af2'), ('anm', 'af2')] + annotator = Annotator(ax, pairs, data=filtered_df, x='edge', y=sel_col, verbose=verbose) + annotator.configure(test='Mann-Whitney', text_format='star', loc='inside', hide_non_significant=True) + annotator.apply_and_annotate() + + if show: + plt.show() + + + def prepare_df(csv_p:str, old_csv_p='results/model_media/old_model_stats.csv') -> pd.DataFrame: df = pd.read_csv(csv_p) @@ -304,6 +340,9 @@ def prepare_df(csv_p:str, old_csv_p='results/model_media/old_model_stats.csv') - return df + + + if __name__ == '__main__': df = prepare_df('results/model_media/model_stats.csv') diff --git a/src/data_processing/datasets.py b/src/data_processing/datasets.py index 6dba0ee1..9ace735b 100644 --- a/src/data_processing/datasets.py +++ b/src/data_processing/datasets.py @@ -100,15 +100,10 @@ def __init__(self, save_root:str, data_root:str, aln_dir:str, assert feature_opt in self.FEATURE_OPTIONS, \ f"Invalid feature_opt '{feature_opt}', choose from {self.FEATURE_OPTIONS}" - self.shannon = False self.feature_opt = feature_opt - if feature_opt == 'nomsa': - self.aln_dir = None # none treats it as np.zeros - elif feature_opt == 'msa': + self.aln_dir = None # none treats it as np.zeros + if feature_opt in ['msa', 'shannon']: self.aln_dir = aln_dir # path to sequence alignments - elif feature_opt == 'shannon': - self.aln_dir = aln_dir - self.shannon = True assert edge_opt in self.EDGE_OPTIONS, \ f"Invalid edge_opt '{edge_opt}', choose from {self.EDGE_OPTIONS}" @@ -150,6 +145,10 @@ def pdb_p(self, code) -> str: """path to pdbfile for a particular protein""" raise NotImplementedError + def pddlt_p(self, code) -> str: + """path to plddt file for a particular protein""" + return None + @abc.abstractmethod def cmap_p(self, code) -> str: raise NotImplementedError @@ -332,10 +331,15 @@ def process(self): # extra_feat is Lx54 or Lx34 (if shannon=True) try: pro_cmap = np.load(self.cmap_p(code)) - extra_feat, edge_idx = target_to_graph(pro_seq, pro_cmap, - threshold=self.cmap_threshold, - aln_file=self.aln_p(code), - shannon=self.shannon) + # updated_seq is for updated foldseek 3di combined seq + updated_seq, extra_feat, edge_idx = target_to_graph(target_sequence=pro_seq, + contact_map=pro_cmap, + threshold=self.cmap_threshold, + pro_feat=self.feature_opt, + aln_file=self.aln_p(code), + # for foldseek feats + pdb_fp=self.pdb_p(code), + pddlt_fp=self.pddlt_p(code)) except Exception as e: raise Exception(f"error on protein graph creation for code {code}") from e @@ -364,7 +368,7 @@ def process(self): pro = torchg.data.Data(x=torch.Tensor(pro_feat), edge_index=torch.LongTensor(edge_idx), - pro_seq=pro_seq, # protein sequence for downstream esm model + pro_seq=updated_seq, # Protein sequence for downstream esm model prot_id=prot_id, edge_weight=pro_edge_weight) processed_prots[prot_id] = pro @@ -591,13 +595,22 @@ def pdb_p(self, code, safe=True): code = re.sub(r'[()]', '_', code) # davis and kiba dont have their own structures so this must be made using # af or some other method beforehand. - if self.edge_opt not in cfg.STRUCT_EDGE_OPT: return None + if (self.edge_opt not in cfg.STRUCT_EDGE_OPT) and \ + (self.feature_opt not in cfg.STRUCT_PRO_FEAT_OPT): return None file = glob(os.path.join(self.af_conf_dir, f'highQ/{code}_unrelaxed_rank_001*.pdb')) # should only be one file assert not safe or len(file) == 1, f'Incorrect pdb pathing, {len(file)}# of structures for {code}.' return file[0] if len(file) >= 1 else None + def pddlt_p(self, code, safe=True): + # this contains confidence scores for each predicted residue position in the protein + pdb_p = self.pdb_p(code, safe=safe) + if pdb_p is None: return None + # from O00141_unrelaxed_rank_001_alphafold2_ptm_model_1_seed_000.pdb + # to O00141_scores_rank_001_alphafold2_ptm_model_1_seed_000.json + return pdb_p.replace('unrelaxed', 'scores').replace('.pdb', '.json') + def cmap_p(self, code): return os.path.join(self.data_root, 'pconsc4', f'{code}.npy') @@ -718,17 +731,21 @@ def pre_process(self): no_cmap = [c for c in codes if not os.path.isfile(self.cmap_p(c))] print(f'Number of codes without cmap files: {len(no_cmap)} out of {len(codes)}') - # Checking that structure and af_confs files are present if edgeW is anm or af2 + # Checking that structure and af_confs files are present if required: no_confs = [] - if self.edge_opt in cfg.STRUCT_EDGE_OPT: - no_confs = [c for c in codes if ( - (self.pdb_p(c, safe=False) is None) or # no highQ structure - (len(self.af_conf_files(c)) < 2))] # not enough af confirmations. + if self.edge_opt in cfg.STRUCT_EDGE_OPT or self.feature_opt in cfg.STRUCT_PRO_FEAT_OPT: + if self.feature_opt == 'foldseek': + # we only need HighQ structures for foldseek + no_confs = [c for c in codes if (self.pdb_p(c, safe=False) is None)] + else: + no_confs = [c for c in codes if ( + (self.pdb_p(c, safe=False) is None) or # no highQ structure + (len(self.af_conf_files(c)) < 2))] # only if not for foldseek - # WARNING: TEMPORARY FIX FOR DAVIS (TESK1 highQ structure is mismatched...) - no_confs.append('TESK1') + # WARNING: TEMPORARY FIX FOR DAVIS (TESK1 highQ structure is mismatched...) + no_confs.append('TESK1') - print(f'Number of codes missing af2 configurations: {len(no_confs)} / {len(codes)}') + print(f'Number of codes missing af2 configurations: {len(no_confs)} / {len(codes)}') invalid_codes = set(no_aln + no_cmap + no_confs) # filtering out invalid codes: diff --git a/src/feature_extraction/protein.py b/src/feature_extraction/protein.py index 8d07392b..3067376e 100644 --- a/src/feature_extraction/protein.py +++ b/src/feature_extraction/protein.py @@ -1,14 +1,13 @@ from multiprocessing import Pool from typing import Callable, Iterable import numpy as np -import matplotlib.pyplot as plt from tqdm import tqdm import os, math import pandas as pd from src.utils import config as cfg from src.utils.residue import ResInfo, Chain -from src.feature_extraction.protein_nodes import get_pfm, target_to_feature +from src.feature_extraction.protein_nodes import get_pfm, target_to_feature, get_foldseek_onehot, run_foldseek from src.feature_extraction.protein_edges import get_target_edge ######################################################################## @@ -16,7 +15,7 @@ ######################################################################## def target_to_graph(target_sequence:str, contact_map:str or np.array, threshold=10.5, pro_feat='nomsa', aln_file:str=None, - pdb_fp:str=None) -> tuple[np.array,np.array]: + pdb_fp:str=None, pddlt_fp:str=None) -> tuple[np.array,np.array]: """ Feature extraction for protein sequence using contact map to generate edge index and node features. @@ -39,11 +38,13 @@ def target_to_graph(target_sequence:str, contact_map:str or np.array, Path to alignment file for PSSM matrix, by default None `pdb_fp` : str, optional Path to pdb file for foldseek feature, by default None + `pddlt_fp` : str, optional + Path to pddlt (confidence) file for foldseek if pdb is a predicted structure, by default None Returns ------- Tuple[np.array] - tuple of (target_feature, target_edge_index) + tuple of (target_sequence, target_feature, target_edge_index) """ assert pro_feat in cfg.PRO_FEAT_OPT, \ f'Invalid protein feature option: {pro_feat}, must be one of {cfg.PRO_FEAT_OPT}' @@ -77,17 +78,31 @@ def entropy(col): pssm = np.apply_along_axis(entropy, axis=1, arr=pssm) pssm = pssm.reshape((len(target_sequence),1)) else: # normal pssm - pseudocount = 0.8 + pseudocount = 0.8 # pseudocount to avoid divide by 0 pssm = (pssm + pseudocount / 4) / (float(line_count) + pseudocount) target_feature = np.concatenate((pssm, pro_hot, pro_property), axis=1) elif pro_feat == 'foldseek': - # include foldseek token in the feature vector - # foldseek = get_foldseek(target_sequence, pdb_file) - target_feature = np.concatenate((pro_hot, pro_property), axis=1) + # returns {chain: [seq, struct_seq, combined_seq]} dict + seq_dict = run_foldseek(pdb_fp, plddt_fp=pddlt_fp) + + # use matching sequence from foldseek + combined_seq = None + for c in seq_dict: + if seq_dict[c][0] == target_sequence: + combined_seq = seq_dict[c][2] + break + assert combined_seq is not None, f'Could not find matching foldseek 3Di sequence for {pdb_fp}' + + # input sequences should now include 3di tokens + pro_hot_3di = get_foldseek_onehot(combined_seq) + target_feature = np.concatenate((pro_hot, pro_hot_3di), axis=1) + + # updating target sequence to include 3di tokens + target_sequence = combined_seq else: raise NotImplementedError(f'Invalid protein feature option: {pro_feat}') - return target_feature, edge_index + return target_sequence, target_feature, edge_index ###################################################################### diff --git a/src/feature_extraction/protein_nodes.py b/src/feature_extraction/protein_nodes.py index 81a2e892..31c78d6d 100644 --- a/src/feature_extraction/protein_nodes.py +++ b/src/feature_extraction/protein_nodes.py @@ -6,6 +6,7 @@ from multiprocessing import Pool from src.utils.residue import ResInfo, one_hot +from src.utils import config as cfg def get_pfm(aln_file: str, target_seq: str=None, overwrite=False) -> Tuple[np.array, int]: """ Returns position frequency matrix of amino acids based on MSA for each node in sequence""" @@ -55,8 +56,6 @@ def target_to_feature(target_seq): pro_hot = np.zeros((len(target_seq), len(ResInfo.amino_acids))) pro_property = np.zeros((len(target_seq), 12)) for i in range(len(target_seq)): - # if 'X' in pro_seq: - # print(pro_seq) pro_hot[i,] = one_hot(target_seq[i], ResInfo.amino_acids) pro_property[i,] = residue_features(target_seq[i]) @@ -73,44 +72,58 @@ def residue_features(residue): ResInfo.hydrophobic_ph7[residue]] return np.array(feats) -def run_foldseek(foldseek, - path, - chains: list = None, - process_id: int = 0, - plddt_path: str = None, - plddt_threshold: float = 70.) -> dict: +def get_foldseek_onehot(combined_seq): + # the target sequence includes 3Di foldseek tokens alternating with the actual sequence + # so we divide by 2 to get the length of the actual sequence + fld_hot = np.zeros((len(combined_seq)//2, len(ResInfo.foldseek_tokens))) + for i in range(1, len(combined_seq), step=2): + fld_hot[i // 2,] = one_hot(combined_seq[i], ResInfo.foldseek_tokens) + return fld_hot + +def run_foldseek(pdb_fp:str, foldseek_bin:str=cfg.FOLDSEEK_BIN, + chains: list = None, + plddt_fp: str = None, + plddt_threshold: float = 70.) -> dict: """ + Adapted from https://github.com/westlake-repl/SaProt/blob/main/utils/foldseek_util.py Args: - foldseek: Binary executable file of foldseek path: Path to pdb file + foldseek: Binary executable file of foldseek chains: Chains to be extracted from pdb file. If None, all chains will be extracted. - process_id: Process ID for temporary files. This is used for parallel processing. plddt_path: Path to plddt file. If None, plddt will not be used. + Example: colabfold/kiba_af2_out/highQ/O00141_scores_rank_001_alphafold2_ptm_model_1_seed_000.json plddt_threshold: Threshold for plddt. If plddt is lower than this value, the structure will be masked. Returns: seq_dict: A dict of structural seqs. The keys are chain IDs. The values are tuples of (seq, struc_seq, combined_seq). """ - assert os.path.exists(foldseek), f"Foldseek not found: {foldseek}" - assert os.path.exists(path), f"Pdb file not found: {path}" - assert plddt_path is None or os.path.exists(plddt_path), f"Plddt file not found: {plddt_path}" + assert os.path.exists(foldseek_bin), f"Foldseek not found: {foldseek_bin}" + assert os.path.exists(pdb_fp), f"Pdb file not found: {pdb_fp}" + assert plddt_fp is None or os.path.exists(plddt_fp), f"Plddt file not found: {plddt_fp}" - tmp_save_path = f"get_struc_seq_{process_id}.tsv" - cmd = f"{foldseek} structureto3didescriptor -v 0 --threads 1 --chain-name-mode 1 {path} {tmp_save_path}" - os.system(cmd) + # save in same location as pdb file + tmp_save_path = f"{pdb_fp}.foldseek.txt" + + # run foldseek only if the output file doesn't already exist + if not os.path.exists(tmp_save_path): + cmd = f"{foldseek_bin} structureto3didescriptor -v 0 --threads 1 --chain-name-mode 1 {pdb_fp} {tmp_save_path}" + os.system(cmd) # this is a blocking call + # remove dbtype file which is created by foldseek for some reason #TODO: why? + os.remove(tmp_save_path + ".dbtype") + # extract seqs from foldseek output seq_dict = {} - name = os.path.basename(path) + name = os.path.basename(pdb_fp) with open(tmp_save_path, "r") as r: for i, line in enumerate(r): desc, seq, struc_seq = line.split("\t")[:3] # Mask low plddt - if plddt_path is not None: - with open(plddt_path, "r") as r: - plddts = np.array(json.load(r)["confidenceScore"]) + if plddt_fp is not None: + with open(plddt_fp, "r") as r: + plddts = np.array(json.load(r)["plddt"]) # NOTE: updated from "confidenceScore" # Mask regions with plddt < threshold indices = np.where(plddts < plddt_threshold)[0] @@ -125,9 +138,7 @@ def run_foldseek(foldseek, if chain not in seq_dict: combined_seq = "".join([a + b.lower() for a, b in zip(seq, struc_seq)]) seq_dict[chain] = (seq, struc_seq, combined_seq) - - os.remove(tmp_save_path) - os.remove(tmp_save_path + ".dbtype") + return seq_dict if __name__ == '__main__': diff --git a/src/train_test/training.py b/src/train_test/training.py index 49f0861c..cd12f3f2 100644 --- a/src/train_test/training.py +++ b/src/train_test/training.py @@ -1,5 +1,3 @@ -import itertools -import gc from typing import Tuple from tqdm import tqdm @@ -10,14 +8,8 @@ from torch_geometric.loader import DataLoader from src.data_analysis.metrics import concordance_index -from src.data_processing.datasets import BaseDataset - from src.models.utils import BaseModel -from src.models.prior_work import DGraphDTA - -from src.train_test.utils import train_val_test_split, CheckpointSaver - -from src.utils.loader import Loader +from src.train_test.utils import CheckpointSaver def train(model: BaseModel, train_loader:DataLoader, val_loader:DataLoader, @@ -189,71 +181,4 @@ def test(model, test_loader, device, CRITERION=None) -> Tuple[float, np.ndarray, # Compute average test loss test_loss /= len(test_loader) - return test_loss, pred, actual - - -def train_tune(config, model:str, pro_feature:str, train_dataset:BaseDataset, val_dataset:BaseDataset): - from ray.air import session - device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') - model = Loader.init_model(model, pro_feature, config['edge'], config['dropout']) - model.to(device) - - train_loader = DataLoader(train_dataset, batch_size=config['batch_size'], - shuffle=True, - num_workers=2) - val_loader = DataLoader(val_dataset, batch_size=config['batch_size'], - shuffle=True, - num_workers=2) - - saver = CheckpointSaver(model, debug=True) - for i in range(10): # 10 epochs - logs = train(model, train_loader, val_loader, device, epochs=1, - lr_0=config['lr'], silent=True, saver=saver) - val_loss = logs['val_loss'][0] - - # Send the current training result back to Tune - session.report({"val_loss": val_loss}) - - if i % 5 == 0: - # This saves the model to the trial directory - torch.save(model.state_dict(), "./model.pth") - - - -def grid_search(pdb_dataset, TRAIN_SPLIT=0.8, VAL_SPLIT=0.1, RAND_SEED=42, - epoch_opt = [5], - weight_opt = ['kiba', 'davis', 'random'], - batch_size_opt = [32, 64, 128], - lr_opt = [0.0001, 0.001, 0.01], - dropout_opt = [0.1, 0.2, 0.3]): - device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') - print(f'Device: {device}') - model_results = {} - for WEIGHTS, BATCH_SIZE, \ - LEARNING_RATE, DROPOUT, NUM_EPOCHS in \ - itertools.product(weight_opt, batch_size_opt, \ - lr_opt, dropout_opt, epoch_opt): - MODEL_KEY = f'{WEIGHTS}W_{BATCH_SIZE}B_{LEARNING_RATE}LR_{DROPOUT}DO_{NUM_EPOCHS}E' - print(f'\n\n{MODEL_KEY}') - - - model = DGraphDTA(dropout=DROPOUT) - model.to(device) - assert WEIGHTS in weight_opt, 'WEIGHTS must be one of: kiba, davis, random' - if WEIGHTS != 'random': - model_file_name = f'results/model_checkpoints/prior_work/DGraphDTA_{WEIGHTS}_t2.model' - model.safe_load_state_dict(torch.load(model_file_name, map_location=device)) - - train_loader, val_loader, test_loader = train_val_test_split(pdb_dataset, - train_split=TRAIN_SPLIT, val_split=VAL_SPLIT, - shuffle_dataset=True, random_seed=RAND_SEED, - batch_size=BATCH_SIZE) - logs = train(model, train_loader, val_loader, device, - epochs=NUM_EPOCHS, lr_0=LEARNING_RATE) - - loss, pred, actual = test(model, test_loader, device) - model_results[MODEL_KEY] = {'test_loss': loss, - 'train_loss': logs['train_loss'][-1], - 'val_loss': logs['val_loss'][-1]} - return model_results - + return test_loss, pred, actual \ No newline at end of file diff --git a/src/utils/config.py b/src/utils/config.py index 1abc10fb..a42cb257 100644 --- a/src/utils/config.py +++ b/src/utils/config.py @@ -5,12 +5,14 @@ from prody import confProDy confProDy(verbosity='none') # stop printouts from prody - +# model and data options MODEL_OPT = ['DG', 'DGI', 'ED', 'EDA', 'EDI', 'EDAI', 'EAT', 'CD', 'CED'] STRUCT_EDGE_OPT = ['anm', 'af2', 'af2-anm'] # edge options that require structural info (pdbs) EDGE_OPT = ['simple', 'binary'] + STRUCT_EDGE_OPT -PRO_FEAT_OPT = ['nomsa', 'msa', 'shannon', 'foldseek'] + +STRUCT_PRO_FEAT_OPT = ['foldseek'] # requires structural info (pdbs) +PRO_FEAT_OPT = ['nomsa', 'msa', 'shannon'] + STRUCT_PRO_FEAT_OPT LIG_FEAT_OPT = [None, 'original'] LIG_EDGE_OPT = [None, 'binary'] @@ -47,4 +49,8 @@ SLURM_CONSTRAINT = 'cascade,v100' elif 'cedar' in DOMAIN_NAME: CLUSTER = 'cedar' - SLURM_GPU_NAME = 'v100l' \ No newline at end of file + SLURM_GPU_NAME = 'v100l' + +# bin paths +from pathlib import Path +FOLDSEEK_BIN = f'{Path.home()}/lib/foldseek/bin/foldseek' \ No newline at end of file diff --git a/src/utils/loader.py b/src/utils/loader.py index 439ea4f9..9d3f146b 100644 --- a/src/utils/loader.py +++ b/src/utils/loader.py @@ -149,12 +149,11 @@ def load_dataset(data:str, pro_feature:str, edge_opt:str, subset:str=None, path: @staticmethod @validate_args({'data': data_opt, 'pro_feature': pro_feature_opt, 'edge_opt': edge_opt, 'ligand_feature':cfg.LIG_FEAT_OPT, 'ligand_edge':cfg.LIG_EDGE_OPT}) - def load_DataLoaders(data:str, pro_feature:str, edge_opt:str, path:str=cfg.DATA_ROOT, - batch_train:int=64, datasets:Iterable[str]=['train', 'test', 'val'], + def load_datasets(data:str, pro_feature:str, edge_opt:str, path:str=cfg.DATA_ROOT, + datasets:Iterable[str]=['train', 'test', 'val'], training_fold:int=None, # for cross-val. None for no cross-val protein_overlap:bool=False, ligand_feature:str=None, ligand_edge:str=None): - loaders = {} # no overlap or cross-val subsets = datasets @@ -172,15 +171,37 @@ def load_DataLoaders(data:str, pro_feature:str, edge_opt:str, path:str=cfg.DATA_ if protein_overlap: subsets = [d+'-overlap' for d in subsets] - + loaded_datasets = {} for d, s in zip(datasets, subsets): dataset = Loader.load_dataset(data, pro_feature, edge_opt, subset=s, path=path, ligand_feature=ligand_feature, ligand_edge=ligand_edge) - + loaded_datasets[d] = dataset + return loaded_datasets + + @staticmethod + @validate_args({'data': data_opt, 'pro_feature': pro_feature_opt, 'edge_opt': edge_opt, + 'ligand_feature':cfg.LIG_FEAT_OPT, 'ligand_edge':cfg.LIG_EDGE_OPT}) + def load_DataLoaders(data:str, pro_feature:str, edge_opt:str, path:str=cfg.DATA_ROOT, + batch_train:int=64, datasets:Iterable[str]=['train', 'test', 'val'], + training_fold:int=None, # for cross-val. None for no cross-val + protein_overlap:bool=False, + ligand_feature:str=None, ligand_edge:str=None, + loaded_datasets=None): + # loaded_datasets is used to avoid loading the same dataset multiple times when we just want + # to create a new dataloader (e.g.: for testing with different batch size) + if loaded_datasets is None: + loaded_datasets = Loader.load_datasets(data=data, pro_feature=pro_feature, edge_opt=edge_opt, + path=path, datasets=datasets, training_fold=training_fold, + protein_overlap=protein_overlap, ligand_feature=ligand_feature, + ligand_edge=ligand_edge) + + loaders = {} + for d in loaded_datasets: bs = 1 if d == 'test' else batch_train - loader = DataLoader(dataset=dataset, batch_size=bs, + loader = DataLoader(dataset=loaded_datasets[d], + batch_size=bs, shuffle=False) loaders[d] = loader @@ -199,30 +220,15 @@ def load_distributed_DataLoaders(num_replicas:int, rank:int, seed:int, data:str, ligand_feature:str=None, ligand_edge:str=None, num_workers:int=4): - loaders = {} - # no overlap or cross-val - subsets = datasets - # training folds are identified by train1, train2, etc. - # (see model_key fn above) - if training_fold is not None: - subsets = [d+str(training_fold) for d in subsets] - try: - # making sure test set is not renamed - subsets[datasets.index('test')] = 'test' - except ValueError: - pass - - # Overlap is identified by adding '-overlap' to the subset name (after cross-val) - if protein_overlap: - subsets = [d+'-overlap' for d in subsets] + loaded_datasets = Loader.load_datasets(data=data, pro_feature=pro_feature, edge_opt=edge_opt, + path=path, datasets=datasets, training_fold=training_fold, + protein_overlap=protein_overlap, ligand_feature=ligand_feature, + ligand_edge=ligand_edge) - - for d, s in zip(datasets, subsets): - dataset = Loader.load_dataset(data, pro_feature, edge_opt, - subset=s, path=path, - ligand_feature=ligand_feature, - ligand_edge=ligand_edge) + loaders = {} + for d in loaded_datasets: + dataset = loaded_datasets[d] sampler = DistributedSampler(dataset, shuffle=True, num_replicas=num_replicas, rank=rank, seed=seed) diff --git a/src/utils/residue.py b/src/utils/residue.py index a58deceb..255e9cd2 100644 --- a/src/utils/residue.py +++ b/src/utils/residue.py @@ -6,10 +6,11 @@ def one_hot(x, allowable_set, cap=False): """Return the one-hot encoding of x as a numpy array.""" if x not in allowable_set: - if not cap: - raise Exception('input {0} not in allowable set{1}:'.format(x, allowable_set)) - else: + if cap: # last element is the catch all/unknown x = allowable_set[-1] + else: + raise Exception('input {0} not in allowable set{1}:'.format(x, allowable_set)) + return np.eye(len(allowable_set))[allowable_set.index(x)] @@ -36,7 +37,12 @@ class ResInfo(): amino_acids = ['A', 'C', 'D', 'E', 'F', 'G', 'H', 'I', 'K', 'L', 'M', 'N', 'P', 'Q', - 'R', 'S', 'T', 'V', 'W', 'Y', 'X'] + 'R', 'S', 'T', 'V', 'W', 'Y', 'X'] # X is unknown + + foldseek_tokens = ['a', 'c', 'd', 'e', 'f', 'g', 'h', + 'i', 'k', 'l', 'm', 'n', 'p', 'q', + 'r', 's', 't', 'v', 'w', 'y', '#'] # '#' is mask/unknown + res_to_i = {k: i for i, k in enumerate(amino_acids)} @@ -103,6 +109,7 @@ def normalize_dict(dictionary): # TODO: why not this instead? pl = normalize_add_x(pl) hydrophobic_ph2 = normalize_add_x(hydrophobic_ph2) hydrophobic_ph7 = normalize_add_x(hydrophobic_ph7) + from collections import OrderedDict diff --git a/src/utils/tuning.py b/src/utils/tuning.py new file mode 100644 index 00000000..eeb9fcb7 --- /dev/null +++ b/src/utils/tuning.py @@ -0,0 +1,38 @@ +from ray.air import session + +import torch +from torch_geometric.loader import DataLoader + +from src.data_processing.datasets import BaseDataset +from src.train_test.utils import CheckpointSaver +from src.utils.loader import Loader +from src.train_test.training import train + + +def train_tune(config, model:str, pro_feature:str, train_dataset:BaseDataset, val_dataset:BaseDataset): + device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') + model = Loader.init_model(model, pro_feature, config['edge'], config['dropout']) + model.to(device) + + train_loader = DataLoader(train_dataset, batch_size=config['batch_size'], + shuffle=True, + num_workers=2) + val_loader = DataLoader(val_dataset, batch_size=config['batch_size'], + shuffle=True, + num_workers=2) + + saver = CheckpointSaver(model, debug=True) + for i in range(10): # 10 epochs + logs = train(model, train_loader, val_loader, device, epochs=1, + lr_0=config['lr'], silent=True, saver=saver) + val_loss = logs['val_loss'][0] + + # Send the current training result back to Tune + session.report({"val_loss": val_loss}) + + if i % 5 == 0: + # This saves the model to the trial directory + torch.save(model.state_dict(), "./model.pth") + + + diff --git a/train_test.py b/train_test.py index beca4ba8..7826333e 100644 --- a/train_test.py +++ b/train_test.py @@ -73,16 +73,7 @@ os.makedirs(f'{media_save_p}/train_log/', exist_ok=True) - # ==== LOAD DATA ==== - # WARNING: Deprecating use of split to ensure all models train on same dataset splits. - # dataset = Loader.load_dataset(DATA, FEATURE, EDGEW, subset='full') - # print(f'# Number of samples: {len(dataset)}') - # train_loader, val_loader, test_loader = train_val_test_split(dataset, - # train_split=TRAIN_SPLIT, val_split=VAL_SPLIT, - # shuffle_dataset=True, random_seed=args.rand_seed, - # batch_train=BATCH_SIZE, use_refined=False, - # split_by_prot=not args.protein_overlap) - + # ==== LOAD DATA ==== loaders = Loader.load_DataLoaders(data=DATA, pro_feature=FEATURE, edge_opt=EDGEW, path=cfg.DATA_ROOT, ligand_feature=ligand_feature, ligand_edge=ligand_edge, batch_train=BATCH_SIZE,