diff --git a/.gitignore b/.gitignore index 6620b40..b0caed1 100644 --- a/.gitignore +++ b/.gitignore @@ -224,4 +224,4 @@ results/model_media/*/train_set_pred/* results/model_media/*/test_set_pred/* results/model_media/test_set_pred -splits/*.csv +splits/**/*.csv diff --git a/splits/davis_splits.zip b/splits/davis_splits.zip index 1169a59..88f3b8b 100644 Binary files a/splits/davis_splits.zip and b/splits/davis_splits.zip differ diff --git a/src/data_prep/init_dataset.py b/src/data_prep/init_dataset.py index ff2ce16..586954e 100644 --- a/src/data_prep/init_dataset.py +++ b/src/data_prep/init_dataset.py @@ -10,7 +10,7 @@ from src.utils import config as cfg from src.data_prep.feature_extraction.protein_nodes import create_pfm_np_files from src.data_prep.datasets import DavisKibaDataset, PDBbindDataset, PlatinumDataset -from src.train_test.splitting import train_val_test_split, balanced_kfold_split +from src.train_test.splitting import train_val_test_split, balanced_kfold_split, resplit def create_datasets(data_opt:list[str]|str, feat_opt:list[str]|str, edge_opt:list[str]|str, ligand_features:list[str]=['original'], @@ -22,6 +22,7 @@ def create_datasets(data_opt:list[str]|str, feat_opt:list[str]|str, edge_opt:lis val_split:float=0.1, overwrite=True, test_prots_csv:str=None, + val_prots_csv:list[str]=None, **kwargs) -> None: """ Creates the datasets for the given data, feature, and edge options. @@ -123,36 +124,43 @@ def create_datasets(data_opt:list[str]|str, feat_opt:list[str]|str, edge_opt:lis # saving training, validation, and test sets test_split = 1 - train_split - val_split - if k_folds is None: - train_loader, val_loader, test_loader = train_val_test_split(dataset, - train_split=train_split, val_split=val_split, - random_seed=random_seed, split_by_prot=not pro_overlap) + if val_prots_csv: + assert k_folds is None or len(val_prots_csv) == k_folds, "Mismatch between number of val_prot_csvs provided and k_folds selected." + + split_files = {os.path.basename(v).split('.')[0]: v for v in val_prots_csv} + split_files['test'] = test_prots_csv + dataset = resplit(dataset, split_files=split_files) else: - assert test_split > 0, f"Invalid train/val/test split: {train_split}/{val_split}/{test_split}" - assert not pro_overlap, f"No support for overlapping proteins with k-folds rn." - if test_prots_csv is not None: - df = pd.read_csv(test_prots_csv) - test_prots = set(df['prot_id'].tolist()) + if k_folds is None: + train_loader, val_loader, test_loader = train_val_test_split(dataset, + train_split=train_split, val_split=val_split, + random_seed=random_seed, split_by_prot=not pro_overlap) else: - test_prots = None + assert test_split > 0, f"Invalid train/val/test split: {train_split}/{val_split}/{test_split}" + assert not pro_overlap, f"No support for overlapping proteins with k-folds rn." + if test_prots_csv is not None: + df = pd.read_csv(test_prots_csv) + test_prots = set(df['prot_id'].tolist()) + else: + test_prots = None - train_loader, val_loader, test_loader = balanced_kfold_split(dataset, - k_folds=k_folds, test_split=test_split, test_prots=test_prots, - random_seed=random_seed) # only non-overlapping splits for k-folds + train_loader, val_loader, test_loader = balanced_kfold_split(dataset, + k_folds=k_folds, test_split=test_split, test_prots=test_prots, + random_seed=random_seed) # only non-overlapping splits for k-folds + + subset_names = ['train', 'val', 'test'] + if pro_overlap: + subset_names = [s+'-overlap' for s in subset_names] - subset_names = ['train', 'val', 'test'] - if pro_overlap: - subset_names = [s+'-overlap' for s in subset_names] - - if test_split < 1: # for datasets that are purely for testing we skip this section - if k_folds is None: - dataset.save_subset(train_loader, subset_names[0]) - dataset.save_subset(val_loader, subset_names[1]) - else: - # loops through all k folds and saves as train1, train2, etc. - dataset.save_subset_folds(train_loader, subset_names[0]) - dataset.save_subset_folds(val_loader, subset_names[1]) - - dataset.save_subset(test_loader, subset_names[2]) + if test_split < 1: # for datasets that are purely for testing we skip this section + if k_folds is None: + dataset.save_subset(train_loader, subset_names[0]) + dataset.save_subset(val_loader, subset_names[1]) + else: + # loops through all k folds and saves as train1, train2, etc. + dataset.save_subset_folds(train_loader, subset_names[0]) + dataset.save_subset_folds(val_loader, subset_names[1]) + + dataset.save_subset(test_loader, subset_names[2]) del dataset # free up memory diff --git a/src/train_test/splitting.py b/src/train_test/splitting.py index a9f36c9..951610a 100644 --- a/src/train_test/splitting.py +++ b/src/train_test/splitting.py @@ -314,7 +314,9 @@ def resplit(dataset:str|BaseDataset, split_files:dict|str=None, **kwargs): csv_files[split] = f'{split_files}/{split}/cleaned_XY.csv' split_files = csv_files print('Using split files from:', split_files) - + + assert 'test' in split_files, 'Missing test csv from split files.' + # Check if split files exist and are in the correct format if split_files is None: raise ValueError('split_files must be provided')