Skip to content

Commit

Permalink
fix(splitting): include train set csv for #115
Browse files Browse the repository at this point in the history
Allows us to get exactly the same prots from another dataset like alphaflow dataset that is limited due to longer proteins.
  • Loading branch information
jyaacoub committed Jul 9, 2024
1 parent a0e4405 commit bb6ff22
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 13 deletions.
14 changes: 12 additions & 2 deletions playground.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,13 @@
# %%
from src.train_test.splitting import resplit
from src import cfg

db_p = lambda x: f'{cfg.DATA_ROOT}/DavisKibaDataset/davis/nomsa_{x}_original_binary'

db = resplit(dataset=db_p('binary'), split_files=db_p('aflow'), use_train_set=True)



# %%
########################################################################
########################## VIOLIN PLOTTING #############################
Expand All @@ -24,13 +34,13 @@
# 'gvpl_esm_aflow': ('ESM', 'aflow', 'gvp', 'binary'),
}

df = prepare_df('/cluster/home/t122995uhn/projects/MutDTA/results/v113/model_media/model_stats.csv')
df = prepare_df('./results/v113/model_media/model_stats.csv')
fig, axes = fig_combined(df, datasets=['davis'], fig_callable=custom_fig,
models=models, metrics=['cindex', 'mse'],
fig_scale=(10,5), add_stats=True, title_postfix=" test set performance")
plt.xticks(rotation=45)

df = prepare_df('/cluster/home/t122995uhn/projects/MutDTA/results/v113/model_media/model_stats_val.csv')
df = prepare_df('./results/v113/model_media/model_stats_val.csv')
fig, axes = fig_combined(df, datasets=['davis'], fig_callable=custom_fig,
models=models, metrics=['cindex', 'mse'],
fig_scale=(10,5), add_stats=True, title_postfix=" validation set performance")
Expand Down
33 changes: 23 additions & 10 deletions src/train_test/splitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ def balanced_kfold_split(dataset: BaseDataset |str,


@init_dataset_object(strict=True)
def resplit(dataset:str|BaseDataset, split_files:dict|str=None, **kwargs):
def resplit(dataset:str|BaseDataset, split_files:dict|str=None, use_train_set=False, **kwargs):
"""
1.Takes as input the target dataset path or dataset object, and a dict defining the 6 splits for all 5 folds +
1 test set.
Expand All @@ -308,24 +308,36 @@ def resplit(dataset:str|BaseDataset, split_files:dict|str=None, **kwargs):
Returns:
BaseDataset: dataset object for "full" dataset
"""
key_names = ['val0', 'val1', 'val2', 'val3', 'val4', 'test']
if use_train_set: key_names += ['train0', 'train1', 'train2', 'train3', 'train4']

# getting split files from directory
if isinstance(split_files, str):
csv_files = {}
for split in ['test'] + [f'val{i}' for i in range(5)]:
for split in key_names:
csv_files[split] = f'{split_files}/{split}/cleaned_XY.csv'
split_files = csv_files
print('Using split files from:', split_files)

assert 'test' in split_files, 'Missing test csv from split files.'

##### Validate split files and if they exist #####
missing = []
for k in key_names:
if k not in split_files:
missing.append(k)
assert len(missing) == 0, f'Missing split files for: {missing}'

# Check if split files exist and are in the correct format
if split_files is None:
raise ValueError('split_files must be provided')
if len(split_files) != 6:
raise ValueError('split_files must contain 6 files for the 5 folds and test set')
num_files = 5*(1+use_train_set) + 1
if len(split_files) != num_files:
raise ValueError(f'split_files must contain {num_files} files for the 5 folds +1 for the test set.')

for f in split_files.values():
if not os.path.exists(f):
raise ValueError(f'{f} does not exist')

##### perform the resplitting #####
# Getting indices for each split based on db.df
split_files = split_files.copy()
test_prots = set(pd.read_csv(split_files['test'])['prot_id'])
Expand All @@ -339,10 +351,11 @@ def resplit(dataset:str|BaseDataset, split_files:dict|str=None, **kwargs):
val_idxs = [i for i in range(len(dataset.df)) if dataset.df.iloc[i]['prot_id'] in prots]
dataset.save_subset(val_idxs, k)

# Build training set from all proteins not in the val/test set
idxs = set(val_idxs + test_idxs)
train_idxs = [i for i in range(len(dataset.df)) if i not in idxs]
dataset.save_subset(train_idxs, k.replace('val', 'train'))
if not use_train_set:
# Build training set from all proteins not in the val/test set
idxs = set(val_idxs + test_idxs)
train_idxs = [i for i in range(len(dataset.df)) if i not in idxs]
dataset.save_subset(train_idxs, k.replace('val', 'train'))

return dataset

2 changes: 1 addition & 1 deletion src/utils/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ class LIG_FEAT_OPT(StringEnum):
DATA_ROOT = os.path.abspath('../data/')

# Model save paths
issue_number = 113 # 113 is for unifying all splits for cross validation so that we are more confident
issue_number = 115 # 113 is for unifying all splits for cross validation so that we are more confident
# when comparing results that they were trained in the same manner.
RESULTS_PATH = os.path.abspath(f'results/v{issue_number}/')
MEDIA_SAVE_DIR = f'{RESULTS_PATH}/model_media/'
Expand Down

0 comments on commit bb6ff22

Please sign in to comment.