Skip to content

Commit 37702b0

Browse files
committed
LECI: beta
1 parent 6592e21 commit 37702b0

File tree

273 files changed

+2912
-22
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

273 files changed

+2912
-22
lines changed
+356
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,356 @@
1+
"""
2+
The GOOD-Motif dataset motivated by `Spurious-Motif
3+
<https://arxiv.org/abs/2201.12872>`_.
4+
"""
5+
import math
6+
import os
7+
import os.path as osp
8+
import random
9+
10+
import gdown
11+
import torch
12+
from munch import Munch
13+
from torch_geometric.data import InMemoryDataset, extract_zip
14+
from torch_geometric.utils import from_networkx
15+
from tqdm import tqdm
16+
17+
from GOOD import register
18+
from GOOD.utils.synthetic_data.BA3_loc import *
19+
from GOOD.utils.synthetic_data import synthetic_structsim
20+
21+
22+
@register.dataset_register
23+
class FPIIFMotif(InMemoryDataset):
24+
r"""
25+
The GOOD-Motif dataset motivated by `Spurious-Motif
26+
<https://arxiv.org/abs/2201.12872>`_.
27+
28+
Args:
29+
root (str): The dataset saving root.
30+
domain (str): The domain selection. Allowed: 'basis' and 'size'.
31+
shift (str): The distributional shift we pick. Allowed: 'no_shift', 'covariate', and 'concept'.
32+
subset (str): The split set. Allowed: 'train', 'id_val', 'id_test', 'val', and 'test'. When shift='no_shift',
33+
'id_val' and 'id_test' are not applicable.
34+
generate (bool): The flag for regenerating dataset. True: regenerate. False: download.
35+
"""
36+
37+
def __init__(self, root: str, domain: str, shift: str = 'no_shift', subset: str = 'train', transform=None,
38+
pre_transform=None, generate: bool = False):
39+
40+
self.name = self.__class__.__name__
41+
self.domain = domain
42+
self.metric = 'Accuracy'
43+
self.task = 'Multi-label classification'
44+
self.url = ''
45+
46+
self.generate = generate
47+
48+
# self.all_basis = ["wheel", "tree", "ladder", "star", "path"]
49+
# self.basis_role_end = {'wheel': 0, 'tree': 0, 'ladder': 0, 'star': 1, 'path': 1}
50+
self.all_basis = ["wheel", "tree", "ladder", "circular_ladder", "dorogovtsev_goltsev_mendes", "star", "path"]
51+
self.basis_role_end = {'wheel': 0, 'tree': 0, 'ladder': 0, 'circular_ladder': 0,
52+
'dorogovtsev_goltsev_mendes': 0, 'star': 1, 'path': 1}
53+
self.all_motifs = [[["house"]], [["dircycle"]], [["crane"]]]
54+
self.num_data = 3000
55+
56+
super().__init__(root, transform, pre_transform)
57+
subset_pt = 0
58+
if shift == 'concept':
59+
subset_pt += 0
60+
elif shift == 'FIIF':
61+
subset_pt += 5
62+
elif shift == 'PIIF':
63+
subset_pt += 10
64+
65+
if subset == 'train':
66+
subset_pt += 0
67+
elif subset == 'val':
68+
subset_pt += 1
69+
elif subset == 'test':
70+
subset_pt += 2
71+
elif subset == 'id_val':
72+
subset_pt += 3
73+
else:
74+
subset_pt += 4
75+
76+
self.data, self.slices = torch.load(self.processed_paths[subset_pt])
77+
78+
@property
79+
def raw_dir(self):
80+
return osp.join(self.root)
81+
82+
def _download(self):
83+
if os.path.exists(osp.join(self.raw_dir, self.name)) or self.generate:
84+
return
85+
if not os.path.exists(self.raw_dir):
86+
os.makedirs(self.raw_dir)
87+
self.download()
88+
89+
def download(self):
90+
path = gdown.download(self.url, output=osp.join(self.raw_dir, self.name + '.zip'), fuzzy=True)
91+
extract_zip(path, self.raw_dir)
92+
os.unlink(path)
93+
94+
@property
95+
def processed_dir(self):
96+
return osp.join(self.root, self.name, self.domain, 'processed')
97+
98+
@property
99+
def processed_file_names(self):
100+
return ['concept_train.pt', 'concept_val.pt', 'concept_test.pt', 'concept_id_val.pt', 'concept_id_test.pt',
101+
'FIIF_train.pt', 'FIIF_val.pt', 'FIIF_test.pt', 'FIIF_id_val.pt', 'FIIF_id_test.pt',
102+
'PIIF_train.pt', 'PIIF_val.pt', 'PIIF_test.pt', 'PIIF_id_val.pt', 'PIIF_id_test.pt']
103+
104+
def gen_data(self, basis_id, width_basis, motif_id, y=None):
105+
basis_type = self.all_basis[basis_id]
106+
if basis_type == 'tree':
107+
width_basis = int(math.log2(width_basis)) - 1
108+
if width_basis <= 0:
109+
width_basis = 1
110+
if basis_type == 'dorogovtsev_goltsev_mendes':
111+
width_basis = math.ceil(math.log(width_basis, 3))
112+
if width_basis <= 0:
113+
width_basis = 1
114+
list_shapes = self.all_motifs[motif_id]
115+
G, role_id, _ = synthetic_structsim.build_graph(
116+
width_basis, basis_type, list_shapes, start=0, rdm_basis_plugins=True
117+
)
118+
G = perturb([G], 0.05, id=role_id)[0]
119+
# from GOOD.causal_engine.graph_visualize import plot_graph
120+
# print(G.edges())
121+
# plot_graph(G, colors=[1 for _ in G.nodes()])
122+
123+
# --- Convert networkx graph into pyg data ---
124+
data = from_networkx(G)
125+
data.x = torch.ones((data.num_nodes, 1))
126+
role_id = torch.tensor(role_id, dtype=torch.long)
127+
role_id[role_id <= self.basis_role_end[basis_type]] = 0
128+
role_id[role_id != 0] = 1
129+
130+
edge_gt = torch.stack([role_id[data.edge_index[0]], role_id[data.edge_index[1]]]).sum(0) > 1.5
131+
132+
data.node_gt = role_id
133+
data.edge_gt = edge_gt
134+
data.basis_id = basis_id
135+
data.motif_id = motif_id
136+
137+
# --- noisy labels ---
138+
if y is None:
139+
if random.random() < 0.1:
140+
data.y = random.randint(0, 2)
141+
else:
142+
data.y = motif_id
143+
else:
144+
data.y = y
145+
146+
return data
147+
148+
def get_basis_concept_list(self, num_data=60000):
149+
# data_list = []
150+
train_ratio = 0.6
151+
val_ratio = 0.2
152+
test_ratio = 0.2
153+
num_train = int(num_data * train_ratio)
154+
num_val = int(num_data * val_ratio)
155+
num_test = int(num_data * test_ratio)
156+
train_spurious_ratio = [0.99, 0.97, 0.95]
157+
val_spurious_ratio = [0.3]
158+
test_spurious_ratio = [0.0]
159+
train_list = []
160+
for spur_id in tqdm(range(len(train_spurious_ratio))):
161+
for i in range(num_train // len(train_spurious_ratio)):
162+
motif_id = random.randint(0, 2)
163+
width_basis = 10 + np.random.random_integers(-5, 5)
164+
if random.random() < train_spurious_ratio[spur_id]:
165+
basis_id = motif_id
166+
else:
167+
basis_id = random.randint(0, 2)
168+
data = self.gen_data(basis_id=basis_id, width_basis=width_basis, motif_id=motif_id)
169+
data.env_id = torch.LongTensor([basis_id])
170+
train_list.append(data)
171+
172+
val_list = []
173+
for i in range(num_val):
174+
motif_id = random.randint(0, 2)
175+
width_basis = 10 + np.random.random_integers(-5, 5)
176+
if random.random() < val_spurious_ratio[0]:
177+
basis_id = motif_id
178+
else:
179+
basis_id = random.randint(0, 2)
180+
data = self.gen_data(basis_id=basis_id, width_basis=width_basis, motif_id=motif_id)
181+
val_list.append(data)
182+
183+
test_list = []
184+
for i in range(num_test):
185+
motif_id = random.randint(0, 2)
186+
width_basis = 10 + np.random.random_integers(-5, 5)
187+
if random.random() < test_spurious_ratio[0]:
188+
basis_id = motif_id
189+
else:
190+
basis_id = random.randint(0, 2)
191+
data = self.gen_data(basis_id=basis_id, width_basis=width_basis, motif_id=motif_id)
192+
test_list.append(data)
193+
194+
id_test_ratio = 0.15
195+
num_id_test = int(len(train_list) * id_test_ratio)
196+
random.shuffle(train_list)
197+
train_list, id_val_list, id_test_list = train_list[: -2 * num_id_test], \
198+
train_list[-2 * num_id_test: - num_id_test], train_list[- num_id_test:]
199+
200+
all_env_list = [train_list, val_list, test_list, id_val_list, id_test_list]
201+
202+
return all_env_list
203+
204+
def get_basis_FIIF_list(self, num_data=60000):
205+
train_ratio = 0.8
206+
val_ratio = 0.1
207+
test_ratio = 0.1
208+
train_num = int(num_data * train_ratio)
209+
val_num = int(num_data * val_ratio)
210+
test_num = int(num_data * test_ratio)
211+
split_num = [train_num, val_num, test_num]
212+
all_width_basis = [10, 20, 30]
213+
all_split_list = [[] for _ in range(3)]
214+
for split_id in range(3):
215+
for _ in range(split_num[split_id]):
216+
motif_id = random.randint(0, 2)
217+
if split_id == 0:
218+
basis_id = random.randint(0, 2)
219+
else:
220+
basis_id = split_id + 2
221+
222+
# --- G_C controls G_S's width ---
223+
width_basis = all_width_basis[motif_id] + random.randint(-5, 5)
224+
data = self.gen_data(basis_id=basis_id, width_basis=width_basis, motif_id=motif_id)
225+
data.env_id = torch.LongTensor([basis_id])
226+
all_split_list[split_id].append(data)
227+
228+
train_list = all_split_list[0]
229+
num_id_test = int(num_data * test_ratio)
230+
random.shuffle(train_list)
231+
train_list, id_val_list, id_test_list = train_list[: -2 * num_id_test], \
232+
train_list[-2 * num_id_test: - num_id_test], train_list[- num_id_test:]
233+
234+
ood_val_list = all_split_list[1]
235+
ood_test_list = all_split_list[2]
236+
237+
all_env_list = [train_list, ood_val_list, ood_test_list, id_val_list, id_test_list]
238+
239+
return all_env_list
240+
241+
def get_basis_PIIF_list(self, num_data=60000):
242+
train_ratio = 0.8
243+
val_ratio = 0.1
244+
test_ratio = 0.1
245+
train_num = int(num_data * train_ratio)
246+
val_num = int(num_data * val_ratio)
247+
test_num = int(num_data * test_ratio)
248+
split_num = [train_num, val_num, test_num]
249+
all_width_basis = [10, 20, 30]
250+
all_split_list = [[] for _ in range(3)]
251+
for split_id in range(3):
252+
for _ in range(split_num[split_id]):
253+
motif_id = random.randint(0, 2)
254+
if split_id == 0:
255+
basis_id = random.randint(0, 2)
256+
else:
257+
basis_id = split_id + 2
258+
259+
# --- get y ---
260+
if random.random() < 0.1:
261+
data_y = random.randint(0, 2)
262+
else:
263+
data_y = motif_id
264+
265+
# --- y controls G_S's width ---
266+
width_basis = all_width_basis[data_y] + random.randint(-5, 5)
267+
268+
data = self.gen_data(basis_id=basis_id, width_basis=width_basis, motif_id=motif_id, y=data_y)
269+
data.env_id = torch.LongTensor([basis_id])
270+
all_split_list[split_id].append(data)
271+
272+
train_list = all_split_list[0]
273+
num_id_test = int(num_data * test_ratio)
274+
random.shuffle(train_list)
275+
train_list, id_val_list, id_test_list = train_list[: -2 * num_id_test], \
276+
train_list[-2 * num_id_test: - num_id_test], train_list[- num_id_test:]
277+
278+
ood_val_list = all_split_list[1]
279+
ood_test_list = all_split_list[2]
280+
281+
all_env_list = [train_list, ood_val_list, ood_test_list, id_val_list, id_test_list]
282+
283+
return all_env_list
284+
285+
def process(self):
286+
287+
if self.domain == 'basis':
288+
concept_shift_list = self.get_basis_concept_list(self.num_data)
289+
print("#IN#concept shift done!")
290+
FIIF_shift_list = self.get_basis_FIIF_list(self.num_data)
291+
print("#IN#FIIF shift done!")
292+
PIIF_shift_list = self.get_basis_PIIF_list(self.num_data)
293+
print("#IN#PIIF shift done!")
294+
else:
295+
raise ValueError(f'Dataset domain cannot be "{self.domain}"')
296+
all_shift_list = concept_shift_list + FIIF_shift_list + PIIF_shift_list
297+
for i, final_data_list in enumerate(all_shift_list):
298+
data, slices = self.collate(final_data_list)
299+
torch.save((data, slices), self.processed_paths[i])
300+
301+
@staticmethod
302+
def load(dataset_root: str, domain: str, shift: str = 'no_shift', generate: bool = False):
303+
r"""
304+
A staticmethod for dataset loading. This method instantiates dataset class, constructing train, id_val, id_test,
305+
ood_val (val), and ood_test (test) splits. Besides, it collects several dataset meta information for further
306+
utilization.
307+
308+
Args:
309+
dataset_root (str): The dataset saving root.
310+
domain (str): The domain selection. Allowed: 'degree' and 'time'.
311+
shift (str): The distributional shift we pick. Allowed: 'no_shift', 'covariate', and 'concept'.
312+
generate (bool): The flag for regenerating dataset. True: regenerate. False: download.
313+
314+
Returns:
315+
dataset or dataset splits.
316+
dataset meta info.
317+
"""
318+
meta_info = Munch()
319+
meta_info.dataset_type = 'syn'
320+
meta_info.model_level = 'graph'
321+
322+
train_dataset = FPIIFMotif(root=dataset_root,
323+
domain=domain, shift=shift, subset='train', generate=generate)
324+
id_val_dataset = FPIIFMotif(root=dataset_root,
325+
domain=domain, shift=shift, subset='id_val', generate=generate) if shift != 'no_shift' else None
326+
id_test_dataset = FPIIFMotif(root=dataset_root,
327+
domain=domain, shift=shift, subset='id_test', generate=generate) if shift != 'no_shift' else None
328+
val_dataset = FPIIFMotif(root=dataset_root,
329+
domain=domain, shift=shift, subset='val', generate=generate)
330+
test_dataset = FPIIFMotif(root=dataset_root,
331+
domain=domain, shift=shift, subset='test', generate=generate)
332+
333+
meta_info.dim_node = train_dataset.num_node_features
334+
meta_info.dim_edge = train_dataset.num_edge_features
335+
336+
meta_info.num_envs = torch.unique(train_dataset.data.env_id).shape[0]
337+
338+
# Define networks' output shape.
339+
if train_dataset.task == 'Binary classification':
340+
meta_info.num_classes = train_dataset.data.y.shape[1]
341+
elif train_dataset.task == 'Regression':
342+
meta_info.num_classes = 1
343+
elif train_dataset.task == 'Multi-label classification':
344+
meta_info.num_classes = torch.unique(train_dataset.data.y).shape[0]
345+
346+
# --- clear buffer dataset._data_list ---
347+
train_dataset._data_list = None
348+
if id_val_dataset:
349+
id_val_dataset._data_list = None
350+
id_test_dataset._data_list = None
351+
val_dataset._data_list = None
352+
test_dataset._data_list = None
353+
354+
return {'train': train_dataset, 'id_val': id_val_dataset, 'id_test': id_test_dataset,
355+
'val': val_dataset, 'test': test_dataset, 'task': train_dataset.task,
356+
'metric': train_dataset.metric}, meta_info

GOOD/kernel/launch.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def make_list_cmds(auto_args, conda_goodtg):
8888
cmd_args = [' '.join(args_set) for args_set in cmd_args_product]
8989
else:
9090
cmd_args = [
91-
f'{conda_goodtg} --exp_round {round} --config_path \"{ood_config_path}\" --log_file default'
91+
f'{conda_goodtg} --exp_round {round} --config_path \"{ood_config_path}\"'
9292
for round in auto_args.allow_rounds]
9393

9494
args_group += cmd_args

GOOD/kernel/launchers/harvest_launcher.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ class HarvestLauncher(Launcher):
2323
def __init__(self):
2424
super(HarvestLauncher, self).__init__()
2525
self.watch = True
26-
self.pick_reference = [-1, -2]
26+
self.pick_reference = [-1]
2727

2828
def __call__(self, jobs_group, auto_args: AutoArgs):
2929
result_dict = self.harvest_all_fruits(jobs_group)

0 commit comments

Comments
 (0)