-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
138 lines (120 loc) · 3.98 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
"""
工具集合
"""
from scipy.misc import imread, imresize
import tensorflow as tf
import numpy as np
import os
import random
from glob import glob
def image_generator(root: str, batch_size=1, resize: tuple = None, crop: tuple = None, flip: float = None,
value_mode: str = "origin", rand=True):
"""
图片获取生成器,随机获取图片数据
:param root: 遍历目录
:param batch_size: 分组数量
:param resize: 更改大小(None为不)
:param crop: 随机截取(None为不)
:param flip: 概率翻转(None为不)
:param value_mode: 结果模式(origin: 原样,sigmoid: 0~1区间,tanh: -1~1区间)
:return: 返回一个迭代器
"""
img_list = glob(os.path.join(root, '*.jpg'))
if rand:
while True:
imgs = []
for _ in range(batch_size):
filename = random.choice(img_list)
img = imread(filename, mode='RGB')
if resize:
img = imresize(img, resize)
if crop:
left = random.randint(0, img.shape[0] - crop[0])
top = random.randint(0, img.shape[1] - crop[1])
img = img[left:left + crop[0], top:top + crop[1]]
if flip:
if random.random() < flip:
img = img[:, ::-1, :]
imgs.append(img)
imgs = np.array(imgs)
if value_mode == 'origin':
yield imgs
elif value_mode == 'sigmoid':
yield imgs / 255.0
elif value_mode == 'tanh':
yield (imgs / 127.5) - 1.0
else:
imgs = []
for filename in img_list:
img = imread(filename, mode='RGB')
if resize:
img = imresize(img, resize)
if crop:
left = random.randint(0, img.shape[0] - crop[0])
top = random.randint(0, img.shape[1] - crop[1])
img = img[left:left + crop[0], top:top + crop[1]]
if flip:
if random.random() < flip:
img = img[:, ::-1, :]
imgs.append(img)
if len(imgs) == batch_size:
rt = np.array(imgs)
imgs = []
if value_mode == 'origin':
yield rt
elif value_mode == 'sigmoid':
yield rt / 255.0
elif value_mode == 'tanh':
yield (rt / 128.0) - 128.0
def visual_grid(X: np.array, shape: tuple((int, int))):
"""
将X中的图片平铺放入新的numpy.array中,用于可视化
:param X: 图片集合(numpy.array)
:param shape: 表格形状(行,列)图片数
:return: 合成后图片array
"""
nh, nw = shape
h, w = X.shape[1:3]
img = np.zeros((h * nh, w * nw, 3))
for n, x in enumerate(X):
j = n // nw
i = n % nw
if n >= nh * nw:
break
img[j * h:j * h + h, i * w:i * w + w, :] = x
return img
def namespace(default_name):
"""
variable space装饰器
产生带name的装饰器
:param fn: 待装饰函数
:return:
"""
def deco(fn):
def wrapper(*args, **kwargs):
if 'name' in kwargs:
name = kwargs['name']
kwargs.pop('name')
else:
name = default_name
with tf.variable_scope(name):
return fn(*args, **kwargs)
return wrapper
return deco
class DataPool:
"""
数据池
用以装载固定量的数据,并提供获取全部及随机获取一个的途径
"""
def __init__(self, size=50):
self._pool = []
self.size = size
def push(self, data):
self._pool.extend(data)
self._pool = self._pool[-self.size:]
def choice(self):
return random.choice(self._pool)
def all(self):
return self._pool
def size(self):
return len(self._pool)