-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathvec.py
68 lines (54 loc) · 2.18 KB
/
vec.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import pickle
import numpy as np
from PIL import Image
def vectorize_imgs(img_path):
with Image.open(img_path) as img:
arr_img = np.asarray(img, dtype='float32')
return arr_img
def read_csv_file(csv_file):
x, y = [], []
with open(csv_file, "r") as f:
for line in f.readlines():
path, label = line.strip().split()
x.append(vectorize_imgs(path))
y.append(int(label))
return np.asarray(x, dtype='float32'), np.asarray(y, dtype='int32')
def read_csv_pair_file(csv_file):
x1, x2, y = [], [], []
with open(csv_file, "r") as f:
for line in f.readlines():
p1, p2, label = line.strip().split()
x1.append(vectorize_imgs(p1))
x2.append(vectorize_imgs(p2))
y.append(int(label))
return np.asarray(x1, dtype='float32'), np.asarray(x2, dtype='float32'), np.asarray(y, dtype='int32')
def load_data():
with open('data/dataset.pkl', 'rb') as f:
testX1 = pickle.load(f)
testX2 = pickle.load(f)
testY = pickle.load(f)
validX = pickle.load(f)
validY = pickle.load(f)
trainX = pickle.load(f)
trainY = pickle.load(f)
return testX1, testX2, testY, validX, validY, trainX, trainY
def load_jbdata():
with open('data/JBdata.pkl', 'rb') as f:
validX = pickle.load(f)
validY = pickle.load(f)
return validX, validY
if __name__ == '__main__':
testX1, testX2, testY = read_csv_pair_file('data/test_set.csv')
# validX, validY = read_csv_file('data/valid_set.csv')
# trainX, trainY = read_csv_file('data/train_set.csv')
print(testX1.shape, testX2.shape, testY.shape)
# print(validX.shape, validY.shape)
# print(trainX.shape, trainY.shape)
with open('data/JBtestData.pkl', 'wb') as f:
pickle.dump(testX1, f, pickle.HIGHEST_PROTOCOL)
pickle.dump(testX2, f, pickle.HIGHEST_PROTOCOL)
pickle.dump(testY, f, pickle.HIGHEST_PROTOCOL)
# pickle.dump(validX, f, pickle.HIGHEST_PROTOCOL)
# pickle.dump(validY, f, pickle.HIGHEST_PROTOCOL)
# pickle.dump(trainX, f, pickle.HIGHEST_PROTOCOL)
# pickle.dump(trainY, f, pickle.HIGHEST_PROTOCOL)