forked from dgriffiths3/ml_segmentation
-
Notifications
You must be signed in to change notification settings - Fork 0
/
inference.py
75 lines (53 loc) · 2.26 KB
/
inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import cv2
import numpy as np
import pylab as plt
from glob import glob
import argparse
import os
import pickle as pkl
import train
import math
def check_args(args):
if not os.path.exists(args.image_dir):
raise ValueError("Image directory does not exist")
if not os.path.exists(args.output_dir):
raise ValueError("Output directory does not exist")
return args
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("-i", "--image_dir" , help="Path to images", required=True)
parser.add_argument("-m", "--model_path", help="Path to .p model", required=True)
parser.add_argument("-o", "--output_dir", help="Path to output directory", required = True)
args = parser.parse_args()
return check_args(args)
def create_features(img):
img_gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
features, _ = train.create_features(img, img_gray, label=None, train=False)
return features
def compute_prediction(img, model):
border = 5 # (haralick neighbourhood - 1) / 2
img = cv2.copyMakeBorder(img, top=border, bottom=border, \
left=border, right=border, \
borderType = cv2.BORDER_CONSTANT, \
value=[0, 0, 0])
features = create_features(img)
predictions = model.predict(features.reshape(-1, features.shape[1]))
pred_size = int(math.sqrt(features.shape[0]))
inference_img = predictions.reshape(pred_size, pred_size)
return inference_img
def infer_images(image_dir, model_path, output_dir):
filelist = glob(os.path.join(image_dir,'*.jpg'))
print ('[INFO] Running inference on %s test images' %len(filelist))
model = pkl.load(open( model_path, "rb" ) )
for file in filelist:
print ('[INFO] Processing images:', os.path.basename(file))
inference_img = compute_prediction(cv2.imread(file, 1), model)
cv2.imwrite(os.path.join(output_dir, os.path.basename(file)), inference_img)
def main(image_dir, model_path, output_dir):
infer_images(image_dir, model_path, output_dir)
if __name__ == '__main__':
args = parse_args()
image_dir = args.image_dir
model_path = args.model_path
output_dir = args.output_dir
main(image_dir, model_path, output_dir)