Skip to content

Commit

Permalink
Update docs
Browse files Browse the repository at this point in the history
  • Loading branch information
nyLiao committed Oct 6, 2024
1 parent 8a178a9 commit 34a5c9b
Show file tree
Hide file tree
Showing 23 changed files with 176 additions and 162 deletions.
10 changes: 5 additions & 5 deletions benchmark/dataset/linkx.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os.path as osp
from typing import Any, Callable, Optional
from typing import Any, Callable
from argparse import Namespace

import numpy as np
Expand Down Expand Up @@ -46,8 +46,8 @@ def __init__(
self,
root: str,
name: str,
transform: Optional[Callable] = None,
pre_transform: Optional[Callable] = None,
transform: Callable | None = None,
pre_transform: Callable | None = None,
force_reload: bool = False,
) -> None:
self.name = name.lower()
Expand Down Expand Up @@ -156,8 +156,8 @@ def __init__(
self,
root: str,
name: str,
transform: Optional[Callable] = None,
pre_transform: Optional[Callable] = None,
transform: Callable | None = None,
pre_transform: Callable | None = None,
force_reload: bool = False,
) -> None:
self.name = name.lower()
Expand Down
3 changes: 1 addition & 2 deletions benchmark/dataset/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from typing import Tuple
from argparse import Namespace
import numpy as np
from sklearn.model_selection import train_test_split
Expand Down Expand Up @@ -107,7 +106,7 @@ def split_crossval(label: torch.Tensor,
r_val: float,
seed: int = None,
ignore_neg: bool =True,
stratify: bool =False) -> Tuple[torch.Tensor]:
stratify: bool =False) -> tuple[torch.Tensor]:
r"""Split index by cross-validation"""
node_labeled = torch.where(label >= 0)[0] if ignore_neg else np.arange(label.shape[0])

Expand Down
6 changes: 3 additions & 3 deletions benchmark/dataset/yandex.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os.path as osp
from typing import Callable, Optional
from typing import Callable
from argparse import Namespace

import numpy as np
Expand All @@ -26,8 +26,8 @@ def __init__(
self,
root: str,
name: str,
transform: Optional[Callable] = None,
pre_transform: Optional[Callable] = None,
transform: Callable | None = None,
pre_transform: Callable | None = None,
force_reload: bool = False,
) -> None:
self.name = name.lower()
Expand Down
15 changes: 7 additions & 8 deletions benchmark/trainer/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
Author: nyLiao
File Created: 2024-03-03
"""
from typing import List
import logging
from argparse import Namespace

Expand Down Expand Up @@ -145,8 +144,8 @@ def wrapper(self, *args, **kwargs):
# ===== Run block
@_log_memory(split='train')
def train_val(self,
split_train: List[str] = ['train'],
split_val: List[str] = ['val']) -> ResLogger:
split_train: list[str] = ['train'],
split_val: list[str] = ['val']) -> ResLogger:
r"""Pipeline for iterative training.
Args:
Expand Down Expand Up @@ -187,7 +186,7 @@ def train_val(self,

@_log_memory(split='eval')
def test(self,
split_test: List[str] = ['train', 'val', 'test']) -> ResLogger:
split_test: list[str] = ['train', 'val', 'test']) -> ResLogger:
r"""Pipeline for testing.
Args:
split_test (list): Testing splits.
Expand All @@ -206,11 +205,11 @@ def _fetch_input(self) -> tuple:
r"""Process each sample of model input and label."""
raise NotImplementedError

def _learn_split(self, split: List[str] = ['train']) -> ResLogger:
def _learn_split(self, split: list[str] = ['train']) -> ResLogger:
r"""Actual train iteration on the given splits."""
raise NotImplementedError

def _eval_split(self, split: List[str]) -> ResLogger:
def _eval_split(self, split: list[str]) -> ResLogger:
r"""Actual test on the given splits."""
raise NotImplementedError

Expand Down Expand Up @@ -258,8 +257,8 @@ def clear(self):
if self.optimizer: del self.optimizer

def train_val(self,
split_train: List[str] = ['train'],
split_val: List[str] = ['val']) -> ResLogger:
split_train: list[str] = ['train'],
split_val: list[str] = ['val']) -> ResLogger:
import optuna

time_learn = profile.Accumulator()
Expand Down
3 changes: 1 addition & 2 deletions benchmark/trainer/fullbatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
Author: nyLiao
File Created: 2024-02-26
"""
from typing import Tuple
from argparse import Namespace

import torch
Expand Down Expand Up @@ -57,7 +56,7 @@ def clear(self):
del self.mask, self.data
return super().clear()

def _fetch_data(self) -> Tuple[Data, dict]:
def _fetch_data(self) -> tuple[Data, dict]:
r"""Process the single graph data."""
t_to_device = T.ToDevice(self.device, attrs=['x', 'y', 'adj_t', 'edge_index'] + [f'{k}_mask' for k in self.splits])
self.data = t_to_device(self.data)
Expand Down
4 changes: 2 additions & 2 deletions benchmark/trainer/load_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Author: nyLiao
File Created: 2024-03-03
"""
from typing import Tuple, List, Callable, Any
from typing import Callable, Any
from argparse import Namespace
from torchmetrics import MetricCollection
from torchmetrics.classification import (
Expand All @@ -15,7 +15,7 @@


class ResCollection(MetricCollection):
def compute(self) -> List[Tuple[str, Any, Callable]]:
def compute(self) -> list[tuple[str, Any, Callable]]:
r"""Wrap compute output to :class:`ResLogger` style."""
dct = self._compute_and_reduce("compute")
return [(k, v.cpu().numpy(), (lambda x: format(x*100, '.3f'))) for k, v in dct.items()]
Expand Down
11 changes: 5 additions & 6 deletions benchmark/trainer/load_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
Author: nyLiao
File Created: 2024-02-26
"""
from typing import Tuple
from argparse import Namespace
import logging
import torch.nn as nn
Expand Down Expand Up @@ -40,7 +39,7 @@ def __init__(self, args: Namespace, res_logger: ResLogger = None) -> None:
self.res_logger = res_logger or ResLogger()

@staticmethod
def get_name(args: Namespace) -> Tuple[str]:
def get_name(args: Namespace) -> tuple[str]:
"""Get model+conv name for logging path from argparse input without instantiation.
Wrapper for :func:`pyg_spectral.nn.get_nn_name()`.
Expand All @@ -51,7 +50,7 @@ def get_name(args: Namespace) -> Tuple[str]:
* args.conv (str): Convolution layer name.
* other args specified in module :attr:`name` function.
Returns:
nn_name (Tuple[str]): Name strings ``(model_name, conv_name)``.
nn_name (tuple[str]): Name strings ``(model_name, conv_name)``.
"""
return get_nn_name(args.model, args.conv, args)

Expand All @@ -75,7 +74,7 @@ def get_trn(args: Namespace) -> TrnBase:
'CppPrecFixed': TrnMinibatch,
}[model_repr]

def _resolve_import(self, args: Namespace) -> Tuple[str, str, dict]:
def _resolve_import(self, args: Namespace) -> tuple[str, str, dict]:
class_name = self.model
module_name = get_model_regi(self.model, 'module', args)
kwargs = set_pargs(self.model, self.conv, args)
Expand Down Expand Up @@ -103,7 +102,7 @@ def _resolve_import(self, args: Namespace) -> Tuple[str, str, dict]:
# <<<<<<<<<<
return class_name, module_name, kwargs

def get(self, args: Namespace) -> Tuple[nn.Module, TrnBase]:
def get(self, args: Namespace) -> tuple[nn.Module, TrnBase]:
r"""Load model with specified arguments.
Args:
Expand Down Expand Up @@ -143,7 +142,7 @@ def __str__(self) -> str:
class ModelLoader_Trial(ModelLoader):
r"""Reuse necessary data for multiple runs.
"""
def get(self, args: Namespace) -> Tuple[nn.Module, TrnBase]:
def get(self, args: Namespace) -> tuple[nn.Module, TrnBase]:
self.signature_lst = ['num_hops', 'in_layers', 'out_layers', 'hidden_channels', 'dropout_lin', 'dropout_conv']
self.signature = {key: getattr(args, key) for key in self.signature_lst}

Expand Down
2 changes: 1 addition & 1 deletion benchmark/trainer/minibatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Author: nyLiao
File Created: 2024-03-03
"""
from typing import Tuple, Generator
from typing import Generator
import logging
from argparse import Namespace

Expand Down
3 changes: 1 addition & 2 deletions benchmark/trainer/regression.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from typing import Tuple
from argparse import Namespace
import logging

Expand Down Expand Up @@ -63,7 +62,7 @@ def __init__(self, args: Namespace, res_logger: ResLogger = None) -> None:
])

# ===== Data acquisition
def _resolve_import(self, args: Namespace) -> Tuple[str, str, dict]:
def _resolve_import(self, args: Namespace) -> tuple[str, str, dict]:
assert self.data in ['2dgrid']
module_name = 'dataset'
class_name = 'Grid2D'
Expand Down
6 changes: 3 additions & 3 deletions benchmark/utils/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Author: nyLiao
File Created: 2023-03-20
"""
from typing import Union, Callable
from typing import Callable
from pathlib import Path

import copy
Expand All @@ -26,12 +26,12 @@ class CkptLogger(object):
metric_cmp: Comparison function for the metric. Can be 'max' or 'min'.
"""
def __init__(self,
logpath: Union[Path, str],
logpath: Path | str,
patience: int = -1,
period: int = 0,
prefix: str = 'model',
storage: str = 'state_gpu',
metric_cmp: Union[Callable[[float, float], bool], str]='max'):
metric_cmp: Callable[[float, float], bool] | str='max'):
self.logpath = Path(logpath)
self.prefix = prefix
self.filetype = 'pth'
Expand Down
24 changes: 12 additions & 12 deletions benchmark/utils/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
Author: nyLiao
File Created: 2023-03-20
"""
from typing import Tuple, List, Dict, Callable, Union, Any
from typing import Callable, Any
from pathlib import Path

import os
Expand All @@ -28,7 +28,7 @@
warnings.filterwarnings('ignore', '.*Attempting to set identical low and high ylims.*')


def setup_logger(logpath: Union[Path, str] = LOGPATH,
def setup_logger(logpath: Path | str = LOGPATH,
level_console: int = logging.DEBUG,
level_file: int = 15,
quiet: bool = True,
Expand Down Expand Up @@ -71,8 +71,8 @@ def clear_logger(logger: logging.Logger):
logger.info(f"[time]: {datetime.now()}")


def setup_logpath(dir: Union[Path, str] = LOGPATH,
folder_args: Tuple=None,
def setup_logpath(dir: Path | str = LOGPATH,
folder_args: tuple=None,
quiet: bool = True):
r"""Resolve log path for saving.
Expand Down Expand Up @@ -103,7 +103,7 @@ class ResLogger(object):
quiet: Quiet run without saving file.
"""
def __init__(self,
logpath: Union[Path, str] = LOGPATH,
logpath: Path | str = LOGPATH,
prefix: str = 'summary',
suffix: str = None,
quiet: bool = True):
Expand Down Expand Up @@ -163,7 +163,7 @@ def _set(self, data: DataFrame, fmt: Series):
self.fmt = self.fmt.combine_first(fmt)

def concat(self,
vals: Union[List[Tuple[str, Any, Callable]], Dict],
vals: list[tuple[str, Any, Callable]] | dict,
row: int = 0,
suffix: str = None):
r"""Concatenate data entries of a single row to data.
Expand Down Expand Up @@ -203,7 +203,7 @@ def __call__(self, *args, **kwargs) -> 'ResLogger':

def merge(self,
logger: 'ResLogger',
rows: List[int] = None,
rows: list[int] = None,
suffix: str = None):
r"""Merge from another logger.
Expand All @@ -222,7 +222,7 @@ def merge(self,
self._set(logger.data, logger.fmt)
return self

def del_col(self, col: Union[List, str]) -> 'ResLogger':
def del_col(self, col: list | str) -> 'ResLogger':
r"""Delete columns from data.
Args:
Expand All @@ -234,8 +234,8 @@ def del_col(self, col: Union[List, str]) -> 'ResLogger':

# ===== Output
def _get(self,
col: Union[List, str]=None,
row: Union[List, str]=None) -> Union[DataFrame, Series, str]:
col: list | str=None,
row: list | str=None) -> DataFrame | Series | str:
r"""Retrieve one or sliced data and apply string format.
Args:
Expand Down Expand Up @@ -289,8 +289,8 @@ def save(self):
mode='a', header=f.tell()==0)

def get_str(self,
col: Union[List, str] = None,
row: Union[List, int] = None,
col: list | str = None,
row: list | int = None,
maxlen: int = -1) -> str:
r"""Get formatted long string for printing of the specified columns
and rows.
Expand Down
25 changes: 14 additions & 11 deletions pyg_spectral/nn/conv/base_mp.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from typing import Any, Callable, Dict, List, Tuple, Optional
type ParamTuple = Tuple[str, tuple, Dict[str, Any], Callable[[Any], str]]
from typing import Any, Callable, NewType
ParamTuple = NewType('ParamTuple', tuple[str, tuple, dict[str, Any], Callable[[Any], str]])
import re

import torch
Expand Down Expand Up @@ -27,25 +27,28 @@ class BaseMP(MessagePassing):
supports_batch: bool = True
supports_norm_batch: bool = True
name: Callable[[Any], str] # args -> str
pargs: List[str] = []
param: Dict[str, ParamTuple] = {}
_cache: Optional[Any]
pargs: list[str] = []
param: dict[str, ParamTuple] = {}
_cache: Any | None

@classmethod
def register_classes(cls, registry: Dict[str, Dict[str, Any]] = CONV_REGI_INIT):
def register_classes(cls, registry: dict[str, dict[str, Any]] = None) -> dict:
r"""Register args for all subclass.
Args:
name (Dict[str, str]): Conv class logging path name.
pargs (Dict[str, List[str]]): Conv arguments from argparse.
pargs_default (Dict[str, Dict[str, Any]]): Default values for model arguments. Not recommended.
param (Dict[str, Dict[str, ParamTuple]]): Conv parameters to tune.
name (dict[str, str]): Conv class logging path name.
pargs (dict[str, list[str]]): Conv arguments from argparse.
pargs_default (dict[str, dict[str, Any]]): Default values for model arguments. Not recommended.
param (dict[str, dict[str, ParamTuple]]): Conv parameters to tune.
* (str) parameter type,
* (tuple) args for :func:`optuna.trial.suggest_<type>`,
* (dict) kwargs for :func:`optuna.trial.suggest_<type>`,
* (callable) format function to str.
"""
if registry is None:
registry = CONV_REGI_INIT

for subcls in cls.__subclasses__():
subname = subcls.__name__
# Traverse the MRO and accumulate args from parent classes
Expand Down Expand Up @@ -186,7 +189,7 @@ def _get_convolute_mat(self, x: Tensor, edge_index: Adj) -> dict:
def get_forward_mat(self,
x: Tensor,
edge_index: Adj,
comp_scheme: Optional[str] = None
comp_scheme: str | None = None
) -> dict:
r"""Get matrices for :meth:`forward()`. Called during :meth:`forward()`.
Expand Down
Loading

0 comments on commit 34a5c9b

Please sign in to comment.