-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdataloading.py
executable file
·92 lines (69 loc) · 2.62 KB
/
dataloading.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import os
import ssl
from six.moves import urllib
import torch
import numpy as np
import dgl
from torch.utils.data import Dataset, DataLoader
def download_file(dataset):
print("Start Downloading data: {}".format(dataset))
url = "https://s3.us-west-2.amazonaws.com/dgl-data/dataset/{}".format(
dataset)
print("Start Downloading File....")
context = ssl._create_unverified_context()
data = urllib.request.urlopen(url, context=context)
with open("./data/{}".format(dataset), "wb") as handle:
handle.write(data.read())
class SnapShotDataset(Dataset):
def __init__(self, path, npz_file):
if not os.path.exists(path+'/'+npz_file):
if not os.path.exists(path):
os.mkdir(path)
download_file(npz_file)
zipfile = np.load(path+'/'+npz_file)
self.x = zipfile['x']
self.y = zipfile['y']
def __len__(self):
return len(self.x)
def __getitem__(self, idx):
if torch.is_tensor(idx):
idx = idx.tolist()
return self.x[idx, ...], self.y[idx, ...]
def METR_LAGraphDataset():
if not os.path.exists('data/graph_la.bin'):
if not os.path.exists('data'):
os.mkdir('data')
download_file('graph_la.bin')
g, _ = dgl.load_graphs('data/graph_la.bin')
return g[0]
class METR_LATrainDataset(SnapShotDataset):
def __init__(self):
super(METR_LATrainDataset, self).__init__('data', 'metr_la_train.npz')
self.mean = self.x[..., 0].mean()
self.std = self.x[..., 0].std()
class METR_LATestDataset(SnapShotDataset):
def __init__(self):
super(METR_LATestDataset, self).__init__('data', 'metr_la_test.npz')
class METR_LAValidDataset(SnapShotDataset):
def __init__(self):
super(METR_LAValidDataset, self).__init__('data', 'metr_la_valid.npz')
def PEMS_BAYGraphDataset():
if not os.path.exists('data/graph_bay.bin'):
if not os.path.exists('data'):
os.mkdir('data')
download_file('graph_bay.bin')
g, _ = dgl.load_graphs('data/graph_bay.bin')
return g[0]
class PEMS_BAYTrainDataset(SnapShotDataset):
def __init__(self):
super(PEMS_BAYTrainDataset, self).__init__(
'data', 'pems_bay_train.npz')
self.mean = self.x[..., 0].mean()
self.std = self.x[..., 0].std()
class PEMS_BAYTestDataset(SnapShotDataset):
def __init__(self):
super(PEMS_BAYTestDataset, self).__init__('data', 'pems_bay_test.npz')
class PEMS_BAYValidDataset(SnapShotDataset):
def __init__(self):
super(PEMS_BAYValidDataset, self).__init__(
'data', 'pems_bay_valid.npz')