Skip to content

Commit 48f3963

Browse files
authored
Merge pull request #5 from BioinfoMachineLearning/develop
Allow input PDBs with custom filenames and make feature imputation function name unique
2 parents cad4ec8 + 17784a7 commit 48f3963

File tree

3 files changed

+18
-14
lines changed

3 files changed

+18
-14
lines changed

project/datasets/builder/impute_missing_feature_values.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55
import click
66
from parallel import submit_jobs
7-
from project.utils.dips_plus_utils import impute_missing_feature_values
7+
from project.utils.dips_plus_utils import impute_postprocessed_missing_feature_values
88

99

1010
# -------------------------------------------------------------------------------------------------------------------------------------
@@ -29,7 +29,7 @@ def main(output_dir: str, impute_atom_features: bool, advanced_logging: bool, nu
2929
inputs = [(pair_filename.as_posix(), pair_filename.as_posix(), impute_atom_features, advanced_logging)
3030
for pair_filename in Path(output_dir).rglob('*.dill')]
3131
# Without impute_atom_features set to True, non-CA atoms will be filtered out after writing updated pairs
32-
submit_jobs(impute_missing_feature_values, inputs, num_cpus)
32+
submit_jobs(impute_postprocessed_missing_feature_values, inputs, num_cpus)
3333

3434

3535
if __name__ == '__main__':

project/utils/deepinteract_utils.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@
3030
from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger
3131

3232
from project.utils.deepinteract_constants import FEAT_COLS, ALLOWABLE_FEATS, D3TO1
33-
from project.utils.dips_plus_utils import postprocess_pruned_pairs
33+
from project.utils.dips_plus_utils import postprocess_pruned_pairs, impute_postprocessed_missing_feature_values
3434
from project.utils.graph_utils import prot_df_to_dgl_graph_feats
3535
from project.utils.protein_feature_utils import GeometricProteinFeatures
3636

@@ -573,9 +573,11 @@ def create_input_dir_struct(input_dataset_dir: str, pdb_code: str):
573573
_, _ = dir_struct_create_proc.communicate() # Wait until the directory structure creation cmd is finished
574574

575575

576-
def copy_input_to_raw_dir(input_dataset_dir: str, pdb_filepath: str, pdb_code: str):
576+
def copy_input_to_raw_dir(input_dataset_dir: str, pdb_filepath: str, pdb_code: str, chain_indic: str):
577577
"""Make a copy of the input PDB file in the newly-created raw directory."""
578-
input_copy_cmd = f'cp {pdb_filepath} {os.path.join(input_dataset_dir, "raw", pdb_code)}'
578+
filename = db.get_pdb_code(pdb_filepath) + f'_{chain_indic}.pdb' \
579+
if chain_indic not in pdb_filepath else db.get_pdb_name(pdb_filepath)
580+
input_copy_cmd = f'cp {pdb_filepath} {os.path.join(input_dataset_dir, "raw", pdb_code, filename)}'
579581
input_copy_proc = subprocess.Popen(input_copy_cmd.split(), stdout=subprocess.PIPE, cwd=os.getcwd())
580582
_, _ = input_copy_proc.communicate() # Wait until the input copy cmd is finished
581583

@@ -590,6 +592,7 @@ def make_dataset(input_dataset_dir='datasets/Input/raw', output_dir='datasets/In
590592
pa.parse_all(input_dataset_dir, parsed_dir, num_cpus)
591593

592594
complexes_dill = os.path.join(output_dir, 'complexes/complexes.dill')
595+
os.remove(complexes_dill) # Ensure that pairs are made everytime this function is called
593596
comp.complexes(parsed_dir, complexes_dill, source_type)
594597
complexes = comp.read_complexes(complexes_dill)
595598
pairs_dir = os.path.join(output_dir, 'pairs')
@@ -697,7 +700,7 @@ def impute_missing_feature_values(output_dir='datasets/Input/final/raw',
697700
inputs = [(pair_filename.as_posix(), pair_filename.as_posix(), impute_atom_features, advanced_logging)
698701
for pair_filename in Path(output_dir).rglob('*.dill')]
699702
# Without impute_atom_features set to True, non-CA atoms will be filtered out after writing updated pairs
700-
par.submit_jobs(impute_missing_feature_values, inputs, num_cpus)
703+
par.submit_jobs(impute_postprocessed_missing_feature_values, inputs, num_cpus)
701704

702705

703706
def convert_input_pdb_files_to_pair(left_pdb_filepath: str, right_pdb_filepath: str, input_dataset_dir: str,
@@ -707,8 +710,8 @@ def convert_input_pdb_files_to_pair(left_pdb_filepath: str, right_pdb_filepath:
707710
pdb_code = db.get_pdb_group(list(ca.get_complex_pdb_codes([left_pdb_filepath, right_pdb_filepath]))[0])
708711
# Iteratively execute the PDB file feature generation process
709712
create_input_dir_struct(input_dataset_dir, pdb_code)
710-
copy_input_to_raw_dir(input_dataset_dir, left_pdb_filepath, pdb_code)
711-
copy_input_to_raw_dir(input_dataset_dir, right_pdb_filepath, pdb_code)
713+
copy_input_to_raw_dir(input_dataset_dir, left_pdb_filepath, pdb_code, 'l_u')
714+
copy_input_to_raw_dir(input_dataset_dir, right_pdb_filepath, pdb_code, 'r_u')
712715
make_dataset(os.path.join(input_dataset_dir, 'raw'), os.path.join(input_dataset_dir, 'interim'))
713716
generate_psaia_features(psaia_dir=psaia_dir,
714717
psaia_config=psaia_config,

project/utils/dips_plus_utils.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -394,9 +394,10 @@ def __should_keep_postprocessed(raw_pdb_dir: str, pair_filename: str, source_typ
394394
# Identify if a given complex contains DSSP-derivable secondary structure features
395395
raw_pdb_filenames.append(get_raw_pdb_filename_from_interim_filename(interim_filename, raw_pdb_dir, source_type))
396396
pair_dssp_dict = get_dssp_dict_for_pdb_file(raw_pdb_filenames[i])
397-
if not pair_dssp_dict and source_type not in ['input']:
397+
if source_type.lower() not in ['input'] and not pair_dssp_dict:
398398
return pair, raw_pdb_filenames[i], False # Discard pair missing DSSP-derivable secondary structure features
399-
if pair.df0.shape[0] > ATOM_COUNT_LIMIT or pair.df1.shape[0] > ATOM_COUNT_LIMIT:
399+
if source_type.lower() not in ['input'] \
400+
and (pair.df0.shape[0] > ATOM_COUNT_LIMIT or pair.df1.shape[0] > ATOM_COUNT_LIMIT):
400401
return pair, raw_pdb_filenames[i], False # Discard pair exceeding atom count limit to reduce comp. complex.
401402
return pair, raw_pdb_filenames, True
402403

@@ -458,8 +459,8 @@ def postprocess_pruned_pair(raw_pdb_filenames: List[str], external_feats_dir: st
458459
rd_dict = get_msms_rd_dict_for_pdb_model(structure[0]) # RD only retrieved for first model
459460

460461
# Get protrusion indices using PSAIA
461-
psaia_filepath = os.path.relpath(os.path.splitext(os.path.split(raw_pdb_filename)[-1])[0])
462-
psaia_filename = [path for path in Path(external_feats_dir).rglob(f'{psaia_filepath}*.tbl')][0] # 1st path
462+
pdb_code = db.get_pdb_code(raw_pdb_filename)
463+
psaia_filename = [path for path in Path(external_feats_dir).rglob(f'{pdb_code}*.tbl')][0] # 1st path
463464
psaia_df = get_df_from_psaia_tbl_file(psaia_filename)
464465

465466
# Extract half-sphere exposure (HSE) statistics for each PDB model (including HSAAC and CN values)
@@ -836,8 +837,8 @@ def determine_nan_fill_value(column: pd.Series, imputation_method='median'):
836837
return imputation_value if column.isna().sum().sum() <= NUM_ALLOWABLE_NANS else 0
837838

838839

839-
def impute_missing_feature_values(input_pair_filename: str, output_pair_filename: str,
840-
impute_atom_features: bool, advanced_logging: bool):
840+
def impute_postprocessed_missing_feature_values(input_pair_filename: str, output_pair_filename: str,
841+
impute_atom_features: bool, advanced_logging: bool):
841842
"""Impute missing feature values in a postprocessed dataset."""
842843
# Look at a .dill file in the given output directory
843844
postprocessed_pair: pa.Pair = pd.read_pickle(input_pair_filename)

0 commit comments

Comments
 (0)