Skip to content

Commit

Permalink
feat(resplit): extract csvs from "like_dataset" #112 #113
Browse files Browse the repository at this point in the history
  • Loading branch information
jyaacoub committed Jul 4, 2024
1 parent c47be94 commit ef0106c
Showing 1 changed file with 28 additions and 6 deletions.
34 changes: 28 additions & 6 deletions src/train_test/splitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,17 +283,39 @@ def balanced_kfold_split(dataset: BaseDataset |str,


@init_dataset_object(strict=True)
def resplit(dataset:str|BaseDataset, split_files:dict=None, **kwargs):
def resplit(dataset:str|BaseDataset, split_files:dict|str=None, **kwargs):
"""
- Takes as input the target dataset path or dataset object, and a dict defining the 6 splits for all 5 folds + 1 test set.
1.Takes as input the target dataset path or dataset object, and a dict defining the 6 splits for all 5 folds +
1 test set.
- Decorator will automatically convert the dataset path to a dataset object
- split files should be a dict to csv files, each containing the proteins for the splits, where the keys are:
- val0, val1, val2, val3, val4, test
- training sets will be built from the remaining proteins (i.e.: proteins not in any of the val/test sets)
- Deletes existing splits
- Builds new splits using Dataset.save_subset()
2.Deletes existing splits
3.Builds new splits using Dataset.save_subset()
Args:
dataset (str | BaseDataset): path to full dataset directory or dataset object
split_files (dict | str, optional): Dictionary of csvs for each of the n folds + the test set, where keys are
val0, val1, val2, val3, val4, test and the values are the path to the csvs with a "prot_id" column. OR path to
another dataset directory that you want to match in terms of dataset split where we extract the csvs from
Defaults to None.
Raises:
ValueError: no split_files provided
ValueError: split_files must contain 6 files for the 5 folds and test set
ValueError: split file does not exist
Returns:
BaseDataset: dataset object for "full" dataset
"""

if isinstance(split_files, str):
csv_files = {}
for split in ['test'] + [f'val{i}' for i in range(5)]:
csv_files[split] = f'{split_files}/{split}/cleaned_XY.csv'
split_files = csv_files
print('Using split files from:', split_files)

# Check if split files exist and are in the correct format
if split_files is None:
raise ValueError('split_files must be provided')
Expand All @@ -315,7 +337,7 @@ def resplit(dataset:str|BaseDataset, split_files:dict=None, **kwargs):
val_idxs = [i for i in range(len(dataset.df)) if dataset.df.iloc[i]['prot_id'] in prots]
dataset.save_subset(val_idxs, k)

# build training set from all proteins not in the val/test set
# Build training set from all proteins not in the val/test set
idxs = set(val_idxs + test_idxs)
train_idxs = [i for i in range(len(dataset.df)) if i not in idxs]
dataset.save_subset(train_idxs, k.replace('val', 'train'))
Expand Down

0 comments on commit ef0106c

Please sign in to comment.