-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathReadData_tfrecord.py
126 lines (95 loc) · 3.54 KB
/
ReadData_tfrecord.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
#!/usr/bin/env python
#-*- coding:utf-8 -*-
import tensorflow as tf
import numpy as np
import os
import json
def read_and_decode(filename):
print(filename)
filename_queue = tf.train.string_input_producer([filename])
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(serialized_example, features = {
'image_label': tf.FixedLenFeature([], tf.int64),
'image_raw': tf.FixedLenFeature([], tf.string)
})
label = tf.cast(features['image_label'], tf.int32)
img = tf.decode_raw(features['image_raw'], tf.uint8)
img = tf.reshape(img, [256, 256, 3])
print(label ,img)
return label, img
test_images_name = []
test_sketchs_name = []
test_triplets = []
train_images_name = []
train_images_sketch = []
train_triplets = []
def LoadTriplets(filename):
with open(filename, encoding='utf-8') as f:
info = json.load(f)
test_images_name = info['test']['images']
test_sketchs_name = info['test']['sketches']
test_triplets = info['test']['triplets']
train_images_name = info['train']['images']
train_sketchs_name = info['train']['sketches']
train_triplets = info['train']['triplets']
#print(len(train_triplets))
#print(train_images_name[0])
#print(train_sketchs_name[0])
#print(train_triplets[0][0])
return (test_images_name, test_sketchs_name, test_triplets, train_images_name, train_sketchs_name, train_triplets)
def ReadData(sess, batch_size = 128):
s = None
ipos = None
ineg = None
shoes_annotation = r'../shoes_annotation.json'
print(r'Loading ' + shoes_annotation + r'...')
test_images_name, test_sketchs_name, test_triplets, train_images_name, train_sketchs_name, train_triplets = LoadTriplets(shoes_annotation)
print(len(train_triplets))
filename1 = r'./shoes_images_train.tfrecords'
filename2 = r'./shoes_sketches_train.tfrecords'
print(r'Loading ' + filename1 + r'...')
label_images, shoes_images_train = read_and_decode(filename1)
print(r'Loading ' + filename2 + r'...')
label_sketches, shoes_sketches_train = read_and_decode(filename2)
shoes_images = []
for i in range(len(train_images_name)):
l, shoes_image = sess.run([label_images, shoes_images_train])
shoes_images.append(shoes_image)
print(len(shoes_images))
shoes_sketchs = []
for i in range(len(train_sketchs_name)):
l, shoes_sketch = sess.run([label_sketches, shoes_sketches_train])
shoes_sketchs.append(shoes_sketch)
print(len(shoes_sketchs))
print(shoes_images_train.get_shape())
for i in range(304 * 45):
t0 = int(i / 45)
t1 = i % 45
t = train_triplets[t0][t1]
t2 = t[0]
t3 = t[1]
if s is None:
s = shoes_sketchs[t0]
else:
s = tf.concat(0, [s, shoes_sketchs[t0]])
if ipos is None:
ipos = shoes_images[t2]
else:
ipos = tf.concat(0, [ipos, shoes_images[t2]])
if ineg is None:
ineg = shoes_images[t3]
else:
ineg = tf.concat(0, [ineg, shoes_images[t3]])
if i % batch_size == 0:
yield s, ipos, ineg
s = None
ipos = None
ineg = None
if __name__ == '__main__':
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
a = ReadData(sess)
next(a)
next(a)