-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathTF_Helper.py
67 lines (47 loc) · 2.22 KB
/
TF_Helper.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import tensorflow as tf
import os
def read_and_decode(filename_queue,imshape=50176):
reader = tf.TFRecordReader()
_, serialized_example = reader.read(filename_queue)
features = tf.parse_single_example(
serialized_example,
features={
'image_raw': tf.FixedLenFeature([], tf.string),
'label': tf.FixedLenFeature([], tf.int64),
})
image = tf.decode_raw(features['image_raw'], tf.uint8)
image.set_shape([imshape])
# image = tf.cast(image, tf.float32) * (1. / 255) - 0.5
image = tf.cast(image, tf.float32)
label = tf.cast(features['label'], tf.int32)
return image, label
def inputs(train_dir, file, batch_size, num_epochs, n_classes, one_hot_labels=False, imshape=50176):
if not num_epochs: num_epochs = None
filename = os.path.join(train_dir, file)
with tf.name_scope('input'):
filename_queue = tf.train.string_input_producer(
[filename], num_epochs=num_epochs)
image, label = read_and_decode(filename_queue, imshape)
if one_hot_labels:
label = tf.one_hot(label, n_classes, dtype=tf.int32)
example_batch, label_batch = tf.train.shuffle_batch(
[image, label], batch_size=batch_size, num_threads=1,
capacity=1000, enqueue_many=False,
# Ensures a minimum amount of shuffling of examples.
min_after_dequeue=10, name=file)
return example_batch, label_batch
def inputs2(train_dir, file, batch_size, num_epochs, n_classes, one_hot_labels=False, imshape=50176):
if not num_epochs: num_epochs = None
filename = os.path.join(train_dir, file)
with tf.name_scope('input'):
filename_queue = tf.train.string_input_producer(
[filename], num_epochs=num_epochs)
image, label = read_and_decode(filename_queue, imshape)
if one_hot_labels:
label = tf.one_hot(label, n_classes, dtype=tf.int32)
example_batch, label_batch = tf.train.shuffle_batch(
[image, label], batch_size=batch_size, num_threads=1,
capacity=1000, enqueue_many=False,
# Ensures a minimum amount of shuffling of examples.
min_after_dequeue=10, name=file)
return example_batch, label_batch