-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdata_trans.py
40 lines (35 loc) · 1.36 KB
/
data_trans.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
"""
# Author :zhengyu
# Date: 25/04/2021
Purpose of file:
1):Read tfrecords and generate dataset ,interation and batch tensor
"""
import tensorflow as tf
from utils.augmentation import random_augmenattion
def parse_function(filenames,img_shape,mask_shape,is_training):
features={
"image":tf.io.FixedLenFeature([img_shape[0]*img_shape[1]*img_shape[2]],tf.float32),
"mask":tf.io.FixedLenFeature([mask_shape[0]*mask_shape[1]*mask_shape[2]],tf.float32)
}
## image
image=parsed_example["image"]
image=tf.reshape(image,image_shape)
## mask
mask=parsed_example["mask"]
mask=tf.reshape(mask,mask_shape)
if is_training:
img,mask=random_augmenattion(image,mask)
image=tf.cast(image,tf.float32)
mask=tf.cast(mask,tf.float32)
return image,mask
def make_batch_iterator(tfrecord_path,img_shape,mask_shape,nums,is_training=True,batch_size=16):
def _parse_fn(tfrecord):
return parse_function(tfrecord,img_shape,mask_shape,is_training)
if is_training:
dataset=dataset.shuffle(nums)
dataset=dataset.map(_parse_fn,num_parallel_calls=4)
dataset=dataset.batch(batch_size)
dataset=dataset.prefetch(buffer_size=3*batch_size)
iterator=dataset.make_initializable_iterator()
next_batch=iterator.get_next()
return dataset,iterator,next_batch