Skip to content

Commit

Permalink
fix(splitting): created davis splits #113
Browse files Browse the repository at this point in the history
  • Loading branch information
jyaacoub committed Jul 4, 2024
1 parent c4c7741 commit 099c3a3
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 76 deletions.
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -222,4 +222,6 @@ results/model_checkpoints/*/*.model
results/model_media/*/train_log/*
results/model_media/*/train_set_pred/*
results/model_media/*/test_set_pred/*
results/model_media/test_set_pred
results/model_media/test_set_pred

splits/*.csv
111 changes: 37 additions & 74 deletions playground.py
Original file line number Diff line number Diff line change
@@ -1,92 +1,55 @@
#%% now based on this test set we can create the splits that will be used for all models
# 5-fold cross validation + test set
import pandas as pd
from src import cfg
from src.train_test.splitting import balanced_kfold_split
from src.utils.loader import Loader

test_df = pd.read_csv('/home/jean/projects/data/splits/davis_test_genes_oncoG.csv')
test_prots = set(test_df.prot_id)

db = Loader.load_dataset(f'{cfg.DATA_ROOT}/DavisKibaDataset/davis/nomsa_binary_original_binary/full/')

#%%
train, val, test = balanced_kfold_split(db,
k_folds=5, test_split=0.1, val_split=0.1,
test_prots=test_prots, random_seed=0, verbose=True
)


#%%
db.save_subset_folds(train, 'train')
db.save_subset_folds(val, 'val')
db.save_subset(test, 'test')

#%%
# %%
import os
from src.train_test.splitting import resplit
from src import cfg

csv_files = {}
for split in ['test'] + [f'val{i}' for i in range(5)]:
csv_files[split] = f'{cfg.DATA_ROOT}/splits/davis/davis_{split}.csv'
csv_files[split] = f'./splits/davis_{split}.csv'
assert os.path.exists(csv_files[split]), csv_files[split]

print(csv_files)

#%%
for d in os.listdir(f'{cfg.DATA_ROOT}/DavisKibaDataset/davis/'):
if len(d.split('_')) < 4:
print('skipping:', d)
continue
print('resplitting:', d)
resplit(f'{cfg.DATA_ROOT}/DavisKibaDataset/davis/{d}', split_files=csv_files)

db = resplit(f'{cfg.DATA_ROOT}/DavisKibaDataset/davis/nomsa_binary_original_binary',
split_files=csv_files)

#%% Checking for overlap
import pandas as pd

# Define file paths
file_paths = {
'test': 'test/cleaned_XY.csv',
'val0': 'val0/cleaned_XY.csv',
'val1': 'val1/cleaned_XY.csv',
'val2': 'val2/cleaned_XY.csv',
'val3': 'val3/cleaned_XY.csv',
'val4': 'val4/cleaned_XY.csv',
'train0': 'train0/cleaned_XY.csv',
'train1': 'train1/cleaned_XY.csv',
'train2': 'train2/cleaned_XY.csv',
'train3': 'train3/cleaned_XY.csv',
'train4': 'train4/cleaned_XY.csv'
}
file_paths = {name: f'{cfg.DATA_ROOT}/DavisKibaDataset/davis/nomsa_binary_original_binary/{path}' for name, path in file_paths.items()}

# Load CSV files into dataframes
dataframes = {name: pd.read_csv(path) for name, path in file_paths.items()}

# Function to check for overlap
def check_overlap(df1, df2, name1, name2):
overlap = df1.merge(df2, on='prot_id', how='inner')
if not overlap.empty:
print(f'Overlap found between {name1} and {name2}')
print(overlap)
else:
print(f'No overlap between {name1} and {name2}')

# Check for overlaps
# Test should not overlap with any other CSV
for name in file_paths:
if name != 'test':
check_overlap(dataframes['test'], dataframes[name], 'test', name)

# valX should not overlap with corresponding trainX
for i in range(5):
val_name = f'val{i}'
train_name = f'train{i}'
check_overlap(dataframes[val_name], dataframes[train_name], val_name, train_name)

# Note: Overlaps between val0 and train1, val1 and train2, etc. are allowed as they are different folds



#%%
from src.utils.loader import Loader
for d in os.listdir(f'{cfg.DATA_ROOT}/DavisKibaDataset/davis/'):
if len(d.split('_')) < 4:
print('skipping:', d)
continue
# Define file paths
file_paths = {
'test': 'test/cleaned_XY.csv',
'val0': 'val0/cleaned_XY.csv',
'train0': 'train0/cleaned_XY.csv',
}
file_paths = {name: f'{cfg.DATA_ROOT}/DavisKibaDataset/davis/{d}/{path}' for name, path in file_paths.items()}
count = 0
print(f"\n{'-'*10}{d}{'-'*10}")
for k, v in file_paths.items():
df = pd.read_csv(v)
print(f"{k:>12}: {len(df):>6d}")
count += len(df)

print(f' = {count:>6d}')

df = f'{cfg.DATA_ROOT}/DavisKibaDataset/davis/{d}/full/cleaned_XY.csv'
df = pd.read_csv(df)
# print(f' = {count:>6d}')
print(f'Dataset Size: {len(df):>6d}')



db_train = Loader.load_dataset(f'{cfg.DATA_ROOT}/DavisKibaDataset/davis/nomsa_binary_original_binary/train0/')

# %%
########################################################################
Expand Down
3 changes: 2 additions & 1 deletion src/train_test/splitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,8 @@ def resplit(dataset:str|BaseDataset, split_files:dict|str=None, **kwargs):
if not os.path.exists(f):
raise ValueError(f'{f} does not exist')

# Getting indices for each split based on db.df
# Getting indices for each split based on db.df
split_files = split_files.copy()
test_prots = set(pd.read_csv(split_files['test'])['prot_id'])
test_idxs = [i for i in range(len(dataset.df)) if dataset.df.iloc[i]['prot_id'] in test_prots]
dataset.save_subset(test_idxs, 'test')
Expand Down

0 comments on commit 099c3a3

Please sign in to comment.