-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
71 lines (55 loc) · 2.07 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
import os
import random
import cv2
import numpy as np
import math
import errno
def get_ids(dir):
return (img_name[:-4] for img_name in os.listdir(dir))
def split_train_val(iddataset, val_percent=0.05):
iddataset = list(iddataset)
n_val = math.ceil(len(iddataset)*val_percent)
random.shuffle(iddataset)
return {'train': iddataset[:-n_val], 'val': iddataset[-n_val:]}
def resize_and_crop(img, scale):
h, w = img.shape[:2]
crop_w = (w - h) // 2
img = img[:, crop_w:(w-crop_w)]
img = cv2.resize(img, dsize=None, fx=scale, fy=scale, interpolation=cv2.INTER_NEAREST)
return img
def normalize(x):
return x / 256
# ['coal', 'gangue'] # 红色是煤,绿色是矸石
def get_imgs_and_masks(ids, dir_img, dir_mask, scale=0.5, num_classes=2):
for id in ids:
img = cv2.imread(dir_img+'/'+id+'.jpg')
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
mask = cv2.imread(dir_mask+'/'+id+'.png')
red_mask = np.array(mask[:, :, 2] == 128)
green_mask = np.array(mask[:, :, 1] == 128)
true_mask = np.stack([red_mask.astype(np.float32), green_mask.astype(np.float32)])
img[red_mask] = img[red_mask] * 0.8
img[green_mask] = img[green_mask] * 1.2
img[np.logical_and(mask[:, :, 2] != 128, mask[:, :, 1] != 128)] = 0.2 * \
img[np.logical_and(mask[:, :, 2] != 128, mask[:, :, 1] != 128)]
img = resize_and_crop(img, scale=scale)
img = normalize(img)
true_mask = true_mask.transpose([1, 2, 0])
true_mask = resize_and_crop(true_mask, scale=scale)
true_mask = true_mask.transpose([2, 0, 1])
yield (img, true_mask)
def batch(iterable, batch_size):
batch_data = []
for i, j in enumerate(iterable):
batch_data.append(j)
if i % batch_size == batch_size - 1:
yield batch_data
batch_data = []
if len(batch_data) != 0:
yield batch_data
def mkdir_p(path):
try:
os.makedirs(path)
except OSError as err:
if err.errno != errno.EEXIST:
raise