diff --git a/results/v128/model_media/model_stats.csv b/results/v128/model_media/model_stats.csv new file mode 100644 index 0000000..cf15971 --- /dev/null +++ b/results/v128/model_media/model_stats.csv @@ -0,0 +1,11 @@ +run,cindex,pearson,spearman,mse,mae,rmse +DGM_PDBbind1D_nomsaF_binaryE_64B_0.0001LR_0.4D_2000E_originalLF_binaryLE,0.6700303728768234,0.5066764643691913,0.4877355421142711,2.8944340651943885,1.3393119536913358,1.7013036369779466 +DGM_PDBbind0D_nomsaF_binaryE_64B_0.0001LR_0.4D_2000E_originalLF_binaryLE,0.6920256384153837,0.5521166195053873,0.5492833432385198,2.6901810056403943,1.2902274471429678,1.6401771263008134 +DGM_PDBbind2D_nomsaF_binaryE_64B_0.0001LR_0.4D_2000E_originalLF_binaryLE,0.6860817317553821,0.5349021289716281,0.5292561754408733,2.783316931934153,1.3007997966913076,1.668327585318349 +DGM_PDBbind4D_nomsaF_binaryE_64B_0.0001LR_0.4D_2000E_originalLF_binaryLE,0.6814616054321716,0.5209332896962908,0.5205006093028853,2.791164015918997,1.3129914098152748,1.6706777115646805 +DGM_PDBbind3D_nomsaF_binaryE_64B_0.0001LR_0.4D_2000E_originalLF_binaryLE,0.6886407871725875,0.5377497788348156,0.536984917834154,2.7892612584388665,1.3068153142929078,1.6701081577068195 +DGM_PDBbind0D_nomsaF_aflowE_128B_0.0009185598967356679LR_0.22880989869337157D_2000E_originalLF_binaryLE,0.6791215578161572,0.5426811555689641,0.5087142971057145,2.7253219488991136,1.2940673363869717,1.6508549145515827 +DGM_PDBbind2D_nomsaF_aflowE_128B_0.0009185598967356679LR_0.22880989869337157D_2000E_originalLF_binaryLE,0.6935974149775556,0.5590003871276702,0.5527753938595599,2.613210370957274,1.2817467006748329,1.616542721661656 +DGM_PDBbind1D_nomsaF_aflowE_128B_0.0009185598967356679LR_0.22880989869337157D_2000E_originalLF_binaryLE,0.6776316790562161,0.5208748506408137,0.5056665537333501,2.79682269540488,1.3061289907536988,1.672370382243383 +DGM_PDBbind3D_nomsaF_aflowE_128B_0.0009185598967356679LR_0.22880989869337157D_2000E_originalLF_binaryLE,0.6772535031619085,0.5129702979170532,0.5079117445737773,2.876293932539623,1.343678560827124,1.69596401274898 +DGM_PDBbind4D_nomsaF_aflowE_128B_0.0009185598967356679LR_0.22880989869337157D_2000E_originalLF_binaryLE,0.6803930344028177,0.5261288742927207,0.5152003489955188,2.7888457682850523,1.3061685765992026,1.669983762880661 diff --git a/results/v128/model_media/model_stats_val.csv b/results/v128/model_media/model_stats_val.csv new file mode 100644 index 0000000..d6661c5 --- /dev/null +++ b/results/v128/model_media/model_stats_val.csv @@ -0,0 +1,11 @@ +run,cindex,pearson,spearman,mse,mae,rmse +DGM_PDBbind1D_nomsaF_binaryE_64B_0.0001LR_0.4D_2000E_originalLF_binaryLE,0.6865570848103539,0.5217828035285198,0.5342788004080268,3.12362316791697,1.4128894441393012,1.7673774831418922 +DGM_PDBbind0D_nomsaF_binaryE_64B_0.0001LR_0.4D_2000E_originalLF_binaryLE,0.6856114868115144,0.5126870149038275,0.5317248936648343,3.3695982839888123,1.4536717794904257,1.8356465574801737 +DGM_PDBbind2D_nomsaF_binaryE_64B_0.0001LR_0.4D_2000E_originalLF_binaryLE,0.6727143444861741,0.5022869639192042,0.5008407950347942,3.421695656587546,1.4982888305121282,1.849782597114468 +DGM_PDBbind4D_nomsaF_binaryE_64B_0.0001LR_0.4D_2000E_originalLF_binaryLE,0.6923138709195878,0.5430234761140276,0.5501681361051465,2.9685919834473102,1.3589257854182952,1.7229602384986458 +DGM_PDBbind3D_nomsaF_binaryE_64B_0.0001LR_0.4D_2000E_originalLF_binaryLE,0.7221237704365324,0.5997226273753726,0.6282960957077218,3.071004525264972,1.3901743548518686,1.7524281797737025 +DGM_PDBbind0D_nomsaF_aflowE_128B_0.0009185598967356679LR_0.22880989869337157D_2000E_originalLF_binaryLE,0.6765066849227523,0.5001545914726073,0.5133105406519389,3.044526199847986,1.3962691384923704,1.7448570714668827 +DGM_PDBbind2D_nomsaF_aflowE_128B_0.0009185598967356679LR_0.22880989869337157D_2000E_originalLF_binaryLE,0.7083260415364561,0.5949781068309761,0.5895379695932953,2.563278684581813,1.274137923604572,1.601024261084701 +DGM_PDBbind1D_nomsaF_aflowE_128B_0.0009185598967356679LR_0.22880989869337157D_2000E_originalLF_binaryLE,0.6854165476207483,0.5218685304168357,0.535423559993022,2.837139616188374,1.326142185326017,1.6843810780783468 +DGM_PDBbind3D_nomsaF_aflowE_128B_0.0009185598967356679LR_0.22880989869337157D_2000E_originalLF_binaryLE,0.6752795639801628,0.5022118247529487,0.5077603820327848,3.2012135060951,1.425228337672172,1.7891935351143822 +DGM_PDBbind4D_nomsaF_aflowE_128B_0.0009185598967356679LR_0.22880989869337157D_2000E_originalLF_binaryLE,0.6843307790913993,0.5107948247428088,0.5298336871787266,2.9007170257795094,1.3341726068459485,1.7031491495989155 diff --git a/src/utils/pocket_alignment.py b/src/utils/pocket_alignment.py index 550eb6e..8ff30eb 100644 --- a/src/utils/pocket_alignment.py +++ b/src/utils/pocket_alignment.py @@ -3,12 +3,19 @@ mask from a binding pocket sequence. """ +import json +import os +import shutil + from Bio import Align from Bio.Align import substitution_matrices +import pandas as pd import torch +from src.data_prep.downloaders import Downloader + -def create_pocket_mask(target_seq: str, query_seq: str) -> list[bool]: +def create_pocket_mask(target_seq: str, pocket_seq: str) -> list[bool]: """ Return an index mask of a pocket on a protein sequence. @@ -16,7 +23,7 @@ def create_pocket_mask(target_seq: str, query_seq: str) -> list[bool]: ---------- target_seq : str The protein sequence you want to query in - query_seq : str + pocket_seq : str The binding pocket sequence for the protein Returns @@ -25,6 +32,8 @@ def create_pocket_mask(target_seq: str, query_seq: str) -> list[bool]: A boolean list of indices that are True if the residue at that position is part of the binding pocket and false otherwise """ + # Ensure that no '-' characters are present in the query sequence + query_seq = pocket_seq.replace('-', 'X') # Taken from tutorial https://biopython.org/docs/dev/Tutorial/chapter_pairwise.html aligner = Align.PairwiseAligner() # Pairwise alignment parameters as specified in paragraph 2 @@ -92,15 +101,165 @@ def mask_graph(data, mask: list[bool]): return data +def _parse_json(json_path: str) -> str: + """ + Parse a JSON file that holds binding pocket data downloaded from KLIFS. + + Parameters + ---------- + json_path : str + The path to the JSON file + + Returns + ------- + str + The binding pocket sequence + """ + with open(json_path, 'r') as json_file: + data = json.load(json_file) + return data[0]['pocket'] + + +def get_dataset_binding_pockets( + dataset_path: str = 'data/DavisKibaDataset/kiba/nomsa_binary_original_binary/full', + pockets_path: str = 'data/DavisKibaDataset/kiba_pocket' + ) -> tuple[dict[str, str], set[str]]: + """ + Get all binding pocket sequences for a dataset + + Parameters + ---------- + dataset_path : str + The path to the directory containing the dataset (as of July 24, 2024, + only expecting Kiba dataset). Specify only the path to one of 'davis', 'kiba', + or 'PDBbind' (e.g., 'data/DavisKibaDataset/kiba') + pockets_path: str + The path to the new dataset directory after all the binding pockets have been found + + Returns + ------- + tuple[dict[str, str], set[str]] + A tuple consisting of: + -A map of protein ID, binding pocket sequence pairs + -A set of protein IDs with no KLIFS binding pockets + """ + csv_path = os.path.join(dataset_path, 'cleaned_XY.csv') + df = pd.read_csv(csv_path, usecols=['prot_id']) + prot_ids = list(set(df['prot_id'])) + # Strip out mutations and '-(alpha, beta, gamma)' tags if they are present, + # the binding pocket sequence will be the same for mutated and non-mutated genes + prot_ids = [id.split('(')[0].split('-')[0] for id in prot_ids] + dl = Downloader() + seq_save_dir = os.path.join(pockets_path, 'pockets') + os.makedirs(seq_save_dir, exist_ok=True) + download_check = dl.download_pocket_seq(prot_ids, seq_save_dir) + download_errors = set() + for key, val in download_check.items(): + if val == 400: + download_errors.add(key) + sequences = {} + for file in os.listdir(seq_save_dir): + pocket_seq = _parse_json(os.path.join(seq_save_dir, file)) + if pocket_seq == 0 or len(pocket_seq) == 0: + download_errors.add(file.split('.')[0]) + else: + sequences[file.split('.')[0]] = pocket_seq + return (sequences, download_errors) + + +def create_binding_pocket_dataset( + dataset_path: str, + pocket_sequences: dict[str, str], + download_errors: set[str], + new_dataset_path: str +) -> None: + """ + Apply the graph mask based on binding pocket sequence for each + Data object in a PyTorch dataset. + + dataset_path : str + The path to the PyTorch dataset object to be transformed + pocket_sequences : dict[str, str] + A map of protein ID, binding pocket sequence pairs + download_errors : set[str] + A set of protein IDs that have no binding pocket sequence + to be downloaded from KLIFS + new_dataset_path : str + A path to where the new dataset should be saved + """ + dataset = torch.load(dataset_path) + new_dataset = {} + for id, data in dataset.items(): + # If there are any mutations or (-alpha,beta,gamma) tags, strip them + stripped_id = id.split('(')[0].split('-')[0] + if stripped_id not in download_errors: + mask = create_pocket_mask(data.pro_seq, pocket_sequences[stripped_id]) + new_data = mask_graph(data, mask) + new_dataset[id] = new_data + os.makedirs(os.path.dirname(new_dataset_path), exist_ok=True) + torch.save(dataset, new_dataset_path) + + +def binding_pocket_filter(dataset_csv_path: str, download_errors: set[str], csv_save_path: str): + """ + Filter out protein IDs that do not have a corresponding KLIFS + binding pocket sequence from the dataset. + + Parameters + ---------- + dataset_csv_path : str + The path to the original cleaned CSV. Will probably be a CSV named cleaned_XY.csv + or something like that. + download_errors : set[str] + A set of protein IDs with no KLIFS binding pocket sequences. + csv_save_path : str + The path to save the new CSV file to. + """ + df = pd.read_csv(dataset_csv_path) + df = df[~df['prot_id'].isin(download_errors)] + os.makedirs(os.path.dirname(csv_save_path), exist_ok=True) + df.to_csv(csv_save_path) + + +def pocket_dataset_full( + dataset_dir: str, + pocket_dir: str, + save_dir: str +) -> None: + """ + Create all elements of a dataset that includes binding pockets. This + function assumes the PyTorch object holding the dataset is named 'data_pro.pt' + and the CSV holding the cleaned data is named 'cleaned_XY.csv'. + + Parameters + ---------- + dataset_dir : str + The path to the dataset to be transformed + pocket_dir : str + The path to where the dataset raw pocket sequences are to be saved + save_dir : str + The path to where the new dataset is to be saved + """ + pocket_map, download_errors = get_dataset_binding_pockets(dataset_dir, pocket_dir) + print(f'Binding pocket sequences were not found for the following {len(download_errors)} protein IDs:') + print(','.join(list(download_errors))) + create_binding_pocket_dataset( + os.path.join(dataset_dir, 'data_pro.pt'), + pocket_map, + download_errors, + os.path.join(save_dir, 'data_pro.pt') + ) + binding_pocket_filter( + os.path.join(dataset_dir, 'cleaned_XY.csv'), + download_errors, + os.path.join(save_dir, 'cleaned_XY.csv') + ) + shutil.copy2(os.path.join(dataset_dir, 'data_mol.pt'), os.path.join(save_dir, 'data_mol.pt')) + + if __name__ == '__main__': - graph_data = torch.load('sample_pro_data.torch') - seq = graph_data.pro_seq - seq = seq[:857] + 'R' + seq[858:] - graph_data.pro_seq = seq - torch.save(graph_data, 'sample_pro_data_unmutated.torch') - binding_pocket_sequence = 'KVLGSGAFGTVYKVAIKELEILDEAYVMASVDPHVCRLLGIQLITQLMPFGCLLDYVREYLEDRRLVHRDLAARNVLVITDFGLA' - mask = create_pocket_mask( - graph_data.pro_seq, - binding_pocket_sequence + pocket_dataset_full( + 'data/DavisKibaDataset/kiba/nomsa_binary_original_binary/full/', + 'data/DavisKibaDataset/kiba_pocket', + 'data/DavisKibaDataset/kiba_pocket/nomsa_binary_original_binary/full/' ) - masked_data = mask_graph(graph_data, mask)