Skip to content

Commit 33784d1

Browse files
committed
init
0 parents  commit 33784d1

File tree

391 files changed

+12477
-0
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

391 files changed

+12477
-0
lines changed

.gitignore

+17
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
.idea/**
2+
storage
3+
docs/build
4+
/docs/source/_autosummary/
5+
**/__pycace__/
6+
/GOOD/kernel/auto_launch.py
7+
/GOOD/kernel/auto_chart.py
8+
/GOOD/kernel/auto_chart_final.py
9+
/GOOD/kernel/auto_chart_paper.py
10+
/GOOD/kernel/auto_chart_tvt.py
11+
/GOOD/kernel/auto_checkpoint.py
12+
/GOOD/kernel/auto_curve.py
13+
/GOOD/kernel/auto_curve_chart.py
14+
/configs/auto_config_finetune/
15+
/graphEx/
16+
/debug_log.py
17+
/GOOD/kernel/launchers/ada_launcher.py

GOOD/__init__.py

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
from .utils import config_summoner, args_parser
2+
from .utils.register import register
3+
from . import data, networks, ood_algorithms

GOOD/data/__init__.py

+6
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
r"""
2+
This data module includes 11 GOOD datasets and a dataloader for an organized data loading process.
3+
"""
4+
from GOOD.data.dataset_manager import load_dataset, create_dataloader
5+
from .good_datasets import *
6+
from .good_loaders import *

GOOD/data/dataset_manager.py

+129
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
r"""A module that is consist of a dataset loading function and a PyTorch dataloader loading function.
2+
"""
3+
4+
from torch_geometric.loader import DataLoader, GraphSAINTRandomWalkSampler
5+
6+
from GOOD import register
7+
from GOOD.utils.config_reader import Union, CommonArgs, Munch
8+
from GOOD.utils.initial import reset_random_seed
9+
10+
11+
def read_meta_info(meta_info, config: Union[CommonArgs, Munch]):
12+
config.dataset.dataset_type = meta_info.dataset_type
13+
config.model.model_level = meta_info.model_level
14+
config.dataset.dim_node = meta_info.dim_node
15+
config.dataset.dim_edge = meta_info.dim_edge
16+
config.dataset.num_envs = meta_info.num_envs
17+
config.dataset.num_classes = meta_info.num_classes
18+
config.dataset.num_train_nodes = meta_info.get('num_train_nodes')
19+
config.dataset.num_domains = meta_info.get('num_domains')
20+
config.dataset.feat_dims = meta_info.get('feat_dims')
21+
config.dataset.edge_feat_dims = meta_info.get('edge_feat_dims')
22+
23+
24+
def load_dataset(name: str, config: Union[CommonArgs, Munch]) -> dir:
25+
r"""
26+
Load a dataset given the dataset name.
27+
28+
Args:
29+
name (str): Dataset name.
30+
config (Union[CommonArgs, Munch]): Required configs:
31+
``config.dataset.dataset_root``
32+
``config.dataset.domain``
33+
``config.dataset.shift_type``
34+
``config.dataset.generate``
35+
36+
Returns:
37+
A dataset object and new configs
38+
- config.dataset.dataset_type
39+
- config.model.model_level
40+
- config.dataset.dim_node
41+
- config.dataset.dim_edge
42+
- config.dataset.num_envs
43+
- config.dataset.num_classes
44+
45+
"""
46+
try:
47+
reset_random_seed(config)
48+
dataset, meta_info = register.datasets[name].load(dataset_root=config.dataset.dataset_root,
49+
domain=config.dataset.domain,
50+
shift=config.dataset.shift_type,
51+
generate=config.dataset.generate)
52+
except KeyError as e:
53+
print('Dataset not found.')
54+
raise e
55+
read_meta_info(meta_info, config)
56+
57+
config.metric.set_score_func(dataset['metric'] if type(dataset) is dict else getattr(dataset, 'metric'))
58+
config.metric.set_loss_func(dataset['task'] if type(dataset) is dict else getattr(dataset, 'task'))
59+
60+
return dataset
61+
62+
63+
def create_dataloader(dataset, config: Union[CommonArgs, Munch]):
64+
r"""
65+
Create a PyG data loader.
66+
67+
Args:
68+
loader_name:
69+
dataset: A GOOD dataset.
70+
config: Required configs:
71+
``config.train.train_bs``
72+
``config.train.val_bs``
73+
``config.train.test_bs``
74+
``config.model.model_layer``
75+
``config.train.num_steps(for node prediction)``
76+
77+
Returns:
78+
A PyG dataset loader.
79+
80+
"""
81+
loader_name = config.dataset.dataloader_name
82+
try:
83+
reset_random_seed(config)
84+
loader = register.dataloader[loader_name].setup(dataset, config)
85+
except KeyError as e:
86+
print(f'DataLoader {loader_name} not found.')
87+
raise e
88+
89+
return loader
90+
91+
92+
def domain_pair_dataloader(dataset, config: Union[CommonArgs, Munch]):
93+
r"""
94+
Create a PyG domain_pair data loader.
95+
96+
Args:
97+
dataset: A GOOD dataset.
98+
config: Required configs:
99+
``config.train.train_bs``
100+
``config.train.val_bs``
101+
``config.train.test_bs``
102+
``config.model.model_layer``
103+
``config.train.num_steps(for node prediction)``
104+
105+
Returns:
106+
A PyG domain_pair dataset loader.
107+
108+
"""
109+
reset_random_seed(config)
110+
if config.model.model_level == 'node':
111+
graph = dataset[0]
112+
loader = GraphSAINTRandomWalkSampler(graph, batch_size=config.train.train_bs,
113+
walk_length=config.model.model_layer,
114+
num_steps=config.train.num_steps, sample_coverage=100,
115+
save_dir=dataset.processed_dir)
116+
loader = {'train': loader, 'eval_train': [graph], 'id_val': [graph], 'id_test': [graph], 'val': [graph],
117+
'test': [graph]}
118+
else:
119+
loader = {'train': DataLoader(dataset['train'], batch_size=config.train.train_bs, shuffle=True),
120+
'eval_train': DataLoader(dataset['train'], batch_size=config.train.val_bs, shuffle=False),
121+
'id_val': DataLoader(dataset['id_val'], batch_size=config.train.val_bs, shuffle=False) if dataset.get(
122+
'id_val') else None,
123+
'id_test': DataLoader(dataset['id_test'], batch_size=config.train.test_bs,
124+
shuffle=False) if dataset.get(
125+
'id_test') else None,
126+
'val': DataLoader(dataset['val'], batch_size=config.train.val_bs, shuffle=False),
127+
'test': DataLoader(dataset['test'], batch_size=config.train.test_bs, shuffle=False)}
128+
129+
return loader

GOOD/data/good_datasets/__init__.py

+15
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
r"""
2+
This module includes 11 GOOD datasets.
3+
4+
- Graph prediction datasets: GOOD-HIV, GOOD-PCBA, GOOD-ZINC, GOOD-SST2, GOOD-CMNIST, GOOD-Motif.
5+
- Node prediction datasets: GOOD-Cora, GOOD-Arxiv, GOOD-Twitch, GOOD-WebKB, GOOD-CBAS.
6+
"""
7+
8+
import glob
9+
from os.path import dirname, basename, isfile, join
10+
11+
modules = glob.glob(join(dirname(__file__), "*.py"))
12+
__all__ = [basename(f)[:-3] for f in modules if isfile(f) and not f.endswith('__init__.py')]
13+
14+
from . import *
15+

0 commit comments

Comments
 (0)