From 11876db6f3fdc7e7bfe1bfdb951d975a6a611413 Mon Sep 17 00:00:00 2001 From: jyaacoub Date: Thu, 20 Jun 2024 13:58:11 -0400 Subject: [PATCH] feat(init_dataset): test_prots_csv for test set consistency --- src/data_prep/init_dataset.py | 22 +++++++++++++++++----- 1 file changed, 17 insertions(+), 5 deletions(-) diff --git a/src/data_prep/init_dataset.py b/src/data_prep/init_dataset.py index d961304..ff2ce16 100644 --- a/src/data_prep/init_dataset.py +++ b/src/data_prep/init_dataset.py @@ -1,12 +1,13 @@ import os import sys import itertools -from src.utils import config as cfg +import pandas as pd # Add the project root directory to Python path so imports work if file is run PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../')) sys.path.append(PROJECT_ROOT) +from src.utils import config as cfg from src.data_prep.feature_extraction.protein_nodes import create_pfm_np_files from src.data_prep.datasets import DavisKibaDataset, PDBbindDataset, PlatinumDataset from src.train_test.splitting import train_val_test_split, balanced_kfold_split @@ -19,7 +20,8 @@ def create_datasets(data_opt:list[str]|str, feat_opt:list[str]|str, edge_opt:lis random_seed:int=0, train_split:float=0.8, val_split:float=0.1, - overwrite=True, + overwrite=True, + test_prots_csv:str=None, **kwargs) -> None: """ Creates the datasets for the given data, feature, and edge options. @@ -44,6 +46,10 @@ def create_datasets(data_opt:list[str]|str, feat_opt:list[str]|str, edge_opt:lis `k_folds` : int, optional If not None, the number of folds to split the final training set into for cross validation, by default None + `test_prots_csv` : str, optional + If not None, the path to a csv file containing the test proteins to use, + by default None. The csv file should have a 'prot_id' column. + """ if isinstance(data_opt, str): data_opt = [data_opt] if isinstance(feat_opt, str): feat_opt = [feat_opt] @@ -124,8 +130,14 @@ def create_datasets(data_opt:list[str]|str, feat_opt:list[str]|str, edge_opt:lis else: assert test_split > 0, f"Invalid train/val/test split: {train_split}/{val_split}/{test_split}" assert not pro_overlap, f"No support for overlapping proteins with k-folds rn." + if test_prots_csv is not None: + df = pd.read_csv(test_prots_csv) + test_prots = set(df['prot_id'].tolist()) + else: + test_prots = None + train_loader, val_loader, test_loader = balanced_kfold_split(dataset, - k_folds=k_folds, test_split=test_split, + k_folds=k_folds, test_split=test_split, test_prots=test_prots, random_seed=random_seed) # only non-overlapping splits for k-folds subset_names = ['train', 'val', 'test'] @@ -140,7 +152,7 @@ def create_datasets(data_opt:list[str]|str, feat_opt:list[str]|str, edge_opt:lis # loops through all k folds and saves as train1, train2, etc. dataset.save_subset_folds(train_loader, subset_names[0]) dataset.save_subset_folds(val_loader, subset_names[1]) - + dataset.save_subset(test_loader, subset_names[2]) - + del dataset # free up memory