Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Integrate binding pocket tool into DAVIS and KIBA datasets #129

Merged
merged 6 commits into from
Jul 31, 2024
181 changes: 169 additions & 12 deletions src/utils/pocket_alignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,20 +3,26 @@
mask from a binding pocket sequence.
"""

import json
import os

from Bio import Align
from Bio.Align import substitution_matrices
import pandas as pd
import torch

from src.data_prep.downloaders import Downloader


def create_pocket_mask(target_seq: str, query_seq: str) -> list[bool]:
def create_pocket_mask(target_seq: str, pocket_seq: str) -> list[bool]:
"""
Return an index mask of a pocket on a protein sequence.

Parameters
----------
target_seq : str
The protein sequence you want to query in
query_seq : str
pocket_seq : str
The binding pocket sequence for the protein

Returns
Expand All @@ -25,6 +31,8 @@ def create_pocket_mask(target_seq: str, query_seq: str) -> list[bool]:
A boolean list of indices that are True if the residue at that
position is part of the binding pocket and false otherwise
"""
# Ensure that no '-' characters are present in the query sequence
query_seq = pocket_seq.replace('-', 'X')
# Taken from tutorial https://biopython.org/docs/dev/Tutorial/chapter_pairwise.html
aligner = Align.PairwiseAligner()
# Pairwise alignment parameters as specified in paragraph 2
Expand Down Expand Up @@ -92,15 +100,164 @@ def mask_graph(data, mask: list[bool]):
return data


def _parse_json(json_path: str) -> str:
"""
Parse a JSON file that holds binding pocket data downloaded from KLIFS.

Parameters
----------
json_path : str
The path to the JSON file

Returns
-------
str
The binding pocket sequence
"""
with open(json_path, 'r') as json_file:
data = json.load(json_file)
return data[0]['pocket']


def get_dataset_binding_pockets(
dataset_path: str = 'data/DavisKibaDataset/kiba/nomsa_binary_original_binary/full',
pockets_path: str = 'data/DavisKibaDataset/kiba_pocket'
) -> tuple[dict[str, str], set[str]]:
"""
Get all binding pocket sequences for a dataset

Parameters
----------
dataset_path : str
The path to the directory containing the dataset (as of July 24, 2024,
only expecting Kiba dataset). Specify only the path to one of 'davis', 'kiba',
or 'PDBbind' (e.g., 'data/DavisKibaDataset/kiba')
pockets_path: str
The path to the new dataset directory after all the binding pockets have been found

Returns
-------
tuple[dict[str, str], set[str]]
A tuple consisting of:
-A map of protein ID, binding pocket sequence pairs
-A set of protein IDs with no KLIFS binding pockets
"""
csv_path = os.path.join(dataset_path, 'cleaned_XY.csv')
df = pd.read_csv(csv_path, usecols=['prot_id'])
prot_ids = list(set(df['prot_id']))
# Strip out mutations and '-(alpha, beta, gamma)' tags if they are present,
# the binding pocket sequence will be the same for mutated and non-mutated genes
prot_ids = [id.split('(')[0].split('-')[0] for id in prot_ids]
dl = Downloader()
seq_save_dir = os.path.join(pockets_path, 'pockets')
os.makedirs(seq_save_dir, exist_ok=True)
download_check = dl.download_pocket_seq(prot_ids, seq_save_dir)
download_errors = set()
for key, val in download_check.items():
if val == 400:
download_errors.add(key)
sequences = {}
for file in os.listdir(seq_save_dir):
pocket_seq = _parse_json(os.path.join(seq_save_dir, file))
if pocket_seq == 0 or len(pocket_seq) == 0:
download_errors.add(file.split('.')[0])
else:
sequences[file.split('.')[0]] = pocket_seq
return (sequences, download_errors)


def create_binding_pocket_dataset(
dataset_path: str,
pocket_sequences: dict[str, str],
download_errors: set[str],
new_dataset_path: str
) -> None:
"""
Apply the graph mask based on binding pocket sequence for each
Data object in a PyTorch dataset.

dataset_path : str
The path to the PyTorch dataset object to be transformed
pocket_sequences : dict[str, str]
A map of protein ID, binding pocket sequence pairs
download_errors : set[str]
A set of protein IDs that have no binding pocket sequence
to be downloaded from KLIFS
new_dataset_path : str
A path to where the new dataset should be saved
"""
dataset = torch.load(dataset_path)
new_dataset = {}
for id, data in dataset.items():
# If there are any mutations or (-alpha,beta,gamma) tags, strip them
stripped_id = id.split('(')[0].split('-')[0]
if stripped_id not in download_errors:
mask = create_pocket_mask(data.pro_seq, pocket_sequences[stripped_id])
new_data = mask_graph(data, mask)
new_dataset[id] = new_data
os.makedirs(os.path.dirname(new_dataset_path), exist_ok=True)
torch.save(dataset, new_dataset_path)


def binding_pocket_filter(dataset_csv_path: str, download_errors: set[str], csv_save_path: str):
"""
Filter out protein IDs that do not have a corresponding KLIFS
binding pocket sequence from the dataset.

Parameters
----------
dataset_csv_path : str
The path to the original cleaned CSV. Will probably be a CSV named cleaned_XY.csv
or something like that.
download_errors : set[str]
A set of protein IDs with no KLIFS binding pocket sequences.
csv_save_path : str
The path to save the new CSV file to.
"""
df = pd.read_csv(dataset_csv_path)
df = df[~df['prot_id'].isin(download_errors)]
os.makedirs(os.path.dirname(csv_save_path), exist_ok=True)
df.to_csv(csv_save_path)


def pocket_dataset_full(
dataset_dir: str,
pocket_dir: str,
save_dir: str
) -> None:
"""
Create all elements of a dataset that includes binding pockets. This
function assumes the PyTorch object holding the dataset is named 'data_pro.pt'
and the CSV holding the cleaned data is named 'cleaned_XY.csv'.

Parameters
----------
dataset_dir : str
The path to the dataset to be transformed
pocket_dir : str
The path to where the dataset raw pocket sequences are to be saved
save_dir : str
The path to where the new dataset is to be saved
"""
pocket_map, download_errors = get_dataset_binding_pockets(dataset_dir, pocket_dir)
print(f'Binding pocket sequences were not found for the following {len(download_errors)} protein IDs:')
print(','.join(list(download_errors)))
create_binding_pocket_dataset(
os.path.join(dataset_dir, 'data_pro.pt'),
pocket_map,
download_errors,
os.path.join(save_dir, 'data_pro.pt')
)
binding_pocket_filter(
os.path.join(dataset_dir, 'cleaned_XY.csv'),
download_errors,
os.path.join(save_dir, 'cleaned_XY.csv')
)
jyaacoub marked this conversation as resolved.
Show resolved Hide resolved


if __name__ == '__main__':
graph_data = torch.load('sample_pro_data.torch')
seq = graph_data.pro_seq
seq = seq[:857] + 'R' + seq[858:]
graph_data.pro_seq = seq
torch.save(graph_data, 'sample_pro_data_unmutated.torch')
binding_pocket_sequence = 'KVLGSGAFGTVYKVAIKELEILDEAYVMASVDPHVCRLLGIQLITQLMPFGCLLDYVREYLEDRRLVHRDLAARNVLVITDFGLA'
mask = create_pocket_mask(
graph_data.pro_seq,
binding_pocket_sequence
pocket_dataset_full(
'data/DavisKibaDataset/kiba/nomsa_binary_original_binary/full/',
'data/DavisKibaDataset/kiba_pocket',
'data/DavisKibaDataset/kiba_pocket/nomsa_binary_original_binary/full/'
)
masked_data = mask_graph(graph_data, mask)
Loading