diff --git a/benchmark/dataset/linkx.py b/benchmark/dataset/linkx.py index e9c0f7a..913c358 100644 --- a/benchmark/dataset/linkx.py +++ b/benchmark/dataset/linkx.py @@ -1,5 +1,5 @@ import os.path as osp -from typing import Any, Callable, Optional +from typing import Any, Callable from argparse import Namespace import numpy as np @@ -46,8 +46,8 @@ def __init__( self, root: str, name: str, - transform: Optional[Callable] = None, - pre_transform: Optional[Callable] = None, + transform: Callable | None = None, + pre_transform: Callable | None = None, force_reload: bool = False, ) -> None: self.name = name.lower() @@ -156,8 +156,8 @@ def __init__( self, root: str, name: str, - transform: Optional[Callable] = None, - pre_transform: Optional[Callable] = None, + transform: Callable | None = None, + pre_transform: Callable | None = None, force_reload: bool = False, ) -> None: self.name = name.lower() diff --git a/benchmark/dataset/utils.py b/benchmark/dataset/utils.py index 829a9b8..2e51b3e 100644 --- a/benchmark/dataset/utils.py +++ b/benchmark/dataset/utils.py @@ -1,4 +1,3 @@ -from typing import Tuple from argparse import Namespace import numpy as np from sklearn.model_selection import train_test_split @@ -107,7 +106,7 @@ def split_crossval(label: torch.Tensor, r_val: float, seed: int = None, ignore_neg: bool =True, - stratify: bool =False) -> Tuple[torch.Tensor]: + stratify: bool =False) -> tuple[torch.Tensor]: r"""Split index by cross-validation""" node_labeled = torch.where(label >= 0)[0] if ignore_neg else np.arange(label.shape[0]) diff --git a/benchmark/dataset/yandex.py b/benchmark/dataset/yandex.py index 86cf97d..214a5cb 100644 --- a/benchmark/dataset/yandex.py +++ b/benchmark/dataset/yandex.py @@ -1,5 +1,5 @@ import os.path as osp -from typing import Callable, Optional +from typing import Callable from argparse import Namespace import numpy as np @@ -26,8 +26,8 @@ def __init__( self, root: str, name: str, - transform: Optional[Callable] = None, - pre_transform: Optional[Callable] = None, + transform: Callable | None = None, + pre_transform: Callable | None = None, force_reload: bool = False, ) -> None: self.name = name.lower() diff --git a/benchmark/trainer/base.py b/benchmark/trainer/base.py index cab14dd..a20d59a 100755 --- a/benchmark/trainer/base.py +++ b/benchmark/trainer/base.py @@ -3,7 +3,6 @@ Author: nyLiao File Created: 2024-03-03 """ -from typing import List import logging from argparse import Namespace @@ -145,8 +144,8 @@ def wrapper(self, *args, **kwargs): # ===== Run block @_log_memory(split='train') def train_val(self, - split_train: List[str] = ['train'], - split_val: List[str] = ['val']) -> ResLogger: + split_train: list[str] = ['train'], + split_val: list[str] = ['val']) -> ResLogger: r"""Pipeline for iterative training. Args: @@ -187,7 +186,7 @@ def train_val(self, @_log_memory(split='eval') def test(self, - split_test: List[str] = ['train', 'val', 'test']) -> ResLogger: + split_test: list[str] = ['train', 'val', 'test']) -> ResLogger: r"""Pipeline for testing. Args: split_test (list): Testing splits. @@ -206,11 +205,11 @@ def _fetch_input(self) -> tuple: r"""Process each sample of model input and label.""" raise NotImplementedError - def _learn_split(self, split: List[str] = ['train']) -> ResLogger: + def _learn_split(self, split: list[str] = ['train']) -> ResLogger: r"""Actual train iteration on the given splits.""" raise NotImplementedError - def _eval_split(self, split: List[str]) -> ResLogger: + def _eval_split(self, split: list[str]) -> ResLogger: r"""Actual test on the given splits.""" raise NotImplementedError @@ -258,8 +257,8 @@ def clear(self): if self.optimizer: del self.optimizer def train_val(self, - split_train: List[str] = ['train'], - split_val: List[str] = ['val']) -> ResLogger: + split_train: list[str] = ['train'], + split_val: list[str] = ['val']) -> ResLogger: import optuna time_learn = profile.Accumulator() diff --git a/benchmark/trainer/fullbatch.py b/benchmark/trainer/fullbatch.py index f1682f0..dbd5bdb 100755 --- a/benchmark/trainer/fullbatch.py +++ b/benchmark/trainer/fullbatch.py @@ -3,7 +3,6 @@ Author: nyLiao File Created: 2024-02-26 """ -from typing import Tuple from argparse import Namespace import torch @@ -57,7 +56,7 @@ def clear(self): del self.mask, self.data return super().clear() - def _fetch_data(self) -> Tuple[Data, dict]: + def _fetch_data(self) -> tuple[Data, dict]: r"""Process the single graph data.""" t_to_device = T.ToDevice(self.device, attrs=['x', 'y', 'adj_t', 'edge_index'] + [f'{k}_mask' for k in self.splits]) self.data = t_to_device(self.data) diff --git a/benchmark/trainer/load_metric.py b/benchmark/trainer/load_metric.py index d2f78e3..64e655c 100755 --- a/benchmark/trainer/load_metric.py +++ b/benchmark/trainer/load_metric.py @@ -3,7 +3,7 @@ Author: nyLiao File Created: 2024-03-03 """ -from typing import Tuple, List, Callable, Any +from typing import Callable, Any from argparse import Namespace from torchmetrics import MetricCollection from torchmetrics.classification import ( @@ -15,7 +15,7 @@ class ResCollection(MetricCollection): - def compute(self) -> List[Tuple[str, Any, Callable]]: + def compute(self) -> list[tuple[str, Any, Callable]]: r"""Wrap compute output to :class:`ResLogger` style.""" dct = self._compute_and_reduce("compute") return [(k, v.cpu().numpy(), (lambda x: format(x*100, '.3f'))) for k, v in dct.items()] diff --git a/benchmark/trainer/load_model.py b/benchmark/trainer/load_model.py index f5b757f..82a4765 100755 --- a/benchmark/trainer/load_model.py +++ b/benchmark/trainer/load_model.py @@ -3,7 +3,6 @@ Author: nyLiao File Created: 2024-02-26 """ -from typing import Tuple from argparse import Namespace import logging import torch.nn as nn @@ -40,7 +39,7 @@ def __init__(self, args: Namespace, res_logger: ResLogger = None) -> None: self.res_logger = res_logger or ResLogger() @staticmethod - def get_name(args: Namespace) -> Tuple[str]: + def get_name(args: Namespace) -> tuple[str]: """Get model+conv name for logging path from argparse input without instantiation. Wrapper for :func:`pyg_spectral.nn.get_nn_name()`. @@ -51,7 +50,7 @@ def get_name(args: Namespace) -> Tuple[str]: * args.conv (str): Convolution layer name. * other args specified in module :attr:`name` function. Returns: - nn_name (Tuple[str]): Name strings ``(model_name, conv_name)``. + nn_name (tuple[str]): Name strings ``(model_name, conv_name)``. """ return get_nn_name(args.model, args.conv, args) @@ -75,7 +74,7 @@ def get_trn(args: Namespace) -> TrnBase: 'CppPrecFixed': TrnMinibatch, }[model_repr] - def _resolve_import(self, args: Namespace) -> Tuple[str, str, dict]: + def _resolve_import(self, args: Namespace) -> tuple[str, str, dict]: class_name = self.model module_name = get_model_regi(self.model, 'module', args) kwargs = set_pargs(self.model, self.conv, args) @@ -103,7 +102,7 @@ def _resolve_import(self, args: Namespace) -> Tuple[str, str, dict]: # <<<<<<<<<< return class_name, module_name, kwargs - def get(self, args: Namespace) -> Tuple[nn.Module, TrnBase]: + def get(self, args: Namespace) -> tuple[nn.Module, TrnBase]: r"""Load model with specified arguments. Args: @@ -143,7 +142,7 @@ def __str__(self) -> str: class ModelLoader_Trial(ModelLoader): r"""Reuse necessary data for multiple runs. """ - def get(self, args: Namespace) -> Tuple[nn.Module, TrnBase]: + def get(self, args: Namespace) -> tuple[nn.Module, TrnBase]: self.signature_lst = ['num_hops', 'in_layers', 'out_layers', 'hidden_channels', 'dropout_lin', 'dropout_conv'] self.signature = {key: getattr(args, key) for key in self.signature_lst} diff --git a/benchmark/trainer/minibatch.py b/benchmark/trainer/minibatch.py index 45e41b5..b573664 100755 --- a/benchmark/trainer/minibatch.py +++ b/benchmark/trainer/minibatch.py @@ -3,7 +3,7 @@ Author: nyLiao File Created: 2024-03-03 """ -from typing import Tuple, Generator +from typing import Generator import logging from argparse import Namespace diff --git a/benchmark/trainer/regression.py b/benchmark/trainer/regression.py index e5f65f9..b25150f 100644 --- a/benchmark/trainer/regression.py +++ b/benchmark/trainer/regression.py @@ -1,4 +1,3 @@ -from typing import Tuple from argparse import Namespace import logging @@ -63,7 +62,7 @@ def __init__(self, args: Namespace, res_logger: ResLogger = None) -> None: ]) # ===== Data acquisition - def _resolve_import(self, args: Namespace) -> Tuple[str, str, dict]: + def _resolve_import(self, args: Namespace) -> tuple[str, str, dict]: assert self.data in ['2dgrid'] module_name = 'dataset' class_name = 'Grid2D' diff --git a/benchmark/utils/checkpoint.py b/benchmark/utils/checkpoint.py index e815ff2..4b09768 100755 --- a/benchmark/utils/checkpoint.py +++ b/benchmark/utils/checkpoint.py @@ -3,7 +3,7 @@ Author: nyLiao File Created: 2023-03-20 """ -from typing import Union, Callable +from typing import Callable from pathlib import Path import copy @@ -26,12 +26,12 @@ class CkptLogger(object): metric_cmp: Comparison function for the metric. Can be 'max' or 'min'. """ def __init__(self, - logpath: Union[Path, str], + logpath: Path | str, patience: int = -1, period: int = 0, prefix: str = 'model', storage: str = 'state_gpu', - metric_cmp: Union[Callable[[float, float], bool], str]='max'): + metric_cmp: Callable[[float, float], bool] | str='max'): self.logpath = Path(logpath) self.prefix = prefix self.filetype = 'pth' diff --git a/benchmark/utils/logger.py b/benchmark/utils/logger.py index 216c5d4..480d5b3 100755 --- a/benchmark/utils/logger.py +++ b/benchmark/utils/logger.py @@ -3,7 +3,7 @@ Author: nyLiao File Created: 2023-03-20 """ -from typing import Tuple, List, Dict, Callable, Union, Any +from typing import Callable, Any from pathlib import Path import os @@ -28,7 +28,7 @@ warnings.filterwarnings('ignore', '.*Attempting to set identical low and high ylims.*') -def setup_logger(logpath: Union[Path, str] = LOGPATH, +def setup_logger(logpath: Path | str = LOGPATH, level_console: int = logging.DEBUG, level_file: int = 15, quiet: bool = True, @@ -71,8 +71,8 @@ def clear_logger(logger: logging.Logger): logger.info(f"[time]: {datetime.now()}") -def setup_logpath(dir: Union[Path, str] = LOGPATH, - folder_args: Tuple=None, +def setup_logpath(dir: Path | str = LOGPATH, + folder_args: tuple=None, quiet: bool = True): r"""Resolve log path for saving. @@ -103,7 +103,7 @@ class ResLogger(object): quiet: Quiet run without saving file. """ def __init__(self, - logpath: Union[Path, str] = LOGPATH, + logpath: Path | str = LOGPATH, prefix: str = 'summary', suffix: str = None, quiet: bool = True): @@ -163,7 +163,7 @@ def _set(self, data: DataFrame, fmt: Series): self.fmt = self.fmt.combine_first(fmt) def concat(self, - vals: Union[List[Tuple[str, Any, Callable]], Dict], + vals: list[tuple[str, Any, Callable]] | dict, row: int = 0, suffix: str = None): r"""Concatenate data entries of a single row to data. @@ -203,7 +203,7 @@ def __call__(self, *args, **kwargs) -> 'ResLogger': def merge(self, logger: 'ResLogger', - rows: List[int] = None, + rows: list[int] = None, suffix: str = None): r"""Merge from another logger. @@ -222,7 +222,7 @@ def merge(self, self._set(logger.data, logger.fmt) return self - def del_col(self, col: Union[List, str]) -> 'ResLogger': + def del_col(self, col: list | str) -> 'ResLogger': r"""Delete columns from data. Args: @@ -234,8 +234,8 @@ def del_col(self, col: Union[List, str]) -> 'ResLogger': # ===== Output def _get(self, - col: Union[List, str]=None, - row: Union[List, str]=None) -> Union[DataFrame, Series, str]: + col: list | str=None, + row: list | str=None) -> DataFrame | Series | str: r"""Retrieve one or sliced data and apply string format. Args: @@ -289,8 +289,8 @@ def save(self): mode='a', header=f.tell()==0) def get_str(self, - col: Union[List, str] = None, - row: Union[List, int] = None, + col: list | str = None, + row: list | int = None, maxlen: int = -1) -> str: r"""Get formatted long string for printing of the specified columns and rows. diff --git a/pyg_spectral/nn/conv/base_mp.py b/pyg_spectral/nn/conv/base_mp.py index 20530db..09154d6 100644 --- a/pyg_spectral/nn/conv/base_mp.py +++ b/pyg_spectral/nn/conv/base_mp.py @@ -1,5 +1,5 @@ -from typing import Any, Callable, Dict, List, Tuple, Optional -type ParamTuple = Tuple[str, tuple, Dict[str, Any], Callable[[Any], str]] +from typing import Any, Callable, NewType +ParamTuple = NewType('ParamTuple', tuple[str, tuple, dict[str, Any], Callable[[Any], str]]) import re import torch @@ -27,25 +27,28 @@ class BaseMP(MessagePassing): supports_batch: bool = True supports_norm_batch: bool = True name: Callable[[Any], str] # args -> str - pargs: List[str] = [] - param: Dict[str, ParamTuple] = {} - _cache: Optional[Any] + pargs: list[str] = [] + param: dict[str, ParamTuple] = {} + _cache: Any | None @classmethod - def register_classes(cls, registry: Dict[str, Dict[str, Any]] = CONV_REGI_INIT): + def register_classes(cls, registry: dict[str, dict[str, Any]] = None) -> dict: r"""Register args for all subclass. Args: - name (Dict[str, str]): Conv class logging path name. - pargs (Dict[str, List[str]]): Conv arguments from argparse. - pargs_default (Dict[str, Dict[str, Any]]): Default values for model arguments. Not recommended. - param (Dict[str, Dict[str, ParamTuple]]): Conv parameters to tune. + name (dict[str, str]): Conv class logging path name. + pargs (dict[str, list[str]]): Conv arguments from argparse. + pargs_default (dict[str, dict[str, Any]]): Default values for model arguments. Not recommended. + param (dict[str, dict[str, ParamTuple]]): Conv parameters to tune. * (str) parameter type, * (tuple) args for :func:`optuna.trial.suggest_`, * (dict) kwargs for :func:`optuna.trial.suggest_`, * (callable) format function to str. """ + if registry is None: + registry = CONV_REGI_INIT + for subcls in cls.__subclasses__(): subname = subcls.__name__ # Traverse the MRO and accumulate args from parent classes @@ -186,7 +189,7 @@ def _get_convolute_mat(self, x: Tensor, edge_index: Adj) -> dict: def get_forward_mat(self, x: Tensor, edge_index: Adj, - comp_scheme: Optional[str] = None + comp_scheme: str | None = None ) -> dict: r"""Get matrices for :meth:`forward()`. Called during :meth:`forward()`. diff --git a/pyg_spectral/nn/conv/favard_conv.py b/pyg_spectral/nn/conv/favard_conv.py index 6a46eb9..55bc4c1 100644 --- a/pyg_spectral/nn/conv/favard_conv.py +++ b/pyg_spectral/nn/conv/favard_conv.py @@ -1,4 +1,3 @@ -from typing import Union import copy import torch import torch.nn as nn @@ -83,7 +82,7 @@ def _forward(self, x: Tensor, x_1: Tensor, prop: Adj, - alpha_1: Union[nn.Parameter, nn.Module] + alpha_1: nn.Parameter | nn.Module ) -> dict: r""" Returns: diff --git a/pyg_spectral/nn/models/base_nn.py b/pyg_spectral/nn/models/base_nn.py index 501f63b..b7d6d05 100644 --- a/pyg_spectral/nn/models/base_nn.py +++ b/pyg_spectral/nn/models/base_nn.py @@ -1,5 +1,5 @@ -from typing import Any, Callable, Dict, Final, List, Tuple, Optional, Union -type ParamTuple = Tuple[str, tuple, Dict[str, Any], Callable[[Any], str]] +from typing import Any, Callable, Final, NewType +ParamTuple = NewType('ParamTuple', tuple[str, tuple, dict[str, Any], Callable[[Any], str]]) import torch import torch.nn as nn @@ -45,35 +45,38 @@ class BaseNN(nn.Module): supports_batch: bool name: str conv_name: Callable[[str, Any], str] = lambda x, args: x - pargs: List[str] = ['conv', 'num_hops', 'in_layers', 'out_layers', + pargs: list[str] = ['conv', 'num_hops', 'in_layers', 'out_layers', 'in_channels', 'hidden_channels', 'out_channels', 'dropout_lin', 'dropout_conv',] - param: Dict[str, ParamTuple] = { + param: dict[str, ParamTuple] = { 'num_hops': ('int', (2, 30), {'step': 2}, lambda x: x), 'in_layers': ('int', (1, 3), {}, lambda x: x), 'out_layers': ('int', (1, 3), {}, lambda x: x), 'hidden_channels': ('categorical', ([16, 32, 64, 128, 256],), {}, lambda x: x), - 'dropout_lin': ('float', (0.0, 1.0), {'step': 0.1}, lambda x: round(x, 2)), - 'dropout_conv': ('float', (0.0, 1.0), {'step': 0.1}, lambda x: round(x, 2)), + 'dropout_lin': ('float', (0.0, 0.9), {'step': 0.1}, lambda x: round(x, 2)), + 'dropout_conv': ('float', (0.0, 0.9), {'step': 0.1}, lambda x: round(x, 2)), } @classmethod - def register_classes(cls, registry: Dict[str, Dict[str, Any]] = MODEL_REGI_INIT): + def register_classes(cls, registry: dict[str, dict[str, Any]] = None) -> dict: r"""Register args for all subclass. Args: - name (Dict[str, str]): Model class logging path name. - conv_name (Dict[str, Callable[[str, Any], str]]): Wrap conv logging path name. - module (Dict[str, str]): Module for importing the model. - pargs (Dict[str, List[str]]): Model arguments from argparse. - pargs_default (Dict[str, Dict[str, Any]]): Default values for model arguments. Not recommended. - param (Dict[str, Dict[str, ParamTuple]]): Model parameters to tune. + name (dict[str, str]): Model class logging path name. + conv_name (dict[str, Callable[[str, Any], str]]): Wrap conv logging path name. + module (dict[str, str]): Module for importing the model. + pargs (dict[str, list[str]]): Model arguments from argparse. + pargs_default (dict[str, dict[str, Any]]): Default values for model arguments. Not recommended. + param (dict[str, dict[str, ParamTuple]]): Model parameters to tune. * (str) parameter type, * (tuple) args for :func:`optuna.trial.suggest_`, * (dict) kwargs for :func:`optuna.trial.suggest_`, * (callable) format function to str. """ + if registry is None: + registry = MODEL_REGI_INIT + for subcls in cls.__subclasses__(): subname = subcls.__name__ # Traverse the MRO and accumulate args from parent classes @@ -97,20 +100,20 @@ def register_classes(cls, registry: Dict[str, Dict[str, Any]] = MODEL_REGI_INIT) def __init__(self, conv: str, num_hops: int = 0, - in_channels: Optional[int] = None, - hidden_channels: Optional[int] = None, - out_channels: Optional[int] = None, - in_layers: Optional[int] = None, - out_layers: Optional[int] = None, - dropout_lin: Union[float, List[float]] = 0., + in_channels: int | None = None, + hidden_channels: int | None = None, + out_channels: int | None = None, + in_layers: int | None = None, + out_layers: int | None = None, + dropout_lin: float | list[float] = 0., dropout_conv: float = 0., - act: Union[str, Callable, None] = "relu", + act: str | Callable | None = "relu", act_first: bool = False, - act_kwargs: Optional[Dict[str, Any]] = None, - norm: Union[str, Callable, None] = "batch_norm", - norm_kwargs: Optional[Dict[str, Any]] = None, + act_kwargs: dict[str, Any | None] = None, + norm: str | Callable | None = "batch_norm", + norm_kwargs: dict[str, Any | None] = None, plain_last: bool = False, - bias: Union[bool, List[bool]] = True, + bias: bool | list[bool] = True, **kwargs): super(BaseNN, self).__init__() @@ -165,7 +168,7 @@ def __init__(self, self.reset_parameters() - def init_channel_list(self, conv: str, in_channels: int, hidden_channels: int, out_channels: int, **kwargs) -> List[int]: + def init_channel_list(self, conv: str, in_channels: int, hidden_channels: int, out_channels: int, **kwargs) -> list[int]: # assert (self.in_layers+self.out_layers > 0) or (self.in_channels == self.out_channels) total_layers = self.in_layers + self.conv_layers + self.out_layers channel_list = [in_channels] + [None] * (total_layers - 1) + [out_channels] @@ -230,7 +233,7 @@ def convolute(self, x: Tensor, edge_index: Adj, batch: OptTensor = None, - batch_size: Optional[int] = None, + batch_size: int | None = None, ) -> Tensor: r"""Decoupled propagation step for calling the convolutional module. """ @@ -248,7 +251,7 @@ def forward(self, x: Tensor, edge_index: Adj, batch: OptTensor = None, - batch_size: Optional[int] = None, + batch_size: int | None = None, ) -> Tensor: r""" Args: @@ -297,7 +300,7 @@ class BaseNNCompose(BaseNN): pargs = ['combine'] param = {'combine': ('categorical', ['sum', 'sum_weighted', 'cat'], {}, lambda x: x)} - def init_channel_list(self, conv: str, in_channels: int, hidden_channels: int, out_channels: int, **kwargs) -> List[int]: + def init_channel_list(self, conv: str, in_channels: int, hidden_channels: int, out_channels: int, **kwargs) -> list[int]: """ Attributes: channel_list: width for each conv channel @@ -330,7 +333,7 @@ def init_channel_list(self, conv: str, in_channels: int, hidden_channels: int, o self.gamma = nn.Parameter(torch.ones(n_conv, channel_list[self.in_layers + self.conv_layers])) return channel_list - def _set_conv_attr(self, key: str) -> List[Callable]: + def _set_conv_attr(self, key: str) -> list[Callable]: # NOTE: return a list, not callable if hasattr(self.convs[0][0], key): lst = [getattr(channel[0], key) for channel in self.convs] @@ -380,7 +383,7 @@ def convolute(self, x: Tensor, edge_index: Adj, batch: OptTensor = None, - batch_size: Optional[int] = None, + batch_size: int | None = None, ) -> Tensor: r"""Decoupled propagation step for calling the convolutional module. """ diff --git a/pyg_spectral/nn/models/decoupled.py b/pyg_spectral/nn/models/decoupled.py index 0ed52da..3cfc459 100644 --- a/pyg_spectral/nn/models/decoupled.py +++ b/pyg_spectral/nn/models/decoupled.py @@ -1,5 +1,3 @@ -from typing import List, Union - import numpy as np import torch import torch.nn as nn @@ -25,7 +23,7 @@ }) -def gen_theta(num_hops: int, scheme: str, param: Union[float, List[float]] = None) -> Tensor: +def gen_theta(num_hops: int, scheme: str, param: float | list[float] = None) -> Tensor: r"""Generate list of hop parameters based on given scheme. Args: @@ -231,8 +229,8 @@ class DecoupledFixedCompose(BaseNNCompose): Fixed scalar propagation parameters. Args: - theta_scheme (List[str]): Method to generate decoupled parameters. - theta_param (List[float], optional): Hyperparameter for the scheme. + theta_scheme (list[str]): Method to generate decoupled parameters. + theta_param (list[float], optional): Hyperparameter for the scheme. combine: How to combine different channels of convs. (:obj:`sum`, :obj:`sum_weighted`, or :obj:`cat`). conv, num_hops, in_channels, hidden_channels, out_channels: @@ -291,8 +289,8 @@ class DecoupledVarCompose(BaseNNCompose): Learnable scalar propagation parameters. Args: - theta_scheme (List[str]): Method to generate decoupled parameters. - theta_param (List[float], optional): Hyperparameter for the scheme. + theta_scheme (list[str]): Method to generate decoupled parameters. + theta_param (list[float], optional): Hyperparameter for the scheme. combine: How to combine different channels of convs. (:obj:`sum`, :obj:`sum_weighted`, or :obj:`cat`). conv, num_hops, in_channels, hidden_channels, out_channels: diff --git a/pyg_spectral/nn/models/iterative.py b/pyg_spectral/nn/models/iterative.py index 5a023d6..f570617 100644 --- a/pyg_spectral/nn/models/iterative.py +++ b/pyg_spectral/nn/models/iterative.py @@ -10,11 +10,11 @@ class Iterative(BaseNN): r"""Iterative structure with matrix transformation each hop of propagation. Args: - bias (Optional[bool]): whether learn an additive bias in conv. - weight_initializer (Optional[str]): The initializer for the weight + bias (bool | None): whether learn an additive bias in conv. + weight_initializer (str | None): The initializer for the weight matrix (:obj:`"glorot"`, :obj:`"uniform"`, :obj:`"kaiming_uniform"`, or :obj:`None`). - bias_initializer (Optional[str]): The initializer for the bias vector + bias_initializer (str | None): The initializer for the bias vector (:obj:`"zeros"` or :obj:`None`). conv, num_hops, in_channels, hidden_channels, out_channels: args for :class:`BaseNN` @@ -59,11 +59,11 @@ class IterativeCompose(BaseNNCompose): r"""Iterative structure with matrix transformation each hop of propagation. Args: - bias (Optional[bool]): whether learn an additive bias in conv. - weight_initializer (Optional[str]): The initializer for the weight + bias (bool | None): whether learn an additive bias in conv. + weight_initializer (str | None): The initializer for the weight matrix (:obj:`"glorot"`, :obj:`"uniform"`, :obj:`"kaiming_uniform"`, or :obj:`None`). - bias_initializer (Optional[str]): The initializer for the bias vector + bias_initializer (str | None): The initializer for the bias vector (:obj:`"zeros"` or :obj:`None`). combine: How to combine different channels of convs. (:obj:`sum`, :obj:`sum_weighted`, or :obj:`cat`). diff --git a/pyg_spectral/nn/models/precomputed.py b/pyg_spectral/nn/models/precomputed.py index b048c8b..7a5946c 100644 --- a/pyg_spectral/nn/models/precomputed.py +++ b/pyg_spectral/nn/models/precomputed.py @@ -1,5 +1,3 @@ -from typing import Optional - import torch from torch import Tensor from torch_geometric.typing import Adj, OptTensor @@ -16,7 +14,7 @@ class PrecomputedFixed(DecoupledFixed): Args: theta_scheme (str): Method to generate decoupled parameters. - theta_param (Optional[float]): Hyperparameter for the scheme. + theta_param (float | None): Hyperparameter for the scheme. conv, num_hops, in_channels, hidden_channels, out_channels: args for :class:`BaseNN` in_layers, out_layers, dropout_lin, dropout_conv, lib_conv: @@ -28,7 +26,7 @@ class PrecomputedFixed(DecoupledFixed): name = 'PrecomputedFixed' param = {'in_layers': ('int', (0, 0), {}, lambda x: x),} - def __init__(self, in_layers: Optional[int] = None, **kwargs): + def __init__(self, in_layers: int | None = None, **kwargs): assert in_layers is None or in_layers == 0, "PrecomputedFixed does not support in_layers." super(PrecomputedFixed, self).__init__(in_layers=in_layers, **kwargs) @@ -50,7 +48,7 @@ def convolute(self, def forward(self, x: Tensor, batch: OptTensor = None, - batch_size: Optional[int] = None, + batch_size: int | None = None, ) -> Tensor: r""" Args: @@ -75,7 +73,7 @@ class PrecomputedVar(DecoupledVar): Args: theta_scheme (str): Method to generate decoupled parameters. - theta_param (Optional[float]): Hyperparameter for the scheme. + theta_param (float | None): Hyperparameter for the scheme. conv, num_hops, in_channels, hidden_channels, out_channels: args for :class:`BaseNN` in_layers, out_layers, dropout_lin, dropout_conv, lib_conv: @@ -87,7 +85,7 @@ class PrecomputedVar(DecoupledVar): name = 'PrecomputedVar' param = {'in_layers': ('int', (0, 0), {}, lambda x: x),} - def __init__(self, in_layers: Optional[int] = None, **kwargs): + def __init__(self, in_layers: int | None = None, **kwargs): assert in_layers is None or in_layers == 0, "PrecomputedVar does not support in_layers." super(PrecomputedVar, self).__init__(in_layers=in_layers, **kwargs) @@ -117,7 +115,7 @@ def convolute(self, def forward(self, xs: Tensor, batch: OptTensor = None, - batch_size: Optional[int] = None, + batch_size: int | None = None, ) -> Tensor: r""" Args: @@ -145,8 +143,8 @@ class PrecomputedFixedCompose(DecoupledFixedCompose): Fixed scalar propagation parameters and accumulating precompute results. Args: - theta_scheme (List[str]): Method to generate decoupled parameters. - theta_param (List[float], optional): Hyperparameter for the scheme. + theta_scheme (list[str]): Method to generate decoupled parameters. + theta_param (list[float], optional): Hyperparameter for the scheme. combine: How to combine different channels of convs. (:obj:`sum`, :obj:`sum_weighted`, or :obj:`cat`). conv, num_hops, in_channels, hidden_channels, out_channels: @@ -182,7 +180,7 @@ def convolute(self, def forward(self, x: Tensor, batch: OptTensor = None, - batch_size: Optional[int] = None, + batch_size: int | None = None, ) -> Tensor: r""" Args: @@ -214,8 +212,8 @@ class PrecomputedVarCompose(DecoupledVarCompose): Learnable scalar propagation parameters and storing all intermediate precompute results. Args: - theta_scheme (List[str]): Method to generate decoupled parameters. - theta_param (List[float], optional): Hyperparameter for the scheme. + theta_scheme (list[str]): Method to generate decoupled parameters. + theta_param (list[float], optional): Hyperparameter for the scheme. combine: How to combine different channels of convs. (:obj:`sum`, :obj:`sum_weighted`, or :obj:`cat`). conv, num_hops, in_channels, hidden_channels, out_channels: @@ -259,7 +257,7 @@ def convolute(self, def forward(self, xs: Tensor, batch: OptTensor = None, - batch_size: Optional[int] = None, + batch_size: int | None = None, ) -> Tensor: r""" Args: diff --git a/pyg_spectral/nn/models_pyg/iterative.py b/pyg_spectral/nn/models_pyg/iterative.py index b4cc5ed..f52d52c 100644 --- a/pyg_spectral/nn/models_pyg/iterative.py +++ b/pyg_spectral/nn/models_pyg/iterative.py @@ -1,4 +1,4 @@ -from typing import Any, Callable, Dict, List, Optional, Union +from typing import Any, Callable import torch.nn as nn import torch.nn.functional as F @@ -20,16 +20,16 @@ class ChebNet(nn.Module): def __init__(self, conv: str, num_hops: int = 0, - in_channels: Optional[int] = None, - hidden_channels: Optional[int] = None, - out_channels: Optional[int] = None, + in_channels: int | None = None, + hidden_channels: int | None = None, + out_channels: int | None = None, dropout_lin: float = 0., - act: Union[str, Callable, None] = "relu", + act: str | Callable | None = "relu", act_first: bool = False, - act_kwargs: Optional[Dict[str, Any]] = None, - norm: Union[str, Callable, None] = "batch_norm", - norm_kwargs: Optional[Dict[str, Any]] = None, - bias: Union[bool, List[bool]] = True, + act_kwargs: dict[str, Any | None] = None, + norm: str | Callable | None = "batch_norm", + norm_kwargs: dict[str, Any | None] = None, + bias: bool | list[bool] = True, **kwargs): super().__init__() self.conv1 = ChebConv( diff --git a/pyg_spectral/nn/norm/standard_scale.py b/pyg_spectral/nn/norm/standard_scale.py index 2853883..445589d 100755 --- a/pyg_spectral/nn/norm/standard_scale.py +++ b/pyg_spectral/nn/norm/standard_scale.py @@ -1,5 +1,3 @@ -from typing import Tuple - import torch import torch.nn as nn @@ -17,7 +15,7 @@ def __init__(self, dim: int = 0): self.mean, self.std = None, None @torch.no_grad() - def fit(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + def fit(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: """ Compute the mean and std to be used for later scaling. diff --git a/pyg_spectral/nn/parse_args.py b/pyg_spectral/nn/parse_args.py index b49d220..bceb08c 100644 --- a/pyg_spectral/nn/parse_args.py +++ b/pyg_spectral/nn/parse_args.py @@ -15,21 +15,21 @@ def update_regi(regi, new_regi): model_regi = BaseNN.register_classes() model_regi = update_regi(model_regi, model_regi_pyg) -conv_regi = CallableDict.to_subcallableVal(conv_regi, ['pargs_default', 'param']) +conv_regi = CallableDict.to_callableVal(conv_regi, ['pargs_default', 'param']) r'''Fields: * name (CallableDict[str, str]): Conv class logging path name. - * pargs (CallableDict[str, List[str]]): Conv arguments from argparse. - * pargs_default (Dict[str, CallableDict[str, Any]]): Default values for model arguments. Not recommended. - * param (Dict[str, CallableDict[str, ParamTuple]]): Conv parameters to tune. + * pargs (CallableDict[str, list[str]]): Conv arguments from argparse. + * pargs_default (dict[str, CallableDict[str, Any]]): Default values for model arguments. Not recommended. + * param (dict[str, CallableDict[str, ParamTuple]]): Conv parameters to tune. ''' -model_regi = CallableDict.to_subcallableVal(model_regi, ['pargs_default', 'param']) +model_regi = CallableDict.to_callableVal(model_regi, ['pargs_default', 'param']) r'''Fields: name (CallableDict[str, str]): Model class logging path name. conv_name (CallableDict[str, Callable[[str, Any], str]]): Wrap conv logging path name. module (CallableDict[str, str]): Module for importing the model. - pargs (CallableDict[str, List[str]]): Model arguments from argparse. - pargs_default (Dict[str, CallableDict[str, Any]]): Default values for model arguments. Not recommended. - param (Dict[str, CallableDict[str, ParamTuple]]): Model parameters to tune. + pargs (CallableDict[str, list[str]]): Model arguments from argparse. + pargs_default (dict[str, CallableDict[str, Any]]): Default values for model arguments. Not recommended. + param (dict[str, CallableDict[str, ParamTuple]]): Model parameters to tune. ''' full_pargs = set(v for pargs in conv_regi['pargs'].values() for v in pargs) @@ -112,7 +112,7 @@ def get_nn_name(model: str, conv: str, args) -> str: conv: Input argparse string of conv. Can be composed. args: Additional arguments specified in module :attr:`name` functions. Returns: - nn_name (Tuple[str]): Name strings ``(model_name, conv_name)``. + nn_name (tuple[str]): Name strings ``(model_name, conv_name)``. """ model_name = model_regi['name'](model, args) conv_name = [conv_regi['name'](channel, args) for channel in conv.split(',')] diff --git a/pyg_spectral/utils/dropout.py b/pyg_spectral/utils/dropout.py index 5a81f36..dc419db 100644 --- a/pyg_spectral/utils/dropout.py +++ b/pyg_spectral/utils/dropout.py @@ -1,5 +1,3 @@ -from typing import Tuple - import torch from torch import Tensor @@ -9,7 +7,7 @@ def dropout_edge(edge_index: Tensor, p: float = 0.5, force_undirected: bool = False, - training: bool = True) -> Tuple[Tensor, Tensor]: + training: bool = True) -> tuple[Tensor, Tensor]: r"""Random inplace edge dropout for the adjacency matrix :obj:`edge_index` with probability :obj:`p` using samples from a Bernoulli distribution. diff --git a/pyg_spectral/utils/laplacian.py b/pyg_spectral/utils/laplacian.py index b00f190..d9ae3ba 100755 --- a/pyg_spectral/utils/laplacian.py +++ b/pyg_spectral/utils/laplacian.py @@ -1,5 +1,3 @@ -from typing import Optional, Tuple, Union - import torch from torch import Tensor @@ -11,11 +9,11 @@ def get_laplacian( edge_index: Adj, edge_weight: OptTensor = None, - normalization: Optional[bool] = None, + normalization: bool | None = None, diag: float = 1.0, - dtype: Optional[torch.dtype] = None, - num_nodes: Optional[int] = None, -) -> Union[Tuple[Tensor, Tensor], SparseTensor]: + dtype: torch.dtype | None = None, + num_nodes: int | None = None, +) -> tuple[Tensor, Tensor] | SparseTensor: r"""Computes the graph Laplacian of the graph given by :obj:`edge_index` and optional :obj:`edge_weight`. Remove the normalization of graph adjacency matrix in diff --git a/pyg_spectral/utils/loader.py b/pyg_spectral/utils/loader.py index bad9576..4956b9b 100755 --- a/pyg_spectral/utils/loader.py +++ b/pyg_spectral/utils/loader.py @@ -3,7 +3,7 @@ def load_import(class_name, module_name): - r"""Simple dynamic import for 'module.class'""" + r"""Simple dynamic import for ``module.class``""" module = importlib.import_module(module_name) class_obj = getattr(module, class_name) if isinstance(class_obj, type): @@ -16,14 +16,16 @@ def resolve_func(nargs=1): Args: nargs: The number of arguments to pass to the decorated function. Arguments beyond this number will be passed to the return function. - Examples: - ```python + Examples:: + @resolve_func(1) def foo(bar): return bar - foo(1) # 1 - foo(lambda x: x+1, 2) # 3 - ``` + + >>> foo(1) + 1 + >>> foo(lambda x: x+1, 2) + 3 """ def decorator(func): @wraps(func) @@ -40,7 +42,20 @@ def wrapper(*inputs): class CallableDict(dict): + """ + A dictionary subclass that allows its values to be called as functions. + """ + def __call__(self, key, *args): + r"""Get key value and call it with args if it is callable. + + Args: + key: The key to get the value from. + *args: Arguments to pass to the indexed value if it is callable. + Returns: + ret (Any | list): If the value is callable, returns the result of + calling it with args. Otherwise, returns the value. + """ def _get_callable(key): ret = self.get(key, None) if callable(ret): @@ -54,19 +69,28 @@ def _get_callable(key): return _get_callable(key) @classmethod - def to_callableVal(cls, dct, keys=None): - keys = keys or dct.keys() - for key in keys: - if isinstance(dct[key], dict): - dct[key] = cls(dct[key]) - return dct + def to_callableVal(cls, dct, keys:list=None, reckeys:list=[]): + r"""Converts the sub-dictionaries of the specified keys in the + dictionary to :class:`CallableDict`. - @classmethod - def to_subcallableVal(cls, dct, keys=[]): + Args: + dct (dict): The dictionary to convert. + keys (list[str]): The keys to convert. If None, converts all sub-dictionaries. + reckeys (list[str]): The keys to recursively convert sub-sub-dictionaries. + Returns: + dct (dict): The dictionary with the specified keys converted to :class:`CallableDict`. + Examples:: + + dct = {'key0': Dict0, 'key1': Dict1, 'key2': Dict2} + dct = CallableDict.to_callableVal(dct, keys=['key1'], reckeys=['key2']) + # dct = {'key0': Dict0, 'key1': CallableDict1, 'key2': Dict2}, + # and each sub-dictionary in 'key2' is converted to CallableDict. + """ + keys = keys or dct.keys() for key in dct: - if key in keys: + if key in reckeys: dct[key] = cls.to_callableVal(dct[key]) - else: + elif key in keys: if isinstance(dct[key], dict): dct[key] = cls(dct[key]) return dct