Skip to content

Commit

Permalink
feat(init_dataset): test_prots_csv for test set consistency
Browse files Browse the repository at this point in the history
  • Loading branch information
jyaacoub committed Jun 20, 2024
1 parent d939e46 commit 11876db
Showing 1 changed file with 17 additions and 5 deletions.
22 changes: 17 additions & 5 deletions src/data_prep/init_dataset.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
import os
import sys
import itertools
from src.utils import config as cfg
import pandas as pd

# Add the project root directory to Python path so imports work if file is run
PROJECT_ROOT = os.path.abspath(os.path.join(os.path.dirname(__file__), '../../'))
sys.path.append(PROJECT_ROOT)

from src.utils import config as cfg
from src.data_prep.feature_extraction.protein_nodes import create_pfm_np_files
from src.data_prep.datasets import DavisKibaDataset, PDBbindDataset, PlatinumDataset
from src.train_test.splitting import train_val_test_split, balanced_kfold_split
Expand All @@ -19,7 +20,8 @@ def create_datasets(data_opt:list[str]|str, feat_opt:list[str]|str, edge_opt:lis
random_seed:int=0,
train_split:float=0.8,
val_split:float=0.1,
overwrite=True,
overwrite=True,
test_prots_csv:str=None,
**kwargs) -> None:
"""
Creates the datasets for the given data, feature, and edge options.
Expand All @@ -44,6 +46,10 @@ def create_datasets(data_opt:list[str]|str, feat_opt:list[str]|str, edge_opt:lis
`k_folds` : int, optional
If not None, the number of folds to split the final training set into for
cross validation, by default None
`test_prots_csv` : str, optional
If not None, the path to a csv file containing the test proteins to use,
by default None. The csv file should have a 'prot_id' column.
"""
if isinstance(data_opt, str): data_opt = [data_opt]
if isinstance(feat_opt, str): feat_opt = [feat_opt]
Expand Down Expand Up @@ -124,8 +130,14 @@ def create_datasets(data_opt:list[str]|str, feat_opt:list[str]|str, edge_opt:lis
else:
assert test_split > 0, f"Invalid train/val/test split: {train_split}/{val_split}/{test_split}"
assert not pro_overlap, f"No support for overlapping proteins with k-folds rn."
if test_prots_csv is not None:
df = pd.read_csv(test_prots_csv)
test_prots = set(df['prot_id'].tolist())
else:
test_prots = None

train_loader, val_loader, test_loader = balanced_kfold_split(dataset,
k_folds=k_folds, test_split=test_split,
k_folds=k_folds, test_split=test_split, test_prots=test_prots,
random_seed=random_seed) # only non-overlapping splits for k-folds

subset_names = ['train', 'val', 'test']
Expand All @@ -140,7 +152,7 @@ def create_datasets(data_opt:list[str]|str, feat_opt:list[str]|str, edge_opt:lis
# loops through all k folds and saves as train1, train2, etc.
dataset.save_subset_folds(train_loader, subset_names[0])
dataset.save_subset_folds(val_loader, subset_names[1])

dataset.save_subset(test_loader, subset_names[2])

del dataset # free up memory

0 comments on commit 11876db

Please sign in to comment.