Skip to content

Commit

Permalink
feat(resplit): resplit stub for #113
Browse files Browse the repository at this point in the history
  • Loading branch information
jyaacoub committed Jul 3, 2024
1 parent f442919 commit 9ac093f
Show file tree
Hide file tree
Showing 2 changed files with 106 additions and 3 deletions.
17 changes: 16 additions & 1 deletion src/train_test/splitting.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from src.models.utils import BaseModel
from src.data_prep.datasets import BaseDataset
from src.utils.loader import init_dataset_object

# Creating data indices for training and validation splits:
def train_val_test_split(dataset: BaseDataset,
Expand Down Expand Up @@ -272,4 +273,18 @@ def balanced_kfold_split(dataset: BaseDataset,
print(f'Dataset size: {dataset_size}')
assert te_count > 0, 'Test set is empty'

return train_loaders, val_loaders, test_loader
return train_loaders, val_loaders, test_loader


@init_dataset_object(strict=True)
def resplit(dataset:str|BaseDataset, split_files:list=None, **kwargs):
"""
- Takes as input the target dataset path or dataset object, and a list defining the 6 splits for all 5 folds + 1 test set.
- Deletes existing splits
- Builds new splits using Dataset.save_subset()
"""
print("RESPLIT")
print(kwargs)

return dataset

92 changes: 90 additions & 2 deletions src/utils/loader.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# TODO: create a trainer class for modularity
import os
from functools import wraps
from typing import Iterable
from torch.utils.data.distributed import DistributedSampler
Expand All @@ -10,7 +10,7 @@
from src.models.prior_work import DGraphDTA, DGraphDTAImproved
from src.models.ring3 import Ring3DTA
from src.models.gvp_models import GVPModel, GVPLigand_DGPro, GVPLigand_RNG3, GVPL_ESM
from src.data_prep.datasets import PDBbindDataset, DavisKibaDataset, PlatinumDataset
from src.data_prep.datasets import PDBbindDataset, DavisKibaDataset, PlatinumDataset, BaseDataset
from src.utils import config as cfg # sets up os env for HF

def validate_args(valid_options):
Expand Down Expand Up @@ -317,3 +317,91 @@ def load_distributed_DataLoaders(num_replicas:int, rank:int, seed:int, data:str,

return loaders




##################################################
########## Extra related helpful methods #########
def parse_db_kwargs(db_path):
"""
Parses parameters given a path string to a db you want to load up.
If subset folder is not included then we default to 'full' for the subset
"""
kwargs = {
'data': None,
'subset': 'full',
}
# get db class/type
db_path_s = [x for x in db_path.split('/') if x]
if 'PDBbindDataset' in db_path_s:
idx_cls = db_path_s.index('PDBbindDataset')
kwargs['data'] = cfg.DATA_OPT.PDBbind
if len(db_path_s) > idx_cls+2: # +2 to skip over db_params
kwargs['subset'] = db_path_s[idx_cls+2]
# remove from string
db_path = '/'.join(db_path_s[:idx_cls+2])
elif 'DavisKibaDataset' in db_path_s:
idx_cls = db_path_s.index('DavisKibaDataset')
kwargs['data'] = cfg.DATA_OPT.davis if db_path_s[idx_cls+1] == 'davis' else cfg.DATA_OPT.kiba
if len(db_path_s) > idx_cls+3:
kwargs['subset'] = db_path_s[idx_cls+3]
db_path = '/'.join(db_path_s[:idx_cls+3])
else:
raise ValueError(f"Invalid path string, couldn't find db class info - {db_path_s}")

# get db parameters:
kwargs_p = {
'pro_feature': cfg.PRO_FEAT_OPT,
'edge_opt': cfg.PRO_EDGE_OPT,
'ligand_feature': cfg.LIG_FEAT_OPT,
'ligand_edge': cfg.LIG_EDGE_OPT,
}
db_params = os.path.basename(db_path.strip('/')).split('_')
for k, params in kwargs_p.items():
double = "_".join(db_params[:2])

if double in params:
kwargs_p[k] = double
db_params = db_params[2:]
elif db_params[0] in params:
kwargs_p[k] = db_params[0]
db_params = db_params[1:]
else:
raise ValueError(f'Invalid option, did not find {double} or {db_params[0]} in {params}')
assert len(db_params) == 0, f"still some unparsed params - {db_params}"

return {**kwargs, **kwargs_p}

# decorator to allow for input to simply be the path to the dataset directory.
def init_dataset_object(strict=True):
def decorator(func):
def wrapper(*args, **kwargs):
# Get dataset argument from args or kwargs
dataset = kwargs.get('dataset', args[0] if args else None)

# Check if dataset is a string (file path) or an actual DB object
if isinstance(dataset, str):
if strict and not os.path.exists(dataset):
raise FileNotFoundError(f'Dataset does not exist - {dataset}')

# Parse and build dataset
kwargs = parse_db_kwargs(dataset)
print('Loading dataset with', kwargs)
built = Loader.load_dataset(**kwargs)
elif isinstance(dataset, BaseDataset):
built = dataset
elif dataset is None:
raise ValueError('Missing Dataset in args/kwargs')
else:
raise TypeError('Invalid format for dataset')

# Add built dataset to args/kwargs
if 'dataset' in kwargs:
kwargs['dataset'] = built
else:
args = (built, *args[1:])

# Return the function call output
return func(*args, **kwargs)
return wrapper
return decorator

0 comments on commit 9ac093f

Please sign in to comment.