diff --git a/benchmark/run_param.py b/benchmark/run_param.py index f0b9eeb..909f455 100644 --- a/benchmark/run_param.py +++ b/benchmark/run_param.py @@ -8,6 +8,10 @@ import uuid from copy import deepcopy +from pyg_spectral.nn import get_model_regi, get_conv_regi +from pyg_spectral.nn.parse_args import compose_param +from pyg_spectral.utils import CallableDict + from trainer import ( SingleGraphLoader_Trial, ModelLoader_Trial, @@ -34,75 +38,34 @@ def __init__(self, data_loader, model_loader, args, res_logger = None): self.fmt_logger = {} self.metric = None - self.data, self.model, self.trn_cls, self.trn = None, None, None, None + self.data, self.model, self.trn = None, None, None + self.trn_cls = { + TrnFullbatch: TrnFullbatch_Trial, + TrnMinibatch: TrnMinibatch_Trial, + }[self.model_loader.get_trn(args)] def _get_suggest(self, trial, key): - list2str = lambda x: ','.join(map(str, x)) - nofmt = lambda x: x - # >>>>>>>>>> - theta_dct = { - # "impulse": (lambda x, _: x[0], (self.args.num_hops,), {}), - "ones": (trial.suggest_float, (0.0, 1.0), {'step': 0.01}, lambda x: round(x, 2)), - "impulse": (trial.suggest_float, (0.0, 1.0), {'step': 0.01}, lambda x: round(x, 2)), - "appr": (trial.suggest_float, (0.0, 1.0), {'step': 0.01}, lambda x: round(x, 2)), - "nappr": (trial.suggest_float, (0.0, 1.0), {'step': 0.01}, lambda x: round(x, 2)), - "mono": (trial.suggest_float, (0.0, 1.0), {'step': 0.01}, lambda x: round(x, 2)), - "hk": (trial.suggest_float, (1e-2, 10), {'log': True}, lambda x: float(f'{x:.3e}')), - "gaussian": (trial.suggest_float, (1e-2, 10), {'log': True}, lambda x: float(f'{x:.3e}')), - } - suggest_dct = { - # critical - 'num_hops': (trial.suggest_int, (2, 30), {'step': 2}, nofmt), - 'in_layers': (trial.suggest_int, (1, 3), {}, nofmt), - 'out_layers': (trial.suggest_int, (1, 3), {}, nofmt), - 'hidden_channels': (trial.suggest_categorical, ([16, 32, 64, 128, 256],), {}, nofmt), - 'combine': (trial.suggest_categorical, (["sum", "sum_weighted", "cat"],), {}, nofmt), - # secondary - 'theta_param': theta_dct.get(self.args.theta_scheme, None), - 'normg': (trial.suggest_float, (0.0, 1.0), {'step': 0.05}, lambda x: round(x, 2)), - 'dropout_lin': (trial.suggest_float, (0.0, 1.0), {'step': 0.1}, lambda x: round(x, 2)), - 'dropout_conv': (trial.suggest_float, (0.0, 1.0), {'step': 0.1}, lambda x: round(x, 2)), - 'lr_lin': (trial.suggest_float, (1e-5, 5e-1), {'log': True}, lambda x: float(f'{x:.3e}')), - 'lr_conv': (trial.suggest_float, (1e-5, 5e-1), {'log': True}, lambda x: float(f'{x:.3e}')), - 'wd_lin': (trial.suggest_float, (1e-7, 1e-3), {'log': True}, lambda x: float(f'{x:.3e}')), - 'wd_conv': (trial.suggest_float, (1e-7, 1e-3), {'log': True}, lambda x: float(f'{x:.3e}')), - 'alpha': (trial.suggest_float, (0.01, 1.0), {'step': 0.01}, lambda x: round(x, 2)), - 'beta': (trial.suggest_float, (0.0, 1.0), {'step': 0.01}, lambda x: round(x, 2)), - } - - # Model/conv-specific - # if self.args.model in ['Iterative']: - # suggest_dct['in_layers'][1] = (1, 3) - # suggest_dct['out_layers'][1] = (1, 3) - # <<<<<<<<<< - - if 'Compose' in self.args.model: - convs = self.args.conv.split(',') - if key == 'theta_param': - schemes = self.args.theta_scheme.split(',') - lst = [] - for i,c in enumerate(convs): - func, fargs, fkwargs, fmt = theta_dct.get(schemes[i], None) - lst.append(func(key+'-'+str(i), *fargs, **fkwargs)) - return lst, fmt - elif key == 'beta': - func, fargs, fkwargs, fmt = suggest_dct[key] - beta_c = { - 'AdjiConv': [(0.0, 1.0), (0.0, 1.0)], # FAGNN - 'AdjSkipConv': [(0.0, 1.0), (0.0, 1.0)], # FAGNN - 'Adji2Conv': [(1.0, 2.0), (0.0, 1.0)], # G2CN - 'AdjSkip2Conv': [(1.0, 2.0), (0.0, 1.0)], # G2CN - 'AdjDiffConv': [(0.0, 1.0), (-1.0, 0.0)], # GNN-LF/HF - } - lst = [func(key+'-'+str(i), *beta_i, **fkwargs) for i,beta_i in enumerate(beta_c[convs[0]])] - # return list2str(lst), str - return lst, fmt - else: - func, fargs, fkwargs, fmt = suggest_dct[key] - return func(key, *fargs, **fkwargs), fmt - else: - func, fargs, fkwargs, fmt = suggest_dct[key] - return func(key, *fargs, **fkwargs), fmt + def parse_param(val): + if isinstance(val, list): + fmt = val[0][-1] + val = [getattr(trial, 'suggest_'+func)(key+'-'+str(i), *fargs, **fkwargs) for i, (func, fargs, fkwargs, _) in enumerate(val)] + return val, fmt + func, fargs, fkwargs, fmt = val + return getattr(trial, 'suggest_'+func)(key, *fargs, **fkwargs), fmt + + # Alias compose models + if (self.args.model in compose_param and + self.model_loader.conv_repr in compose_param[self.args.model] and + key in compose_param[self.args.model][self.model_loader.conv_repr]): + return parse_param(compose_param[self.args.model][self.model_loader.conv_repr](key, self.args)) + + single_param = SingleGraphLoader_Trial.param | ModelLoader_Trial.param | self.trn_cls.param + single_param = CallableDict(single_param) + single_param |= get_model_regi(self.args.model, 'param') + if key in single_param: + return parse_param(single_param(key, self.args)) + + return parse_param(get_conv_regi(self.args.conv, 'param')(key, self.args)) def __call__(self, trial): args = deepcopy(self.args) @@ -115,18 +78,15 @@ def __call__(self, trial): if self.data is None: self.data = self.data_loader.get(args) - self.model, trn_cls = self.model_loader.get(args) - self.trn_cls = { - TrnFullbatch: TrnFullbatch_Trial, - TrnMinibatch: TrnMinibatch_Trial, - }[trn_cls] + self.model, _ = self.model_loader.get(args) - for key in ['in_channels', 'out_channels', 'metric', 'multi', 'criterion']: + for key in SingleGraphLoader_Trial.args_out + ModelLoader_Trial.args_out: self.args.__dict__[key] = args.__dict__[key] self.metric = args.metric else: self.data = self.data_loader.update(args, self.data) self.model = self.model_loader.update(args, self.model) + res_logger = deepcopy(self.res_logger) for key in self.args.param: val = args.__dict__[key] @@ -220,7 +180,7 @@ def main(args): best_params = {k: trn.fmt_logger[k](v) for k, v in study.best_params.items()} save_args(args.logpath, best_params) axes = optuna.visualization.matplotlib.plot_parallel_coordinate( - study, params=best_params.keys()) + study, params=study.best_params.keys()) axes.get_figure().savefig(args.logpath.joinpath('parallel_coordinate.png')) clear_logger(logger) diff --git a/benchmark/trainer/base.py b/benchmark/trainer/base.py index abfc2a6..cab14dd 100755 --- a/benchmark/trainer/base.py +++ b/benchmark/trainer/base.py @@ -50,6 +50,12 @@ class TrnBase(object): run: Run the training process. """ name: str + param = { + 'lr_lin': ('float', (1e-5, 5e-1), {'log': True}, lambda x: float(f'{x:.3e}')), + 'lr_conv': ('float', (1e-5, 5e-1), {'log': True}, lambda x: float(f'{x:.3e}')), + 'wd_lin': ('float', (1e-7, 1e-3), {'log': True}, lambda x: float(f'{x:.3e}')), + 'wd_conv': ('float', (1e-7, 1e-3), {'log': True}, lambda x: float(f'{x:.3e}')), + } def __init__(self, model: nn.Module, diff --git a/benchmark/trainer/fullbatch.py b/benchmark/trainer/fullbatch.py index a3dc667..f1682f0 100755 --- a/benchmark/trainer/fullbatch.py +++ b/benchmark/trainer/fullbatch.py @@ -194,3 +194,4 @@ def run(self) -> ResLogger: def update(self, *args, **kwargs): self.__init__(*args, **kwargs) + return self diff --git a/benchmark/trainer/load_data.py b/benchmark/trainer/load_data.py index fefcc05..8b0373e 100755 --- a/benchmark/trainer/load_data.py +++ b/benchmark/trainer/load_data.py @@ -32,6 +32,11 @@ class SingleGraphLoader(object): * args.data_split (str): Index of dataset split. res_logger: Logger for results. """ + args_out = ['in_channels', 'out_channels', 'multi', 'metric'] + param = { + 'normg': ('float', (0.0, 1.0), {'step': 0.05}, lambda x: round(x, 2)), + } + def __init__(self, args: Namespace, res_logger: ResLogger = None) -> None: # Assigning dataset identity. self.seed = args.seed diff --git a/benchmark/trainer/load_model.py b/benchmark/trainer/load_model.py index 9340d03..f5b757f 100755 --- a/benchmark/trainer/load_model.py +++ b/benchmark/trainer/load_model.py @@ -26,6 +26,9 @@ class ModelLoader(object): * args.conv (str): Convolution layer name. res_logger: Logger for results. """ + args_out = ['criterion'] + param = {} + def __init__(self, args: Namespace, res_logger: ResLogger = None) -> None: r"""Assigning model identity. """ @@ -52,11 +55,17 @@ def get_name(args: Namespace) -> Tuple[str]: """ return get_nn_name(args.model, args.conv, args) - def _resolve_import(self, args: Namespace) -> Tuple[str, str, dict, TrnBase]: - class_name = self.model - module_name = get_model_regi(self.model, 'module', args) - kwargs = set_pargs(self.model, self.conv, args) - trn = { + @staticmethod + def get_trn(args: Namespace) -> TrnBase: + r"""Get trainer class from model name. + + Args: + args: Configuration arguments. + Returns: + trn (TrnBase): Trainer class. + """ + model_repr, _ = get_nn_name(args.model, args.conv, args) + return { 'DecoupledFixed': TrnFullbatch, 'DecoupledVar': TrnFullbatch, 'Iterative': TrnFullbatch, @@ -64,12 +73,18 @@ def _resolve_import(self, args: Namespace) -> Tuple[str, str, dict, TrnBase]: 'PrecomputedVar': TrnMinibatch, 'PrecomputedFixed': TrnMinibatch, 'CppPrecFixed': TrnMinibatch, - }[self.model_repr] + }[model_repr] + + def _resolve_import(self, args: Namespace) -> Tuple[str, str, dict]: + class_name = self.model + module_name = get_model_regi(self.model, 'module', args) + kwargs = set_pargs(self.model, self.conv, args) # >>>>>>>>>> if module_name == 'torch_geometric.nn.models': args.criterion = 'BCEWithLogitsLoss' if args.out_channels == 1 else 'CrossEntropyLoss' + del kwargs['conv'] kwargs.setdefault('num_layers', kwargs.pop('num_hops')) kwargs.setdefault('dropout', kwargs.pop('dropout_lin')) @@ -86,7 +101,7 @@ def _resolve_import(self, args: Namespace) -> Tuple[str, str, dict, TrnBase]: kwargs['num_hops'] = int(kwargs['num_hops'] / 2) # Parse model args # <<<<<<<<<< - return class_name, module_name, kwargs, trn + return class_name, module_name, kwargs def get(self, args: Namespace) -> Tuple[nn.Module, TrnBase]: r"""Load model with specified arguments. @@ -99,10 +114,13 @@ def get(self, args: Namespace) -> Tuple[nn.Module, TrnBase]: args.out_channels (int): Number of output classes. args.hidden_channels (int): Number of hidden units. args.dropout_[lin/conv] (float): Dropout rate for linear/conv. + Updates: + args.criterion (str): Loss function name. """ self.logger.debug('-'*20 + f" Loading model: {self} " + '-'*20) - class_name, module_name, kwargs, trn = self._resolve_import(args) + trn = self.get_trn(args) + class_name, module_name, kwargs = self._resolve_import(args) model = load_import(class_name, module_name)(**kwargs) if hasattr(model, 'reset_parameters'): model.reset_parameters() @@ -129,7 +147,8 @@ def get(self, args: Namespace) -> Tuple[nn.Module, TrnBase]: self.signature_lst = ['num_hops', 'in_layers', 'out_layers', 'hidden_channels', 'dropout_lin', 'dropout_conv'] self.signature = {key: getattr(args, key) for key in self.signature_lst} - class_name, module_name, kwargs, trn = self._resolve_import(args) + trn = self.get_trn(args) + class_name, module_name, kwargs = self._resolve_import(args) model = load_import(class_name, module_name)(**kwargs) if hasattr(model, 'reset_parameters'): model.reset_parameters() diff --git a/benchmark/trainer/minibatch.py b/benchmark/trainer/minibatch.py index 54c43ab..45e41b5 100755 --- a/benchmark/trainer/minibatch.py +++ b/benchmark/trainer/minibatch.py @@ -272,3 +272,4 @@ def update(self, if not hasattr(self, 'signature') or self.signature != signature: self.signature = signature self.embed = None + return self diff --git a/benchmark/utils/logger.py b/benchmark/utils/logger.py index 437087f..216c5d4 100755 --- a/benchmark/utils/logger.py +++ b/benchmark/utils/logger.py @@ -25,6 +25,7 @@ warnings.filterwarnings('ignore', '.*No positive samples found in target.*') warnings.filterwarnings('ignore', '.*No negative samples found in target.*') warnings.filterwarnings('ignore', '.*is( an)? experimental.*') +warnings.filterwarnings('ignore', '.*Attempting to set identical low and high ylims.*') def setup_logger(logpath: Union[Path, str] = LOGPATH, diff --git a/pyg_spectral/nn/__init__.py b/pyg_spectral/nn/__init__.py index 93e990e..5478bff 100755 --- a/pyg_spectral/nn/__init__.py +++ b/pyg_spectral/nn/__init__.py @@ -1,4 +1,4 @@ from .parse_args import ( get_model_regi, get_conv_regi, - get_nn_name, set_pargs, get_param + get_nn_name, set_pargs ) diff --git a/pyg_spectral/nn/models/decoupled.py b/pyg_spectral/nn/models/decoupled.py index 6953278..0ed52da 100644 --- a/pyg_spectral/nn/models/decoupled.py +++ b/pyg_spectral/nn/models/decoupled.py @@ -9,11 +9,11 @@ from torch_geometric.nn.inits import reset from pyg_spectral.nn.models.base_nn import BaseNN, BaseNNCompose -from pyg_spectral.utils import load_import +from pyg_spectral.utils import load_import, CallableDict -theta_param = { - # "impulse": (lambda x, _: x[0], (self.args.num_hops,), {}), +theta_param = CallableDict({ + "impulse": ('float', (0.0, 1.0), {'step': 0.01}, lambda x: round(x, 2)), "ones": ('float', (0.0, 1.0), {'step': 0.01}, lambda x: round(x, 2)), "impulse": ('float', (0.0, 1.0), {'step': 0.01}, lambda x: round(x, 2)), "appr": ('float', (0.0, 1.0), {'step': 0.01}, lambda x: round(x, 2)), @@ -22,7 +22,7 @@ "hk": ('float', (1e-2, 10), {'log': True}, lambda x: float(f'{x:.2e}')), "gaussian": ('float', (1e-2, 10), {'log': True}, lambda x: float(f'{x:.2e}')), "log": ('float', (1e-2, 10), {'log': True}, lambda x: float(f'{x:.2e}')), -} +}) def gen_theta(num_hops: int, scheme: str, param: Union[float, List[float]] = None) -> Tensor: @@ -147,7 +147,7 @@ class DecoupledFixed(BaseNN): name = 'DecoupledFixed' conv_name = lambda x, args: '-'.join([x, args.theta_scheme]) pargs = ['theta_scheme', 'theta_param'] - param = {'theta_param': lambda x: theta_param.get(x, None)} + param = {'theta_param': lambda args: theta_param(args.theta_scheme)} def init_conv(self, conv: str, @@ -192,7 +192,7 @@ class DecoupledVar(BaseNN): """ name = 'DecoupledVar' pargs = ['theta_scheme', 'theta_param'] - param = {'theta_param': lambda x: theta_param.get(x, None)} + param = {'theta_param': lambda args: theta_param(args.theta_scheme)} def init_conv(self, conv: str, @@ -246,7 +246,7 @@ class DecoupledFixedCompose(BaseNNCompose): name = 'DecoupledFixed' conv_name = lambda x, args: '-'.join([x, args.theta_scheme]) pargs = ['theta_scheme', 'theta_param'] - param = {'theta_param': lambda x: theta_param.get(x, None)} + param = {'theta_param': lambda args: theta_param(args.theta_scheme)} def init_conv(self, conv: str, @@ -305,7 +305,7 @@ class DecoupledVarCompose(BaseNNCompose): """ name = 'DecoupledVar' pargs = ['theta_scheme', 'theta_param'] - param = {'theta_param': lambda x: theta_param.get(x, None)} + param = {'theta_param': lambda args: theta_param(args.theta_scheme)} def init_conv(self, conv: str, diff --git a/pyg_spectral/nn/parse_args.py b/pyg_spectral/nn/parse_args.py index 34c1165..b49d220 100644 --- a/pyg_spectral/nn/parse_args.py +++ b/pyg_spectral/nn/parse_args.py @@ -1,8 +1,7 @@ -from functools import wraps - from .conv.base_mp import BaseMP from .models.base_nn import BaseNN from .models_pyg import model_regi_pyg, conv_regi_pyg +from pyg_spectral.utils import CallableDict def update_regi(regi, new_regi): @@ -16,6 +15,23 @@ def update_regi(regi, new_regi): model_regi = BaseNN.register_classes() model_regi = update_regi(model_regi, model_regi_pyg) +conv_regi = CallableDict.to_subcallableVal(conv_regi, ['pargs_default', 'param']) +r'''Fields: + * name (CallableDict[str, str]): Conv class logging path name. + * pargs (CallableDict[str, List[str]]): Conv arguments from argparse. + * pargs_default (Dict[str, CallableDict[str, Any]]): Default values for model arguments. Not recommended. + * param (Dict[str, CallableDict[str, ParamTuple]]): Conv parameters to tune. +''' +model_regi = CallableDict.to_subcallableVal(model_regi, ['pargs_default', 'param']) +r'''Fields: + name (CallableDict[str, str]): Model class logging path name. + conv_name (CallableDict[str, Callable[[str, Any], str]]): Wrap conv logging path name. + module (CallableDict[str, str]): Module for importing the model. + pargs (CallableDict[str, List[str]]): Model arguments from argparse. + pargs_default (Dict[str, CallableDict[str, Any]]): Default values for model arguments. Not recommended. + param (Dict[str, CallableDict[str, ParamTuple]]): Model parameters to tune. +''' + full_pargs = set(v for pargs in conv_regi['pargs'].values() for v in pargs) full_pargs.update(v for pargs in model_regi['pargs'].values() for v in pargs) @@ -28,50 +44,34 @@ def update_regi(regi, new_regi): 'DecoupledFixedCompose': { 'AdjiConv,AdjiConv-ones,ones': 'FAGNN', 'Adji2Conv,Adji2Conv-gaussian,gaussian': 'G2CN', - 'AdjDiffConv,AdjDiffConv-appr,appr': 'GNN-LFHF',}, + 'AdjDiffConv,AdjDiffConv-appr,appr': 'GNN_LFHF',}, 'DecoupledVarCompose': { 'AdjConv,ChebConv,BernConv': 'FiGURe',}, 'PrecomputedFixedCompose': { 'AdjSkipConv,AdjSkipConv-ones,ones': 'FAGNN', 'AdjSkip2Conv,AdjSkip2Conv-gaussian,gaussian': 'G2CN', - 'AdjDiffConv,AdjDiffConv-appr,appr': 'GNN-LFHF',}, + 'AdjDiffConv,AdjDiffConv-appr,appr': 'GNN_LFHF',}, 'PrecomputedVarCompose': { 'AdjConv,ChebConv,BernConv': 'FiGURe',} } - - -def resolve_func(nargs=1): - """Enable calling the return function with additional arguments. - - Args: - nargs: The number of arguments to pass to the decorated function. Arguments - beyond this number will be passed to the return function. - Examples: - ```python - @resolve_func(1) - def foo(bar): - return bar - foo(1) # 1 - foo(lambda x: x+1, 2) # 3 - ``` - """ - def decorator(func): - @wraps(func) - def wrapper(*inputs): - if len(inputs) <= nargs: - return func(*inputs[:nargs]) - else: - ret = func(*inputs[:nargs]) - if callable(ret): - return ret(*inputs[nargs:]) - return ret - return wrapper - return decorator - - -@resolve_func(2) -def get_dct(dct: dict, k: str) -> str: - return dct[k] +compose_param = { + 'DecoupledFixedCompose': { + 'G2CN': CallableDict({ + 'beta': [('float', (1.00, 2.00), {'step': 0.01}, lambda x: round(x, 2)), + ('float', (0.01, 1.00), {'step': 0.01}, lambda x: round(x, 2))],}), + 'GNN_LFHF': CallableDict({ + 'beta': [('float', ( 0.01, 1.00), {'step': 0.01}, lambda x: round(x, 2)), + ('float', (-1.00, -0.01), {'step': 0.01}, lambda x: round(x, 2))],}), + }, + 'PrecomputedFixedCompose': { + 'G2CN': CallableDict({ + 'beta': [('float', (1.00, 2.00), {'step': 0.01}, lambda x: round(x, 2)), + ('float', (0.01, 1.00), {'step': 0.01}, lambda x: round(x, 2))],}), + 'GNN_LFHF': CallableDict({ + 'beta': [('float', ( 0.01, 1.00), {'step': 0.01}, lambda x: round(x, 2)), + ('float', (-1.00, -0.01), {'step': 0.01}, lambda x: round(x, 2))],}), + }, +} def get_model_regi(model: str, k: str, args=None) -> str: @@ -84,7 +84,9 @@ def get_model_regi(model: str, k: str, args=None) -> str: Returns: value (str): The value of the model registry. """ - return get_dct(model_regi[k], model, args) + if not model in model_regi[k]: + return None + return model_regi[k](model, args) if args else model_regi[k][model] def get_conv_regi(conv: str, k: str, args=None) -> str: @@ -97,7 +99,9 @@ def get_conv_regi(conv: str, k: str, args=None) -> str: Returns: value (str): The value of the convolution registry. """ - return get_dct(conv_regi[k], conv, args) + if not conv in conv_regi[k]: + return None + return conv_regi[k](conv, args) if args else conv_regi[k][conv] def get_nn_name(model: str, conv: str, args) -> str: @@ -110,13 +114,12 @@ def get_nn_name(model: str, conv: str, args) -> str: Returns: nn_name (Tuple[str]): Name strings ``(model_name, conv_name)``. """ - model_name = get_dct(model_regi['name'], model, args) - conv_name = [get_dct(conv_regi['name'], channel, args) for channel in conv.split(',')] + model_name = model_regi['name'](model, args) + conv_name = [conv_regi['name'](channel, args) for channel in conv.split(',')] conv_name = ','.join(conv_name) - conv_name = get_dct(model_regi['conv_name'], model, conv_name, args) - if model in compose_name: - if conv_name in compose_name[model]: - conv_name = compose_name[model][conv_name] + conv_name = model_regi['conv_name'](model, conv_name, args) + if model in compose_name and conv_name in compose_name[model]: + conv_name = compose_name[model][conv_name] return (model_name, conv_name) @@ -132,36 +135,22 @@ def set_pargs(model: str, conv: str, args): kwargs (dict): Arguments for importing the model. """ valid_pargs = model_regi['pargs'][model] - valid_pargs.extend(conv_regi['pargs'][channel] for channel in conv.split(',')) + for channel in conv.split(','): + valid_pargs.extend(conv_regi['pargs'][channel]) kwargs = {} for parg in full_pargs: - if parg in valid_pargs and hasattr(args, parg): - kwargs[parg] = getattr(args, parg) - else: - delattr(args, parg) + if hasattr(args, parg): + if parg in valid_pargs: + kwargs[parg] = getattr(args, parg) + else: + delattr(args, parg) if model in model_regi['pargs_default']: for parg in model_regi['pargs_default'][model]: - kwargs.setdefault(parg, get_dct(model_regi['pargs_default'][model], parg, kwargs)) + kwargs.setdefault(parg, model_regi['pargs_default'][model](parg, kwargs)) for channel in conv.split(','): if channel in conv_regi['pargs_default']: for parg in conv_regi['pargs_default'][channel]: - kwargs.setdefault(parg, get_dct(conv_regi['pargs_default'][channel], parg, kwargs)) + kwargs.setdefault(parg, conv_regi['pargs_default'][channel](parg, kwargs)) return kwargs - - -def get_param(model: str, conv: str, parg: str, args) -> tuple: - r"""Query parameter settings for model+conv. - - Args: - model: The name of the model. - conv: The type of convolution. - parg: The name key of the parameter. - args: Configuration arguments. - Returns: - tune_tuple (dict): Configurations for tuning the model. - """ - if parg in model_regi['param'][model]: - return get_dct(model_regi['param'][model], parg, args) - return get_dct(conv_regi['param'][conv], parg, args) diff --git a/pyg_spectral/utils/__init__.py b/pyg_spectral/utils/__init__.py index 5c40409..101b6db 100755 --- a/pyg_spectral/utils/__init__.py +++ b/pyg_spectral/utils/__init__.py @@ -1,9 +1,9 @@ -from .loader import load_import +from .loader import load_import, CallableDict from .laplacian import get_laplacian from .dropout import dropout_edge __all__ = [ - 'load_import', + 'load_import', 'CallableDict', 'get_laplacian', 'dropout_edge' ] diff --git a/pyg_spectral/utils/loader.py b/pyg_spectral/utils/loader.py index 51a1759..bad9576 100755 --- a/pyg_spectral/utils/loader.py +++ b/pyg_spectral/utils/loader.py @@ -1,4 +1,5 @@ import importlib +from functools import wraps def load_import(class_name, module_name): @@ -7,3 +8,65 @@ def load_import(class_name, module_name): class_obj = getattr(module, class_name) if isinstance(class_obj, type): return class_obj + + +def resolve_func(nargs=1): + """Enable calling the return function with additional arguments. + + Args: + nargs: The number of arguments to pass to the decorated function. Arguments + beyond this number will be passed to the return function. + Examples: + ```python + @resolve_func(1) + def foo(bar): + return bar + foo(1) # 1 + foo(lambda x: x+1, 2) # 3 + ``` + """ + def decorator(func): + @wraps(func) + def wrapper(*inputs): + if len(inputs) <= nargs: + return func(*inputs[:nargs]) + else: + ret = func(*inputs[:nargs]) + if callable(ret): + return ret(*inputs[nargs:]) + return ret + return wrapper + return decorator + + +class CallableDict(dict): + def __call__(self, key, *args): + def _get_callable(key): + ret = self.get(key, None) + if callable(ret): + return ret(*args) + return ret + + if not key in self: + if ',' in key: + return [_get_callable(k) for k in key.split(',')] + raise ValueError(f"Key '{key}' not found in {self.keys()}.") + return _get_callable(key) + + @classmethod + def to_callableVal(cls, dct, keys=None): + keys = keys or dct.keys() + for key in keys: + if isinstance(dct[key], dict): + dct[key] = cls(dct[key]) + return dct + + @classmethod + def to_subcallableVal(cls, dct, keys=[]): + for key in dct: + if key in keys: + dct[key] = cls.to_callableVal(dct[key]) + else: + if isinstance(dct[key], dict): + dct[key] = cls(dct[key]) + return dct