-
Notifications
You must be signed in to change notification settings - Fork 330
/
Collect_training_data.py
122 lines (100 loc) · 4.53 KB
/
Collect_training_data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
#================================================================
#
# File name : Collect_training_data.py
# Author : PyLessons
# Created date: 2020-09-27
# Website : https://pylessons.com/
# GitHub : https://github.com/pythonlessons/TensorFlow-2.x-YOLOv3
# Description : YOLO detection to XML example script
#
#================================================================
import os
import subprocess
import time
from datetime import datetime
import cv2
import mss
import numpy as np
import tensorflow as tf
from yolov3.utils import *
from yolov3.configs import *
from yolov3.yolov4 import read_class_names
from tools.Detection_to_XML import CreateXMLfile
import random
def draw_enemy(image, bboxes, CLASSES=YOLO_COCO_CLASSES, show_label=True, show_confidence = True, Text_colors=(255,255,0), rectangle_colors='', tracking=False):
NUM_CLASS = read_class_names(CLASSES)
num_classes = len(NUM_CLASS)
image_h, image_w, _ = image.shape
hsv_tuples = [(1.0 * x / num_classes, 1., 1.) for x in range(num_classes)]
colors = list(map(lambda x: colorsys.hsv_to_rgb(*x), hsv_tuples))
colors = list(map(lambda x: (int(x[0] * 255), int(x[1] * 255), int(x[2] * 255)), colors))
random.seed(0)
random.shuffle(colors)
random.seed(None)
detection_list = []
for i, bbox in enumerate(bboxes):
coor = np.array(bbox[:4], dtype=np.int32)
score = bbox[4]
class_ind = int(bbox[5])
bbox_color = rectangle_colors if rectangle_colors != '' else colors[class_ind]
bbox_thick = int(0.6 * (image_h + image_w) / 1000)
if bbox_thick < 1: bbox_thick = 1
fontScale = 0.75 * bbox_thick
(x1, y1), (x2, y2) = (coor[0], coor[1]), (coor[2], coor[3])
# put object rectangle
cv2.rectangle(image, (x1, y1), (x2, y2), bbox_color, bbox_thick*2)
x, y = int(x1+(x2-x1)/2), int(y1+(y2-y1)/2)
if show_label:
# get text label
score_str = " {:.2f}".format(score) if show_confidence else ""
if tracking: score_str = " "+str(score)
label = "{}".format(NUM_CLASS[class_ind]) + score_str
# get text size
(text_width, text_height), baseline = cv2.getTextSize(label, cv2.FONT_HERSHEY_COMPLEX_SMALL,
fontScale, thickness=bbox_thick)
# put filled text rectangle
cv2.rectangle(image, (x1, y1), (x1 + text_width, y1 - text_height - baseline), bbox_color, thickness=cv2.FILLED)
# put text above rectangle
cv2.putText(image, label, (x1, y1-4), cv2.FONT_HERSHEY_COMPLEX_SMALL, fontScale, Text_colors, bbox_thick, lineType=cv2.LINE_AA)
return image
def detect_enemy(Yolo, original_image, input_size=416, CLASSES=YOLO_COCO_CLASSES, score_threshold=0.3, iou_threshold=0.45, rectangle_colors=''):
image_data = image_preprocess(original_image, [input_size, input_size])
image_data = image_data[np.newaxis, ...].astype(np.float32)
if YOLO_FRAMEWORK == "tf":
pred_bbox = Yolo.predict(image_data)
elif YOLO_FRAMEWORK == "trt":
batched_input = tf.constant(image_data)
result = Yolo(batched_input)
pred_bbox = []
for key, value in result.items():
value = value.numpy()
pred_bbox.append(value)
pred_bbox = [tf.reshape(x, (-1, tf.shape(x)[-1])) for x in pred_bbox]
pred_bbox = tf.concat(pred_bbox, axis=0)
bboxes = postprocess_boxes(pred_bbox, original_image, input_size, score_threshold)
bboxes = nms(bboxes, iou_threshold, method='nms')
image = draw_enemy(original_image, bboxes, CLASSES=CLASSES, rectangle_colors=rectangle_colors)
return image, bboxes
offset = 30
times = []
sct = mss.mss()
yolo = Load_Yolo_model()
while True:
t1 = time.time()
img = np.array(sct.grab({"top": 87-offset, "left": 1920, "width": 1280, "height": 720, "mon": -1}))
img = cv2.cvtColor(img, cv2.COLOR_RGBA2RGB)
image, bboxes = detect_enemy(yolo, np.copy(img), input_size=YOLO_INPUT_SIZE, CLASSES=TRAIN_CLASSES, rectangle_colors=(255,0,0))
if len(bboxes) > 0:
CreateXMLfile("XML_Detections", str(int(time.time())), img, bboxes, read_class_names(TRAIN_CLASSES))
print("got it")
time.sleep(2)
t2 = time.time()
times.append(t2-t1)
times = times[-20:]
ms = sum(times)/len(times)*1000
fps = 1000 / ms
print("FPS", fps)
#cv2.imshow("Detection image", img)
#if cv2.waitKey(25) & 0xFF == ord("q"):
#cv2.destroyAllWindows()
#break