Skip to content

Commit

Permalink
Update run_param with new CallableDict
Browse files Browse the repository at this point in the history
  • Loading branch information
nyLiao committed Oct 6, 2024
1 parent 0c85e54 commit 8a178a9
Show file tree
Hide file tree
Showing 12 changed files with 208 additions and 163 deletions.
108 changes: 34 additions & 74 deletions benchmark/run_param.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@
import uuid
from copy import deepcopy

from pyg_spectral.nn import get_model_regi, get_conv_regi
from pyg_spectral.nn.parse_args import compose_param
from pyg_spectral.utils import CallableDict

from trainer import (
SingleGraphLoader_Trial,
ModelLoader_Trial,
Expand All @@ -34,75 +38,34 @@ def __init__(self, data_loader, model_loader, args, res_logger = None):
self.fmt_logger = {}
self.metric = None

self.data, self.model, self.trn_cls, self.trn = None, None, None, None
self.data, self.model, self.trn = None, None, None
self.trn_cls = {
TrnFullbatch: TrnFullbatch_Trial,
TrnMinibatch: TrnMinibatch_Trial,
}[self.model_loader.get_trn(args)]

def _get_suggest(self, trial, key):
list2str = lambda x: ','.join(map(str, x))
nofmt = lambda x: x
# >>>>>>>>>>
theta_dct = {
# "impulse": (lambda x, _: x[0], (self.args.num_hops,), {}),
"ones": (trial.suggest_float, (0.0, 1.0), {'step': 0.01}, lambda x: round(x, 2)),
"impulse": (trial.suggest_float, (0.0, 1.0), {'step': 0.01}, lambda x: round(x, 2)),
"appr": (trial.suggest_float, (0.0, 1.0), {'step': 0.01}, lambda x: round(x, 2)),
"nappr": (trial.suggest_float, (0.0, 1.0), {'step': 0.01}, lambda x: round(x, 2)),
"mono": (trial.suggest_float, (0.0, 1.0), {'step': 0.01}, lambda x: round(x, 2)),
"hk": (trial.suggest_float, (1e-2, 10), {'log': True}, lambda x: float(f'{x:.3e}')),
"gaussian": (trial.suggest_float, (1e-2, 10), {'log': True}, lambda x: float(f'{x:.3e}')),
}
suggest_dct = {
# critical
'num_hops': (trial.suggest_int, (2, 30), {'step': 2}, nofmt),
'in_layers': (trial.suggest_int, (1, 3), {}, nofmt),
'out_layers': (trial.suggest_int, (1, 3), {}, nofmt),
'hidden_channels': (trial.suggest_categorical, ([16, 32, 64, 128, 256],), {}, nofmt),
'combine': (trial.suggest_categorical, (["sum", "sum_weighted", "cat"],), {}, nofmt),
# secondary
'theta_param': theta_dct.get(self.args.theta_scheme, None),
'normg': (trial.suggest_float, (0.0, 1.0), {'step': 0.05}, lambda x: round(x, 2)),
'dropout_lin': (trial.suggest_float, (0.0, 1.0), {'step': 0.1}, lambda x: round(x, 2)),
'dropout_conv': (trial.suggest_float, (0.0, 1.0), {'step': 0.1}, lambda x: round(x, 2)),
'lr_lin': (trial.suggest_float, (1e-5, 5e-1), {'log': True}, lambda x: float(f'{x:.3e}')),
'lr_conv': (trial.suggest_float, (1e-5, 5e-1), {'log': True}, lambda x: float(f'{x:.3e}')),
'wd_lin': (trial.suggest_float, (1e-7, 1e-3), {'log': True}, lambda x: float(f'{x:.3e}')),
'wd_conv': (trial.suggest_float, (1e-7, 1e-3), {'log': True}, lambda x: float(f'{x:.3e}')),
'alpha': (trial.suggest_float, (0.01, 1.0), {'step': 0.01}, lambda x: round(x, 2)),
'beta': (trial.suggest_float, (0.0, 1.0), {'step': 0.01}, lambda x: round(x, 2)),
}

# Model/conv-specific
# if self.args.model in ['Iterative']:
# suggest_dct['in_layers'][1] = (1, 3)
# suggest_dct['out_layers'][1] = (1, 3)
# <<<<<<<<<<

if 'Compose' in self.args.model:
convs = self.args.conv.split(',')
if key == 'theta_param':
schemes = self.args.theta_scheme.split(',')
lst = []
for i,c in enumerate(convs):
func, fargs, fkwargs, fmt = theta_dct.get(schemes[i], None)
lst.append(func(key+'-'+str(i), *fargs, **fkwargs))
return lst, fmt
elif key == 'beta':
func, fargs, fkwargs, fmt = suggest_dct[key]
beta_c = {
'AdjiConv': [(0.0, 1.0), (0.0, 1.0)], # FAGNN
'AdjSkipConv': [(0.0, 1.0), (0.0, 1.0)], # FAGNN
'Adji2Conv': [(1.0, 2.0), (0.0, 1.0)], # G2CN
'AdjSkip2Conv': [(1.0, 2.0), (0.0, 1.0)], # G2CN
'AdjDiffConv': [(0.0, 1.0), (-1.0, 0.0)], # GNN-LF/HF
}
lst = [func(key+'-'+str(i), *beta_i, **fkwargs) for i,beta_i in enumerate(beta_c[convs[0]])]
# return list2str(lst), str
return lst, fmt
else:
func, fargs, fkwargs, fmt = suggest_dct[key]
return func(key, *fargs, **fkwargs), fmt
else:
func, fargs, fkwargs, fmt = suggest_dct[key]
return func(key, *fargs, **fkwargs), fmt
def parse_param(val):
if isinstance(val, list):
fmt = val[0][-1]
val = [getattr(trial, 'suggest_'+func)(key+'-'+str(i), *fargs, **fkwargs) for i, (func, fargs, fkwargs, _) in enumerate(val)]
return val, fmt
func, fargs, fkwargs, fmt = val
return getattr(trial, 'suggest_'+func)(key, *fargs, **fkwargs), fmt

# Alias compose models
if (self.args.model in compose_param and
self.model_loader.conv_repr in compose_param[self.args.model] and
key in compose_param[self.args.model][self.model_loader.conv_repr]):
return parse_param(compose_param[self.args.model][self.model_loader.conv_repr](key, self.args))

single_param = SingleGraphLoader_Trial.param | ModelLoader_Trial.param | self.trn_cls.param
single_param = CallableDict(single_param)
single_param |= get_model_regi(self.args.model, 'param')
if key in single_param:
return parse_param(single_param(key, self.args))

return parse_param(get_conv_regi(self.args.conv, 'param')(key, self.args))

def __call__(self, trial):
args = deepcopy(self.args)
Expand All @@ -115,18 +78,15 @@ def __call__(self, trial):

if self.data is None:
self.data = self.data_loader.get(args)
self.model, trn_cls = self.model_loader.get(args)
self.trn_cls = {
TrnFullbatch: TrnFullbatch_Trial,
TrnMinibatch: TrnMinibatch_Trial,
}[trn_cls]
self.model, _ = self.model_loader.get(args)

for key in ['in_channels', 'out_channels', 'metric', 'multi', 'criterion']:
for key in SingleGraphLoader_Trial.args_out + ModelLoader_Trial.args_out:
self.args.__dict__[key] = args.__dict__[key]
self.metric = args.metric
else:
self.data = self.data_loader.update(args, self.data)
self.model = self.model_loader.update(args, self.model)

res_logger = deepcopy(self.res_logger)
for key in self.args.param:
val = args.__dict__[key]
Expand Down Expand Up @@ -220,7 +180,7 @@ def main(args):
best_params = {k: trn.fmt_logger[k](v) for k, v in study.best_params.items()}
save_args(args.logpath, best_params)
axes = optuna.visualization.matplotlib.plot_parallel_coordinate(
study, params=best_params.keys())
study, params=study.best_params.keys())
axes.get_figure().savefig(args.logpath.joinpath('parallel_coordinate.png'))
clear_logger(logger)

Expand Down
6 changes: 6 additions & 0 deletions benchmark/trainer/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,12 @@ class TrnBase(object):
run: Run the training process.
"""
name: str
param = {
'lr_lin': ('float', (1e-5, 5e-1), {'log': True}, lambda x: float(f'{x:.3e}')),
'lr_conv': ('float', (1e-5, 5e-1), {'log': True}, lambda x: float(f'{x:.3e}')),
'wd_lin': ('float', (1e-7, 1e-3), {'log': True}, lambda x: float(f'{x:.3e}')),
'wd_conv': ('float', (1e-7, 1e-3), {'log': True}, lambda x: float(f'{x:.3e}')),
}

def __init__(self,
model: nn.Module,
Expand Down
1 change: 1 addition & 0 deletions benchmark/trainer/fullbatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -194,3 +194,4 @@ def run(self) -> ResLogger:

def update(self, *args, **kwargs):
self.__init__(*args, **kwargs)
return self
5 changes: 5 additions & 0 deletions benchmark/trainer/load_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,11 @@ class SingleGraphLoader(object):
* args.data_split (str): Index of dataset split.
res_logger: Logger for results.
"""
args_out = ['in_channels', 'out_channels', 'multi', 'metric']
param = {
'normg': ('float', (0.0, 1.0), {'step': 0.05}, lambda x: round(x, 2)),
}

def __init__(self, args: Namespace, res_logger: ResLogger = None) -> None:
# Assigning dataset identity.
self.seed = args.seed
Expand Down
37 changes: 28 additions & 9 deletions benchmark/trainer/load_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@ class ModelLoader(object):
* args.conv (str): Convolution layer name.
res_logger: Logger for results.
"""
args_out = ['criterion']
param = {}

def __init__(self, args: Namespace, res_logger: ResLogger = None) -> None:
r"""Assigning model identity.
"""
Expand All @@ -52,24 +55,36 @@ def get_name(args: Namespace) -> Tuple[str]:
"""
return get_nn_name(args.model, args.conv, args)

def _resolve_import(self, args: Namespace) -> Tuple[str, str, dict, TrnBase]:
class_name = self.model
module_name = get_model_regi(self.model, 'module', args)
kwargs = set_pargs(self.model, self.conv, args)
trn = {
@staticmethod
def get_trn(args: Namespace) -> TrnBase:
r"""Get trainer class from model name.
Args:
args: Configuration arguments.
Returns:
trn (TrnBase): Trainer class.
"""
model_repr, _ = get_nn_name(args.model, args.conv, args)
return {
'DecoupledFixed': TrnFullbatch,
'DecoupledVar': TrnFullbatch,
'Iterative': TrnFullbatch,
'IterativeFixed': TrnFullbatch,
'PrecomputedVar': TrnMinibatch,
'PrecomputedFixed': TrnMinibatch,
'CppPrecFixed': TrnMinibatch,
}[self.model_repr]
}[model_repr]

def _resolve_import(self, args: Namespace) -> Tuple[str, str, dict]:
class_name = self.model
module_name = get_model_regi(self.model, 'module', args)
kwargs = set_pargs(self.model, self.conv, args)

# >>>>>>>>>>
if module_name == 'torch_geometric.nn.models':
args.criterion = 'BCEWithLogitsLoss' if args.out_channels == 1 else 'CrossEntropyLoss'

del kwargs['conv']
kwargs.setdefault('num_layers', kwargs.pop('num_hops'))
kwargs.setdefault('dropout', kwargs.pop('dropout_lin'))

Expand All @@ -86,7 +101,7 @@ def _resolve_import(self, args: Namespace) -> Tuple[str, str, dict, TrnBase]:
kwargs['num_hops'] = int(kwargs['num_hops'] / 2)
# Parse model args
# <<<<<<<<<<
return class_name, module_name, kwargs, trn
return class_name, module_name, kwargs

def get(self, args: Namespace) -> Tuple[nn.Module, TrnBase]:
r"""Load model with specified arguments.
Expand All @@ -99,10 +114,13 @@ def get(self, args: Namespace) -> Tuple[nn.Module, TrnBase]:
args.out_channels (int): Number of output classes.
args.hidden_channels (int): Number of hidden units.
args.dropout_[lin/conv] (float): Dropout rate for linear/conv.
Updates:
args.criterion (str): Loss function name.
"""
self.logger.debug('-'*20 + f" Loading model: {self} " + '-'*20)

class_name, module_name, kwargs, trn = self._resolve_import(args)
trn = self.get_trn(args)
class_name, module_name, kwargs = self._resolve_import(args)
model = load_import(class_name, module_name)(**kwargs)
if hasattr(model, 'reset_parameters'):
model.reset_parameters()
Expand All @@ -129,7 +147,8 @@ def get(self, args: Namespace) -> Tuple[nn.Module, TrnBase]:
self.signature_lst = ['num_hops', 'in_layers', 'out_layers', 'hidden_channels', 'dropout_lin', 'dropout_conv']
self.signature = {key: getattr(args, key) for key in self.signature_lst}

class_name, module_name, kwargs, trn = self._resolve_import(args)
trn = self.get_trn(args)
class_name, module_name, kwargs = self._resolve_import(args)
model = load_import(class_name, module_name)(**kwargs)
if hasattr(model, 'reset_parameters'):
model.reset_parameters()
Expand Down
1 change: 1 addition & 0 deletions benchmark/trainer/minibatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,3 +272,4 @@ def update(self,
if not hasattr(self, 'signature') or self.signature != signature:
self.signature = signature
self.embed = None
return self
1 change: 1 addition & 0 deletions benchmark/utils/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
warnings.filterwarnings('ignore', '.*No positive samples found in target.*')
warnings.filterwarnings('ignore', '.*No negative samples found in target.*')
warnings.filterwarnings('ignore', '.*is( an)? experimental.*')
warnings.filterwarnings('ignore', '.*Attempting to set identical low and high ylims.*')


def setup_logger(logpath: Union[Path, str] = LOGPATH,
Expand Down
2 changes: 1 addition & 1 deletion pyg_spectral/nn/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .parse_args import (
get_model_regi, get_conv_regi,
get_nn_name, set_pargs, get_param
get_nn_name, set_pargs
)
16 changes: 8 additions & 8 deletions pyg_spectral/nn/models/decoupled.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@
from torch_geometric.nn.inits import reset

from pyg_spectral.nn.models.base_nn import BaseNN, BaseNNCompose
from pyg_spectral.utils import load_import
from pyg_spectral.utils import load_import, CallableDict


theta_param = {
# "impulse": (lambda x, _: x[0], (self.args.num_hops,), {}),
theta_param = CallableDict({
"impulse": ('float', (0.0, 1.0), {'step': 0.01}, lambda x: round(x, 2)),
"ones": ('float', (0.0, 1.0), {'step': 0.01}, lambda x: round(x, 2)),
"impulse": ('float', (0.0, 1.0), {'step': 0.01}, lambda x: round(x, 2)),
"appr": ('float', (0.0, 1.0), {'step': 0.01}, lambda x: round(x, 2)),
Expand All @@ -22,7 +22,7 @@
"hk": ('float', (1e-2, 10), {'log': True}, lambda x: float(f'{x:.2e}')),
"gaussian": ('float', (1e-2, 10), {'log': True}, lambda x: float(f'{x:.2e}')),
"log": ('float', (1e-2, 10), {'log': True}, lambda x: float(f'{x:.2e}')),
}
})


def gen_theta(num_hops: int, scheme: str, param: Union[float, List[float]] = None) -> Tensor:
Expand Down Expand Up @@ -147,7 +147,7 @@ class DecoupledFixed(BaseNN):
name = 'DecoupledFixed'
conv_name = lambda x, args: '-'.join([x, args.theta_scheme])
pargs = ['theta_scheme', 'theta_param']
param = {'theta_param': lambda x: theta_param.get(x, None)}
param = {'theta_param': lambda args: theta_param(args.theta_scheme)}

def init_conv(self,
conv: str,
Expand Down Expand Up @@ -192,7 +192,7 @@ class DecoupledVar(BaseNN):
"""
name = 'DecoupledVar'
pargs = ['theta_scheme', 'theta_param']
param = {'theta_param': lambda x: theta_param.get(x, None)}
param = {'theta_param': lambda args: theta_param(args.theta_scheme)}

def init_conv(self,
conv: str,
Expand Down Expand Up @@ -246,7 +246,7 @@ class DecoupledFixedCompose(BaseNNCompose):
name = 'DecoupledFixed'
conv_name = lambda x, args: '-'.join([x, args.theta_scheme])
pargs = ['theta_scheme', 'theta_param']
param = {'theta_param': lambda x: theta_param.get(x, None)}
param = {'theta_param': lambda args: theta_param(args.theta_scheme)}

def init_conv(self,
conv: str,
Expand Down Expand Up @@ -305,7 +305,7 @@ class DecoupledVarCompose(BaseNNCompose):
"""
name = 'DecoupledVar'
pargs = ['theta_scheme', 'theta_param']
param = {'theta_param': lambda x: theta_param.get(x, None)}
param = {'theta_param': lambda args: theta_param(args.theta_scheme)}

def init_conv(self,
conv: str,
Expand Down
Loading

0 comments on commit 8a178a9

Please sign in to comment.