Skip to content

Commit

Permalink
fix(init_dataset): adding resplit to create_datasets #113
Browse files Browse the repository at this point in the history
  • Loading branch information
jyaacoub committed Jul 8, 2024
1 parent 256563c commit b15e83d
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 30 deletions.
2 changes: 1 addition & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -224,4 +224,4 @@ results/model_media/*/train_set_pred/*
results/model_media/*/test_set_pred/*
results/model_media/test_set_pred

splits/*.csv
splits/**/*.csv
Binary file modified splits/davis_splits.zip
Binary file not shown.
64 changes: 36 additions & 28 deletions src/data_prep/init_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from src.utils import config as cfg
from src.data_prep.feature_extraction.protein_nodes import create_pfm_np_files
from src.data_prep.datasets import DavisKibaDataset, PDBbindDataset, PlatinumDataset
from src.train_test.splitting import train_val_test_split, balanced_kfold_split
from src.train_test.splitting import train_val_test_split, balanced_kfold_split, resplit

def create_datasets(data_opt:list[str]|str, feat_opt:list[str]|str, edge_opt:list[str]|str,
ligand_features:list[str]=['original'],
Expand All @@ -22,6 +22,7 @@ def create_datasets(data_opt:list[str]|str, feat_opt:list[str]|str, edge_opt:lis
val_split:float=0.1,
overwrite=True,
test_prots_csv:str=None,
val_prots_csv:list[str]=None,
**kwargs) -> None:
"""
Creates the datasets for the given data, feature, and edge options.
Expand Down Expand Up @@ -123,36 +124,43 @@ def create_datasets(data_opt:list[str]|str, feat_opt:list[str]|str, edge_opt:lis

# saving training, validation, and test sets
test_split = 1 - train_split - val_split
if k_folds is None:
train_loader, val_loader, test_loader = train_val_test_split(dataset,
train_split=train_split, val_split=val_split,
random_seed=random_seed, split_by_prot=not pro_overlap)
if val_prots_csv:
assert k_folds is None or len(val_prots_csv) == k_folds, "Mismatch between number of val_prot_csvs provided and k_folds selected."

split_files = {os.path.basename(v).split('.')[0]: v for v in val_prots_csv}
split_files['test'] = test_prots_csv
dataset = resplit(dataset, split_files=split_files)
else:
assert test_split > 0, f"Invalid train/val/test split: {train_split}/{val_split}/{test_split}"
assert not pro_overlap, f"No support for overlapping proteins with k-folds rn."
if test_prots_csv is not None:
df = pd.read_csv(test_prots_csv)
test_prots = set(df['prot_id'].tolist())
if k_folds is None:
train_loader, val_loader, test_loader = train_val_test_split(dataset,
train_split=train_split, val_split=val_split,
random_seed=random_seed, split_by_prot=not pro_overlap)
else:
test_prots = None
assert test_split > 0, f"Invalid train/val/test split: {train_split}/{val_split}/{test_split}"
assert not pro_overlap, f"No support for overlapping proteins with k-folds rn."
if test_prots_csv is not None:
df = pd.read_csv(test_prots_csv)
test_prots = set(df['prot_id'].tolist())
else:
test_prots = None

train_loader, val_loader, test_loader = balanced_kfold_split(dataset,
k_folds=k_folds, test_split=test_split, test_prots=test_prots,
random_seed=random_seed) # only non-overlapping splits for k-folds
train_loader, val_loader, test_loader = balanced_kfold_split(dataset,
k_folds=k_folds, test_split=test_split, test_prots=test_prots,
random_seed=random_seed) # only non-overlapping splits for k-folds

subset_names = ['train', 'val', 'test']
if pro_overlap:
subset_names = [s+'-overlap' for s in subset_names]

subset_names = ['train', 'val', 'test']
if pro_overlap:
subset_names = [s+'-overlap' for s in subset_names]

if test_split < 1: # for datasets that are purely for testing we skip this section
if k_folds is None:
dataset.save_subset(train_loader, subset_names[0])
dataset.save_subset(val_loader, subset_names[1])
else:
# loops through all k folds and saves as train1, train2, etc.
dataset.save_subset_folds(train_loader, subset_names[0])
dataset.save_subset_folds(val_loader, subset_names[1])

dataset.save_subset(test_loader, subset_names[2])
if test_split < 1: # for datasets that are purely for testing we skip this section
if k_folds is None:
dataset.save_subset(train_loader, subset_names[0])
dataset.save_subset(val_loader, subset_names[1])
else:
# loops through all k folds and saves as train1, train2, etc.
dataset.save_subset_folds(train_loader, subset_names[0])
dataset.save_subset_folds(val_loader, subset_names[1])

dataset.save_subset(test_loader, subset_names[2])

del dataset # free up memory
4 changes: 3 additions & 1 deletion src/train_test/splitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,9 @@ def resplit(dataset:str|BaseDataset, split_files:dict|str=None, **kwargs):
csv_files[split] = f'{split_files}/{split}/cleaned_XY.csv'
split_files = csv_files
print('Using split files from:', split_files)


assert 'test' in split_files, 'Missing test csv from split files.'

# Check if split files exist and are in the correct format
if split_files is None:
raise ValueError('split_files must be provided')
Expand Down

0 comments on commit b15e83d

Please sign in to comment.